"""
backend.py — MASHePRO-Kor 진입점.

서브커맨드:
  python3 backend.py schedule --protocol weekly --duration 52
  python3 backend.py power    --n-current 200 --compliance 0.85 --target-n 600
  python3 backend.py export   --format redcap|sdtm|dsmb-kr|sponsor-en|all
  python3 backend.py demo     --전체 시연 (합성데이터 → 통계 → export)
  python3 backend.py gendata  --N 50 --weeks 12 (합성데이터 재생성)
"""
from __future__ import annotations
import argparse
import json
import os
import random
import sys
from datetime import date, timedelta

import notifications as N
import power as P
import exports as E
from scales import ALL_SCALES


HERE = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(HERE, "data")
PATIENTS_JSON = os.path.join(DATA_DIR, "patients.json")
EXPORT_DIR = os.path.join(HERE, "exports")


# -------------------------------------------------------------------
# 합성 데이터 생성
# -------------------------------------------------------------------
SITES = ["SCH-부천", "서울대병원", "삼성서울", "세브란스", "분당서울대"]
CHRONOTYPES = ["lark", "neutral", "neutral", "neutral", "owl"]


def _gen_patient(idx: int, rng: random.Random, weeks: int,
                 start_date: date) -> dict:
    age = rng.randint(35, 70)
    sex = rng.choice(["M", "F"])
    site = rng.choice(SITES)
    chrono = rng.choice(CHRONOTYPES)
    phq9 = rng.choices([2, 4, 6, 8, 12, 18],
                       weights=[3, 5, 4, 3, 2, 1])[0]
    # 컴플라이언스 그룹 (높음/중간/낮음 mix)
    grp = rng.choices(["HIGH", "MID", "LOW"], weights=[5, 3, 2])[0]
    if grp == "HIGH":
        base_p = rng.uniform(0.85, 0.98)
    elif grp == "MID":
        base_p = rng.uniform(0.55, 0.84)
    else:
        base_p = rng.uniform(0.20, 0.54)
    # dropout 확률
    dropout_week = None
    if rng.random() < (0.25 if grp == "LOW" else 0.05):
        dropout_week = rng.randint(2, max(3, weeks - 1))
    pid = f"P{idx:03d}"
    entries = []
    for wk in range(weeks):
        if dropout_week is not None and wk >= dropout_week:
            break
        # 입력 여부
        if rng.random() > base_p:
            continue
        entry_date = start_date + timedelta(weeks=wk)
        # 시간대 (chronotype 영향)
        hr = {"lark": rng.choice([7, 8, 9, 11]),
              "neutral": rng.choice([9, 12, 13, 19, 20]),
              "owl": rng.choice([11, 14, 15, 21, 22])}[chrono]
        ans = _gen_answers(rng, wk, weeks, grp)
        entries.append({
            "week": wk,
            "date": entry_date.isoformat(),
            "hour": hr,
            "answers": ans,
        })
    return {
        "patient_id": pid,
        "site": site,
        "age": age,
        "sex": sex,
        "chronotype": chrono,
        "phq9_baseline": phq9,
        "compliance_group": grp,
        "dropout_week": dropout_week,
        "entries": entries,
    }


def _gen_answers(rng, wk, weeks, grp):
    """간단한 답변 생성. 시간 진행에 따라 약간씩 호전."""
    progress = wk / max(weeks - 1, 1)
    # treatment 효과 시뮬레이션: HIGH 그룹은 호전 강함
    eff = {"HIGH": 0.4, "MID": 0.2, "LOW": 0.05}[grp] * progress
    cldq = {f"cldq{i+1}": max(1, min(7, round(rng.gauss(4 + eff * 2, 1))))
            for i in range(8)}
    facit = {f"facit{i+1}": max(0, min(4, round(rng.gauss(2 - eff, 1))))
             for i in range(13)}
    promis = {f"promis{i+1}": max(1, min(5, round(rng.gauss(3 - eff, 1))))
              for i in range(8)}
    itch_now = max(0, min(10, round(rng.gauss(5 - eff * 3, 2))))
    itch_worst = min(10, itch_now + rng.randint(0, 3))
    itch = {"itch_now": itch_now, "itch_worst": itch_worst}
    eq = {
        "eq_mobility": max(1, min(5, round(rng.gauss(2 - eff, 0.8)))),
        "eq_selfcare": max(1, min(5, round(rng.gauss(1.5 - eff, 0.7)))),
        "eq_usual":    max(1, min(5, round(rng.gauss(2 - eff, 0.8)))),
        "eq_pain":     max(1, min(5, round(rng.gauss(2.5 - eff, 0.9)))),
        "eq_anxiety":  max(1, min(5, round(rng.gauss(2 - eff, 0.9)))),
        "eq_vas":      max(0, min(100, round(rng.gauss(60 + eff * 25, 12)))),
    }
    return {
        "CLDQ_NASH_KOR_DEMO": cldq,
        "FACIT_FATIGUE_KOR": facit,
        "PROMIS_FATIGUE_KOR_SF8A": promis,
        "HEPATIC_ITCH_NRS": itch,
        "EQ5D5L_KOR": eq,
    }


