"""AdipoCLSPolarMap v0 — Streamlit 진입점.

지방조직 multi-channel IF mini WSI에서 adipocyte/CLS/ATM polarization을
heuristic 기반으로 정량 + depot/그룹 비교 통계를 시연하는 stand-alone 도구.

연구·참고용. 임상 진단 사용 금지.
"""
from __future__ import annotations

import io
import json
from pathlib import Path
from typing import Dict, List

import numpy as np

try:
    import streamlit as st
    _HAS_ST = True
except Exception:  # pragma: no cover
    st = None
    _HAS_ST = False

try:
    import matplotlib.pyplot as plt
    from matplotlib.patches import Circle
    _HAS_MPL = True
except Exception:  # pragma: no cover
    plt = None
    Circle = None
    _HAS_MPL = False

from lib import synth_adipose, adipo_seg, cls_detect, atm_polarize, stats


# ---------------------------------------------------------------------------
# 헤드리스 분석 헬퍼 (Streamlit 없이도 실행/검수 가능)
# ---------------------------------------------------------------------------

def analyze_wsi(wsi: synth_adipose.SyntheticWSI) -> Dict:
    """단일 WSI에 대한 전 파이프라인 실행."""
    seg = adipo_seg.segment_adipocytes(wsi.image, perilipin_idx=0,
                                       pixel_um=wsi.pixel_um)
    cls_res = cls_detect.detect_cls(
        image=wsi.image,
        adipo_centers=seg.centers,
        adipo_radii=seg.radii,
        dead_mask=seg.dead_mask,
        channels=synth_adipose.CHANNELS,
        pixel_um=wsi.pixel_um,
    )
    pol = atm_polarize.quantify_polarization(
        image=wsi.image,
        cls_centers=cls_res.cls_centers,
        cls_radii=cls_res.cls_radii,
        channels=synth_adipose.CHANNELS,
    )
    return {
        "group": wsi.group,
        "depot": wsi.depot,
        "mouse_id": wsi.mouse_id,
        "n_adipocytes": seg.n_total,
        "n_dead": seg.n_dead,
        "mean_diameter_um": float(seg.diameters_um.mean()) if seg.n_total else 0.0,
        "n_cls": cls_res.n_cls,
        "cls_density_per_mm2": cls_res.density_per_mm2,
        "global_m1m2_ratio": pol.global_m1m2_ratio,
        "polarization_index": pol.polarization_index,
        "_seg": seg,
        "_cls": cls_res,
        "_pol": pol,
    }


def cohort_anovas(records: List[Dict]) -> Dict[str, stats.AnovaResult]:
    """그룹별 cls_density_per_mm2, polarization_index, mean_diameter_um ANOVA."""
    metrics = ["cls_density_per_mm2", "polarization_index", "mean_diameter_um"]
    out: Dict[str, stats.AnovaResult] = {}
    for m in metrics:
        groups: Dict[str, List[float]] = {}
        for r in records:
            groups.setdefault(r["group"], []).append(float(r[m]))
        out[m] = stats.one_way_anova(groups)
    return out


# ---------------------------------------------------------------------------
# Streamlit UI
# ---------------------------------------------------------------------------

def _render_overlay(wsi: synth_adipose.SyntheticWSI, result: Dict):
    if not _HAS_MPL:
        return None
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    # composite RGB: perilipin=green, F4/80=red, DAPI=blue
    img = wsi.image
    rgb = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.float32)
    rgb[..., 1] = img[..., synth_adipose.CHANNELS.index("perilipin")]
    rgb[..., 0] = img[..., synth_adipose.CHANNELS.index("F480")]
    rgb[..., 2] = img[..., synth_adipose.CHANNELS.index("DAPI")]
    rgb = np.clip(rgb, 0, 1)

    axes[0].imshow(rgb)
    axes[0].set_title(f"{wsi.group} / {wsi.depot} / {wsi.mouse_id}")
    seg = result["_seg"]
    for i in range(seg.centers.shape[0]):
        cx, cy = seg.centers[i]
        r = seg.radii[i]
        color = "yellow" if seg.dead_mask[i] else "lime"
        axes[0].add_patch(Circle((cx, cy), r, fill=False, edgecolor=color, lw=0.7))
    axes[0].axis("off")

    # CLS overlay
    cls_res = result["_cls"]
    axes[1].imshow(rgb)
    axes[1].set_title(f"CLS = {cls_res.n_cls} | density = {cls_res.density_per_mm2:.1f}/mm²")
    for i in range(cls_res.cls_centers.shape[0]):
        cx, cy = cls_res.cls_centers[i]
        r = cls_res.cls_radii[i]
        axes[1].add_patch(Circle((cx, cy), r * 1.4, fill=False, edgecolor="red", lw=1.2))
    axes[1].axis("off")
    fig.tight_layout()
    return fig


