#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
DMAESentinel-Kor (디엠에이이센티넬코어)
한국 당뇨 약물 post-marketing AE / safety signal disproportionality 알림 도구.

서브커맨드:
  ingest    - mock data 적재 + SQLite DB 구축
  analyze   - ROR / PRR / IC / EBGM disproportionality 산출
  digest    - 한국어 weekly safety digest 생성
  psur      - MFDS PSUR Form draft 출력
  alert     - cross-discipline / cross-source alert 시뮬레이션

참고용·연구용. 임상 의사결정 단독 근거 금지.
실제 PSUR 제출 전 식약처/한국의약품안전관리원 검토 필요.
"""

from __future__ import annotations

import argparse
import csv
import json
import math
import os
import sqlite3
import sys
import textwrap
from collections import defaultdict
from datetime import datetime
from pathlib import Path

ROOT = Path(__file__).resolve().parent
DATA = ROOT / "data"
DB_PATH = ROOT / "dmae_sentinel.db"

# ----------------------------------------------------------------------------
# 1) 약물 / AE / 한국어 매핑 정의
# ----------------------------------------------------------------------------

DRUGS = [
    # SGLT2i
    ("empagliflozin",  "엠파글리플로진",  "MFDS-EMPA-001",  "A10BK03", "SGLT2i"),
    ("dapagliflozin",  "다파글리플로진",  "MFDS-DAPA-002",  "A10BK01", "SGLT2i"),
    ("canagliflozin",  "카나글리플로진",  "MFDS-CANA-003",  "A10BK02", "SGLT2i"),
    ("sotagliflozin",  "소타글리플로진",  "MFDS-SOTA-004",  "A10BK06", "SGLT1/2i"),
    # GLP-1RA / dual / triple
    ("semaglutide",    "세마글루타이드",  "MFDS-SEMA-010",  "A10BJ06", "GLP-1RA"),
    ("tirzepatide",    "티르제파타이드",  "MFDS-TIRZ-011",  "A10BX16", "GIP/GLP-1"),
    ("retatrutide",    "레타트루타이드",  "MFDS-RETA-012",  "A10BX99", "GLP-1/GIP/GCG"),
    ("orforglipron",   "오르포글리프론",  "MFDS-ORFO-013",  "A10BJ99", "GLP-1RA(oral)"),
    ("survodutide",    "서르보두타이드",  "MFDS-SURV-014",  "A10BX98", "GLP-1/GCG"),
    ("dulaglutide",    "둘라글루타이드",  "MFDS-DULA-015",  "A10BJ05", "GLP-1RA"),
    ("liraglutide",    "리라글루타이드",  "MFDS-LIRA-016",  "A10BJ02", "GLP-1RA"),
    # MRA
    ("finerenone",     "피네레논",        "MFDS-FINE-020",  "C03DA05", "ns-MRA"),
    # Insulin
    ("insulin glargine", "인슐린 글라진",  "MFDS-IGLA-030",  "A10AE04", "Insulin"),
    ("insulin icodec",   "인슐린 이코덱",  "MFDS-IICO-031",  "A10AE07", "Insulin"),
    # DPP4i
    ("sitagliptin",    "시타글립틴",      "MFDS-SITA-040",  "A10BH01", "DPP4i"),
]

AES = [
    # SGLT2i 관련
    ("DKA",                      "당뇨병성 케토산증",          "Metabolism",   "DKA HLT",          "Critical"),
    ("euglycemic DKA",           "정상혈당성 케토산증",        "Metabolism",   "DKA HLT",          "Critical"),
    ("Fournier gangrene",        "푸르니에 괴저",             "Infection",    "Soft tissue",      "Critical"),
    ("lower limb amputation",    "하지 절단",                "Vascular",     "Limb",             "Critical"),
    ("genital mycotic infection","생식기 진균감염",            "Infection",    "Genital",          "Moderate"),
    # GLP-1RA / dual 관련
    ("pancreatitis",             "췌장염",                  "GI",           "Pancreas HLT",     "Serious"),
    ("biliary disease",          "담도질환",                "GI",           "Biliary",          "Serious"),
    ("gastric stasis",           "위마비",                  "GI",           "Gastric motility", "Moderate"),
    ("NAION",                    "비동맥염성 전허혈성 시신경병증", "Eye",       "Optic nerve",      "Serious"),
    ("suicidal ideation",        "자살사고",                "Psychiatric",  "Suicide HLT",      "Serious"),
    ("sarcopenia",               "근감소증",                "Musculoskel",  "Muscle",           "Moderate"),
    ("thyroid C-cell tumor",     "갑상선 C세포 종양",          "Endocrine",    "Thyroid",          "Critical"),
    ("nausea",                   "오심",                    "GI",           "Nausea",           "Mild"),
    # MRA
    ("hyperkalemia",             "고칼륨혈증",               "Metabolism",   "Electrolyte",      "Serious"),
    # Insulin
    ("severe hypoglycemia",      "중증 저혈당",              "Metabolism",   "Hypoglycemia HLT", "Critical"),
    ("injection site reaction",  "주사부위 반응",            "Skin",         "Inj site",         "Mild"),
]

SOURCES = ["faers", "eudra", "vigibase", "kaers", "cohort", "pubmed", "abstract"]


# ----------------------------------------------------------------------------
# 2) 합성 데이터 생성
# ----------------------------------------------------------------------------

def _seeded_random():
    import random
    r = random.Random(20260507)
    return r


def _signal_weight(drug: str, ae: str) -> float:
    """약물–AE 쌍별 prior signal weight (mock, 합성 데이터 분포 제어용)."""
    table = {
        ("empagliflozin", "DKA"): 4.0,
        ("empagliflozin", "euglycemic DKA"): 5.0,
        ("dapagliflozin", "euglycemic DKA"): 4.5,
        ("canagliflozin", "lower limb amputation"): 5.0,
        ("canagliflozin", "Fournier gangrene"): 3.5,
        ("dapagliflozin", "Fournier gangrene"): 3.0,
        ("sotagliflozin", "DKA"): 4.0,
        ("semaglutide", "pancreatitis"): 3.5,
        ("semaglutide", "biliary disease"): 3.5,
        ("semaglutide", "NAION"): 4.0,
        ("semaglutide", "suicidal ideation"): 2.5,
        ("semaglutide", "sarcopenia"): 3.0,
        ("semaglutide", "gastric stasis"): 4.5,
        ("tirzepatide", "gastric stasis"): 4.5,
        ("tirzepatide", "pancreatitis"): 3.0,
        ("tirzepatide", "sarcopenia"): 3.5,
        ("retatrutide", "nausea"): 3.0,
        ("orforglipron", "nausea"): 3.0,
        ("survodutide", "nausea"): 3.0,
        ("dulaglutide", "pancreatitis"): 2.5,
        ("liraglutide", "thyroid C-cell tumor"): 3.0,
        ("finerenone", "hyperkalemia"): 5.0,
        ("insulin glargine", "severe hypoglycemia"): 4.5,
        ("insulin icodec",   "severe hypoglycemia"): 4.0,
        ("sitagliptin", "pancreatitis"): 2.5,
    }
    return table.get((drug, ae), 1.0)


def generate_meddra_korean_terms():
    rows = [["english_PT", "korean_PT", "SOC", "HLT", "severity_class"]]
    for ae in AES:
        rows.append(list(ae))
    return rows


def generate_kfda_drug_codes():
    rows = [["english_name", "korean_name", "MFDS_code", "ATC_code", "class"]]
    for d in DRUGS:
        rows.append(list(d))
    return rows


def generate_source_csv(source: str, target_rows: int):
    """source 별 합성 AE report를 생성한다."""
    rng = _seeded_random()
    rows = []
    if source == "faers":
        header = ["report_id", "drug", "AE_term", "age_group", "sex",
                  "country", "n_reports", "dose_mg", "year"]
        country = "US"
    elif source == "eudra":
        header = ["report_id", "drug", "AE_term", "age_group", "sex",
                  "country", "n_reports", "dose_mg", "year"]
        country = "EU"
    elif source == "vigibase":
        header = ["report_id", "drug", "AE_term", "age_group", "sex",
                  "country", "n_reports", "dose_mg", "year"]
        country = "WHO-Global"
    elif source == "kaers":
        header = ["report_id", "drug_korean", "drug_english", "AE_term_korean",
                  "AE_term_english", "kfda_code", "age_group", "sex",
                  "country", "n_reports", "year"]
        country = "KR"
    else:
        raise ValueError(source)

    rows.append(header)
    age_groups = ["18-44", "45-64", "65-74", "75+"]
    sexes = ["M", "F"]

    rid = 1
    for d in DRUGS:
        for ae in AES:
            w = _signal_weight(d[0], ae[0])
            base = max(1, int(rng.gauss(2 * w, 1.2)))
            if rng.random() < 0.20:
                base = max(0, base - 1)
            if base == 0:
                continue
            ag = rng.choice(age_groups)
            sx = rng.choice(sexes)
            yr = rng.choice([2023, 2024, 2025, 2026])
            dose = rng.choice([0.25, 0.5, 1.0, 2.0, 5, 10, 25, 100])
            if source == "kaers":
                rows.append([f"KR-{rid:05d}", d[1], d[0], ae[1], ae[0],
                             d[2], ag, sx, country, base, yr])
            else:
                rows.append([f"{source.upper()}-{rid:05d}",
                             d[0], ae[0], ag, sx, country, base, dose, yr])
            rid += 1
            if len(rows) - 1 >= target_rows:
                return rows
    return rows


def generate_site_cohort():
    rng = _seeded_random()
    rows = [["drug", "drug_korean", "n_patients", "n_AE_events",
             "BMI_mean", "age_mean", "kor_extrapolation_score"]]
    for d in DRUGS:
        n_pat = rng.randint(120, 1500)
        ae_total = 0
        for ae in AES:
            w = _signal_weight(d[0], ae[0])
            ae_total += max(0, int(rng.gauss(w * 1.5, 0.8)))
        bmi = round(rng.uniform(24.0, 31.0), 1)
        age = round(rng.uniform(45.0, 70.0), 1)
        # 한국 외삽 호환성 점수 (0-100). BMI가 한국 평균(~25) 가까울수록 높음.
        bmi_pen = abs(bmi - 25.5) / 6.0
        score = round(max(0.0, min(100.0, 95 - bmi_pen * 30 - rng.uniform(0, 5))), 1)
        rows.append([d[0], d[1], n_pat, ae_total, bmi, age, score])
    return rows


def generate_pubmed_case_reports():
    rng = _seeded_random()
    items = []
    pmid = 39000000
    for d in DRUGS:
        for ae in AES:
            w = _signal_weight(d[0], ae[0])
            n = max(0, int(rng.gauss(w - 0.5, 0.7)))
            for _ in range(n):
                items.append({
                    "PMID": str(pmid),
                    "title": f"Case report: {ae[0]} associated with {d[0]} in a Korean patient",
                    "drug": d[0],
                    "drug_korean": d[1],
                    "AE_term": ae[0],
                    "AE_term_korean": ae[1],
                    "year": rng.choice([2023, 2024, 2025, 2026]),
                    "journal": rng.choice([
                        "Diabetes Care", "JKMS", "Endocrinol Metab",
                        "Diabetes Obes Metab", "Korean J Intern Med"
                    ]),
                })
                pmid += 1
    return items


def generate_abstract_reports():
    """학회 abstract metadata (ADA / EASD / KDA)."""
    rng = _seeded_random()
    items = []
    aid = 1
    for d in DRUGS:
        for ae in AES:
            w = _signal_weight(d[0], ae[0])
            if rng.random() < 0.25 * (w / 5):
                items.append({
                    "abstract_id": f"ABS-{aid:04d}",
                    "drug": d[0],
                    "drug_korean": d[1],
                    "AE_term": ae[0],
                    "AE_term_korean": ae[1],
                    "conference": rng.choice(["ADA-2025", "EASD-2025", "KDA-2025", "ATTD-2025"]),
                    "year": 2025,
                    "n_subjects": rng.randint(40, 4000),
                })
                aid += 1
    return items


def write_csv(path: Path, rows):
    with path.open("w", encoding="utf-8", newline="") as f:
        w = csv.writer(f)
        for r in rows:
            w.writerow(r)


def write_json(path: Path, obj):
    with path.open("w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)


# ----------------------------------------------------------------------------
# 3) ETL → SQLite
# ----------------------------------------------------------------------------

def db_connect():
    conn = sqlite3.connect(str(DB_PATH))
    conn.row_factory = sqlite3.Row
    return conn


def etl_init(conn):
    cur = conn.cursor()
    cur.executescript("""
        DROP TABLE IF EXISTS reports;
        DROP TABLE IF EXISTS cohort;
        DROP TABLE IF EXISTS literature;
        DROP TABLE IF EXISTS drug_map;
        DROP TABLE IF EXISTS ae_map;
        CREATE TABLE reports (
            source TEXT, drug TEXT, ae TEXT, n_reports INTEGER,
            country TEXT, year INTEGER
        );
        CREATE TABLE cohort (
            drug TEXT, drug_korean TEXT,
            n_patients INTEGER, n_AE_events INTEGER,
            bmi_mean REAL, age_mean REAL, kor_score REAL
        );
        CREATE TABLE literature (
            source TEXT, drug TEXT, ae TEXT, ref_id TEXT,
            year INTEGER, title TEXT
        );
        CREATE TABLE drug_map (
            english TEXT PRIMARY KEY, korean TEXT, mfds TEXT, atc TEXT, class TEXT
        );
        CREATE TABLE ae_map (
            english TEXT PRIMARY KEY, korean TEXT, soc TEXT, hlt TEXT, severity TEXT
        );
        CREATE INDEX idx_reports_drug_ae ON reports(drug, ae);
        CREATE INDEX idx_lit_drug_ae ON literature(drug, ae);
    """)
    conn.commit()


def etl_load_maps(conn):
    cur = conn.cursor()
    with (DATA / "kfda_drug_codes.csv").open(encoding="utf-8") as f:
        next(f)
        for r in csv.reader(f):
            cur.execute("INSERT INTO drug_map VALUES (?,?,?,?,?)", r)
    with (DATA / "meddra_korean_terms.csv").open(encoding="utf-8") as f:
        next(f)
        for r in csv.reader(f):
            cur.execute("INSERT INTO ae_map VALUES (?,?,?,?,?)", r)
    conn.commit()


def etl_load_source(conn, source: str):
    cur = conn.cursor()
    if source in ("faers", "eudra", "vigibase"):
        path = DATA / f"{source}_synthetic.csv"
        with path.open(encoding="utf-8") as f:
            reader = csv.DictReader(f)
            for row in reader:
                cur.execute(
                    "INSERT INTO reports(source,drug,ae,n_reports,country,year) "
                    "VALUES (?,?,?,?,?,?)",
                    (source, row["drug"], row["AE_term"],
                     int(row["n_reports"]), row["country"], int(row["year"]))
                )
    elif source == "kaers":
        path = DATA / "kaers_synthetic.csv"
        with path.open(encoding="utf-8") as f:
            reader = csv.DictReader(f)
            for row in reader:
                cur.execute(
                    "INSERT INTO reports(source,drug,ae,n_reports,country,year) "
                    "VALUES (?,?,?,?,?,?)",
                    ("kaers", row["drug_english"], row["AE_term_english"],
                     int(row["n_reports"]), row["country"], int(row["year"]))
                )
    elif source == "cohort":
        path = DATA / "site_cohort_synthetic.csv"
        with path.open(encoding="utf-8") as f:
            reader = csv.DictReader(f)
            for row in reader:
                cur.execute(
                    "INSERT INTO cohort VALUES (?,?,?,?,?,?,?)",
                    (row["drug"], row["drug_korean"],
                     int(row["n_patients"]), int(row["n_AE_events"]),
                     float(row["BMI_mean"]), float(row["age_mean"]),
                     float(row["kor_extrapolation_score"]))
                )
    elif source == "pubmed":
        path = DATA / "pubmed_case_reports_synthetic.json"
        items = json.loads(path.read_text(encoding="utf-8"))
        for it in items:
            cur.execute(
                "INSERT INTO literature VALUES (?,?,?,?,?,?)",
                ("pubmed", it["drug"], it["AE_term"], it["PMID"],
                 int(it["year"]), it["title"])
            )
    elif source == "abstract":
        path = DATA / "abstract_synthetic.json"
        items = json.loads(path.read_text(encoding="utf-8"))
        for it in items:
            cur.execute(
                "INSERT INTO literature VALUES (?,?,?,?,?,?)",
                ("abstract", it["drug"], it["AE_term"], it["abstract_id"],
                 int(it["year"]),
                 f"{it['conference']} abstract on {it['AE_term']} with {it['drug']}")
            )
    conn.commit()


# ----------------------------------------------------------------------------
# 4) Disproportionality 통계
# ----------------------------------------------------------------------------

def disprop_2x2(a: int, b: int, c: int, d: int):
    """Return dict with ROR, PRR, IC, EBGM and CIs.
    a = drug+AE, b = otherdrug+AE, c = drug+otherAE, d = otherdrug+otherAE
    """
    # Haldane–Anscombe 보정 (0 셀)
    if min(a, b, c, d) == 0:
        a_, b_, c_, d_ = a + 0.5, b + 0.5, c + 0.5, d + 0.5
    else:
        a_, b_, c_, d_ = a, b, c, d

    ror = (a_ * d_) / (b_ * c_)
    se = math.sqrt(1/a_ + 1/b_ + 1/c_ + 1/d_)
    log_ror = math.log(ror)
    ror_lo = math.exp(log_ror - 1.96 * se)
    ror_hi = math.exp(log_ror + 1.96 * se)

    n = a_ + b_ + c_ + d_
    prr = (a_ / (a_ + c_)) / (b_ / (b_ + d_))

    expected = (a_ + b_) * (a_ + c_) / n
    ic = math.log2((a_ + 0.5) / (expected + 0.5))
    ebgm = 2 ** ic  # 단순 근사

    # IC 95% CI (variance 근사)
    ic_var = 1.0 / max(a_, 0.5) + 1.0 / max(expected, 0.5)
    ic_se = math.sqrt(ic_var) / math.log(2) * 0.5  # 근사
    ic_lo = ic - 1.96 * ic_se
    ic_hi = ic + 1.96 * ic_se

    return {
        "a": a, "b": b, "c": c, "d": d,
        "ROR": ror, "ROR_lo": ror_lo, "ROR_hi": ror_hi,
        "PRR": prr,
        "IC": ic, "IC_lo": ic_lo, "IC_hi": ic_hi,
        "EBGM": ebgm,
    }


def signal_flag(stat: dict, n_min: int = 3, ror_min: float = 2.0) -> bool:
    return (stat["a"] >= n_min) and (stat["ROR"] >= ror_min) and (stat["ROR_lo"] >= 1.0)


def compute_2x2_for_drug_ae(conn, drug: str, ae: str, sources=None):
    cur = conn.cursor()
    src_filter = ""
    params = [drug, ae]
    if sources:
        src_filter = " AND source IN (" + ",".join("?" * len(sources)) + ")"
        params += list(sources)

    a = cur.execute(
        f"SELECT COALESCE(SUM(n_reports),0) FROM reports "
        f"WHERE drug=? AND ae=?{src_filter}", params
    ).fetchone()[0]
    params2 = [drug, ae] + (list(sources) if sources else [])
    b = cur.execute(
        f"SELECT COALESCE(SUM(n_reports),0) FROM reports "
        f"WHERE drug<>? AND ae=?{src_filter}", params2
    ).fetchone()[0]
    params3 = [drug, ae] + (list(sources) if sources else [])
    c = cur.execute(
        f"SELECT COALESCE(SUM(n_reports),0) FROM reports "
        f"WHERE drug=? AND ae<>?{src_filter}", params3
    ).fetchone()[0]
    params4 = [drug, ae] + (list(sources) if sources else [])
    d = cur.execute(
        f"SELECT COALESCE(SUM(n_reports),0) FROM reports "
        f"WHERE drug<>? AND ae<>?{src_filter}", params4
    ).fetchone()[0]
    return int(a), int(b), int(c), int(d)


# ----------------------------------------------------------------------------
# 5) 서브커맨드 구현
# ----------------------------------------------------------------------------

def cmd_ingest(args):
    DATA.mkdir(exist_ok=True, parents=True)

    # 1) 매핑 테이블
    write_csv(DATA / "meddra_korean_terms.csv", generate_meddra_korean_terms())
    write_csv(DATA / "kfda_drug_codes.csv",     generate_kfda_drug_codes())

    # 2) source별 mock data
    write_csv(DATA / "faers_synthetic.csv",     generate_source_csv("faers",   500))
    write_csv(DATA / "eudra_synthetic.csv",     generate_source_csv("eudra",   300))
    write_csv(DATA / "vigibase_synthetic.csv",  generate_source_csv("vigibase",300))
    write_csv(DATA / "kaers_synthetic.csv",     generate_source_csv("kaers",   200))
    write_csv(DATA / "site_cohort_synthetic.csv", generate_site_cohort())
    write_json(DATA / "pubmed_case_reports_synthetic.json", generate_pubmed_case_reports())
    write_json(DATA / "abstract_synthetic.json", generate_abstract_reports())

    # 3) SQLite ETL
    if DB_PATH.exists():
        DB_PATH.unlink()
    conn = db_connect()
    etl_init(conn)
    etl_load_maps(conn)

    target = args.source if args.source and args.source != "all" else "all"
    if target == "all":
        for s in SOURCES:
            etl_load_source(conn, s)
    else:
        if target not in SOURCES:
            print(f"[!] unknown source: {target}")
            sys.exit(2)
        etl_load_source(conn, target)

    cur = conn.cursor()
    n_rep = cur.execute("SELECT COUNT(*), COALESCE(SUM(n_reports),0) FROM reports").fetchone()
    n_lit = cur.execute("SELECT COUNT(*) FROM literature").fetchone()[0]
    n_coh = cur.execute("SELECT COUNT(*) FROM cohort").fetchone()[0]
    conn.close()

    print("=" * 60)
    print(" DMAESentinel-Kor ingest 완료")
    print("=" * 60)
    print(f"  - reports rows         : {n_rep[0]} (n_reports total = {n_rep[1]})")
    print(f"  - literature rows      : {n_lit}")
    print(f"  - cohort rows          : {n_coh}")
    print(f"  - DB                   : {DB_PATH}")
    print(f"  - data dir             : {DATA}")
    print("  source filter          :", target)


def cmd_analyze(args):
    if not DB_PATH.exists():
        print("[!] DB가 없습니다. 먼저 'ingest'를 실행하세요.")
        sys.exit(2)
    conn = db_connect()
    cur = conn.cursor()

    if args.drug:
        drugs = [args.drug]
    else:
        drugs = [r[0] for r in cur.execute(
            "SELECT DISTINCT drug FROM reports ORDER BY drug").fetchall()]

    rows = []
    for drug in drugs:
        aes = [r[0] for r in cur.execute(
            "SELECT DISTINCT ae FROM reports WHERE drug=?", (drug,)).fetchall()]
        for ae in aes:
            a, b, c, d = compute_2x2_for_drug_ae(conn, drug, ae)
            stat = disprop_2x2(a, b, c, d)
            # cross-source: 양성으로 잡힌 source 수
            sources_pos = []
            for s in ("faers", "eudra", "vigibase", "kaers"):
                a_s, b_s, c_s, d_s = compute_2x2_for_drug_ae(conn, drug, ae, [s])
                if a_s >= 1:
                    s_stat = disprop_2x2(a_s, b_s, c_s, d_s)
                    if signal_flag(s_stat, n_min=1, ror_min=2.0):
                        sources_pos.append(s)
            lit_count = cur.execute(
                "SELECT COUNT(*) FROM literature WHERE drug=? AND ae=?",
                (drug, ae)).fetchone()[0]
            stat.update({
                "drug": drug, "ae": ae,
                "cross_source_n": len(sources_pos),
                "cross_sources":  ",".join(sources_pos),
                "literature_n":   lit_count,
                "signal":         signal_flag(stat, ror_min=args.threshold or 2.0),
            })
            rows.append(stat)

    rows.sort(key=lambda r: (-r["cross_source_n"], -r["ROR"]))
    if args.top:
        rows = rows[: args.top]

    print("=" * 110)
    print(" DMAESentinel-Kor disproportionality analysis")
    print(" 참고용·연구용. ROR/PRR/IC/EBGM은 spontaneous report bias가 큼. 임상 단독 근거 금지.")
    print("=" * 110)
    print(f"{'drug':22s}{'AE':28s}{'a':>4s}{'b':>6s}{'c':>6s}{'d':>7s}"
          f"{'ROR':>8s}{'95%CI':>16s}{'PRR':>8s}{'IC':>8s}{'EBGM':>8s}"
          f"{'src':>5s}{'lit':>5s}{'sig':>5s}")
    for r in rows:
        ci = f"{r['ROR_lo']:.2f}-{r['ROR_hi']:.2f}"
        sig = "*" if r["signal"] else " "
        print(f"{r['drug']:22.22s}{r['ae']:28.28s}"
              f"{r['a']:>4d}{r['b']:>6d}{r['c']:>6d}{r['d']:>7d}"
              f"{r['ROR']:>8.2f}{ci:>16s}{r['PRR']:>8.2f}"
              f"{r['IC']:>8.2f}{r['EBGM']:>8.2f}"
              f"{r['cross_source_n']:>5d}{r['literature_n']:>5d}{sig:>5s}")
    print(f"\n  signal threshold: ROR >= {args.threshold or 2.0}, n>=3, lower 95% CI >= 1")
    print(f"  rows displayed   : {len(rows)}")
    conn.close()


def _kor_extrapolation(conn, drug: str):
    cur = conn.cursor()
    row = cur.execute(
        "SELECT bmi_mean, age_mean, kor_score FROM cohort WHERE drug=?", (drug,)
    ).fetchone()
    if not row:
        return None
    bmi, age, score = row
    note = []
    if bmi >= 28:
        note.append(f"BMI {bmi} (한국 평균 약 25 대비 높음 → 외삽 시 약효/이상반응 과대평가 우려)")
    elif bmi <= 23:
        note.append(f"BMI {bmi} (한국 평균보다 낮음)")
    if age >= 65:
        note.append(f"평균 연령 {age}세 (고령군 비중 ↑, 신기능·DDI 주의)")
    return {
        "BMI_mean": bmi, "age_mean": age,
        "kor_extrapolation_score": score,
        "notes": "; ".join(note) if note else "한국 외삽 호환성 양호 추정",
    }


def cmd_digest(args):
    if not DB_PATH.exists():
        print("[!] DB가 없습니다. 먼저 'ingest'를 실행하세요.")
        sys.exit(2)
    conn = db_connect()
    cur = conn.cursor()

    drugs = [r[0] for r in cur.execute(
        "SELECT DISTINCT drug FROM reports ORDER BY drug").fetchall()]

    findings = []
    for drug in drugs:
        aes = [r[0] for r in cur.execute(
            "SELECT DISTINCT ae FROM reports WHERE drug=?", (drug,)).fetchall()]
        for ae in aes:
            a, b, c, d = compute_2x2_for_drug_ae(conn, drug, ae)
            stat = disprop_2x2(a, b, c, d)
            if not signal_flag(stat):
                continue
            sources_pos = []
            for s in ("faers", "eudra", "vigibase", "kaers"):
                a_s, b_s, c_s, d_s = compute_2x2_for_drug_ae(conn, drug, ae, [s])
                if a_s >= 1:
                    if signal_flag(disprop_2x2(a_s, b_s, c_s, d_s),
                                   n_min=1, ror_min=2.0):
                        sources_pos.append(s)
            lit_count = cur.execute(
                "SELECT COUNT(*) FROM literature WHERE drug=? AND ae=?",
                (drug, ae)).fetchone()[0]
            findings.append((drug, ae, stat, sources_pos, lit_count))

    findings.sort(key=lambda x: (-len(x[3]), -x[2]["ROR"]))
    findings = findings[:15]

    drug_kor = {r[0]: r[1] for r in cur.execute("SELECT english, korean FROM drug_map")}
    ae_kor   = {r[0]: r[1] for r in cur.execute("SELECT english, korean FROM ae_map")}

    today = datetime.now().strftime("%Y-%m-%d")
    use_kor = args.korean

    out = []
    if use_kor:
        out.append("=" * 70)
        out.append(f" DMAESentinel-Kor 주간 안전성 다이제스트 ({today})")
        out.append("=" * 70)
        out.append("")
        out.append("[디스클레이머] 본 보고서는 합성 데이터 기반 연구·참고용 산출물입니다.")
        out.append("실제 임상 의사결정에 단독 근거로 사용하지 마십시오.")
        out.append("")
        out.append(f"이번 주 신규/주목 시그널 상위 {len(findings)}건")
        out.append("-" * 70)
        for i, (drug, ae, stat, srcs, lit) in enumerate(findings, 1):
            kdrug = drug_kor.get(drug, drug)
            kae   = ae_kor.get(ae, ae)
            out.append(f"{i:>2d}. {kdrug} ({drug}) – {kae} ({ae})")
            out.append(f"     ROR={stat['ROR']:.2f} (95%CI {stat['ROR_lo']:.2f}-{stat['ROR_hi']:.2f}), "
                       f"PRR={stat['PRR']:.2f}, IC={stat['IC']:.2f}, EBGM={stat['EBGM']:.2f}")
            out.append(f"     n(보고)={stat['a']}, 교차 source={len(srcs)}개 [{', '.join(srcs) or '없음'}], "
                       f"문헌 n={lit}")
            kex = _kor_extrapolation(conn, drug)
            if kex:
                out.append(f"     한국 외삽: 점수 {kex['kor_extrapolation_score']:.1f}/100 – {kex['notes']}")
            out.append("")
        out.append("[Cross-discipline 알림 후보]")
        for drug, ae, stat, srcs, lit in findings:
            if ae in ("NAION", "suicidal ideation"):
                out.append(f"  - 안과/정신과 협진 권고: {drug_kor.get(drug,drug)} ↔ {ae_kor.get(ae,ae)}")
            if ae == "sarcopenia":
                out.append(f"  - 재활/노인의학 협진 권고: {drug_kor.get(drug,drug)} ↔ 근감소증")
            if ae in ("DKA", "euglycemic DKA"):
                out.append(f"  - 응급/마취과 알림: {drug_kor.get(drug,drug)} ↔ DKA (수술 전 휴약 검토)")
            if ae == "Fournier gangrene":
                out.append(f"  - 비뇨기과/외과 알림: {drug_kor.get(drug,drug)} ↔ 푸르니에 괴저")
            if ae == "hyperkalemia":
                out.append(f"  - 신장내과 협진: {drug_kor.get(drug,drug)} ↔ 고칼륨혈증")
        out.append("")
        out.append("문의: pharmacovigilance@hospital.local")
    else:
        out.append("=" * 70)
        out.append(f" DMAESentinel-Kor weekly safety digest ({today})")
        out.append("=" * 70)
        out.append("DISCLAIMER: research/reference use only on synthetic data.")
        out.append("")
        for i, (drug, ae, stat, srcs, lit) in enumerate(findings, 1):
            out.append(f"{i:>2d}. {drug} – {ae} | ROR={stat['ROR']:.2f} "
                       f"(CI {stat['ROR_lo']:.2f}-{stat['ROR_hi']:.2f}) "
                       f"IC={stat['IC']:.2f} sources={len(srcs)} lit={lit}")
    print("\n".join(out))
    conn.close()


def cmd_psur(args):
    if not DB_PATH.exists():
        print("[!] DB가 없습니다. 먼저 'ingest'를 실행하세요.")
        sys.exit(2)
    conn = db_connect()
    cur = conn.cursor()
    drug = args.drug
    drow = cur.execute(
        "SELECT english, korean, mfds, atc, class FROM drug_map WHERE english=?", (drug,)
    ).fetchone()
    if not drow:
        print(f"[!] 약물 '{drug}'을(를) 찾을 수 없습니다. (예: empagliflozin, semaglutide)")
        conn.close()
        sys.exit(2)

    aes = [r[0] for r in cur.execute(
        "SELECT DISTINCT ae FROM reports WHERE drug=?", (drug,)).fetchall()]
    rows = []
    total_n = 0
    for ae in aes:
        a, b, c, d = compute_2x2_for_drug_ae(conn, drug, ae)
        stat = disprop_2x2(a, b, c, d)
        rows.append((ae, stat))
        total_n += a
    rows.sort(key=lambda x: -x[1]["ROR"])

    cohort = cur.execute(
        "SELECT n_patients, n_AE_events, bmi_mean, age_mean, kor_score "
        "FROM cohort WHERE drug=?", (drug,)).fetchone()

    today = datetime.now().strftime("%Y-%m-%d")
    print("=" * 78)
    print(f" 식품의약품안전처 정기적 약물감시보고서(PSUR) 자동 초안 - {drow['korean']} ({drug})")
    print("=" * 78)
    print(f"  생성일                  : {today}")
    print(f"  MFDS code               : {drow['mfds']}")
    print(f"  ATC                     : {drow['atc']}")
    print(f"  계열                    : {drow['class']}")
    print()
    print("[Form 1] 보고기간 요약")
    print("-" * 78)
    print(f"  보고기간                : 2025-Q1 ~ 2026-Q1 (mock)")
    print(f"  국내·외 자발보고 합계   : {total_n} (FAERS+EudraVigilance+VigiBase+KAERS)")
    if cohort:
        print(f"  사이트 cohort 환자 수   : {cohort[0]}")
        print(f"  사이트 cohort AE 건수   : {cohort[1]}")
        print(f"  cohort BMI mean         : {cohort[2]}")
        print(f"  cohort age mean         : {cohort[3]}")
        print(f"  한국 외삽 호환성 점수   : {cohort[4]:.1f}/100")
    print()
    print("[Form 2] 신규/중요 안전성 정보 (top 10 by ROR)")
    print("-" * 78)
    print(f"  {'AE_term':30s}{'a':>4s}{'ROR':>8s}{'95%CI':>16s}{'IC':>8s}{'signal':>8s}")
    ae_kor = {r[0]: r[1] for r in cur.execute("SELECT english, korean FROM ae_map")}
    flagged = []
    for ae, st in rows[:10]:
        ci = f"{st['ROR_lo']:.2f}-{st['ROR_hi']:.2f}"
        sig = "Y" if signal_flag(st) else "N"
        if sig == "Y":
            flagged.append((ae, ae_kor.get(ae, ae), st))
        print(f"  {ae:30.30s}{st['a']:>4d}{st['ROR']:>8.2f}{ci:>16s}"
              f"{st['IC']:>8.2f}{sig:>8s}")
    print()
    print("[Form 3] 시그널 평가 및 권고 (한국어)")
    print("-" * 78)
    if flagged:
        for ae, kae, st in flagged:
            print(f"  - {kae} ({ae}): ROR {st['ROR']:.2f} (95%CI "
                  f"{st['ROR_lo']:.2f}-{st['ROR_hi']:.2f}), 추가 평가 권고.")
    else:
        print("  - 본 기간 내 임계치(ROR≥2, n≥3, 하한 CI≥1) 초과 시그널 없음.")
    print()
    print("[디스클레이머]")
    print("  본 PSUR 초안은 합성 데이터 및 disproportionality 통계 기반 자동 생성물입니다.")
    print("  실제 식약처 제출 전 한국의약품안전관리원·약물역학팀 검토를 거쳐야 합니다.")
    print("  연구·참고용. 임상 단독 의사결정 근거 금지.")
    conn.close()


def cmd_alert(args):
    if not DB_PATH.exists():
        print("[!] DB가 없습니다. 먼저 'ingest'를 실행하세요.")
        sys.exit(2)
    conn = db_connect()
    cur = conn.cursor()

    drug_kor = {r[0]: r[1] for r in cur.execute("SELECT english, korean FROM drug_map")}
    ae_kor   = {r[0]: r[1] for r in cur.execute("SELECT english, korean FROM ae_map")}

    discipline_map = {
        "NAION": "안과",
        "suicidal ideation": "정신건강의학과",
        "sarcopenia": "재활의학/노인의학",
        "DKA": "응급/마취",
        "euglycemic DKA": "응급/마취",
        "Fournier gangrene": "비뇨기/일반외과",
        "hyperkalemia": "신장내과",
        "lower limb amputation": "혈관외과",
        "thyroid C-cell tumor": "내분비외과",
        "pancreatitis": "소화기내과",
        "biliary disease": "소화기내과",
        "gastric stasis": "소화기내과",
        "severe hypoglycemia": "응급/내분비",
    }

    drugs = [r[0] for r in cur.execute("SELECT DISTINCT drug FROM reports").fetchall()]
    alerts = []
    for drug in drugs:
        aes = [r[0] for r in cur.execute(
            "SELECT DISTINCT ae FROM reports WHERE drug=?", (drug,)).fetchall()]
        for ae in aes:
            a, b, c, d = compute_2x2_for_drug_ae(conn, drug, ae)
            stat = disprop_2x2(a, b, c, d)
            if not signal_flag(stat):
                continue
            srcs = []
            for s in ("faers", "eudra", "vigibase", "kaers"):
                a_s, b_s, c_s, d_s = compute_2x2_for_drug_ae(conn, drug, ae, [s])
                if a_s >= 1 and signal_flag(disprop_2x2(a_s, b_s, c_s, d_s),
                                            n_min=1, ror_min=2.0):
                    srcs.append(s)
            alerts.append((drug, ae, stat, srcs))
    alerts.sort(key=lambda x: (-len(x[3]), -x[2]["ROR"]))

    print("=" * 78)
    print(" DMAESentinel-Kor cross-source / cross-discipline alert")
    print("=" * 78)
    if args.cross_discipline:
        print(" [모드] cross-discipline (협진 의뢰 후보 표시)")
    print(f" 검출된 시그널 후보       : {len(alerts)}")
    print()
    for drug, ae, st, srcs in alerts[:25]:
        kdrug = drug_kor.get(drug, drug)
        kae = ae_kor.get(ae, ae)
        disc = discipline_map.get(ae, "-")
        line = (f" - {kdrug} ({drug}) ↔ {kae} ({ae}) | "
                f"ROR {st['ROR']:.2f} (95%CI {st['ROR_lo']:.2f}-{st['ROR_hi']:.2f}) "
                f"| sources={len(srcs)} [{','.join(srcs) or '-'}]")
        if args.cross_discipline:
            line += f" | 협진: {disc}"
        print(line)
    print()
    print("[디스클레이머] 합성 데이터 기반 연구·참고용. 실 임상 알림으로 사용 금지.")
    conn.close()


# ----------------------------------------------------------------------------
# 6) argparse
# ----------------------------------------------------------------------------

def build_parser():
    p = argparse.ArgumentParser(
        prog="dmae-sentinel-kor",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description=textwrap.dedent("""\
            DMAESentinel-Kor — 한국 당뇨 약물 post-marketing AE / safety signal 도구.

            FAERS / EudraVigilance / VigiBase / KAERS / 사이트 cohort / PubMed case report /
            학회 abstract → ROR/PRR/IC/EBGM disproportionality + 한국어 weekly digest +
            MFDS PSUR Form 자동 초안.

            참고용·연구용. 임상 단독 의사결정 근거로 사용 금지.
        """),
    )
    sub = p.add_subparsers(dest="cmd", required=True)

    sp = sub.add_parser("ingest", help="합성 mock data 생성 + SQLite 적재")
    sp.add_argument("--source",
                    choices=SOURCES + ["all"],
                    default="all",
                    help="적재할 source (default: all)")

    sp = sub.add_parser("analyze", help="ROR/PRR/IC/EBGM disproportionality 산출")
    sp.add_argument("--drug", help="특정 약물(영문) 한정")
    sp.add_argument("--top", type=int, default=20, help="상위 N (default 20)")
    sp.add_argument("--threshold", type=float, default=2.0,
                    help="ROR signal threshold (default 2.0)")

    sp = sub.add_parser("digest", help="weekly safety digest 출력")
    sp.add_argument("--korean", action="store_true", help="한국어 출력 (기본 켜기 권장)")

    sp = sub.add_parser("psur", help="MFDS PSUR Form 1/2/3 자동 draft 출력")
    sp.add_argument("--drug", required=True, help="약물 영문명 (예: empagliflozin)")

    sp = sub.add_parser("alert", help="cross-source / cross-discipline alert 시뮬레이션")
    sp.add_argument("--cross-discipline", action="store_true",
                    help="cross-discipline 협진 후보 함께 표시")

    return p


def main(argv=None):
    parser = build_parser()
    args = parser.parse_args(argv)
    if args.cmd == "ingest":
        cmd_ingest(args)
    elif args.cmd == "analyze":
        cmd_analyze(args)
    elif args.cmd == "digest":
        cmd_digest(args)
    elif args.cmd == "psur":
        cmd_psur(args)
    elif args.cmd == "alert":
        cmd_alert(args)


if __name__ == "__main__":
    main()
