#!/usr/bin/env python3
"""CLI entry point for BariERASRecov-Kor (바리이라스리커브코어).

비만수술 perioperative ERAS recovery cohort dashboard.

Usage:
    python3 main.py --help
    python3 main.py --gen-data --n 400 --seed 42
    python3 main.py --analyze
    python3 main.py --analyze --top 8
    python3 main.py --report --lang ko
    python3 main.py --all        # gen-data -> analyze -> report (both langs)

For research / synthetic data only. NOT for clinical decision making.
"""
from __future__ import annotations

import argparse
import os
import sys
import textwrap

DISCLAIMER = (
    "[BariERASRecov-Kor] 본 도구는 참고용·연구용입니다 — "
    "임상 의사결정용 아님 (Not for clinical decision)."
)


def _build_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(
        prog="bari-eras-recov-kor",
        description=(
            "비만수술 시행 병원의 preop·intraop·POD0~30 perioperative metabolic "
            "recovery raw를 받아 ERAS Society bariatric protocol·MBSAQIP·ASMBS·"
            "KASMBS 호환 KPI(VTE·anastomotic leak·marginal ulcer·dumping·"
            "post-bariatric hypoglycemia·protein/vitamin/iron deficiency·"
            "early weight loss·LOS·30-day readmission·30-day mortality)를 5+ 술식 "
            "stratification(RYGB/SG/OAGB/SADI/DJB)으로 추적하고 POD0-90 outpatient "
            "transition까지 한 화면에서 perioperative QI로 다루는 standalone dashboard."
        ),
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=textwrap.dedent("""
            Examples:
              python3 main.py --gen-data --n 400 --seed 42
              python3 main.py --analyze --top 8
              python3 main.py --report --lang ko --out reports/bari_eras_recov_kor_ko.docx
              python3 main.py --all

            For research / synthetic data only. NOT for clinical decision making.
        """).strip(),
    )
    p.add_argument("--gen-data", action="store_true",
                   help="data/ 에 synthetic CSV (patients / intraop / pod03 / pod430 "
                        "/ pod90 / hypo_events) 생성")
    p.add_argument("--analyze", action="store_true",
                   help="ERAS bundle · MBSAQIP · 술식 stratification · hypo 분포 · "
                        "POD0-90 outpatient · KM 30d 계산")
    p.add_argument("--report", action="store_true",
                   help="KASMBS/MBSAQIP-호환 한국어/영문 KPI 리포트 생성 (md + docx)")
    p.add_argument("--all", action="store_true",
                   help="gen-data → analyze → report (ko + en) 일괄 실행")
    p.add_argument("--n", type=int, default=400,
                   help="합성 데이터 환자 수 (default: 400)")
    p.add_argument("--seed", type=int, default=42, help="합성 데이터 seed")
    p.add_argument("--data-dir", default="data", help="CSV 입출력 디렉토리")
    p.add_argument("--top", type=int, default=8, help="ranking 출력 행 수")
    p.add_argument("--lang", choices=["ko", "en"], default="ko",
                   help="리포트 언어 (default: ko)")
    p.add_argument("--out", default="reports/bari_eras_recov_kor.docx",
                   help="리포트 출력 경로 (.docx; python-docx 없으면 .md 폴백)")
    return p


def _missing_deps_message(err: Exception) -> str:
    return textwrap.dedent(f"""
        [BariERASRecov-Kor] 의존성 누락: {err}

        가상환경에서 설치 후 다시 시도하세요:
            python3 -m venv .venv
            source .venv/bin/activate
            pip install -r requirements.txt
            python3 main.py --all

        (--help 와 모든 분석 CLI는 표준 라이브러리만으로 동작합니다.)
    """).strip()


def _resolve(here: str, p: str) -> str:
    return p if os.path.isabs(p) else os.path.join(here, p)


def cmd_gen_data(args, here: str) -> int:
    from modules.ingest import generate_synthetic
    out_dir = _resolve(here, args.data_dir)
    rep = generate_synthetic(n_patients=args.n, out_dir=out_dir, seed=args.seed)
    print(f"[gen-data] N={rep.n_patients}  intraop={rep.n_intraop}  "
          f"POD0-3={rep.n_pod03}  POD4-30={rep.n_pod430}  "
          f"POD90={rep.n_pod90}  hypo={rep.n_hypo}")
    print(f"[gen-data] dir={out_dir}")
    print(f"[gen-data] de-id: {rep.deid_method}  shift={rep.date_shift_offset_days}d")
    for note in rep.notes:
        print(f"  - {note}")
    return 0


