"""합성 간조직 paired WSI 생성 모듈.

512x512 numpy array 기반.
- portal tract = 3-요소 cluster (portal vein + hepatic artery + bile duct)
- DR Type1: lumen 있는 작은 ductule
- DR Type2: hepatocyte morphology + CK19+ 산발
- DR Type3: mass-like 큰 cluster
- channels: H&E pseudo, CK19, SOX9, Sirius Red

그룹별 패턴:
- control: T1 dominant 적음 (대부분 portal tract 정상)
- CDAHFD: T1+T2 많고 T3 등장
- CDAHFD+resmetirom: T1만 약간 (회복 모방)
"""
from __future__ import annotations

import math
from dataclasses import dataclass, field
from typing import Dict, List, Tuple

import numpy as np


WSI_SIZE = 512
DEFAULT_SEED = 42


@dataclass
class PortalTract:
    """단일 portal tract 메타데이터."""

    cx: int
    cy: int
    pv_radius: int  # portal vein
    ha_radius: int  # hepatic artery
    bd_radius: int  # bile duct
    dr_structures: List[Dict] = field(default_factory=list)  # type, size, lumen, position


@dataclass
class SynthSlide:
    """단일 합성 WSI 슬라이드."""

    he: np.ndarray        # (H, W, 3) float32 0..1
    ck19: np.ndarray      # (H, W) float32 0..1
    sox9: np.ndarray      # (H, W) float32 0..1
    sirius: np.ndarray    # (H, W) float32 0..1
    portal_tracts: List[PortalTract]
    group: str
    mouse_id: str
    slide_id: str


def _draw_disk(arr: np.ndarray, cx: int, cy: int, r: int, value: float, blur: float = 0.0) -> None:
    """disk 그리기 (정수 좌표). blur > 0 이면 가장자리 부드럽게."""
    h, w = arr.shape[:2]
    y, x = np.ogrid[:h, :w]
    dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
    if blur > 0:
        soft = np.clip(1.0 - (dist - r) / blur, 0.0, 1.0)
        soft[dist > r + blur] = 0.0
        soft[dist <= r] = 1.0
        if arr.ndim == 3:
            for c in range(arr.shape[2]):
                arr[..., c] = np.maximum(arr[..., c], soft * value if isinstance(value, float) else soft * value[c])
        else:
            arr[:] = np.maximum(arr, soft * value)
    else:
        mask = dist <= r
        if arr.ndim == 3:
            for c in range(arr.shape[2]):
                arr[..., c][mask] = value[c] if not isinstance(value, float) else value
        else:
            arr[mask] = value


def _draw_ellipse(arr: np.ndarray, cx: int, cy: int, ra: int, rb: int, angle: float, value: float) -> None:
    h, w = arr.shape[:2]
    y, x = np.ogrid[:h, :w]
    ca, sa = math.cos(angle), math.sin(angle)
    xr = (x - cx) * ca + (y - cy) * sa
    yr = -(x - cx) * sa + (y - cy) * ca
    mask = (xr / max(ra, 1)) ** 2 + (yr / max(rb, 1)) ** 2 <= 1.0
    if arr.ndim == 3:
        for c in range(arr.shape[2]):
            arr[..., c][mask] = value[c] if not isinstance(value, float) else value
    else:
        arr[mask] = value


def _hex_to_rgb(hexstr: str) -> Tuple[float, float, float]:
    h = hexstr.lstrip("#")
    return tuple(int(h[i : i + 2], 16) / 255.0 for i in (0, 2, 4))


