"""bioenergetics.py

Compute the 11 standard Seahorse bioenergetic parameters and derived
substrate-dependence / phenotype map outputs.

Parameters (per well or per group):
  Mito Stress (6):
    1. basal_ocr             = basal_ocr - non_mito_ocr
    2. atp_linked_ocr        = basal_ocr - oligo_ocr
    3. maximal_ocr           = fccp_ocr - non_mito_ocr
    4. spare_capacity        = maximal_ocr - basal_ocr
    5. proton_leak           = oligo_ocr - non_mito_ocr
    6. non_mito_ocr          = rotaa_ocr
  Glycolysis Stress (3):
    7. basal_ecar (glycolysis)
    8. glycolytic_capacity   = oligo_ecar - 2dg_ecar
    9. glycolytic_reserve    = glycolytic_capacity - glycolysis
  ATP Rate Assay (2):
   10. basal_atp_rate        = atp_linked + glyco-ATP proxy
   11. max_atp_rate

Substrate flexibility (Houtkooper / Mootha framework):
  - palmitate-BSA vs BSA (FAO contribution)
  - ± etomoxir (CPT1-dependent FAO)
  - ± BPTES (glutamine)
  - ± UK5099 (pyruvate)

Phenotype map quadrant: energetic / glycolytic / aerobic / quiescent.
"""

from __future__ import annotations

from dataclasses import dataclass, asdict
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from scipy import stats


# ---------------------------------------------------------------------------
# phase extractors
# ---------------------------------------------------------------------------


def _phase_mean(df_well: pd.DataFrame, col: str, key_substrings: List[str]) -> float:
    """Mean of `col` in all measurements whose 'injection' contains any key."""
    if "injection" not in df_well.columns:
        return float("nan")
    inj = df_well["injection"].astype(str).str.lower()
    mask = pd.Series(False, index=df_well.index)
    for k in key_substrings:
        mask = mask | inj.str.contains(k, na=False)
    sub = df_well.loc[mask, col]
    sub = pd.to_numeric(sub, errors="coerce").dropna()
    return float(sub.mean()) if len(sub) else float("nan")


def _baseline_mean(df_well: pd.DataFrame, col: str) -> float:
    if "injection" in df_well.columns:
        inj = df_well["injection"].astype(str).str.lower()
        mask = inj.isin(["baseline", "", "nan"]) | inj.isna()
        sub = df_well.loc[mask, col]
    else:
        sub = df_well.sort_values("measurement").head(3)[col]
    sub = pd.to_numeric(sub, errors="coerce").dropna()
    return float(sub.mean()) if len(sub) else float("nan")


# ---------------------------------------------------------------------------
# 11 parameters
# ---------------------------------------------------------------------------


@dataclass
class WellParams:
    well: str
    group: str
    cell_type: str
    drug: str
    dose: str
    substrate: str
    protocol: str
    # mito stress
    basal_ocr: float = float("nan")
    atp_linked_ocr: float = float("nan")
    maximal_ocr: float = float("nan")
    spare_capacity: float = float("nan")
    proton_leak: float = float("nan")
    non_mito_ocr: float = float("nan")
    # glycolysis
    basal_ecar: float = float("nan")
    glycolytic_capacity: float = float("nan")
    glycolytic_reserve: float = float("nan")
    # atp rate
    basal_atp_rate: float = float("nan")
    max_atp_rate: float = float("nan")


def _safe(x: float) -> float:
    try:
        f = float(x)
    except Exception:
        return float("nan")
    if np.isnan(f):
        return float("nan")
    return f