def _render_polar_donut(result: Dict):
    if not _HAS_MPL:
        return None
    pol = result["_pol"]
    m1 = float(pol.mean_m1)
    m2 = float(pol.mean_m2)
    if m1 + m2 <= 0:
        m1, m2 = 0.5, 0.5
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.pie([m1, m2], labels=["M1 (CD11c)", "M2 (CD206)"],
           colors=["#d9534f", "#5bc0de"], wedgeprops={"width": 0.4},
           startangle=90, autopct="%1.1f%%")
    ax.set_title(f"M1/M2 ratio = {pol.global_m1m2_ratio:.2f}")
    return fig


def _render_violin(records: List[Dict], metric: str, title: str):
    if not _HAS_MPL:
        return None
    groups = sorted({r["group"] for r in records})
    data = [[r[metric] for r in records if r["group"] == g] for g in groups]
    fig, ax = plt.subplots(figsize=(6, 4))
    if any(len(d) > 0 for d in data):
        parts = ax.violinplot(data, showmeans=True)
    ax.set_xticks(range(1, len(groups) + 1))
    ax.set_xticklabels(groups)
    ax.set_title(title)
    ax.set_ylabel(metric)
    ax.grid(alpha=0.3)
    return fig


def main():  # pragma: no cover
    if not _HAS_ST:
        print("Streamlit이 설치되지 않았습니다. requirements.txt 참조.")
        return

    st.set_page_config(page_title="AdipoCLSPolarMap v0", layout="wide")
    st.title("AdipoCLSPolarMap v0")
    st.caption("지방조직 CLS + ATM M1/M2 polarization 정량 (합성 mock 데이터 시연)")
    st.warning("본 도구는 연구·참고용입니다. 임상 진단 사용 금지.")

    with st.sidebar:
        st.header("입력 / 설정")
        mode = st.radio("데이터 소스", ["내장 demo cohort", "단일 WSI 합성"])
        if mode == "내장 demo cohort":
            run_btn = st.button("Demo cohort 분석 실행", type="primary")
            wsi_size = 384
        else:
            group = st.selectbox("Group", list(synth_adipose.GROUP_PARAMS.keys()))
            depot = st.selectbox("Depot", list(synth_adipose.DEPOTS))
            mouse_id = st.text_input("Mouse ID", "demo_m1")
            wsi_size = st.slider("WSI size (px)", 256, 768, 384, step=64)
            seed = st.number_input("Seed", value=42, step=1)
            run_btn = st.button("단일 WSI 분석 실행", type="primary")

        st.markdown("---")
        st.markdown("**채널 매핑**")
        st.code(", ".join(synth_adipose.CHANNELS))

    if not run_btn:
        st.info("좌측에서 데이터 소스를 선택 후 '실행' 버튼을 눌러주세요.")
        return

    if mode == "단일 WSI 합성":
        wsi = synth_adipose.synthesize_wsi(group=group, depot=depot,
                                           mouse_id=mouse_id,
                                           size=int(wsi_size), seed=int(seed))
        result = analyze_wsi(wsi)
        col1, col2 = st.columns([2, 1])
        with col1:
            st.subheader("(a-c) WSI + adipocyte/CLS overlay")
            fig = _render_overlay(wsi, result)
            if fig is not None:
                st.pyplot(fig)
        with col2:
            st.subheader("(d) ATM M1/M2 polarization")
            fig2 = _render_polar_donut(result)
            if fig2 is not None:
                st.pyplot(fig2)
            st.metric("Adipocytes (n)", result["n_adipocytes"])
            st.metric("Dead adipocytes", result["n_dead"])
            st.metric("CLS count", result["n_cls"])
            st.metric("CLS density (/mm²)", f"{result['cls_density_per_mm2']:.2f}")
            st.metric("Polarization index", f"{result['polarization_index']:+.3f}")
        return

    # Cohort 모드
    cohort = synth_adipose.build_demo_cohort(size=wsi_size)
    records: List[Dict] = []
    progress = st.progress(0.0, text="cohort 분석 중...")
    for i, wsi in enumerate(cohort):
        records.append(analyze_wsi(wsi))
        progress.progress((i + 1) / len(cohort), text=f"{i+1}/{len(cohort)}")
    progress.empty()

    st.subheader("Cohort summary table")
    table_rows = [{k: v for k, v in r.items() if not k.startswith("_")} for r in records]
    st.dataframe(table_rows, use_container_width=True)

    anovas = cohort_anovas(records)
    st.subheader("(e) Group ANOVA")
    cols = st.columns(len(anovas))
    for col, (metric, ar) in zip(cols, anovas.items()):
        with col:
            fig = _render_violin(records, metric, metric)
            if fig is not None:
                st.pyplot(fig)
            st.markdown(f"**{metric}**\n\nF = {ar.f_stat:.3f}, p = {ar.p_value:.4f}")

    # Export
    st.subheader("Export (한국어 라벨)")
    json_blob = json.dumps(table_rows, ensure_ascii=False, indent=2)
    st.download_button("결과 JSON 다운로드", json_blob,
                       file_name="adipo_cls_cohort.json", mime="application/json")


if __name__ == "__main__":  # pragma: no cover
    main()