def _make_portal_tract(rng: np.random.Generator, group: str) -> PortalTract:
    cx = int(rng.integers(60, WSI_SIZE - 60))
    cy = int(rng.integers(60, WSI_SIZE - 60))
    pv_r = int(rng.integers(10, 18))
    ha_r = int(rng.integers(4, 8))
    bd_r = int(rng.integers(3, 6))
    pt = PortalTract(cx=cx, cy=cy, pv_radius=pv_r, ha_radius=ha_r, bd_radius=bd_r)

    # DR structures: group별 type 분포
    if group == "control":
        n_t1 = int(rng.integers(0, 2))
        n_t2 = 0
        n_t3 = 0
    elif group == "CDAHFD":
        n_t1 = int(rng.integers(2, 5))
        n_t2 = int(rng.integers(1, 4))
        n_t3 = int(rng.integers(0, 2))
    elif group == "CDAHFD+resmetirom":
        n_t1 = int(rng.integers(0, 3))
        n_t2 = int(rng.integers(0, 2))
        n_t3 = 0
    else:
        n_t1 = n_t2 = n_t3 = 0

    # buffer 200μm 가정 (픽셀 환산: ~30-40 px)
    buf = 35
    for _ in range(n_t1):
        ang = float(rng.uniform(0, 2 * math.pi))
        dist = int(rng.integers(20, buf))
        x = cx + int(dist * math.cos(ang))
        y = cy + int(dist * math.sin(ang))
        size = int(rng.integers(3, 6))
        pt.dr_structures.append({"type": "T1", "x": x, "y": y, "size": size, "lumen": True})
    for _ in range(n_t2):
        ang = float(rng.uniform(0, 2 * math.pi))
        dist = int(rng.integers(15, buf))
        x = cx + int(dist * math.cos(ang))
        y = cy + int(dist * math.sin(ang))
        size = int(rng.integers(5, 9))  # hepatocyte 크기
        pt.dr_structures.append({"type": "T2", "x": x, "y": y, "size": size, "lumen": False})
    for _ in range(n_t3):
        ang = float(rng.uniform(0, 2 * math.pi))
        dist = int(rng.integers(10, buf - 5))
        x = cx + int(dist * math.cos(ang))
        y = cy + int(dist * math.sin(ang))
        size = int(rng.integers(10, 16))  # mass-like
        pt.dr_structures.append({"type": "T3", "x": x, "y": y, "size": size, "lumen": False})
    return pt


