"""
PostLTMASHMetabolicKor 합성 cohort 생성기.

생성 파일:
  - data/synthetic_post_lt_cohort.csv  : n=300 환자 cross-sectional
  - data/synthetic_longitudinal.csv    : 환자별 LT=0 기준 ~5y 추적

재현 가능: SEED=20260520 고정.
의학적 디스클레이머: 본 데이터는 합성 데이터이며 실제 환자 데이터가 아님.
KONOS / KLTF / KASL / AASLD / EASL 분포를 광범위하게 참조.
"""

import os
import numpy as np
import pandas as pd

SEED = 20260520
N_PATIENTS = 300
CENTERS = [
    "Seoul-A",
    "Seoul-B",
    "Busan-C",
    "Daegu-D",
    "Incheon-E",
    "Gwangju-F",
]
INDICATIONS = ["HBV-LC", "HBV-HCC", "ALC", "MASH", "CRYPTO", "AIH", "PBC", "HCV"]
# KONOS 광범위 참조 분포
INDICATION_P = [0.28, 0.22, 0.18, 0.12, 0.06, 0.05, 0.05, 0.04]
IS_REGIMENS = [
    "TAC+MMF+Steroid",
    "TAC+MMF",
    "TAC+Sirolimus",
    "CsA+MMF+Steroid",
    "CsA+MMF",
    "TAC mono",
]
IS_P = [0.42, 0.22, 0.12, 0.10, 0.08, 0.06]

DATA_DIR = os.path.dirname(os.path.abspath(__file__))


def _rng():
    return np.random.default_rng(SEED)


