"""PancIsletMass v0 — 세포 분류.

검출된 islet ROI 내부에서 marker intensity를 기반으로 픽셀 단위/패치 단위
세포 종류를 분류한다 (α, β, δ, β∩Ki67, β∩TUNEL).
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Tuple

import numpy as np


CELL_MARKERS = {
    "alpha": "glucagon",
    "beta": "insulin",
    "delta": "somatostatin",
}


@dataclass
class IsletCellComposition:
    label: int
    alpha_area_px: int
    beta_area_px: int
    delta_area_px: int
    ki67_pos_px: int       # β ∩ Ki67
    tunel_pos_px: int      # β ∩ TUNEL
    total_islet_px: int

    def beta_fraction(self) -> float:
        endo = self.alpha_area_px + self.beta_area_px + self.delta_area_px
        return float(self.beta_area_px) / endo if endo > 0 else 0.0

    def proliferation_rate(self) -> float:
        return float(self.ki67_pos_px) / self.beta_area_px if self.beta_area_px > 0 else 0.0

    def apoptosis_rate(self) -> float:
        return float(self.tunel_pos_px) / self.beta_area_px if self.beta_area_px > 0 else 0.0


def classify_islet_cells(
    channels: np.ndarray,
    channel_names: Tuple[str, ...],
    label_image: np.ndarray,
    threshold: float = 0.25,
) -> List[IsletCellComposition]:
    """For each islet label, count pixels positive for each cell-type marker.

    Pixel-level rule (mutually exclusive at pixel level):
      - alpha: glucagon > thr AND glucagon = max(insulin, glucagon, somatostatin)
      - beta:  insulin > thr AND insulin = max(...)
      - delta: somatostatin > thr AND somatostatin = max(...)
    Ki67+ / TUNEL+ counted only on beta-positive pixels.
    """
    idx_ins = channel_names.index("insulin")
    idx_glu = channel_names.index("glucagon")
    idx_sst = channel_names.index("somatostatin")
    idx_ki = channel_names.index("Ki67")
    idx_tu = channel_names.index("TUNEL")

    ins = channels[idx_ins]
    glu = channels[idx_glu]
    sst = channels[idx_sst]
    ki = channels[idx_ki]
    tu = channels[idx_tu]

    stack = np.stack([ins, glu, sst], axis=0)
    argmax = stack.argmax(axis=0)  # 0=ins,1=glu,2=sst
    above = stack.max(axis=0) > threshold

    beta_mask_global = (argmax == 0) & above
    alpha_mask_global = (argmax == 1) & above
    delta_mask_global = (argmax == 2) & above
    ki_mask_global = ki > threshold
    tu_mask_global = tu > threshold

    out: List[IsletCellComposition] = []
    n_labels = int(label_image.max())
    for lab in range(1, n_labels + 1):
        m = label_image == lab
        if not m.any():
            continue
        a = int((alpha_mask_global & m).sum())
        b = int((beta_mask_global & m).sum())
        d = int((delta_mask_global & m).sum())
        k = int((beta_mask_global & ki_mask_global & m).sum())
        t = int((beta_mask_global & tu_mask_global & m).sum())
        out.append(
            IsletCellComposition(
                label=lab,
                alpha_area_px=a,
                beta_area_px=b,
                delta_area_px=d,
                ki67_pos_px=k,
                tunel_pos_px=t,
                total_islet_px=int(m.sum()),
            )
        )
    return out


def composition_to_dicts(comps: List[IsletCellComposition]) -> List[Dict]:
    rows: List[Dict] = []
    for c in comps:
        rows.append(
            {
                "islet_label": c.label,
                "islet_area_px": c.total_islet_px,
                "alpha_px": c.alpha_area_px,
                "beta_px": c.beta_area_px,
                "delta_px": c.delta_area_px,
                "beta_fraction": round(c.beta_fraction(), 4),
                "ki67_beta_px": c.ki67_pos_px,
                "tunel_beta_px": c.tunel_pos_px,
                "proliferation_rate": round(c.proliferation_rate(), 4),
                "apoptosis_rate": round(c.apoptosis_rate(), 4),
            }
        )
    return rows
