#!/usr/bin/env python3
"""CirrDecompUnit-Kor CLI.

Generate synthetic data, run analyses, and produce KPI reports for
MASLD/MASH cirrhosis decompensation inpatient cohorts.

Examples:
    python3 main.py --gen                     # regenerate synthetic CSVs
    python3 main.py --summary                 # KPI summary by decomp type
    python3 main.py --top 8                   # top-8 wards by adherence
    python3 main.py --report                  # build docx/PDF in reports/
    python3 main.py --episode E0123           # detail a single episode
"""
from __future__ import annotations

import argparse
import csv
import json
import os
import random
import sys
from typing import Dict, List, Optional

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

from modules import scoring, aclf, protocols, lt_candidacy, kpi_report, ingest  # noqa: E402

DATA_DIR = os.path.join(HERE, "data")
REPORTS_DIR = os.path.join(HERE, "reports")

ETIOLOGIES = ["MASLD", "HBV", "HCV", "alcohol", "autoimmune", "other", "overlap"]
ETIOLOGY_WEIGHTS = [0.42, 0.22, 0.08, 0.18, 0.04, 0.03, 0.03]  # Korea-like, MASLD-heavy
DECOMP_TYPES = ["ascites", "VB", "HE", "HRS-AKI", "SBP", "ACLF", "other"]
WARDS = ["HepWard-A", "HepWard-B", "HepWard-C", "MICU-1", "MICU-2"]
PIS = ["PI-Kim", "PI-Lee", "PI-Park", "PI-Choi", "PI-Jung"]


