"""PancIsletMass v0 — islet auto-detection (heuristic).

실제 Cellpose 2 대신, multi-channel IHC 합성 영상을 intensity threshold +
connected-components로 segmentation. scikit-image가 있으면 사용, 없으면
순수 numpy로 4-connectivity flood-fill.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Tuple

import numpy as np

try:
    from skimage import measure, morphology  # type: ignore

    _HAS_SKIMAGE = True
except Exception:  # pragma: no cover - optional
    _HAS_SKIMAGE = False


@dataclass
class IsletROI:
    label: int
    cy: float
    cx: float
    area_px: int
    perimeter_px: float
    circularity: float
    bbox: Tuple[int, int, int, int]  # ymin, xmin, ymax, xmax


def _islet_score(channels: np.ndarray, channel_names: Tuple[str, ...]) -> np.ndarray:
    """Combine endocrine markers into a single islet-likelihood map."""
    idx_ins = channel_names.index("insulin")
    idx_glu = channel_names.index("glucagon")
    idx_sst = channel_names.index("somatostatin")
    score = (
        channels[idx_ins] * 1.0
        + channels[idx_glu] * 0.6
        + channels[idx_sst] * 0.4
    )
    # smooth via box filter (numpy fallback)
    score = _box_smooth(score, k=5)
    return score


def _box_smooth(img: np.ndarray, k: int = 5) -> np.ndarray:
    """Cheap box filter (separable) without scipy."""
    if k <= 1:
        return img
    pad = k // 2
    padded = np.pad(img, pad, mode="reflect")
    csum = padded.cumsum(axis=0).cumsum(axis=1)
    h, w = img.shape
    # integral image trick
    a = csum[k - 1 : k - 1 + h, k - 1 : k - 1 + w]
    b = csum[: h, k - 1 : k - 1 + w] if pad > 0 else 0
    c = csum[k - 1 : k - 1 + h, :w] if pad > 0 else 0
    d = csum[:h, :w] if pad > 0 else 0
    out = (a - b - c + d) / (k * k)
    return out.astype(np.float32)


def _label_cc_numpy(mask: np.ndarray) -> np.ndarray:
    """4-connectivity connected components in pure numpy (iterative)."""
    h, w = mask.shape
    labels = np.zeros((h, w), dtype=np.int32)
    cur = 0
    for y in range(h):
        for x in range(w):
            if mask[y, x] and labels[y, x] == 0:
                cur += 1
                stack = [(y, x)]
                while stack:
                    yy, xx = stack.pop()
                    if yy < 0 or yy >= h or xx < 0 or xx >= w:
                        continue
                    if not mask[yy, xx] or labels[yy, xx] != 0:
                        continue
                    labels[yy, xx] = cur
                    stack.extend([(yy + 1, xx), (yy - 1, xx), (yy, xx + 1), (yy, xx - 1)])
    return labels


def _perimeter_numpy(mask: np.ndarray) -> float:
    if mask.sum() == 0:
        return 0.0
    # 4-neighbor boundary count
    m = mask.astype(np.int32)
    boundary = (
        (m - np.roll(m, 1, axis=0) == 1).sum()
        + (m - np.roll(m, -1, axis=0) == 1).sum()
        + (m - np.roll(m, 1, axis=1) == 1).sum()
        + (m - np.roll(m, -1, axis=1) == 1).sum()
    )
    return float(boundary)


def detect_islets(
    channels: np.ndarray,
    channel_names: Tuple[str, ...],
    score_threshold: float = 0.18,
    min_area: int = 800,
) -> Tuple[np.ndarray, List[IsletROI]]:
    """Return (label_image, list_of_ROIs).

    label_image: int array, 0 = background, 1..N = islets.
    """
    score = _islet_score(channels, channel_names)
    mask = score > score_threshold

    if _HAS_SKIMAGE:
        mask = morphology.remove_small_objects(mask, min_size=min_area // 4)
        mask = morphology.binary_closing(mask, morphology.disk(3))
        labels = measure.label(mask, connectivity=2)
        rois: List[IsletROI] = []
        for prop in measure.regionprops(labels):
            if prop.area < min_area:
                continue
            cy, cx = prop.centroid
            perim = float(prop.perimeter) if prop.perimeter else 1.0
            circ = float(4 * np.pi * prop.area / (perim * perim + 1e-9))
            ymin, xmin, ymax, xmax = prop.bbox
            rois.append(
                IsletROI(
                    label=int(prop.label),
                    cy=float(cy),
                    cx=float(cx),
                    area_px=int(prop.area),
                    perimeter_px=perim,
                    circularity=circ,
                    bbox=(int(ymin), int(xmin), int(ymax), int(xmax)),
                )
            )
        # remap labels to keep only valid ROIs
        valid_set = {r.label for r in rois}
        out_labels = np.where(np.isin(labels, list(valid_set) or [-1]), labels, 0)
        return out_labels.astype(np.int32), rois

    # fallback: pure numpy
    labels = _label_cc_numpy(mask)
    rois = []
    out_labels = np.zeros_like(labels)
    new_id = 0
    for lab in range(1, int(labels.max()) + 1):
        m = labels == lab
        area = int(m.sum())
        if area < min_area:
            continue
        ys, xs = np.where(m)
        cy = float(ys.mean())
        cx = float(xs.mean())
        perim = _perimeter_numpy(m)
        circ = float(4 * np.pi * area / (perim * perim + 1e-9))
        ymin, ymax = int(ys.min()), int(ys.max() + 1)
        xmin, xmax = int(xs.min()), int(xs.max() + 1)
        new_id += 1
        out_labels[m] = new_id
        rois.append(
            IsletROI(
                label=new_id,
                cy=cy,
                cx=cx,
                area_px=area,
                perimeter_px=perim,
                circularity=circ,
                bbox=(ymin, xmin, ymax, xmax),
            )
        )
    return out_labels.astype(np.int32), rois
