"""Synthetic LD-object dataset generator for demo runs.

Builds a 96-well plate (8 drugs x 6 doses x 2 replicates) with
~25 cells per well and ~10 LDs per cell. Drug profiles are encoded
as fingerprint targets that vary monotonically with dose.
"""

from __future__ import annotations

import os
import numpy as np
import pandas as pd


# ---------------------------------------------------------------------------
# Reference drug fingerprints (target distributions at high dose)
# ---------------------------------------------------------------------------
DRUG_PROFILES = {
    # name: (macro_pct, medium_pct, micro_pct, ld_count_factor, plin1_factor, plin5_factor, manders_m1)
    "vehicle":                  (0.60, 0.30, 0.10, 1.00, 1.00, 0.30, 0.45),
    "resmetirom":               (0.25, 0.35, 0.40, 0.65, 0.80, 1.40, 0.55),
    "DGAT2_inhibitor":          (0.15, 0.30, 0.55, 0.40, 0.60, 0.50, 0.50),
    "SCD1_inhibitor":           (0.50, 0.30, 0.20, 0.85, 1.10, 0.60, 0.45),
    "FGF21_analog":             (0.40, 0.35, 0.25, 0.75, 0.65, 1.50, 0.55),
    "ACC_inhibitor":            (0.30, 0.35, 0.35, 0.55, 0.70, 0.90, 0.50),
    "oleic_palmitic_loading_only": (0.70, 0.25, 0.05, 1.20, 1.10, 0.30, 0.40),
    "perilipin5_inducer":       (0.45, 0.30, 0.25, 0.85, 0.85, 1.80, 0.60),
}

VEHICLE_PROFILE = DRUG_PROFILES["vehicle"]

DOSES_UM = [0.01, 0.1, 1.0, 10.0, 100.0, 1000.0]
DRUGS = list(DRUG_PROFILES.keys())
REPLICATES = [1, 2]


def _wells(n: int) -> list[str]:
    rows = "ABCDEFGH"
    cols = list(range(1, 13))
    out = []
    for r in rows:
        for c in cols:
            out.append(f"{r}{c:02d}")
    return out[:n]


def _dose_response_factor(dose_uM: float, ec50: float = 5.0, hill: float = 1.0) -> float:
    """0 at very low dose, 1 at very high dose."""
    if dose_uM <= 0:
        return 0.0
    return 1.0 / (1.0 + (dose_uM / ec50) ** -hill)


def _interp(vehicle_val: float, target_val: float, frac: float) -> float:
    return vehicle_val + (target_val - vehicle_val) * frac


