"""synthetic_generator.py

Generate synthetic Seahorse-style plate CSVs covering 4 protocols.

Design choices (seed=42, reproducible):
- 4 protocols: Mito Stress, Glycolysis Stress, FAO Assay, ATP Rate Assay
- 5 cell types: PMH (primary mouse hep), HepG2, AML12, HepaRG, LX-2 (HSC)
- 5 drugs: resmetirom, pegozafermin, efruxifermin, lanifibranor, vehicle
- 24 wells per plate (6 conditions x 4 replicate wells)
- 12-15 timepoints with 4 injection steps per Mito Stress / Glyc Stress
- baseline + cell-type baseline + drug-specific delta + injection response

Outputs CSVs to data/synthetic/ relative to this file:
- mito_stress.csv
- glycolysis_stress.csv
- fao_assay.csv
- atp_rate.csv

The schema (long-form):
    plate_id, well, group, cell_type, drug, dose, substrate,
    measurement, time_min, injection, ocr, ecar, ppr
"""

from __future__ import annotations

import os
import csv
import math
import random
from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np

SEED = 42

# directory containing this script
HERE = os.path.dirname(os.path.abspath(__file__))
OUT_DIR = os.path.join(HERE, "synthetic")


# ---------------------------------------------------------------------------
# baseline matrices
# ---------------------------------------------------------------------------

CELL_BASELINES: Dict[str, Tuple[float, float]] = {
    # cell_type: (basal_ocr_pmol_O2_min, basal_ecar_mpH_min)
    "PMH":     (180.0, 25.0),   # primary mouse hep — high OXPHOS
    "AML12":   (140.0, 35.0),   # mouse hep line
    "HepG2":   ( 90.0, 75.0),   # warburg-ish
    "HepaRG":  (160.0, 45.0),   # diff hep line
    "LX-2":    ( 70.0, 95.0),   # HSC — glycolytic
}

# Drug effects (multiplicative on basal OCR / ECAR)
DRUG_EFFECT: Dict[str, Tuple[float, float]] = {
    # drug: (ocr_mult, ecar_mult)
    "vehicle":      (1.00, 1.00),
    "resmetirom":   (1.15, 0.95),  # THR-beta -> beta-ox up
    "pegozafermin": (1.20, 0.90),  # FGF21 analog -> FAO up
    "efruxifermin": (1.18, 0.92),  # FGF21 analog
    "lanifibranor": (1.10, 1.05),  # pan-PPAR
}

CELL_TYPES = list(CELL_BASELINES.keys())
DRUGS = list(DRUG_EFFECT.keys())


# ---------------------------------------------------------------------------
# protocol templates (injection sequences + multiplicative response factors)
# ---------------------------------------------------------------------------


@dataclass
class Protocol:
    name: str
    injections: List[str]  # phase labels per measurement
    # phase -> (ocr_factor, ecar_factor) multiplier relative to baseline
    response: Dict[str, Tuple[float, float]]


def _make_phase_labels(phase_sequence: List[Tuple[str, int]]) -> List[str]:
    """Expand [(label, n_meas), ...] into a flat label list."""
    out: List[str] = []
    for lab, n in phase_sequence:
        out.extend([lab] * n)
    return out


PROTOCOLS: Dict[str, Protocol] = {
    "mito_stress": Protocol(
        name="Mito Stress",
        injections=_make_phase_labels([
            ("baseline", 3),
            ("Oligomycin", 3),
            ("FCCP", 3),
            ("Rot/AA", 3),
        ]),
        response={
            "baseline": (1.00, 1.00),
            "Oligomycin": (0.45, 1.30),   # ATP-linked OCR drops, ECAR comp up
            "FCCP": (1.95, 0.95),         # maximal OCR
            "Rot/AA": (0.10, 0.70),       # non-mito floor
        },
    ),
    "glycolysis_stress": Protocol(
        name="Glycolysis Stress",
        injections=_make_phase_labels([
            ("baseline", 3),
            ("Glucose", 3),
            ("Oligomycin", 3),
            ("2-DG", 3),
        ]),
        response={
            "baseline": (1.00, 0.20),     # no glucose -> ECAR floor
            "Glucose": (1.00, 1.00),      # glucose triggers glycolysis
            "Oligomycin": (0.50, 1.60),   # glycolytic capacity
            "2-DG": (0.50, 0.15),         # ECAR back to non-glyc floor
        },
    ),
    "fao_assay": Protocol(
        name="FAO Assay",
        injections=_make_phase_labels([
            ("baseline", 3),
            ("Oligomycin", 3),
            ("FCCP", 3),
            ("Rot/AA", 3),
        ]),
        # FAO assay uses palmitate-BSA substrate; baselines higher for FAO-competent cells
        response={
            "baseline": (1.00, 1.00),
            "Oligomycin": (0.40, 1.25),
            "FCCP": (1.85, 0.90),
            "Rot/AA": (0.12, 0.70),
        },
    ),
    "atp_rate": Protocol(
        name="ATP Rate Assay",
        injections=_make_phase_labels([
            ("baseline", 3),
            ("Oligomycin", 3),
            ("Rot/AA", 3),
        ]),
        response={
            "baseline": (1.00, 1.00),
            "Oligomycin": (0.45, 1.30),
            "Rot/AA": (0.10, 0.70),
        },
    ),
}


# ---------------------------------------------------------------------------
# substrate matrices for FAO assay
# ---------------------------------------------------------------------------

