"""Core analytics for BodyCompMouseDXA.

All inputs are pandas DataFrames following the canonical schema:
    animal_id, model, group, sex, time_point_wk, day_label,
    body_weight_g, fat_g, lean_g, water_g, fat_pct, lean_pct,
    BMD, BMC, visceral_fat_proxy_g, appendicular_lean_g,
    drug_phase, source_format

Heavy dependencies (statsmodels, plotly, pydicom) are optional. The
module falls back to scipy-only implementations where possible.

This file does no plotting itself; it returns DataFrames / dicts. The
Streamlit layer in `main.py` is responsible for rendering.
"""

from __future__ import annotations

import io
import math
import os
from dataclasses import dataclass, field
from typing import Dict, Iterable, List, Optional, Tuple

import numpy as np
import pandas as pd
from scipy import stats

try:
    import statsmodels.api as sm  # type: ignore
    import statsmodels.formula.api as smf  # type: ignore
    HAS_STATSMODELS = True
except ImportError:
    sm = None
    smf = None
    HAS_STATSMODELS = False

try:
    import pydicom  # type: ignore
    HAS_PYDICOM = True
except ImportError:
    pydicom = None
    HAS_PYDICOM = False


SCHEMA_COLUMNS = [
    "animal_id", "model", "group", "sex", "time_point_wk", "day_label",
    "body_weight_g", "fat_g", "lean_g", "water_g", "fat_pct", "lean_pct",
    "BMD", "BMC", "visceral_fat_proxy_g", "appendicular_lean_g",
    "drug_phase", "source_format",
]


# ============================================================================
#  Multi-format ingest (feature 1)
# ============================================================================

def _ensure_schema(df: pd.DataFrame) -> pd.DataFrame:
    """Coerce a partial dataframe into canonical SCHEMA_COLUMNS.

    Missing optional columns are filled with NaN / sensible defaults.
    """
    df = df.copy()
    for col in SCHEMA_COLUMNS:
        if col not in df.columns:
            if col in ("group", "sex", "day_label", "model", "drug_phase",
                      "source_format", "animal_id"):
                df[col] = ""
            else:
                df[col] = np.nan
    # derive fat_pct / lean_pct if absent but we have fat_g / lean_g / bw
    bw = pd.to_numeric(df["body_weight_g"], errors="coerce")
    fat = pd.to_numeric(df["fat_g"], errors="coerce")
    lean = pd.to_numeric(df["lean_g"], errors="coerce")
    fat_pct = pd.to_numeric(df["fat_pct"], errors="coerce")
    lean_pct = pd.to_numeric(df["lean_pct"], errors="coerce")
    mask = fat_pct.isna() & bw.gt(0) & fat.notna()
    df.loc[mask, "fat_pct"] = (fat[mask] / bw[mask] * 100.0).round(2)
    mask = lean_pct.isna() & bw.gt(0) & lean.notna()
    df.loc[mask, "lean_pct"] = (lean[mask] / bw[mask] * 100.0).round(2)
    # if fat_g missing but fat_pct present
    mask = fat.isna() & bw.gt(0) & fat_pct.notna()
    df.loc[mask, "fat_g"] = (bw[mask] * fat_pct[mask] / 100.0).round(3)
    mask = lean.isna() & bw.gt(0) & lean_pct.notna()
    df.loc[mask, "lean_g"] = (bw[mask] * lean_pct[mask] / 100.0).round(3)
    return df[SCHEMA_COLUMNS]


