"""Discharge transition + 30-day readmission Kaplan-Meier.

- discharge insulin / oral 당뇨약 transition
- 30-day readmission KM (DKA / HHS / Hypo / Hyper)
- lifelines 가 설치되어 있으면 정식 KM fit, 아니면 단순 step-wise estimate

이 모듈은 참고용·연구용.
"""
from __future__ import annotations

from collections import Counter
from dataclasses import dataclass, field
from typing import Dict, Iterable, List, Optional, Tuple

from .ingest import Patient


@dataclass
class TransitionSummary:
    discharge_regimen_mix: Dict[str, int]
    n_total: int
    n_readmit_30d: int
    readmit_rate_pct: float
    readmit_reason_mix: Dict[str, int]
    median_time_to_readmit_d: float


def _median(xs: List[float]) -> float:
    if not xs:
        return 0.0
    s = sorted(xs)
    n = len(s)
    return s[n // 2] if n % 2 else 0.5 * (s[n // 2 - 1] + s[n // 2])


def summarize_transition(patients: Iterable[Patient]) -> TransitionSummary:
    patients = list(patients)
    mix = Counter(p.discharge_insulin for p in patients)
    readmits = [p for p in patients if p.readmit_30d]
    reasons = Counter((p.readmit_reason or "Unknown") for p in readmits)
    times = [float(p.readmit_day - p.discharge_day) for p in readmits if p.readmit_day]
    n = len(patients)
    return TransitionSummary(
        discharge_regimen_mix=dict(mix),
        n_total=n,
        n_readmit_30d=len(readmits),
        readmit_rate_pct=round(100 * len(readmits) / max(1, n), 2),
        readmit_reason_mix=dict(reasons),
        median_time_to_readmit_d=round(_median(times), 1),
    )


def kaplan_meier_30d(patients: Iterable[Patient]) -> List[Tuple[int, float, int]]:
    """Step-wise KM estimate (day, S(t), n_at_risk) for 30-day readmission.

    Uses lifelines when available; falls back to a hand-rolled estimator.
    """
    patients = list(patients)
    times: List[float] = []
    events: List[int] = []
    for p in patients:
        if p.readmit_30d and p.readmit_day is not None:
            t = max(1, p.readmit_day - p.discharge_day)
            times.append(min(30.0, t))
            events.append(1)
        else:
            times.append(30.0)
            events.append(0)

    try:
        from lifelines import KaplanMeierFitter  # type: ignore
        kmf = KaplanMeierFitter()
        kmf.fit(times, event_observed=events)
        rows = []
        for d, s in zip(kmf.survival_function_.index, kmf.survival_function_.iloc[:, 0]):
            rows.append((int(round(float(d))), float(s), int(sum(1 for t in times if t >= float(d)))))
        return rows
    except Exception:
        pass

    # ---- fallback: manual KM ----
    pairs = sorted(zip(times, events), key=lambda x: x[0])
    out = []
    surv = 1.0
    n = len(pairs)
    day_to_events: Dict[int, Tuple[int, int]] = {}
    sorted_times = sorted(set(int(t) for t, _ in pairs))
    i = 0
    at_risk = n
    for day in sorted_times:
        d_events = 0
        d_censor = 0
        consumed = 0
        for t, e in pairs[i:]:
            if int(t) != day:
                break
            if e == 1:
                d_events += 1
            else:
                d_censor += 1
            consumed += 1
        i += consumed
        if at_risk > 0 and d_events > 0:
            surv *= 1 - d_events / at_risk
        out.append((day, round(surv, 4), at_risk))
        at_risk -= (d_events + d_censor)
    return out


def regimen_vs_readmit(patients: Iterable[Patient]) -> List[Dict[str, float]]:
    by_reg: Dict[str, List[int]] = {}
    for p in patients:
        by_reg.setdefault(p.discharge_insulin, []).append(1 if p.readmit_30d else 0)
    out = []
    for reg, flags in sorted(by_reg.items()):
        n = len(flags) or 1
        out.append({
            "regimen": reg,
            "n": n,
            "n_readmit": sum(flags),
            "readmit_rate_pct": round(100 * sum(flags) / n, 1),
        })
    return out
