"""Adipocyte segmentation (perilipin ring + dead adipocyte 검출).

Heuristic: perilipin 채널 ring → 내부 충진 → connected component → ellipse fit으로
adipocyte cell 추출. perilipin- 영역(ring 부재 + 큰 빈 공간)은 dead adipocyte 후보.

scikit-image 가용 시 morphology 사용, 없으면 numpy fallback.
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Tuple

import numpy as np

try:
    from scipy import ndimage as ndi
    _HAS_SCIPY = True
except Exception:  # pragma: no cover
    ndi = None
    _HAS_SCIPY = False


@dataclass
class AdipocyteResult:
    centers: np.ndarray   # (N, 2) — cx, cy
    radii: np.ndarray     # (N,)
    dead_mask: np.ndarray # (N,) bool
    diameters_um: np.ndarray  # (N,)
    n_total: int
    n_dead: int


def _label_components(mask: np.ndarray) -> Tuple[np.ndarray, int]:
    """Connected component labeling. scipy 우선, 없으면 단순 BFS."""
    if _HAS_SCIPY:
        labeled, n = ndi.label(mask)
        return labeled, int(n)

    h, w = mask.shape
    labeled = np.zeros_like(mask, dtype=np.int32)
    n = 0
    for i in range(h):
        for j in range(w):
            if mask[i, j] and labeled[i, j] == 0:
                n += 1
                stack = [(i, j)]
                while stack:
                    y, x = stack.pop()
                    if y < 0 or y >= h or x < 0 or x >= w:
                        continue
                    if not mask[y, x] or labeled[y, x] != 0:
                        continue
                    labeled[y, x] = n
                    stack.extend([(y + 1, x), (y - 1, x), (y, x + 1), (y, x - 1)])
    return labeled, n


def segment_adipocytes(image: np.ndarray, perilipin_idx: int = 0,
                       pixel_um: float = 0.5,
                       min_radius_px: float = 18.0,
                       max_radius_px: float = 110.0) -> AdipocyteResult:
    """perilipin 채널 기반 adipocyte segmentation.

    Returns center/radius array + dead mask (perilipin- 영역과 가까운 component).
    """
    perilipin = image[..., perilipin_idx]
    bin_ring = perilipin > 0.4

    # ring을 dilate → fill을 흉내
    if _HAS_SCIPY:
        dilated = ndi.binary_dilation(bin_ring, iterations=3)
        filled = ndi.binary_fill_holes(dilated)
    else:  # pragma: no cover
        filled = bin_ring.copy()

    labeled, n = _label_components(filled)

    centers = []
    radii = []
    dead_flags = []

    for lbl in range(1, n + 1):
        ys, xs = np.where(labeled == lbl)
        if ys.size < 200:
            continue
        cy = float(ys.mean())
        cx = float(xs.mean())
        # 등가 원 반경
        area = float(ys.size)
        r = float(np.sqrt(area / np.pi))
        if r < min_radius_px or r > max_radius_px:
            continue
        centers.append((cx, cy))
        radii.append(r)
        dead_flags.append(False)

    # dead adipocyte 추정: F4/80 shell ring + 내부에 perilipin- 영역
    # F480 채널은 CHANNELS 순서상 index=1
    if image.shape[-1] > 1:
        f480 = image[..., 1]
        f480_bin = f480 > 0.35
        if _HAS_SCIPY:
            # 충분한 dilation으로 ring을 닫고 fill
            f480_dilated = ndi.binary_dilation(f480_bin, iterations=8)
            f480_filled = ndi.binary_fill_holes(f480_dilated)
            # 다시 erode하여 진짜 채워진 영역만 남김
            f480_solid = ndi.binary_erosion(f480_filled, iterations=4)
        else:  # pragma: no cover
            f480_solid = f480_bin
        # perilipin ring과 겹치지 않는 영역 = dead 후보
        dead_region = f480_solid & (~bin_ring)
        labeled_d, nd = _label_components(dead_region)
        for lbl in range(1, nd + 1):
            ys, xs = np.where(labeled_d == lbl)
            # dead adipocyte는 일반 adipocyte 면적과 비슷한 큰 component
            if ys.size < 800:
                continue
            cy = float(ys.mean())
            cx = float(xs.mean())
            area = float(ys.size)
            r = float(np.sqrt(area / np.pi))
            if r < min_radius_px * 0.7 or r > max_radius_px * 1.5:
                continue
            centers.append((cx, cy))
            radii.append(r)
            dead_flags.append(True)

    if not centers:
        empty = np.zeros((0,), dtype=np.float32)
        return AdipocyteResult(
            centers=np.zeros((0, 2), dtype=np.float32),
            radii=empty, dead_mask=np.zeros((0,), dtype=bool),
            diameters_um=empty, n_total=0, n_dead=0,
        )

    centers_arr = np.array(centers, dtype=np.float32)
    radii_arr = np.array(radii, dtype=np.float32)
    dead_arr = np.array(dead_flags, dtype=bool)
    diameters_um = 2.0 * radii_arr * pixel_um
    return AdipocyteResult(
        centers=centers_arr,
        radii=radii_arr,
        dead_mask=dead_arr,
        diameters_um=diameters_um,
        n_total=int(centers_arr.shape[0]),
        n_dead=int(dead_arr.sum()),
    )
