"""ADA 2023 / ATTD consensus CGM outcome calculator.

Outcomes:
  1. TIR (70-180)            %
  2. TBR Level 1 (<70)        %
  3. TBR Level 2 (<54)        %
  4. TAR Level 1 (>180)       %
  5. TAR Level 2 (>250)       %
  6. Mean glucose             mg/dL
  7. GMI                      %  = 3.31 + 0.02392 * mean_mg_dl
  8. CV%                      %  = SD / mean * 100
  9. SD                       mg/dL
 10. MAGE (simplified)        mg/dL  (mean of |peak-trough| > 1*SD)
 11. J-index                  = 0.001 * (mean + SD)^2
 12. Hypo events              count of contiguous <70 spans >= 15min
 13. Sensor wear time %       (#actual_5min_samples / expected) * 100
 14. Days with data           int
"""

from __future__ import annotations
import math
import pandas as pd
import numpy as np


def _expected_samples(start, end, interval_min: int = 5) -> int:
    if pd.isna(start) or pd.isna(end):
        return 0
    delta = (end - start).total_seconds() / 60.0
    return max(int(delta / interval_min) + 1, 1)


def _mage_simple(g: np.ndarray, sd: float) -> float:
    """Very simplified MAGE: scan for peak-trough excursions > 1 SD."""
    if len(g) < 3 or sd <= 0:
        return float("nan")
    excursions = []
    last_extreme = g[0]
    direction = 0  # +1 up, -1 down
    for x in g[1:]:
        if direction >= 0 and x > last_extreme:
            last_extreme = x
            direction = 1
        elif direction <= 0 and x < last_extreme:
            last_extreme = x
            direction = -1
        else:
            amp = abs(x - last_extreme)
            if amp > sd:
                excursions.append(amp)
            last_extreme = x
            direction = 1 if direction <= 0 else -1
    if not excursions:
        return float("nan")
    return float(np.mean(excursions))


def _hypo_events(df_subj: pd.DataFrame, threshold: float = 70.0,
                 min_minutes: int = 15) -> int:
    """Count hypoglycemia events: contiguous <70 spans of >= 15 min, separated by >=15 min above."""
    if len(df_subj) == 0:
        return 0
    s = df_subj.sort_values("timestamp_KST").reset_index(drop=True)
    below = s["glucose_mg_dl"] < threshold
    if not below.any():
        return 0
    events = 0
    in_event = False
    event_start = None
    for i, row in s.iterrows():
        if below.iat[i]:
            if not in_event:
                in_event = True
                event_start = row["timestamp_KST"]
        else:
            if in_event:
                duration_min = (row["timestamp_KST"] - event_start).total_seconds() / 60.0
                if duration_min >= min_minutes:
                    events += 1
                in_event = False
    if in_event and event_start is not None:
        duration_min = (s["timestamp_KST"].iloc[-1] - event_start).total_seconds() / 60.0
        if duration_min >= min_minutes:
            events += 1
    return events


def compute_outcomes(df: pd.DataFrame, interval_min: int = 5) -> pd.DataFrame:
    """Return per-subject outcome table (one row per subject_id)."""
    if df.empty:
        return pd.DataFrame()
    df = df.copy()
    df = df.dropna(subset=["timestamp_KST"])
    df = df[df["glucose_mg_dl"].notna()]
    df["timestamp_KST"] = pd.to_datetime(df["timestamp_KST"])

    rows = []
    for sid, g in df.groupby("subject_id"):
        glu = g["glucose_mg_dl"].astype(float).values
        n = len(glu)
        if n == 0:
            continue
        mean = float(np.mean(glu))
        sd = float(np.std(glu, ddof=1)) if n > 1 else 0.0
        cv = (sd / mean * 100.0) if mean > 0 else float("nan")
        gmi = 3.31 + 0.02392 * mean
        tir = float(np.mean((glu >= 70) & (glu <= 180)) * 100.0)
        tbr1 = float(np.mean(glu < 70) * 100.0)
        tbr2 = float(np.mean(glu < 54) * 100.0)
        tar1 = float(np.mean(glu > 180) * 100.0)
        tar2 = float(np.mean(glu > 250) * 100.0)
        mage = _mage_simple(glu, sd)
        j_idx = 0.001 * (mean + sd) ** 2
        hypo = _hypo_events(g)
        start, end = g["timestamp_KST"].min(), g["timestamp_KST"].max()
        expected = _expected_samples(start, end, interval_min)
        wear = min(100.0, (n / expected) * 100.0) if expected > 0 else float("nan")
        days = max(1, int(((end - start).total_seconds() / 86400.0) + 1))

        rows.append({
            "subject_id": sid,
            "n_readings": n,
            "TIR_70_180_pct": round(tir, 2),
            "TBR_lt70_pct": round(tbr1, 2),
            "TBR_lt54_pct": round(tbr2, 2),
            "TAR_gt180_pct": round(tar1, 2),
            "TAR_gt250_pct": round(tar2, 2),
            "mean_glucose_mg_dl": round(mean, 1),
            "GMI_pct": round(gmi, 2),
            "CV_pct": round(cv, 2),
            "SD_mg_dl": round(sd, 2),
            "MAGE_mg_dl": round(mage, 2) if not math.isnan(mage) else None,
            "J_index": round(j_idx, 2),
            "hypo_events_ge15min": hypo,
            "sensor_wear_pct": round(wear, 2) if not math.isnan(wear) else None,
            "days_with_data": days,
        })

    return pd.DataFrame(rows)


def detect_gaps(df: pd.DataFrame, gap_min: int = 30) -> pd.DataFrame:
    """Return per-subject list of gaps >= gap_min minutes."""
    if df.empty:
        return pd.DataFrame(columns=["subject_id", "gap_start", "gap_end", "gap_minutes"])
    out = []
    df = df.dropna(subset=["timestamp_KST", "glucose_mg_dl"]).copy()
    df["timestamp_KST"] = pd.to_datetime(df["timestamp_KST"])
    for sid, g in df.groupby("subject_id"):
        g = g.sort_values("timestamp_KST").reset_index(drop=True)
        diffs = g["timestamp_KST"].diff().dt.total_seconds() / 60.0
        for i, d in enumerate(diffs):
            if pd.notna(d) and d >= gap_min:
                out.append({
                    "subject_id": sid,
                    "gap_start": g["timestamp_KST"].iloc[i - 1],
                    "gap_end": g["timestamp_KST"].iloc[i],
                    "gap_minutes": round(float(d), 1),
                })
    return pd.DataFrame(out)


def apply_breakin_cut(df: pd.DataFrame, hours: int = 24) -> pd.DataFrame:
    """Drop the first `hours` from each subject's timeline (sensor break-in)."""
    if df.empty:
        return df
    df = df.copy()
    df["timestamp_KST"] = pd.to_datetime(df["timestamp_KST"])
    keep = []
    for sid, g in df.groupby("subject_id"):
        start = g["timestamp_KST"].min()
        cutoff = start + pd.Timedelta(hours=hours)
        keep.append(g[g["timestamp_KST"] >= cutoff])
    if not keep:
        return df.iloc[0:0]
    return pd.concat(keep).reset_index(drop=True)
