"""
analysis.py
===========
RodentCGMTel 핵심 metric 계산 모듈 (pure pandas/numpy/scipy).
Streamlit 없이 import 가능. CLI/배치 분석에서도 그대로 사용.

기능:
- Multi-format raw ingest (Ponemah/Empatica/Medtronic/Eversense/Dexcom rodent)
- Animal-adapted metrics: TIR/TAR/TBR, MAGE, CV%, MODD, GMI, dawn phenomenon,
  nocturnal/diurnal split, drug action time-window
- Challenge analysis: AUC/iAUC/peak BG/Tmax/Tlate/recovery slope
- Cohort summary + group comparison (mixed-effect or ANOVA fallback)

DISCLAIMER: 연구·참고용. 임상의사결정용 아님.
"""

from __future__ import annotations

import math
from dataclasses import dataclass, field, asdict
from typing import Optional, Sequence

import numpy as np
import pandas as pd
from scipy import stats

try:
    from species_reference import (
        SPECIES_REFERENCE,
        DEFAULT_LIGHT_CYCLE,
        get_reference,
    )
except ImportError:  # 패키지화 시도
    from .species_reference import (  # type: ignore
        SPECIES_REFERENCE,
        DEFAULT_LIGHT_CYCLE,
        get_reference,
    )


# ------------------------------------------------------------------
# 1. INGEST: multi-format raw → 표준 schema
# ------------------------------------------------------------------
# 표준 schema: columns = [timestamp, animal_id, model, bg_mgdl, group, source_format]

STANDARD_COLUMNS = ["timestamp", "animal_id", "model", "bg_mgdl", "group", "source_format"]


def _coerce_bg_to_mgdl(values: pd.Series, unit: str) -> pd.Series:
    unit = unit.lower().strip()
    if unit in ("mg/dl", "mgdl", "mg_dl"):
        return values.astype(float)
    if unit in ("mmol/l", "mmoll", "mmol_l"):
        return values.astype(float) * 18.0
    raise ValueError(f"Unknown BG unit: {unit}")


def load_ponemah(path: str, animal_id: str, model: str, group: str = "exp",
                 bg_col: str = "Glucose_mgdl", ts_col: str = "Time") -> pd.DataFrame:
    """Ponemah .txt/.csv (DSI telemetry export, tab 또는 comma sep)."""
    sep = "\t" if path.endswith(".txt") else ","
    df = pd.read_csv(path, sep=sep)
    out = pd.DataFrame({
        "timestamp": pd.to_datetime(df[ts_col]),
        "animal_id": animal_id,
        "model": model,
        "bg_mgdl": df[bg_col].astype(float),
        "group": group,
        "source_format": "ponemah",
    })
    return out


def load_empatica(path: str, animal_id: str, model: str, group: str = "exp") -> pd.DataFrame:
    """Empatica E4-style csv: header row=start_unix, second=Hz, then values."""
    with open(path, "r") as f:
        start = float(f.readline().strip().split(",")[0])
        hz = float(f.readline().strip().split(",")[0])
    vals = pd.read_csv(path, skiprows=2, header=None).iloc[:, 0].astype(float)
    interval_s = 1.0 / hz if hz > 0 else 300.0
    ts = pd.to_datetime(start + np.arange(len(vals)) * interval_s, unit="s")
    return pd.DataFrame({
        "timestamp": ts,
        "animal_id": animal_id,
        "model": model,
        "bg_mgdl": vals,
        "group": group,
        "source_format": "empatica",
    })


def load_medtronic(path: str, animal_id: str, model: str, group: str = "exp") -> pd.DataFrame:
    """Medtronic Guardian rodent-adapted CSV: cols Date, Time, Sensor Glucose (mg/dL)."""
    df = pd.read_csv(path)
    ts = pd.to_datetime(df["Date"].astype(str) + " " + df["Time"].astype(str))
    return pd.DataFrame({
        "timestamp": ts,
        "animal_id": animal_id,
        "model": model,
        "bg_mgdl": df["Sensor Glucose (mg/dL)"].astype(float),
        "group": group,
        "source_format": "medtronic",
    })


def load_eversense(path: str, animal_id: str, model: str, group: str = "exp") -> pd.DataFrame:
    """Eversense rodent variant CSV: Timestamp, Glucose_mmol_per_L."""
    df = pd.read_csv(path)
    return pd.DataFrame({
        "timestamp": pd.to_datetime(df["Timestamp"]),
        "animal_id": animal_id,
        "model": model,
        "bg_mgdl": _coerce_bg_to_mgdl(df["Glucose_mmol_per_L"], "mmol/L"),
        "group": group,
        "source_format": "eversense",
    })


