"""PancIsletMass v0 — Streamlit 진입점"""

from __future__ import annotations

import io
import json
import os
from typing import Dict, List

import numpy as np

try:
    import streamlit as st
except Exception as e:  # pragma: no cover - streamlit는 실행 시 필요
    st = None  # type: ignore
    _ST_IMPORT_ERR = e
else:
    _ST_IMPORT_ERR = None

try:
    import plotly.express as px
    import plotly.graph_objects as go
except Exception:
    px = None  # type: ignore
    go = None  # type: ignore

from lib import synth_wsi, islet_detect, cell_classify, mass_compute, stats as cstats


HERE = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(HERE, "data")
SAMPLE_DIR = os.path.join(DATA_DIR, "sample_wsi")
OUTPUT_DIR = os.path.join(HERE, "output")

DISCLAIMER_KO = (
    "본 도구는 연구·참고용입니다. 임상 진단 사용 금지. "
    "(For research use only — not for clinical diagnosis.)"
)


def _ensure_demo_cohort() -> Dict:
    """sample_wsi/demo_cohort.json 없으면 생성."""
    meta_path = os.path.join(SAMPLE_DIR, "demo_cohort.json")
    if os.path.exists(meta_path):
        with open(meta_path, "r", encoding="utf-8") as f:
            return json.load(f)
    cohort = synth_wsi.build_demo_cohort(SAMPLE_DIR)
    return cohort


def _channels_to_rgb(channels: np.ndarray, channel_names) -> np.ndarray:
    """Compose insulin(R), glucagon(G), somatostatin(B) RGB preview."""
    idx_ins = channel_names.index("insulin")
    idx_glu = channel_names.index("glucagon")
    idx_sst = channel_names.index("somatostatin")
    rgb = np.stack(
        [channels[idx_ins], channels[idx_glu], channels[idx_sst]], axis=-1
    )
    rgb = np.clip(rgb, 0, 1)
    return (rgb * 255).astype(np.uint8)


def _overlay_islets(rgb: np.ndarray, label_image: np.ndarray) -> np.ndarray:
    """Draw simple cyan boundary over RGB."""
    out = rgb.copy()
    h, w = label_image.shape
    edges = np.zeros((h, w), dtype=bool)
    edges[:-1, :] |= label_image[:-1, :] != label_image[1:, :]
    edges[:, :-1] |= label_image[:, :-1] != label_image[:, 1:]
    out[edges] = [0, 255, 255]
    return out


def run_pipeline_for_slide(slide_path: str, slide_meta: Dict) -> Dict:
    channels, _centers = synth_wsi.load_slide_npz(slide_path)
    label_image, rois = islet_detect.detect_islets(channels, synth_wsi.CHANNELS)
    comps = cell_classify.classify_islet_cells(channels, synth_wsi.CHANNELS, label_image)
    section_area_px = int(channels.shape[1] * channels.shape[2])
    sum_b = sum(c.beta_area_px for c in comps)
    sum_a = sum(c.alpha_area_px for c in comps)
    sum_d = sum(c.delta_area_px for c in comps)
    rec = mass_compute.slide_record(
        slide_id=slide_meta["slide_id"],
        group=slide_meta["group"],
        mouse_id=slide_meta["mouse_id"],
        pancreas_weight_mg=slide_meta["pancreas_weight_mg"],
        image_um_per_px=1.0,
        section_area_px=section_area_px,
        beta_px=sum_b,
        alpha_px=sum_a,
        delta_px=sum_d,
    )
    return {
        "channels": channels,
        "label_image": label_image,
        "rois": rois,
        "comps": comps,
        "record": rec,
    }


def run_pipeline_for_cohort(cohort: Dict) -> Dict[str, List[mass_compute.SlideMassRecord]]:
    by_mouse: Dict[str, List[mass_compute.SlideMassRecord]] = {}
    for s in cohort["slides"]:
        slide_path = os.path.join(SAMPLE_DIR, s["npz"])
        result = run_pipeline_for_slide(slide_path, s)
        by_mouse.setdefault(s["mouse_id"], []).append(result["record"])
    return by_mouse


