"""ATM (Adipose Tissue Macrophage) M1/M2 polarization 정량.

CD11c = M1 마커, CD206 = M2 마커.
- per-CLS: CLS 영역 내 평균 CD11c / CD206 intensity 비율 → M1/M2 ratio
- global polarization index = mean intensity ratio across whole image
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import List

import numpy as np


@dataclass
class PolarResult:
    per_cls_m1_intensity: np.ndarray
    per_cls_m2_intensity: np.ndarray
    per_cls_m1m2_ratio: np.ndarray
    mean_m1: float
    mean_m2: float
    global_m1m2_ratio: float
    polarization_index: float  # (M1 - M2) / (M1 + M2), [-1, 1]


def _circular_mask(shape, cx, cy, r):
    h, w = shape
    yy, xx = np.ogrid[:h, :w]
    return (yy - cy) ** 2 + (xx - cx) ** 2 <= r * r


def quantify_polarization(image: np.ndarray,
                          cls_centers: np.ndarray,
                          cls_radii: np.ndarray,
                          channels: tuple = ("perilipin", "F480", "CD11c", "CD206", "DAPI"),
                          shell_factor: float = 1.5) -> PolarResult:
    """CLS shell 영역 내부 CD11c/CD206 intensity 합을 비교."""
    cd11c = image[..., channels.index("CD11c")]
    cd206 = image[..., channels.index("CD206")]

    m1_per = []
    m2_per = []
    ratios = []

    for i in range(cls_centers.shape[0]):
        cx, cy = float(cls_centers[i, 0]), float(cls_centers[i, 1])
        r = float(cls_radii[i])
        outer = r * shell_factor
        inner = r * 0.95
        mask_outer = _circular_mask(cd11c.shape, cx, cy, outer)
        mask_inner = _circular_mask(cd11c.shape, cx, cy, inner)
        ring = mask_outer & (~mask_inner)
        if ring.sum() == 0:
            continue
        m1_int = float(cd11c[ring].mean())
        m2_int = float(cd206[ring].mean())
        ratio = m1_int / max(m2_int, 1e-6)
        m1_per.append(m1_int)
        m2_per.append(m2_int)
        ratios.append(ratio)

    m1_arr = np.array(m1_per, dtype=np.float32) if m1_per else np.zeros((0,), dtype=np.float32)
    m2_arr = np.array(m2_per, dtype=np.float32) if m2_per else np.zeros((0,), dtype=np.float32)
    ratios_arr = np.array(ratios, dtype=np.float32) if ratios else np.zeros((0,), dtype=np.float32)

    # global: 전체 채널 합 (background는 작아 거의 무시됨)
    g_m1 = float(cd11c.sum())
    g_m2 = float(cd206.sum())
    g_ratio = g_m1 / max(g_m2, 1e-6)
    pol_index = (g_m1 - g_m2) / max(g_m1 + g_m2, 1e-6)

    return PolarResult(
        per_cls_m1_intensity=m1_arr,
        per_cls_m2_intensity=m2_arr,
        per_cls_m1m2_ratio=ratios_arr,
        mean_m1=float(m1_arr.mean()) if m1_arr.size else 0.0,
        mean_m2=float(m2_arr.mean()) if m2_arr.size else 0.0,
        global_m1m2_ratio=float(g_ratio),
        polarization_index=float(pol_index),
    )
