"""FAERS background disproportionality (ROR/PRR/EBGM).

참고용·연구용 — Not for clinical decision. 합성 FAERS-like mini-CSV를 사용한다.

참고: van Puijenbroek 2002, DuMouchel 1999.
"""
from __future__ import annotations

import csv
import math
import os
from collections import defaultdict
from dataclasses import dataclass


DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")


def load_faers(path: str | None = None) -> list[dict]:
    if path is None:
        path = os.path.join(DATA_DIR, "synthetic_faers.csv")
    out: list[dict] = []
    with open(path, encoding="utf-8") as f:
        rd = csv.DictReader(f)
        for r in rd:
            r["count_drug_event"] = int(r["count_drug_event"])
            r["total_drug_reports"] = int(r["total_drug_reports"])
            out.append(r)
    return out


@dataclass
class Disproportion:
    drug: str
    pt_term: str
    panel: str
    a: int  # drug + event
    b: int  # drug + not event
    c: int  # other drug + event
    d: int  # other drug + not event
    ror: float
    ror_lci: float
    ror_uci: float
    prr: float
    chi2: float
    ebgm: float
    signal: bool

    def to_dict(self) -> dict:
        return {
            "drug": self.drug,
            "pt_term": self.pt_term,
            "panel": self.panel,
            "a": self.a, "b": self.b, "c": self.c, "d": self.d,
            "ROR": round(self.ror, 3),
            "ROR_LCI": round(self.ror_lci, 3),
            "ROR_UCI": round(self.ror_uci, 3),
            "PRR": round(self.prr, 3),
            "chi2": round(self.chi2, 2),
            "EBGM": round(self.ebgm, 3),
            "signal": self.signal,
        }


def build_2x2(faers: list[dict]) -> dict[tuple[str, str], dict[str, int]]:
    """Per (drug, pt) build 2x2: a,b,c,d."""
    # totals
    event_total: dict[str, int] = defaultdict(int)   # PT total across all drugs
    drug_total: dict[str, int] = defaultdict(int)    # drug total reports
    drug_event: dict[tuple[str, str], int] = {}      # a
    grand = 0
    for r in faers:
        key = (r["drug"], r["pt_term"])
        drug_event[key] = r["count_drug_event"]
        event_total[r["pt_term"]] += r["count_drug_event"]
        # avoid double-counting drug totals across PTs:
        # drug_total uses MAX of total_drug_reports per drug (constant per drug across rows)
        drug_total[r["drug"]] = max(drug_total[r["drug"]], r["total_drug_reports"])

    grand = sum(drug_total.values())
    out: dict[tuple[str, str], dict[str, int]] = {}
    for (drug, pt), a in drug_event.items():
        drug_n = drug_total[drug]
        pt_n = event_total[pt]
        b = max(0, drug_n - a)
        c = max(0, pt_n - a)
        d = max(0, grand - a - b - c)
        out[(drug, pt)] = {"a": a, "b": b, "c": c, "d": d}
    return out


def compute_ror(a: int, b: int, c: int, d: int) -> tuple[float, float, float]:
    """Reporting Odds Ratio with 95% CI (log-normal).

    Half-integer continuity correction if any cell is 0.
    """
    if min(a, b, c, d) == 0:
        a, b, c, d = a + 0.5, b + 0.5, c + 0.5, d + 0.5
    if b == 0 or c == 0:
        return (float("nan"), float("nan"), float("nan"))
    ror = (a * d) / (b * c)
    se_log = math.sqrt(1 / a + 1 / b + 1 / c + 1 / d)
    lci = math.exp(math.log(ror) - 1.96 * se_log)
    uci = math.exp(math.log(ror) + 1.96 * se_log)
    return (ror, lci, uci)


def compute_prr(a: int, b: int, c: int, d: int) -> float:
    """Proportional Reporting Ratio: (a/(a+b)) / (c/(c+d))."""
    denom1 = a + b
    denom2 = c + d
    if denom1 == 0 or denom2 == 0 or c == 0:
        return float("nan")
    p1 = a / denom1
    p2 = c / denom2
    if p2 == 0:
        return float("nan")
    return p1 / p2


def compute_chi2(a: int, b: int, c: int, d: int) -> float:
    """Pearson χ² for 2x2."""
    n = a + b + c + d
    if n == 0:
        return 0.0
    row1, row2 = a + b, c + d
    col1, col2 = a + c, b + d
    if row1 == 0 or row2 == 0 or col1 == 0 or col2 == 0:
        return 0.0
    exp_a = row1 * col1 / n
    exp_b = row1 * col2 / n
    exp_c = row2 * col1 / n
    exp_d = row2 * col2 / n
    chi2 = 0.0
    for o, e in ((a, exp_a), (b, exp_b), (c, exp_c), (d, exp_d)):
        if e > 0:
            chi2 += (o - e) ** 2 / e
    return chi2


def compute_ebgm(a: int, b: int, c: int, d: int) -> float:
    """EBGM (Empirical Bayes Geometric Mean) — simplified shrinkage estimator.

    Full DuMouchel 1999 GPS uses gamma-Poisson mixture; here we use a simple
    Bayesian shrinkage toward 1:
        E = (a + alpha) / (expected + alpha)
        expected = (a+b)*(a+c)/N
        alpha = 0.5
    EBGM = log2 prior-pulled RR.
    """
    n = a + b + c + d
    if n == 0:
        return float("nan")
    expected = (a + b) * (a + c) / n
    alpha = 0.5
    if expected <= 0:
        return float("nan")
    return (a + alpha) / (expected + alpha)


def evaluate_signal(prr: float, chi2: float, count: int,
                    prr_min: float = 2.0, chi2_min: float = 4.0,
                    count_min: int = 3) -> bool:
    if any(math.isnan(x) for x in (prr, chi2)):
        return False
    return prr >= prr_min and chi2 >= chi2_min and count >= count_min


def disproportionality_all(faers: list[dict]) -> list[Disproportion]:
    cells = build_2x2(faers)
    pt_panel = {r["pt_term"]: r["panel"] for r in faers}
    out: list[Disproportion] = []
    for (drug, pt), c in cells.items():
        a, b, cc, d = c["a"], c["b"], c["c"], c["d"]
        ror, lci, uci = compute_ror(a, b, cc, d)
        prr = compute_prr(a, b, cc, d)
        chi2 = compute_chi2(a, b, cc, d)
        ebgm = compute_ebgm(a, b, cc, d)
        sig = evaluate_signal(prr, chi2, a)
        out.append(Disproportion(
            drug=drug, pt_term=pt, panel=pt_panel.get(pt, "other"),
            a=a, b=b, c=cc, d=d,
            ror=ror, ror_lci=lci, ror_uci=uci,
            prr=prr, chi2=chi2, ebgm=ebgm, signal=sig,
        ))
    return out


def filter_glp1ra(disps: list[Disproportion]) -> list[Disproportion]:
    glp1 = {"semaglutide", "liraglutide", "tirzepatide", "dulaglutide",
            "exenatide", "orforglipron", "retatrutide", "cagrilintide_semaglutide"}
    return [d for d in disps if d.drug in glp1]
