"""Event packet ingest, de-identification, and charter mapping.

This module loads suspected MACE event packets (eCRF + troponin time series +
ECG metadata + imaging + discharge summary), performs SHA-256 hashing of patient
identifiers, applies per-site date shifts, and maps fields against a CEC charter
YAML so downstream classification can run against charter-specific definitions.

For research / synthetic data only.
"""

from __future__ import annotations

import csv
import hashlib
import os
import random
from dataclasses import dataclass, field, asdict
from datetime import datetime, timedelta
from typing import Any, Dict, Iterable, List, Optional

try:
    import yaml  # type: ignore
except ImportError:  # pragma: no cover
    yaml = None


# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------


@dataclass
class TroponinPoint:
    timepoint_h: float
    value_ng_l: float
    urln_99th: float = 14.0  # site-specific 99th percentile URL (ng/L)

    @property
    def above_url(self) -> bool:
        return self.value_ng_l > self.urln_99th

    @property
    def multiple_of_urln(self) -> float:
        if self.urln_99th <= 0:
            return 0.0
        return self.value_ng_l / self.urln_99th


@dataclass
class EventPacket:
    """A single suspected MACE event packet."""

    event_id: str
    patient_id: str
    site_id: str
    age: int
    sex: str  # M / F
    treatment_arm: str  # A / B / placebo (blinded)
    event_date: str  # ISO date (already shifted after de-id)
    suspected_category: str  # MI / Stroke / HF / CV_death / other
    troponin_series: List[TroponinPoint] = field(default_factory=list)
    ecg_findings: str = ""
    angiography: str = ""
    imaging_brain: str = ""
    bnp_pg_ml: Optional[float] = None
    ntprobnp_pg_ml: Optional[float] = None
    nihss: Optional[int] = None
    symptom_duration_h: Optional[float] = None
    iv_diuretic: bool = False
    hospitalization_los_days: Optional[int] = None
    death_witnessed: Optional[bool] = None
    death_precursors: str = ""
    discharge_summary: str = ""

    def to_dict(self) -> Dict[str, Any]:
        d = asdict(self)
        d["troponin_series"] = [asdict(t) for t in self.troponin_series]
        return d


# ---------------------------------------------------------------------------
# De-identification
# ---------------------------------------------------------------------------


def hash_identifier(raw: str, salt: str = "cvot-mace-2026") -> str:
    """SHA-256 hash patient identifier with a project-level salt."""
    h = hashlib.sha256()
    h.update(salt.encode("utf-8"))
    h.update(b"::")
    h.update(raw.encode("utf-8"))
    return h.hexdigest()[:16]


def shift_date(iso_date: str, site_id: str, seed: int = 42) -> str:
    """Apply a deterministic per-site date shift (-180 .. +180 days)."""
    rng = random.Random(f"{seed}-{site_id}")
    delta = rng.randint(-180, 180)
    try:
        d = datetime.fromisoformat(iso_date)
    except ValueError:
        return iso_date
    shifted = d + timedelta(days=delta)
    return shifted.date().isoformat()


def mask_site(site_id: str) -> str:
    """Mask raw site identifier to a 3-char hash prefix."""
    return "S-" + hash_identifier(site_id, salt="site-mask")[:6]


def deidentify_packet(packet: EventPacket) -> EventPacket:
    """Return a new EventPacket with identifiers hashed, dates shifted, site masked."""
    return EventPacket(
        event_id=packet.event_id,
        patient_id=hash_identifier(packet.patient_id),
        site_id=mask_site(packet.site_id),
        age=packet.age,
        sex=packet.sex,
        treatment_arm=packet.treatment_arm,
        event_date=shift_date(packet.event_date, packet.site_id),
        suspected_category=packet.suspected_category,
        troponin_series=list(packet.troponin_series),
        ecg_findings=packet.ecg_findings,
        angiography=packet.angiography,
        imaging_brain=packet.imaging_brain,
        bnp_pg_ml=packet.bnp_pg_ml,
        ntprobnp_pg_ml=packet.ntprobnp_pg_ml,
        nihss=packet.nihss,
        symptom_duration_h=packet.symptom_duration_h,
        iv_diuretic=packet.iv_diuretic,
        hospitalization_los_days=packet.hospitalization_los_days,
        death_witnessed=packet.death_witnessed,
        death_precursors=packet.death_precursors,
        discharge_summary=packet.discharge_summary,
    )


# ---------------------------------------------------------------------------
# CSV loaders
# ---------------------------------------------------------------------------


def _to_float(s: str) -> Optional[float]:
    if s is None or s == "" or s.lower() == "na":
        return None
    try:
        return float(s)
    except ValueError:
        return None


def _to_int(s: str) -> Optional[int]:
    f = _to_float(s)
    return int(f) if f is not None else None