def load_piximus_dicom(path_or_dir: str) -> pd.DataFrame:
    """Load PIXImus DICOM raw export into the canonical schema.

    PIXImus exports DXA results as a series of DICOM files. We extract
    the metadata-driven fields (BMD, fat_pct, lean_g, body_weight_g,
    PatientID, StudyDate) from each file.

    If pydicom is missing this function raises a friendly RuntimeError;
    use `load_dxa_csv_export` to ingest a pre-decoded CSV.
    """
    if not HAS_PYDICOM:
        raise RuntimeError(
            "pydicom is not installed. Export PIXImus to CSV and use "
            "load_dxa_csv_export() instead."
        )
    rows: List[Dict] = []
    if os.path.isdir(path_or_dir):
        files = [os.path.join(path_or_dir, f) for f in os.listdir(path_or_dir)
                 if f.lower().endswith((".dcm", ".dicom"))]
    else:
        files = [path_or_dir]
    for f in files:
        try:
            ds = pydicom.dcmread(f, stop_before_pixels=True)
        except Exception:  # pragma: no cover
            continue
        rows.append({
            "animal_id": str(getattr(ds, "PatientID", os.path.basename(f))),
            "model": str(getattr(ds, "PatientSpeciesDescription", "")),
            "group": "",
            "sex": str(getattr(ds, "PatientSex", "M")),
            "time_point_wk": _parse_time_point(getattr(ds, "StudyDescription", "")),
            "day_label": str(getattr(ds, "StudyDate", "")),
            "body_weight_g": _safe_float(getattr(ds, "PatientWeight", np.nan)) * 1000.0 if getattr(ds, "PatientWeight", None) else np.nan,
            "fat_g": np.nan,
            "lean_g": np.nan,
            "water_g": np.nan,
            "fat_pct": np.nan,
            "lean_pct": np.nan,
            "BMD": _safe_float(_get_private(ds, "BMD")),
            "BMC": _safe_float(_get_private(ds, "BMC")),
            "visceral_fat_proxy_g": np.nan,
            "appendicular_lean_g": np.nan,
            "drug_phase": "",
            "source_format": "PIXImus_DICOM",
        })
    return _ensure_schema(pd.DataFrame(rows))


def _safe_float(x) -> float:
    try:
        return float(x)
    except (TypeError, ValueError):
        return float("nan")


def _get_private(ds, tag_name: str):
    try:
        return getattr(ds, tag_name)
    except AttributeError:
        return None


def _parse_time_point(s: str) -> float:
    s = str(s).lower()
    for tok in s.replace("_", " ").split():
        if tok.endswith("wk"):
            try:
                return float(tok[:-2])
            except ValueError:
                pass
    return np.nan


def load_echomri_csv(path_or_buffer) -> pd.DataFrame:
    """Load EchoMRI .csv export (whole-body composition).

    Typical EchoMRI columns: Label, Fat, Lean, Free Water, Total Water,
    Weight. We map to the canonical schema; animal_id taken from 'Label'.
    """
    df = pd.read_csv(path_or_buffer)
    df.columns = [c.strip() for c in df.columns]
    rename = {
        "Label": "animal_id",
        "Weight": "body_weight_g",
        "Fat": "fat_g",
        "Lean": "lean_g",
        "Total Water": "water_g",
        "Free Water": "water_g",  # fallback
    }
    df = df.rename(columns={k: v for k, v in rename.items() if k in df.columns})
    if "source_format" not in df.columns:
        df["source_format"] = "EchoMRI_csv"
    return _ensure_schema(df)


def load_qnmr_txt(path_or_buffer) -> pd.DataFrame:
    """Load Bruker minispec qNMR .txt export.

    Bruker minispec writes tab- or whitespace-delimited files with
    header columns including SampleID, Weight(g), Fat(%), Lean(%),
    FreeFluid(%). We coerce into canonical schema.
    """
    if hasattr(path_or_buffer, "read"):
        text = path_or_buffer.read()
        if isinstance(text, bytes):
            text = text.decode("utf-8", errors="ignore")
        df = pd.read_csv(io.StringIO(text), sep=None, engine="python")
    else:
        df = pd.read_csv(path_or_buffer, sep=None, engine="python")
    df.columns = [c.strip() for c in df.columns]
    rename = {
        "SampleID": "animal_id",
        "Sample ID": "animal_id",
        "Weight(g)": "body_weight_g",
        "Weight": "body_weight_g",
        "Fat(%)": "fat_pct",
        "Fat%": "fat_pct",
        "Lean(%)": "lean_pct",
        "Lean%": "lean_pct",
        "FreeFluid(%)": "water_pct",
    }
    df = df.rename(columns={k: v for k, v in rename.items() if k in df.columns})
    df["source_format"] = "qNMR_Bruker_minispec_txt"
    return _ensure_schema(df)


def load_skyscan_csv(path_or_buffer) -> pd.DataFrame:
    """Load SkyScan microCT volumetric CSV stack.

    SkyScan exports per-slice volumes; we expect a pre-aggregated CSV
    with at minimum: animal_id, time_point_wk, total_fat_volume_mm3,
    visceral_fat_volume_mm3, lean_volume_mm3, cortical_BMD,
    trabecular_BMD. Volumes are converted to grams assuming
    fat density 0.92 g/mL and lean density 1.06 g/mL.
    """
    df = pd.read_csv(path_or_buffer)
    df.columns = [c.strip() for c in df.columns]
    if "total_fat_volume_mm3" in df.columns:
        df["fat_g"] = df["total_fat_volume_mm3"] * 0.92 / 1000.0
    if "lean_volume_mm3" in df.columns:
        df["lean_g"] = df["lean_volume_mm3"] * 1.06 / 1000.0
    if "visceral_fat_volume_mm3" in df.columns:
        df["visceral_fat_proxy_g"] = df["visceral_fat_volume_mm3"] * 0.92 / 1000.0
    if "cortical_BMD" in df.columns:
        df["BMD"] = df["cortical_BMD"]
    df["source_format"] = "SkyScan_microCT_csv"
    return _ensure_schema(df)


