"""Muscle-sparing decompose framework.

For a combo treatment (e.g. tirzepatide+bimagrumab):
- expected lean-mass-loss: predicted from monotherapy cohort (linear regression
  on BW pct change → lean pct change for the GLP-1RA mono arm)
- actual lean-mass-loss: observed in combo arm
- sparing effect = expected - actual  (positive = combo spares lean mass)
- MSI = lean preserved / fat lost  (positive = better composition)
- fiber type protection: detect IIB > IIX > IIA > I monotonic CSA gain
"""
from __future__ import annotations

from typing import Dict, Tuple

import numpy as np
import pandas as pd
from scipy import stats

from .trajectory import baseline_lock


def parse_combo(combo: str) -> Tuple[str, str]:
    if "+" not in combo:
        raise ValueError(f"combo must be of form A+B, got {combo}")
    a, b = combo.split("+", 1)
    return a.strip(), b.strip()


def monotherapy_lean_loss_model(
    bw_pct_changes: np.ndarray, lean_pct_changes: np.ndarray
) -> Dict[str, float]:
    """Simple OLS lean_pct = a + b * bw_pct.

    Returns dict with slope, intercept, r2, n.
    """
    if len(bw_pct_changes) < 3:
        return dict(slope=np.nan, intercept=np.nan, r2=np.nan, n=len(bw_pct_changes))
    res = stats.linregress(bw_pct_changes, lean_pct_changes)
    return dict(
        slope=float(res.slope),
        intercept=float(res.intercept),
        r2=float(res.rvalue ** 2),
        n=int(len(bw_pct_changes)),
    )


def decompose_muscle_sparing(
    bw_df: pd.DataFrame,
    comp_df: pd.DataFrame,
    meta: pd.DataFrame,
    glp_arm: str,
    combo_arm: str,
    final_week: int,
) -> Dict[str, object]:
    """Compute expected vs actual lean loss for combo arm.

    Steps:
      1. Compute mouse-level pct change at final_week for BW, lean, fat (vs week 0).
      2. Fit lean_pct ~ bw_pct on glp monotherapy mice.
      3. For each combo mouse, predict expected lean_pct from its bw_pct.
      4. Sparing effect = actual_lean_pct - expected_lean_pct.
         Positive = combo retained MORE lean than predicted by monotherapy slope
         (i.e. actual is higher / less negative than expected).
      5. MSI = (lean change [g]) / (fat lost [g]) per mouse.
         Positive = lean gained while fat lost (ideal). Negative but >-1 = small lean
         loss per unit fat lost (better than mono). <-1 = more lean lost than fat.
    """
    bw_lock = baseline_lock(bw_df, "bw_g")
    lean_lock = baseline_lock(comp_df, "lean_mass_g")
    fat_lock = baseline_lock(comp_df, "fat_mass_g")

    bw_final = bw_lock[bw_lock["week"] == final_week][
        ["mouse_id", "pct_change_bw_g", "bw_g", "baseline_bw_g"]
    ]
    lean_final = lean_lock[lean_lock["week"] == final_week][
        ["mouse_id", "pct_change_lean_mass_g", "lean_mass_g", "baseline_lean_mass_g"]
    ]
    fat_final = fat_lock[fat_lock["week"] == final_week][
        ["mouse_id", "pct_change_fat_mass_g", "fat_mass_g", "baseline_fat_mass_g"]
    ]

    df = bw_final.merge(lean_final, on="mouse_id").merge(fat_final, on="mouse_id")
    df = df.merge(meta[["mouse_id", "treatment"]], on="mouse_id", how="left")

    mono = df[df["treatment"] == glp_arm]
    combo = df[df["treatment"] == combo_arm]

    if mono.empty or combo.empty:
        return dict(
            error=f"missing arm: mono={glp_arm} (n={len(mono)}), combo={combo_arm} (n={len(combo)})"
        )

    model = monotherapy_lean_loss_model(
        mono["pct_change_bw_g"].to_numpy(),
        mono["pct_change_lean_mass_g"].to_numpy(),
    )

    # predict expected lean pct loss for combo mice
    expected_lean_pct = model["intercept"] + model["slope"] * combo["pct_change_bw_g"]
    actual_lean_pct = combo["pct_change_lean_mass_g"].to_numpy()
    sparing_pct = actual_lean_pct - expected_lean_pct.to_numpy()
    # positive = combo retained more lean than monotherapy slope predicted (sparing)

    # MSI = lean preserved [g] / fat lost [g]
    # lean_preserved = max(0, expected_g - actual_g)? Use simpler: -lean_change_g [if positive=preserve]
    # Convention: lean_change_g = lean_final - lean_baseline (negative = lost)
    lean_change_g = combo["lean_mass_g"].to_numpy() - combo["baseline_lean_mass_g"].to_numpy()
    fat_change_g = combo["fat_mass_g"].to_numpy() - combo["baseline_fat_mass_g"].to_numpy()
    fat_lost_g = -fat_change_g  # positive when fat decreased
    # avoid div by zero
    msi = np.where(np.abs(fat_lost_g) > 1e-6, lean_change_g / fat_lost_g, np.nan)
    # Note: MSI > 0 means lean gained while fat lost (ideal)
    #       MSI between -1 and 0 means small lean loss per unit fat lost (better than mono)
    #       MSI < -1 means more lean lost than fat

    sparing_summary = dict(
        glp_arm=glp_arm,
        combo_arm=combo_arm,
        n_mono=int(len(mono)),
        n_combo=int(len(combo)),
        mono_model=model,
        mean_expected_lean_pct=float(np.nanmean(expected_lean_pct)),
        mean_actual_lean_pct=float(np.nanmean(actual_lean_pct)),
        mean_sparing_pct=float(np.nanmean(sparing_pct)),
        median_msi=float(np.nanmedian(msi)),
        per_mouse=pd.DataFrame(
            dict(
                mouse_id=combo["mouse_id"].to_numpy(),
                bw_pct=combo["pct_change_bw_g"].to_numpy(),
                expected_lean_pct=expected_lean_pct.to_numpy(),
                actual_lean_pct=actual_lean_pct,
                sparing_pct=sparing_pct,
                lean_change_g=lean_change_g,
                fat_change_g=fat_change_g,
                msi=msi,
            )
        ),
    )
    return sparing_summary