def gen_cohort(rng: np.random.Generator) -> pd.DataFrame:
    """Cross-sectional cohort (n=300)."""
    n = N_PATIENTS
    pid = [f"PT{1000 + i:04d}" for i in range(n)]
    center = rng.choice(CENTERS, size=n, p=[1 / len(CENTERS)] * len(CENTERS))
    sex = rng.choice(["M", "F"], size=n, p=[0.68, 0.32])
    age_lt = np.clip(rng.normal(55, 9, n), 22, 76).round(1)
    # LT date 분포 (2019-2024 사이)
    lt_year = rng.integers(2019, 2025, n)
    lt_month = rng.integers(1, 13, n)
    lt_date = [f"{y}-{m:02d}-15" for y, m in zip(lt_year, lt_month)]
    fu_months = np.clip(
        (2026 - lt_year) * 12 + (5 - lt_month) + rng.integers(-3, 4, n), 3, 84
    )

    indication = rng.choice(INDICATIONS, size=n, p=INDICATION_P)
    is_regimen = rng.choice(IS_REGIMENS, size=n, p=IS_P)

    # baseline
    bmi_pre = np.clip(rng.normal(24.5, 3.6, n), 16, 39).round(1)
    bmi_cur = np.clip(bmi_pre + rng.normal(2.1, 2.3, n), 16, 41).round(1)
    weight_regain_pct = np.round((bmi_cur - bmi_pre) / bmi_pre * 100, 1)

    # tacrolimus trough (target 3-8) — 일부 over
    tac_trough = np.where(
        np.isin(is_regimen, ["TAC+MMF+Steroid", "TAC+MMF", "TAC+Sirolimus", "TAC mono"]),
        np.clip(rng.normal(6.3, 1.9, n), 1.5, 14.0).round(2),
        np.nan,
    )
    csa_trough = np.where(
        np.isin(is_regimen, ["CsA+MMF+Steroid", "CsA+MMF"]),
        np.clip(rng.normal(110, 32, n), 40, 240).round(0),
        np.nan,
    )
    siro_trough = np.where(
        is_regimen == "TAC+Sirolimus",
        np.clip(rng.normal(7.5, 2.4, n), 2.0, 16.0).round(2),
        np.nan,
    )
    steroid_dose = np.where(
        np.isin(is_regimen, ["TAC+MMF+Steroid", "CsA+MMF+Steroid"]),
        np.clip(rng.normal(5.0, 2.2, n), 0, 20).round(1),
        0.0,
    )

    # metabolic outcomes
    hba1c = np.clip(rng.normal(6.1, 0.95, n), 4.7, 11.2).round(2)
    fpg = np.clip(rng.normal(108, 24, n), 70, 260).round(0)
    sbp = np.clip(rng.normal(131, 14, n), 95, 185).round(0)
    dbp = np.clip(rng.normal(82, 9, n), 55, 110).round(0)
    ldl = np.clip(rng.normal(112, 32, n), 35, 230).round(0)
    hdl = np.clip(rng.normal(48, 12, n), 22, 95).round(0)
    tg = np.clip(rng.normal(168, 78, n), 45, 580).round(0)
    uric = np.clip(rng.normal(6.4, 1.5, n), 2.5, 12.5).round(1)

    # liver
    alt = np.clip(rng.normal(34, 22, n), 7, 220).round(0)
    ast = np.clip(rng.normal(30, 18, n), 8, 195).round(0)
    ggt = np.clip(rng.normal(58, 48, n), 10, 420).round(0)
    meld = np.clip(rng.normal(9, 3.5, n), 6, 28).round(0)

    # NIT
    vcte_lsm = np.clip(rng.normal(6.4, 3.8, n), 2.5, 35.0).round(1)
    mri_pdff = np.clip(rng.normal(5.6, 4.6, n), 0.6, 28.0).round(1)
    fib4 = np.clip(rng.normal(1.6, 1.2, n), 0.3, 8.5).round(2)
    elf = np.clip(rng.normal(9.2, 1.1, n), 6.5, 13.5).round(2)

    # biopsy fibrosis stage (F0-F4), 일부 환자만 시행
    has_biopsy = rng.random(n) < 0.42
    biopsy_stage = np.where(
        has_biopsy,
        rng.choice([0, 1, 2, 3, 4], size=n, p=[0.34, 0.30, 0.18, 0.12, 0.06]),
        -1,
    )

    # MASH 재발 (전체의 ~22%)
    mash_recur_risk = (
        0.10
        + 0.18 * (vcte_lsm >= 8).astype(float)
        + 0.16 * (mri_pdff >= 5).astype(float)
        + 0.10 * (indication == "MASH").astype(float)
        + 0.07 * (bmi_cur >= 28).astype(float)
    )
    mash_recurrence = (rng.random(n) < np.clip(mash_recur_risk, 0, 0.85)).astype(int)
    # 재발 시점 (months from LT)
    mash_recur_month = np.where(
        mash_recurrence == 1,
        np.clip(rng.normal(28, 14, n), 3, fu_months).round(0),
        -1,
    )

    # NODAT/PTDM
    pre_dm = rng.random(n) < 0.18
    nodat_risk = (
        0.05
        + 0.18 * np.isin(is_regimen, ["TAC+MMF+Steroid", "TAC+MMF", "TAC mono", "TAC+Sirolimus"]).astype(float)
        + 0.10 * (steroid_dose >= 5).astype(float)
        + 0.10 * (bmi_cur >= 28).astype(float)
        + 0.07 * (age_lt >= 60).astype(float)
    )
    nodat = np.where(pre_dm, 0, (rng.random(n) < np.clip(nodat_risk, 0, 0.7)).astype(int))
    nodat_month = np.where(
        nodat == 1,
        np.clip(rng.normal(9, 7, n), 1.5, fu_months).round(0),
        -1,
    )

    # screening (1m/3m/6m/1y/매년) 시행 여부
    scr_1m = (rng.random(n) < 0.92).astype(int)
    scr_3m = (rng.random(n) < 0.86).astype(int)
    scr_6m = (rng.random(n) < 0.79).astype(int)
    scr_1y = (rng.random(n) < 0.72).astype(int)
    scr_annual = (rng.random(n) < 0.61).astype(int)

    # antidiabetic post-LT
    use_glp1 = (rng.random(n) < 0.18).astype(int)
    use_sglt2 = (rng.random(n) < 0.14).astype(int)
    use_metformin = (rng.random(n) < 0.31).astype(int)

    # rejection / infection / survival
    rejection = (rng.random(n) < 0.13).astype(int)
    infection_cmv = (rng.random(n) < 0.17).astype(int)
    death = (rng.random(n) < 0.08).astype(int)

    df = pd.DataFrame(
        {
            "patient_id": pid,
            "center": center,
            "sex": sex,
            "age_at_lt": age_lt,
            "lt_date": lt_date,
            "followup_months": fu_months,
            "indication": indication,
            "is_regimen": is_regimen,
            "tac_trough_ng_mL": tac_trough,
            "csa_trough_ng_mL": csa_trough,
            "siro_trough_ng_mL": siro_trough,
            "steroid_dose_mg": steroid_dose,
            "bmi_pre_lt": bmi_pre,
            "bmi_current": bmi_cur,
            "weight_regain_pct": weight_regain_pct,
            "hba1c_pct": hba1c,
            "fpg_mgdl": fpg,
            "sbp_mmHg": sbp,
            "dbp_mmHg": dbp,
            "ldl_mgdl": ldl,
            "hdl_mgdl": hdl,
            "tg_mgdl": tg,
            "uric_acid_mgdl": uric,
            "alt_uL": alt,
            "ast_uL": ast,
            "ggt_uL": ggt,
            "meld_score": meld,
            "vcte_lsm_kPa": vcte_lsm,
            "mri_pdff_pct": mri_pdff,
            "fib4": fib4,
            "elf_score": elf,
            "biopsy_stage": biopsy_stage,
            "mash_recurrence": mash_recurrence,
            "mash_recur_month": mash_recur_month,
            "pre_lt_dm": pre_dm.astype(int),
            "nodat": nodat,
            "nodat_month": nodat_month,
            "scr_1m": scr_1m,
            "scr_3m": scr_3m,
            "scr_6m": scr_6m,
            "scr_1y": scr_1y,
            "scr_annual": scr_annual,
            "use_glp1ra": use_glp1,
            "use_sglt2i": use_sglt2,
            "use_metformin": use_metformin,
            "rejection_event": rejection,
            "infection_cmv": infection_cmv,
            "death_event": death,
        }
    )
    return df