def load_dxa_csv_export(path_or_buffer) -> pd.DataFrame:
    """Generic DXA-style CSV already in canonical schema (the lingua
    franca for this MVP). All synthetic data files follow this format.
    """
    df = pd.read_csv(path_or_buffer)
    return _ensure_schema(df)


def load_any(path: str) -> pd.DataFrame:
    """Dispatch on file extension / naming."""
    name = os.path.basename(path).lower()
    if name.endswith(".dcm") or name.endswith(".dicom"):
        return load_piximus_dicom(path)
    if "echomri" in name:
        return load_echomri_csv(path)
    if "qnmr" in name or "minispec" in name:
        return load_qnmr_txt(path)
    if "skyscan" in name or "microct" in name:
        return load_skyscan_csv(path)
    if name.endswith(".csv"):
        return load_dxa_csv_export(path)
    if name.endswith(".txt"):
        return load_qnmr_txt(path)
    raise ValueError(f"Unrecognized file type for: {path}")


# ============================================================================
#  Auto ROI / segmentation heuristic (feature 2)
# ============================================================================

def derive_roi_metrics(df: pd.DataFrame) -> pd.DataFrame:
    """Apply metadata-driven ROI heuristics.

    Real image segmentation is intentionally stubbed for this MVP — we
    instead derive:
        - visceral_fat_proxy_g  (40% of total fat for HFD/DIO/NASH,
                                 55% for ob/ob & db/db, 30% for chow)
        - subcutaneous_fat_g    (fat_g - visceral)
        - muscle_compartment_g  (appendicular_lean approximation;
                                 60% of lean for adult mice)
    These ratios are documented in the literature as central tendencies.
    """
    df = df.copy()
    if "visceral_fat_proxy_g" not in df.columns:
        df["visceral_fat_proxy_g"] = np.nan

    vis_ratio_map = {
        "C57BL_6J_HFD60": 0.40,
        "GAN_DIO_NASH": 0.42,
        "CDAA_HFD": 0.38,
        "ob_ob": 0.55,
        "db_db": 0.55,
        "NZO": 0.48,
        "KK_Ay": 0.45,
        "STAM": 0.35,
        "C57BL_6J_chow": 0.30,
    }

    def _vis_for_row(row):
        if pd.notna(row.get("visceral_fat_proxy_g")) and row.get("visceral_fat_proxy_g") > 0:
            return row["visceral_fat_proxy_g"]
        fat = row.get("fat_g")
        if pd.isna(fat):
            return np.nan
        ratio = vis_ratio_map.get(str(row.get("model", "")), 0.40)
        return round(fat * ratio, 3)

    df["visceral_fat_proxy_g"] = df.apply(_vis_for_row, axis=1)
    df["subcutaneous_fat_g"] = (df["fat_g"].astype(float)
                                - df["visceral_fat_proxy_g"].astype(float)).clip(lower=0).round(3)
    if "appendicular_lean_g" not in df.columns:
        df["appendicular_lean_g"] = np.nan
    df["appendicular_lean_g"] = df.apply(
        lambda r: r["appendicular_lean_g"] if pd.notna(r.get("appendicular_lean_g")) and r["appendicular_lean_g"] > 0
        else (round(r["lean_g"] * 0.60, 3) if pd.notna(r.get("lean_g")) else np.nan),
        axis=1,
    )
    df["muscle_compartment_g"] = df["appendicular_lean_g"]
    return df


# ============================================================================
#  Body composition indices (feature 4)
# ============================================================================