def make_slide(
    group: str,
    mouse_id: str,
    slide_id: str,
    n_portal_tracts: int = 6,
    seed: int = DEFAULT_SEED,
) -> SynthSlide:
    """단일 합성 슬라이드 생성. seed 고정 시 재현."""
    rng = np.random.default_rng(seed)

    # H&E 베이스: 분홍-보라 hepatocyte parenchyma
    he = np.ones((WSI_SIZE, WSI_SIZE, 3), dtype=np.float32)
    he[..., 0] = 0.92 + rng.normal(0, 0.02, (WSI_SIZE, WSI_SIZE))  # R
    he[..., 1] = 0.78 + rng.normal(0, 0.02, (WSI_SIZE, WSI_SIZE))  # G
    he[..., 2] = 0.88 + rng.normal(0, 0.02, (WSI_SIZE, WSI_SIZE))  # B
    he = np.clip(he, 0.0, 1.0)

    # hepatocyte 핵 점 (작은 진한 점)
    n_nuclei = 800
    nx = rng.integers(0, WSI_SIZE, n_nuclei)
    ny = rng.integers(0, WSI_SIZE, n_nuclei)
    for i in range(n_nuclei):
        _draw_disk(he, int(nx[i]), int(ny[i]), 1, (0.4, 0.2, 0.5))

    ck19 = np.zeros((WSI_SIZE, WSI_SIZE), dtype=np.float32)
    sox9 = np.zeros((WSI_SIZE, WSI_SIZE), dtype=np.float32)
    sirius = np.zeros((WSI_SIZE, WSI_SIZE), dtype=np.float32)

    # Sirius Red baseline noise
    sirius += np.clip(rng.normal(0.05, 0.02, (WSI_SIZE, WSI_SIZE)), 0.0, 0.2)

    pts: List[PortalTract] = []
    for _ in range(n_portal_tracts):
        pt = _make_portal_tract(rng, group)
        pts.append(pt)

        # H&E: portal vein lumen (밝게/비어 있음), hepatic artery (작은 빈 원), bile duct (작은 빈 원)
        _draw_disk(he, pt.cx, pt.cy, pt.pv_radius, (0.98, 0.96, 0.97))
        ha_x = pt.cx + int(pt.pv_radius * 1.3)
        ha_y = pt.cy + int(pt.pv_radius * 0.4)
        bd_x = pt.cx - int(pt.pv_radius * 0.4)
        bd_y = pt.cy + int(pt.pv_radius * 1.4)
        _draw_disk(he, ha_x, ha_y, pt.ha_radius, (0.98, 0.96, 0.97))
        _draw_disk(he, bd_x, bd_y, pt.bd_radius, (0.96, 0.92, 0.94))

        # CK19: bile duct epithelium 양성 → 작은 ring
        _draw_disk(ck19, bd_x, bd_y, pt.bd_radius + 1, 0.85)
        _draw_disk(ck19, bd_x, bd_y, pt.bd_radius - 1, 0.0)

        # SOX9: bile duct + ductular reaction에 양성 (CK19와 약간 겹침)
        _draw_disk(sox9, bd_x, bd_y, pt.bd_radius + 1, 0.7)
        _draw_disk(sox9, bd_x, bd_y, pt.bd_radius - 1, 0.0)

        # Sirius Red: portal tract 주변 collagen
        for r in range(pt.pv_radius + 2, pt.pv_radius + 8):
            _draw_disk(sirius, pt.cx, pt.cy, r, 0.55 - (r - pt.pv_radius - 2) * 0.05)

        # DR structures
        for ds in pt.dr_structures:
            x, y, size, t = ds["x"], ds["y"], ds["size"], ds["type"]
            if t == "T1":
                # lumen 있는 작은 ductule
                _draw_disk(ck19, x, y, size, 0.9)
                _draw_disk(ck19, x, y, max(size - 2, 0), 0.0)
                _draw_disk(sox9, x, y, size, 0.7)
                _draw_disk(sox9, x, y, max(size - 2, 0), 0.0)
                _draw_disk(he, x, y, size, (0.92, 0.85, 0.92))
                _draw_disk(he, x, y, max(size - 2, 0), (0.98, 0.96, 0.98))
            elif t == "T2":
                # hepatocyte morphology + CK19+ 산발 (filled)
                _draw_disk(ck19, x, y, size, 0.55)
                _draw_disk(sox9, x, y, size, 0.45)
                # H&E는 hepatocyte스럽게 약간 분홍
                _draw_disk(he, x, y, size, (0.88, 0.7, 0.85))
            elif t == "T3":
                # mass-like 큰 cluster
                _draw_ellipse(ck19, x, y, size, int(size * 0.7), 0.0, 0.92)
                _draw_ellipse(sox9, x, y, size, int(size * 0.7), 0.0, 0.78)
                _draw_disk(he, x, y, size, (0.85, 0.7, 0.88))

        # T3 있으면 portal tract 사이 sirius bridging 가능성
        if any(ds["type"] == "T3" for ds in pt.dr_structures):
            for r in range(pt.pv_radius + 8, pt.pv_radius + 18):
                _draw_disk(sirius, pt.cx, pt.cy, r, 0.35)

    # 약간 노이즈
    ck19 = np.clip(ck19 + rng.normal(0, 0.02, ck19.shape), 0.0, 1.0).astype(np.float32)
    sox9 = np.clip(sox9 + rng.normal(0, 0.02, sox9.shape), 0.0, 1.0).astype(np.float32)
    sirius = np.clip(sirius, 0.0, 1.0).astype(np.float32)
    he = np.clip(he, 0.0, 1.0).astype(np.float32)

    return SynthSlide(
        he=he,
        ck19=ck19,
        sox9=sox9,
        sirius=sirius,
        portal_tracts=pts,
        group=group,
        mouse_id=mouse_id,
        slide_id=slide_id,
    )


def make_demo_cohort(seed: int = DEFAULT_SEED) -> List[SynthSlide]:
    """3그룹 × 3마우스 × 1슬라이드(대표) 합성 코호트.

    실제로는 마우스당 H&E/CK19/SiriusRed serial section 3장이지만
    합성 1장에 다채널 묶여 있음.
    """
    groups = ["control", "CDAHFD", "CDAHFD+resmetirom"]
    slides: List[SynthSlide] = []
    s = seed
    for g in groups:
        for m in range(1, 4):
            mid = f"{g}-m{m}"
            slide = make_slide(group=g, mouse_id=mid, slide_id=f"{mid}-s1", seed=s)
            slides.append(slide)
            s += 1
    return slides
