"""합성 지방조직 multi-channel WSI 생성 모듈.

채널 구성: (perilipin, F4/80, CD11c, CD206, DAPI)
그룹별 dead adipocyte 비율과 M1/M2 분극 비율을 다르게 설정해
heuristic 기반 CLS detection / polarization 분석을 시연한다.

Reproducible: numpy seed = 42.
"""
from __future__ import annotations

import math
from dataclasses import dataclass, field
from typing import Dict, List, Tuple

import numpy as np


CHANNELS: Tuple[str, ...] = ("perilipin", "F480", "CD11c", "CD206", "DAPI")


# 그룹별 시뮬레이션 파라미터
GROUP_PARAMS: Dict[str, Dict[str, float]] = {
    "control":    {"dead_frac": 0.02, "m1_ratio": 0.30, "adipo_radius_mean": 55.0},
    "HFD":        {"dead_frac": 0.12, "m1_ratio": 0.70, "adipo_radius_mean": 75.0},
    "HFD+drug":   {"dead_frac": 0.05, "m1_ratio": 0.50, "adipo_radius_mean": 60.0},
}

DEPOTS: Tuple[str, ...] = ("eWAT", "iWAT", "BAT")


@dataclass
class SyntheticWSI:
    """합성 지방조직 mini WSI 컨테이너."""

    image: np.ndarray  # shape = (H, W, len(CHANNELS)), float32 [0,1]
    adipocytes: List[Dict]  # cx, cy, r, dead(bool)
    cls_truth: List[Dict]   # cx, cy, r (dead adipocyte 주변)
    nuclei: List[Dict]      # cx, cy, m1(bool)
    group: str
    depot: str
    mouse_id: str
    pixel_um: float = 0.5  # 1 px = 0.5 µm 가정
    meta: Dict = field(default_factory=dict)


def _seeded_rng(seed: int) -> np.random.Generator:
    return np.random.default_rng(seed)


def _draw_disk(canvas: np.ndarray, cx: float, cy: float, r: float, value: float) -> None:
    h, w = canvas.shape[:2]
    y0 = max(int(cy - r) - 1, 0)
    y1 = min(int(cy + r) + 2, h)
    x0 = max(int(cx - r) - 1, 0)
    x1 = min(int(cx + r) + 2, w)
    if y1 <= y0 or x1 <= x0:
        return
    yy, xx = np.ogrid[y0:y1, x0:x1]
    mask = (yy - cy) ** 2 + (xx - cx) ** 2 <= r * r
    canvas[y0:y1, x0:x1][mask] = value


def _draw_ring(canvas: np.ndarray, cx: float, cy: float, r_inner: float,
               r_outer: float, value: float) -> None:
    h, w = canvas.shape[:2]
    y0 = max(int(cy - r_outer) - 1, 0)
    y1 = min(int(cy + r_outer) + 2, h)
    x0 = max(int(cx - r_outer) - 1, 0)
    x1 = min(int(cx + r_outer) + 2, w)
    if y1 <= y0 or x1 <= x0:
        return
    yy, xx = np.ogrid[y0:y1, x0:x1]
    d2 = (yy - cy) ** 2 + (xx - cx) ** 2
    mask = (d2 >= r_inner * r_inner) & (d2 <= r_outer * r_outer)
    canvas[y0:y1, x0:x1][mask] = np.maximum(canvas[y0:y1, x0:x1][mask], value)


def _pack_circles(rng: np.random.Generator, h: int, w: int,
                  radius_mean: float, radius_std: float = 8.0,
                  jitter: float = 4.0, padding: int = 2) -> List[Tuple[float, float, float]]:
    """격자 + 흔들림으로 비겹침 원형 adipocyte 배치를 생성."""
    spacing = radius_mean * 2 + padding
    out: List[Tuple[float, float, float]] = []
    y = radius_mean
    row = 0
    while y <= h + radius_mean * 0.3:
        offset = (row % 2) * (spacing / 2)
        x = radius_mean + offset
        while x <= w + radius_mean * 0.3:
            r = max(15.0, float(rng.normal(radius_mean, radius_std)))
            cx = x + float(rng.uniform(-jitter, jitter))
            cy = y + float(rng.uniform(-jitter, jitter))
            out.append((cx, cy, r))
            x += spacing
        y += spacing * math.sqrt(3) / 2
        row += 1
    return out