def load_dexcom(path: str, animal_id: str, model: str, group: str = "exp") -> pd.DataFrame:
    """Dexcom G6 rodent: Timestamp (YYYY-MM-DDTHH:MM:SS), Glucose Value (mg/dL)."""
    df = pd.read_csv(path)
    ts_col = "Timestamp (YYYY-MM-DDThh:mm:ss)" if "Timestamp (YYYY-MM-DDThh:mm:ss)" in df.columns else "Timestamp"
    bg_col = "Glucose Value (mg/dL)" if "Glucose Value (mg/dL)" in df.columns else "Glucose_mgdl"
    return pd.DataFrame({
        "timestamp": pd.to_datetime(df[ts_col]),
        "animal_id": animal_id,
        "model": model,
        "bg_mgdl": df[bg_col].astype(float),
        "group": group,
        "source_format": "dexcom",
    })


def load_standard_csv(path: str) -> pd.DataFrame:
    """이미 표준 schema로 저장된 CSV (RodentCGMTel 합성/export 포맷)."""
    df = pd.read_csv(path, parse_dates=["timestamp"])
    missing = [c for c in STANDARD_COLUMNS if c not in df.columns]
    if missing:
        raise ValueError(f"Standard CSV missing columns: {missing}")
    return df[STANDARD_COLUMNS].copy()


FORMAT_LOADERS = {
    "ponemah": load_ponemah,
    "empatica": load_empatica,
    "medtronic": load_medtronic,
    "eversense": load_eversense,
    "dexcom": load_dexcom,
    "standard": load_standard_csv,
}


def detect_format(path: str) -> str:
    """간단한 포맷 감지 (헤더 시그니처)."""
    try:
        with open(path, "r") as f:
            head = "\n".join([next(f) for _ in range(3)])
    except StopIteration:
        head = ""
    if "Sensor Glucose" in head:
        return "medtronic"
    if "Glucose_mmol_per_L" in head:
        return "eversense"
    if "Glucose Value" in head or "Dexcom" in head:
        return "dexcom"
    if "animal_id" in head and "bg_mgdl" in head:
        return "standard"
    if head.startswith(("0,", "1,")) or head[:8].replace(".", "").replace(",", "").isdigit():
        return "empatica"
    return "ponemah"


# ------------------------------------------------------------------
# 2. CORE METRICS
# ------------------------------------------------------------------
@dataclass
class GlycemicMetrics:
    animal_id: str
    model: str
    group: str
    n_samples: int
    duration_hr: float
    mean_bg: float
    median_bg: float
    sd_bg: float
    cv_pct: float
    tir_pct: float
    tar_pct: float
    tbr_pct: float
    severe_hyper_pct: float
    severe_hypo_pct: float
    mage: float
    modd: float
    gmi: float
    dawn_delta: float
    nocturnal_mean: float
    diurnal_mean: float
    nocturnal_diurnal_ratio: float

    def to_dict(self) -> dict:
        return asdict(self)


def _basic_stats(bg: np.ndarray) -> tuple:
    mean = float(np.nanmean(bg))
    median = float(np.nanmedian(bg))
    sd = float(np.nanstd(bg, ddof=1)) if bg.size > 1 else 0.0
    cv = 100.0 * sd / mean if mean > 0 else 0.0
    return mean, median, sd, cv


def time_in_range(bg: np.ndarray, low: float, high: float) -> tuple:
    """TIR/TAR/TBR(%)."""
    valid = ~np.isnan(bg)
    n = int(valid.sum())
    if n == 0:
        return 0.0, 0.0, 0.0
    tir = 100.0 * np.sum((bg[valid] >= low) & (bg[valid] <= high)) / n
    tar = 100.0 * np.sum(bg[valid] > high) / n
    tbr = 100.0 * np.sum(bg[valid] < low) / n
    return float(tir), float(tar), float(tbr)


def severe_fractions(bg: np.ndarray, severe_hypo: float, severe_hyper: float) -> tuple:
    valid = ~np.isnan(bg)
    n = int(valid.sum())
    if n == 0:
        return 0.0, 0.0
    s_hypo = 100.0 * np.sum(bg[valid] < severe_hypo) / n
    s_hyper = 100.0 * np.sum(bg[valid] > severe_hyper) / n
    return float(s_hypo), float(s_hyper)