# ---------------------------------------------------------------------------
# Synthetic data generation
# ---------------------------------------------------------------------------
def gen_synthetic(n_patients: int = 380, seed: int = 20260527) -> None:
    os.makedirs(DATA_DIR, exist_ok=True)
    rng = random.Random(seed)

    patients = []
    for i in range(n_patients):
        pid = f"P{i+1:04d}"
        etio = rng.choices(ETIOLOGIES, weights=ETIOLOGY_WEIGHTS, k=1)[0]
        sex = rng.choice(["M", "F"])
        age = int(rng.gauss(63 if etio == "MASLD" else 58, 11))
        age = max(28, min(86, age))
        bmi = round(rng.gauss(28 if etio == "MASLD" else 24, 4), 1)
        patients.append({
            "patient_id": pid,
            "age": age,
            "sex": sex,
            "bmi": bmi,
            "etiology": etio,
            "prev_decomp_count": rng.choice([0, 0, 0, 1, 1, 2, 3]),
            "prev_tips": int(rng.random() < 0.05),
            "konos_listed_prior": int(rng.random() < 0.07),
        })

    _write_csv(os.path.join(DATA_DIR, "patients.csv"), patients)

    episodes = []
    labs = []
    adherence_rows = []
    lt_rows = []
    outcomes = []
    eid = 0
    for p in patients:
        n_epi = rng.choices([1, 2, 3], weights=[0.65, 0.25, 0.10])[0]
        for k in range(n_epi):
            eid += 1
            eid_str = f"E{eid:05d}"
            decomp = rng.choices(
                DECOMP_TYPES,
                weights=[0.32, 0.16, 0.18, 0.08, 0.10, 0.10, 0.06],
                k=1,
            )[0]
            ward = rng.choice(WARDS)
            pi = rng.choice(PIS)
            los = max(2, int(rng.gauss(12 if decomp in ("ACLF", "HRS-AKI") else 8, 4)))

            # labs sampled around decomp severity (ACLF gets multi-organ severity)
            bili = max(0.5, rng.gauss(
                {"ACLF": 22, "HE": 6, "HRS-AKI": 10, "VB": 4, "SBP": 5,
                 "ascites": 3, "other": 2}[decomp], 5))
            inr = max(1.0, rng.gauss(
                {"ACLF": 3.2, "HE": 1.7, "HRS-AKI": 2.0, "VB": 1.6, "SBP": 1.7,
                 "ascites": 1.4, "other": 1.3}[decomp], 0.6))
            creat = max(0.6, rng.gauss(
                {"ACLF": 3.8, "HRS-AKI": 4.0, "HE": 1.2, "VB": 1.0, "SBP": 1.4,
                 "ascites": 1.0, "other": 0.9}[decomp], 1.0))
            na = round(rng.gauss(133, 4), 1)
            alb = round(rng.gauss(2.8, 0.5), 2)
            ast = round(rng.gauss(80, 30), 1)
            alt = round(rng.gauss(55, 20), 1)
            plt = int(rng.gauss(90, 30))
            map_v = round(rng.gauss(60 if decomp == "ACLF" else 72, 10), 1)
            if decomp == "HE":
                he_grade = rng.choices([1, 2, 3, 4], weights=[0.15, 0.45, 0.3, 0.1])[0]
            elif decomp == "ACLF":
                he_grade = rng.choices([2, 3, 4], weights=[0.3, 0.45, 0.25])[0]
            else:
                he_grade = rng.choices([0, 1, 2, 3, 4], weights=[0.5, 0.25, 0.15, 0.07, 0.03])[0]
            dialysis = int(creat >= 3.5 and rng.random() < 0.4)
            ascites_sev = rng.choices(
                ["none", "mild", "moderate-severe"],
                weights=[0.2, 0.3, 0.5] if decomp == "ascites" else [0.4, 0.4, 0.2],
                k=1,
            )[0]

            sub = scoring.clif_sofa(bili, creat, inr, he_grade, map_v,
                                    spo2_fio2=rng.gauss(330, 50))
            grade_info = aclf.aclf_grade(sub, creat, he_grade)
            grade = grade_info["grade"]
            meld_v = scoring.meld(bili, inr, creat, bool(dialysis))
            meld_na_v = scoring.meld_na(bili, inr, creat, na, bool(dialysis))
            meld3_v = scoring.meld_3(bili, inr, creat, na, alb,
                                     p["sex"] == "F", bool(dialysis))
            cp = scoring.child_pugh(bili, alb, inr, ascites_sev, he_grade)

            # protocol adherence (type-specific, then weighted average)
            adh_items = {}
            if decomp == "ascites":
                adh_items["ascites"] = protocols.check_ascites(
                    diuretic=rng.random() < 0.92,
                    albumin_given=rng.random() < 0.78,
                    large_volume_paracentesis=rng.random() < 0.5,
                    refractory=ascites_sev == "moderate-severe",
                    tips_considered=rng.random() < 0.45)
            elif decomp == "VB":
                high_risk = cp["class"] in ("B", "C") and rng.random() < 0.6
                adh_items["VB"] = protocols.check_vb(
                    EBL_done=rng.random() < 0.94,
                    terlipressin=rng.random() < 0.7,
                    octreotide=rng.random() < 0.6,
                    nsbb_started=rng.random() < 0.82,
                    baveno7_pretips=rng.random() < 0.4,
                    high_risk=high_risk)
            elif decomp == "HE":
                adh_items["HE"] = protocols.check_he(
                    west_haven=he_grade,
                    lactulose=rng.random() < 0.97,
                    rifaximin=rng.random() < 0.82,
                    precipitant_identified=rng.random() < 0.75)
            elif decomp == "HRS-AKI":
                adh_items["HRS-AKI"] = protocols.check_hrs_aki(
                    terlipressin=rng.random() < 0.7,
                    albumin=rng.random() < 0.9,
                    vasoactive_alt=rng.random() < 0.5,
                    kdigo_stage=rng.choice([1, 2, 3]))
            elif decomp == "SBP":
                adh_items["SBP"] = protocols.check_sbp(
                    empirical_abx=rng.random() < 0.95,
                    albumin=rng.random() < 0.78,
                    prophylaxis_started=rng.random() < 0.85)
            elif decomp == "ACLF":
                adh_items["ACLF"] = protocols.check_aclf(
                    grade=grade,
                    lt_evaluation=rng.random() < 0.85,
                    plasma_exchange=rng.random() < 0.25,
                    escalation_icu=rng.random() < 0.7)
            adherence = (protocols.overall_adherence(adh_items)
                         if adh_items else round(rng.uniform(0.6, 0.95), 3))

            # LT candidacy
            hcc = rng.random() < 0.07
            hcc_milan = hcc and rng.random() < 0.6
            cand = lt_candidacy.lt_candidate(meld3_v, cp["class"], grade,
                                             p["age"], hcc, hcc_milan)
            lt_listed = int(cand["eligible"] and rng.random() < 0.55)
            lt_tx = int(lt_listed and rng.random() < 0.18)

            # outcomes (probability shaped by ACLF grade & MELD)
            p_mort = (aclf.expected_28d_mortality(grade)
                      + 0.005 * max(0, meld3_v - 20))
            mort = int(rng.random() < min(0.92, p_mort))
            in_hosp_mort = int(mort and rng.random() < 0.7)
            readmit = int(not mort and rng.random() < (0.32 if grade != "no ACLF" else 0.18))

            episodes.append({
                "episode_id": eid_str,
                "patient_id": p["patient_id"],
                "etiology": p["etiology"],
                "decomp_type": decomp,
                "ward": ward,
                "pi": pi,
                "admit_day": k + 1,
                "los_days": los,
                "ascites_severity": ascites_sev,
                "he_grade": he_grade,
                "dialysis": dialysis,
                "bili": round(bili, 2),
                "inr": round(inr, 2),
                "creat": round(creat, 2),
                "na": na,
                "alb": alb,
                "ast": ast,
                "alt": alt,
                "plt": plt,
                "map_mmhg": map_v,
                "meld": meld_v,
                "meld_na": meld_na_v,
                "meld3": meld3_v,
                "child_score": cp["score"],
                "child_class": cp["class"],
                "clif_sofa_total": sub["total"],
                "aclf_grade": grade,
                "n_organ_failures": grade_info["n_failures"],
                "protocol_adherence": adherence,
                "lt_listed": lt_listed,
                "lt_transplanted": lt_tx,
                "in_hospital_mortality": in_hosp_mort,
                "mortality_30d": mort,
                "readmission_90d": readmit,
            })

            # daily lab trajectory (3 days)
            for day in range(1, min(los, 7) + 1):
                drift = 1 - 0.04 * day if mort == 0 else 1 + 0.06 * day
                labs.append({
                    "episode_id": eid_str,
                    "day": day,
                    "bili": round(bili * drift, 2),
                    "inr": round(inr * drift, 2),
                    "creat": round(creat * drift, 2),
                    "na": round(na + rng.gauss(0, 1), 1),
                    "alb": round(alb + rng.gauss(0, 0.1), 2),
                    "clif_sofa": scoring.clif_sofa(
                        bili * drift, creat * drift, inr * drift,
                        he_grade, map_v)["total"],
                })

            adherence_rows.append({
                "episode_id": eid_str,
                "decomp_type": decomp,
                "adherence": adherence,
                "items_json": json.dumps(adh_items, ensure_ascii=False),
            })

            lt_rows.append({
                "episode_id": eid_str,
                "patient_id": p["patient_id"],
                "meld3": meld3_v,
                "aclf_grade": grade,
                "child_class": cp["class"],
                "konos_band": cand["konos_priority_band"],
                "urgency": cand["urgency"],
                "eligible": int(cand["eligible"]),
                "listed": lt_listed,
                "transplanted": lt_tx,
                "hcc": int(hcc),
                "hcc_within_milan": int(hcc_milan),
            })

            outcomes.append({
                "episode_id": eid_str,
                "patient_id": p["patient_id"],
                "los_days": los,
                "in_hospital_mortality": in_hosp_mort,
                "mortality_30d": mort,
                "readmission_90d": readmit,
                "pod7_visit": int(not mort and rng.random() < 0.78),
                "pod30_visit": int(not mort and rng.random() < 0.66),
                "pod90_visit": int(not mort and rng.random() < 0.54),
                "rifaximin_on_discharge": int(rng.random() < 0.5),
                "nsbb_on_discharge": int(rng.random() < 0.6),
                "diuretic_on_discharge": int(rng.random() < 0.7),
                "sbp_prophylaxis_on_discharge": int(rng.random() < 0.35),
                "resmetirom_on_discharge": int(p["etiology"] == "MASLD" and rng.random() < 0.18),
            })

    _write_csv(os.path.join(DATA_DIR, "episodes.csv"), episodes)
    _write_csv(os.path.join(DATA_DIR, "lab_trajectory.csv"), labs)
    _write_csv(os.path.join(DATA_DIR, "protocol_adherence.csv"), adherence_rows)
    _write_csv(os.path.join(DATA_DIR, "lt_waitlist.csv"), lt_rows)
    _write_csv(os.path.join(DATA_DIR, "outcomes.csv"), outcomes)

    print(f"Generated: {n_patients} patients, {len(episodes)} episodes, "
          f"{len(labs)} lab rows -> {DATA_DIR}")


