"""Per-cell aggregation, LD subtype classification."""

from __future__ import annotations

import numpy as np
import pandas as pd


# Subtype thresholds (diameter, microns)
MICRO_MAX = 1.0   # < 1 µm -> microsteatosis
MEDIUM_MAX = 5.0  # 1-5 µm -> medium
# > 5 µm -> macrosteatosis


def classify_ld_subtype(diameter_um: float) -> str:
    if diameter_um < MICRO_MAX:
        return "micro"
    if diameter_um < MEDIUM_MAX:
        return "medium"
    return "macro"


def add_subtype_column(df: pd.DataFrame) -> pd.DataFrame:
    """Vectorised LD subtype classifier."""
    out = df.copy()
    bins = [-np.inf, MICRO_MAX, MEDIUM_MAX, np.inf]
    labels = ["micro", "medium", "macro"]
    out["subtype"] = pd.cut(out["diameter_um"], bins=bins, labels=labels, right=False)
    return out


def per_cell_summary(df: pd.DataFrame) -> pd.DataFrame:
    """Aggregate LD objects → one row per (well, cell_id)."""
    df = add_subtype_column(df)
    grp = df.groupby(["well", "cell_id"], observed=True)

    base = grp.agg(
        ld_count=("ld_id", "count"),
        total_area_um2=("area_um2", "sum"),
        mean_diameter_um=("diameter_um", "mean"),
        mean_circularity=("circularity", "mean"),
        mean_integrated_intensity=("integrated_intensity", "mean"),
        mean_distance_to_mito_um=("distance_to_mito_um", "mean"),
        mean_manders_m1=("manders_m1", "mean"),
        mean_manders_m2=("manders_m2", "mean"),
        mean_pearson=("pearson", "mean"),
        mean_plin1=("plin1_intensity", "mean"),
        mean_plin2=("plin2_intensity", "mean"),
        mean_plin3=("plin3_intensity", "mean"),
        mean_plin5=("plin5_intensity", "mean"),
    ).reset_index()

    # Subtype counts
    sub = (
        df.groupby(["well", "cell_id", "subtype"], observed=True)
        .size()
        .unstack(fill_value=0)
        .reset_index()
    )
    for col in ["micro", "medium", "macro"]:
        if col not in sub.columns:
            sub[col] = 0
    sub = sub.rename(
        columns={"micro": "micro_count", "medium": "medium_count", "macro": "macro_count"}
    )
    out = base.merge(sub, on=["well", "cell_id"], how="left")

    out["macro_pct"] = out["macro_count"] / out["ld_count"]
    out["medium_pct"] = out["medium_count"] / out["ld_count"]
    out["micro_pct"] = out["micro_count"] / out["ld_count"]
    return out


def per_well_summary(per_cell: pd.DataFrame, plate_map: pd.DataFrame | None = None) -> pd.DataFrame:
    """Aggregate per-cell → per-well fingerprint vector."""
    grp = per_cell.groupby("well")
    out = grp.agg(
        cell_count=("cell_id", "count"),
        ld_count_per_cell=("ld_count", "mean"),
        total_LD_area_per_cell=("total_area_um2", "mean"),
        mean_diameter_um=("mean_diameter_um", "mean"),
        macro_pct=("macro_pct", "mean"),
        medium_pct=("medium_pct", "mean"),
        micro_pct=("micro_pct", "mean"),
        manders_m1=("mean_manders_m1", "mean"),
        manders_m2=("mean_manders_m2", "mean"),
        pearson=("mean_pearson", "mean"),
        plin1_mean=("mean_plin1", "mean"),
        plin2_mean=("mean_plin2", "mean"),
        plin3_mean=("mean_plin3", "mean"),
        plin5_mean=("mean_plin5", "mean"),
        distance_to_mito_um=("mean_distance_to_mito_um", "mean"),
    ).reset_index()

    if plate_map is not None and not plate_map.empty:
        out = out.merge(plate_map, on="well", how="left")
    return out