def fiber_protection_pattern(
    fiber_df: pd.DataFrame,
    meta: pd.DataFrame,
    arm: str,
    final_week: int,
) -> Dict[str, object]:
    """Detect IIB > IIX > IIA > I CSA gain pattern in given arm.

    Returns per-fiber mean pct change vs week 0 and a boolean flag for monotonic
    fast-twitch protection (IIB >= IIX >= IIA >= I).
    """
    df = pd.merge(fiber_df, meta[["mouse_id", "treatment"]], on="mouse_id", how="left")
    df = df[df["treatment"] == arm].copy()
    if df.empty:
        return dict(error=f"no fiber data for arm {arm}")

    # baseline lock per (mouse_id, fiber_type) on csa_mean_um2
    baselines: Dict[Tuple[str, str], float] = {}
    for (mid, ft), sub in df.groupby(["mouse_id", "fiber_type"]):
        if (sub["week"] == 0).any():
            bl = sub.loc[sub["week"] == 0, "csa_mean_um2"].iloc[0]
        else:
            bl = sub.sort_values("week")["csa_mean_um2"].iloc[0]
        baselines[(mid, ft)] = float(bl)

    df["baseline"] = df.apply(
        lambda r: baselines.get((r["mouse_id"], r["fiber_type"]), np.nan), axis=1
    )
    df["pct_change_csa"] = (df["csa_mean_um2"] - df["baseline"]) / df["baseline"] * 100.0

    final = df[df["week"] == final_week]
    means = final.groupby("fiber_type")["pct_change_csa"].mean().to_dict()

    order = ["IIB", "IIX", "IIA", "I"]
    seq = [means.get(ft, np.nan) for ft in order]
    monotonic = all(
        (not np.isnan(a) and not np.isnan(b) and a >= b)
        for a, b in zip(seq[:-1], seq[1:])
    )
    return dict(
        arm=arm,
        per_fiber_mean_pct_change=means,
        ordered_pattern={ft: seq[i] for i, ft in enumerate(order)},
        fast_twitch_dominant=bool(monotonic),
    )
