"""Synthetic body composition generator for BodyCompMouseDXA.

Produces 5 CSVs in the same directory:
    - cohort_C57BL6_DIO.csv          DIO 60% HFD, 8 mice
    - cohort_obob.csv                ob/ob hyperphagic obese, 8 mice
    - cohort_dbdb.csv                db/db obese-diabetic, 8 mice
    - cohort_STAM.csv                STZ + HFD NASH, 6 mice
    - cohort_control_chow.csv        C57BL/6 chow control, 8 mice
    - cohort_GLP1RA_STEP4_mimic.csv  STEP-4 mimic, drug-on + drug-off, 12 mice

All values are synthetic. seed=42. Schema columns:
    animal_id, model, group, sex, time_point_wk, day_label,
    body_weight_g, fat_g, lean_g, water_g, fat_pct, lean_pct,
    BMD, BMC, visceral_fat_proxy_g, appendicular_lean_g,
    drug_phase, source_format

The data are intended for offline analytics validation only.
"""

from __future__ import annotations

import csv
import os
import random
from typing import List, Dict

try:
    # local import so this script works whether invoked directly or as a module.
    from dio_reference import DIO_REFERENCE, get_reference
except ImportError:  # pragma: no cover
    import sys
    HERE = os.path.dirname(os.path.abspath(__file__))
    sys.path.insert(0, os.path.dirname(HERE))
    from dio_reference import DIO_REFERENCE, get_reference  # type: ignore

SEED = 42
TIME_POINTS_STD = [0, 4, 8, 12, 16, 20]  # 6 time points
TIME_POINTS_DRUG = [0, 4, 8, 12, 16, 20]  # drug-on 0-12, drug-off 12-20

SCHEMA_COLUMNS = [
    "animal_id", "model", "group", "sex", "time_point_wk", "day_label",
    "body_weight_g", "fat_g", "lean_g", "water_g", "fat_pct", "lean_pct",
    "BMD", "BMC", "visceral_fat_proxy_g", "appendicular_lean_g",
    "drug_phase", "source_format",
]


def _jitter(rng: random.Random, mid: float, spread: float) -> float:
    """Return mid + small jitter, clipped to non-negative."""
    val = rng.gauss(mid, spread)
    return max(val, 0.0)


def _round(x: float, ndigits: int = 3) -> float:
    return round(float(x), ndigits)


def _row_from_reference(rng: random.Random, model_key: str,
                        t: int, animal_idx: int,
                        group: str, sex: str,
                        drug_phase: str,
                        source_format: str,
                        modifier: Dict[str, float] | None = None) -> Dict:
    """Generate one row drawing from DIO_REFERENCE mid-band with noise.

    modifier: optional multiplicative tweaks per measurement, used to
    simulate drug-on (fat reduction, lean preservation) vs drug-off
    (rebound). Keys: 'bw_mult', 'fat_mult', 'lean_mult', 'vis_mult'.
    """
    band = get_reference(model_key, t)
    bw = _jitter(rng, band["body_weight_g"][1], (band["body_weight_g"][2] - band["body_weight_g"][0]) / 4)
    fat_pct = _jitter(rng, band["fat_pct"][1], (band["fat_pct"][2] - band["fat_pct"][0]) / 4)
    lean_pct = _jitter(rng, band["lean_pct"][1], (band["lean_pct"][2] - band["lean_pct"][0]) / 4)
    bmd = _jitter(rng, band["BMD"][1], (band["BMD"][2] - band["BMD"][0]) / 4)
    bmc = _jitter(rng, band["BMC"][1], (band["BMC"][2] - band["BMC"][0]) / 4)
    vis = _jitter(rng, band["visceral_fat_proxy_g"][1], (band["visceral_fat_proxy_g"][2] - band["visceral_fat_proxy_g"][0]) / 4)
    app_lean = _jitter(rng, band["appendicular_lean_g"][1], (band["appendicular_lean_g"][2] - band["appendicular_lean_g"][0]) / 4)

    if modifier:
        bw *= modifier.get("bw_mult", 1.0)
        fat_pct *= modifier.get("fat_mult", 1.0)
        lean_pct *= modifier.get("lean_mult", 1.0)
        vis *= modifier.get("vis_mult", 1.0)
        app_lean *= modifier.get("app_lean_mult", 1.0)

    fat_g = bw * fat_pct / 100.0
    lean_g = bw * lean_pct / 100.0
    water_g = max(bw - fat_g - lean_g - bmc, 0.0) * 0.85  # crude water proxy
    return {
        "animal_id": f"{model_key}_M{animal_idx:02d}",
        "model": model_key,
        "group": group,
        "sex": sex,
        "time_point_wk": t,
        "day_label": f"D{t * 7}",
        "body_weight_g": _round(bw, 2),
        "fat_g": _round(fat_g, 3),
        "lean_g": _round(lean_g, 3),
        "water_g": _round(water_g, 3),
        "fat_pct": _round(fat_pct, 2),
        "lean_pct": _round(lean_pct, 2),
        "BMD": _round(bmd, 4),
        "BMC": _round(bmc, 4),
        "visceral_fat_proxy_g": _round(vis, 3),
        "appendicular_lean_g": _round(app_lean, 3),
        "drug_phase": drug_phase,
        "source_format": source_format,
    }


def _write_csv(path: str, rows: List[Dict]) -> None:
    with open(path, "w", newline="", encoding="utf-8") as fh:
        writer = csv.DictWriter(fh, fieldnames=SCHEMA_COLUMNS)
        writer.writeheader()
        for r in rows:
            writer.writerow(r)