def cmd_analyze(args, here: str):
    from modules import ingest, eras, procedure, hypo, outpatient
    data_dir = _resolve(here, args.data_dir)
    if not os.path.exists(os.path.join(data_dir, "patients.csv")):
        print(f"[!] data 없음: {data_dir}/patients.csv — 먼저 --gen-data 실행",
              file=sys.stderr)
        return None

    patients, intraop, pod03, pod430, pod90, hypo_ev, irep = ingest.load_all(data_dir)
    print(f"[ingest] N={irep.n_patients}  intraop={irep.n_intraop}  "
          f"POD0-3={irep.n_pod03}  POD4-30={irep.n_pod430}  "
          f"POD90={irep.n_pod90}  hypo={irep.n_hypo}")
    print(f"[ingest] de-id: {irep.deid_method}")

    bundles = eras.compute_patient_bundles(patients, intraop, pod03, pod430)
    radar = eras.ward_radar(bundles)
    print(f"\n[ERAS bundle — ward radar, top {args.top}]")
    for w in radar[: args.top]:
        print(f"  {w.ward:6s}  N={w.n_patients:3d}  pre={w.preop_pct}%  "
              f"intra={w.intraop_pct}%  POD03={w.pod03_pct}%  "
              f"POD430={w.pod430_pct}%  overall={w.overall_pct}%")

    rank_ward = eras.overall_compliance_ranking(bundles, by="ward")
    print(f"\n[overall compliance ranking — ward, top {args.top}]")
    for k, s, n in rank_ward[: args.top]:
        print(f"  {k:6s}  overall={s}%  N={n}")

    rank_proc = eras.overall_compliance_ranking(bundles, by="procedure")
    print(f"\n[overall compliance ranking — procedure, top {args.top}]")
    for k, s, n in rank_proc[: args.top]:
        print(f"  {k:6s}  overall={s}%  N={n}")

    mbsaqip = eras.mbsaqip_measures(patients, pod03, pod430)
    print("\n[MBSAQIP/ASMBS analog measures]")
    for m in mbsaqip:
        flag = "PASS" if m.rate_pct <= m.target_pct else "WATCH"
        print(f"  {m.measure:42s}  {m.rate_pct:5.2f}%  "
              f"(n={m.n_events}/{m.n_denominator})  "
              f"target<={m.target_pct}%  [{flag}]")

    proc_rows = procedure.stratify_by_procedure(patients, pod03, pod430)
    print(f"\n[procedure stratification — top {args.top}]")
    for p in proc_rows[: args.top]:
        print(f"  {p.procedure:5s}  N={p.n:3d}  age={p.mean_age}  "
              f"BMI={p.mean_bmi}  leak={p.leak_rate_pct}%  "
              f"margUlcer={p.marginal_ulcer_pct}%  "
              f"readmit={p.readmit_30d_pct}%  mort={p.mortality_30d_pct}%  "
              f"O/E_leak={p.oe_leak}  O/E_readmit={p.oe_readmit}")

    hypo_dist = hypo.time_distribution(hypo_ev)
    print(f"\n[post-bariatric hypoglycemia / dumping — time distribution, "
          f"top {args.top}]")
    for h in hypo_dist[: args.top]:
        action = "ACT" if h.kda_action_recommended else "—"
        print(f"  {h.time_bucket:8s} {h.hypo_type:22s}  ev={h.n_events:3d}  "
              f"uniq={h.n_unique_patients:3d}  meanBG={h.mean_glucose}  "
              f"L2+={h.pct_L2_or_worse}%  [{action}]")

    hypo_proc = hypo.by_procedure(patients, hypo_ev)
    print("\n[hypoglycemia by procedure]")
    for h in hypo_proc:
        print(f"  {h.procedure:5s}  N={h.n_patients:3d}  imm={h.n_immediate_events}  "
              f"delayed={h.n_delayed_events}  dumping={h.n_dumping_events}  "
              f"imm%={h.immediate_rate_pct}  delayed%={h.delayed_rate_pct}  "
              f"dump%={h.dumping_rate_pct}")

    outpt_proc = outpatient.adherence_by("procedure", patients, pod90)
    print(f"\n[POD0-90 outpatient adherence (per procedure) — top {args.top}]")
    for o in outpt_proc[: args.top]:
        print(f"  {o.key:5s}  N={o.n_total:3d}  POD7={o.pct_pod7}%  "
              f"POD30={o.pct_pod30}%  POD60={o.pct_pod60}%  "
              f"POD90={o.pct_pod90}%  TWL={o.mean_twl_pct}%  "
              f"HbA1c={o.mean_hba1c}  GLP1+={o.glp1_added_pct}%  "
              f"reop={o.reop_trigger_pct}%")

    trans = outpatient.transition_summary(patients, pod430)
    print("\n[30-day readmission / mortality summary]")
    print(f"  total={trans.n_total}  readmit={trans.n_readmit_30d} "
          f"({trans.readmit_rate_pct}%)  "
          f"medTime={trans.median_time_to_readmit_d}d  "
          f"mort={trans.n_mort_30d} ({trans.mort_rate_pct}%)")
    print(f"  reason mix={trans.readmit_reason_mix}")

    km = outpatient.kaplan_meier_step(patients, pod430,
                                      endpoint="readmit", horizon_d=30)
    print("\n[KM 30-day readmission (head)]")
    for d, s, n in km[:10]:
        print(f"  day={d:3d}  S(t)={s:.4f}  n_at_risk={n}")

    return {
        "ingest": irep,
        "patients": patients, "intraop": intraop,
        "pod03": pod03, "pod430": pod430, "pod90": pod90, "hypo": hypo_ev,
        "bundles": bundles, "radar": radar,
        "rank_ward": rank_ward, "rank_proc": rank_proc,
        "mbsaqip": mbsaqip,
        "procedure": proc_rows,
        "hypo_dist": hypo_dist, "hypo_proc": hypo_proc,
        "outpt_proc": outpt_proc, "transition": trans, "km": km,
    }


