"""Portal tract auto-detection 휴리스틱.

H&E 채널의 round empty regions(혈관 lumen) + 인접 CK19+ structure 조합으로
portal triad 후보 검출. 3-요소(portal vein + hepatic artery + bile duct) 클러스터링.
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Tuple

import numpy as np


@dataclass
class DetectedPortalTract:
    cx: int
    cy: int
    pv_radius: int
    bd_x: int
    bd_y: int
    bd_radius: int
    score: float  # confidence


def _find_empty_regions(he_gray: np.ndarray, threshold: float = 0.93, min_size: int = 30) -> List[Tuple[int, int, int]]:
    """밝은 (빈) 둥근 영역을 찾아서 (cx, cy, radius) 리스트 반환.

    간단한 connected component 대용: scipy 없이 수동 flood fill 너무 느리므로
    grid 기반 local-maxima로 근사.
    """
    h, w = he_gray.shape
    bright = he_gray > threshold
    # downscale grid에서 bright fraction이 높은 셀을 후보로
    grid = 16
    candidates: List[Tuple[int, int, int]] = []
    visited = np.zeros((h // grid + 1, w // grid + 1), dtype=bool)
    for gy in range(h // grid):
        for gx in range(w // grid):
            if visited[gy, gx]:
                continue
            cell = bright[gy * grid : (gy + 1) * grid, gx * grid : (gx + 1) * grid]
            if cell.mean() > 0.6:
                # neighbor 합쳐서 region
                # 단순히 인접 셀까지 확장
                y0, y1 = gy * grid, (gy + 1) * grid
                x0, x1 = gx * grid, (gx + 1) * grid
                # 확장
                for dy in (-1, 0, 1):
                    for dx in (-1, 0, 1):
                        ny, nx = gy + dy, gx + dx
                        if 0 <= ny < h // grid and 0 <= nx < w // grid:
                            ncell = bright[ny * grid : (ny + 1) * grid, nx * grid : (nx + 1) * grid]
                            if ncell.mean() > 0.5:
                                visited[ny, nx] = True
                                y0 = min(y0, ny * grid)
                                y1 = max(y1, (ny + 1) * grid)
                                x0 = min(x0, nx * grid)
                                x1 = max(x1, (nx + 1) * grid)
                cx = (x0 + x1) // 2
                cy = (y0 + y1) // 2
                area = (y1 - y0) * (x1 - x0) * float(bright[y0:y1, x0:x1].mean())
                radius = int(np.sqrt(area / np.pi))
                if area >= min_size:
                    candidates.append((cx, cy, max(radius, 5)))
    # dedupe by distance
    deduped: List[Tuple[int, int, int]] = []
    for c in candidates:
        if all(np.hypot(c[0] - d[0], c[1] - d[1]) > 25 for d in deduped):
            deduped.append(c)
    return deduped


def detect_portal_tracts(
    he: np.ndarray,
    ck19: np.ndarray,
    pv_threshold: float = 0.93,
    ck19_threshold: float = 0.4,
    proximity: int = 30,
) -> List[DetectedPortalTract]:
    """portal vein 후보 + 근접 CK19+ structure로 portal tract 결정.

    he: (H, W, 3), ck19: (H, W).
    """
    if he.ndim == 3:
        he_gray = he.mean(axis=2)
    else:
        he_gray = he
    pv_candidates = _find_empty_regions(he_gray, threshold=pv_threshold)

    detected: List[DetectedPortalTract] = []
    for cx, cy, r in pv_candidates:
        # 근접 CK19+ ring 찾기
        y0 = max(0, cy - proximity - r)
        y1 = min(ck19.shape[0], cy + proximity + r)
        x0 = max(0, cx - proximity - r)
        x1 = min(ck19.shape[1], cx + proximity + r)
        local = ck19[y0:y1, x0:x1]
        if local.size == 0:
            continue
        ck_pos = local > ck19_threshold
        if ck_pos.mean() < 0.005:
            continue
        # CK19+ centroid (relative)
        ys, xs = np.where(ck_pos)
        if len(ys) == 0:
            continue
        bd_y = int(ys.mean()) + y0
        bd_x = int(xs.mean()) + x0
        bd_r = max(int(np.sqrt(len(ys) / np.pi)), 2)
        score = float(min(local.mean() * 3, 1.0))
        detected.append(
            DetectedPortalTract(
                cx=cx,
                cy=cy,
                pv_radius=r,
                bd_x=bd_x,
                bd_y=bd_y,
                bd_radius=bd_r,
                score=score,
            )
        )
    return detected