def compute_indices(df: pd.DataFrame, mouse_length_cm: float = 9.5) -> pd.DataFrame:
    """Compute FMI, LMI, sarcopenic obesity index, ALM/BW etc.

    Mouse body length default 9.5 cm (nose-anus, adult C57BL/6).
    FMI = fat_g / length_m^2 ; LMI = lean_g / length_m^2
    ALM/BW = appendicular_lean_g / body_weight_g
    Sarcopenic-obesity index = lean_g / (lean_g + fat_g) — lower means
        relatively more fat per total mass.
    """
    df = df.copy()
    length_m = mouse_length_cm / 100.0
    length_m2 = length_m * length_m
    bw = pd.to_numeric(df["body_weight_g"], errors="coerce")
    fat = pd.to_numeric(df["fat_g"], errors="coerce")
    lean = pd.to_numeric(df["lean_g"], errors="coerce")
    app_lean = pd.to_numeric(df.get("appendicular_lean_g", np.nan), errors="coerce")
    df["FMI"] = (fat / 1000.0 / length_m2).round(3)
    df["LMI"] = (lean / 1000.0 / length_m2).round(3)
    df["ALM_over_BW"] = (app_lean / bw).round(4)
    df["sarcopenic_obesity_index"] = (lean / (lean + fat)).round(4)
    # higher = leaner; we also expose "fat-to-lean ratio"
    df["fat_to_lean_ratio"] = (fat / lean.replace(0, np.nan)).round(4)
    return df


def flag_sarcopenic_obesity(df: pd.DataFrame,
                            soi_cutoff: float = 0.45,
                            alm_cutoff: float = 0.12) -> pd.DataFrame:
    """Flag sarcopenic obesity per-row.

    Combined criterion: sarcopenic_obesity_index < soi_cutoff AND
    ALM/BW < alm_cutoff. Both must be met. Returns a copy with column
    `sarcopenic_obesity_flag` (bool) and `so_severity` (mild/mod/severe).
    """
    df = df.copy()
    if "sarcopenic_obesity_index" not in df.columns:
        df = compute_indices(df)
    soi = pd.to_numeric(df["sarcopenic_obesity_index"], errors="coerce")
    alm = pd.to_numeric(df.get("ALM_over_BW", np.nan), errors="coerce")
    df["sarcopenic_obesity_flag"] = (soi < soi_cutoff) & (alm < alm_cutoff)

    def _sev(row):
        if not row["sarcopenic_obesity_flag"]:
            return "none"
        # severity by SOI distance from cutoff
        s = row["sarcopenic_obesity_index"]
        if pd.isna(s):
            return "none"
        if s < soi_cutoff - 0.15:
            return "severe"
        if s < soi_cutoff - 0.07:
            return "moderate"
        return "mild"

    df["so_severity"] = df.apply(_sev, axis=1)
    return df


# ============================================================================
#  Trajectory metrics (feature 3)
# ============================================================================

@dataclass
class TrajectorySummary:
    animal_id: str
    n_time_points: int
    baseline_wk: float
    final_wk: float
    baseline_bw: float
    final_bw: float
    pct_change_bw: float
    pct_change_fat: float
    pct_change_lean: float
    slope_bw_per_wk: float
    slope_fat_per_wk: float
    nadir_bw: float
    nadir_bw_wk: float
    regain_after_nadir_g: float
    drug_phase_summary: str
    by_time: List[Dict] = field(default_factory=list)

    def as_dict(self) -> Dict:
        d = self.__dict__.copy()
        return d


def _linreg(xs: np.ndarray, ys: np.ndarray) -> Tuple[float, float, float]:
    """Plain OLS slope/intercept/r2 using numpy. Returns (slope, intercept, r2)."""
    if len(xs) < 2:
        return float("nan"), float("nan"), float("nan")
    slope, intercept = np.polyfit(xs, ys, 1)
    pred = slope * xs + intercept
    ss_res = float(np.sum((ys - pred) ** 2))
    ss_tot = float(np.sum((ys - ys.mean()) ** 2))
    r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else float("nan")
    return float(slope), float(intercept), float(r2)


