"""Baseline lock + % change trajectory + composite endpoint."""
from __future__ import annotations

from typing import Dict, List, Tuple

import numpy as np
import pandas as pd

from .schemas import CompositeEndpoint


def baseline_lock(df: pd.DataFrame, value_col: str, baseline_week: int = 0) -> pd.DataFrame:
    """Return df with extra columns: baseline_<col>, pct_change_<col>.

    Baseline is per mouse_id at baseline_week. Mice without a baseline measure
    fall back to their earliest observation.
    """
    out = df.copy()
    bl_map: Dict[str, float] = {}
    for mid, sub in out.groupby("mouse_id"):
        if (sub["week"] == baseline_week).any():
            bl = sub.loc[sub["week"] == baseline_week, value_col].iloc[0]
        else:
            bl = sub.sort_values("week")[value_col].iloc[0]
        bl_map[mid] = float(bl)

    out[f"baseline_{value_col}"] = out["mouse_id"].map(bl_map)
    out[f"pct_change_{value_col}"] = (
        (out[value_col] - out[f"baseline_{value_col}"]) / out[f"baseline_{value_col}"] * 100.0
    )
    return out


def grip_per_lean(grip_df: pd.DataFrame, comp_df: pd.DataFrame) -> pd.DataFrame:
    """Join grip and body composition; produce force_per_g_lean."""
    g = grip_df[["mouse_id", "week", "force_g"]].copy()
    c = comp_df[["mouse_id", "week", "lean_mass_g"]].copy()
    merged = pd.merge(g, c, on=["mouse_id", "week"], how="inner")
    merged["force_per_g_lean"] = merged["force_g"] / merged["lean_mass_g"]
    return merged


def per_arm_summary(
    df: pd.DataFrame, meta: pd.DataFrame, value_col: str
) -> pd.DataFrame:
    """Median + IQR per (treatment, week) for value_col."""
    m = pd.merge(df, meta[["mouse_id", "treatment"]], on="mouse_id", how="left")
    g = m.groupby(["treatment", "week"])[value_col].agg(
        median="median",
        q25=lambda x: np.percentile(x, 25),
        q75=lambda x: np.percentile(x, 75),
        n="count",
    ).reset_index()
    return g


def composite_endpoints(
    bw_df: pd.DataFrame,
    comp_df: pd.DataFrame,
    grip_per_lean_df: pd.DataFrame,
    meta: pd.DataFrame,
    final_week: int,
) -> pd.DataFrame:
    """Per-mouse composite endpoint at final_week.

    BW pct change <-5% AND lean pct change >-3% AND grip-per-lean pct change >+5%
    """
    # BW pct
    bw_lock = baseline_lock(bw_df, "bw_g")
    lean_lock = baseline_lock(comp_df, "lean_mass_g")
    gpl_lock = baseline_lock(grip_per_lean_df, "force_per_g_lean")

    rows = []
    for mid in meta["mouse_id"].unique():
        bw_row = bw_lock[(bw_lock["mouse_id"] == mid) & (bw_lock["week"] == final_week)]
        lean_row = lean_lock[(lean_lock["mouse_id"] == mid) & (lean_lock["week"] == final_week)]
        gpl_row = gpl_lock[(gpl_lock["mouse_id"] == mid) & (gpl_lock["week"] == final_week)]
        if bw_row.empty or lean_row.empty or gpl_row.empty:
            continue
        bw_pct = float(bw_row["pct_change_bw_g"].iloc[0])
        lean_pct = float(lean_row["pct_change_lean_mass_g"].iloc[0])
        gpl_pct = float(gpl_row["pct_change_force_per_g_lean"].iloc[0])
        ce = CompositeEndpoint(bw_pct, lean_pct, gpl_pct)
        rows.append(
            dict(
                mouse_id=mid,
                bw_pct_change=bw_pct,
                lean_pct_change=lean_pct,
                grip_per_lean_pct_change=gpl_pct,
                composite_pass=int(ce.passes()),
            )
        )
    out = pd.DataFrame(rows)
    out = pd.merge(out, meta[["mouse_id", "treatment", "model"]], on="mouse_id", how="left")
    return out


def composite_pass_rate(ce_df: pd.DataFrame) -> pd.DataFrame:
    g = ce_df.groupby("treatment")["composite_pass"].agg(
        n="count",
        n_pass="sum",
    ).reset_index()
    g["pass_rate_pct"] = g["n_pass"] / g["n"] * 100.0
    return g
