"""
ObesityEnrollFunnelOps-Kor — 합성 데모 데이터 생성 스크립트
================================================================
비만 RCT 진행 중 등록(enrollment) funnel 운영 분석용 합성 데이터를 생성한다.
모든 데이터는 합성(synthetic)이며 실제 환자/시험 데이터가 아니다.

생성 CSV:
  - enrollment_funnel.csv : site x week x funnel 단계별 카운트 (long format)
  - screen_fail_reasons.csv: screen-fail 사유 코드별 카운트 (site별)
  - site_enrollment.csv    : site별 메타(목표 등록수, 활성 주차, 누적 등록)
  - demographics.csv       : 등록된 피험자 인구통계(성/연령/BMI/인종)
  - retention_visits.csv   : 방문(visit)별 잔류/이탈 추적

실행:
  python3 generate_demo_data.py
"""
import os
import numpy as np
import pandas as pd

RNG = np.random.default_rng(20260531)
HERE = os.path.dirname(os.path.abspath(__file__))

# 6단계 funnel (CONSORT enrollment flow 기반)
FUNNEL_STAGES = ["referral", "prescreen", "consent", "screen", "randomize", "retain"]

# 단계 간 typical pass rate (비만 RCT 문헌 기반 합성 가정)
STAGE_PASS = {
    "referral->prescreen": 0.70,
    "prescreen->consent": 0.55,
    "consent->screen": 0.85,
    "screen->randomize": 0.60,   # screen-fail 발생 핵심 구간
    "randomize->retain": 0.82,   # dropout 발생 구간
}

# screen-fail 사유 taxonomy (비만 시험 typical), avoidable 분류 포함
SCREEN_FAIL_REASONS = [
    # code, label, weight, avoidable(회피가능 여부)
    ("BMI_BELOW", "BMI 하한 미달", 0.12, True),
    ("BMI_ABOVE", "BMI 상한 초과", 0.06, True),
    ("HBA1C_OOR", "HbA1c 범위 이탈", 0.18, False),
    ("COMORBID", "동반질환 배제기준", 0.16, False),
    ("PROHIB_MED", "금기약물 복용", 0.14, True),
    ("LAB_ABN", "검사실 수치 이상", 0.11, False),
    ("CONSENT_WD", "동의 철회", 0.10, True),
    ("NO_SHOW", "스크리닝 방문 불참", 0.08, True),
    ("OTHER", "기타", 0.05, False),
]

SITES = [
    ("S01", "서울중앙", 60, 1.25),
    ("S02", "부천순천향", 45, 1.00),
    ("S03", "대구메디", 50, 0.80),
    ("S04", "광주제일", 35, 0.55),
    ("S05", "부산해운대", 40, 0.95),
    ("S06", "대전한밭", 30, 0.70),
]
N_WEEKS = 16
TARGET_TOTAL = sum(s[2] for s in SITES)  # 목표 총 randomize 수


def gen_enrollment_funnel():
    rows = []
    for sid, sname, target, speed in SITES:
        # site별 주차당 referral 유입량(속도에 비례) + 노이즈
        base_ref = max(4, int(round(target / N_WEEKS * 2.4 * speed)))
        for wk in range(1, N_WEEKS + 1):
            referral = max(0, int(RNG.poisson(base_ref)))
            prescreen = int(RNG.binomial(referral, STAGE_PASS["referral->prescreen"]))
            consent = int(RNG.binomial(prescreen, STAGE_PASS["prescreen->consent"]))
            screen = int(RNG.binomial(consent, STAGE_PASS["consent->screen"]))
            randomize = int(RNG.binomial(screen, STAGE_PASS["screen->randomize"]))
            retain = int(RNG.binomial(randomize, STAGE_PASS["randomize->retain"]))
            counts = dict(referral=referral, prescreen=prescreen, consent=consent,
                          screen=screen, randomize=randomize, retain=retain)
            for stage in FUNNEL_STAGES:
                rows.append(dict(site_id=sid, site_name=sname, week=wk,
                                 stage=stage, count=counts[stage]))
    return pd.DataFrame(rows)