def synthesize_wsi(group: str, depot: str, mouse_id: str,
                   size: int = 512, seed: int = 42) -> SyntheticWSI:
    """단일 합성 지방조직 mini WSI 생성."""
    if group not in GROUP_PARAMS:
        raise ValueError(f"unknown group: {group}")
    if depot not in DEPOTS:
        raise ValueError(f"unknown depot: {depot}")

    params = GROUP_PARAMS[group]
    rng = _seeded_rng(seed)

    # depot에 따른 adipocyte 평균 크기 보정
    radius_mean = params["adipo_radius_mean"]
    if depot == "iWAT":
        radius_mean *= 0.9
    elif depot == "BAT":
        radius_mean *= 0.55  # BAT는 multilocular: 의도적으로 작게 표현

    h = w = size
    img = np.zeros((h, w, len(CHANNELS)), dtype=np.float32)
    # 약한 노이즈 배경
    img += rng.normal(0.05, 0.01, img.shape).astype(np.float32)
    img = np.clip(img, 0.0, 1.0)

    circles = _pack_circles(rng, h, w, radius_mean=radius_mean, radius_std=8.0)
    adipocytes: List[Dict] = []
    cls_truth: List[Dict] = []
    nuclei: List[Dict] = []

    dead_frac = params["dead_frac"]
    m1_ratio = params["m1_ratio"]

    for cx, cy, r in circles:
        is_dead = bool(rng.uniform() < dead_frac)
        adipocytes.append({"cx": cx, "cy": cy, "r": r, "dead": is_dead})

        if is_dead:
            # dead adipocyte: perilipin- (drop), F4/80+ shell, nuclei ring
            # perilipin 채널은 0
            # F4/80 shell ring
            _draw_ring(img[..., CHANNELS.index("F480")],
                       cx, cy, r * 1.0, r * 1.25, value=0.85)
            # nuclei (DAPI) 환형 5-10개
            n_nuc = int(rng.integers(5, 11))
            cls_truth.append({"cx": cx, "cy": cy, "r": r, "n_nuclei": n_nuc})
            for k in range(n_nuc):
                theta = 2 * math.pi * k / n_nuc + float(rng.uniform(-0.2, 0.2))
                nr = r * 1.12 + float(rng.uniform(-3, 3))
                nx = cx + nr * math.cos(theta)
                ny = cy + nr * math.sin(theta)
                _draw_disk(img[..., CHANNELS.index("DAPI")], nx, ny, 4.5, 0.95)
                # M1/M2 마커: 같은 위치에 CD11c 또는 CD206
                is_m1 = bool(rng.uniform() < m1_ratio)
                if is_m1:
                    _draw_disk(img[..., CHANNELS.index("CD11c")], nx, ny, 5.0, 0.85)
                else:
                    _draw_disk(img[..., CHANNELS.index("CD206")], nx, ny, 5.0, 0.85)
                nuclei.append({"cx": nx, "cy": ny, "m1": is_m1, "near_dead": True})
        else:
            # 정상 adipocyte: perilipin+ ring (membrane)
            _draw_ring(img[..., CHANNELS.index("perilipin")],
                       cx, cy, r * 0.92, r * 1.0, value=0.9)
            # 가끔 단일 nucleus (stromal) 흩뿌리기
            if rng.uniform() < 0.25:
                theta = float(rng.uniform(0, 2 * math.pi))
                nr = r * 1.08
                nx = cx + nr * math.cos(theta)
                ny = cy + nr * math.sin(theta)
                _draw_disk(img[..., CHANNELS.index("DAPI")], nx, ny, 4.0, 0.7)
                nuclei.append({"cx": nx, "cy": ny, "m1": False, "near_dead": False})

    img = np.clip(img, 0.0, 1.0)

    return SyntheticWSI(
        image=img,
        adipocytes=adipocytes,
        cls_truth=cls_truth,
        nuclei=nuclei,
        group=group,
        depot=depot,
        mouse_id=mouse_id,
        meta={"size": size, "seed": seed, "radius_mean": radius_mean},
    )


def build_demo_cohort(size: int = 512, base_seed: int = 42) -> List[SyntheticWSI]:
    """3 groups × 3 mice × 3 depots = 27 mini WSIs."""
    cohort: List[SyntheticWSI] = []
    seed = base_seed
    for group in GROUP_PARAMS.keys():
        for mouse_idx in range(3):
            mouse_id = f"{group}_m{mouse_idx + 1}"
            for depot in DEPOTS:
                wsi = synthesize_wsi(group=group, depot=depot, mouse_id=mouse_id,
                                     size=size, seed=seed)
                cohort.append(wsi)
                seed += 1
    return cohort
