"""
demo_data.py — 합성 데모 데이터 절차적 생성

생성 데이터:
  - 정상(NGT) OGTT 곡선
  - 내당능장애(IGT) OGTT 곡선
  - 제2형당뇨(T2DM) OGTT 곡선
  - 고인슐린혈증-정상혈당 클램프 곡선

모든 값은 canonical 단위:
  glucose mg/dL, insulin uU/mL, cpeptide ng/mL.
CSV 스키마(OGTT/MMTT): time_min, glucose, insulin, cpeptide
CSV 스키마(CLAMP):     time_min, glucose, insulin, gir  (gir = mg/kg/min)

재현성을 위해 numpy 시드 고정.
주의: 절차적 합성값으로, 실제 환자 데이터가 아니며 생리학적 '그럴듯함'만 목표로 한다.
"""
from __future__ import annotations
import numpy as np
import pandas as pd


def _ogtt_curve(times, g_peak_t, g0, g_peak, g_end,
                i0, i_peak, i_peak_t, i_end,
                c0, c_peak, c_peak_t, c_end, rng, noise=0.03):
    """비대칭 상승/하강 곡선을 시점별로 보간 생성."""
    def shaped(t, x0, xpk, xpk_t, xend):
        out = np.empty_like(t, dtype=float)
        for k, tt in enumerate(t):
            if tt <= xpk_t:
                frac = 0 if xpk_t == 0 else tt / xpk_t
                out[k] = x0 + (xpk - x0) * (1 - np.cos(np.pi * frac)) / 2  # smooth rise
            else:
                span = (t[-1] - xpk_t)
                frac = 0 if span == 0 else (tt - xpk_t) / span
                out[k] = xpk + (xend - xpk) * frac  # linear-ish decay
        return out

    g = shaped(times, g0, g_peak, g_peak_t, g_end)
    i = shaped(times, i0, i_peak, i_peak_t, i_end)
    c = shaped(times, c0, c_peak, c_peak_t, c_end)
    # 곱셈 노이즈
    g *= (1 + rng.normal(0, noise, size=g.shape))
    i *= (1 + rng.normal(0, noise, size=i.shape))
    c *= (1 + rng.normal(0, noise, size=c.shape))
    return np.round(g, 1), np.round(i, 1), np.round(c, 2)


def make_ogtt_normal(seed=1):
    rng = np.random.default_rng(seed)
    times = np.array([0, 30, 60, 90, 120], dtype=float)
    g, i, c = _ogtt_curve(times, 30, 88, 150, 110,
                          6, 60, 30, 12,
                          1.2, 7.0, 60, 2.5, rng)
    return pd.DataFrame({"time_min": times.astype(int),
                         "glucose": g, "insulin": i, "cpeptide": c})


def make_ogtt_igt(seed=2):
    rng = np.random.default_rng(seed)
    times = np.array([0, 30, 60, 90, 120], dtype=float)
    g, i, c = _ogtt_curve(times, 60, 102, 190, 170,
                          12, 75, 60, 45,
                          1.8, 8.5, 90, 5.0, rng)
    return pd.DataFrame({"time_min": times.astype(int),
                         "glucose": g, "insulin": i, "cpeptide": c})


def make_ogtt_t2dm(seed=3):
    rng = np.random.default_rng(seed)
    times = np.array([0, 30, 60, 90, 120], dtype=float)
    # 당뇨: 높은 공복혈당, 둔한 인슐린 1상, 지속 고혈당
    g, i, c = _ogtt_curve(times, 90, 145, 250, 240,
                          14, 35, 90, 40,
                          2.2, 4.5, 90, 5.5, rng)
    return pd.DataFrame({"time_min": times.astype(int),
                         "glucose": g, "insulin": i, "cpeptide": c})


def make_clamp(seed=4):
    """
    고인슐린혈증-정상혈당 클램프.
    인슐린 일정 주입 -> 혈장 인슐린 고원, 혈당 ~90 유지, GIR 점진 상승 후 안정.
    """
    rng = np.random.default_rng(seed)
    times = np.array([0, 10, 20, 30, 60, 90, 120], dtype=float)
    glucose = np.array([92, 90, 90, 90, 90, 90, 90], dtype=float)
    glucose = np.round(glucose * (1 + rng.normal(0, 0.01, size=glucose.shape)), 1)
    insulin = np.array([8, 60, 85, 95, 100, 100, 100], dtype=float)
    insulin = np.round(insulin * (1 + rng.normal(0, 0.03, size=insulin.shape)), 1)
    # GIR (mg/kg/min): 0에서 상승, 안정상태 ~7.5 (인슐린 감수성 정상)
    gir = np.array([0.0, 2.0, 4.5, 6.0, 7.2, 7.5, 7.6], dtype=float)
    gir = np.round(gir * (1 + rng.normal(0, 0.03, size=gir.shape)), 2)
    gir[0] = 0.0
    return pd.DataFrame({"time_min": times.astype(int),
                         "glucose": glucose, "insulin": insulin, "gir": gir})


def all_demos():
    return {
        "OGTT_normal": make_ogtt_normal(),
        "OGTT_IGT": make_ogtt_igt(),
        "OGTT_T2DM": make_ogtt_t2dm(),
        "CLAMP": make_clamp(),
    }


def write_demo_csvs(out_dir):
    import os
    os.makedirs(out_dir, exist_ok=True)
    paths = {}
    mapping = {
        "OGTT_normal": "demo_ogtt_normal.csv",
        "OGTT_IGT": "demo_ogtt_igt.csv",
        "OGTT_T2DM": "demo_ogtt_t2dm.csv",
        "CLAMP": "demo_clamp.csv",
    }
    demos = all_demos()
    for key, fname in mapping.items():
        p = os.path.join(out_dir, fname)
        demos[key].to_csv(p, index=False)
        paths[key] = p
    return paths


if __name__ == "__main__":
    import sys
    out = sys.argv[1] if len(sys.argv) > 1 else "data"
    paths = write_demo_csvs(out)
    for k, v in paths.items():
        print(f"wrote {k}: {v}")
