"""Ward-level ADA / KDA inpatient TIR & hypo/hyper KPI.

Reference ranges (참고):
  - ADA Standards of Care 2025 inpatient: TIR 100-180 (일반병동), 140-180 (ICU)
  - hypo: L1 <70, L2 <54, L3 <40 (severe)
  - hyper: >180, >250, >300, >400 (persistent >180 ≥6h flag)
  - KDA 2023 입원당뇨: 유사한 범위, ICU 140-180 권고

이 모듈은 의학적 의사결정 도구가 아님 — 참고용·연구용.
"""
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, Iterable, List, Tuple

from .ingest import Patient, POCTReading


ICU_WARDS = {"MICU", "SICU", "CCU", "NICU"}


def _tir_window(ward: str) -> Tuple[int, int]:
    """ADA/KDA inpatient TIR window by ward."""
    if ward in ICU_WARDS:
        return (140, 180)
    return (100, 180)


@dataclass
class WardTIR:
    ward: str
    n_patients: int
    n_readings: int
    tir_low: int
    tir_high: int
    mean_bg: float
    median_bg: float
    pct_in_range: float
    pct_hypo_l1: float        # <70
    pct_hypo_l2: float        # <54
    pct_hypo_l3: float        # <40
    pct_hyper_180: float
    pct_hyper_250: float
    pct_hyper_300: float
    pct_hyper_400: float
    n_persistent_hyper: int   # ≥6 consecutive hours >180 (per patient)
    kda_compliance_score: float  # 0..100


def _mean(xs: List[float]) -> float:
    return sum(xs) / len(xs) if xs else 0.0


def _median(xs: List[float]) -> float:
    if not xs:
        return 0.0
    s = sorted(xs)
    n = len(s)
    return s[n // 2] if n % 2 else 0.5 * (s[n // 2 - 1] + s[n // 2])


def _count_persistent_hyper(readings: List[POCTReading], threshold: int = 180,
                            hours: int = 6) -> int:
    """A coarse flag — number of (patient, study_day) windows with ≥`hours`
    consecutive hours of glucose > threshold."""
    if not readings:
        return 0
    by_pat_day: Dict[Tuple[str, int], List[POCTReading]] = defaultdict(list)
    for r in readings:
        by_pat_day[(r.patient_id, r.study_day)].append(r)
    count = 0
    for key, rows in by_pat_day.items():
        rows.sort(key=lambda r: r.hour)
        run = 0
        prev_hour = None
        for r in rows:
            if r.glucose_mg_dl > threshold:
                # treat any reading >180 as covering 1 hour;
                # require run of hours covered ≥ `hours`
                if prev_hour is None or r.hour - prev_hour <= 2:
                    run += max(1, r.hour - prev_hour if prev_hour is not None else 1)
                else:
                    run = 1
                if run >= hours:
                    count += 1
                    run = 0
                    prev_hour = None
                    continue
                prev_hour = r.hour
            else:
                run = 0
                prev_hour = None
    return count


def compute_ward_tir(patients: Iterable[Patient],
                     poct: Iterable[POCTReading]) -> List[WardTIR]:
    patients = list(patients)
    poct = list(poct)
    pid_to_ward = {p.patient_id: p.ward for p in patients}

    by_ward_readings: Dict[str, List[POCTReading]] = defaultdict(list)
    by_ward_patients: Dict[str, set] = defaultdict(set)

    for r in poct:
        w = pid_to_ward.get(r.patient_id)
        if w is None:
            continue
        by_ward_readings[w].append(r)
        by_ward_patients[w].add(r.patient_id)

    wards = sorted(set(by_ward_readings.keys()) | {p.ward for p in patients})
    out: List[WardTIR] = []
    for w in wards:
        rs = by_ward_readings.get(w, [])
        n_pat = len(by_ward_patients.get(w, set())) or sum(1 for p in patients if p.ward == w)
        vals = [r.glucose_mg_dl for r in rs]
        low, high = _tir_window(w)
        n = len(vals) or 1
        in_range = sum(1 for v in vals if low <= v <= high)
        hypo1 = sum(1 for v in vals if v < 70)
        hypo2 = sum(1 for v in vals if v < 54)
        hypo3 = sum(1 for v in vals if v < 40)
        hy180 = sum(1 for v in vals if v > 180)
        hy250 = sum(1 for v in vals if v > 250)
        hy300 = sum(1 for v in vals if v > 300)
        hy400 = sum(1 for v in vals if v > 400)

        pers = _count_persistent_hyper(rs)
        pct_in = 100 * in_range / n
        pct_hypo1 = 100 * hypo1 / n
        pct_hy180 = 100 * hy180 / n

        # KDA compliance heuristic: TIR↑ good, hypo↓ good, persistent hyper↓ good
        kda = max(0.0, min(100.0,
                           0.6 * pct_in
                           - 1.5 * pct_hypo1
                           - 0.05 * (pers / max(1, n_pat)) * 100
                           + 40))
        out.append(WardTIR(
            ward=w,
            n_patients=n_pat,
            n_readings=len(vals),
            tir_low=low,
            tir_high=high,
            mean_bg=round(_mean(vals), 1),
            median_bg=round(_median(vals), 1),
            pct_in_range=round(pct_in, 1),
            pct_hypo_l1=round(pct_hypo1, 2),
            pct_hypo_l2=round(100 * hypo2 / n, 2),
            pct_hypo_l3=round(100 * hypo3 / n, 2),
            pct_hyper_180=round(pct_hy180, 1),
            pct_hyper_250=round(100 * hy250 / n, 1),
            pct_hyper_300=round(100 * hy300 / n, 1),
            pct_hyper_400=round(100 * hy400 / n, 2),
            n_persistent_hyper=pers,
            kda_compliance_score=round(kda, 1),
        ))
    return out


def ward_ranking(rows: List[WardTIR]) -> List[Tuple[str, float]]:
    """Highest KDA compliance first."""
    return sorted([(r.ward, r.kda_compliance_score) for r in rows],
                  key=lambda x: -x[1])