def compute_well_params(df_well: pd.DataFrame, protocol: str) -> WellParams:
    """Compute the 11 parameters for a single well."""
    g = lambda c: str(df_well[c].iloc[0]) if c in df_well.columns and df_well[c].notna().any() else ""

    base_ocr = _baseline_mean(df_well, "ocr")
    base_ecar = _baseline_mean(df_well, "ecar")

    oligo_ocr = _phase_mean(df_well, "ocr", ["oligo"])
    oligo_ecar = _phase_mean(df_well, "ecar", ["oligo"])
    fccp_ocr = _phase_mean(df_well, "ocr", ["fccp", "bam15"])
    rotaa_ocr = _phase_mean(df_well, "ocr", ["rot", "aa", "antimycin"])
    glucose_ecar = _phase_mean(df_well, "ecar", ["glucose"])
    twodg_ecar = _phase_mean(df_well, "ecar", ["2-dg", "2dg", "deoxy"])

    # mito stress / FAO -> all 6
    if protocol in ("Mito Stress", "FAO Assay"):
        non_mito = _safe(rotaa_ocr)
        basal_corr = _safe(base_ocr - non_mito) if not np.isnan(non_mito) else _safe(base_ocr)
        atp_linked = _safe(base_ocr - oligo_ocr)
        maximal = _safe(fccp_ocr - non_mito) if not np.isnan(non_mito) else _safe(fccp_ocr)
        spare = _safe(maximal - basal_corr) if not (np.isnan(maximal) or np.isnan(basal_corr)) else float("nan")
        proton = _safe(oligo_ocr - non_mito) if not np.isnan(non_mito) else _safe(oligo_ocr)
    else:
        non_mito = float("nan")
        basal_corr = _safe(base_ocr)
        atp_linked = _safe(base_ocr - oligo_ocr) if not np.isnan(oligo_ocr) else float("nan")
        maximal = float("nan")
        spare = float("nan")
        proton = float("nan")

    # glycolysis stress -> 3 (re-define basal_ecar as glycolysis after glucose)
    if protocol == "Glycolysis Stress":
        # In a real Glyc Stress: baseline = no glucose; then +glucose; then +oligo; then +2-DG
        # "glycolysis" rate = ECAR after glucose - ECAR after 2-DG (non-glyc ECAR)
        non_glyc = _safe(twodg_ecar)
        glyc = _safe(glucose_ecar - non_glyc) if not np.isnan(non_glyc) else _safe(glucose_ecar)
        glyco_cap = _safe(oligo_ecar - non_glyc) if not np.isnan(non_glyc) else _safe(oligo_ecar)
        glyco_res = _safe(glyco_cap - glyc) if not (np.isnan(glyco_cap) or np.isnan(glyc)) else float("nan")
        basal_ecar_out = glyc
    else:
        basal_ecar_out = _safe(base_ecar)
        glyco_cap = float("nan")
        glyco_res = float("nan")

    # ATP rate assay -> basal_atp_rate, max_atp_rate
    if protocol == "ATP Rate Assay":
        # simplified: mitoATP proxy = base_ocr - oligo_ocr; glycoATP proxy ~= basal ECAR
        mito_atp = _safe(base_ocr - oligo_ocr)
        glyco_atp = _safe(base_ecar)
        basal_atp = _safe(mito_atp + glyco_atp)
        max_atp = _safe(mito_atp * 1.5 + glyco_atp * 1.2)  # rough headroom proxy
    else:
        basal_atp = float("nan")
        max_atp = float("nan")

    return WellParams(
        well=g("well"),
        group=g("group"),
        cell_type=g("cell_type"),
        drug=g("drug"),
        dose=g("dose"),
        substrate=g("substrate"),
        protocol=protocol,
        basal_ocr=basal_corr,
        atp_linked_ocr=atp_linked,
        maximal_ocr=maximal,
        spare_capacity=spare,
        proton_leak=proton,
        non_mito_ocr=non_mito,
        basal_ecar=basal_ecar_out,
        glycolytic_capacity=glyco_cap,
        glycolytic_reserve=glyco_res,
        basal_atp_rate=basal_atp,
        max_atp_rate=max_atp,
    )


def compute_plate_params(df_plate: pd.DataFrame, protocol: str) -> pd.DataFrame:
    rows = []
    for well, sub in df_plate.groupby("well"):
        rows.append(asdict(compute_well_params(sub, protocol)))
    return pd.DataFrame(rows)


# ---------------------------------------------------------------------------
# substrate dependence (Houtkooper / Mootha framework)
# ---------------------------------------------------------------------------


