"""GLP-1RA class-특이 signal panel 7개 로직.

참고용·연구용 — Not for clinical decision. Atlanta classification은 단순 mock 구현이며
실제 임상 판정은 영상/검사/조직 종합 후 가능하다.
"""
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any


PANELS = [
    "pancreatitis", "thyroid", "gallbladder", "gi",
    "injection_site", "retinopathy", "suicidality",
]

PANEL_LABEL_KO = {
    "pancreatitis": "췌장염",
    "thyroid": "갑상선",
    "gallbladder": "담낭염/담석",
    "gi": "GI AE",
    "injection_site": "주사부위",
    "retinopathy": "망막병증",
    "suicidality": "Suicidality",
}


@dataclass
class PanelResult:
    panel: str
    label_ko: str
    counts_by_arm: dict[str, int] = field(default_factory=dict)
    n_serious: int = 0
    early_count: int = 0  # onset_week < 12
    late_count: int = 0   # onset_week >= 12
    detail: dict[str, Any] = field(default_factory=dict)


def _subject_followup_total(ae_rows: list[dict]) -> dict[str, float]:
    """arm별 총 person-weeks (subject 단위, 중복 제거)."""
    seen: dict[str, set[str]] = defaultdict(set)
    weeks: dict[str, float] = defaultdict(float)
    for r in ae_rows:
        sid = r["subject_id_hash"]
        if sid in seen[r["arm"]]:
            continue
        seen[r["arm"]].add(sid)
        weeks[r["arm"]] += float(r.get("followup_weeks", 52) or 52)
    return weeks


def panel_pancreatitis(ae_rows: list[dict], lab_rows: list[dict]) -> PanelResult:
    """췌장염 panel — Atlanta classification mock + amylase/lipase 시계열.

    Atlanta classification mock rules (참고용):
      severe   : CTCAE severity == 5 OR outcome == 'fatal' OR outcome == 'sequelae'
      moderate : CTCAE severity == 3 또는 4 AND not severe
      mild     : 그 외
    """
    panc = [r for r in ae_rows if r["panel"] == "pancreatitis"]
    res = PanelResult(panel="pancreatitis", label_ko="췌장염")
    res.detail["atlanta"] = {"mild": 0, "moderate": 0, "severe": 0}
    for r in panc:
        sev = int(r.get("severity_ctcae", 1) or 1)
        outcome = (r.get("outcome") or "").lower()
        if sev >= 5 or outcome in ("fatal", "sequelae"):
            atl = "severe"
        elif sev in (3, 4):
            atl = "moderate"
        else:
            atl = "mild"
        res.detail["atlanta"][atl] += 1
        res.counts_by_arm[r["arm"]] = res.counts_by_arm.get(r["arm"], 0) + 1
        if int(r.get("serious", 0) or 0):
            res.n_serious += 1
        if int(r.get("onset_week", 0) or 0) < 12:
            res.early_count += 1
        else:
            res.late_count += 1

    # Lab signal: subjects with lipase >= 3x ULN at any timepoint
    over_uln: dict[str, set[str]] = defaultdict(set)
    for lab in lab_rows:
        if not lab.get("lipase_U_L"):
            continue
        try:
            lip = float(lab["lipase_U_L"])
            uln = float(lab.get("lipase_uln") or 160)
        except (TypeError, ValueError):
            continue
        if lip >= 3 * uln:
            over_uln[lab["arm"]].add(lab["subject_id_hash"])
    res.detail["lipase_3xULN_subjects"] = {arm: len(s) for arm, s in over_uln.items()}
    return res


def panel_thyroid(ae_rows: list[dict], lab_rows: list[dict]) -> PanelResult:
    res = PanelResult(panel="thyroid", label_ko="갑상선")
    thy = [r for r in ae_rows if r["panel"] == "thyroid"]
    for r in thy:
        res.counts_by_arm[r["arm"]] = res.counts_by_arm.get(r["arm"], 0) + 1
        if int(r.get("serious", 0) or 0):
            res.n_serious += 1
        if int(r.get("onset_week", 0) or 0) < 12:
            res.early_count += 1
        else:
            res.late_count += 1

    # Calcitonin >= 50 pg/mL flag
    flag_subjects: dict[str, set[str]] = defaultdict(set)
    us_trigger: dict[str, set[str]] = defaultdict(set)
    for lab in lab_rows:
        cal = lab.get("calcitonin_pg_mL")
        if cal in (None, ""):
            continue
        try:
            v = float(cal)
        except (TypeError, ValueError):
            continue
        if v >= 50:
            flag_subjects[lab["arm"]].add(lab["subject_id_hash"])
        if v >= 20:
            us_trigger[lab["arm"]].add(lab["subject_id_hash"])
    res.detail["calcitonin_ge50"] = {arm: len(s) for arm, s in flag_subjects.items()}
    res.detail["us_trigger_ge20"] = {arm: len(s) for arm, s in us_trigger.items()}
    return res


