import os
import base64
import asyncio
from src.common.logger import logger
from collections import defaultdict
from PIL import Image
from io import BytesIO
from dotenv import load_dotenv
load_dotenv()
 
BASE_INPUT_DIR = "Input_Images"
BASE_OUTPUT_DIR = "Output_Images"
DOMAIN = os.getenv("DOMAIN")
 
async def save_input_images_async(timestamp: str):
    try:
        # Generate URLs for the saved images
        before_image_url = f"https://{DOMAIN}/{BASE_INPUT_DIR}/{timestamp}/before.jpg"
        after_image_url = f"https://{DOMAIN}/{BASE_INPUT_DIR}/{timestamp}/after.jpg"
 
        logger.info(f"Generated image URLs: before_url={before_image_url}, after_url={after_image_url}")
        return before_image_url, after_image_url
    except Exception as e:
        logger.exception(f"Error while generating image URLs for timestamp {timestamp}.")
        raise Exception(f"Error while generating image URLs: {str(e)}")
 
   
async def generate_output_image_url(output_image_path: str):
    try:
        # Generate a unique URL using the timestamp
        output_url = f"https://{DOMAIN}/{BASE_OUTPUT_DIR}/{output_image_path}"
 
        logger.info(f"Generated output image URL: {output_url}")
        return output_url
    except Exception as e:
        logger.exception(f"Error while generating output image URL for {output_image_path}.")
        raise Exception(f"Error while generating output image URL: {str(e)}")
   
async def encode_image(image_path: str) -> str:
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")
 
    def read_image():
        with open(image_path, "rb") as f:
            return f.read()
 
    image_bytes = await asyncio.to_thread(read_image)
    return base64.b64encode(image_bytes).decode("utf-8")
 
def preprocess_caption(caption: str) -> str:
    """Add a period at the end of the caption if it doesn't have one."""
    result = caption.lower().strip()
    return result if result.endswith(".") else result + "."
 
def get_label_box_dict(predictions, score_threshold=0.2, id2label=None):
    """Convert model predictions to a dictionary of labels and boxes."""
    label_dict = defaultdict(list)
 
    if isinstance(predictions, dict) and all(key in predictions for key in ["scores", "labels", "boxes"]):
        scores = predictions["scores"].tolist()
        label_ids = predictions["labels"].tolist()
        boxes = predictions["boxes"].tolist()
 
        for score, label_id, box in zip(scores, label_ids, boxes):
            if score >= score_threshold:
                label = id2label.get(label_id, str(label_id))
                label_dict[label].append((box, score))
 
    return label_dict
 
def compute_iou(box1, box2):
    """Compute Intersection over Union between two bounding boxes."""
    x1, y1, x2, y2 = box1
    x1_p, y1_p, x2_p, y2_p = box2
 
    x_int = max(x1, x1_p)
    y_int = max(y1, y1_p)
    x_int2 = min(x2, x2_p)
    y_int2 = min(y2, y2_p)
 
    intersection_area = max(0, x_int2 - x_int) * max(0, y_int2 - y_int)
    area_box1 = (x2 - x1) * (y2 - y1)
    area_box2 = (x2_p - x1_p) * (y2_p - y1_p)
 
    union_area = area_box1 + area_box2 - intersection_area
    return intersection_area / union_area if union_area != 0 else 0
 
 
 
def image_to_base64(image: Image) -> str:
    """Convert PIL Image to base64 string."""
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return img_base64
