"""Synthetic sarcopenic obesity cohort generator.

32 mice × 4 arms × 7 timepoints. Spec:
- vehicle: BW slight up, lean ~stable, grip ~stable
- tirzepatide: BW -20%, fat -30%, lean -5%, grip per lean -2%
- bimagrumab: BW +5%, fat -10%, lean +12%, grip +15%
- combo (tirzepatide+bimagrumab): BW -18%, fat -35%, lean +5% (spared), grip per lean +8%
"""
from __future__ import annotations

import os
from typing import Dict

import numpy as np
import pandas as pd

from .loaders import write_modality

WEEKS = [0, 4, 8, 12, 16, 20, 24]
ARMS = ["vehicle", "tirzepatide", "bimagrumab", "tirzepatide+bimagrumab"]
N_PER_ARM = 8

# end-of-trial expected % change (W24 vs W0) by arm — drives the synthetic curves
ARM_TARGETS = {
    "vehicle": dict(
        bw=+2.0, lean=-1.0, fat=+3.0, grip=-1.0, csa=-3.0,
        myostatin=+5.0, activin_A=+3.0, irisin=+0.0, decorin=-2.0, BAIBA=-3.0,
    ),
    "tirzepatide": dict(
        bw=-20.0, lean=-5.0, fat=-30.0, grip=-7.0, csa=-7.0,
        myostatin=+25.0, activin_A=+8.0, irisin=+5.0, decorin=-10.0, BAIBA=+10.0,
    ),
    "bimagrumab": dict(
        bw=+5.0, lean=+12.0, fat=-10.0, grip=+15.0, csa=+18.0,
        myostatin=+0.0, activin_A=-40.0, irisin=+5.0, decorin=+10.0, BAIBA=+12.0,
    ),
    "tirzepatide+bimagrumab": dict(
        bw=-18.0, lean=+5.0, fat=-35.0, grip=+12.0, csa=+10.0,
        myostatin=+10.0, activin_A=-35.0, irisin=+8.0, decorin=+5.0, BAIBA=+18.0,
    ),
}

# Fiber-type protection: bimagrumab dominant on IIB > IIX > IIA > I
FIBER_TARGETS = {
    "vehicle": {"I": -2.0, "IIA": -3.0, "IIX": -4.0, "IIB": -5.0},
    "tirzepatide": {"I": -3.0, "IIA": -5.0, "IIX": -8.0, "IIB": -12.0},
    "bimagrumab": {"I": +3.0, "IIA": +10.0, "IIX": +18.0, "IIB": +28.0},
    "tirzepatide+bimagrumab": {"I": +1.0, "IIA": +6.0, "IIX": +12.0, "IIB": +20.0},
}


def _trajectory(target_pct: float, weeks: np.ndarray, rng: np.random.Generator) -> np.ndarray:
    """Smooth approach to target_pct over weeks 0..24. Returns multiplier (1.0 baseline)."""
    # approach with saturating exponential
    frac = 1.0 - np.exp(-weeks / 8.0)
    pct = target_pct * frac
    # noise
    noise = rng.normal(0, 0.6, size=len(weeks))
    return 1.0 + (pct + noise) / 100.0