def gen_longitudinal(cohort: pd.DataFrame, rng: np.random.Generator) -> pd.DataFrame:
    """LT=0 기준 ~5y 추적 longitudinal (월별 sparse: 1, 3, 6, 12, 24, 36, 48, 60)."""
    timepoints = [1, 3, 6, 12, 24, 36, 48, 60]
    rows = []
    for _, p in cohort.iterrows():
        fu = int(p["followup_months"])
        for t in timepoints:
            if t > fu:
                continue
            decay = np.exp(-t / 36.0)
            tac_t = p["tac_trough_ng_mL"]
            tac_t_t = (
                np.nan if pd.isna(tac_t) else round(max(1.0, tac_t * (0.8 + 0.4 * decay) + rng.normal(0, 0.6)), 2)
            )
            hba1c_t = round(
                max(4.6, p["hba1c_pct"] - 0.15 * decay + rng.normal(0, 0.25) + 0.04 * (t / 12)),
                2,
            )
            weight_t = round(
                max(40.0, p["bmi_current"] * 1.65 ** 2 * (0.92 + 0.08 * (t / 60)) + rng.normal(0, 1.2)),
                1,
            )
            sbp_t = round(max(95, p["sbp_mmHg"] + rng.normal(0, 6)), 0)
            ldl_t = round(max(35, p["ldl_mgdl"] + rng.normal(0, 14)), 0)
            tg_t = round(max(40, p["tg_mgdl"] + rng.normal(0, 28)), 0)
            alt_t = round(max(7, p["alt_uL"] + rng.normal(0, 8)), 0)
            ast_t = round(max(7, p["ast_uL"] + rng.normal(0, 7)), 0)

            # NIT: 6m, 12m, 24m, 36m, 48m, 60m 에서 측정 (1m, 3m은 결측)
            if t >= 6:
                vcte_t = round(max(2.5, p["vcte_lsm_kPa"] + rng.normal(0, 1.0) + 0.15 * (t / 12)), 1)
                pdff_t = round(max(0.5, p["mri_pdff_pct"] + rng.normal(0, 1.0) + 0.10 * (t / 12)), 1)
                fib4_t = round(max(0.3, p["fib4"] + rng.normal(0, 0.18) + 0.02 * (t / 12)), 2)
            else:
                vcte_t = np.nan
                pdff_t = np.nan
                fib4_t = np.nan

            rows.append(
                {
                    "patient_id": p["patient_id"],
                    "center": p["center"],
                    "is_regimen": p["is_regimen"],
                    "month_from_lt": t,
                    "tac_trough_ng_mL": tac_t_t,
                    "hba1c_pct": hba1c_t,
                    "weight_kg": weight_t,
                    "sbp_mmHg": sbp_t,
                    "ldl_mgdl": ldl_t,
                    "tg_mgdl": tg_t,
                    "alt_uL": alt_t,
                    "ast_uL": ast_t,
                    "vcte_lsm_kPa": vcte_t,
                    "mri_pdff_pct": pdff_t,
                    "fib4": fib4_t,
                }
            )
    return pd.DataFrame(rows)


def main():
    rng = _rng()
    cohort = gen_cohort(rng)
    longi = gen_longitudinal(cohort, rng)
    cohort_path = os.path.join(DATA_DIR, "synthetic_post_lt_cohort.csv")
    longi_path = os.path.join(DATA_DIR, "synthetic_longitudinal.csv")
    cohort.to_csv(cohort_path, index=False)
    longi.to_csv(longi_path, index=False)
    print(f"[OK] cohort: {cohort.shape} -> {cohort_path}")
    print(f"[OK] longitudinal: {longi.shape} -> {longi_path}")
    print(f"[seed] {SEED}")


if __name__ == "__main__":
    main()