def build_synthetic(seed: int = 7, n_cells_per_well: int = 25, n_lds_per_cell: int = 10):
    rng = np.random.default_rng(seed)
    n_wells_total = len(DRUGS) * len(DOSES_UM) * len(REPLICATES)  # 8*6*2 = 96
    well_names = _wells(n_wells_total)

    # Build plate map
    plate_rows = []
    cell_meta_rows = []
    ld_rows = []

    # Cell-type assignment cycle
    cell_types = [
        "HepG2",
        "Huh7",
        "primary_mouse",
        "iPSC",
        "HepaRG",
        "spheroid",
        "organoid",
        "STAM_cryosection",
    ]

    idx = 0
    for drug in DRUGS:
        target = DRUG_PROFILES[drug]
        for dose in DOSES_UM:
            for rep in REPLICATES:
                well = well_names[idx]
                idx += 1
                plate_rows.append(
                    {"well": well, "drug": drug, "dose_uM": dose, "replicate": rep}
                )
                ct = cell_types[hash(drug) % len(cell_types)]
                cell_meta_rows.append(
                    {"well": well, "cell_count": n_cells_per_well, "cell_type": ct}
                )

                # Dose-dependent fraction towards target
                frac = _dose_response_factor(dose, ec50=5.0, hill=1.2)
                if drug == "vehicle":
                    frac = 0.0  # vehicle stays at vehicle profile regardless of dose
                # Mix vehicle <-> target
                macro_pct = _interp(VEHICLE_PROFILE[0], target[0], frac)
                medium_pct = _interp(VEHICLE_PROFILE[1], target[1], frac)
                micro_pct = _interp(VEHICLE_PROFILE[2], target[2], frac)
                # Normalise to 1
                tot = macro_pct + medium_pct + micro_pct
                macro_pct, medium_pct, micro_pct = (
                    macro_pct / tot,
                    medium_pct / tot,
                    micro_pct / tot,
                )

                ld_count_factor = _interp(1.0, target[3], frac)
                plin1_factor = _interp(1.0, target[4], frac)
                plin5_factor = _interp(1.0, target[5], frac)
                manders_target = _interp(VEHICLE_PROFILE[6], target[6], frac)

                # Build LD objects
                for c in range(n_cells_per_well):
                    cell_id = f"{well}_c{c:03d}"
                    n_lds = max(1, int(rng.normal(n_lds_per_cell * ld_count_factor, 1.5)))
                    # Choose LD subtype per droplet
                    subtypes = rng.choice(
                        ["macro", "medium", "micro"],
                        size=n_lds,
                        p=[macro_pct, medium_pct, micro_pct],
                    )
                    for k, s in enumerate(subtypes):
                        if s == "macro":
                            diam = rng.normal(7.0, 1.5)
                        elif s == "medium":
                            diam = rng.normal(2.5, 0.8)
                        else:
                            diam = rng.normal(0.6, 0.2)
                        diam = max(0.1, float(diam))
                        area = float(np.pi * (diam / 2.0) ** 2)
                        circ = float(np.clip(rng.normal(0.92, 0.05), 0.5, 1.0))
                        max_int = float(max(0.0, rng.normal(180 + 30 * (1 - frac), 25)))
                        int_int = float(max_int * area * rng.normal(1.0, 0.1))
                        # LD-mito distance: closer when manders_target is higher
                        dist = float(max(0.05, rng.normal(2.5 - manders_target * 1.5, 0.6)))
                        m1 = float(np.clip(rng.normal(manders_target, 0.07), 0.0, 1.0))
                        m2 = float(np.clip(rng.normal(manders_target * 0.95, 0.07), 0.0, 1.0))
                        pearson = float(np.clip(rng.normal(manders_target - 0.05, 0.08), -0.2, 1.0))
                        # PLIN ring intensities
                        plin1 = float(max(0.0, rng.normal(120 * plin1_factor, 18)))
                        plin2 = float(max(0.0, rng.normal(80, 14)))
                        plin3 = float(max(0.0, rng.normal(60, 12)))
                        plin5 = float(max(0.0, rng.normal(40 * plin5_factor, 12)))

                        ld_rows.append(
                            {
                                "well": well,
                                "cell_id": cell_id,
                                "ld_id": f"{cell_id}_l{k:03d}",
                                "area_um2": area,
                                "diameter_um": diam,
                                "circularity": circ,
                                "max_intensity": max_int,
                                "integrated_intensity": int_int,
                                "distance_to_mito_um": dist,
                                "manders_m1": m1,
                                "manders_m2": m2,
                                "pearson": pearson,
                                "plin1_intensity": plin1,
                                "plin2_intensity": plin2,
                                "plin3_intensity": plin3,
                                "plin5_intensity": plin5,
                            }
                        )

    plate_df = pd.DataFrame(plate_rows)
    cell_df = pd.DataFrame(cell_meta_rows)
    ld_df = pd.DataFrame(ld_rows)

    # Reference MOA fingerprint table (one row per reference drug)
    ref_rows = []
    for drug, prof in DRUG_PROFILES.items():
        macro, med, micro, ld_factor, p1, p5, m1 = prof
        ref_rows.append(
            {
                "drug": drug,
                "macro_pct": macro,
                "medium_pct": med,
                "micro_pct": micro,
                "total_LD_area_per_cell": 80.0 * ld_factor,
                "ld_count_per_cell": 10.0 * ld_factor,
                "manders_m1": m1,
                "plin1_mean": 120.0 * p1,
                "plin5_mean": 40.0 * p5,
            }
        )
    ref_df = pd.DataFrame(ref_rows)

    return ld_df, plate_df, cell_df, ref_df


def write_synthetic_to(data_dir: str, seed: int = 7) -> dict:
    os.makedirs(data_dir, exist_ok=True)
    ld, plate, cell, ref = build_synthetic(seed=seed)
    ld_path = os.path.join(data_dir, "ld_objects.csv")
    plate_path = os.path.join(data_dir, "plate_map.csv")
    cell_path = os.path.join(data_dir, "cell_meta.csv")
    ref_path = os.path.join(data_dir, "reference_moa_fingerprints.csv")
    ld.to_csv(ld_path, index=False)
    plate.to_csv(plate_path, index=False)
    cell.to_csv(cell_path, index=False)
    ref.to_csv(ref_path, index=False)
    return {
        "ld_objects": ld_path,
        "plate_map": plate_path,
        "cell_meta": cell_path,
        "reference_moa": ref_path,
    }