def calculate_mage(bg: np.ndarray) -> float:
    """Mean Amplitude of Glycemic Excursions (Service 1970 단순화 버전).

    > 1 SD 의 turning-point peak-to-nadir 변동의 평균."""
    bg = bg[~np.isnan(bg)]
    if bg.size < 5:
        return 0.0
    sd = float(np.std(bg, ddof=1))
    if sd == 0:
        return 0.0
    # turning points
    diffs = np.diff(bg)
    signs = np.sign(diffs)
    # zero 처리
    signs[signs == 0] = 1
    turn_idx = np.where(np.diff(signs) != 0)[0] + 1
    tps = np.concatenate(([0], turn_idx, [len(bg) - 1]))
    excursions = []
    for i in range(len(tps) - 1):
        amp = abs(bg[tps[i + 1]] - bg[tps[i]])
        if amp > sd:
            excursions.append(amp)
    return float(np.mean(excursions)) if excursions else 0.0


def calculate_modd(df: pd.DataFrame) -> float:
    """Mean of Daily Differences: 24h 간격 동일 timestamp의 BG 차 평균.

    df: timestamp 정렬된 단일 animal 데이터. 5min interval 가정 → 24h=288 lag.
    """
    if df.empty:
        return 0.0
    df = df.sort_values("timestamp").reset_index(drop=True)
    delta = (df["timestamp"].iloc[1] - df["timestamp"].iloc[0]).total_seconds()
    if delta <= 0:
        return 0.0
    lag = max(1, int(round(86400.0 / delta)))
    if len(df) <= lag:
        return 0.0
    diffs = (df["bg_mgdl"].iloc[lag:].values - df["bg_mgdl"].iloc[:-lag].values)
    diffs = diffs[~np.isnan(diffs)]
    return float(np.mean(np.abs(diffs))) if diffs.size else 0.0


def gmi_from_mean(mean_bg: float) -> float:
    """Glucose Management Indicator (Bergenstal 2018, 인간식). 동물에서는 참고용."""
    return 3.31 + 0.02392 * mean_bg


def dawn_phenomenon(df: pd.DataFrame, window_start_hr: int = 4, window_end_hr: int = 8) -> float:
    """4-8 AM 평균 BG - 0-4 AM 평균 BG 의 delta(mg/dL).

    rodent는 nocturnal feeder이지만 dawn-like 상승이 일부 모델에서 관찰됨."""
    if df.empty:
        return 0.0
    hours = df["timestamp"].dt.hour
    early = df.loc[(hours >= 0) & (hours < window_start_hr), "bg_mgdl"]
    dawn = df.loc[(hours >= window_start_hr) & (hours < window_end_hr), "bg_mgdl"]
    if early.empty or dawn.empty:
        return 0.0
    return float(dawn.mean() - early.mean())


def nocturnal_diurnal_split(df: pd.DataFrame, lights_on: int = 7, lights_off: int = 19) -> tuple:
    """야간(다크, 활동기)/주간(라이트, 휴식기) 평균 BG, 비율 반환."""
    if df.empty:
        return 0.0, 0.0, 1.0
    hours = df["timestamp"].dt.hour
    # lights_on..lights_off 가 light phase (낮)
    if lights_on < lights_off:
        light_mask = (hours >= lights_on) & (hours < lights_off)
    else:
        light_mask = (hours >= lights_on) | (hours < lights_off)
    dark = df.loc[~light_mask, "bg_mgdl"]
    light = df.loc[light_mask, "bg_mgdl"]
    n_mean = float(dark.mean()) if not dark.empty else 0.0
    d_mean = float(light.mean()) if not light.empty else 0.0
    ratio = (n_mean / d_mean) if d_mean > 0 else 1.0
    return n_mean, d_mean, ratio


