"""PancIsletMass v0 — 합성 췌장 WSI 생성기.

실제 IHC 슬라이드 대신 numpy 기반 합성 multichannel 영상을 만든다.
채널: insulin (β), glucagon (α), somatostatin (δ), Ki67, TUNEL, DAPI.
재현성: numpy seed=42 고정.
"""

from __future__ import annotations

import json
import os
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Tuple

import numpy as np

CHANNELS: Tuple[str, ...] = (
    "insulin",   # β-cell marker
    "glucagon",  # α-cell marker
    "somatostatin",  # δ-cell marker
    "Ki67",      # replication
    "TUNEL",     # apoptosis
    "DAPI",      # nuclei
)

DEFAULT_GROUPS: Tuple[str, ...] = ("control", "HFD", "drug")


@dataclass
class SyntheticSlide:
    """단일 합성 WSI."""

    slide_id: str
    group: str
    mouse_id: str
    pancreas_weight_mg: float
    image_um_per_px: float
    # shape: (C, H, W), float32 in [0, 1]
    channels: np.ndarray
    # ground-truth islet centers in pixel coords (for QA)
    islet_centers: List[Tuple[int, int, int]] = field(default_factory=list)

    def channel_dict(self) -> Dict[str, np.ndarray]:
        return {name: self.channels[i] for i, name in enumerate(CHANNELS)}


def _disk_mask(h: int, w: int, cy: int, cx: int, r: int) -> np.ndarray:
    yy, xx = np.ogrid[:h, :w]
    return (yy - cy) ** 2 + (xx - cx) ** 2 <= r * r


def _gaussian_blob(h: int, w: int, cy: int, cx: int, sigma: float) -> np.ndarray:
    yy, xx = np.mgrid[:h, :w]
    d2 = (yy - cy) ** 2 + (xx - cx) ** 2
    return np.exp(-d2 / (2.0 * sigma * sigma)).astype(np.float32)


