"""demo_data.py — 합성 데모 데이터 절차적 생성.

외부 네트워크/실데이터 없이 OFTT·¹³C 호기 곡선을 절차적으로 생성한다.
- 정상 OFTT: peak ~3-4h, 6-8h 기저 복귀
- 지연청소 OFTT: peak 높고 늦으며, 8h 까지 기저 미복귀
- 호기 PDR 곡선: 정상(빠른 산화) / 저하(느린 산화)

생성 함수는 결정론적 시드를 받아 재현 가능. data/ CSV 생성 스크립트로도 동작.

면책: 합성 데이터, 연구용·참고용.
"""
from __future__ import annotations

import os
from typing import Dict, List

import numpy as np
import pandas as pd

DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")


def _oftt_curve(times_h, baseline, amp, tmax, width, decay_floor, noise_sd, rng):
    """식후 TG 곡선: 상승-하강 비대칭(감마-유사) 형태 + 잡음."""
    t = np.asarray(times_h, dtype=float)
    # 비대칭 곡선: amp * (t/tmax)^k * exp(k*(1 - t/tmax)) 형태(감마형, peak at tmax)
    k = max(width, 0.5)
    safe = np.where(t <= 0, 1e-6, t)
    shape = (safe / tmax) ** k * np.exp(k * (1.0 - safe / tmax))
    shape[t <= 0] = 0.0
    y = baseline + amp * shape + decay_floor * (t / max(t.max(), 1e-6))
    y = y + rng.normal(0.0, noise_sd, size=y.shape)
    y[0] = baseline + rng.normal(0.0, noise_sd * 0.3)  # 기저는 잡음 작게
    return np.maximum(y, 0.0)


def make_oftt_normal(n_subjects: int = 6, seed: int = 11) -> pd.DataFrame:
    """정상 OFTT(빠른 청소). long-format: subject_id,time_h,tg,apob48,retinyl."""
    rng = np.random.default_rng(seed)
    times = [0, 2, 4, 6, 8]
    rows: List[Dict] = []
    for s in range(1, n_subjects + 1):
        base = rng.uniform(80, 120)
        amp = rng.uniform(70, 110)
        tmax = rng.uniform(3.0, 4.0)
        tg = _oftt_curve(times, base, amp, tmax, width=2.0,
                         decay_floor=-base * 0.05, noise_sd=6.0, rng=rng)
        # apoB48 (μg/mL): TG 와 상관, 작은 규모
        apob = (tg - base) * rng.uniform(0.02, 0.03) + rng.uniform(3, 6)
        apob = np.maximum(apob, 0.1)
        # retinyl palmitate (μg/L): chylomicron 표지, 더 늦은 peak
        ret = _oftt_curve(times, 0.5, rng.uniform(2.5, 4.0), tmax + 0.7, 2.5,
                          0.0, 0.3, rng)
        for i, th in enumerate(times):
            rows.append({"subject_id": f"N{s:02d}", "time_h": th,
                         "tg": round(float(tg[i]), 1),
                         "apob48": round(float(apob[i]), 2),
                         "retinyl": round(float(ret[i]), 2)})
    return pd.DataFrame(rows)


def make_oftt_delayed(n_subjects: int = 6, seed: int = 23) -> pd.DataFrame:
    """지연청소 OFTT(8h 기저 미복귀, peak 높고 늦음)."""
    rng = np.random.default_rng(seed)
    times = [0, 2, 4, 6, 8]
    rows: List[Dict] = []
    for s in range(1, n_subjects + 1):
        base = rng.uniform(110, 160)
        amp = rng.uniform(140, 200)
        tmax = rng.uniform(4.0, 5.0)
        tg = _oftt_curve(times, base, amp, tmax, width=3.0,
                         decay_floor=base * 0.10, noise_sd=8.0, rng=rng)
        apob = (tg - base) * rng.uniform(0.03, 0.045) + rng.uniform(5, 9)
        apob = np.maximum(apob, 0.1)
        ret = _oftt_curve(times, 0.6, rng.uniform(4.0, 6.0), tmax + 0.8, 3.0,
                          0.2, 0.4, rng)
        for i, th in enumerate(times):
            rows.append({"subject_id": f"D{s:02d}", "time_h": th,
                         "tg": round(float(tg[i]), 1),
                         "apob48": round(float(apob[i]), 2),
                         "retinyl": round(float(ret[i]), 2)})
    return pd.DataFrame(rows)