def gen_dataset(n: int = 50, weeks: int = 12, seed: int = 20260428) -> dict:
    rng = random.Random(seed)
    start_date = date(2026, 1, 6)  # 월요일
    patients = [_gen_patient(i + 1, rng, weeks, start_date) for i in range(n)]
    ds = {
        "study_id": "MASHePROKOR-001",
        "n_patients": n,
        "weeks": weeks,
        "start_date": start_date.isoformat(),
        "generated_seed": seed,
        "scales_used": [s["id"] for s in ALL_SCALES],
        "patients": patients,
    }
    return ds


def save_dataset(ds: dict, path: str = PATIENTS_JSON) -> str:
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(ds, f, ensure_ascii=False, indent=2)
    return path


def load_dataset(path: str = PATIENTS_JSON) -> dict:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


# -------------------------------------------------------------------
# 통계 집계 (dashboard·report 공용)
# -------------------------------------------------------------------
def compute_stats(ds: dict) -> dict:
    pts = ds["patients"]
    weeks = ds["weeks"]
    n_enrolled = len(pts)
    n_dropout = sum(1 for p in pts if p.get("dropout_week") is not None)
    n_active = n_enrolled - n_dropout

    compliances = []
    compliances_4w = []
    promis_t = []
    itch_4w = []
    high_risk = []
    n_80 = 0
    n_below50 = 0
    per_patient = []

    for p in pts:
        entries = p["entries"]
        actual = len(entries)
        if p.get("dropout_week") is not None:
            denom = p["dropout_week"]
        else:
            denom = weeks
        denom = max(denom, 1)
        comp = actual / denom
        compliances.append(comp)
        if comp >= 0.80:
            n_80 += 1
        if comp < 0.50:
            n_below50 += 1
        # 4주 입력률
        c4w = N.score_compliance_window(entries, weeks_back=4,
                                         today=date.fromisoformat(ds["start_date"])
                                         + timedelta(weeks=weeks))
        compliances_4w.append(c4w)
        risk = N.dropout_risk(c4w, p.get("phq9_baseline"))
        last_entry = entries[-1] if entries else None
        per_patient.append({
            "patient_id": p["patient_id"],
            "site": p["site"],
            "compliance_total": round(comp, 3),
            "compliance_4w": round(c4w, 3),
            "n_entries": actual,
            "expected": denom,
            "dropout_week": p.get("dropout_week"),
            "phq9_baseline": p.get("phq9_baseline"),
            "risk_score": risk["score"],
            "risk_level": risk["level"],
            "last_date": last_entry["date"] if last_entry else None,
            "chronotype": p["chronotype"],
        })
        if risk["level"] == "HIGH":
            high_risk.append({
                "patient_id": p["patient_id"], "site": p["site"],
                "compliance_4w": c4w, "risk": risk["score"],
            })
        # PROMIS T-score 시연 (마지막 entry)
        if entries:
            last_ans = entries[-1]["answers"].get("PROMIS_FATIGUE_KOR_SF8A", {})
            raw = sum(last_ans.values()) if last_ans else 0
            t = 30 + (raw - 8) * (80 - 30) / (40 - 8)
            promis_t.append(t)
            # 4주 itch 평균
            recent = entries[-4:]
            for e in recent:
                v = e["answers"].get("HEPATIC_ITCH_NRS", {}).get("itch_now")
                if v is not None:
                    itch_4w.append(v)

    mean_comp = sum(compliances) / len(compliances) if compliances else 0
    mean_comp_4w = sum(compliances_4w) / len(compliances_4w) if compliances_4w else 0
    mean_followup = sum(p["expected"] for p in per_patient) / max(len(per_patient), 1)

    # power table (디폴트 가정: control 10%, treatment 25%)
    power_table = P.compliance_scenarios(0.10, 0.25, 0.05, 0.80, 0.10)

    return {
        "n_enrolled": n_enrolled,
        "n_active": n_active,
        "n_dropout": n_dropout,
        "dropout_rate": n_dropout / max(n_enrolled, 1),
        "mean_compliance": mean_comp,
        "mean_compliance_4w": mean_comp_4w,
        "n_compliance_80plus": n_80,
        "n_compliance_below50": n_below50,
        "mean_followup_weeks": mean_followup,
        "promis_t_mean": round(sum(promis_t) / len(promis_t), 1) if promis_t else "NA",
        "itch_mean_4w": round(sum(itch_4w) / len(itch_4w), 2) if itch_4w else "NA",
        "high_risk_patients": sorted(high_risk, key=lambda x: -x["risk"])[:20],
        "per_patient": per_patient,
        "power_table": power_table,
        "by_site": _site_summary(per_patient),
    }