def _write_csv(path: str, rows: List[Dict[str, object]]) -> None:
    if not rows:
        return
    cols = list(rows[0].keys())
    with open(path, "w", newline="", encoding="utf-8") as fh:
        w = csv.DictWriter(fh, fieldnames=cols)
        w.writeheader()
        w.writerows(rows)


# ---------------------------------------------------------------------------
# CLI commands
# ---------------------------------------------------------------------------
def _load_episodes() -> List[Dict[str, object]]:
    rows = ingest.load_episodes(DATA_DIR)
    out = []
    for r in rows:
        e = dict(r)
        for k in ("los_days", "in_hospital_mortality", "mortality_30d",
                  "readmission_90d", "lt_listed", "lt_transplanted",
                  "child_score", "clif_sofa_total", "n_organ_failures",
                  "dialysis", "he_grade", "plt"):
            if k in e and e[k] != "":
                e[k] = int(float(e[k]))
        for k in ("bili", "inr", "creat", "na", "alb", "ast", "alt",
                  "map_mmhg", "meld", "meld_na", "meld3", "protocol_adherence"):
            if k in e and e[k] != "":
                e[k] = float(e[k])
        out.append(e)
    return out


def cmd_summary(top: Optional[int] = None) -> None:
    episodes = _load_episodes()
    kpis = kpi_report.episode_kpis(episodes)
    print(kpi_report.render_text_report(kpis, lang="kor"))
    if top:
        from collections import defaultdict
        ward_adh: Dict[str, list] = defaultdict(list)
        for e in episodes:
            ward_adh[e["ward"]].append(e["protocol_adherence"])
        ranked = sorted(((w, sum(v)/len(v)) for w, v in ward_adh.items()),
                        key=lambda x: -x[1])
        print(f"\n## Top-{top} ward protocol adherence")
        for w, a in ranked[:top]:
            print(f"- {w}: {a:.2%}")


