"""
GLP1DiscontRebound-Kor 합성 데이터 생성기
- 재현 가능 (seed 고정)
- STEP-1 extension / SURMOUNT-4 rebound 패턴 반영
- 외부 네트워크 호출 0
"""
import os
import numpy as np
import pandas as pd

SEED = 20260520
N_PATIENTS = 450  # ≥400 요구
OUT_DIR = os.path.dirname(os.path.abspath(__file__))


def gen_cohort(rng: np.random.Generator) -> pd.DataFrame:
    drugs = ["semaglutide", "liraglutide", "tirzepatide", "orforglipron"]
    drug_p = [0.40, 0.10, 0.40, 0.10]
    reasons = [
        "AE",           # 부작용 (GI 등)
        "cost",         # 비용
        "shortage",     # 품귀
        "goal_met",     # 목표 달성
        "pregnancy",    # 임신
        "CV_event",     # CV 사건
        "other",        # 기타
    ]
    reason_p = [0.28, 0.22, 0.10, 0.18, 0.03, 0.05, 0.14]

    rows = []
    for i in range(N_PATIENTS):
        pid = f"P{i+1:04d}"
        age = int(np.clip(rng.normal(52, 11), 19, 84))
        sex = rng.choice(["F", "M"], p=[0.62, 0.38])
        drug = rng.choice(drugs, p=drug_p)

        # baseline 체중·BMI
        baseline_wt = float(np.clip(rng.normal(94, 17), 60, 165))
        height_m = float(np.clip(rng.normal(1.65, 0.08), 1.45, 1.92))
        baseline_bmi = baseline_wt / (height_m ** 2)
        baseline_wc = float(np.clip(baseline_wt * 1.05 + rng.normal(0, 6), 70, 160))

        # baseline 대사
        baseline_hba1c = float(np.clip(rng.normal(7.4, 1.2), 5.4, 11.2))
        baseline_sbp = float(np.clip(rng.normal(133, 13), 100, 175))
        baseline_dbp = float(np.clip(rng.normal(82, 9), 60, 110))
        baseline_ldl = float(np.clip(rng.normal(118, 28), 50, 220))
        baseline_hdl = float(np.clip(rng.normal(48, 11), 25, 90))
        baseline_tg = float(np.clip(rng.normal(165, 60), 50, 480))
        baseline_alt = float(np.clip(rng.normal(32, 14), 8, 120))
        baseline_ast = float(np.clip(rng.normal(28, 11), 8, 110))

        # 약물 시작일 (0~720일 전)
        start_offset = int(rng.integers(120, 900))
        # 사용 지속 기간
        duration_w = int(np.clip(rng.normal(58, 22), 8, 156))  # weeks

        # 사용중인지 중단인지
        discontinued = rng.random() < 0.62
        if discontinued:
            reason = rng.choice(reasons, p=reason_p)
            discont_w = duration_w
            # 재시작?
            restart_p = {
                "AE": 0.18,
                "cost": 0.42,
                "shortage": 0.55,
                "goal_met": 0.25,
                "pregnancy": 0.30,
                "CV_event": 0.05,
                "other": 0.20,
            }[reason]
            restarted = rng.random() < restart_p
            # holiday: <12주 중단 후 재시작; permanent: ≥12주 미재시작
            if restarted:
                gap_w = int(np.clip(rng.exponential(10), 1, 80))
            else:
                gap_w = int(np.clip(rng.exponential(40), 12, 260))
            holiday = restarted and gap_w < 12
        else:
            reason = "ongoing"
            discont_w = None
            restarted = False
            gap_w = None
            holiday = False

        # dose titration (시작 dose, 최대 dose)
        if drug == "semaglutide":
            start_dose, max_dose = 0.25, float(rng.choice([1.0, 1.7, 2.4]))
        elif drug == "liraglutide":
            start_dose, max_dose = 0.6, float(rng.choice([1.8, 3.0]))
        elif drug == "tirzepatide":
            start_dose, max_dose = 2.5, float(rng.choice([5, 10, 15]))
        else:  # orforglipron
            start_dose, max_dose = 3.0, float(rng.choice([12, 24, 36]))

        # 약물 효과 (사용 중 체중 변화 %)
        peak_loss_pct = {
            "semaglutide": rng.normal(14.5, 4.5),
            "liraglutide": rng.normal(7.5, 3.0),
            "tirzepatide": rng.normal(20.0, 5.0),
            "orforglipron": rng.normal(13.0, 4.5),
        }[drug]
        peak_loss_pct = float(np.clip(peak_loss_pct, 1.0, 32.0))
        nadir_wt = baseline_wt * (1 - peak_loss_pct / 100.0)
        nadir_bmi = nadir_wt / (height_m ** 2)
        nadir_hba1c = float(np.clip(baseline_hba1c - rng.normal(1.5, 0.7), 4.8, baseline_hba1c))

        # rebound (중단 후 52주까지 체중 회복 비율: STEP-1 ext ~2/3, SURMOUNT-4 유사)
        if discontinued and (restarted is False or holiday is False):
            rebound_frac_52w = float(np.clip(rng.normal(0.67, 0.18), 0.0, 1.15))
        elif discontinued and holiday:
            rebound_frac_52w = float(np.clip(rng.normal(0.30, 0.15), 0.0, 0.90))
        else:
            rebound_frac_52w = 0.0
        wt_at_52w_post = nadir_wt + (baseline_wt - nadir_wt) * rebound_frac_52w
        hba1c_at_52w_post = nadir_hba1c + (baseline_hba1c - nadir_hba1c) * float(
            np.clip(rebound_frac_52w + rng.normal(0, 0.08), 0, 1.1)
        )

        # AE
        gi_ae = rng.random() < (0.55 if drug in ("semaglutide", "tirzepatide", "orforglipron") else 0.45)
        pancreatitis = rng.random() < 0.012
        gallstone = rng.random() < 0.028
        sarcopenia_flag = rng.random() < (0.18 if peak_loss_pct > 15 else 0.06)
        thyroid_c_cell = rng.random() < 0.004
        cv_event = (reason == "CV_event") or (rng.random() < 0.020)

        # 외과/내시경 전환
        if discontinued and rebound_frac_52w > 0.55 and rng.random() < 0.10:
            surgical_conversion = rng.choice(["bariatric_surgery", "endoscopic"], p=[0.65, 0.35])
        else:
            surgical_conversion = "none"

        # 재시작 시 GI AE 재발
        if restarted:
            restart_gi_ae = rng.random() < 0.40
            restart_additional_loss_pct = float(np.clip(rng.normal(4.5, 2.5), -2.0, 12.0))
        else:
            restart_gi_ae = False
            restart_additional_loss_pct = None

        rows.append({
            "patient_id": pid,
            "age": age,
            "sex": sex,
            "height_m": round(height_m, 2),
            "drug": drug,
            "start_dose_mg": start_dose,
            "max_dose_mg": max_dose,
            "duration_weeks": duration_w,
            "discontinued": discontinued,
            "discontinuation_reason": reason,
            "discontinuation_week": discont_w,
            "restarted": restarted,
            "restart_gap_weeks": gap_w,
            "holiday": holiday,
            "baseline_weight_kg": round(baseline_wt, 1),
            "baseline_bmi": round(baseline_bmi, 1),
            "baseline_waist_cm": round(baseline_wc, 1),
            "baseline_hba1c": round(baseline_hba1c, 2),
            "baseline_sbp": round(baseline_sbp, 0),
            "baseline_dbp": round(baseline_dbp, 0),
            "baseline_ldl": round(baseline_ldl, 0),
            "baseline_hdl": round(baseline_hdl, 0),
            "baseline_tg": round(baseline_tg, 0),
            "baseline_alt": round(baseline_alt, 0),
            "baseline_ast": round(baseline_ast, 0),
            "peak_weight_loss_pct": round(peak_loss_pct, 2),
            "nadir_weight_kg": round(nadir_wt, 1),
            "nadir_bmi": round(nadir_bmi, 1),
            "nadir_hba1c": round(nadir_hba1c, 2),
            "rebound_frac_52w": round(rebound_frac_52w, 3),
            "weight_at_52w_post_disc_kg": round(wt_at_52w_post, 1),
            "hba1c_at_52w_post_disc": round(hba1c_at_52w_post, 2),
            "gi_ae": gi_ae,
            "pancreatitis": pancreatitis,
            "gallstone": gallstone,
            "sarcopenia_flag": sarcopenia_flag,
            "thyroid_c_cell": thyroid_c_cell,
            "cv_event": cv_event,
            "surgical_conversion": surgical_conversion,
            "restart_gi_ae": restart_gi_ae,
            "restart_additional_loss_pct": restart_additional_loss_pct,
        })

    return pd.DataFrame(rows)


