You are on page 1of 5

import argparse

import logging
import pandas as pd
from tqdm import tqdm

from text_translation.translator import TextTranslator


from text_translation.TextCleaner import TextCleaner
from text_summarization.summarization import TextSummarization
from string_matcher.string_matching import IncidentTypeMatcher
from database_operations.db_connect import (
load_credentials_from_json,
insert_data_db,
DatabaseConnector,
insert_Keyword_predictions_to_db,update_progress_db
)
from location_extraction.location_extraction import LocationExtractor
from coordinates_extraction.geocoding import CoordinateRetriever
from image_captioning.generator import ImageCaptionGenerator
from keyword_matching.text_classifier import TextClassifier
from iso_code_extraction.isocode_from_location import GeoLocator
from state_stock_classifier.state_stock_binary_classifier import
StateStockIdentifier
from report_id_generator.report_id_generator import ReportIDGenerator
from relatedreport_extraction.relatedreport_extraction import SimilarityFinder
from error_detection.ErrorChecker import ErrorChecker
from tweet_title_generation.title_generation import TweetTitleGeneration

class Pipeline:
"""
Represents the data processing pipeline for text translation, summarization,
location extraction, image caption generation, and keyword classification.
"""

def __init__(self, credentials_file="db_credentials.json", query="SELECT * FROM


public.scraping_result"):
"""
Initialize the Pipeline object.

Args:
credentials_file (str): Path to the database credentials file.
query (str): SQL query to retrieve data from the database.
"""
self.credentials_file = credentials_file
self.query = query
self.logger = self._create_logger()

# Initialize pipeline components


self.text_classifier = TextClassifier()
self.text_cleaner=TextCleaner()
self.translator = TextTranslator()
self.summarizer = TextSummarization()
self.location_extractor = LocationExtractor()
self.coordinates_retriever = CoordinateRetriever()
self.iso_code_locator = GeoLocator()
self.image_caption_generator = ImageCaptionGenerator()
self.incident_type_matcher = IncidentTypeMatcher()
self.state_stock_identifier = StateStockIdentifier()
self.similarity_finder = SimilarityFinder()
self.ErrorChecker=ErrorChecker()
self.tweet_title_generator = TweetTitleGeneration()
self.db_connection = None

def run_pipeline(self):
"""
Execute the data processing pipeline.
"""
credentials = load_credentials_from_json(self.credentials_file)
self.db_connection = DatabaseConnector(
credentials["host"],
credentials["port"],
credentials["database"],
credentials["user"],
credentials["password"]
)

try:
self.db_connection.connect()
self.report_id_generator=ReportIDGenerator(self.db_connection)
# Read data from the database
query = self.query + " WHERE model_processed IS NULL" # Add WHERE
clause to filter NULL values

df = pd.read_sql(query, self.db_connection.connection)
df.dropna(subset=['maintext'], inplace=True)
with tqdm(total=len(df), desc="Progress") as pbar:
for index, row in df.iterrows():
id = str(row['id']) # Convert report ID to string
original_text = row["maintext"]
image_url = row['image_url']
title = row['title']
source = row['source']
screenshot_path = row['screenshot_path']
incident_date = row['date_publish']
url = row['url']
source_domain = row['source_domain']

self.logger.info("Processing ID: %s", id)

try:
# Checking errors in text
self.logger.info("error checking ...")
text = self.ErrorChecker.check_for_errors(original_text)

# clean the text


self.logger.info("Text Cleaning...")
cleaned_text = self.text_cleaner.get_cleaned_text(text)
self.logger.info("Text Cleaning completed.")

# Translate the text


self.logger.info("Translating text...")
translation = self.translator.translate_text(cleaned_text)
self.logger.info("Translation completed.")

# Summarize the text


self.logger.info("Summarizing text...")
summary = self.summarizer.summarize(translation)
self.logger.info("Summarization completed.")
# Extract locations from the summary
self.logger.info("Extracting locations...")
locations =
self.location_extractor.extract_locations(translation)
self.logger.info("Location extraction completed.")

# Extract latitude and longitude from locations


self.logger.info("Getting coordinates...")
latitude, longitude =
self.coordinates_retriever.get_coordinates(locations)
self.logger.info("Coordinate extraction completed.")

# Get ISO code from latitude and longitude


self.logger.info("Getting ISO code...")
iso_code = self.iso_code_locator.get_iso_code(latitude,
longitude)

# Generate image caption


self.logger.info("Generating image caption...")
caption =
self.image_caption_generator.generate_captions(image_url)
self.logger.info("Image caption generation completed.")

if 'twitter.com' in source_domain:
# generate the title
self.logger.info("Generating title...")
translated_title =
self.tweet_title_generator.title_generator(translation)
self.logger.info("Title generation completed.")
else:
# Translate the title
self.logger.info("Translating title...")
translated_title =
self.translator.translate_text(title)
self.logger.info("Title translation completed.")

# Find incident types


self.logger.info("Matching incident types...")
primary_incident_type =
self.incident_type_matcher.find_similar_primary_incident_type(translation)
secondary_incident_type =
self.incident_type_matcher.find_similar_secondary_incident_type(translation)
associated_group =
self.incident_type_matcher.find_similar_associated_group(translation)
other_associated_group =
self.incident_type_matcher.find_similar_other_associated_group(translation)
self.logger.info("Incident matching completed.")

# Generate report ID
self.logger.info("Generating report ID...")
generated_report_id =
self.report_id_generator.generate_report_id(iso_code, incident_date)
self.logger.info("Report ID generation completed.")

# Check state stock


self.logger.info("Checking state stock...")
state_diverted =
self.state_stock_identifier.check_state_stock(summary)
self.logger.info("State stock checking completed.")
# Checking related reports
self.logger.info("Checking related incidents...")
related_reports =
self.similarity_finder.find_similar_incidents_pipeline(self.db_connection, summary,
top_n=5)
self.logger.info("Related incident checking completed.")

# Log the intermediate results


self.logger.info("Translation complete")
self.logger.info("Summary complete")
self.logger.info("Locations complete: %s", locations)
self.logger.info("Coordinates complete: %s, %s", latitude,
longitude)
self.logger.info("ISO code: %s", iso_code)
self.logger.info("Caption: %s", caption)
self.logger.info("Translated title complete")
self.logger.info("Incident date: %s", str(incident_date))
self.logger.info("Related incident IDs: %s",
related_reports)

# Save the summary, translation, caption, and coordinates


in the database
if latitude is not None and longitude is not None:
success = insert_data_db(self.logger,
summary, id, self.db_connection, locations,
caption,
translated_title, iso_code, source,
screenshot_path, incident_date, url,
primary_incident_type, secondary_incident_type,
associated_group, other_associated_group,
generated_report_id, state_diverted, related_reports,
latitude, longitude
)
else:
success = insert_data_db(self.logger,
summary, id, self.db_connection, locations,
caption,
translated_title, iso_code, source,
screenshot_path, incident_date, url,
primary_incident_type, secondary_incident_type,
associated_group, other_associated_group,
generated_report_id, state_diverted, related_reports
)

# Keywords prediction from the summary


keyword_prediction =
self.text_classifier.predict_from_text(summary)
# Save the keyword predictions in the database
insert_Keyword_predictions_to_db(self.logger,id,
keyword_prediction, self.db_connection)

# Update the progress in the database


update_progress_db(id, self.db_connection,
model_processed=True)

except Exception as e:
self.logger.error("Error processing report ID %s: %s", id,
str(e), exc_info=True)
pbar.update(1)

self.logger.info("Translation and summarization completed.")

except Exception as e:
self.logger.error("Failed to connect to the database: %s", str(e))

finally:
if self.db_connection:
self.db_connection.disconnect()

def _create_logger(self):
"""
Create a logger instance for logging pipeline events and messages.

Returns:
logging.Logger: The logger instance.
"""
logger = logging.getLogger("Pipeline")
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
return logger

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Main")
parser.add_argument("--credentials-file", type=str,
default="db_credentials.json", help="Database credentials file")
parser.add_argument("--query", type=str, default="SELECT * FROM
public.scraping_result", help="Database query")
args = parser.parse_args()

pipeline = Pipeline(args.credentials_file, args.query)


pipeline.run_pipeline()

You might also like