"""Signal-correction utilities for perifusion traces.

Implements:
  - dead-volume transit-time lag correction (volume / flow rate)
  - degradation drift correction (timepoint x storage hours linear regression)
  - baseline subtraction (low-glucose pre-stimulus mean +/- 2SD)
  - inter-channel normalization (per-IEQ / per-cell / per-protein)
  - multi-batch normalization with KCl-peak CV gating
"""
from __future__ import annotations

import math
from typing import Dict, List, Tuple

import numpy as np


# ---------------------------------------------------------------------------
# 1. dead-volume lag
# ---------------------------------------------------------------------------
def transit_time_min(dead_volume_ul: float, flow_rate_ml_min: float) -> float:
    """transit time (min) = dead_volume(uL) / flow_rate(uL/min) ; flow ml/min*1000."""
    if flow_rate_ml_min <= 0:
        return 0.0
    return dead_volume_ul / (flow_rate_ml_min * 1000.0)


def shift_time_axis(time_min: List[float], lag_min: float) -> List[float]:
    """Subtract lag from observed times so signal lines up with stimulus axis."""
    return [t - lag_min for t in time_min]


# ---------------------------------------------------------------------------
# 2. degradation drift
# ---------------------------------------------------------------------------
def degradation_correct(
    values: List[float], time_min: List[float], storage_hours: float, decay_per_hour: float = 0.005
) -> List[float]:
    """Reverse linear decay: corrected = raw / (1 - decay_per_hour * (storage_hours + t/60)).

    Conservative: if storage_hours <=0 returns input.
    """
    if storage_hours <= 0:
        return list(values)
    out = []
    for v, t in zip(values, time_min):
        denom = max(1e-6, 1.0 - decay_per_hour * (storage_hours + t / 60.0))
        out.append(v / denom)
    return out


# ---------------------------------------------------------------------------
# 3. baseline subtraction (low-glucose pre-stimulus)
# ---------------------------------------------------------------------------
def baseline_window(
    time_min: List[float], values: List[float], pre_t_end_min: float = 10.0
) -> Tuple[float, float]:
    """Return (mean, sd) of pre-stimulus window."""
    arr = np.array(
        [v for t, v in zip(time_min, values) if t < pre_t_end_min and not math.isnan(v)],
        dtype=float,
    )
    if arr.size == 0:
        return 0.0, 0.0
    return float(arr.mean()), float(arr.std(ddof=0))


def subtract_baseline(values: List[float], baseline: float) -> List[float]:
    return [v - baseline for v in values]


# ---------------------------------------------------------------------------
# 4. inter-channel normalization
# ---------------------------------------------------------------------------
def per_ieq(values: List[float], ieq: float) -> List[float]:
    if ieq is None or ieq <= 0:
        return list(values)
    return [v / ieq for v in values]


def per_cell(values: List[float], cells: float) -> List[float]:
    if cells is None or cells <= 0:
        return list(values)
    # normalize to per 1000 cells for readability
    return [v / (cells / 1000.0) for v in values]


def per_protein(values: List[float], protein_ug: float) -> List[float]:
    if protein_ug is None or protein_ug <= 0:
        return list(values)
    return [v / protein_ug for v in values]


# ---------------------------------------------------------------------------
# 5. multi-batch normalization via KCl peak
# ---------------------------------------------------------------------------
def kcl_peak(time_min: List[float], values: List[float], kcl_t_start: float, kcl_t_end: float) -> float:
    arr = np.array(
        [v for t, v in zip(time_min, values) if kcl_t_start <= t <= kcl_t_end and not math.isnan(v)],
        dtype=float,
    )
    if arr.size == 0:
        return float("nan")
    return float(arr.max())


def kcl_cv_pct(peaks: List[float]) -> float:
    arr = np.array([p for p in peaks if not math.isnan(p)], dtype=float)
    if arr.size < 2 or arr.mean() == 0:
        return float("nan")
    return float(100.0 * arr.std(ddof=1) / arr.mean())


def batch_normalize_to_kcl(
    channels: Dict[str, List[float]],
    time_min: List[float],
    kcl_t_start: float,
    kcl_t_end: float,
    target_peak: float = 40.0,
) -> Tuple[Dict[str, List[float]], Dict[str, float], float]:
    """Scale every channel so its KCl peak == target_peak.

    Returns (scaled_channels, scale_factors, observed_cv_pct).
    """
    peaks = {ch: kcl_peak(time_min, v, kcl_t_start, kcl_t_end) for ch, v in channels.items()}
    scaled: Dict[str, List[float]] = {}
    scale_factors: Dict[str, float] = {}
    for ch, vals in channels.items():
        p = peaks[ch]
        if not math.isnan(p) and p > 0:
            sf = target_peak / p
        else:
            sf = 1.0
        scale_factors[ch] = sf
        scaled[ch] = [v * sf for v in vals]
    cv = kcl_cv_pct(list(peaks.values()))
    return scaled, scale_factors, cv


def kcl_pass(cv_pct: float, threshold_pct: float = 15.0) -> bool:
    if math.isnan(cv_pct):
        return False
    return cv_pct <= threshold_pct
