"""POD0-90 outpatient transition + 30-day readmission/mortality KM.

Outpatient follow-up adherence:
  - POD7 / POD30 / POD60 / POD90 visit adherence by ward, surgeon, procedure
  - mean weight loss, %TWL, HbA1c, SBP, LDL, OSA recovery, PHQ-9 at POD90
  - GLP-1RA add-on rate, reoperation trigger rate

Survival (Kaplan-Meier-like step):
  - 30-day readmission KM curve (failure = readmission)
  - 30-day mortality KM curve (failure = death)

Pure-stdlib implementation (no lifelines / no scipy dependency for CLI/report).
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

from .ingest import Patient, POD430Row, POD90Row


@dataclass
class OutptAdherence:
    grouping: str            # "ward" / "surgeon" / "procedure"
    key: str
    n_total: int
    pct_pod7: float
    pct_pod30: float
    pct_pod60: float
    pct_pod90: float
    mean_twl_pct: float
    mean_hba1c: Optional[float]
    glp1_added_pct: float
    reop_trigger_pct: float


@dataclass
class TransitionSummary:
    n_total: int
    n_readmit_30d: int
    readmit_rate_pct: float
    median_time_to_readmit_d: Optional[float]
    n_mort_30d: int
    mort_rate_pct: float
    readmit_reason_mix: Dict[str, int]


def _avg(xs: List[float]) -> float:
    return round(sum(xs) / len(xs), 2) if xs else 0.0


def _pct(n: int, d: int) -> float:
    return round(100.0 * n / d, 1) if d else 0.0


def adherence_by(grouping: str,
                 patients: List[Patient],
                 pod90: List[POD90Row]) -> List[OutptAdherence]:
    assert grouping in ("ward", "surgeon", "procedure")
    pod90_map = {r.patient_id: r for r in pod90}
    grp: Dict[str, List[Patient]] = {}
    for p in patients:
        key = getattr(p, grouping)
        grp.setdefault(key, []).append(p)

    out: List[OutptAdherence] = []
    for key in sorted(grp):
        plist = grp[key]
        n = len(plist)
        v7 = sum(1 for p in plist if pod90_map.get(p.patient_id)
                 and pod90_map[p.patient_id].visit_pod7)
        v30 = sum(1 for p in plist if pod90_map.get(p.patient_id)
                  and pod90_map[p.patient_id].visit_pod30)
        v60 = sum(1 for p in plist if pod90_map.get(p.patient_id)
                  and pod90_map[p.patient_id].visit_pod60)
        v90 = sum(1 for p in plist if pod90_map.get(p.patient_id)
                  and pod90_map[p.patient_id].visit_pod90)
        # %TWL = (weight_pre - weight_pod90)/weight_pre*100  (only if visited)
        twls: List[float] = []
        a1cs: List[float] = []
        n_glp = 0
        n_reop = 0
        for p in plist:
            row = pod90_map.get(p.patient_id)
            if not row:
                continue
            if row.weight_pod90_kg and p.weight_pre_kg:
                twls.append(
                    100.0 * (p.weight_pre_kg - row.weight_pod90_kg)
                    / p.weight_pre_kg
                )
            if row.hba1c_pod90 is not None:
                a1cs.append(row.hba1c_pod90)
            if row.glp1_ra_added:
                n_glp += 1
            if row.reop_trigger:
                n_reop += 1

        out.append(OutptAdherence(
            grouping=grouping, key=key, n_total=n,
            pct_pod7=_pct(v7, n), pct_pod30=_pct(v30, n),
            pct_pod60=_pct(v60, n), pct_pod90=_pct(v90, n),
            mean_twl_pct=_avg(twls),
            mean_hba1c=(round(sum(a1cs) / len(a1cs), 2) if a1cs else None),
            glp1_added_pct=_pct(n_glp, n),
            reop_trigger_pct=_pct(n_reop, n),
        ))
    return out


def transition_summary(patients: List[Patient],
                       pod430: List[POD430Row]) -> TransitionSummary:
    n = len(patients)
    n_rd = sum(1 for r in pod430 if r.readmit_30d)
    rd_days = [r.readmit_day for r in pod430
               if r.readmit_30d and r.readmit_day is not None]
    rd_days.sort()
    median = (rd_days[len(rd_days) // 2] if rd_days else None)
    n_mt = sum(1 for p in patients if p.died_30d)
    reason_mix: Dict[str, int] = {}
    for r in pod430:
        if r.readmit_30d and r.readmit_reason:
            reason_mix[r.readmit_reason] = reason_mix.get(r.readmit_reason, 0) + 1
    return TransitionSummary(
        n_total=n,
        n_readmit_30d=n_rd,
        readmit_rate_pct=_pct(n_rd, n),
        median_time_to_readmit_d=(float(median) if median is not None else None),
        n_mort_30d=n_mt,
        mort_rate_pct=round(100.0 * n_mt / n, 2) if n else 0.0,
        readmit_reason_mix=reason_mix,
    )


def kaplan_meier_step(patients: List[Patient],
                      pod430: List[POD430Row],
                      endpoint: str = "readmit",
                      horizon_d: int = 30) -> List[Tuple[int, float, int]]:
    """Pure-stdlib KM-like step for 30-day readmission or mortality.

    endpoint = "readmit" | "mortality"
    Returns [(day, S(t), n_at_risk), ...] day=0..horizon_d.

    Censoring rule: every patient followed full horizon_d days unless event.
    """
    n = len(patients)
    events: Dict[int, int] = {}    # day -> n_events
    rd_map = {r.patient_id: r for r in pod430}
    for p in patients:
        if endpoint == "readmit":
            r = rd_map.get(p.patient_id)
            if r and r.readmit_30d and r.readmit_day is not None:
                d = min(int(r.readmit_day), horizon_d)
                events[d] = events.get(d, 0) + 1
        elif endpoint == "mortality":
            if p.died_30d:
                # spread within horizon — assume late (POD20) if not known
                d = horizon_d - 10
                events[d] = events.get(d, 0) + 1
        else:
            raise ValueError(f"unknown endpoint: {endpoint}")

    s = 1.0
    at_risk = n
    out: List[Tuple[int, float, int]] = [(0, 1.0, n)]
    for day in range(1, horizon_d + 1):
        d_events = events.get(day, 0)
        if at_risk > 0 and d_events > 0:
            s = s * (1.0 - d_events / at_risk)
        at_risk -= d_events
        out.append((day, round(s, 4), max(at_risk, 0)))
    return out
