"""Sirius Red collagen overlay + bridging vs ductular bridge 분류.

- Sirius Red 양성 area를 portal tract polygon과 overlay
- 두 portal tract 사이 collagen bridge가 있고 그 위에 CK19+ T3 structure가
  뿌려져 있으면 'ductular bridge', 없으면 'pure bridging fibrosis'.
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Tuple

import numpy as np


@dataclass
class BridgeReport:
    pt_a: int
    pt_b: int
    distance: float
    sirius_along_score: float  # 0..1
    ck19_along_score: float    # 0..1
    bridge_type: str           # 'none' / 'pure_fibrosis' / 'ductular_bridge'


def collagen_area_fraction(sirius: np.ndarray, threshold: float = 0.3) -> float:
    """Sirius Red 양성 area / 전체."""
    return float((sirius > threshold).mean())


def _line_samples(p0: Tuple[int, int], p1: Tuple[int, int], n: int = 30) -> List[Tuple[int, int]]:
    xs = np.linspace(p0[0], p1[0], n).astype(int)
    ys = np.linspace(p0[1], p1[1], n).astype(int)
    return list(zip(xs, ys))


def detect_bridges(
    portal_tracts: List,
    sirius: np.ndarray,
    ck19: np.ndarray,
    sirius_threshold: float = 0.3,
    ck19_threshold: float = 0.5,
    max_distance: int = 200,
    min_bridge_score: float = 0.3,
) -> List[BridgeReport]:
    """portal tract 쌍에 대해 bridge 분석. portal_tracts는 DetectedPortalTract / PortalTract."""
    reports: List[BridgeReport] = []
    n = len(portal_tracts)
    for i in range(n):
        for j in range(i + 1, n):
            a = portal_tracts[i]
            b = portal_tracts[j]
            d = float(np.hypot(a.cx - b.cx, a.cy - b.cy))
            if d > max_distance or d < 30:
                continue
            samples = _line_samples((a.cx, a.cy), (b.cx, b.cy), n=30)
            sir_vals = []
            ck_vals = []
            for x, y in samples[3:-3]:  # 양 끝 portal 본체 제외
                if 0 <= x < sirius.shape[1] and 0 <= y < sirius.shape[0]:
                    sir_vals.append(sirius[y, x])
                    ck_vals.append(ck19[y, x])
            if not sir_vals:
                continue
            sirius_along = float(np.mean(np.array(sir_vals) > sirius_threshold))
            ck19_along = float(np.mean(np.array(ck_vals) > ck19_threshold))
            if sirius_along < min_bridge_score:
                bridge_type = "none"
            elif ck19_along > 0.15:
                bridge_type = "ductular_bridge"
            else:
                bridge_type = "pure_fibrosis"
            reports.append(
                BridgeReport(
                    pt_a=i,
                    pt_b=j,
                    distance=d,
                    sirius_along_score=sirius_along,
                    ck19_along_score=ck19_along,
                    bridge_type=bridge_type,
                )
            )
    return reports


def make_overlay_rgb(he: np.ndarray, sirius: np.ndarray, sirius_threshold: float = 0.3) -> np.ndarray:
    """H&E + Sirius Red 빨강 overlay RGB."""
    if he.ndim == 2:
        base = np.stack([he, he, he], axis=-1)
    else:
        base = he.copy()
    mask = sirius > sirius_threshold
    base = base.copy()
    base[..., 0] = np.where(mask, np.maximum(base[..., 0], sirius), base[..., 0])
    base[..., 1] = np.where(mask, base[..., 1] * 0.5, base[..., 1])
    base[..., 2] = np.where(mask, base[..., 2] * 0.5, base[..., 2])
    return np.clip(base, 0.0, 1.0)