def generate(seed: int = 42) -> Dict[str, pd.DataFrame]:
    rng = np.random.default_rng(seed)

    # cohort meta
    meta_rows = []
    mouse_idx = 0
    for arm in ARMS:
        # alternate models for realism (mostly HFD+aging baseline)
        model_pool = ["HFD+aging", "HFD+OVX"]
        for j in range(N_PER_ARM):
            mouse_id = f"M{mouse_idx:03d}"
            model = model_pool[j % len(model_pool)]
            dose = 0.0 if arm == "vehicle" else 1.0
            meta_rows.append(
                dict(
                    mouse_id=mouse_id,
                    model=model,
                    treatment=arm,
                    dose_mg_kg=dose,
                    randomization_seed=seed,
                    sex="M",
                    cohort_id="demo_cohort_2026_05_02",
                )
            )
            mouse_idx += 1
    meta = pd.DataFrame(meta_rows)

    weeks_arr = np.array(WEEKS)

    bw_rows = []
    comp_rows = []
    grip_rows = []
    treadmill_rows = []
    csa_rows = []
    fiber_rows = []
    myokine_rows = []
    exvivo_rows = []
    wheel_rows = []

    for _, m in meta.iterrows():
        mid = m["mouse_id"]
        arm = m["treatment"]
        targets = ARM_TARGETS[arm]
        ftargets = FIBER_TARGETS[arm]

        # individual baselines (sarcopenic obese mice ~ 35-45g, lean ~ 22g, fat ~ 13g)
        bw0 = float(rng.normal(40.0, 2.5))
        lean0 = float(rng.normal(22.0, 1.2))
        fat0 = float(rng.normal(bw0 - lean0 - 4.0, 1.0))  # rest is bone+water
        grip0 = float(rng.normal(180.0, 12.0))  # grams force
        treadmill_dist0 = float(rng.normal(380.0, 35.0))
        csa0 = float(rng.normal(2400.0, 150.0))  # um^2 mean

        bw_traj = _trajectory(targets["bw"], weeks_arr, rng) * bw0
        lean_traj = _trajectory(targets["lean"], weeks_arr, rng) * lean0
        fat_traj = _trajectory(targets["fat"], weeks_arr, rng) * fat0
        grip_traj = _trajectory(targets["grip"], weeks_arr, rng) * grip0
        treadmill_traj = _trajectory(targets["grip"] * 1.2, weeks_arr, rng) * treadmill_dist0
        wheel_traj = _trajectory(targets["grip"] * 0.5, weeks_arr, rng) * 8000.0

        for i, w in enumerate(WEEKS):
            bw_rows.append(dict(mouse_id=mid, week=w, bw_g=round(bw_traj[i], 2)))
            comp_rows.append(
                dict(
                    mouse_id=mid,
                    week=w,
                    lean_mass_g=round(lean_traj[i], 2),
                    fat_mass_g=round(fat_traj[i], 2),
                    bw_g=round(bw_traj[i], 2),
                )
            )
            grip_rows.append(dict(mouse_id=mid, week=w, force_g=round(grip_traj[i], 1)))
            treadmill_rows.append(
                dict(
                    mouse_id=mid,
                    week=w,
                    distance_m=round(treadmill_traj[i], 1),
                    duration_s=round(treadmill_traj[i] / 0.25, 1),
                )
            )
            wheel_rows.append(
                dict(mouse_id=mid, week=w, rev_per_day=round(float(wheel_traj[i]), 0))
            )

            # micro-CT muscles
            for muscle in ["gastroc", "soleus", "TA", "quad"]:
                # muscles share the same arm-level CSA target with small per-muscle jitter
                m_target = targets["csa"] + rng.normal(0, 1.0)
                muscle_traj = _trajectory(m_target, weeks_arr, rng)[i]
                csa_rows.append(
                    dict(
                        mouse_id=mid,
                        week=w,
                        muscle=muscle,
                        csa_mm2=round(8.0 * muscle_traj, 3),
                        imat_pct=round(max(0.0, 5.0 - 2.0 * (muscle_traj - 1.0) * 100), 2),
                    )
                )

            # exvivo force at terminal weeks only (W12, W24)
            if w in (12, 24):
                for muscle in ["gastroc", "soleus"]:
                    tet = float(rng.normal(280.0, 25.0)) * (1 + (targets["grip"] / 100.0) * (w / 24.0))
                    sf = float(rng.normal(220.0, 18.0)) * (1 + (targets["grip"] / 100.0) * (w / 24.0))
                    exvivo_rows.append(
                        dict(
                            mouse_id=mid,
                            week=w,
                            muscle=muscle,
                            tetanic_mN=round(tet, 1),
                            specific_force_kPa=round(sf, 1),
                        )
                    )

            # myofiber HCS
            for ft in ["I", "IIA", "IIX", "IIB"]:
                ft_target = ftargets[ft]
                ft_traj = _trajectory(ft_target, weeks_arr, rng)[i]
                base_csa = {"I": 1700.0, "IIA": 2200.0, "IIX": 2800.0, "IIB": 3300.0}[ft]
                fiber_rows.append(
                    dict(
                        mouse_id=mid,
                        week=w,
                        fiber_type=ft,
                        csa_mean_um2=round(base_csa * ft_traj, 1),
                        pct=round({"I": 12, "IIA": 28, "IIX": 30, "IIB": 30}[ft] + rng.normal(0, 1.5), 2),
                        pax7_density_per_mm2=round(
                            float(rng.normal(45, 6)) * (1 + (ftargets["IIB"] / 200.0)),
                            1,
                        ),
                        centronucleated_pct=round(max(0.0, float(rng.normal(2.0, 0.5))), 2),
                    )
                )

            # myokine ELISA
            for analyte in ["myostatin", "activin_A", "irisin", "decorin", "BAIBA"]:
                base = {"myostatin": 80.0, "activin_A": 25.0, "irisin": 7.5, "decorin": 30.0, "BAIBA": 12.0}[analyte]
                trg = targets[analyte]
                val = base * _trajectory(trg, weeks_arr, rng)[i]
                myokine_rows.append(
                    dict(
                        mouse_id=mid,
                        week=w,
                        analyte=analyte,
                        value_pgmL=round(float(val), 2),
                    )
                )

    return dict(
        cohort_meta=meta,
        body_weight=pd.DataFrame(bw_rows),
        body_composition=pd.DataFrame(comp_rows),
        grip_strength=pd.DataFrame(grip_rows),
        treadmill=pd.DataFrame(treadmill_rows),
        running_wheel=pd.DataFrame(wheel_rows),
        microct_muscle=pd.DataFrame(csa_rows),
        myofiber_hcs=pd.DataFrame(fiber_rows),
        exvivo_force=pd.DataFrame(exvivo_rows),
        myokine=pd.DataFrame(myokine_rows),
    )


def write_demo(cohort_dir: str, seed: int = 42) -> Dict[str, str]:
    os.makedirs(cohort_dir, exist_ok=True)
    bundle = generate(seed)
    paths = {}
    for name, df in bundle.items():
        paths[name] = write_modality(cohort_dir, name, df)
    return paths