def substrate_dependence(params_df: pd.DataFrame) -> pd.DataFrame:
    """Compute fuel-flexibility per cell_type x drug pair.

    Requires substrate column with values like:
       - "BSA", "Palmitate-BSA"
       - "Glutamine", "Glutamine+BPTES"
       - "Glucose", "Glucose+UK5099"
       - "Palmitate+Etomoxir"
    """
    if "substrate" not in params_df.columns:
        return pd.DataFrame()

    out_rows = []
    grouped = params_df.groupby(["cell_type", "drug"])
    for (ct, dr), sub in grouped:
        rec: Dict[str, float] = {"cell_type": ct, "drug": dr}

        def mean_where(pred) -> float:
            v = sub.loc[pred, "basal_ocr"]
            v = pd.to_numeric(v, errors="coerce").dropna()
            return float(v.mean()) if len(v) else float("nan")

        bsa = mean_where(sub["substrate"].astype(str).str.lower() == "bsa")
        palm = mean_where(sub["substrate"].astype(str).str.lower() == "palmitate-bsa")
        palm_eto = mean_where(sub["substrate"].astype(str).str.lower() == "palmitate+etomoxir")
        gln = mean_where(sub["substrate"].astype(str).str.lower() == "glutamine")
        gln_bptes = mean_where(sub["substrate"].astype(str).str.lower() == "glutamine+bptes")
        glc = mean_where(sub["substrate"].astype(str).str.lower() == "glucose")
        glc_uk = mean_where(sub["substrate"].astype(str).str.lower() == "glucose+uk5099")

        # contributions (delta OCR attributable to substrate)
        rec["fao_contribution"] = _safe(palm - bsa)
        rec["cpt1_dependent_fao"] = _safe(palm - palm_eto) if not np.isnan(palm_eto) else float("nan")
        rec["gln_dependence"] = _safe(gln - gln_bptes) if not np.isnan(gln_bptes) else float("nan")
        rec["pyruvate_dependence"] = _safe(glc - glc_uk) if not np.isnan(glc_uk) else float("nan")

        # flexibility index = std / mean of [fao, gln, pyr]
        contribs = [
            x for x in (rec["fao_contribution"], rec["gln_dependence"], rec["pyruvate_dependence"])
            if not np.isnan(x)
        ]
        if len(contribs) >= 2 and np.mean(contribs) != 0:
            rec["flexibility_index"] = float(np.std(contribs) / abs(np.mean(contribs)))
        else:
            rec["flexibility_index"] = float("nan")

        out_rows.append(rec)
    return pd.DataFrame(out_rows)


# ---------------------------------------------------------------------------
# phenotype quadrant
# ---------------------------------------------------------------------------


def phenotype_quadrant(params_df: pd.DataFrame, ocr_thr: Optional[float] = None,
                       ecar_thr: Optional[float] = None) -> pd.DataFrame:
    """Classify each well into energetic / glycolytic / aerobic / quiescent.

    Thresholds default to median of the cohort (Mootha-style relative map).
    """
    if params_df.empty:
        return params_df
    df = params_df.copy()
    ocr = pd.to_numeric(df["basal_ocr"], errors="coerce")
    ecar = pd.to_numeric(df["basal_ecar"], errors="coerce")
    ocr_thr = float(np.nanmedian(ocr)) if ocr_thr is None else ocr_thr
    ecar_thr = float(np.nanmedian(ecar)) if ecar_thr is None else ecar_thr

    def classify(row) -> str:
        o = row["basal_ocr"]
        e = row["basal_ecar"]
        if np.isnan(o) or np.isnan(e):
            return "Unclassified"
        hi_o = o >= ocr_thr
        hi_e = e >= ecar_thr
        if hi_o and hi_e:
            return "Energetic"
        if not hi_o and hi_e:
            return "Glycolytic"
        if hi_o and not hi_e:
            return "Aerobic"
        return "Quiescent"

    df["phenotype"] = df.apply(classify, axis=1)
    df.attrs["ocr_threshold"] = ocr_thr
    df.attrs["ecar_threshold"] = ecar_thr
    return df


# ---------------------------------------------------------------------------
# cohort statistics
# ---------------------------------------------------------------------------


def anova_by_group(
    params_df: pd.DataFrame, value_col: str, group_col: str = "drug"
) -> Dict[str, object]:
    """One-way ANOVA across groups; returns F, p, group means, n."""
    if params_df.empty or group_col not in params_df.columns:
        return {"ok": False, "reason": "empty or missing group col"}
    df = params_df.copy()
    df[value_col] = pd.to_numeric(df[value_col], errors="coerce")
    df = df.dropna(subset=[value_col, group_col])
    if df.empty:
        return {"ok": False, "reason": "no valid values"}
    groups = [g[value_col].values for _, g in df.groupby(group_col) if len(g) >= 2]
    names = [n for n, g in df.groupby(group_col) if len(g) >= 2]
    if len(groups) < 2:
        return {"ok": False, "reason": "need >=2 groups"}
    F, p = stats.f_oneway(*groups)
    means = {n: float(np.mean(g)) for n, g in zip(names, groups)}
    return {"ok": True, "F": float(F), "p": float(p), "means": means,
            "n_per_group": {n: int(len(g)) for n, g in zip(names, groups)}}