def synth_slide(
    slide_id: str,
    group: str,
    mouse_id: str,
    h: int = 256,
    w: int = 256,
    seed: int = 42,
    pancreas_weight_mg: float = 180.0,
    image_um_per_px: float = 1.0,
    n_islets: int = 7,
) -> SyntheticSlide:
    """Generate one synthetic pancreas slice with multi-channel IHC.

    Group effect:
      - control: β:α:δ = 70:20:10, 평균 islet radius 45 px
      - HFD:    β-cell density 1.4x (β:α:δ = 78:14:8), radius 55 px (β-cell expansion)
      - drug:   β proliferation (Ki67↑) 회복, radius 50 px
    """
    rng = np.random.default_rng(seed)

    channels = np.zeros((len(CHANNELS), h, w), dtype=np.float32)
    # background DAPI: 약한 nuclei sprinkle
    channels[CHANNELS.index("DAPI")] = rng.uniform(0.02, 0.08, (h, w)).astype(np.float32)

    # group-specific
    if group == "HFD":
        beta_frac, alpha_frac, delta_frac = 0.78, 0.14, 0.08
        radius_mu = 55
        ki67_rate = 0.04
        tunel_rate = 0.02
    elif group == "drug":
        beta_frac, alpha_frac, delta_frac = 0.74, 0.18, 0.08
        radius_mu = 50
        ki67_rate = 0.08
        tunel_rate = 0.01
    else:  # control
        beta_frac, alpha_frac, delta_frac = 0.70, 0.20, 0.10
        radius_mu = 45
        ki67_rate = 0.03
        tunel_rate = 0.015

    islet_centers: List[Tuple[int, int, int]] = []
    margin = 60
    placed = 0
    attempts = 0
    while placed < n_islets and attempts < n_islets * 20:
        attempts += 1
        cy = int(rng.integers(margin, max(margin + 1, h - margin)))
        cx = int(rng.integers(margin, max(margin + 1, w - margin)))
        r = int(np.clip(rng.normal(radius_mu, radius_mu * 0.15), 25, min(h, w) // 3))
        # avoid overlap with previous
        too_close = False
        for (py, px, pr) in islet_centers:
            if (py - cy) ** 2 + (px - cx) ** 2 < (pr + r) ** 2:
                too_close = True
                break
        if too_close:
            continue
        islet_centers.append((cy, cx, r))
        placed += 1

        # core islet glow on DAPI
        glow = _gaussian_blob(h, w, cy, cx, sigma=r * 0.7) * 0.6
        channels[CHANNELS.index("DAPI")] = np.maximum(
            channels[CHANNELS.index("DAPI")], glow * 0.5
        )

        mask = _disk_mask(h, w, cy, cx, r)
        ys, xs = np.where(mask)
        n_cells = max(20, int(len(ys) / 25))  # rough cell density
        # sample cell positions inside disk
        idx = rng.choice(len(ys), size=n_cells, replace=False)
        cells_y = ys[idx]
        cells_x = xs[idx]

        types = rng.choice(
            ["beta", "alpha", "delta"],
            size=n_cells,
            p=[beta_frac, alpha_frac, delta_frac],
        )
        # core = β, periphery = α/δ (마우스 췌도 architecture 모사)
        for cy_i, cx_i, t in zip(cells_y, cells_x, types):
            d = float(np.sqrt((cy_i - cy) ** 2 + (cx_i - cx) ** 2))
            # cell as small gaussian
            sig = 2.5
            blob = _gaussian_blob(h, w, int(cy_i), int(cx_i), sigma=sig)
            if t == "beta" and d < r * 0.8:
                channels[CHANNELS.index("insulin")] = np.maximum(
                    channels[CHANNELS.index("insulin")], blob * rng.uniform(0.7, 1.0)
                )
            elif t == "alpha" and d > r * 0.4:
                channels[CHANNELS.index("glucagon")] = np.maximum(
                    channels[CHANNELS.index("glucagon")], blob * rng.uniform(0.6, 0.95)
                )
            elif t == "delta" and d > r * 0.4:
                channels[CHANNELS.index("somatostatin")] = np.maximum(
                    channels[CHANNELS.index("somatostatin")],
                    blob * rng.uniform(0.5, 0.9),
                )
            else:
                # mis-typed cell: keep weak signal in nearest correct channel
                ch = {"beta": "insulin", "alpha": "glucagon", "delta": "somatostatin"}[t]
                channels[CHANNELS.index(ch)] = np.maximum(
                    channels[CHANNELS.index(ch)], blob * rng.uniform(0.3, 0.6)
                )
            # Ki67 / TUNEL sparse on β-cells
            if t == "beta":
                if rng.random() < ki67_rate:
                    channels[CHANNELS.index("Ki67")] = np.maximum(
                        channels[CHANNELS.index("Ki67")],
                        blob * rng.uniform(0.6, 1.0),
                    )
                if rng.random() < tunel_rate:
                    channels[CHANNELS.index("TUNEL")] = np.maximum(
                        channels[CHANNELS.index("TUNEL")],
                        blob * rng.uniform(0.5, 0.9),
                    )
        # nuclei in islet
        for cy_i, cx_i in zip(cells_y, cells_x):
            blob = _gaussian_blob(h, w, int(cy_i), int(cx_i), sigma=1.8)
            channels[CHANNELS.index("DAPI")] = np.maximum(
                channels[CHANNELS.index("DAPI")], blob * rng.uniform(0.4, 0.8)
            )

    # gentle blur-like noise + clip
    noise = rng.normal(0, 0.01, channels.shape).astype(np.float32)
    channels = np.clip(channels + noise, 0.0, 1.0)

    return SyntheticSlide(
        slide_id=slide_id,
        group=group,
        mouse_id=mouse_id,
        pancreas_weight_mg=pancreas_weight_mg,
        image_um_per_px=image_um_per_px,
        channels=channels,
        islet_centers=islet_centers,
    )


def build_demo_cohort(
    out_dir: str,
    groups: Tuple[str, ...] = DEFAULT_GROUPS,
    mice_per_group: int = 3,
    slides_per_mouse: int = 5,
    h: int = 256,
    w: int = 256,
) -> Dict:
    """Build a small demo cohort and persist metadata + .npz files.

    Returns the cohort metadata dict.
    """
    os.makedirs(out_dir, exist_ok=True)
    cohort: Dict = {
        "version": "v0",
        "groups": list(groups),
        "channels": list(CHANNELS),
        "image_um_per_px": 1.0,
        "h": h,
        "w": w,
        "slides": [],
    }
    seed_base = 42
    counter = 0
    for g_idx, group in enumerate(groups):
        for m in range(mice_per_group):
            mouse_id = f"{group}-m{m+1:02d}"
            # mouse-level pancreas weight perturbation
            pw = float(180.0 + 5.0 * (g_idx - 1) + (m * 1.7))
            for s in range(slides_per_mouse):
                counter += 1
                slide_id = f"{mouse_id}-s{s+1:02d}"
                slide = synth_slide(
                    slide_id=slide_id,
                    group=group,
                    mouse_id=mouse_id,
                    h=h,
                    w=w,
                    seed=seed_base + counter,
                    pancreas_weight_mg=pw,
                    n_islets=int(6 + (g_idx * 0.4) + (s % 3)),
                )
                npz_path = os.path.join(out_dir, f"{slide_id}.npz")
                np.savez_compressed(
                    npz_path,
                    channels=slide.channels.astype(np.float32),
                    centers=np.array(slide.islet_centers, dtype=np.int32),
                )
                cohort["slides"].append(
                    {
                        "slide_id": slide_id,
                        "group": group,
                        "mouse_id": mouse_id,
                        "pancreas_weight_mg": pw,
                        "npz": os.path.relpath(npz_path, out_dir),
                        "n_islets_gt": len(slide.islet_centers),
                    }
                )
    meta_path = os.path.join(out_dir, "demo_cohort.json")
    with open(meta_path, "w", encoding="utf-8") as f:
        json.dump(cohort, f, ensure_ascii=False, indent=2)
    return cohort


def load_slide_npz(path: str) -> Tuple[np.ndarray, np.ndarray]:
    """Load a saved slide. Returns (channels CxHxW, centers Nx3)."""
    arr = np.load(path)
    return arr["channels"], arr["centers"]