def gen_cohort_dio(out_dir: str, n_mice: int = 8) -> str:
    rng = random.Random(SEED + 1)
    rows = []
    for i in range(1, n_mice + 1):
        for t in TIME_POINTS_STD:
            rows.append(_row_from_reference(rng, "C57BL_6J_HFD60", t, i,
                                            group="DIO_HFD60",
                                            sex="M",
                                            drug_phase="none",
                                            source_format="PIXImus_DICOM_sim"))
    path = os.path.join(out_dir, "cohort_C57BL6_DIO.csv")
    _write_csv(path, rows)
    return path


def gen_cohort_obob(out_dir: str, n_mice: int = 8) -> str:
    rng = random.Random(SEED + 2)
    rows = []
    for i in range(1, n_mice + 1):
        for t in TIME_POINTS_STD:
            rows.append(_row_from_reference(rng, "ob_ob", t, i,
                                            group="ob_ob",
                                            sex="M",
                                            drug_phase="none",
                                            source_format="EchoMRI_csv"))
    path = os.path.join(out_dir, "cohort_obob.csv")
    _write_csv(path, rows)
    return path


def gen_cohort_dbdb(out_dir: str, n_mice: int = 8) -> str:
    rng = random.Random(SEED + 3)
    rows = []
    for i in range(1, n_mice + 1):
        for t in TIME_POINTS_STD:
            rows.append(_row_from_reference(rng, "db_db", t, i,
                                            group="db_db",
                                            sex="M",
                                            drug_phase="none",
                                            source_format="qNMR_Bruker_minispec_txt"))
    path = os.path.join(out_dir, "cohort_dbdb.csv")
    _write_csv(path, rows)
    return path


def gen_cohort_stam(out_dir: str, n_mice: int = 6) -> str:
    rng = random.Random(SEED + 4)
    rows = []
    for i in range(1, n_mice + 1):
        for t in TIME_POINTS_STD:
            rows.append(_row_from_reference(rng, "STAM", t, i,
                                            group="STAM_NASH",
                                            sex="M",
                                            drug_phase="none",
                                            source_format="SkyScan_microCT_csv"))
    path = os.path.join(out_dir, "cohort_STAM.csv")
    _write_csv(path, rows)
    return path


def gen_cohort_control(out_dir: str, n_mice: int = 8) -> str:
    rng = random.Random(SEED + 5)
    rows = []
    for i in range(1, n_mice + 1):
        for t in TIME_POINTS_STD:
            rows.append(_row_from_reference(rng, "C57BL_6J_chow", t, i,
                                            group="control_chow",
                                            sex="M",
                                            drug_phase="none",
                                            source_format="PIXImus_DICOM_sim"))
    path = os.path.join(out_dir, "cohort_control_chow.csv")
    _write_csv(path, rows)
    return path


def gen_cohort_glp1ra_step4(out_dir: str, n_mice: int = 12) -> str:
    """Simulate STEP-4-style design: DIO induction → drug-on → drug-off.

    Phases:
        wk 0-12: HFD induction (drug_phase = 'pre')
        wk 12: randomization
        wk 12 → 16: drug-on (semaglutide-mimic), 6 mice = treated, 6 = placebo
        wk 16 → 20: drug-off (treated mice rebound)

    But we sample at 0,4,8,12,16,20 wk to fit standard schema. For drug-on
    treated mice at 16wk: ~20% weight loss, fat reduced ~35%, lean preserved.
    For drug-off treated at 20wk: ~70% of fat regain, lean partial recovery.
    """
    rng = random.Random(SEED + 6)
    rows = []
    for i in range(1, n_mice + 1):
        treated = i <= n_mice // 2
        group = "GLP1RA_treated" if treated else "GLP1RA_placebo"
        for t in TIME_POINTS_DRUG:
            # Drug phase logic
            if t <= 12:
                phase = "induction"
                mod = None
            elif t <= 16:
                phase = "drug_on" if treated else "placebo_on"
                mod = {"bw_mult": 0.82, "fat_mult": 0.65, "lean_mult": 1.05,
                       "vis_mult": 0.55, "app_lean_mult": 0.97} if treated else None
            else:
                phase = "drug_off" if treated else "placebo_off"
                # rebound partial
                mod = {"bw_mult": 0.92, "fat_mult": 0.85, "lean_mult": 1.00,
                       "vis_mult": 0.78, "app_lean_mult": 0.96} if treated else None

            row = _row_from_reference(rng, "C57BL_6J_HFD60", t, i,
                                      group=group, sex="M",
                                      drug_phase=phase,
                                      source_format="EchoMRI_csv",
                                      modifier=mod)
            row["animal_id"] = f"STEP4_M{i:02d}"
            rows.append(row)
    path = os.path.join(out_dir, "cohort_GLP1RA_STEP4_mimic.csv")
    _write_csv(path, rows)
    return path


def generate_all(out_dir: str | None = None) -> List[str]:
    if out_dir is None:
        out_dir = os.path.dirname(os.path.abspath(__file__))
    os.makedirs(out_dir, exist_ok=True)
    paths = [
        gen_cohort_dio(out_dir),
        gen_cohort_obob(out_dir),
        gen_cohort_dbdb(out_dir),
        gen_cohort_stam(out_dir),
        gen_cohort_control(out_dir),
        gen_cohort_glp1ra_step4(out_dir),
    ]
    return paths


if __name__ == "__main__":
    out = os.path.join(os.path.dirname(os.path.abspath(__file__)), "synthetic")
    os.makedirs(out, exist_ok=True)
    paths = generate_all(out)
    for p in paths:
        print(f"wrote: {p}")