def compute_metrics_for_animal(df_animal: pd.DataFrame,
                                model_key: Optional[str] = None,
                                tir_range: Optional[tuple] = None,
                                lights_on: int = 7,
                                lights_off: int = 19) -> GlycemicMetrics:
    """단일 동물의 모든 metric 산출."""
    if df_animal.empty:
        raise ValueError("Empty animal dataframe")

    df_animal = df_animal.sort_values("timestamp").reset_index(drop=True)
    bg = df_animal["bg_mgdl"].to_numpy(dtype=float)

    model = df_animal["model"].iloc[0] if model_key is None else model_key
    ref = SPECIES_REFERENCE.get(model, None)
    if tir_range is None:
        tir_range = ref["tir_target"] if ref else (80, 180)
    hypo = ref["severe_hypo"] if ref else 50
    hyper = ref["severe_hyper"] if ref else 300

    mean, median, sd, cv = _basic_stats(bg)
    tir, tar, tbr = time_in_range(bg, tir_range[0], tir_range[1])
    s_hypo, s_hyper = severe_fractions(bg, hypo, hyper)
    mage = calculate_mage(bg)
    modd = calculate_modd(df_animal)
    gmi = gmi_from_mean(mean)
    dawn = dawn_phenomenon(df_animal)
    noc, diu, ratio = nocturnal_diurnal_split(df_animal, lights_on, lights_off)

    duration_hr = (df_animal["timestamp"].iloc[-1] - df_animal["timestamp"].iloc[0]).total_seconds() / 3600.0

    return GlycemicMetrics(
        animal_id=str(df_animal["animal_id"].iloc[0]),
        model=str(model),
        group=str(df_animal["group"].iloc[0]) if "group" in df_animal.columns else "exp",
        n_samples=int(bg.size),
        duration_hr=float(duration_hr),
        mean_bg=mean,
        median_bg=median,
        sd_bg=sd,
        cv_pct=cv,
        tir_pct=tir,
        tar_pct=tar,
        tbr_pct=tbr,
        severe_hyper_pct=s_hyper,
        severe_hypo_pct=s_hypo,
        mage=mage,
        modd=modd,
        gmi=gmi,
        dawn_delta=dawn,
        nocturnal_mean=noc,
        diurnal_mean=diu,
        nocturnal_diurnal_ratio=ratio,
    )


def compute_metrics_cohort(df: pd.DataFrame, **kwargs) -> pd.DataFrame:
    """전체 cohort에 대해 animal_id 별 metric 산출 → DataFrame."""
    rows = []
    for aid, sub in df.groupby("animal_id"):
        try:
            m = compute_metrics_for_animal(sub, **kwargs)
            rows.append(m.to_dict())
        except ValueError:
            continue
    return pd.DataFrame(rows)


# ------------------------------------------------------------------
# 3. CHALLENGE (GTT/ITT/MTT) ANALYSIS
# ------------------------------------------------------------------
@dataclass
class ChallengeResult:
    animal_id: str
    model: str
    challenge_type: str
    baseline_bg: float
    peak_bg: float
    t_peak_min: float
    t_late_min: float
    auc_total: float
    auc_incremental: float
    recovery_slope: float
    returned_to_baseline: bool


def analyze_challenge(bg_series: Sequence[float],
                      times_min: Sequence[float],
                      animal_id: str = "",
                      model: str = "",
                      challenge_type: str = "GTT",
                      baseline_tolerance_pct: float = 10.0) -> ChallengeResult:
    """Challenge curve 분석.

    bg_series: 0,15,30,...,120 min 등 sampling 시점의 BG.
    times_min: 위 sampling 시간(분).
    """
    bg = np.asarray(bg_series, dtype=float)
    t = np.asarray(times_min, dtype=float)
    if bg.size < 2 or bg.size != t.size:
        raise ValueError("bg_series and times_min must be same length >=2")

    baseline = float(bg[0])
    peak_idx = int(np.argmax(bg))
    peak_bg = float(bg[peak_idx])
    t_peak = float(t[peak_idx])

    # AUC (trapezoidal, mg/dL · min)
    auc_total = float(np.trapz(bg, t))
    auc_incr = float(np.trapz(np.maximum(bg - baseline, 0.0), t))

    # late time-point: 120min 또는 마지막
    t_late = float(t[-1])

    # recovery slope: peak → end, mg/dL per min (음수면 회복)
    if t[-1] > t_peak:
        recovery_slope = (bg[-1] - peak_bg) / (t[-1] - t_peak)
    else:
        recovery_slope = 0.0

    returned = abs(bg[-1] - baseline) <= baseline * baseline_tolerance_pct / 100.0

    return ChallengeResult(
        animal_id=animal_id,
        model=model,
        challenge_type=challenge_type,
        baseline_bg=baseline,
        peak_bg=peak_bg,
        t_peak_min=t_peak,
        t_late_min=t_late,
        auc_total=auc_total,
        auc_incremental=auc_incr,
        recovery_slope=float(recovery_slope),
        returned_to_baseline=bool(returned),
    )