def trajectory_for_animal(df: pd.DataFrame, animal_id: str) -> TrajectorySummary:
    sub = df[df["animal_id"] == animal_id].sort_values("time_point_wk")
    if sub.empty:
        raise KeyError(f"No rows for animal_id={animal_id}")
    sub = sub.copy()
    bw = pd.to_numeric(sub["body_weight_g"], errors="coerce").to_numpy()
    fat = pd.to_numeric(sub["fat_g"], errors="coerce").to_numpy()
    lean = pd.to_numeric(sub["lean_g"], errors="coerce").to_numpy()
    t = pd.to_numeric(sub["time_point_wk"], errors="coerce").to_numpy()
    baseline_bw = float(bw[0])
    final_bw = float(bw[-1])
    baseline_fat = float(fat[0]) if not np.isnan(fat[0]) else float("nan")
    final_fat = float(fat[-1]) if not np.isnan(fat[-1]) else float("nan")
    baseline_lean = float(lean[0]) if not np.isnan(lean[0]) else float("nan")
    final_lean = float(lean[-1]) if not np.isnan(lean[-1]) else float("nan")
    pct_bw = (final_bw - baseline_bw) / baseline_bw * 100 if baseline_bw else float("nan")
    pct_fat = ((final_fat - baseline_fat) / baseline_fat * 100
               if baseline_fat and not math.isnan(baseline_fat) else float("nan"))
    pct_lean = ((final_lean - baseline_lean) / baseline_lean * 100
                if baseline_lean and not math.isnan(baseline_lean) else float("nan"))
    slope_bw, _, _ = _linreg(t, bw)
    slope_fat, _, _ = _linreg(t[~np.isnan(fat)], fat[~np.isnan(fat)])
    nadir_idx = int(np.nanargmin(bw))
    nadir_bw = float(bw[nadir_idx])
    nadir_wk = float(t[nadir_idx])
    regain = float(bw[-1] - bw[nadir_idx]) if nadir_idx < len(bw) - 1 else 0.0
    phases = sub["drug_phase"].dropna().astype(str).unique().tolist()
    by_time = [
        {
            "time_point_wk": float(t[i]) if not np.isnan(t[i]) else None,
            "body_weight_g": float(bw[i]) if not np.isnan(bw[i]) else None,
            "fat_g": float(fat[i]) if not np.isnan(fat[i]) else None,
            "lean_g": float(lean[i]) if not np.isnan(lean[i]) else None,
            "drug_phase": str(sub["drug_phase"].iloc[i]),
        }
        for i in range(len(t))
    ]
    return TrajectorySummary(
        animal_id=animal_id,
        n_time_points=int(len(sub)),
        baseline_wk=float(t[0]),
        final_wk=float(t[-1]),
        baseline_bw=baseline_bw,
        final_bw=final_bw,
        pct_change_bw=round(pct_bw, 2),
        pct_change_fat=round(pct_fat, 2) if not math.isnan(pct_fat) else float("nan"),
        pct_change_lean=round(pct_lean, 2) if not math.isnan(pct_lean) else float("nan"),
        slope_bw_per_wk=round(slope_bw, 3),
        slope_fat_per_wk=round(slope_fat, 3) if not math.isnan(slope_fat) else float("nan"),
        nadir_bw=round(nadir_bw, 2),
        nadir_bw_wk=nadir_wk,
        regain_after_nadir_g=round(regain, 3),
        drug_phase_summary=" → ".join(phases) if phases else "none",
        by_time=by_time,
    )


def trajectory_table(df: pd.DataFrame) -> pd.DataFrame:
    out = []
    for aid in df["animal_id"].dropna().unique():
        try:
            out.append(trajectory_for_animal(df, aid).as_dict())
        except Exception:
            continue
    return pd.DataFrame(out)


# ============================================================================
#  ANCOVA + cohort statistics (feature 5)
# ============================================================================

