"""Synthetic perifusion data generator (offline, deterministic with seed).

Generates 60-min, 1-min-resolution traces for 5 channels per file:
    vehicle / GLP-1 100 nM / Exendin-4 10 nM / Glibenclamide 1 uM / Palmitate 0.5 mM 24h

Stimulus protocol:
    0-10  min : low glucose 2.8 mM
    10-40 min : high glucose 16.7 mM
    40-50 min : 30 mM KCl
    50-60 min : low glucose 2.8 mM

Optional analytes: insulin (default), c-peptide, glucagon, proinsulin.
"""
from __future__ import annotations

import csv
import os
from typing import Dict, List, Tuple

import numpy as np


def _insulin_trace(t_grid: np.ndarray, condition: str, rng: np.random.Generator) -> np.ndarray:
    """Synthesize a typical insulin response trace in ng/mL."""
    base = 5.0
    out = np.full_like(t_grid, base, dtype=float)

    # high glucose 1st phase peak ~12 min, ~25 ng/mL (vehicle)
    p1_amp = {"vehicle": 25.0, "glp1": 38.0, "exendin": 36.0, "glibenclamide": 32.0, "palmitate": 12.0}[condition]
    p1_peak_t = 12.0
    sigma1 = 1.6
    p1 = p1_amp * np.exp(-((t_grid - p1_peak_t) ** 2) / (2 * sigma1 ** 2))

    # 2nd phase plateau 15-18 (vehicle)
    plateau = {"vehicle": 16.0, "glp1": 22.0, "exendin": 20.0, "glibenclamide": 28.0, "palmitate": 7.0}[condition]
    onset = 1.0 / (1.0 + np.exp(-(t_grid - 14.0)))  # rises around 14 min
    offset = 1.0 / (1.0 + np.exp((t_grid - 38.0)))  # falls before KCl
    p2 = plateau * onset * offset

    # KCl peak ~40 ng/mL at ~43 min
    kcl_amp = {"vehicle": 40.0, "glp1": 42.0, "exendin": 41.0, "glibenclamide": 38.0, "palmitate": 28.0}[condition]
    kcl_peak_t = 43.0
    sigma_k = 2.0
    k = kcl_amp * np.exp(-((t_grid - kcl_peak_t) ** 2) / (2 * sigma_k ** 2)) * (
        (t_grid >= 40) & (t_grid <= 50)
    )

    # recovery decline back to baseline 50-60
    recovery = base * np.ones_like(t_grid)
    sig = base + p1 + p2 + k
    sig = np.where(t_grid > 50, recovery + 2 * np.exp(-(t_grid - 50) / 2.0), sig)

    # noise
    noise = rng.normal(0, 1.0, size=t_grid.shape)
    sig = sig + noise

    # degradation drift (~5% over 60 min)
    decay = 1.0 - 0.05 * (t_grid / 60.0)
    sig = sig * decay

    sig[sig < 0] = 0
    return sig


def _c_peptide_from_insulin(insulin_trace: np.ndarray, rng: np.random.Generator) -> np.ndarray:
    # 1.0-1.2x with small lag, plus noise
    factor = rng.uniform(1.0, 1.2)
    cp = np.roll(insulin_trace, 1) * factor
    cp[0] = insulin_trace[0] * factor
    cp = cp + rng.normal(0, 0.5, size=insulin_trace.shape)
    return cp


def _glucagon_trace(t_grid: np.ndarray, condition: str, rng: np.random.Generator) -> np.ndarray:
    base = 60.0  # pg/mL
    out = np.full_like(t_grid, base, dtype=float)
    # high glucose suppresses glucagon by ~50% (palmitate impairs suppression)
    suppression = {"vehicle": 0.5, "glp1": 0.4, "exendin": 0.4, "glibenclamide": 0.55, "palmitate": 0.85}[condition]
    high_mask = (t_grid >= 10) & (t_grid <= 40)
    out[high_mask] = base * suppression
    # KCl mildly stimulates glucagon
    kcl_mask = (t_grid >= 40) & (t_grid <= 50)
    out[kcl_mask] = base * 1.2
    out = out + rng.normal(0, 3.0, size=t_grid.shape)
    out[out < 0] = 0
    return out


def _proinsulin_from_insulin(insulin_trace: np.ndarray, rng: np.random.Generator) -> np.ndarray:
    factor = rng.uniform(0.05, 0.10)
    pi = insulin_trace * factor + rng.normal(0, 0.05, size=insulin_trace.shape)
    pi[pi < 0] = 0
    return pi


CHANNEL_LABELS = {
    "vehicle": "Vehicle",
    "glp1": "GLP-1_100nM",
    "exendin": "Exendin-4_10nM",
    "glibenclamide": "Glibenclamide_1uM",
    "palmitate": "Palmitate_0.5mM_24h",
}


def make_file(
    out_path: str,
    sample_kind: str,
    sample_id: str,
    analyte: str = "insulin",
    seed: int = 42,
    vendor: str = "in-house",
) -> str:
    rng = np.random.default_rng(seed)
    t_grid = np.arange(0, 60.0001, 1.0, dtype=float)

    rows = []
    headers = ["Time_min"] + [CHANNEL_LABELS[c] for c in CHANNEL_LABELS]
    rows.append(headers)
    traces: Dict[str, np.ndarray] = {}
    for cond in CHANNEL_LABELS:
        if analyte == "insulin":
            traces[cond] = _insulin_trace(t_grid, cond, rng)
        elif analyte == "c-peptide":
            ins = _insulin_trace(t_grid, cond, rng)
            traces[cond] = _c_peptide_from_insulin(ins, rng)
        elif analyte == "glucagon":
            traces[cond] = _glucagon_trace(t_grid, cond, rng)
        elif analyte == "proinsulin":
            ins = _insulin_trace(t_grid, cond, rng)
            traces[cond] = _proinsulin_from_insulin(ins, rng)
        else:
            traces[cond] = _insulin_trace(t_grid, cond, rng)

    for i, t in enumerate(t_grid):
        row = [f"{t:.2f}"]
        for cond in CHANNEL_LABELS:
            row.append(f"{traces[cond][i]:.3f}")
        rows.append(row)

    # comment metadata
    with open(out_path, "w", newline="", encoding="utf-8") as f:
        f.write(
            f"# vendor={vendor}; sample_kind={sample_kind}; sample_id={sample_id}; "
            f"analyte={analyte}; flow_rate_ml_min=0.1; dead_volume_ul=100; ieq=100\n"
        )
        w = csv.writer(f)
        for r in rows:
            w.writerow(r)
    return out_path


def make_demo_set(out_dir: str) -> List[str]:
    os.makedirs(out_dir, exist_ok=True)
    files = []
    files.append(make_file(os.path.join(out_dir, "primary_mouse_islet_GSIS.csv"),
                           "primary_mouse_islet", "M-001", "insulin", seed=11))
    files.append(make_file(os.path.join(out_dir, "ins1_KCl_response.csv"),
                           "INS-1_832/13", "INS1-A", "insulin", seed=12))
    files.append(make_file(os.path.join(out_dir, "ipsc_scbeta_D21_maturation.csv"),
                           "iPSC-SC-beta_D21", "iPSC-D21", "insulin", seed=13))
    files.append(make_file(os.path.join(out_dir, "primary_human_cpeptide.csv"),
                           "primary_human_islet", "H-007", "c-peptide", seed=14))
    files.append(make_file(os.path.join(out_dir, "primary_mouse_glucagon.csv"),
                           "primary_mouse_islet", "M-001", "glucagon", seed=15))
    return files