def tukey_hsd_fallback(
    params_df: pd.DataFrame, value_col: str, group_col: str = "drug"
) -> pd.DataFrame:
    """Tukey HSD pairwise. Tries statsmodels; falls back to scipy.tukey_hsd.

    Returns a long-form DataFrame (group1, group2, mean_diff, p_adj).
    """
    df = params_df.copy()
    df[value_col] = pd.to_numeric(df[value_col], errors="coerce")
    df = df.dropna(subset=[value_col, group_col])
    if df.empty:
        return pd.DataFrame(columns=["group1", "group2", "mean_diff", "p_adj"])

    # try statsmodels
    try:
        from statsmodels.stats.multicomp import pairwise_tukeyhsd  # type: ignore
        res = pairwise_tukeyhsd(df[value_col].values, df[group_col].astype(str).values)
        return pd.DataFrame({
            "group1": res._results_table.data[1:][:],  # not all rows reliable
        })
    except Exception:
        pass

    # scipy fallback
    try:
        groups = [g[value_col].values for _, g in df.groupby(group_col) if len(g) >= 2]
        names = [n for n, g in df.groupby(group_col) if len(g) >= 2]
        if len(groups) < 2:
            return pd.DataFrame(columns=["group1", "group2", "mean_diff", "p_adj"])
        res = stats.tukey_hsd(*groups)
        rows = []
        for i in range(len(names)):
            for j in range(i + 1, len(names)):
                rows.append({
                    "group1": str(names[i]),
                    "group2": str(names[j]),
                    "mean_diff": float(np.mean(groups[i]) - np.mean(groups[j])),
                    "p_adj": float(res.pvalue[i, j]),
                })
        return pd.DataFrame(rows)
    except Exception:
        # absolute last resort: pairwise t-tests with bonferroni
        groups = {n: g[value_col].values for n, g in df.groupby(group_col) if len(g) >= 2}
        names = list(groups.keys())
        rows = []
        m = len(names) * (len(names) - 1) // 2 or 1
        for i in range(len(names)):
            for j in range(i + 1, len(names)):
                t, p = stats.ttest_ind(groups[names[i]], groups[names[j]], equal_var=False)
                rows.append({
                    "group1": names[i], "group2": names[j],
                    "mean_diff": float(np.mean(groups[names[i]]) - np.mean(groups[names[j]])),
                    "p_adj": float(min(1.0, p * m)),
                })
        return pd.DataFrame(rows)


# ---------------------------------------------------------------------------
# manuscript-ready Korean summary
# ---------------------------------------------------------------------------


def korean_summary(params_df: pd.DataFrame, protocol: str) -> str:
    """Hepatology / Cell Metabolism-style manuscript Korean summary."""
    if params_df.empty:
        return "(데이터 없음)"

    lines: List[str] = []
    lines.append(f"## Bioenergetic 결과 요약 — {protocol}")
    lines.append(f"- 분석 well 수: {len(params_df)}")
    if "cell_type" in params_df.columns:
        lines.append(f"- 세포주: {', '.join(sorted(set(params_df['cell_type'].astype(str))))}")
    if "drug" in params_df.columns:
        lines.append(f"- 약물 처리: {', '.join(sorted(set(params_df['drug'].astype(str))))}")

    def mean_sd(col: str) -> Tuple[float, float]:
        v = pd.to_numeric(params_df[col], errors="coerce").dropna()
        if len(v) == 0:
            return (float("nan"), float("nan"))
        return (float(v.mean()), float(v.std()))

    param_kor = {
        "basal_ocr": "기저 OCR (pmol O2/min)",
        "atp_linked_ocr": "ATP-linked OCR",
        "maximal_ocr": "최대 OCR (FCCP)",
        "spare_capacity": "예비호흡능 (spare respiratory capacity)",
        "proton_leak": "양성자 누출 (proton leak)",
        "non_mito_ocr": "비미토콘드리아 OCR",
        "basal_ecar": "기저 ECAR (mpH/min)",
        "glycolytic_capacity": "해당 능력 (glycolytic capacity)",
        "glycolytic_reserve": "해당 예비능 (glycolytic reserve)",
        "basal_atp_rate": "기저 ATP 생성률",
        "max_atp_rate": "최대 ATP 생성률",
    }
    lines.append("")
    lines.append("### 표준 parameter (평균 ± SD)")
    for col, kor in param_kor.items():
        if col in params_df.columns:
            m, s = mean_sd(col)
            if not np.isnan(m):
                lines.append(f"- {kor}: {m:.2f} ± {s:.2f}")

    # quadrant distribution
    if "phenotype" in params_df.columns:
        lines.append("")
        lines.append("### Bioenergetic 표현형 분포")
        for k, v in params_df["phenotype"].value_counts().items():
            lines.append(f"- {k}: n={v}")

    lines.append("")
    lines.append("> 본 결과는 Mootha/Houtkooper framework에 기반한 사후 분석입니다. 임상 의사결정에 사용해서는 안 됩니다.")
    return "\n".join(lines)