def _site_summary(per_patient):
    sites = {}
    for p in per_patient:
        sites.setdefault(p["site"], []).append(p)
    out = []
    for site, ps in sites.items():
        comps = [x["compliance_total"] for x in ps]
        out.append({
            "site": site,
            "n": len(ps),
            "mean_compliance": round(sum(comps) / len(comps), 3),
            "n_high_risk": sum(1 for x in ps if x["risk_level"] == "HIGH"),
        })
    return sorted(out, key=lambda x: -x["mean_compliance"])


# -------------------------------------------------------------------
# CLI 서브커맨드
# -------------------------------------------------------------------
def cmd_schedule(args):
    sch = N.generate_schedule(args.protocol, args.duration,
                              date.fromisoformat(args.start) if args.start else None)
    print(json.dumps(sch[:args.limit], ensure_ascii=False, indent=2))
    print(f"... (총 {len(sch)}개 시점)")


def cmd_power(args):
    base = P.sample_size_two_proportions(args.p_control, args.p_treatment,
                                          args.alpha, args.power)
    print(f"군당 nominal N (alpha={args.alpha}, power={args.power}): {base}")
    if args.compliance is not None:
        adj = P.adjust_for_compliance(base, args.compliance, args.dropout)
        print(f"컴플라이언스 {args.compliance:.0%}, dropout {args.dropout:.0%} → 군당 {adj} (총 {adj*2})")
    print("\n다단 시나리오:")
    for r in P.compliance_scenarios(args.p_control, args.p_treatment,
                                     args.alpha, args.power, args.dropout):
        print(f"  comp {r['compliance']:.0%}: n/arm={r['n_per_arm_required']}, total={r['total_required']}")
    if args.n_current:
        pw = P.power_at_n(args.n_current, args.p_control, args.p_treatment, args.alpha)
        print(f"\n현재 N={args.n_current}/arm → estimated power = {pw:.3f}")
    if args.target_n:
        pw = P.power_at_n(args.target_n, args.p_control, args.p_treatment, args.alpha)
        print(f"목표 N={args.target_n}/arm → estimated power = {pw:.3f}")


def cmd_gendata(args):
    ds = gen_dataset(args.N, args.weeks, args.seed)
    p = save_dataset(ds)
    print(f"saved synthetic dataset: {p}")
    print(f"  n_patients={ds['n_patients']}, weeks={ds['weeks']}")


def _ensure_dataset():
    if not os.path.exists(PATIENTS_JSON):
        ds = gen_dataset()
        save_dataset(ds)
    return load_dataset()


def cmd_export(args):
    ds = _ensure_dataset()
    stats = compute_stats(ds)
    # SDTM용 entries 평탄화
    flat = []
    for p in ds["patients"]:
        for e in p["entries"]:
            flat.append({
                "patient_id": p["patient_id"],
                "week": e["week"],
                "date": e["date"],
                "answers": e["answers"],
            })
    fmt = args.format
    os.makedirs(EXPORT_DIR, exist_ok=True)
    if fmt in ("redcap", "all"):
        p = os.path.join(EXPORT_DIR, "redcap_dictionary.csv")
        with open(p, "w", encoding="utf-8") as f:
            f.write(E.build_redcap_dict())
        print(f"  REDCap dict → {p}")
    if fmt in ("sdtm", "all"):
        p = os.path.join(EXPORT_DIR, "sdtm_qs.csv")
        with open(p, "w", encoding="utf-8") as f:
            f.write(E.build_sdtm_qs(flat))
        print(f"  SDTM-QS    → {p}")
    if fmt in ("dsmb-kr", "all"):
        p = os.path.join(EXPORT_DIR, "dsmb_kr_report.md")
        with open(p, "w", encoding="utf-8") as f:
            f.write(E.build_dsmb_kr_report(stats))
        print(f"  DSMB-KR    → {p}")
    if fmt in ("sponsor-en", "all"):
        p = os.path.join(EXPORT_DIR, "sponsor_en_report.md")
        with open(p, "w", encoding="utf-8") as f:
            f.write(E.build_sponsor_en_report(stats))
        print(f"  Sponsor-EN → {p}")


