"""CholangioDuctReact v0 — Streamlit 진입점.

MASLD/MASH/PBC/PSC/BDL/CDAA/CDAHFD 마우스 모델 간 CK19/SOX9/EpCAM IHC +
paired H&E/Sirius Red WSI → portal tract 기반 ductular reaction 정량 +
type 분류 standalone 도구.

본 도구는 연구·참고용. 임상 진단 사용 금지.
"""
from __future__ import annotations

import io
import json
from pathlib import Path
from typing import List

import numpy as np

try:
    import streamlit as st
except Exception as e:  # pragma: no cover
    st = None  # type: ignore

from lib import synth_liver, registration, portal_detect, dr_classify, fibrosis_overlay, stats


PROJECT_DIR = Path(__file__).parent
DEMO_PATH = PROJECT_DIR / "data" / "demo_cohort.json"


def _slide_to_rgb_he(slide: synth_liver.SynthSlide) -> np.ndarray:
    return (slide.he * 255).clip(0, 255).astype(np.uint8)


def _gray_to_rgb(gray: np.ndarray, color=(0.0, 0.6, 0.0)) -> np.ndarray:
    """단일 채널을 가짜 IHC 색조로 변환 (CK19=brown, SOX9=blue 등)."""
    rgb = np.zeros((*gray.shape, 3), dtype=np.float32)
    for c in range(3):
        rgb[..., c] = (1 - gray) + gray * color[c]
    return (np.clip(rgb, 0, 1) * 255).astype(np.uint8)


def _ck19_colored(ck19: np.ndarray) -> np.ndarray:
    # brown DAB
    rgb = np.zeros((*ck19.shape, 3), dtype=np.float32)
    rgb[..., 0] = 1 - 0.4 * ck19
    rgb[..., 1] = 1 - 0.7 * ck19
    rgb[..., 2] = 1 - 0.9 * ck19
    return (np.clip(rgb, 0, 1) * 255).astype(np.uint8)


def _sirius_colored(sirius: np.ndarray) -> np.ndarray:
    rgb = np.ones((*sirius.shape, 3), dtype=np.float32)
    rgb[..., 0] = 1 - 0.05 * sirius
    rgb[..., 1] = 1 - 0.7 * sirius
    rgb[..., 2] = 1 - 0.7 * sirius
    return (np.clip(rgb, 0, 1) * 255).astype(np.uint8)


def _draw_overlay(rgb: np.ndarray, dets, drs) -> np.ndarray:
    """portal tract + DR structure 마커 그리기 (in-place 사본)."""
    out = rgb.copy()
    h, w = out.shape[:2]

    def circle(cx, cy, r, color, thick=2):
        for theta_deg in range(0, 360, 4):
            t = np.deg2rad(theta_deg)
            for rr in range(max(r - thick, 0), r + 1):
                x = int(cx + rr * np.cos(t))
                y = int(cy + rr * np.sin(t))
                if 0 <= x < w and 0 <= y < h:
                    out[y, x] = color

    for d in dets:
        circle(d.cx, d.cy, max(d.pv_radius + 4, 12), (255, 200, 0), thick=2)  # PT outline
    color_map = {"T1": (0, 200, 0), "T2": (0, 100, 220), "T3": (220, 30, 30)}
    for s in drs:
        circle(s.cx, s.cy, max(s.size + 1, 4), color_map.get(s.dr_type, (180, 180, 180)), thick=1)
    return out


def _build_cohort_rows(slides: List[synth_liver.SynthSlide]) -> List[stats.CohortRow]:
    rows: List[stats.CohortRow] = []
    for sl in slides:
        dets = portal_detect.detect_portal_tracts(sl.he, sl.ck19)
        drs = dr_classify.classify_dr_structures(sl.ck19, dets) if dets else []
        sirius_pct = fibrosis_overlay.collagen_area_fraction(sl.sirius)
        bridges = fibrosis_overlay.detect_bridges(dets, sl.sirius, sl.ck19) if dets else []
        rows.append(
            stats.per_slide_summary(
                slide_id=sl.slide_id,
                mouse_id=sl.mouse_id,
                group=sl.group,
                portal_tracts=dets,
                dr_structures=drs,
                sirius_area_pct=sirius_pct,
                bridges=bridges,
            )
        )
    return rows