def _to_bool(s: str) -> Optional[bool]:
    if s is None or s == "":
        return None
    return s.strip().lower() in ("1", "true", "y", "yes")


def load_packets_csv(packets_csv: str, troponin_csv: Optional[str] = None) -> List[EventPacket]:
    """Load packets and optionally attach troponin time series."""
    trop_index: Dict[str, List[TroponinPoint]] = {}
    if troponin_csv and os.path.exists(troponin_csv):
        with open(troponin_csv, newline="", encoding="utf-8") as f:
            for row in csv.DictReader(f):
                eid = row["event_id"]
                pt = TroponinPoint(
                    timepoint_h=float(row["timepoint_h"]),
                    value_ng_l=float(row["value_ng_l"]),
                    urln_99th=float(row.get("urln_99th") or 14.0),
                )
                trop_index.setdefault(eid, []).append(pt)
        for eid in trop_index:
            trop_index[eid].sort(key=lambda p: p.timepoint_h)

    packets: List[EventPacket] = []
    with open(packets_csv, newline="", encoding="utf-8") as f:
        for row in csv.DictReader(f):
            packets.append(
                EventPacket(
                    event_id=row["event_id"],
                    patient_id=row["patient_id"],
                    site_id=row["site_id"],
                    age=int(row["age"]),
                    sex=row["sex"],
                    treatment_arm=row["treatment_arm"],
                    event_date=row["event_date"],
                    suspected_category=row["suspected_category"],
                    troponin_series=trop_index.get(row["event_id"], []),
                    ecg_findings=row.get("ecg_findings", ""),
                    angiography=row.get("angiography", ""),
                    imaging_brain=row.get("imaging_brain", ""),
                    bnp_pg_ml=_to_float(row.get("bnp_pg_ml", "")),
                    ntprobnp_pg_ml=_to_float(row.get("ntprobnp_pg_ml", "")),
                    nihss=_to_int(row.get("nihss", "")),
                    symptom_duration_h=_to_float(row.get("symptom_duration_h", "")),
                    iv_diuretic=bool(_to_bool(row.get("iv_diuretic", "")) or False),
                    hospitalization_los_days=_to_int(row.get("hospitalization_los_days", "")),
                    death_witnessed=_to_bool(row.get("death_witnessed", "")),
                    death_precursors=row.get("death_precursors", ""),
                    discharge_summary=row.get("discharge_summary", ""),
                )
            )
    return packets


# ---------------------------------------------------------------------------
# Charter mapping
# ---------------------------------------------------------------------------

DEFAULT_CHARTER: Dict[str, Any] = {
    "trial_id": "DEMO-CVOT-001",
    "primary_endpoint": "3p-MACE",  # 3p-MACE | 4p-MACE | 5p-MACE
    "include_type2_mi": False,
    "mi_definition": "UDM-2018",
    "stroke_definition": "AHA-ASA-2013",
    "hf_definition": "ESC-2021-HHF",
    "cv_death_categories": [
        "fatal_MI",
        "fatal_HF",
        "sudden_cardiac_death",
        "fatal_stroke",
        "other_CV",
    ],
    "adjudication": {
        "n_readers": 2,
        "blinded": True,
        "discordance_routing": "3rd_reader_then_panel",
        "calibration_interval_n_cases": 20,
    },
}


def load_charter(path: Optional[str]) -> Dict[str, Any]:
    """Load CEC charter YAML; fall back to DEFAULT_CHARTER if missing/parse-fail."""
    if not path or not os.path.exists(path):
        return dict(DEFAULT_CHARTER)
    if yaml is None:
        # YAML not installed — surface default with note
        c = dict(DEFAULT_CHARTER)
        c["_note"] = "PyYAML not installed; default charter applied"
        return c
    try:
        with open(path, "r", encoding="utf-8") as f:
            data = yaml.safe_load(f) or {}
    except Exception as e:  # noqa: BLE001
        c = dict(DEFAULT_CHARTER)
        c["_note"] = f"charter parse error: {e}"
        return c
    merged = dict(DEFAULT_CHARTER)
    merged.update(data)
    return merged


def map_packet_to_charter(packet: EventPacket, charter: Dict[str, Any]) -> Dict[str, Any]:
    """Return a dict of charter-relevant flags for downstream classification."""
    return {
        "event_id": packet.event_id,
        "primary_endpoint": charter.get("primary_endpoint", "3p-MACE"),
        "include_type2_mi": bool(charter.get("include_type2_mi", False)),
        "mi_definition": charter.get("mi_definition", "UDM-2018"),
        "stroke_definition": charter.get("stroke_definition", "AHA-ASA-2013"),
        "hf_definition": charter.get("hf_definition", "ESC-2021-HHF"),
    }


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def deidentify_all(packets: Iterable[EventPacket]) -> List[EventPacket]:
    return [deidentify_packet(p) for p in packets]
