"""Rule-based assist classifiers for MI / Stroke / HF / CV death.

These are *adjudication-support* heuristics — they do not replace the human
CEC adjudicator. Each function returns a structured result with the suggested
class, rationale flags, and a confidence score in [0, 1].

Definitions implemented:
- UDM 2018 MI Type 1–5 (Thygesen 2018)
- AHA/ASA 2013 stroke subtypes
- ESC 2021 hospitalization for heart failure
- CV death causal categories per Hicks 2017 / CDISC

For research / synthetic data only. Not for clinical use.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

from .ingest import EventPacket, TroponinPoint


# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------


@dataclass
class ClassificationResult:
    event_id: str
    domain: str  # MI / Stroke / HF / CV_death / non_MACE
    label: str  # specific subtype label
    confidence: float  # 0..1
    rationale: List[str] = field(default_factory=list)
    flags: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "event_id": self.event_id,
            "domain": self.domain,
            "label": self.label,
            "confidence": round(float(self.confidence), 3),
            "rationale": list(self.rationale),
            "flags": dict(self.flags),
        }


# ---------------------------------------------------------------------------
# Troponin pattern detection
# ---------------------------------------------------------------------------


def troponin_pattern(series: List[TroponinPoint]) -> Dict[str, Any]:
    """Return rise/fall pattern descriptors per UDM 2018."""
    if not series:
        return {"available": False}
    values = [p.value_ng_l for p in series]
    urln = series[0].urln_99th
    peak = max(values)
    nadir = min(values)
    above = any(p.above_url for p in series)
    rise = peak - values[0] > 0.2 * urln  # >20% URL absolute rise
    fall = (peak - values[-1] > 0.2 * urln) and values.index(peak) < len(values) - 1
    multiple = peak / urln if urln else 0
    return {
        "available": True,
        "above_url_at_any_point": above,
        "rise_pattern": bool(rise),
        "fall_pattern": bool(fall),
        "rise_and_fall": bool(rise and fall),
        "peak_x_urln": round(multiple, 2),
        "urln_99th": urln,
        "n_points": len(series),
    }


# ---------------------------------------------------------------------------
# UDM 2018 MI Type classifier
# ---------------------------------------------------------------------------


_PCI_KEYWORDS = ("PCI", "percutaneous", "stent", "balloon")
_CABG_KEYWORDS = ("CABG", "bypass")
_STENT_THROMBOSIS = ("stent thrombosis", "stent-thrombosis", "in-stent thrombus")
_RESTENOSIS = ("restenosis", "in-stent restenosis", "ISR")
_PLAQUE_RUPTURE = ("plaque rupture", "plaque-rupture", "ruptured plaque", "thrombus on plaque")
_STABLE_STENOSIS = ("stable stenosis", "fixed stenosis", "no acute culprit")
_DEMAND_ISCHEMIA = ("supply-demand", "demand ischemia", "anemia", "tachyarrhythmia", "hypotension", "sepsis")
_STEMI = ("STEMI", "ST elevation")
_NSTEMI = ("NSTEMI", "ST depression", "T-wave inversion")


def _has_any(text: str, keywords) -> bool:
    if not text:
        return False
    lower = text.lower()
    return any(k.lower() in lower for k in keywords)


def classify_mi(packet: EventPacket, include_type2: bool = True) -> ClassificationResult:
    """Suggest UDM 2018 MI Type for the packet."""
    rationale: List[str] = []
    flags: Dict[str, Any] = {}

    trop = troponin_pattern(packet.troponin_series)
    flags["troponin"] = trop

    ecg = packet.ecg_findings or ""
    angio = packet.angiography or ""
    summary = packet.discharge_summary or ""

    # Type 3: sudden cardiac death pre-troponin
    if (packet.suspected_category == "CV_death"
            and packet.death_precursors
            and ("ischemic symptoms" in packet.death_precursors.lower()
                 or "chest pain" in packet.death_precursors.lower()
                 or "st elevation" in packet.death_precursors.lower())
            and not trop.get("available", False)):
        rationale.append("Death before biomarker draw with ischemic precursors")
        return ClassificationResult(
            event_id=packet.event_id, domain="MI", label="Type 3",
            confidence=0.7, rationale=rationale, flags=flags,
        )

    # Procedural types must be PERI-procedural (UDM 2018: MI occurring during /
    # caused by the procedure). We require a procedural cue AND no plaque-rupture
    # evidence (otherwise this is a Type 1 MI that happened to be treated with PCI).
    plaque_evidence = _has_any(angio, _PLAQUE_RUPTURE) or _has_any(ecg, _STEMI)

    if _has_any(summary, _STENT_THROMBOSIS) or _has_any(angio, _STENT_THROMBOSIS):
        rationale.append("Stent thrombosis documented")
        return ClassificationResult(
            event_id=packet.event_id, domain="MI", label="Type 4b",
            confidence=0.8, rationale=rationale, flags=flags,
        )
    if _has_any(summary, _RESTENOSIS) or _has_any(angio, _RESTENOSIS):
        rationale.append("In-stent restenosis documented")
        return ClassificationResult(
            event_id=packet.event_id, domain="MI", label="Type 4c",
            confidence=0.7, rationale=rationale, flags=flags,
        )
    # Peri-PCI MI: procedural mention AND troponin >=5x URL AND no spontaneous
    # plaque rupture (treatment-PCI on a spontaneous MI stays Type 1)
    peri_pci = (_has_any(summary, ("peri-PCI", "during PCI", "post-procedural",
                                   "procedural MI", "stent placement"))
                or "post-pci elevation" in summary.lower())
    if peri_pci and trop.get("peak_x_urln", 0) >= 5 and not plaque_evidence:
        rationale.append("Peri-PCI MI: troponin >=5x URL without spontaneous plaque rupture")
        return ClassificationResult(
            event_id=packet.event_id, domain="MI", label="Type 4a",
            confidence=0.75, rationale=rationale, flags=flags,
        )
    if _has_any(summary, _CABG_KEYWORDS) and trop.get("peak_x_urln", 0) >= 10 and not plaque_evidence:
        rationale.append("Post-CABG troponin >=10x URL")
        return ClassificationResult(
            event_id=packet.event_id, domain="MI", label="Type 5",
            confidence=0.75, rationale=rationale, flags=flags,
        )

    # Need rise+fall + above URL for spontaneous MI
    if not trop.get("rise_and_fall") or not trop.get("above_url_at_any_point"):
        rationale.append("Troponin rise+fall above 99th URL not satisfied")
        return ClassificationResult(
            event_id=packet.event_id, domain="MI", label="not MI",
            confidence=0.6, rationale=rationale, flags=flags,
        )

    # Type 1 vs Type 2 — plaque rupture vs supply-demand mismatch
    if _has_any(angio, _PLAQUE_RUPTURE) or _has_any(ecg, _STEMI):
        rationale.append("Plaque rupture or STEMI ECG pattern present")
        return ClassificationResult(
            event_id=packet.event_id, domain="MI", label="Type 1",
            confidence=0.85, rationale=rationale, flags=flags,
        )

    if _has_any(summary, _DEMAND_ISCHEMIA) or _has_any(angio, _STABLE_STENOSIS):
        if include_type2:
            rationale.append("Supply-demand mismatch context; charter includes Type 2")
            return ClassificationResult(
                event_id=packet.event_id, domain="MI", label="Type 2",
                confidence=0.75, rationale=rationale, flags=flags,
            )
        rationale.append("Type 2 MI excluded per charter")
        return ClassificationResult(
            event_id=packet.event_id, domain="MI", label="Type 2 (excluded)",
            confidence=0.7, rationale=rationale, flags=flags,
        )

    # Default to Type 1 with low confidence if STEMI/NSTEMI ECG seen
    if _has_any(ecg, _STEMI + _NSTEMI):
        rationale.append("ECG ischemic changes, no plaque info; assume Type 1")
        return ClassificationResult(
            event_id=packet.event_id, domain="MI", label="Type 1",
            confidence=0.6, rationale=rationale, flags=flags,
        )

    rationale.append("Insufficient information; recommend manual review")
    return ClassificationResult(
        event_id=packet.event_id, domain="MI", label="indeterminate",
        confidence=0.4, rationale=rationale, flags=flags,
    )


# ---------------------------------------------------------------------------
# AHA/ASA stroke classifier
# ---------------------------------------------------------------------------


def classify_stroke(packet: EventPacket) -> ClassificationResult:
    rationale: List[str] = []
    flags: Dict[str, Any] = {"nihss": packet.nihss,
                            "symptom_duration_h": packet.symptom_duration_h}

    img = (packet.imaging_brain or "").lower()
    summary = (packet.discharge_summary or "").lower()

    # SAH first
    if "subarachnoid" in img or "sah" in img:
        rationale.append("Imaging shows subarachnoid hemorrhage")
        return ClassificationResult(
            event_id=packet.event_id, domain="Stroke", label="SAH",
            confidence=0.85, rationale=rationale, flags=flags,
        )

    # Hemorrhagic
    if "intracerebral hemorrhage" in img or "ich" in img or "hemorrhage" in img:
        rationale.append("Imaging shows intracerebral hemorrhage")
        return ClassificationResult(
            event_id=packet.event_id, domain="Stroke", label="Hemorrhagic",
            confidence=0.85, rationale=rationale, flags=flags,
        )

    # Ischemic vs TIA — AHA/ASA 2013: tissue-based, symptoms ≥ 24h OR infarct on imaging
    has_infarct = "infarct" in img or "diffusion restriction" in img or "dwi" in img
    long_symptom = (packet.symptom_duration_h or 0) >= 24
    if has_infarct or long_symptom:
        rationale.append("Infarct on imaging or symptoms ≥ 24h")
        conf = 0.85 if (has_infarct and long_symptom) else 0.7
        return ClassificationResult(
            event_id=packet.event_id, domain="Stroke", label="Ischemic",
            confidence=conf, rationale=rationale, flags=flags,
        )

    if packet.symptom_duration_h is not None and packet.symptom_duration_h < 24 and not has_infarct:
        rationale.append("Symptoms < 24h and no infarct → TIA")
        return ClassificationResult(
            event_id=packet.event_id, domain="Stroke", label="TIA",
            confidence=0.75, rationale=rationale, flags=flags,
        )

    if "stroke" in summary:
        rationale.append("Discharge summary mentions stroke; insufficient detail")
        return ClassificationResult(
            event_id=packet.event_id, domain="Stroke", label="indeterminate",
            confidence=0.4, rationale=rationale, flags=flags,
        )

    rationale.append("No imaging/symptom evidence for stroke")
    return ClassificationResult(
        event_id=packet.event_id, domain="Stroke", label="not stroke",
        confidence=0.5, rationale=rationale, flags=flags,
    )


# ---------------------------------------------------------------------------
# ESC HF (hospitalization for heart failure)
# ---------------------------------------------------------------------------


def classify_hf(packet: EventPacket) -> ClassificationResult:
    rationale: List[str] = []
    flags: Dict[str, Any] = {
        "bnp": packet.bnp_pg_ml,
        "ntprobnp": packet.ntprobnp_pg_ml,
        "iv_diuretic": packet.iv_diuretic,
        "los": packet.hospitalization_los_days,
    }

    bnp_high = (packet.bnp_pg_ml is not None and packet.bnp_pg_ml >= 100) or \
               (packet.ntprobnp_pg_ml is not None and packet.ntprobnp_pg_ml >= 300)
    diuretic = bool(packet.iv_diuretic)
    los_ok = packet.hospitalization_los_days is not None and packet.hospitalization_los_days >= 1

    score = sum([bnp_high, diuretic, los_ok])
    if bnp_high:
        rationale.append("Natriuretic peptide elevated above ESC threshold")
    if diuretic:
        rationale.append("IV diuretic administered")
    if los_ok:
        rationale.append("Overnight hospitalization documented")

    if score >= 2:
        return ClassificationResult(
            event_id=packet.event_id, domain="HF", label="HHF",
            confidence=0.7 + 0.1 * (score - 2), rationale=rationale, flags=flags,
        )
    if score == 1:
        return ClassificationResult(
            event_id=packet.event_id, domain="HF", label="possible HHF",
            confidence=0.5, rationale=rationale, flags=flags,
        )
    rationale.append("HF criteria not met")
    return ClassificationResult(
        event_id=packet.event_id, domain="HF", label="not HHF",
        confidence=0.55, rationale=rationale, flags=flags,
    )


# ---------------------------------------------------------------------------
# CV death causal classifier
# ---------------------------------------------------------------------------


def classify_cv_death(packet: EventPacket) -> ClassificationResult:
    rationale: List[str] = []
    flags: Dict[str, Any] = {
        "witnessed": packet.death_witnessed,
        "precursors": packet.death_precursors,
    }
    if packet.suspected_category != "CV_death":
        return ClassificationResult(
            event_id=packet.event_id, domain="CV_death", label="n/a",
            confidence=0.0, rationale=["Suspected category is not CV death"],
            flags=flags,
        )

    pre = (packet.death_precursors or "").lower()
    summary = (packet.discharge_summary or "").lower()

    if "mi" in pre or "infarction" in pre or "stemi" in pre or "nstemi" in pre:
        rationale.append("Documented MI precursor")
        return ClassificationResult(
            event_id=packet.event_id, domain="CV_death", label="fatal_MI",
            confidence=0.75, rationale=rationale, flags=flags,
        )
    if "hf" in pre or "heart failure" in pre or "pulmonary edema" in pre:
        rationale.append("HF decompensation prior to death")
        return ClassificationResult(
            event_id=packet.event_id, domain="CV_death", label="fatal_HF",
            confidence=0.75, rationale=rationale, flags=flags,
        )
    if "stroke" in pre or "ich" in pre or "intracranial" in pre:
        rationale.append("Stroke precursor")
        return ClassificationResult(
            event_id=packet.event_id, domain="CV_death", label="fatal_stroke",
            confidence=0.75, rationale=rationale, flags=flags,
        )
    if "vt" in pre or "vf" in pre or "arrhythmia" in pre or "torsades" in pre:
        rationale.append("Arrhythmic precursor documented")
        return ClassificationResult(
            event_id=packet.event_id, domain="CV_death", label="arrhythmia",
            confidence=0.75, rationale=rationale, flags=flags,
        )
    if packet.death_witnessed is False or "unwitnessed" in summary or "found dead" in summary:
        rationale.append("Unwitnessed death without obvious non-CV cause")
        return ClassificationResult(
            event_id=packet.event_id, domain="CV_death", label="sudden_cardiac_death",
            confidence=0.6, rationale=rationale, flags=flags,
        )
    if "cancer" in summary or "sepsis" in summary or "pneumonia" in summary or "trauma" in summary:
        rationale.append("Non-CV cause documented")
        return ClassificationResult(
            event_id=packet.event_id, domain="CV_death", label="non_CV",
            confidence=0.7, rationale=rationale, flags=flags,
        )
    rationale.append("Insufficient causal data; default to undetermined CV")
    return ClassificationResult(
        event_id=packet.event_id, domain="CV_death", label="undetermined",
        confidence=0.45, rationale=rationale, flags=flags,
    )


# ---------------------------------------------------------------------------
# Convenience batch API
# ---------------------------------------------------------------------------


def classify_all(packets: List[EventPacket], include_type2_mi: bool = True
                 ) -> List[ClassificationResult]:
    out: List[ClassificationResult] = []
    for p in packets:
        cat = p.suspected_category
        if cat == "MI":
            out.append(classify_mi(p, include_type2=include_type2_mi))
        elif cat == "Stroke":
            out.append(classify_stroke(p))
        elif cat == "HF":
            out.append(classify_hf(p))
        elif cat == "CV_death":
            out.append(classify_cv_death(p))
        else:
            out.append(ClassificationResult(
                event_id=p.event_id, domain="non_MACE", label="other",
                confidence=0.5, rationale=["Suspected category not in MACE set"],
            ))
    return out
