"""Kinetic-parameter computation for perifusion traces.

Conventions (re: the standard 60-min protocol used in this MVP):
    0  - 10  min : low glucose 2.8 mM   (basal / pre-stimulus)
    10 - 40  min : high glucose 16.7 mM (1st phase 10-20, 2nd phase 20-40)
    40 - 50  min : KCl 30 mM
    50 - 60  min : low glucose 2.8 mM   (recovery)

All inputs are aligned, lag-corrected, baseline-subtracted (or raw - context
explicit). 1st phase = 10-20 min (first 10 min of stimulus), 2nd phase = 20-40 min.
AUC windows: 0-10 (basal), 10-30, 30-60.
"""
from __future__ import annotations

import math
from typing import Dict, List, Optional, Tuple

import numpy as np


# ---------------------------------------------------------------------------
# helpers
# ---------------------------------------------------------------------------
def _slice(t: List[float], v: List[float], t0: float, t1: float):
    tt, vv = [], []
    for ti, vi in zip(t, v):
        if t0 <= ti <= t1 and not math.isnan(vi):
            tt.append(ti)
            vv.append(vi)
    return np.array(tt, dtype=float), np.array(vv, dtype=float)


def _trapz(t: np.ndarray, v: np.ndarray) -> float:
    if t.size < 2:
        return 0.0
    return float(np.trapz(v, t))


def _linregress_slope(t: np.ndarray, v: np.ndarray) -> float:
    if t.size < 2:
        return 0.0
    A = np.vstack([t, np.ones_like(t)]).T
    m, _ = np.linalg.lstsq(A, v, rcond=None)[0]
    return float(m)


# ---------------------------------------------------------------------------
# kinetic parameters per channel
# ---------------------------------------------------------------------------
def kinetic_params(time_min: List[float], values: List[float]) -> Dict[str, float]:
    out: Dict[str, float] = {}

    # basal mean (0-10)
    _, v_base = _slice(time_min, values, 0, 10)
    basal = float(v_base.mean()) if v_base.size else 0.0
    out["basal_mean"] = basal

    # 1st phase (10-20)
    t1, v1 = _slice(time_min, values, 10, 20)
    if v1.size:
        peak_idx = int(np.argmax(v1))
        out["phase1_peak"] = float(v1[peak_idx])
        out["phase1_time_to_peak_min"] = float(t1[peak_idx])
        out["phase1_slope"] = _linregress_slope(t1, v1)
    else:
        out["phase1_peak"] = float("nan")
        out["phase1_time_to_peak_min"] = float("nan")
        out["phase1_slope"] = float("nan")

    # 2nd phase (20-40)
    t2, v2 = _slice(time_min, values, 20, 40)
    if v2.size:
        out["phase2_plateau"] = float(np.median(v2))
        out["phase2_slope"] = _linregress_slope(t2, v2)
        out["phase2_steady_state"] = float(v2[-min(5, v2.size) :].mean())
    else:
        out["phase2_plateau"] = float("nan")
        out["phase2_slope"] = float("nan")
        out["phase2_steady_state"] = float("nan")

    # KCl phase (40-50)
    tk, vk = _slice(time_min, values, 40, 50)
    if vk.size:
        out["kcl_peak"] = float(vk.max())
        out["kcl_time_to_peak_min"] = float(tk[int(np.argmax(vk))])
    else:
        out["kcl_peak"] = float("nan")
        out["kcl_time_to_peak_min"] = float("nan")

    # AUC trapezoidal
    out["auc_0_10"] = _trapz(*_slice(time_min, values, 0, 10))
    out["auc_10_30"] = _trapz(*_slice(time_min, values, 10, 30))
    out["auc_30_60"] = _trapz(*_slice(time_min, values, 30, 60))

    # Fold-change vs basal
    if basal > 0:
        out["fold_change_phase1"] = out["phase1_peak"] / basal if not math.isnan(out["phase1_peak"]) else float("nan")
        out["fold_change_phase2"] = (
            out["phase2_plateau"] / basal if not math.isnan(out["phase2_plateau"]) else float("nan")
        )
        out["fold_change_kcl"] = out["kcl_peak"] / basal if not math.isnan(out["kcl_peak"]) else float("nan")
    else:
        out["fold_change_phase1"] = float("nan")
        out["fold_change_phase2"] = float("nan")
        out["fold_change_kcl"] = float("nan")

    # GSIS ratio = high glucose mean / basal mean
    _, v_high = _slice(time_min, values, 10, 40)
    high_mean = float(v_high.mean()) if v_high.size else 0.0
    out["gsis_ratio_16p7_2p8"] = (high_mean / basal) if basal > 0 else float("nan")

    # KSIS ratio = KCl peak / glucose peak
    if not math.isnan(out["kcl_peak"]) and not math.isnan(out["phase1_peak"]) and out["phase1_peak"] > 0:
        out["ksis_ratio_kcl_glucose"] = out["kcl_peak"] / out["phase1_peak"]
    else:
        out["ksis_ratio_kcl_glucose"] = float("nan")

    return out


# ---------------------------------------------------------------------------
# Cross-condition derived metrics
# ---------------------------------------------------------------------------
def glp1_potentiation_index(
    vehicle_params: Dict[str, float], glp1_params: Dict[str, float]
) -> float:
    """AUC10-30(GLP-1) / AUC10-30(vehicle)."""
    base = vehicle_params.get("auc_10_30", 0.0)
    if base <= 0:
        return float("nan")
    return glp1_params.get("auc_10_30", 0.0) / base


def lipotoxicity_delta(
    vehicle_params: Dict[str, float], palmitate_params: Dict[str, float]
) -> Dict[str, float]:
    """Compare GSIS ratio + 1st-phase peak loss vs vehicle."""
    out = {
        "delta_gsis_ratio": palmitate_params.get("gsis_ratio_16p7_2p8", float("nan"))
        - vehicle_params.get("gsis_ratio_16p7_2p8", float("nan")),
        "ratio_phase1_peak": (
            palmitate_params.get("phase1_peak", 0.0) / vehicle_params.get("phase1_peak", 1.0)
            if vehicle_params.get("phase1_peak", 0.0) > 0
            else float("nan")
        ),
    }
    return out


def proinsulin_insulin_ratio(
    proinsulin_params: Dict[str, float], insulin_params: Dict[str, float]
) -> Dict[str, float]:
    out = {}
    for key in ("basal_mean", "phase1_peak", "phase2_plateau", "auc_10_30"):
        ins = insulin_params.get(key, 0.0)
        pro = proinsulin_params.get(key, 0.0)
        out[f"proinsulin_insulin_{key}"] = (pro / ins) if ins > 0 else float("nan")
    return out
