"""Streamlit dashboard — BariERASRecov-Kor.

비만수술 perioperative ERAS recovery cohort dashboard.

Run:
    pip install -r requirements.txt
    streamlit run app.py

For research / synthetic data only. NOT for clinical decision making.
"""
from __future__ import annotations

import os
import sys

DISCLAIMER = ("본 대시보드는 합성 데이터 기반 참고용·연구용입니다. "
              "임상 의사결정에 직접 사용하지 마십시오.  "
              "(For research / synthetic data only — NOT for clinical decision.)")


def _missing_deps_panel(err: Exception):
    print(f"[BariERASRecov-Kor] streamlit 의존성 누락: {err}", file=sys.stderr)
    print("가상환경에서 `pip install -r requirements.txt` 후 다시 실행하세요.",
          file=sys.stderr)
    sys.exit(0)


def main() -> None:
    try:
        import streamlit as st
        import pandas as pd
        import plotly.express as px
        import plotly.graph_objects as go
    except ImportError as e:
        _missing_deps_panel(e)
        return

    here = os.path.dirname(os.path.abspath(__file__))
    sys.path.insert(0, here)

    from modules import ingest, eras, procedure, hypo, outpatient

    st.set_page_config(
        page_title="BariERASRecov-Kor",
        page_icon="🏥", layout="wide",
    )
    st.title("BariERASRecov-Kor — 비만수술 perioperative ERAS recovery dashboard")
    st.caption("RYGB · SG · OAGB · SADI · DJB | ERAS Society · MBSAQIP · "
               "ASMBS · KASMBS · KDA")
    st.warning(DISCLAIMER)

    # -- Sidebar: data load -----------------------------------------------
    st.sidebar.header("데이터")
    data_dir = st.sidebar.text_input("CSV 디렉토리", value=os.path.join(here, "data"))
    if st.sidebar.button("Synthetic data 재생성 (n=400, seed=42)"):
        with st.spinner("생성 중..."):
            ingest.generate_synthetic(n_patients=400, out_dir=data_dir, seed=42)
        st.sidebar.success("data/ 재생성 완료")

    if not os.path.exists(os.path.join(data_dir, "patients.csv")):
        st.error("patients.csv 없음. 사이드바에서 synthetic data 생성 또는 "
                 "`python3 main.py --gen-data` 실행하세요.")
        return

    patients, intraop, pod03, pod430, pod90, hypo_ev, irep = ingest.load_all(data_dir)

    # Sidebar filters
    st.sidebar.header("필터")
    procs = sorted({p.procedure for p in patients})
    sel_procs = st.sidebar.multiselect("술식", procs, default=procs)
    wards = sorted({p.ward for p in patients})
    sel_wards = st.sidebar.multiselect("Ward", wards, default=wards)

    patients_f = [p for p in patients
                  if p.procedure in sel_procs and p.ward in sel_wards]
    keep = {p.patient_id for p in patients_f}
    intraop_f = [r for r in intraop if r.patient_id in keep]
    pod03_f = [r for r in pod03 if r.patient_id in keep]
    pod430_f = [r for r in pod430 if r.patient_id in keep]
    pod90_f = [r for r in pod90 if r.patient_id in keep]
    hypo_f = [r for r in hypo_ev if r.patient_id in keep]

    # -- KPI cards --------------------------------------------------------
    c1, c2, c3, c4, c5 = st.columns(5)
    c1.metric("환자 N", len(patients_f))
    c2.metric("Mean BMI", round(
        sum(p.bmi_pre for p in patients_f) / max(1, len(patients_f)), 1))
    n_leak = sum(1 for r in pod03_f if r.leak or r.staple_leak)
    c3.metric("30d Leak %",
              round(100 * n_leak / max(1, len(patients_f)), 2))
    n_readmit = sum(1 for r in pod430_f if r.readmit_30d)
    c4.metric("30d Readmission %",
              round(100 * n_readmit / max(1, len(patients_f)), 2))
    n_mort = sum(1 for p in patients_f if p.died_30d)
    c5.metric("30d Mortality %",
              round(100 * n_mort / max(1, len(patients_f)), 2))

    # -- Tabs -------------------------------------------------------------
    tabs = st.tabs([
        "1) ERAS bundle radar",
        "2) MBSAQIP measures",
        "3) Procedure stratification",
        "4) Hypo + dumping",
        "5) POD0-90 outpatient",
        "6) KM survival",
        "7) Raw data",
    ])

    # 1) ERAS radar
    with tabs[0]:
        st.subheader("ERAS Society 비만수술 protocol — ward별 radar")
        bundles = eras.compute_patient_bundles(
            patients_f, intraop_f, pod03_f, pod430_f)
        radar = eras.ward_radar(bundles)
        rdf = pd.DataFrame([{
            "ward": w.ward, "N": w.n_patients,
            "preop": w.preop_pct, "intraop": w.intraop_pct,
            "POD0-3": w.pod03_pct, "POD4-30": w.pod430_pct,
            "overall": w.overall_pct,
        } for w in radar])
        st.dataframe(rdf, use_container_width=True)

        fig = go.Figure()
        categories = ["preop", "intraop", "POD0-3", "POD4-30"]
        for w in radar:
            fig.add_trace(go.Scatterpolar(
                r=[w.preop_pct, w.intraop_pct, w.pod03_pct, w.pod430_pct],
                theta=categories, fill="toself", name=w.ward,
            ))
        fig.update_layout(
            polar=dict(radialaxis=dict(visible=True, range=[0, 100])),
            showlegend=True, height=500,
        )
        st.plotly_chart(fig, use_container_width=True)

    # 2) MBSAQIP
    with tabs[1]:
        st.subheader("MBSAQIP / ASMBS analog quality measures")
        m = eras.mbsaqip_measures(patients_f, pod03_f, pod430_f)
        mdf = pd.DataFrame([{
            "measure": x.measure, "rate %": x.rate_pct,
            "events": x.n_events, "N": x.n_denominator,
            "target %": x.target_pct,
            "flag": ("PASS" if x.rate_pct <= x.target_pct else "WATCH"),
        } for x in m])
        st.dataframe(mdf, use_container_width=True)
        fig = px.bar(mdf, x="measure", y="rate %", color="flag",
                     text="rate %",
                     color_discrete_map={"PASS": "#4CAF50", "WATCH": "#E53935"})
        fig.update_layout(xaxis_tickangle=-30, height=420)
        st.plotly_chart(fig, use_container_width=True)

    # 3) Procedure stratification
    with tabs[2]:
        st.subheader("5+ 술식별 stratification")
        rows = procedure.stratify_by_procedure(patients_f, pod03_f, pod430_f)
        df = pd.DataFrame([r.__dict__ for r in rows])
        st.dataframe(df, use_container_width=True)
        if not df.empty:
            fig = px.bar(df, x="procedure",
                         y=["leak_rate_pct", "marginal_ulcer_pct",
                            "dumping_any_pct", "readmit_30d_pct"],
                         barmode="group", height=420,
                         labels={"value": "% of N", "variable": "endpoint"})
            st.plotly_chart(fig, use_container_width=True)
            oe_cols = ["oe_leak", "oe_readmit", "oe_mortality"]
            oe_df = df[["procedure"] + oe_cols].copy()
            st.markdown("**Risk-adjusted O/E ratio (vs. literature anchor)**")
            st.dataframe(oe_df, use_container_width=True)

    # 4) Hypoglycemia / dumping
    with tabs[3]:
        st.subheader("post-bariatric hypoglycemia + dumping syndrome 시간 분포")
        bucket = hypo.time_distribution(hypo_f)
        bdf = pd.DataFrame([r.__dict__ for r in bucket])
        st.dataframe(bdf, use_container_width=True)
        if not bdf.empty:
            fig = px.bar(bdf, x="time_bucket", y="n_events",
                         color="hypo_type", barmode="stack", height=380)
            st.plotly_chart(fig, use_container_width=True)
        pdf = pd.DataFrame([r.__dict__ for r in hypo.by_procedure(
            patients_f, hypo_f)])
        st.markdown("**By procedure**")
        st.dataframe(pdf, use_container_width=True)

    # 5) POD0-90 outpatient
    with tabs[4]:
        st.subheader("POD0-90 outpatient follow-up & transition")
        adh = outpatient.adherence_by("procedure", patients_f, pod90_f)
        adh_df = pd.DataFrame([r.__dict__ for r in adh])
        st.dataframe(adh_df, use_container_width=True)
        if not adh_df.empty:
            long_df = adh_df.melt(
                id_vars=["key"],
                value_vars=["pct_pod7", "pct_pod30", "pct_pod60", "pct_pod90"],
                var_name="visit", value_name="adherence_pct")
            fig = px.bar(long_df, x="key", y="adherence_pct",
                         color="visit", barmode="group", height=420,
                         labels={"key": "procedure"})
            st.plotly_chart(fig, use_container_width=True)
        trans = outpatient.transition_summary(patients_f, pod430_f)
        st.markdown(f"**30-day readmission**: {trans.n_readmit_30d}/"
                    f"{trans.n_total} ({trans.readmit_rate_pct}%)  |  "
                    f"**mortality**: {trans.n_mort_30d} "
                    f"({trans.mort_rate_pct}%)  |  "
                    f"median readmit day: {trans.median_time_to_readmit_d}")
        st.write("readmit reason mix:", trans.readmit_reason_mix)

    # 6) KM survival
    with tabs[5]:
        st.subheader("Kaplan-Meier (30-day readmission)")
        km = outpatient.kaplan_meier_step(
            patients_f, pod430_f, endpoint="readmit", horizon_d=30)
        kdf = pd.DataFrame(km, columns=["day", "S(t)", "n_at_risk"])
        st.dataframe(kdf, use_container_width=True)
        fig = px.line(kdf, x="day", y="S(t)", height=380,
                      title="30-day survival free of readmission")
        fig.update_layout(yaxis=dict(range=[0.0, 1.0]))
        st.plotly_chart(fig, use_container_width=True)

    # 7) Raw
    with tabs[6]:
        st.subheader("Raw data (de-identified)")
        st.markdown(f"**de-id**: {irep.deid_method}")
        st.markdown("**patients** (head)")
        st.dataframe(pd.DataFrame([p.__dict__ for p in patients_f[:50]]),
                     use_container_width=True)
        st.markdown("**hypo events** (head)")
        st.dataframe(pd.DataFrame([e.__dict__ for e in hypo_f[:50]]),
                     use_container_width=True)


if __name__ == "__main__":
    main()
