"""Crown-Like Structure (CLS) auto-detection.

휴리스틱:
1) DAPI 채널에서 nuclei 후보 추출 (local maxima 또는 component centroid).
2) 각 dead adipocyte 후보 중심에서 ~ 1.3 × radius 내 nuclei가 3개 이상,
   F4/80+ shell coverage가 perimeter의 50% 이상이면 CLS로 판정.
3) CLS density = count / mm² (pixel_um으로 환산).
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import List

import math
import numpy as np

try:
    from scipy import ndimage as ndi
    _HAS_SCIPY = True
except Exception:  # pragma: no cover
    ndi = None
    _HAS_SCIPY = False


@dataclass
class CLSResult:
    cls_centers: np.ndarray    # (K, 2) cx, cy
    cls_radii: np.ndarray      # (K,)
    nuclei_per_cls: np.ndarray # (K,)
    shell_coverage: np.ndarray # (K,) 0..1
    n_cls: int
    area_mm2: float
    density_per_mm2: float


def _detect_nuclei(dapi: np.ndarray, threshold: float = 0.4,
                   min_size: int = 5) -> np.ndarray:
    """DAPI 채널에서 nuclei centroid 검출. (N, 2) cx, cy."""
    bin_mask = dapi > threshold
    if _HAS_SCIPY:
        labeled, n = ndi.label(bin_mask)
        if n == 0:
            return np.zeros((0, 2), dtype=np.float32)
        centroids = ndi.center_of_mass(bin_mask, labeled, range(1, n + 1))
        sizes = ndi.sum(bin_mask, labeled, range(1, n + 1))
        out = []
        for (cy, cx), sz in zip(centroids, sizes):
            if sz >= min_size:
                out.append((cx, cy))
        return np.array(out, dtype=np.float32) if out else np.zeros((0, 2), dtype=np.float32)
    # fallback: naive — 모든 임계 픽셀 평균
    ys, xs = np.where(bin_mask)
    if ys.size == 0:
        return np.zeros((0, 2), dtype=np.float32)
    return np.array([[xs.mean(), ys.mean()]], dtype=np.float32)


def _shell_coverage(f480: np.ndarray, cx: float, cy: float, r: float,
                    n_samples: int = 24, threshold: float = 0.3) -> float:
    """원주를 따라 n개 점을 sampling.

    여러 radial 위치(0.85r, 1.0r, 1.15r)를 시도해 그 중 한 곳에서라도
    F4/80+ 신호가 있으면 hit로 카운트(detection radius와 truth radius 차이를 흡수).
    """
    h, w = f480.shape
    hits = 0
    radial_factors = (0.85, 1.0, 1.15)
    for k in range(n_samples):
        theta = 2 * math.pi * k / n_samples
        ct = math.cos(theta)
        st_ = math.sin(theta)
        is_hit = False
        for rf in radial_factors:
            x = int(round(cx + r * rf * ct))
            y = int(round(cy + r * rf * st_))
            if 0 <= x < w and 0 <= y < h:
                if f480[y, x] > threshold:
                    is_hit = True
                    break
        if is_hit:
            hits += 1
    return hits / n_samples


def detect_cls(image: np.ndarray,
               adipo_centers: np.ndarray,
               adipo_radii: np.ndarray,
               dead_mask: np.ndarray,
               channels: tuple = ("perilipin", "F480", "CD11c", "CD206", "DAPI"),
               pixel_um: float = 0.5,
               min_nuclei: int = 3,
               shell_min: float = 0.5,
               radius_factor: float = 1.4) -> CLSResult:
    """Dead adipocyte 후보 주변 휴리스틱 CLS detection."""
    h, w = image.shape[:2]
    dapi = image[..., channels.index("DAPI")]
    f480 = image[..., channels.index("F480")]

    nuclei_pts = _detect_nuclei(dapi)

    cls_centers = []
    cls_radii = []
    nuclei_counts = []
    shell_covs = []

    for i in range(adipo_centers.shape[0]):
        if not dead_mask[i]:
            continue
        cx, cy = float(adipo_centers[i, 0]), float(adipo_centers[i, 1])
        r = float(adipo_radii[i])
        if nuclei_pts.shape[0] == 0:
            continue
        # 거리 필터
        d2 = (nuclei_pts[:, 0] - cx) ** 2 + (nuclei_pts[:, 1] - cy) ** 2
        within = d2 <= (r * radius_factor) ** 2
        n_within = int(within.sum())
        cov = _shell_coverage(f480, cx, cy, r)
        if n_within >= min_nuclei and cov >= shell_min:
            cls_centers.append((cx, cy))
            cls_radii.append(r)
            nuclei_counts.append(n_within)
            shell_covs.append(cov)

    cls_centers_arr = (np.array(cls_centers, dtype=np.float32)
                       if cls_centers else np.zeros((0, 2), dtype=np.float32))
    cls_radii_arr = np.array(cls_radii, dtype=np.float32) if cls_radii else np.zeros((0,), dtype=np.float32)
    nuclei_counts_arr = np.array(nuclei_counts, dtype=np.int32) if nuclei_counts else np.zeros((0,), dtype=np.int32)
    shell_covs_arr = np.array(shell_covs, dtype=np.float32) if shell_covs else np.zeros((0,), dtype=np.float32)

    area_mm2 = (h * pixel_um / 1000.0) * (w * pixel_um / 1000.0)
    density = (cls_centers_arr.shape[0] / area_mm2) if area_mm2 > 0 else 0.0

    return CLSResult(
        cls_centers=cls_centers_arr,
        cls_radii=cls_radii_arr,
        nuclei_per_cls=nuclei_counts_arr,
        shell_coverage=shell_covs_arr,
        n_cls=int(cls_centers_arr.shape[0]),
        area_mm2=float(area_mm2),
        density_per_mm2=float(density),
    )
