"""DKA / HHS / Hypoglycemia episode trajectory + protocol 부합.

ADA DKA/HHS protocol (간략):
  - IV insulin 0.1 U/kg/h (DKA) / 0.05-0.1 (HHS)
  - K+ supplementation when K < 5.2 mEq/L
  - resolution: pH > 7.3 + bicarb > 18 + anion gap < 12 (DKA)
  - HHS resolution: glucose < 300 + osm normalized + mental status restored
  - 후 transition to SC basal-bolus + 1-2h overlap

이 모듈은 참고용·연구용.
"""
from __future__ import annotations

from collections import Counter, defaultdict
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple

from .ingest import Episode, Patient


@dataclass
class EpisodeTrajectory:
    episode_type: str
    n_total: int
    n_resolved: int
    resolution_rate: float
    median_resolution_h: float
    pct_k_supplement: float
    median_fluid_l_24h: float
    median_iv_rate_uhr: float
    pct_transition_basal_bolus: float
    protocol_compliance_pct: float


def _median(xs: List[float]) -> float:
    if not xs:
        return 0.0
    s = sorted(xs)
    n = len(s)
    return s[n // 2] if n % 2 else 0.5 * (s[n // 2 - 1] + s[n // 2])


def _compliance(ep: Episode) -> bool:
    """Heuristic DKA/HHS protocol compliance — all-or-none."""
    if ep.episode_type == "DKA":
        return (
            0.05 <= ep.iv_insulin_rate_uhr <= 0.2
            and ep.k_supplement
            and ep.fluid_l_24h >= 3.0
            and ep.transition_regimen == "basal-bolus"
        )
    if ep.episode_type == "HHS":
        return (
            0.03 <= ep.iv_insulin_rate_uhr <= 0.15
            and ep.fluid_l_24h >= 4.0
            and ep.transition_regimen in ("basal-bolus", "basal-only")
        )
    if ep.episode_type == "Hypoglycemia":
        # protocol: D50 rescue + glucose recheck in 15 min;
        # we use resolved+short resolution as proxy
        return ep.resolved and (ep.resolution_h is None or ep.resolution_h <= 1.0)
    if ep.episode_type == "Hyperglycemia-Persistent":
        return ep.transition_regimen in ("basal-bolus", "IV-infusion")
    return True


def summarize_episodes(episodes: Iterable[Episode]) -> List[EpisodeTrajectory]:
    by_type: Dict[str, List[Episode]] = defaultdict(list)
    for e in episodes:
        if e.episode_type == "None":
            continue
        by_type[e.episode_type].append(e)

    out = []
    for etype, rows in sorted(by_type.items()):
        n = len(rows)
        n_res = sum(1 for e in rows if e.resolved)
        res_h_vals = [e.resolution_h for e in rows if e.resolution_h is not None]
        out.append(EpisodeTrajectory(
            episode_type=etype,
            n_total=n,
            n_resolved=n_res,
            resolution_rate=round(100 * n_res / max(1, n), 1),
            median_resolution_h=round(_median(res_h_vals), 1),
            pct_k_supplement=round(100 * sum(1 for e in rows if e.k_supplement) / max(1, n), 1),
            median_fluid_l_24h=round(_median([e.fluid_l_24h for e in rows]), 1),
            median_iv_rate_uhr=round(_median([e.iv_insulin_rate_uhr for e in rows]), 2),
            pct_transition_basal_bolus=round(
                100 * sum(1 for e in rows if e.transition_regimen == "basal-bolus") / max(1, n), 1),
            protocol_compliance_pct=round(100 * sum(1 for e in rows if _compliance(e)) / max(1, n), 1),
        ))
    return out


def per_patient_episode_table(patients: Iterable[Patient],
                              episodes: Iterable[Episode]) -> List[Dict[str, str]]:
    pid_to_ward = {p.patient_id: p.ward for p in patients}
    rows = []
    for e in episodes:
        if e.episode_type == "None":
            continue
        rows.append({
            "patient_id": e.patient_id,
            "ward": pid_to_ward.get(e.patient_id, "?"),
            "episode_type": e.episode_type,
            "duration_h": e.duration_h,
            "resolved": e.resolved,
            "resolution_h": e.resolution_h,
            "transition_regimen": e.transition_regimen,
            "protocol_compliant": _compliance(e),
        })
    return rows