def gen_longitudinal(cohort: pd.DataFrame, rng: np.random.Generator) -> pd.DataFrame:
    """중단 시점=0으로 정렬된 longitudinal trajectory.
    Week range: -52 ~ +104 (4주 간격).
    """
    weeks = list(range(-52, 105, 4))
    rows = []
    sub = cohort[cohort["discontinued"]].copy()
    for _, r in sub.iterrows():
        pid = r["patient_id"]
        bw = r["baseline_weight_kg"]
        nw = r["nadir_weight_kg"]
        bw_a1c = r["baseline_hba1c"]
        nw_a1c = r["nadir_hba1c"]
        bw_sbp = r["baseline_sbp"]
        bw_ldl = r["baseline_ldl"]
        bw_hdl = r["baseline_hdl"]
        bw_tg = r["baseline_tg"]
        rebound = r["rebound_frac_52w"]
        holiday = r["holiday"]

        for w in weeks:
            # pre-discontinuation (사용 중): 선형 감소 to nadir at w=0
            if w <= 0:
                f = max(0.0, min(1.0, (w + 52) / 52.0))  # 0..1
                wt = bw - (bw - nw) * f
                a1c = bw_a1c - (bw_a1c - nw_a1c) * f
                sbp = bw_sbp - 5 * f
                ldl = bw_ldl - 8 * f
                hdl = bw_hdl + 2 * f
                tg = bw_tg - 25 * f
            else:
                # post-disc: rebound (saturating curve)
                t = w / 52.0  # 0..2
                # 1-exp 형태 (mean-reverting)
                shape = 1 - np.exp(-1.2 * t)
                if holiday:
                    shape *= 0.5  # holiday short → 회복 둔화 후 재시작으로 다시 감소
                    if w > 12:  # 재시작 후 다시 감소
                        shape = max(0.0, shape - 0.3 * (1 - np.exp(-0.8 * (t - 12 / 52))))
                wt = nw + (bw - nw) * rebound * shape
                a1c = nw_a1c + (bw_a1c - nw_a1c) * rebound * shape
                sbp = bw_sbp - 5 + 7 * rebound * shape
                ldl = bw_ldl - 8 + 12 * rebound * shape
                hdl = bw_hdl + 2 - 3 * rebound * shape
                tg = bw_tg - 25 + 35 * rebound * shape

            wt += rng.normal(0, 0.6)
            a1c += rng.normal(0, 0.08)
            sbp += rng.normal(0, 3)
            ldl += rng.normal(0, 4)
            hdl += rng.normal(0, 2)
            tg += rng.normal(0, 10)

            rows.append({
                "patient_id": pid,
                "week_from_disc": w,
                "weight_kg": round(float(wt), 2),
                "hba1c": round(float(a1c), 2),
                "sbp": round(float(sbp), 0),
                "ldl": round(float(ldl), 0),
                "hdl": round(float(hdl), 0),
                "tg": round(float(tg), 0),
            })
    return pd.DataFrame(rows)


def main():
    rng = np.random.default_rng(SEED)
    cohort = gen_cohort(rng)
    longi = gen_longitudinal(cohort, rng)

    cohort_path = os.path.join(OUT_DIR, "synthetic_glp1_cohort.csv")
    longi_path = os.path.join(OUT_DIR, "synthetic_longitudinal.csv")
    cohort.to_csv(cohort_path, index=False)
    longi.to_csv(longi_path, index=False)
    print(f"cohort: {cohort.shape} -> {cohort_path}")
    print(f"longitudinal: {longi.shape} -> {longi_path}")


if __name__ == "__main__":
    main()
