"""IRR + 95% CI + subgroup / dose-response / onset 분포.

참고용·연구용 — Not for clinical decision.

IRR (Incidence Rate Ratio):
   IRR = (cases_drug / PT_drug) / (cases_ref / PT_ref)
   log(IRR) SE ~ sqrt(1/cases_drug + 1/cases_ref) (Poisson approx)
"""
from __future__ import annotations

import math
from collections import defaultdict
from dataclasses import dataclass


@dataclass
class IRR:
    panel: str
    drug_arm: str
    ref_arm: str
    cases_drug: int
    pt_drug_weeks: float
    cases_ref: int
    pt_ref_weeks: float
    irr: float
    lci: float
    uci: float

    def to_dict(self) -> dict:
        return {
            "panel": self.panel,
            "drug_arm": self.drug_arm,
            "ref_arm": self.ref_arm,
            "cases_drug": self.cases_drug,
            "pt_drug_weeks": round(self.pt_drug_weeks, 1),
            "cases_ref": self.cases_ref,
            "pt_ref_weeks": round(self.pt_ref_weeks, 1),
            "IRR": round(self.irr, 3) if not math.isnan(self.irr) else None,
            "IRR_LCI": round(self.lci, 3) if not math.isnan(self.lci) else None,
            "IRR_UCI": round(self.uci, 3) if not math.isnan(self.uci) else None,
        }


def _arm_person_weeks(ae_rows: list[dict]) -> dict[str, float]:
    """arm별 총 person-weeks (subject 중복 제거)."""
    seen: dict[str, set[str]] = defaultdict(set)
    weeks: dict[str, float] = defaultdict(float)
    for r in ae_rows:
        sid = r["subject_id_hash"]
        if sid in seen[r["arm"]]:
            continue
        seen[r["arm"]].add(sid)
        try:
            weeks[r["arm"]] += float(r.get("followup_weeks", 52) or 52)
        except (ValueError, TypeError):
            weeks[r["arm"]] += 52
    return weeks


def compute_irr(panel: str, ae_rows: list[dict],
                drug_arm: str, ref_arm: str = "placebo") -> IRR:
    pw = _arm_person_weeks(ae_rows)
    cases_drug = sum(1 for r in ae_rows if r["panel"] == panel and r["arm"] == drug_arm)
    cases_ref = sum(1 for r in ae_rows if r["panel"] == panel and r["arm"] == ref_arm)
    pw_drug = pw.get(drug_arm, 0.0)
    pw_ref = pw.get(ref_arm, 0.0)
    if pw_drug == 0 or pw_ref == 0:
        return IRR(panel, drug_arm, ref_arm, cases_drug, pw_drug, cases_ref, pw_ref,
                   float("nan"), float("nan"), float("nan"))
    # continuity correction
    a = cases_drug if cases_drug > 0 else 0.5
    b = cases_ref if cases_ref > 0 else 0.5
    rate_drug = a / pw_drug
    rate_ref = b / pw_ref
    irr = rate_drug / rate_ref
    se_log = math.sqrt(1 / a + 1 / b)
    lci = math.exp(math.log(irr) - 1.96 * se_log)
    uci = math.exp(math.log(irr) + 1.96 * se_log)
    return IRR(panel, drug_arm, ref_arm, cases_drug, pw_drug, cases_ref, pw_ref,
               irr, lci, uci)


def all_panel_irrs(ae_rows: list[dict], panels: list[str],
                   drug_arms: list[str], ref_arm: str = "placebo") -> list[IRR]:
    out: list[IRR] = []
    for p in panels:
        for d in drug_arms:
            out.append(compute_irr(p, ae_rows, d, ref_arm))
    return out


def onset_distribution(ae_rows: list[dict], panel: str,
                       early_cutoff_wk: int = 12) -> dict[str, dict[str, int]]:
    """arm별 early/late onset 분포."""
    out: dict[str, dict[str, int]] = defaultdict(lambda: {"early": 0, "late": 0})
    for r in ae_rows:
        if r["panel"] != panel:
            continue
        try:
            ow = int(r.get("onset_week", 0) or 0)
        except ValueError:
            ow = 0
        bucket = "early" if ow < early_cutoff_wk else "late"
        out[r["arm"]][bucket] += 1
    return dict(out)


def dose_response(ae_rows: list[dict], panel: str, arm: str) -> dict:
    """dose_mg별 incidence."""
    sub_by_dose: dict[float, set[str]] = defaultdict(set)
    case_by_dose: dict[float, set[str]] = defaultdict(set)
    for r in ae_rows:
        if r["arm"] != arm:
            continue
        try:
            dose = float(r.get("dose_mg") or 0)
        except (ValueError, TypeError):
            dose = 0.0
        sub_by_dose[dose].add(r["subject_id_hash"])
        if r["panel"] == panel:
            case_by_dose[dose].add(r["subject_id_hash"])
    out = {}
    for dose, subs in sorted(sub_by_dose.items()):
        n = len(subs)
        cases = len(case_by_dose.get(dose, set()))
        out[dose] = {"n": n, "cases": cases, "incidence": cases / n if n else 0.0}
    return out


def subgroup_incidence(ae_rows: list[dict], panel: str,
                       subgroup_key: str) -> dict[str, dict]:
    """subgroup_key별 arm-stratified incidence.

    subgroup_key examples:
        baseline_bmi -> binned (<35, 35-40, >=40)
        age          -> binned (<50, 50-65, >=65)
        sex          -> M/F
        t2dm         -> 0/1
        prior_pancreatitis -> 0/1
    """
    def _bucket(r: dict) -> str:
        if subgroup_key == "baseline_bmi":
            try:
                v = float(r.get("baseline_bmi") or 0)
            except (ValueError, TypeError):
                v = 0
            if v < 35: return "BMI<35"
            if v < 40: return "BMI 35-40"
            return "BMI>=40"
        if subgroup_key == "age":
            try:
                v = int(r.get("age") or 0)
            except (ValueError, TypeError):
                v = 0
            if v < 50: return "age<50"
            if v < 65: return "age 50-64"
            return "age>=65"
        if subgroup_key == "sex":
            return f"sex={r.get('sex','?')}"
        if subgroup_key == "t2dm":
            return f"t2dm={int(r.get('t2dm', 0) or 0)}"
        if subgroup_key == "prior_pancreatitis":
            return f"priorPanc={int(r.get('prior_pancreatitis', 0) or 0)}"
        return str(r.get(subgroup_key, "?"))

    subs: dict[tuple[str, str], set[str]] = defaultdict(set)
    cases: dict[tuple[str, str], set[str]] = defaultdict(set)
    for r in ae_rows:
        key = (r["arm"], _bucket(r))
        subs[key].add(r["subject_id_hash"])
        if r["panel"] == panel:
            cases[key].add(r["subject_id_hash"])
    out: dict[str, dict] = {}
    for (arm, bucket), s in subs.items():
        n = len(s)
        c = len(cases.get((arm, bucket), set()))
        out[f"{arm}|{bucket}"] = {
            "arm": arm, "bucket": bucket,
            "n": n, "cases": c, "incidence": c / n if n else 0.0,
        }
    return out