def ancova_endpoint(df: pd.DataFrame,
                    endpoint: str = "fat_g",
                    final_wk: Optional[float] = None,
                    baseline_wk: float = 0.0,
                    group_col: str = "group") -> Dict:
    """Baseline-corrected ANCOVA on a single endpoint.

    Model: endpoint_final ~ group + endpoint_baseline
    Falls back to pure-scipy partial regression if statsmodels missing.
    Returns dict with: groups, means_final, means_adjusted, p_value,
    f_statistic, n_per_group, baseline_corrected (bool), backend.
    """
    df = df.copy()
    df["time_point_wk"] = pd.to_numeric(df["time_point_wk"], errors="coerce")
    if final_wk is None:
        final_wk = float(df["time_point_wk"].max())

    base = df[df["time_point_wk"] == baseline_wk][["animal_id", endpoint]].rename(
        columns={endpoint: f"{endpoint}_baseline"})
    final = df[df["time_point_wk"] == final_wk][["animal_id", group_col, endpoint]].rename(
        columns={endpoint: f"{endpoint}_final"})
    merged = base.merge(final, on="animal_id", how="inner").dropna()
    if merged.empty or merged[group_col].nunique() < 2:
        return {
            "endpoint": endpoint,
            "error": "Insufficient data: need ≥2 groups with paired baseline/final.",
            "n_per_group": {},
        }

    groups = sorted(merged[group_col].unique())
    n_per_group = merged.groupby(group_col).size().to_dict()
    means_final = merged.groupby(group_col)[f"{endpoint}_final"].mean().to_dict()
    means_baseline = merged.groupby(group_col)[f"{endpoint}_baseline"].mean().to_dict()

    if HAS_STATSMODELS:
        formula = f"Q('{endpoint}_final') ~ C({group_col}) + Q('{endpoint}_baseline')"
        try:
            model = smf.ols(formula, data=merged).fit()
            anova = sm.stats.anova_lm(model, typ=2)
            row = anova.loc[f"C({group_col})"]
            f_stat = float(row["F"])
            p_val = float(row["PR(>F)"])
            # adjusted means at grand mean baseline
            gm = merged[f"{endpoint}_baseline"].mean()
            adj_means = {}
            for g in groups:
                tmp = merged.iloc[:1].copy()
                tmp[group_col] = g
                tmp[f"{endpoint}_baseline"] = gm
                pred = float(model.predict(tmp).iloc[0])
                adj_means[g] = pred
            return {
                "endpoint": endpoint,
                "groups": groups,
                "n_per_group": n_per_group,
                "means_baseline": means_baseline,
                "means_final": means_final,
                "means_adjusted": adj_means,
                "f_statistic": f_stat,
                "p_value": p_val,
                "baseline_corrected": True,
                "backend": "statsmodels",
                "final_wk": final_wk,
            }
        except Exception as e:  # pragma: no cover
            return {
                "endpoint": endpoint,
                "error": f"statsmodels fit failed: {e}",
                "n_per_group": n_per_group,
            }

    # ----- scipy-only fallback: residualize endpoint on baseline, then ANOVA
    x = merged[f"{endpoint}_baseline"].to_numpy(dtype=float)
    y = merged[f"{endpoint}_final"].to_numpy(dtype=float)
    if np.std(x) > 0:
        slope, intercept = np.polyfit(x, y, 1)
        resid = y - (slope * x + intercept)
    else:
        resid = y - y.mean()
    merged["_resid"] = resid
    # one-way ANOVA on residuals
    samples = [merged.loc[merged[group_col] == g, "_resid"].to_numpy() for g in groups]
    f_stat, p_val = stats.f_oneway(*samples)
    grand_mean = float(np.mean(y))
    adj_means = {g: grand_mean + float(np.mean(merged.loc[merged[group_col] == g, "_resid"]))
                 for g in groups}
    return {
        "endpoint": endpoint,
        "groups": groups,
        "n_per_group": n_per_group,
        "means_baseline": means_baseline,
        "means_final": means_final,
        "means_adjusted": adj_means,
        "f_statistic": float(f_stat),
        "p_value": float(p_val),
        "baseline_corrected": True,
        "backend": "scipy_fallback",
        "final_wk": final_wk,
    }


def mixed_effects_repeated_measures(df: pd.DataFrame,
                                    endpoint: str = "fat_g",
                                    group_col: str = "group") -> Dict:
    """Mixed-effects RM-ANOVA fit using statsmodels.

    Random intercept per animal_id; fixed effects: group, time_point_wk,
    and their interaction. Falls back to per-time one-way ANOVA when
    statsmodels is unavailable.
    """
    df = df.copy()
    df["time_point_wk"] = pd.to_numeric(df["time_point_wk"], errors="coerce")
    df = df.dropna(subset=[endpoint, "time_point_wk", group_col, "animal_id"])
    if HAS_STATSMODELS:
        try:
            formula = f"Q('{endpoint}') ~ C({group_col}) * time_point_wk"
            md = smf.mixedlm(formula, df, groups=df["animal_id"])
            mdf = md.fit(method="lbfgs", disp=False)
            return {
                "endpoint": endpoint,
                "backend": "statsmodels_mixedlm",
                "summary_text": mdf.summary().as_text(),
                "params": mdf.params.to_dict(),
                "pvalues": mdf.pvalues.to_dict(),
                "aic": float(mdf.aic) if hasattr(mdf, "aic") else None,
                "n": int(len(df)),
            }
        except Exception as e:  # pragma: no cover
            return {"endpoint": endpoint, "error": f"mixedlm failed: {e}"}
    # fallback: per-time one-way ANOVA
    per_time = {}
    for t, sub in df.groupby("time_point_wk"):
        groups = sorted(sub[group_col].unique())
        if len(groups) < 2:
            continue
        samples = [sub.loc[sub[group_col] == g, endpoint].to_numpy(dtype=float) for g in groups]
        if any(len(s) < 2 for s in samples):
            continue
        f, p = stats.f_oneway(*samples)
        per_time[float(t)] = {"f": float(f), "p": float(p),
                              "n_per_group": {g: int(len(s)) for g, s in zip(groups, samples)}}
    return {
        "endpoint": endpoint,
        "backend": "scipy_per_time_anova_fallback",
        "per_time_results": per_time,
    }


