"""CirrDecompUnit-Kor Streamlit dashboard.

Run:
    streamlit run app.py

Standalone offline. Uses synthetic CSVs from data/.
For research / QI use only. NOT for clinical decision making.
"""
from __future__ import annotations
import os
import sys
from collections import defaultdict
from typing import Dict, List

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)
DATA_DIR = os.path.join(HERE, "data")

try:
    import streamlit as st
    import pandas as pd
    import plotly.express as px
except Exception as e:  # pragma: no cover
    print("Streamlit/pandas/plotly are required. Install requirements.txt.")
    print(e)
    sys.exit(1)

from modules import kpi_report, lt_candidacy  # noqa: E402


# ---------------------------------------------------------------------------
@st.cache_data
def load_data() -> Dict[str, pd.DataFrame]:
    def _read(name):
        path = os.path.join(DATA_DIR, name)
        return pd.read_csv(path) if os.path.exists(path) else pd.DataFrame()
    return {
        "patients": _read("patients.csv"),
        "episodes": _read("episodes.csv"),
        "labs": _read("lab_trajectory.csv"),
        "adherence": _read("protocol_adherence.csv"),
        "lt": _read("lt_waitlist.csv"),
        "outcomes": _read("outcomes.csv"),
    }


def main():
    st.set_page_config(page_title="CirrDecompUnit-Kor", layout="wide")
    st.title("시르데컴프유닛코어 (CirrDecompUnit-Kor)")
    st.caption("MASLD/MASH cirrhosis decompensation inpatient QI dashboard "
               "— 참고용·연구용 (research / QI only)")

    data = load_data()
    if data["episodes"].empty:
        st.warning("No data found. Run: `python3 main.py --gen`")
        return

    epi = data["episodes"].copy()

    with st.sidebar:
        st.header("Filters")
        etio_opts = sorted(epi["etiology"].dropna().unique())
        etio_sel = st.multiselect("Etiology", etio_opts, default=etio_opts)
        type_opts = sorted(epi["decomp_type"].dropna().unique())
        type_sel = st.multiselect("Decompensation type", type_opts, default=type_opts)
        aclf_opts = ["no ACLF", "ACLF-1", "ACLF-2", "ACLF-3"]
        aclf_sel = st.multiselect("ACLF grade", aclf_opts, default=aclf_opts)
        ward_opts = sorted(epi["ward"].dropna().unique())
        ward_sel = st.multiselect("Ward", ward_opts, default=ward_opts)

    flt = epi[
        epi["etiology"].isin(etio_sel)
        & epi["decomp_type"].isin(type_sel)
        & epi["aclf_grade"].isin(aclf_sel)
        & epi["ward"].isin(ward_sel)
    ]

    if flt.empty:
        st.warning("No episodes after filtering.")
        return

    # ------------------------------------------------------------------ KPI
    c1, c2, c3, c4, c5 = st.columns(5)
    c1.metric("환자", flt["patient_id"].nunique())
    c2.metric("Episodes", len(flt))
    c3.metric("MELD 3.0 평균", f"{flt['meld3'].mean():.1f}")
    c4.metric("30일 사망률", f"{flt['mortality_30d'].mean():.1%}")
    c5.metric("90일 재입원율", f"{flt['readmission_90d'].mean():.1%}")

    tab1, tab2, tab3, tab4, tab5 = st.tabs([
        "1. Cohort / etiology",
        "2. ACLF / MELD trajectory",
        "3. Protocol adherence",
        "4. LT candidacy / KONOS",
        "5. Post-discharge & report",
    ])

    # 1. Cohort / etiology --------------------------------------------------
    with tab1:
        st.subheader("Etiology mix")
        etio_ct = flt["etiology"].value_counts().reset_index()
        etio_ct.columns = ["etiology", "n"]
        st.plotly_chart(px.bar(etio_ct, x="etiology", y="n",
                               color="etiology"), use_container_width=True)

        st.subheader("Decompensation type")
        d_ct = flt["decomp_type"].value_counts().reset_index()
        d_ct.columns = ["decomp_type", "n"]
        st.plotly_chart(px.bar(d_ct, x="decomp_type", y="n",
                               color="decomp_type"), use_container_width=True)

        st.subheader("Per-ward episode load")
        ward_ct = flt.groupby(["ward", "decomp_type"]).size().reset_index(name="n")
        st.plotly_chart(px.bar(ward_ct, x="ward", y="n", color="decomp_type"),
                        use_container_width=True)

    # 2. ACLF / MELD --------------------------------------------------------
    with tab2:
        st.subheader("EASL-CLIF ACLF grade distribution")
        order = ["no ACLF", "ACLF-1", "ACLF-2", "ACLF-3"]
        g_ct = flt["aclf_grade"].value_counts().reindex(order, fill_value=0).reset_index()
        g_ct.columns = ["aclf_grade", "n"]
        st.plotly_chart(px.bar(g_ct, x="aclf_grade", y="n", color="aclf_grade",
                               category_orders={"aclf_grade": order}),
                        use_container_width=True)

        st.subheader("30-day mortality by ACLF grade")
        mort = flt.groupby("aclf_grade")["mortality_30d"].mean().reindex(order)
        mort_df = mort.reset_index().rename(columns={"mortality_30d": "rate"})
        st.plotly_chart(px.bar(mort_df, x="aclf_grade", y="rate",
                               labels={"rate": "30d mortality"}),
                        use_container_width=True)

        st.subheader("MELD / MELD-Na / MELD 3.0 distribution")
        long = flt[["meld", "meld_na", "meld3"]].melt(var_name="score",
                                                       value_name="value")
        st.plotly_chart(px.violin(long, x="score", y="value", box=True),
                        use_container_width=True)

        st.subheader("CLIF-SOFA daily trajectory (sample episode)")
        if not data["labs"].empty:
            sample_eid = flt.iloc[0]["episode_id"]
            traj = data["labs"][data["labs"]["episode_id"] == sample_eid]
            if not traj.empty:
                st.plotly_chart(px.line(traj, x="day", y="clif_sofa",
                                        markers=True,
                                        title=f"Episode {sample_eid}"),
                                use_container_width=True)

    # 3. Protocol adherence -------------------------------------------------
    with tab3:
        st.subheader("Protocol adherence by decomp type")
        adh = flt.groupby("decomp_type")["protocol_adherence"].mean().reset_index()
        st.plotly_chart(px.bar(adh, x="decomp_type", y="protocol_adherence",
                               color="decomp_type"),
                        use_container_width=True)

        st.subheader("Adherence vs 30-day mortality")
        agg = flt.groupby("decomp_type").agg(
            adherence=("protocol_adherence", "mean"),
            mortality=("mortality_30d", "mean"),
            n=("episode_id", "count"),
        ).reset_index()
        st.plotly_chart(px.scatter(agg, x="adherence", y="mortality",
                                   size="n", color="decomp_type",
                                   text="decomp_type"),
                        use_container_width=True)

        st.subheader("Ward ranking (mean adherence)")
        rank = (flt.groupby("ward")["protocol_adherence"].mean()
                .sort_values(ascending=False).reset_index())
        st.dataframe(rank.style.format({"protocol_adherence": "{:.2%}"}),
                     use_container_width=True)

    # 4. LT candidacy -------------------------------------------------------
    with tab4:
        st.subheader("KONOS priority band")
        if not data["lt"].empty:
            lt = data["lt"].merge(flt[["episode_id"]], on="episode_id")
            band_ct = lt["konos_band"].value_counts().reset_index()
            band_ct.columns = ["konos_band", "n"]
            st.plotly_chart(px.pie(band_ct, names="konos_band", values="n"),
                            use_container_width=True)
            st.dataframe(lt.head(50), use_container_width=True)
        st.markdown("MELD 3.0 ≥ 15 또는 ACLF-1/2/3 → LT 평가 대상. "
                    "OPTN MELD 3.0 (2023-07) 및 KONOS 가이드 호환 (proxy).")

    # 5. Post-discharge & report -------------------------------------------
    with tab5:
        st.subheader("Post-discharge follow-up rate")
        if not data["outcomes"].empty:
            o = data["outcomes"].merge(flt[["episode_id"]], on="episode_id")
            visits = {
                "POD7": o["pod7_visit"].mean(),
                "POD30": o["pod30_visit"].mean(),
                "POD90": o["pod90_visit"].mean(),
            }
            v_df = pd.DataFrame({"visit": list(visits), "rate": list(visits.values())})
            st.plotly_chart(px.bar(v_df, x="visit", y="rate"),
                            use_container_width=True)

            st.subheader("Discharge medications")
            med_cols = ["nsbb_on_discharge", "diuretic_on_discharge",
                        "rifaximin_on_discharge", "sbp_prophylaxis_on_discharge",
                        "resmetirom_on_discharge"]
            med_df = pd.DataFrame({"med": med_cols,
                                   "rate": [o[c].mean() for c in med_cols]})
            st.plotly_chart(px.bar(med_df, x="med", y="rate"),
                            use_container_width=True)

        st.subheader("Generate KASL/KLTS-style QI report")
        records = flt.to_dict("records")
        kpis = kpi_report.episode_kpis(records)
        lang = st.radio("언어", ["kor", "eng"], horizontal=True)
        st.code(kpi_report.render_text_report(kpis, lang=lang), language="markdown")

    st.divider()
    st.caption("Synthetic data only. de-identification: PHI columns (name/RRN/MRN) "
               "dropped at ingest. 참고용·연구용 only — NOT for clinical decision making.")


if __name__ == "__main__":
    main()