# for FAO assay we vary substrate per well group; basal_ocr multiplier
SUBSTRATE_OCR_MULT: Dict[str, float] = {
    "BSA":                 1.00,
    "Palmitate-BSA":       1.30,
    "Palmitate+Etomoxir":  0.85,
    "Glutamine":           1.05,
    "Glutamine+BPTES":     0.80,
    "Glucose":             1.10,
    "Glucose+UK5099":      0.90,
}


# ---------------------------------------------------------------------------
# plate generator
# ---------------------------------------------------------------------------


def _well_name(i: int) -> str:
    # 96-well: A1..H12 ; we only use up to 24 wells (A1-B12)
    rows = "ABCDEFGH"
    row = i // 12
    col = i % 12 + 1
    return f"{rows[row]}{col}"


def _build_conditions(protocol_key: str) -> List[Dict[str, str]]:
    """Return list of well-condition dicts for a plate.

    Mito/Glyc/ATP: 5 cell types x 5 drugs -> too many; use a fixed cohort:
      6 conditions: (HepG2,vehicle), (HepG2,resmetirom), (AML12,vehicle),
                    (AML12,pegozafermin), (PMH, lanifibranor), (LX-2, efruxifermin)
      x 4 replicate wells = 24 wells
    FAO: 6 substrate conditions for HepG2+vehicle, 4 replicates each.
    """
    if protocol_key == "fao_assay":
        substrates = ["BSA", "Palmitate-BSA", "Palmitate+Etomoxir",
                       "Glutamine", "Glutamine+BPTES", "Glucose+UK5099"]
        return [
            {"cell_type": "HepG2", "drug": "vehicle", "dose": "0", "substrate": s,
             "group": f"HepG2|vehicle|{s}"}
            for s in substrates
        ]
    base_conditions = [
        ("HepG2",   "vehicle",      "0"),
        ("HepG2",   "resmetirom",   "1uM"),
        ("AML12",   "vehicle",      "0"),
        ("AML12",   "pegozafermin", "100ng/mL"),
        ("PMH",     "lanifibranor", "1uM"),
        ("LX-2",    "efruxifermin", "100ng/mL"),
    ]
    return [
        {"cell_type": c, "drug": d, "dose": dose,
         "substrate": "Glucose" if protocol_key == "glycolysis_stress" else "Standard",
         "group": f"{c}|{d}"}
        for c, d, dose in base_conditions
    ]


def generate_plate_csv(protocol_key: str, out_path: str, seed: int = SEED) -> int:
    """Generate one plate CSV; returns number of rows written."""
    rng = np.random.default_rng(seed + hash(protocol_key) % 10000)
    proto = PROTOCOLS[protocol_key]
    conditions = _build_conditions(protocol_key)
    replicates = 4
    n_meas = len(proto.injections)

    rows: List[Dict[str, object]] = []
    well_idx = 0
    for cond in conditions:
        for rep in range(replicates):
            well = _well_name(well_idx)
            well_idx += 1
            cell = cond["cell_type"]
            drug = cond["drug"]
            sub = cond["substrate"]

            base_ocr, base_ecar = CELL_BASELINES[cell]
            ocr_mult, ecar_mult = DRUG_EFFECT[drug]
            base_ocr *= ocr_mult
            base_ecar *= ecar_mult

            # substrate effect (FAO assay)
            if protocol_key == "fao_assay":
                base_ocr *= SUBSTRATE_OCR_MULT.get(sub, 1.0)

            # per-well biological noise
            base_ocr *= rng.normal(1.0, 0.06)
            base_ecar *= rng.normal(1.0, 0.08)

            for m_idx, phase in enumerate(proto.injections, start=1):
                of, ef = proto.response[phase]
                ocr_val = base_ocr * of * rng.normal(1.0, 0.04)
                ecar_val = base_ecar * ef * rng.normal(1.0, 0.05)
                # ensure non-negative
                ocr_val = max(0.0, ocr_val)
                ecar_val = max(0.0, ecar_val)
                time_min = m_idx * 7.0  # ~7 min per measurement
                rows.append({
                    "plate_id": protocol_key,
                    "well": well,
                    "group": cond["group"],
                    "cell_type": cell,
                    "drug": drug,
                    "dose": cond["dose"],
                    "substrate": sub,
                    "measurement": m_idx,
                    "time_min": round(time_min, 2),
                    "injection": phase,
                    "ocr": round(ocr_val, 3),
                    "ecar": round(ecar_val, 3),
                    "ppr": round(ecar_val * 0.8, 3),  # rough proxy
                })

    fields = list(rows[0].keys())
    with open(out_path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fields)
        w.writeheader()
        for r in rows:
            w.writerow(r)
    return len(rows)


def main() -> None:
    os.makedirs(OUT_DIR, exist_ok=True)
    targets = [
        ("mito_stress", "mito_stress.csv"),
        ("glycolysis_stress", "glycolysis_stress.csv"),
        ("fao_assay", "fao_assay.csv"),
        ("atp_rate", "atp_rate.csv"),
    ]
    total = 0
    for key, fname in targets:
        path = os.path.join(OUT_DIR, fname)
        n = generate_plate_csv(key, path)
        total += n
        print(f"[gen] {fname}: {n} rows")
    print(f"[done] wrote {total} rows to {OUT_DIR}")


if __name__ == "__main__":
    main()
