"""qc.py

Per-well quality control for Seahorse XF plates.

Implements:
- Basal OCR / ECAR CV% per group
- Oligomycin response (>=30% OCR drop)
- FCCP response (>=50% OCR increase)
- Rotenone/Antimycin response (>=80% OCR drop from FCCP/basal)
- Baseline drift (linear regression slope significance)
- IQR outlier + Mahalanobis distance outlier
- Per-well auto flag: Pass / Borderline / Excludable
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from scipy import stats


# ---------------------------------------------------------------------------
# helpers
# ---------------------------------------------------------------------------


def _phase_boundaries(df_well: pd.DataFrame) -> Dict[str, List[int]]:
    """Group measurement indices by injection phase using the 'injection' col."""
    phases: Dict[str, List[int]] = {}
    if "injection" not in df_well.columns:
        # fallback: split into 3 equal phases
        meas = sorted(df_well["measurement"].unique())
        n = len(meas) // 4 or 1
        phases["baseline"] = meas[:n]
        phases["phase1"] = meas[n : 2 * n]
        phases["phase2"] = meas[2 * n : 3 * n]
        phases["phase3"] = meas[3 * n :]
        return phases
    for _, row in df_well.iterrows():
        key = str(row.get("injection", "")).strip().lower() or "baseline"
        phases.setdefault(key, []).append(int(row["measurement"]))
    return phases


def _mean(arr: List[float]) -> float:
    a = np.asarray(arr, dtype=float)
    a = a[~np.isnan(a)]
    return float(a.mean()) if len(a) else float("nan")


# ---------------------------------------------------------------------------
# per-well QC
# ---------------------------------------------------------------------------


@dataclass
class WellQC:
    well: str
    group: str
    cell_type: str
    drug: str
    basal_ocr: float
    basal_ecar: float
    basal_cv_ocr: float
    oligo_drop_pct: float
    fccp_rise_pct: float
    rotaa_drop_pct: float
    drift_slope: float
    drift_p: float
    is_outlier_iqr: bool
    mahalanobis_d2: float
    is_outlier_maha: bool
    flag: str          # "Pass" | "Borderline" | "Excludable"
    reasons: str


def _drift_check(times: np.ndarray, vals: np.ndarray) -> Tuple[float, float]:
    """Linear regression slope significance on basal phase OCR."""
    if len(vals) < 3 or np.all(np.isnan(vals)):
        return 0.0, 1.0
    mask = ~np.isnan(vals) & ~np.isnan(times)
    if mask.sum() < 3:
        return 0.0, 1.0
    res = stats.linregress(times[mask], vals[mask])
    return float(res.slope), float(res.pvalue)


def _iqr_outlier(values: np.ndarray, x: float, k: float = 1.5) -> bool:
    if len(values) < 4 or np.isnan(x):
        return False
    q1, q3 = np.nanpercentile(values, [25, 75])
    iqr = q3 - q1
    return bool((x < q1 - k * iqr) or (x > q3 + k * iqr))


def _mahalanobis(x: np.ndarray, X: np.ndarray) -> float:
    """Mahalanobis distance^2 from a 2D point x against a set X (n x 2)."""
    if X.shape[0] < 3:
        return 0.0
    try:
        cov = np.cov(X, rowvar=False)
        if np.linalg.matrix_rank(cov) < cov.shape[0]:
            cov = cov + np.eye(cov.shape[0]) * 1e-6
        inv = np.linalg.inv(cov)
        mu = X.mean(axis=0)
        d = x - mu
        return float(d @ inv @ d.T)
    except Exception:
        return 0.0


def run_well_qc(plate_df: pd.DataFrame, protocol: str) -> pd.DataFrame:
    """Run QC per well; returns a DataFrame of WellQC rows."""
    df = plate_df.copy()
    df["measurement"] = pd.to_numeric(df["measurement"], errors="coerce")
    df["ocr"] = pd.to_numeric(df["ocr"], errors="coerce")
    df["ecar"] = pd.to_numeric(df["ecar"], errors="coerce")

    # collect basal OCR / ECAR distribution at group level for outlier checks
    basal_by_group: Dict[str, List[Tuple[float, float]]] = {}

    rows: List[WellQC] = []
    for well, sub in df.groupby("well"):
        sub = sub.sort_values("measurement")
        phases = _phase_boundaries(sub)

        # baseline = measurements before first non-baseline injection
        baseline_idx = phases.get("baseline", [])
        if not baseline_idx:
            # fallback: first 3 measurements
            baseline_idx = sorted(sub["measurement"].unique())[:3]
        baseline = sub[sub["measurement"].isin(baseline_idx)]

        basal_ocr = _mean(baseline["ocr"].tolist())
        basal_ecar = _mean(baseline["ecar"].tolist())
        basal_cv = (
            float(np.nanstd(baseline["ocr"]) / np.nanmean(baseline["ocr"]) * 100.0)
            if len(baseline) > 1 and np.nanmean(baseline["ocr"]) > 0
            else 0.0
        )

        # injection responses (only meaningful for some protocols)
        def phase_mean(name_keys: List[str]) -> float:
            for k in name_keys:
                for col_key in phases:
                    if k in col_key:
                        vals = sub[sub["measurement"].isin(phases[col_key])]["ocr"]
                        return _mean(vals.tolist())
            return float("nan")

        oligo_mean = phase_mean(["oligo"])
        fccp_mean = phase_mean(["fccp", "bam15"])
        rotaa_mean = phase_mean(["rot", "aa", "antimycin"])

        oligo_drop = (
            (basal_ocr - oligo_mean) / basal_ocr * 100.0
            if basal_ocr and not np.isnan(oligo_mean)
            else float("nan")
        )
        fccp_rise = (
            (fccp_mean - basal_ocr) / basal_ocr * 100.0
            if basal_ocr and not np.isnan(fccp_mean)
            else float("nan")
        )
        rotaa_drop = (
            (basal_ocr - rotaa_mean) / basal_ocr * 100.0
            if basal_ocr and not np.isnan(rotaa_mean)
            else float("nan")
        )

        # drift on baseline
        t = baseline.get("time_min", baseline["measurement"]).to_numpy(dtype=float)
        slope, pval = _drift_check(t, baseline["ocr"].to_numpy(dtype=float))

        group = str(sub["group"].iloc[0]) if sub["group"].notna().any() else "default"
        cell_type = str(sub["cell_type"].iloc[0]) if sub["cell_type"].notna().any() else ""
        drug = str(sub["drug"].iloc[0]) if sub["drug"].notna().any() else ""

        basal_by_group.setdefault(group, []).append((basal_ocr, basal_ecar))

        rows.append(
            WellQC(
                well=str(well),
                group=group,
                cell_type=cell_type,
                drug=drug,
                basal_ocr=basal_ocr,
                basal_ecar=basal_ecar,
                basal_cv_ocr=basal_cv,
                oligo_drop_pct=oligo_drop,
                fccp_rise_pct=fccp_rise,
                rotaa_drop_pct=rotaa_drop,
                drift_slope=slope,
                drift_p=pval,
                is_outlier_iqr=False,
                mahalanobis_d2=0.0,
                is_outlier_maha=False,
                flag="Pass",
                reasons="",
            )
        )

    # second pass: outliers + flags using per-group basal distribution
    out_rows = []
    for r in rows:
        group_pts = np.array(basal_by_group.get(r.group, []))
        if group_pts.size > 0:
            r.is_outlier_iqr = _iqr_outlier(group_pts[:, 0], r.basal_ocr)
            r.mahalanobis_d2 = _mahalanobis(
                np.array([r.basal_ocr, r.basal_ecar]), group_pts
            )
            r.is_outlier_maha = bool(r.mahalanobis_d2 > 7.81)  # chi2 df=2, p<0.05

        reasons: List[str] = []
        # response thresholds (only check if available)
        if protocol == "Mito Stress" or protocol == "FAO Assay":
            if not np.isnan(r.oligo_drop_pct) and r.oligo_drop_pct < 30:
                reasons.append(f"oligo response weak ({r.oligo_drop_pct:.1f}%<30%)")
            if not np.isnan(r.fccp_rise_pct) and r.fccp_rise_pct < 50:
                reasons.append(f"FCCP response weak ({r.fccp_rise_pct:.1f}%<50%)")
        if protocol in ("Mito Stress", "FAO Assay", "ATP Rate Assay"):
            if not np.isnan(r.rotaa_drop_pct) and r.rotaa_drop_pct < 80:
                reasons.append(f"Rot/AA response weak ({r.rotaa_drop_pct:.1f}%<80%)")
        if r.basal_cv_ocr > 20:
            reasons.append(f"basal OCR CV high ({r.basal_cv_ocr:.1f}%>20%)")
        if r.drift_p < 0.05 and abs(r.drift_slope) > 0:
            reasons.append(f"baseline drift (slope={r.drift_slope:.2f},p={r.drift_p:.3f})")
        if r.is_outlier_iqr:
            reasons.append("IQR outlier")
        if r.is_outlier_maha:
            reasons.append(f"Mahalanobis outlier (d2={r.mahalanobis_d2:.2f})")

        # flag policy
        if r.is_outlier_iqr or r.is_outlier_maha or len(reasons) >= 3:
            r.flag = "Excludable"
        elif reasons:
            r.flag = "Borderline"
        else:
            r.flag = "Pass"
        r.reasons = "; ".join(reasons)
        out_rows.append(r)

    return pd.DataFrame([r.__dict__ for r in out_rows])


def summarize_qc(qc_df: pd.DataFrame) -> Dict[str, int]:
    if qc_df.empty:
        return {"Pass": 0, "Borderline": 0, "Excludable": 0, "total": 0}
    vc = qc_df["flag"].value_counts().to_dict()
    vc["total"] = int(len(qc_df))
    for k in ("Pass", "Borderline", "Excludable"):
        vc.setdefault(k, 0)
    return vc