def gen_screen_fail(funnel_df):
    # site별 총 screen-fail 수 = screen - randomize 누적
    rows = []
    codes = [r[0] for r in SCREEN_FAIL_REASONS]
    labels = {r[0]: r[1] for r in SCREEN_FAIL_REASONS}
    weights = np.array([r[2] for r in SCREEN_FAIL_REASONS])
    weights = weights / weights.sum()
    avoidable = {r[0]: r[3] for r in SCREEN_FAIL_REASONS}
    for sid, sname, target, speed in SITES:
        sdf = funnel_df[funnel_df.site_id == sid]
        n_screen = sdf[sdf.stage == "screen"]["count"].sum()
        n_rand = sdf[sdf.stage == "randomize"]["count"].sum()
        n_fail = max(0, int(n_screen - n_rand))
        if n_fail == 0:
            continue
        # site별 사유 분포에 약간의 편차 부여
        site_w = weights * RNG.uniform(0.7, 1.3, size=len(weights))
        site_w = site_w / site_w.sum()
        draws = RNG.multinomial(n_fail, site_w)
        for code, n in zip(codes, draws):
            if n > 0:
                rows.append(dict(site_id=sid, site_name=sname, reason_code=code,
                                 reason_label=labels[code],
                                 avoidable=("회피가능" if avoidable[code] else "불가피"),
                                 count=int(n)))
    return pd.DataFrame(rows)


def gen_site_enrollment(funnel_df):
    rows = []
    for sid, sname, target, speed in SITES:
        sdf = funnel_df[funnel_df.site_id == sid]
        cum_rand = sdf[sdf.stage == "randomize"]["count"].sum()
        active_weeks = N_WEEKS
        rows.append(dict(site_id=sid, site_name=sname, target_n=target,
                         active_weeks=active_weeks, cum_randomized=int(cum_rand)))
    return pd.DataFrame(rows)


def gen_demographics(funnel_df):
    # 등록(randomize)된 총 인원만큼 인구통계 생성
    total_rand = int(funnel_df[funnel_df.stage == "randomize"]["count"].sum())
    sex = RNG.choice(["여성", "남성"], size=total_rand, p=[0.62, 0.38])  # 비만시험 여성 과대
    age = np.clip(RNG.normal(46, 11, total_rand), 19, 75).round().astype(int)
    bmi = np.clip(RNG.normal(34.5, 3.8, total_rand), 27, 50).round(1)
    race = RNG.choice(["동아시아", "동남아시아", "기타"], size=total_rand, p=[0.88, 0.08, 0.04])
    site_ids = RNG.choice([s[0] for s in SITES], size=total_rand)
    return pd.DataFrame(dict(subject_idx=np.arange(1, total_rand + 1),
                             site_id=site_ids, sex=sex, age=age, bmi=bmi, race=race))


def gen_retention(demo_df):
    # 방문 스케줄: baseline(0주), 4, 12, 24, 52주
    visits = [0, 4, 12, 24, 52]
    n = len(demo_df)
    rows = []
    # early-responder 잔류 편향: baseline BMI 높을수록 약간 더 잘 잔류한다고 가정
    bmi = demo_df["bmi"].to_numpy()
    bmi_z = (bmi - bmi.mean()) / (bmi.std() + 1e-9)
    alive = np.ones(n, dtype=bool)
    for v in visits:
        if v == 0:
            present = alive.copy()
        else:
            # 방문별 hazard, BMI 높을수록 잔류 확률 소폭 상승
            base_keep = 0.93 if v <= 12 else 0.96
            keep_p = np.clip(base_keep + 0.015 * bmi_z, 0.80, 0.995)
            drop = RNG.random(n) > keep_p
            alive = alive & (~drop)
            present = alive.copy()
        for sid in demo_df["site_id"].unique():
            mask = (demo_df["site_id"].to_numpy() == sid)
            n_present = int((present & mask).sum())
            n_total = int(mask.sum())
            mean_bmi_present = float(bmi[present & mask].mean()) if n_present else float("nan")
            rows.append(dict(site_id=sid, visit_week=v,
                             n_present=n_present, n_enrolled=n_total,
                             mean_bmi_present=round(mean_bmi_present, 2) if n_present else None))
    return pd.DataFrame(rows)


def main():
    funnel = gen_enrollment_funnel()
    screen_fail = gen_screen_fail(funnel)
    site_enr = gen_site_enrollment(funnel)
    demo = gen_demographics(funnel)
    retention = gen_retention(demo)

    funnel.to_csv(os.path.join(HERE, "enrollment_funnel.csv"), index=False)
    screen_fail.to_csv(os.path.join(HERE, "screen_fail_reasons.csv"), index=False)
    site_enr.to_csv(os.path.join(HERE, "site_enrollment.csv"), index=False)
    demo.to_csv(os.path.join(HERE, "demographics.csv"), index=False)
    retention.to_csv(os.path.join(HERE, "retention_visits.csv"), index=False)

    print("생성 완료:")
    for name, df in [("enrollment_funnel.csv", funnel),
                     ("screen_fail_reasons.csv", screen_fail),
                     ("site_enrollment.csv", site_enr),
                     ("demographics.csv", demo),
                     ("retention_visits.csv", retention)]:
        print(f"  {name}: {df.shape}")


if __name__ == "__main__":
    main()