def drug_on_off_interaction(df: pd.DataFrame,
                            endpoint: str = "fat_g",
                            group_col: str = "group") -> Dict:
    """Test drug-on vs drug-off interaction.

    Splits trajectory at the row where drug_phase first matches
    /drug_on|drug_off/ and reports mean change within each phase per
    group. Two-sample t-test (Welch) compares treated vs placebo within
    each phase.
    """
    df = df.copy()
    df["time_point_wk"] = pd.to_numeric(df["time_point_wk"], errors="coerce")
    df = df.dropna(subset=[endpoint, "time_point_wk", group_col, "animal_id"])
    rows = []
    for aid, sub in df.groupby("animal_id"):
        sub = sub.sort_values("time_point_wk")
        phases = sub["drug_phase"].astype(str).str.lower()
        if not phases.str.contains("drug_on").any() and not phases.str.contains("on").any():
            continue
        # find drug-on first index, drug-off first index
        on_rows = sub[phases.str.contains("on")]
        off_rows = sub[phases.str.contains("off")]
        induct_rows = sub[~phases.str.contains("on") & ~phases.str.contains("off")]
        if on_rows.empty:
            continue
        on_start = float(induct_rows[endpoint].iloc[-1]) if not induct_rows.empty else float(on_rows[endpoint].iloc[0])
        on_end = float(on_rows[endpoint].iloc[-1])
        delta_on = on_end - on_start
        delta_off = float(off_rows[endpoint].iloc[-1]) - on_end if not off_rows.empty else float("nan")
        rows.append({
            "animal_id": aid,
            "group": str(sub[group_col].iloc[0]),
            "delta_on": delta_on,
            "delta_off": delta_off,
            "on_start": on_start,
            "on_end": on_end,
        })
    if not rows:
        return {"endpoint": endpoint, "error": "No drug-on/off labels detected."}
    res = pd.DataFrame(rows)
    groups = sorted(res["group"].unique())
    out = {"endpoint": endpoint, "groups": groups, "n_per_group": res.groupby("group").size().to_dict()}
    if len(groups) == 2:
        a, b = groups
        on_a = res.loc[res["group"] == a, "delta_on"].dropna().to_numpy()
        on_b = res.loc[res["group"] == b, "delta_on"].dropna().to_numpy()
        if len(on_a) > 1 and len(on_b) > 1:
            t_on, p_on = stats.ttest_ind(on_a, on_b, equal_var=False)
            out["delta_on_t"] = float(t_on)
            out["delta_on_p"] = float(p_on)
            out["delta_on_mean"] = {a: float(on_a.mean()), b: float(on_b.mean())}
        off_a = res.loc[res["group"] == a, "delta_off"].dropna().to_numpy()
        off_b = res.loc[res["group"] == b, "delta_off"].dropna().to_numpy()
        if len(off_a) > 1 and len(off_b) > 1:
            t_off, p_off = stats.ttest_ind(off_a, off_b, equal_var=False)
            out["delta_off_t"] = float(t_off)
            out["delta_off_p"] = float(p_off)
            out["delta_off_mean"] = {a: float(off_a.mean()), b: float(off_b.mean())}
    out["per_animal"] = rows
    return out


# ============================================================================
#  Korean-language reporting
# ============================================================================

