import os
from PIL import Image, ImageDraw, ImageFont
import torch # type: ignore
from transformers import GroundingDinoProcessor, GroundingDinoForObjectDetection
from src.utils.files_utils import preprocess_caption, get_label_box_dict, compute_iou, generate_output_image_url
from sqlalchemy.ext.asyncio import AsyncSession
from src.models.prediction import Prediction
from datetime import datetime
from dotenv import load_dotenv
load_dotenv()
 
class ImageComparison:
    def __init__(self, db: AsyncSession, model_directory=os.getenv("MODEL_DIR")):
        self.db = db
        self.processor = GroundingDinoProcessor.from_pretrained(model_directory)
        self.model = GroundingDinoForObjectDetection.from_pretrained(model_directory)
 
    async def compare_and_annotate(self, before_path, after_path, candidate_labels, prediction_id: int, iou_threshold=0.2, score_threshold=0.3):
        try:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            before_image = Image.open(before_path).convert("RGB")
            after_image = Image.open(after_path).convert("RGB")
 
            text_prompt = ", ".join(candidate_labels)
            processed_text = preprocess_caption(text_prompt)
 
            before_inputs = self.processor(images=before_image, text=processed_text, return_tensors="pt")
            with torch.no_grad():
                before_outputs = self.model(**before_inputs)
 
            id2label = {idx: label for idx, label in enumerate(candidate_labels)}
 
            before_results = self.processor.image_processor.post_process_object_detection(
                before_outputs,
                target_sizes=[before_image.size[::-1]],
                threshold=score_threshold
            )[0]
 
            after_inputs = self.processor(images=after_image, text=processed_text, return_tensors="pt")
            with torch.no_grad():
                after_outputs = self.model(**after_inputs)
 
            after_results = self.processor.image_processor.post_process_object_detection(
                after_outputs,
                target_sizes=[after_image.size[::-1]],
                threshold=score_threshold
            )[0]
 
            before_labels = get_label_box_dict(before_results, score_threshold, id2label)
            after_labels = get_label_box_dict(after_results, score_threshold, id2label)
 
            common_labels = set(before_labels.keys()) & set(after_labels.keys())
            new_labels = set(after_labels.keys()) - set(before_labels.keys())
            removed_labels = set(before_labels.keys()) - set(after_labels.keys())
 
            diff_image = after_image.copy()
            draw_diff = ImageDraw.Draw(diff_image)
 
            try:
                font = ImageFont.truetype("arial.ttf", 15)
            except IOError:
                font = ImageFont.load_default()
 
            # Draw green boxes for new objects
            for label in new_labels:
                for box, score in after_labels[label]:
                    draw_diff.rectangle(box, outline="green", width=2)
                    draw_diff.text((box[0], box[1] - 15), f"{label}: {score:.2f}", fill="green", font=font)
 
            # Draw red boxes for removed objects
            for label in removed_labels:
                for box, score in before_labels[label]:
                    draw_diff.rectangle(box, outline="red", width=2)
                    draw_diff.text((box[0], box[1] - 15), f"{label}: {score:.2f}", fill="red", font=font)
 
            # Draw yellow boxes for moved objects
            for label in common_labels:
                before_box_scores = before_labels[label]
                after_box_scores = after_labels[label]
               
                # Track which boxes have been matched to avoid duplicate markings
                matched_after_boxes = set()
               
                for before_box, before_score in before_box_scores:
                    found_match = False
                    for i, (after_box, after_score) in enumerate(after_box_scores):
                        if i in matched_after_boxes:
                            continue
                           
                        iou = compute_iou(before_box, after_box)
                        if iou < iou_threshold:
                            # Only mark as moved if it's a different position (low IoU)
                            # but not already counted as a new object
                            draw_diff.rectangle(after_box, outline="yellow", width=2)
                            draw_diff.text((after_box[0], after_box[1] - 15), f"{label}: moved", fill="yellow", font=font)
                            matched_after_boxes.add(i)
                            found_match = True
                            break
 
            # Ensure the output directory is created
            output_image_folder = os.path.join(os.getcwd(), "Output_Images")  # Using the absolute path
            os.makedirs(output_image_folder, exist_ok=True)
 
            image_file_name = f"difference_{timestamp}.jpg"
            output_image_path = os.path.join(output_image_folder, image_file_name)  # Using the absolute path
 
            # Save the image
            diff_image.save(output_image_path)
 
            output_image_url = await generate_output_image_url(image_file_name)
 
            prediction = await self.db.get(Prediction, prediction_id)
            if prediction:
                prediction.output_image_url = output_image_url
                prediction.status = "completed"
                await self.db.commit()
                await self.db.refresh(prediction)
            else:
                raise Exception("Prediction not found")
 
            return output_image_url
 
        except Exception as e:
            print(f"Error during comparison: {str(e)}")
            import traceback
            traceback.print_exc()
            return None
 