def cmd_report(args, here: str, pre=None) -> int:
    from modules import report
    bundle = pre or cmd_analyze(args, here)
    if not bundle:
        return 2
    md = report.build_markdown(
        ingest_report=bundle["ingest"],
        ward_radar_rows=bundle["radar"],
        mbsaqip_rows=bundle["mbsaqip"],
        procedure_rows=bundle["procedure"],
        hypo_bucket_rows=bundle["hypo_dist"],
        hypo_proc_rows=bundle["hypo_proc"],
        outpt_proc_rows=bundle["outpt_proc"],
        transition=bundle["transition"],
        km_readmit_rows=bundle["km"],
        language=args.lang,
    )
    out = _resolve(here, args.out)
    md_path = out.rsplit(".", 1)[0] + ".md"
    report.write_markdown(md_path, md)
    written = report.write_docx(out, md,
                                title="BariERASRecov-Kor KPI 리포트")
    print(f"[report] markdown: {md_path}")
    print(f"[report] docx:     {written}")
    return 0


def main(argv=None) -> int:
    args = _build_parser().parse_args(argv)
    print(DISCLAIMER)
    here = os.path.dirname(os.path.abspath(__file__))

    if not any([args.gen_data, args.analyze, args.report, args.all]):
        print("\n사용 예: python3 main.py --all  |  python3 main.py --help")
        return 0

    try:
        if args.all:
            rc = cmd_gen_data(args, here)
            if rc != 0:
                return rc
            bundle = cmd_analyze(args, here)
            if not bundle:
                return 2
            # ko + en both
            for lang in ("ko", "en"):
                args.lang = lang
                args.out = f"reports/bari_eras_recov_kor_{lang}.docx"
                rc = cmd_report(args, here, pre=bundle)
                if rc != 0:
                    return rc
            return 0
        if args.gen_data:
            cmd_gen_data(args, here)
        bundle = None
        if args.analyze:
            bundle = cmd_analyze(args, here)
            if not bundle:
                return 2
        if args.report:
            return cmd_report(args, here, pre=bundle)
        return 0
    except ImportError as e:
        print(_missing_deps_message(e))
        return 0
    except FileNotFoundError as e:
        print(f"[!] file missing: {e}", file=sys.stderr)
        return 2


if __name__ == "__main__":
    sys.exit(main())