def korean_report(traj: TrajectorySummary,
                  ancova_result: Optional[Dict] = None,
                  drug_result: Optional[Dict] = None) -> str:
    """Render a manuscript-ready Korean summary block."""
    lines = []
    lines.append(f"[개체 {traj.animal_id} 트라젝토리 요약]")
    lines.append(f"  - 관측 시점 수: {traj.n_time_points} (wk{traj.baseline_wk:g} → wk{traj.final_wk:g})")
    lines.append(f"  - 기저 체중: {traj.baseline_bw:.2f} g → 최종 체중: {traj.final_bw:.2f} g "
                 f"(변화율 {traj.pct_change_bw:+.2f}%)")
    if not math.isnan(traj.pct_change_fat):
        lines.append(f"  - 지방량 변화율: {traj.pct_change_fat:+.2f}%")
    if not math.isnan(traj.pct_change_lean):
        lines.append(f"  - 제지방량 변화율: {traj.pct_change_lean:+.2f}%")
    lines.append(f"  - 체중 변화율(주당): {traj.slope_bw_per_wk:+.3f} g/wk")
    if not math.isnan(traj.slope_fat_per_wk):
        lines.append(f"  - 지방량 변화율(주당): {traj.slope_fat_per_wk:+.3f} g/wk")
    lines.append(f"  - 체중 최저점: wk{traj.nadir_bw_wk:g}, {traj.nadir_bw:.2f} g "
                 f"(이후 재증가 {traj.regain_after_nadir_g:+.3f} g)")
    lines.append(f"  - 약물 단계: {traj.drug_phase_summary}")

    if ancova_result and "p_value" in ancova_result:
        lines.append("")
        lines.append(f"[ANCOVA: {ancova_result['endpoint']} @ wk{ancova_result.get('final_wk', '?')}]")
        lines.append(f"  - 백엔드: {ancova_result.get('backend')}")
        lines.append(f"  - F = {ancova_result['f_statistic']:.3f}, p = {ancova_result['p_value']:.4g}")
        lines.append("  - 그룹별 조정 평균(기저값 보정):")
        for g, m in ancova_result.get("means_adjusted", {}).items():
            n = ancova_result["n_per_group"].get(g, 0)
            lines.append(f"      · {g}: {m:.3f} (n={n})")

    if drug_result and "groups" in drug_result:
        lines.append("")
        lines.append(f"[Drug-on / drug-off 상호작용: {drug_result['endpoint']}]")
        if "delta_on_p" in drug_result:
            lines.append(f"  - Drug-on 단계 변화 t = {drug_result['delta_on_t']:.3f}, "
                         f"p = {drug_result['delta_on_p']:.4g}")
            for g, m in drug_result["delta_on_mean"].items():
                lines.append(f"      · {g} Δ = {m:+.3f}")
        if "delta_off_p" in drug_result:
            lines.append(f"  - Drug-off 단계 변화 t = {drug_result['delta_off_t']:.3f}, "
                         f"p = {drug_result['delta_off_p']:.4g}")
            for g, m in drug_result["delta_off_mean"].items():
                lines.append(f"      · {g} Δ = {m:+.3f}")
    lines.append("")
    lines.append("주의: 본 분석은 IACUC 승인 동물실험 사후 분석용이며, 임상 의사결정 도구가 아닙니다.")
    return "\n".join(lines)


# ============================================================================
#  Helpers for the Streamlit UI
# ============================================================================

def load_demo_data() -> pd.DataFrame:
    """Load all bundled synthetic CSVs and concatenate.

    Returns a single canonical-schema DataFrame.
    """
    here = os.path.dirname(os.path.abspath(__file__))
    csv_dir = os.path.join(here, "data", "synthetic")
    if not os.path.isdir(csv_dir):
        return pd.DataFrame(columns=SCHEMA_COLUMNS)
    frames = []
    for f in sorted(os.listdir(csv_dir)):
        if f.lower().endswith(".csv"):
            try:
                frames.append(load_dxa_csv_export(os.path.join(csv_dir, f)))
            except Exception:
                continue
    if not frames:
        return pd.DataFrame(columns=SCHEMA_COLUMNS)
    return pd.concat(frames, ignore_index=True)


def regenerate_synthetic_if_missing() -> List[str]:
    """If the synthetic CSVs are missing, generate them now."""
    here = os.path.dirname(os.path.abspath(__file__))
    out_dir = os.path.join(here, "data", "synthetic")
    expected = [
        "cohort_C57BL6_DIO.csv",
        "cohort_obob.csv",
        "cohort_dbdb.csv",
        "cohort_STAM.csv",
        "cohort_control_chow.csv",
        "cohort_GLP1RA_STEP4_mimic.csv",
    ]
    missing = [f for f in expected
               if not os.path.isfile(os.path.join(out_dir, f))]
    if not missing:
        return []
    import sys
    sys.path.insert(0, os.path.join(here, "data"))
    from synthetic_generator import generate_all  # type: ignore
    return generate_all(out_dir)


if __name__ == "__main__":
    df = load_demo_data()
    if df.empty:
        regenerate_synthetic_if_missing()
        df = load_demo_data()
    df = derive_roi_metrics(df)
    df = compute_indices(df)
    df = flag_sarcopenic_obesity(df)
    aid = df["animal_id"].iloc[0]
    t = trajectory_for_animal(df, aid)
    print(korean_report(t))