def cmd_report() -> None:
    os.makedirs(REPORTS_DIR, exist_ok=True)
    episodes = _load_episodes()
    kpis = kpi_report.episode_kpis(episodes)
    out_kor = os.path.join(REPORTS_DIR, "CirrDecompUnit_QI_kor.docx")
    out_eng = os.path.join(REPORTS_DIR, "CirrDecompUnit_QI_eng.docx")
    p1 = kpi_report.write_docx(kpis, out_kor, lang="kor")
    p2 = kpi_report.write_docx(kpis, out_eng, lang="eng")
    json_path = os.path.join(REPORTS_DIR, "kpis.json")
    with open(json_path, "w", encoding="utf-8") as fh:
        json.dump(kpis, fh, ensure_ascii=False, indent=2)
    print(f"Wrote: {p1}\n       {p2}\n       {json_path}")


def cmd_episode(eid: str) -> None:
    for e in _load_episodes():
        if e["episode_id"] == eid:
            print(json.dumps(e, ensure_ascii=False, indent=2))
            return
    print(f"episode {eid} not found", file=sys.stderr)
    sys.exit(2)


def main(argv: Optional[List[str]] = None) -> int:
    ap = argparse.ArgumentParser(
        prog="cirr-decomp-unit-kor",
        description=("MASLD cirrhosis decompensation inpatient QI dashboard CLI. "
                     "참고용·연구용 (research / QI only)."),
    )
    ap.add_argument("--gen", action="store_true",
                    help="generate synthetic CSV data into data/")
    ap.add_argument("--summary", action="store_true",
                    help="print KPI summary by decomp type")
    ap.add_argument("--top", type=int, default=0,
                    help="show top-N wards by adherence (use with --summary)")
    ap.add_argument("--report", action="store_true",
                    help="build docx/PDF KPI reports into reports/")
    ap.add_argument("--episode", type=str, default=None,
                    help="print details of a single episode by ID")
    ap.add_argument("--n", type=int, default=380,
                    help="number of synthetic patients to generate")
    args = ap.parse_args(argv)

    if not any([args.gen, args.summary, args.report, args.episode]):
        ap.print_help()
        return 0

    if args.gen:
        gen_synthetic(n_patients=args.n)
    if args.summary:
        cmd_summary(top=args.top or None)
    if args.report:
        cmd_report()
    if args.episode:
        cmd_episode(args.episode)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