def run_streamlit() -> None:
    if st is None:
        print("Streamlit이 설치되어 있지 않습니다. requirements.txt 참조.")
        return
    st.set_page_config(page_title="CholangioDuctReact v0", layout="wide")
    st.title("CholangioDuctReact v0")
    st.caption("MASLD 동물실험 — Ductular Reaction 정량/분류 도구 (연구·참고용. 임상 진단 사용 금지)")

    with st.sidebar:
        st.header("코호트")
        st.write("합성 데모 코호트: 3그룹 × 3마우스")
        if st.button("데모 코호트 로드/재생성"):
            slides = synth_liver.make_demo_cohort(seed=42)
            st.session_state["slides"] = slides
        if "slides" not in st.session_state:
            st.session_state["slides"] = synth_liver.make_demo_cohort(seed=42)
        slides: List[synth_liver.SynthSlide] = st.session_state["slides"]
        slide_labels = [f"{s.group} | {s.mouse_id} | {s.slide_id}" for s in slides]
        idx = st.selectbox("슬라이드 선택", range(len(slides)), format_func=lambda i: slide_labels[i])
        st.markdown("---")
        st.subheader("디스클레이머")
        st.warning("본 도구는 연구·참고용. 임상 진단 사용 금지.")

    sl = slides[idx]
    tabs = st.tabs(["Paired WSI", "Registration", "Portal & DR", "Fibrosis Overlay", "Cohort 통계"])

    # Tab 1: Paired WSI
    with tabs[0]:
        st.subheader(f"{sl.group} / {sl.mouse_id} / {sl.slide_id}")
        c1, c2, c3 = st.columns(3)
        with c1:
            st.image(_slide_to_rgb_he(sl), caption="H&E (pseudo)", use_container_width=True)
        with c2:
            st.image(_ck19_colored(sl.ck19), caption="CK19 IHC (DAB)", use_container_width=True)
        with c3:
            st.image(_sirius_colored(sl.sirius), caption="Sirius Red", use_container_width=True)
        st.caption("SOX9 채널은 CK19와 거의 겹쳐 보이지만 별도 채널로 보존됨.")

    # Tab 2: Registration demo
    with tabs[1]:
        st.subheader("Rigid translation registration (합성 offset 시연)")
        offset_x = st.slider("CK19 offset X", -15, 15, 6)
        offset_y = st.slider("CK19 offset Y", -15, 15, -4)
        moved = registration.apply_translation(sl.ck19, offset_x, offset_y)
        aligned, (dx, dy) = registration.register_pair(sl.ck19, moved)
        c1, c2, c3 = st.columns(3)
        c1.image(_ck19_colored(sl.ck19), caption="Reference CK19", use_container_width=True)
        c2.image(_ck19_colored(moved), caption=f"Misaligned (dx={offset_x}, dy={offset_y})", use_container_width=True)
        c3.image(_ck19_colored(aligned), caption=f"Aligned (estimated dx={dx}, dy={dy})", use_container_width=True)

    # Tab 3: Portal tract & DR
    with tabs[2]:
        st.subheader("Portal tract auto-detection + DR Type 분류")
        dets = portal_detect.detect_portal_tracts(sl.he, sl.ck19)
        drs = dr_classify.classify_dr_structures(sl.ck19, dets) if dets else []
        overlay = _draw_overlay(_slide_to_rgb_he(sl), dets, drs)
        c1, c2 = st.columns([2, 1])
        with c1:
            st.image(overlay, caption=f"검출 portal tract: {len(dets)} / DR structure: {len(drs)}", use_container_width=True)
        with c2:
            t1 = sum(1 for d in drs if d.dr_type == "T1")
            t2 = sum(1 for d in drs if d.dr_type == "T2")
            t3 = sum(1 for d in drs if d.dr_type == "T3")
            st.metric("Type1 (reactive ductule)", t1)
            st.metric("Type2 (intermediate hepatocyte)", t2)
            st.metric("Type3 (mass-like metaplasia)", t3)
            st.bar_chart({"count": [t1, t2, t3]})
            st.caption("색상: 노랑=portal tract, 초록=T1, 파랑=T2, 빨강=T3")

    # Tab 4: Fibrosis overlay
    with tabs[3]:
        st.subheader("Sirius Red Collagen + Bridging 분석")
        sir_pct = fibrosis_overlay.collagen_area_fraction(sl.sirius)
        st.metric("Collagen area (%)", f"{sir_pct * 100:.2f}")
        overlay_rgb = fibrosis_overlay.make_overlay_rgb(sl.he, sl.sirius)
        st.image((overlay_rgb * 255).astype(np.uint8), caption="H&E + Sirius Red overlay", use_container_width=True)
        dets = portal_detect.detect_portal_tracts(sl.he, sl.ck19)
        bridges = fibrosis_overlay.detect_bridges(dets, sl.sirius, sl.ck19) if dets else []
        if bridges:
            import pandas as pd

            df = pd.DataFrame(
                [
                    {
                        "PT_a": b.pt_a,
                        "PT_b": b.pt_b,
                        "거리(px)": round(b.distance, 1),
                        "Sirius_score": round(b.sirius_along_score, 3),
                        "CK19_score": round(b.ck19_along_score, 3),
                        "타입": b.bridge_type,
                    }
                    for b in bridges
                ]
            )
            st.dataframe(df, use_container_width=True)
        else:
            st.info("검출된 portal tract 쌍이 충분하지 않아 bridge 분석 결과 없음.")

    # Tab 5: Cohort 통계
    with tabs[4]:
        st.subheader("Cohort 통계 + ANOVA")
        rows = _build_cohort_rows(slides)
        import pandas as pd

        df = pd.DataFrame([r.__dict__ for r in rows])
        st.dataframe(df, use_container_width=True)

        groups: dict = {}
        for r in rows:
            groups.setdefault(r.group, []).append(r.dr_per_portal)
        anova = stats.one_way_anova(groups)
        st.write(f"DR per portal tract one-way ANOVA: F={anova.get('F', float('nan')):.3f}, p={anova.get('p', float('nan'))}")

        try:
            import plotly.express as px

            fig = px.violin(df, x="group", y="dr_per_portal", box=True, points="all", title="그룹별 DR / Portal Tract")
            st.plotly_chart(fig, use_container_width=True)
            fig2 = px.bar(
                df.melt(id_vars=["group", "mouse_id"], value_vars=["t1_count", "t2_count", "t3_count"]),
                x="mouse_id",
                y="value",
                color="variable",
                facet_col="group",
                title="DR Type 분포 (마우스별)",
            )
            st.plotly_chart(fig2, use_container_width=True)
        except Exception as e:
            st.info(f"plotly 미설치 또는 오류: {e}")

        # Export
        out_csv = df.to_csv(index=False).encode("utf-8")
        st.download_button("CSV 다운로드", out_csv, file_name="cholangio_cohort.csv", mime="text/csv")
        st.download_button(
            "JSON 다운로드",
            json.dumps(stats.cohort_to_dict(rows), ensure_ascii=False, indent=2).encode("utf-8"),
            file_name="cholangio_cohort.json",
            mime="application/json",
        )


if __name__ == "__main__":
    run_streamlit()
