from collections import defaultdict
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from app.models.entities import Product, Retailer


def dashboard(db: Session) -> dict:
    products = db.query(Product).all()
    retailers = db.query(Retailer).all()
    total = len(products)
    with_price = sum(1 for p in products if p.price is not None or p.promo_price is not None)
    with_barcode = sum(1 for p in products if p.barcode_or_gtin)
    with_category = sum(1 for p in products if p.category_level_1)
    latest = max((p.date_captured for p in products), default=None)
    quality_score = 0 if total == 0 else round(((with_price + with_category + with_barcode) / (total * 3)) * 100, 1)
    return {
        "total_products": total,
        "retailers_monitored": len(retailers),
        "latest_capture_date": latest.isoformat() if latest else None,
        "price_coverage_pct": round((with_price / total) * 100, 1) if total else 0,
        "barcode_coverage_pct": round((with_barcode / total) * 100, 1) if total else 0,
        "category_coverage_pct": round((with_category / total) * 100, 1) if total else 0,
        "data_quality_score": quality_score,
    }


def data_quality(db: Session) -> dict:
    products = db.query(Product).all()
    total = len(products)
    stale_cutoff = datetime.utcnow() - timedelta(days=30)
    duplicates = defaultdict(int)
    retailer_category = defaultdict(set)
    for p in products:
        duplicates[(p.retailer_name, p.canonical_key)] += 1
        if p.category_level_1:
            retailer_category[p.retailer_name].add(p.category_level_1)
    return {
        "total_products": total,
        "missing_price": sum(1 for p in products if p.price is None and p.promo_price is None),
        "missing_pack_size": sum(1 for p in products if not p.pack_size),
        "missing_category": sum(1 for p in products if not p.category_level_1),
        "missing_image": sum(1 for p in products if not p.image_url),
        "missing_barcode": sum(1 for p in products if not p.barcode_or_gtin),
        "stale_products": sum(1 for p in products if p.date_captured < stale_cutoff),
        "duplicate_groups": sum(1 for count in duplicates.values() if count > 1),
        "category_coverage_by_retailer": {k: len(v) for k, v in retailer_category.items()},
    }


def product_matches(db: Session) -> list[dict]:
    buckets = defaultdict(list)
    for p in db.query(Product).all():
        if p.canonical_key:
            buckets[p.canonical_key].append(p)
    groups = []
    for key, items in buckets.items():
        retailers = sorted(set(i.retailer_name for i in items))
        if len(items) > 1 and len(retailers) > 1:
            groups.append({
                "canonical_key": key,
                "confidence_score": min(0.95, 0.65 + len(retailers) * 0.05),
                "retailers": retailers,
                "products": [{"id": i.id, "retailer_name": i.retailer_name, "product_name": i.product_name, "brand": i.brand, "pack_size": i.pack_size, "price": i.price, "promo_price": i.promo_price} for i in items]
            })
    return groups