def cohort_challenge_summary(challenges: list) -> pd.DataFrame:
    return pd.DataFrame([asdict(c) for c in challenges])


# ------------------------------------------------------------------
# 4. GROUP COMPARISON
# ------------------------------------------------------------------
def compare_groups(metrics_df: pd.DataFrame, metric: str,
                   group_col: str = "group") -> dict:
    """그룹간 비교. statsmodels 미설치 환경 대비 scipy ANOVA + Tukey 근사 fallback."""
    if metric not in metrics_df.columns:
        raise KeyError(f"Metric {metric} not in metrics_df")
    groups = metrics_df[group_col].unique().tolist()
    arrays = [metrics_df.loc[metrics_df[group_col] == g, metric].dropna().to_numpy()
              for g in groups]
    arrays = [a for a in arrays if a.size > 0]
    if len(arrays) < 2:
        return {"method": "insufficient_groups", "p_value": float("nan"), "groups": groups}

    # one-way ANOVA
    try:
        f_stat, p_val = stats.f_oneway(*arrays)
        method = "one_way_anova"
    except Exception:  # noqa: BLE001
        f_stat, p_val = float("nan"), float("nan")
        method = "anova_failed"

    # pairwise Welch t-test (Bonferroni)
    pairwise = []
    for i in range(len(groups)):
        for j in range(i + 1, len(groups)):
            if arrays[i].size and arrays[j].size:
                t_stat, t_p = stats.ttest_ind(arrays[i], arrays[j], equal_var=False)
                pairwise.append({
                    "group_a": groups[i],
                    "group_b": groups[j],
                    "mean_a": float(np.mean(arrays[i])),
                    "mean_b": float(np.mean(arrays[j])),
                    "t_stat": float(t_stat),
                    "p_value": float(t_p),
                })
    n_compare = max(1, len(pairwise))
    for row in pairwise:
        row["p_bonferroni"] = min(1.0, row["p_value"] * n_compare)

    return {
        "method": method,
        "f_stat": float(f_stat) if not math.isnan(f_stat) else None,
        "p_value": float(p_val) if not math.isnan(p_val) else None,
        "groups": groups,
        "pairwise": pairwise,
    }


# ------------------------------------------------------------------
# 5. COHORT HEATMAP DATA (24h × animal)
# ------------------------------------------------------------------
def cohort_diurnal_matrix(df: pd.DataFrame, bin_minutes: int = 30) -> pd.DataFrame:
    """동물 × hour-of-day bin BG mean matrix."""
    if df.empty:
        return pd.DataFrame()
    work = df.copy()
    work["minute_of_day"] = work["timestamp"].dt.hour * 60 + work["timestamp"].dt.minute
    work["bin"] = (work["minute_of_day"] // bin_minutes) * bin_minutes
    pivot = work.pivot_table(index="animal_id", columns="bin",
                              values="bg_mgdl", aggfunc="mean")
    return pivot


# ------------------------------------------------------------------
# 6. DRUG ACTION TIME-WINDOW
# ------------------------------------------------------------------
def drug_action_window(df: pd.DataFrame, dose_time: pd.Timestamp,
                       window_hr: float = 6.0) -> dict:
    """투약 시점 기준 window 내 BG nadir/peak/Δ 계산."""
    end = dose_time + pd.Timedelta(hours=window_hr)
    win = df[(df["timestamp"] >= dose_time) & (df["timestamp"] <= end)]
    if win.empty:
        return {"baseline": None, "nadir": None, "peak": None, "delta_nadir": None}
    pre = df[df["timestamp"] < dose_time].tail(3)["bg_mgdl"].mean()
    nadir = float(win["bg_mgdl"].min())
    peak = float(win["bg_mgdl"].max())
    return {
        "baseline": float(pre) if not np.isnan(pre) else None,
        "nadir": nadir,
        "peak": peak,
        "t_nadir_min": float((win.loc[win["bg_mgdl"].idxmin(), "timestamp"] - dose_time).total_seconds() / 60.0),
        "t_peak_min": float((win.loc[win["bg_mgdl"].idxmax(), "timestamp"] - dose_time).total_seconds() / 60.0),
        "delta_nadir": (nadir - float(pre)) if not np.isnan(pre) else None,
        "delta_peak": (peak - float(pre)) if not np.isnan(pre) else None,
    }