def _breath_delta(times_min, baseline_delta, peak_dob, tmax_min, width_min,
                  kel_per_h, noise, rng):
    """δ¹³C 곡선: DOB 가 상승–지수감쇠. baseline_delta 기준."""
    t = np.asarray(times_min, dtype=float)
    # 상승부: 1-exp ; 하강부: exp(-kel*t)
    rise = 1.0 - np.exp(-t / max(width_min, 1.0))
    decay = np.exp(-(kel_per_h / 60.0) * np.maximum(t - tmax_min, 0.0))
    dob = peak_dob * rise * decay
    dob = dob + rng.normal(0.0, noise, size=dob.shape)
    dob[0] = 0.0 + rng.normal(0.0, noise * 0.2)
    return baseline_delta + dob


def make_breath_curves(seed: int = 37) -> pd.DataFrame:
    """호기 δ¹³C 곡선: 정상 산화(빠름) 3명 + 저하(느림) 3명.

    long-format: subject_id,group,time_min,delta13c,weight_kg,height_cm,
                 dose_mmol_13c
    """
    rng = np.random.default_rng(seed)
    times = [0, 10, 20, 30, 40, 60, 90, 120, 150, 180, 240]
    rows: List[Dict] = []
    # 정상: 빠른 산화(높은 peak DOB, 빠른 kel)
    for s in range(1, 4):
        w = rng.uniform(60, 78); h = rng.uniform(162, 178)
        dose = 4.0  # mmol 13C (예: 1-13C-octanoate 100 mg 가정 환산)
        d = _breath_delta(times, baseline_delta=rng.uniform(-22, -19),
                          peak_dob=rng.uniform(14, 20), tmax_min=rng.uniform(50, 70),
                          width_min=rng.uniform(15, 25), kel_per_h=rng.uniform(0.6, 0.9),
                          noise=0.4, rng=rng)
        for i, tm in enumerate(times):
            rows.append({"subject_id": f"BN{s:02d}", "group": "normal_oxidation",
                         "time_min": tm, "delta13c": round(float(d[i]), 3),
                         "weight_kg": round(w, 1), "height_cm": round(h, 1),
                         "dose_mmol_13c": dose})
    # 저하: 느린 산화(낮은 peak DOB, 느린 kel, 늦은 tmax)
    for s in range(1, 4):
        w = rng.uniform(72, 95); h = rng.uniform(160, 176)
        dose = 4.0
        d = _breath_delta(times, baseline_delta=rng.uniform(-23, -20),
                          peak_dob=rng.uniform(6, 10), tmax_min=rng.uniform(90, 120),
                          width_min=rng.uniform(30, 45), kel_per_h=rng.uniform(0.20, 0.35),
                          noise=0.4, rng=rng)
        for i, tm in enumerate(times):
            rows.append({"subject_id": f"BL{s:02d}", "group": "impaired_oxidation",
                         "time_min": tm, "delta13c": round(float(d[i]), 3),
                         "weight_kg": round(w, 1), "height_cm": round(h, 1),
                         "dose_mmol_13c": dose})
    return pd.DataFrame(rows)


def write_demo_csvs(data_dir: str = DATA_DIR) -> Dict[str, str]:
    """data/ 에 데모 CSV 3종 저장. 경로 dict 반환."""
    os.makedirs(data_dir, exist_ok=True)
    paths = {}
    df1 = make_oftt_normal()
    p1 = os.path.join(data_dir, "oftt_normal.csv")
    df1.to_csv(p1, index=False)
    paths["oftt_normal"] = p1

    df2 = make_oftt_delayed()
    p2 = os.path.join(data_dir, "oftt_delayed.csv")
    df2.to_csv(p2, index=False)
    paths["oftt_delayed"] = p2

    df3 = make_breath_curves()
    p3 = os.path.join(data_dir, "breath_delta13c.csv")
    df3.to_csv(p3, index=False)
    paths["breath_delta13c"] = p3
    return paths


if __name__ == "__main__":
    written = write_demo_csvs()
    for k, v in written.items():
        print(f"wrote {k}: {v}")
