"""ERAS Society bariatric protocol compliance KPI calculator.

Computes per-patient and ward-aggregated ERAS bariatric pathway compliance
across 4 phases (preop / intraop / POD0-3 / POD4-30) and produces:
- per-patient bundle compliance score (% elements met)
- ward-level radar values (mean compliance per phase)
- MBSAQIP-analog quality measures
- KASMBS Korean-localized quality summary
"""
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Tuple

from .ingest import (
    Patient,
    POD03Row,
    POD430Row,
    IntraopRow,
)

# Element groupings — match modules.ingest constants
PREOP_KEYS = [
    "eras_vte_prophy_preop",
    "eras_smoking_cess_4wk",
    "eras_nutrition_assess",
    "eras_osa_screen",
    "eras_carb_loading",
    "eras_preop_edu",
]
INTRAOP_KEYS = [
    "eras_lap_or_robot",
    "eras_glycemia_lt180",
    "eras_ppv_lung",
    "eras_normothermia",
    "eras_antiemetic",
]
POD03_KEYS = [
    "eras_mob_pod1",
    "eras_oral_pod1",
    "eras_vte_postop",
    "eras_mm_analgesia",
    "eras_no_routine_ngt",
]
POD430_KEYS = [
    "eras_protein_60g",
    "eras_multivitamin",
    "eras_iron_b12_pod30",
    "eras_weight_fu_pod30",
    "eras_ewl_5pct",
]


@dataclass
class ERASBundle:
    patient_id: str
    ward: str
    procedure: str
    surgeon: str
    preop_pct: float
    intraop_pct: float
    pod03_pct: float
    pod430_pct: float
    overall_pct: float


@dataclass
class WardERASRadar:
    ward: str
    n_patients: int
    preop_pct: float
    intraop_pct: float
    pod03_pct: float
    pod430_pct: float
    overall_pct: float


@dataclass
class MBSAQIPMeasure:
    measure: str          # e.g. "30-day mortality"
    rate_pct: float
    n_events: int
    n_denominator: int
    target_pct: float     # ASMBS/MBSAQIP target


def _pct(num: int, denom: int) -> float:
    return round(100.0 * num / denom, 1) if denom else 0.0


def _bundle_pct(obj, keys: List[str]) -> float:
    if not keys:
        return 0.0
    hits = sum(1 for k in keys if getattr(obj, k, False))
    return round(100.0 * hits / len(keys), 1)


def compute_patient_bundles(patients: List[Patient],
                            intraop: List[IntraopRow],
                            pod03: List[POD03Row],
                            pod430: List[POD430Row]) -> List[ERASBundle]:
    intraop_map = {r.patient_id: r for r in intraop}
    pod03_map = {r.patient_id: r for r in pod03}
    pod430_map = {r.patient_id: r for r in pod430}

    out: List[ERASBundle] = []
    for p in patients:
        ir = intraop_map.get(p.patient_id)
        p3 = pod03_map.get(p.patient_id)
        p4 = pod430_map.get(p.patient_id)
        if not (ir and p3 and p4):
            continue
        pre = _bundle_pct(p, PREOP_KEYS)
        intra = _bundle_pct(ir, INTRAOP_KEYS)
        p03 = _bundle_pct(p3, POD03_KEYS)
        p430 = _bundle_pct(p4, POD430_KEYS)
        overall = round((pre + intra + p03 + p430) / 4.0, 1)
        out.append(ERASBundle(
            patient_id=p.patient_id, ward=p.ward,
            procedure=p.procedure, surgeon=p.surgeon,
            preop_pct=pre, intraop_pct=intra,
            pod03_pct=p03, pod430_pct=p430,
            overall_pct=overall,
        ))
    return out


def ward_radar(bundles: List[ERASBundle]) -> List[WardERASRadar]:
    grp: Dict[str, List[ERASBundle]] = {}
    for b in bundles:
        grp.setdefault(b.ward, []).append(b)

    def avg(xs: List[float]) -> float:
        return round(sum(xs) / len(xs), 1) if xs else 0.0

    out: List[WardERASRadar] = []
    for w in sorted(grp):
        lst = grp[w]
        out.append(WardERASRadar(
            ward=w, n_patients=len(lst),
            preop_pct=avg([b.preop_pct for b in lst]),
            intraop_pct=avg([b.intraop_pct for b in lst]),
            pod03_pct=avg([b.pod03_pct for b in lst]),
            pod430_pct=avg([b.pod430_pct for b in lst]),
            overall_pct=avg([b.overall_pct for b in lst]),
        ))
    return out


def mbsaqip_measures(patients: List[Patient],
                     pod03: List[POD03Row],
                     pod430: List[POD430Row]) -> List[MBSAQIPMeasure]:
    """A small panel of MBSAQIP-analog quality measures with ASMBS targets."""
    n = len(patients)
    n_leak = sum(1 for r in pod03 if r.leak or r.staple_leak)
    n_bleed = sum(1 for r in pod03 if r.bleeding)
    n_vte = sum(1 for r in pod03 if r.vte)
    n_reop = sum(1 for r in pod03 if r.reop_acute)
    n_readmit = sum(1 for r in pod430 if r.readmit_30d)
    n_mort = sum(1 for p in patients if p.died_30d)
    n_marg_ulcer = sum(1 for r in pod430 if r.marginal_ulcer)

    return [
        MBSAQIPMeasure("Anastomotic / staple leak (30d)",
                       _pct(n_leak, n), n_leak, n, target_pct=1.5),
        MBSAQIPMeasure("Bleeding requiring transfusion (30d)",
                       _pct(n_bleed, n), n_bleed, n, target_pct=2.5),
        MBSAQIPMeasure("VTE (30d)",
                       _pct(n_vte, n), n_vte, n, target_pct=1.0),
        MBSAQIPMeasure("Acute reoperation (30d)",
                       _pct(n_reop, n), n_reop, n, target_pct=2.0),
        MBSAQIPMeasure("30-day readmission",
                       _pct(n_readmit, n), n_readmit, n, target_pct=5.0),
        MBSAQIPMeasure("30-day mortality",
                       _pct(n_mort, n), n_mort, n, target_pct=0.3),
        MBSAQIPMeasure("Marginal ulcer (30d)",
                       _pct(n_marg_ulcer, n), n_marg_ulcer, n, target_pct=3.0),
    ]


def overall_compliance_ranking(bundles: List[ERASBundle],
                               by: str = "ward") -> List[Tuple[str, float, int]]:
    """Rank by ward or surgeon or procedure on mean overall compliance."""
    grp: Dict[str, List[float]] = {}
    for b in bundles:
        key = getattr(b, by)
        grp.setdefault(key, []).append(b.overall_pct)
    rows = [(k, round(sum(v) / len(v), 1), len(v)) for k, v in grp.items()]
    rows.sort(key=lambda x: -x[1])
    return rows
