"""Ductular Reaction (DR) Type1/2/3 분류 휴리스틱.

각 portal tract polygon + 200μm buffer 내 CK19+ structure를 검출하고
크기/lumen 유무/모양 비율로 룰베이스 분류.

- Type1 (reactive ductule): 작고 lumen 있는 환형
- Type2 (intermediate hepatocyte): 중간 크기, hepatocyte morphology + 약한 CK19
- Type3 (ductular metaplasia / mass-like): 큰 cluster
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Tuple

import numpy as np


@dataclass
class DRStructure:
    cx: int
    cy: int
    size: int  # 픽셀 area sqrt
    intensity: float  # CK19 mean
    lumen_score: float  # 0..1 (안쪽 hole 정도)
    dr_type: str  # T1 / T2 / T3
    portal_idx: int  # 어느 portal tract 소속


def _find_ck19_blobs(ck19: np.ndarray, threshold: float = 0.4, min_area: int = 6) -> List[Tuple[int, int, int, float]]:
    """간단한 grid 기반 CK19+ 클러스터 검출.

    return (cx, cy, area_sqrt, mean_intensity).
    """
    h, w = ck19.shape
    pos = ck19 > threshold
    grid = 6
    candidates: List[Tuple[int, int, int, float]] = []
    visited = np.zeros((h // grid + 1, w // grid + 1), dtype=bool)
    for gy in range(h // grid):
        for gx in range(w // grid):
            if visited[gy, gx]:
                continue
            cell = pos[gy * grid : (gy + 1) * grid, gx * grid : (gx + 1) * grid]
            if cell.mean() <= 0.2:
                continue
            # 인접 셀 확장 (BFS-lite)
            y0, y1 = gy * grid, (gy + 1) * grid
            x0, x1 = gx * grid, (gx + 1) * grid
            stack = [(gy, gx)]
            visited[gy, gx] = True
            while stack:
                cy_, cx_ = stack.pop()
                for dy in (-1, 0, 1):
                    for dx in (-1, 0, 1):
                        ny, nx = cy_ + dy, cx_ + dx
                        if 0 <= ny < h // grid and 0 <= nx < w // grid and not visited[ny, nx]:
                            ncell = pos[ny * grid : (ny + 1) * grid, nx * grid : (nx + 1) * grid]
                            if ncell.mean() > 0.15:
                                visited[ny, nx] = True
                                y0 = min(y0, ny * grid)
                                y1 = max(y1, (ny + 1) * grid)
                                x0 = min(x0, nx * grid)
                                x1 = max(x1, (nx + 1) * grid)
                                stack.append((ny, nx))
            patch = ck19[y0:y1, x0:x1]
            patch_pos = pos[y0:y1, x0:x1]
            area = int(patch_pos.sum())
            if area < min_area:
                continue
            ys, xs = np.where(patch_pos)
            cx = int(xs.mean()) + x0
            cy = int(ys.mean()) + y0
            mean_int = float(patch[patch_pos].mean()) if area > 0 else 0.0
            candidates.append((cx, cy, int(np.sqrt(area)), mean_int))
    return candidates


def _lumen_score(ck19_patch: np.ndarray) -> float:
    """patch 내부의 빈 공간(중심부 강도 낮음) 비율.

    중심 1/4 영역 평균이 가장자리 평균보다 많이 낮으면 lumen 있음.
    """
    h, w = ck19_patch.shape
    if h < 4 or w < 4:
        return 0.0
    cy0, cy1 = h // 3, 2 * h // 3
    cx0, cx1 = w // 3, 2 * w // 3
    center = float(ck19_patch[cy0:cy1, cx0:cx1].mean())
    border_mask = np.ones((h, w), dtype=bool)
    border_mask[cy0:cy1, cx0:cx1] = False
    border = float(ck19_patch[border_mask].mean())
    if border <= 1e-6:
        return 0.0
    return float(max(0.0, (border - center) / border))


def classify_dr_structures(
    ck19: np.ndarray,
    portal_tracts: List,
    buffer: int = 35,
) -> List[DRStructure]:
    """portal tract 주변 buffer 내 CK19+ structure를 분류.

    portal_tracts: DetectedPortalTract 또는 PortalTract iterable. cx/cy 속성만 사용.
    """
    blobs = _find_ck19_blobs(ck19)
    out: List[DRStructure] = []
    for bx, by, size, mean_int in blobs:
        # 가장 가까운 portal tract
        best_idx = -1
        best_dist = float("inf")
        for i, pt in enumerate(portal_tracts):
            d = float(np.hypot(bx - pt.cx, by - pt.cy))
            if d < best_dist:
                best_dist = d
                best_idx = i
        if best_idx < 0 or best_dist > buffer + 30:
            continue

        # bile duct 자체(작은 ring 정확히 portal tract 중심 근처)는 스킵 — bd_x/y 근처
        pt = portal_tracts[best_idx]
        if hasattr(pt, "bd_x") and hasattr(pt, "bd_y"):
            if np.hypot(bx - pt.bd_x, by - pt.bd_y) < 5:
                continue
        # patch 추출
        ps = max(size, 3)
        y0 = max(0, by - ps - 1)
        y1 = min(ck19.shape[0], by + ps + 1)
        x0 = max(0, bx - ps - 1)
        x1 = min(ck19.shape[1], bx + ps + 1)
        patch = ck19[y0:y1, x0:x1]
        ls = _lumen_score(patch)

        # 룰베이스 분류
        if size <= 6:
            dr_type = "T1" if ls > 0.15 else "T2"
        elif size <= 11:
            dr_type = "T2" if mean_int < 0.7 else "T1"
        else:
            dr_type = "T3"

        out.append(
            DRStructure(
                cx=bx,
                cy=by,
                size=size,
                intensity=mean_int,
                lumen_score=ls,
                dr_type=dr_type,
                portal_idx=best_idx,
            )
        )
    return out