def cmd_demo(args):
    print("=" * 60)
    print("MASHePRO-Kor end-to-end DEMO")
    print("=" * 60)
    # 1. 합성 데이터
    print("\n[1/5] 합성 데이터 생성 (N=50, 12주)")
    ds = gen_dataset(50, 12)
    save_dataset(ds)
    print(f"  saved → {PATIENTS_JSON}")

    # 2. 일정
    print("\n[2/5] 매주 일정 생성 (12주)")
    sch = N.generate_schedule("weekly+milestone", 12)
    print(f"  생성된 시점 수: {len(sch)}, 첫 3개: {[s['week'] for s in sch[:3]]}")

    # 3. 통계
    print("\n[3/5] 컴플라이언스 통계")
    stats = compute_stats(ds)
    print(f"  사이트 평균 입력률: {stats['mean_compliance']:.1%}")
    print(f"  ≥80% 환자: {stats['n_compliance_80plus']}, <50%: {stats['n_compliance_below50']}")
    print(f"  HIGH risk: {len(stats['high_risk_patients'])}명, dropout: {stats['n_dropout']}명")

    # 4. 검정력 시뮬레이션 sanity
    print("\n[4/5] 검정력 시뮬레이션 (control 10% vs treatment 25%)")
    base = P.sample_size_two_proportions(0.10, 0.25, 0.05, 0.80)
    print(f"  nominal n/arm = {base}")
    for c in (0.90, 0.80, 0.70):
        adj = P.adjust_for_compliance(base, c, 0.10)
        print(f"  compliance {c:.0%} → required n/arm = {adj}")

    # 5. Export
    print("\n[5/5] Export 4종 생성")
    args2 = argparse.Namespace(format="all")
    cmd_export(args2)

    # dashboard JSON 갱신
    dash_payload = {
        "generated_at": date.today().isoformat(),
        "study_id": ds["study_id"],
        "stats": {k: v for k, v in stats.items() if k != "per_patient"},
        "per_patient": stats["per_patient"],
    }
    p = os.path.join(DATA_DIR, "dashboard_data.json")
    with open(p, "w", encoding="utf-8") as f:
        json.dump(dash_payload, f, ensure_ascii=False, indent=2)
    print(f"\n  dashboard data → {p}")
    print("\nDEMO 완료. dashboard.html 을 브라우저로 열어 확인하세요.")


def main(argv=None):
    ap = argparse.ArgumentParser(prog="mashepro-kor",
        description="MASHePRO-Kor: 한국어 MASH ePRO + 컴플라이언스 추적 MVP")
    sub = ap.add_subparsers(dest="cmd", required=True)

    sp = sub.add_parser("schedule", help="ePRO 일정 생성")
    sp.add_argument("--protocol", choices=list(N.PROTOCOLS.keys()),
                    default="weekly+milestone")
    sp.add_argument("--duration", type=int, default=52, help="총 주차")
    sp.add_argument("--start", type=str, default=None, help="YYYY-MM-DD")
    sp.add_argument("--limit", type=int, default=8)
    sp.set_defaults(func=cmd_schedule)

    pp = sub.add_parser("power", help="검정력 시뮬레이션")
    pp.add_argument("--p-control", type=float, default=0.10)
    pp.add_argument("--p-treatment", type=float, default=0.25)
    pp.add_argument("--alpha", type=float, default=0.05)
    pp.add_argument("--power", type=float, default=0.80)
    pp.add_argument("--dropout", type=float, default=0.10)
    pp.add_argument("--compliance", type=float, default=None)
    pp.add_argument("--n-current", type=int, default=None)
    pp.add_argument("--target-n", type=int, default=None)
    pp.set_defaults(func=cmd_power)

    gp = sub.add_parser("gendata", help="합성 데이터 생성")
    gp.add_argument("--N", type=int, default=50)
    gp.add_argument("--weeks", type=int, default=12)
    gp.add_argument("--seed", type=int, default=20260428)
    gp.set_defaults(func=cmd_gendata)

    ep = sub.add_parser("export", help="REDCap·SDTM·DSMB·Sponsor 보고")
    ep.add_argument("--format", choices=["redcap", "sdtm", "dsmb-kr",
                                          "sponsor-en", "all"], default="all")
    ep.set_defaults(func=cmd_export)

    dp = sub.add_parser("demo", help="end-to-end 시연")
    dp.set_defaults(func=cmd_demo)

    args = ap.parse_args(argv)
    return args.func(args)


if __name__ == "__main__":
    sys.exit(main() or 0)