def _render_streamlit() -> None:  # pragma: no cover - UI
    assert st is not None
    st.set_page_config(page_title="PancIsletMass v0", layout="wide")
    st.title("PancIsletMass v0 — β-cell mass morphometry")
    st.caption(DISCLAIMER_KO)

    cohort = _ensure_demo_cohort()
    slides = cohort["slides"]
    groups = cohort["groups"]
    channel_names = tuple(cohort["channels"])

    with st.sidebar:
        st.header("입력 / Input")
        mode = st.radio(
            "데이터 소스",
            ["데모 코호트 (Demo cohort)", "사용자 업로드 (.npz)"],
            index=0,
        )
        st.markdown("**Group filter**")
        sel_groups = st.multiselect("그룹 선택", groups, default=list(groups))
        slide_ids = [
            s["slide_id"] for s in slides if s["group"] in sel_groups
        ] or [s["slide_id"] for s in slides]
        sel_slide_id = st.selectbox("슬라이드 선택", slide_ids)
        run_cohort = st.button("코호트 전체 분석 실행", type="primary")

    sel_meta = next((s for s in slides if s["slide_id"] == sel_slide_id), slides[0])

    if mode.startswith("사용자"):
        up = st.sidebar.file_uploader(
            "channels CxHxW float32 npz 업로드 (key='channels')", type=["npz"]
        )
        if up is not None:
            tmp_path = os.path.join(OUTPUT_DIR, "user_upload.npz")
            os.makedirs(OUTPUT_DIR, exist_ok=True)
            with open(tmp_path, "wb") as f:
                f.write(up.read())
            sel_meta = {
                "slide_id": "user-upload",
                "group": "user",
                "mouse_id": "user",
                "pancreas_weight_mg": 180.0,
                "npz": os.path.relpath(tmp_path, SAMPLE_DIR),
            }
            slide_path = tmp_path
        else:
            slide_path = os.path.join(SAMPLE_DIR, sel_meta["npz"])
    else:
        slide_path = os.path.join(SAMPLE_DIR, sel_meta["npz"])

    st.subheader(f"슬라이드: {sel_meta['slide_id']}  (group={sel_meta['group']})")
    result = run_pipeline_for_slide(slide_path, sel_meta)

    col1, col2 = st.columns(2)
    rgb = _channels_to_rgb(result["channels"], channel_names)
    with col1:
        st.markdown("**채널 합성 RGB (R=insulin, G=glucagon, B=somatostatin)**")
        st.image(rgb, use_column_width=True)
    with col2:
        st.markdown("**검출 islet overlay (cyan boundary)**")
        st.image(_overlay_islets(rgb, result["label_image"]), use_column_width=True)

    with st.expander("개별 채널"):
        ch_cols = st.columns(3)
        for i, name in enumerate(channel_names):
            with ch_cols[i % 3]:
                arr = (np.clip(result["channels"][i], 0, 1) * 255).astype(np.uint8)
                st.markdown(f"`{name}`")
                st.image(arr, use_column_width=True, clamp=True)

    rows = cell_classify.composition_to_dicts(result["comps"])
    st.markdown("**검출/분류 결과 테이블**")
    st.dataframe(rows, use_container_width=True)

    rec = result["record"]
    st.markdown("**슬라이드 mass 요약 (Cavalieri 단일 절편)**")
    st.json(
        {
            "pancreas_section_area_um2": rec.pancreas_section_area_um2,
            "beta_area_um2": rec.beta_area_um2,
            "alpha_area_um2": rec.alpha_area_um2,
            "delta_area_um2": rec.delta_area_um2,
            "beta_fraction": round(rec.beta_fraction, 5),
        }
    )

    if run_cohort:
        st.divider()
        st.subheader("코호트 분석")
        with st.spinner("전체 슬라이드 처리 중..."):
            by_mouse = run_pipeline_for_cohort(cohort)
        summary = mass_compute.cohort_summary(by_mouse)
        st.dataframe(summary, use_container_width=True)

        # group-level beta_mass distribution
        values_by_group: Dict[str, List[float]] = {}
        for row in summary:
            values_by_group.setdefault(row["group"], []).append(row["beta_mass_mg"])
        gstats = cstats.group_stats(values_by_group)
        st.markdown("**그룹 통계 (β-cell mass mg)**")
        st.dataframe(
            [
                {"group": g.group, "n": g.n, "mean": round(g.mean, 4), "sd": round(g.sd, 4), "sem": round(g.sem, 4)}
                for g in gstats
            ],
            use_container_width=True,
        )
        F, p = cstats.one_way_anova(values_by_group)
        st.markdown(f"**One-way ANOVA**: F={F:.3f}, p={p}")
        pairs = cstats.pairwise_ttests(values_by_group)
        st.dataframe(pairs, use_container_width=True)

        if go is not None:
            fig = go.Figure()
            for g, vals in values_by_group.items():
                fig.add_trace(go.Violin(y=vals, name=g, box_visible=True, meanline_visible=True))
            fig.update_layout(title="β-cell mass per mouse (mg)", yaxis_title="β-cell mass (mg)")
            st.plotly_chart(fig, use_container_width=True)

        # export CSV
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        csv_path = os.path.join(OUTPUT_DIR, "cohort_summary.csv")
        with open(csv_path, "w", encoding="utf-8") as f:
            f.write("mouse_id,group,pancreas_weight_mg,n_slides,beta_fraction,beta_mass_mg\n")
            for row in summary:
                f.write(
                    f"{row['mouse_id']},{row['group']},{row['pancreas_weight_mg']},"
                    f"{row['n_slides']},{row['beta_fraction']},{row['beta_mass_mg']}\n"
                )
        st.success(f"CSV 저장: {csv_path}")

    st.divider()
    st.caption("출처: 2026-05-10 daily-ideas / PancIsletMass v0 — " + DISCLAIMER_KO)


def main() -> None:
    if st is None:
        print("Streamlit이 설치되어 있지 않습니다. pip install streamlit numpy scipy plotly scikit-image")
        if _ST_IMPORT_ERR is not None:
            print("import error:", _ST_IMPORT_ERR)
        return
    _render_streamlit()


if __name__ == "__main__":
    main()