def _generic_panel(ae_rows: list[dict], panel: str, label_ko: str) -> PanelResult:
    res = PanelResult(panel=panel, label_ko=label_ko)
    for r in ae_rows:
        if r["panel"] != panel:
            continue
        res.counts_by_arm[r["arm"]] = res.counts_by_arm.get(r["arm"], 0) + 1
        if int(r.get("serious", 0) or 0):
            res.n_serious += 1
        if int(r.get("onset_week", 0) or 0) < 12:
            res.early_count += 1
        else:
            res.late_count += 1
    return res


def panel_gallbladder(ae_rows: list[dict]) -> PanelResult:
    res = _generic_panel(ae_rows, "gallbladder", "담낭염/담석")
    # cholecystectomy count
    n_chole = sum(1 for r in ae_rows
                  if r["panel"] == "gallbladder" and r.get("pt_term") == "Cholecystectomy")
    res.detail["cholecystectomy"] = n_chole
    return res


def panel_gi(ae_rows: list[dict]) -> PanelResult:
    res = _generic_panel(ae_rows, "gi", "GI AE")
    # dose interruption rate per arm
    interruption_codes = {"dose_reduced", "dose_interrupted", "withdrawn"}
    by_arm_total: dict[str, int] = defaultdict(int)
    by_arm_interrupt: dict[str, int] = defaultdict(int)
    by_site: dict[str, int] = defaultdict(int)
    by_site_interrupt: dict[str, int] = defaultdict(int)
    for r in ae_rows:
        if r["panel"] != "gi":
            continue
        by_arm_total[r["arm"]] += 1
        by_site[r["site_mask"]] += 1
        if r.get("action_taken") in interruption_codes:
            by_arm_interrupt[r["arm"]] += 1
            by_site_interrupt[r["site_mask"]] += 1
    res.detail["interruption_rate_by_arm"] = {
        a: (by_arm_interrupt[a] / by_arm_total[a] if by_arm_total[a] else 0.0)
        for a in by_arm_total
    }
    # site heterogeneity: range of interruption rates
    site_rates = [
        (by_site_interrupt[s] / by_site[s]) for s in by_site if by_site[s] >= 5
    ]
    if site_rates:
        res.detail["site_interruption_min"] = round(min(site_rates), 3)
        res.detail["site_interruption_max"] = round(max(site_rates), 3)
        res.detail["site_interruption_range"] = round(max(site_rates) - min(site_rates), 3)
    return res


def panel_injection_site(ae_rows: list[dict]) -> PanelResult:
    return _generic_panel(ae_rows, "injection_site", "주사부위")


def panel_retinopathy(ae_rows: list[dict]) -> PanelResult:
    res = _generic_panel(ae_rows, "retinopathy", "망막병증")
    # T2DM subgroup overrepresentation
    by_arm_t2dm: dict[str, int] = defaultdict(int)
    by_arm_total: dict[str, int] = defaultdict(int)
    for r in ae_rows:
        if r["panel"] != "retinopathy":
            continue
        by_arm_total[r["arm"]] += 1
        if int(r.get("t2dm", 0) or 0):
            by_arm_t2dm[r["arm"]] += 1
    res.detail["t2dm_proportion_by_arm"] = {
        a: (by_arm_t2dm[a] / by_arm_total[a] if by_arm_total[a] else 0.0)
        for a in by_arm_total
    }
    res.detail["historical_signal"] = "SUSTAIN-6 (Marso 2016)"
    return res


def panel_suicidality(ae_rows: list[dict]) -> PanelResult:
    res = _generic_panel(ae_rows, "suicidality", "Suicidality")
    # 7-day SUSAR candidate: serious + onset
    susar = sum(1 for r in ae_rows
                if r["panel"] == "suicidality" and int(r.get("serious", 0) or 0) == 1)
    res.detail["susar_candidate"] = susar
    return res


def compute_all_panels(ae_rows: list[dict], lab_rows: list[dict]) -> dict[str, PanelResult]:
    return {
        "pancreatitis": panel_pancreatitis(ae_rows, lab_rows),
        "thyroid": panel_thyroid(ae_rows, lab_rows),
        "gallbladder": panel_gallbladder(ae_rows),
        "gi": panel_gi(ae_rows),
        "injection_site": panel_injection_site(ae_rows),
        "retinopathy": panel_retinopathy(ae_rows),
        "suicidality": panel_suicidality(ae_rows),
    }


def panel_incidence(ae_rows: list[dict], panel: str) -> dict[str, dict]:
    """arm별 incidence 계산 (subjects with >=1 event / total subjects in arm).

    rough N: subject_id_hash 유니크 카운트
    """
    arms_subjects: dict[str, set[str]] = defaultdict(set)
    arms_cases: dict[str, set[str]] = defaultdict(set)
    for r in ae_rows:
        arms_subjects[r["arm"]].add(r["subject_id_hash"])
        if r["panel"] == panel:
            arms_cases[r["arm"]].add(r["subject_id_hash"])
    out: dict[str, dict] = {}
    for arm, subs in arms_subjects.items():
        n = len(subs)
        cases = len(arms_cases.get(arm, set()))
        out[arm] = {
            "n_subjects": n,
            "n_cases": cases,
            "incidence": cases / n if n else 0.0,
        }
    return out
