"""
PostLTMASHMetabolicKor (포스트엘티매시메타볼릭코어)
간이식 후 cohort 의 MASH 재발 / NODAT / 면역억제제 / 항당뇨제 RWE Streamlit dashboard.

실행: streamlit run app.py
"""

import json
import os
from io import BytesIO

import numpy as np
import pandas as pd
import streamlit as st

# ---- 선택적 의존성 (런타임 import) ----------------------------------------
try:
    import plotly.express as px
    import plotly.graph_objects as go
    _HAS_PLOTLY = True
except Exception:  # pragma: no cover
    _HAS_PLOTLY = False

try:
    from lifelines import KaplanMeierFitter
    _HAS_LIFELINES = True
except Exception:  # pragma: no cover
    _HAS_LIFELINES = False

try:
    import statsmodels.api as sm
    _HAS_STATSMODELS = True
except Exception:  # pragma: no cover
    _HAS_STATSMODELS = False

try:
    from docx import Document
    _HAS_DOCX = True
except Exception:  # pragma: no cover
    _HAS_DOCX = False


# ---- 경로 ----------------------------------------------------------------
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
COHORT_CSV = os.path.join(BASE_DIR, "data", "synthetic_post_lt_cohort.csv")
LONGI_CSV = os.path.join(BASE_DIR, "data", "synthetic_longitudinal.csv")
GUIDE_JSON = os.path.join(BASE_DIR, "assets", "guidelines.json")


# ---- 데이터 로더 ---------------------------------------------------------
@st.cache_data(show_spinner=False)
def load_cohort():
    return pd.read_csv(COHORT_CSV)


@st.cache_data(show_spinner=False)
def load_longi():
    return pd.read_csv(LONGI_CSV)


@st.cache_data(show_spinner=False)
def load_guidelines():
    with open(GUIDE_JSON, "r", encoding="utf-8") as f:
        return json.load(f)


# ---- 페이지 설정 ---------------------------------------------------------
st.set_page_config(
    page_title="PostLTMASHMetabolicKor",
    page_icon="LT",
    layout="wide",
    initial_sidebar_state="expanded",
)

# ---- 의학 디스클레이머 ---------------------------------------------------
DISCLAIMER = (
    "본 대시보드는 **이식 RWE 연구·quality improvement 보조용** 합성 데이터 도구이며, "
    "실제 임상 의사결정은 이식센터 다학제팀 판단 하에 이루어져야 합니다."
)


def header():
    st.title("PostLTMASHMetabolicKor — 포스트엘티매시메타볼릭코어")
    st.caption(
        "간이식 후 MASH 재발 · NODAT/PTDM · 면역억제제 trough/dose vs metabolic AE · "
        "GLP-1RA/SGLT2i/metformin RWE · KLTF/KASL 가이드 호환 리포트"
    )
    st.warning(DISCLAIMER)


def sidebar_filters(cohort: pd.DataFrame):
    st.sidebar.header("필터")
    centers = sorted(cohort["center"].unique().tolist())
    sel_centers = st.sidebar.multiselect("이식센터", centers, default=centers)

    inds = sorted(cohort["indication"].unique().tolist())
    sel_inds = st.sidebar.multiselect("LT 적응증", inds, default=inds)

    regs = sorted(cohort["is_regimen"].unique().tolist())
    sel_regs = st.sidebar.multiselect("면역억제제 regimen", regs, default=regs)

    sex_opts = sorted(cohort["sex"].unique().tolist())
    sel_sex = st.sidebar.multiselect("성별", sex_opts, default=sex_opts)

    age_min, age_max = float(cohort["age_at_lt"].min()), float(cohort["age_at_lt"].max())
    age_rng = st.sidebar.slider("이식 시 연령", age_min, age_max, (age_min, age_max))

    fu_min, fu_max = int(cohort["followup_months"].min()), int(cohort["followup_months"].max())
    fu_rng = st.sidebar.slider("추적기간(개월)", fu_min, fu_max, (fu_min, fu_max))

    f = cohort[
        cohort["center"].isin(sel_centers)
        & cohort["indication"].isin(sel_inds)
        & cohort["is_regimen"].isin(sel_regs)
        & cohort["sex"].isin(sel_sex)
        & cohort["age_at_lt"].between(age_rng[0], age_rng[1])
        & cohort["followup_months"].between(fu_rng[0], fu_rng[1])
    ].copy()
    st.sidebar.markdown(f"**선택 cohort: n = {len(f)}**")
    return f


# ============================================================================
# Tab 1: post-LT timeline + MASH 재발 NIT
# ============================================================================
def tab_mash_nit(cohort: pd.DataFrame, longi: pd.DataFrame):
    st.subheader("1) post-LT timeline + MASH 재발 NIT 모듈")
    st.markdown(
        "LT=0 기준 longitudinal **VCTE LSM · MRI-PDFF · FIB-4** trajectory와 "
        "재발 fibrosis stage(F0~F4)·환자별 재발 timing heatmap."
    )

    c1, c2, c3, c4 = st.columns(4)
    c1.metric("MASH 재발 (%)", f"{cohort['mash_recurrence'].mean() * 100:.1f}")
    c2.metric("VCTE LSM ≥ 8 kPa (%)", f"{(cohort['vcte_lsm_kPa'] >= 8).mean() * 100:.1f}")
    c3.metric("MRI-PDFF ≥ 5 % (%)", f"{(cohort['mri_pdff_pct'] >= 5).mean() * 100:.1f}")
    c4.metric("FIB-4 ≥ 2.67 (%)", f"{(cohort['fib4'] >= 2.67).mean() * 100:.1f}")

    st.markdown("---")
    ids = cohort["patient_id"].tolist()
    if not ids:
        st.info("선택된 cohort 없음.")
        return

    sel_pid = st.selectbox("환자 선택 (longitudinal)", ids, index=0)
    pl = longi[longi["patient_id"] == sel_pid].sort_values("month_from_lt")
    if _HAS_PLOTLY and not pl.empty:
        fig = go.Figure()
        for col, label in [
            ("vcte_lsm_kPa", "VCTE LSM (kPa)"),
            ("mri_pdff_pct", "MRI-PDFF (%)"),
            ("fib4", "FIB-4"),
        ]:
            fig.add_trace(
                go.Scatter(
                    x=pl["month_from_lt"], y=pl[col], mode="lines+markers", name=label
                )
            )
        fig.update_layout(
            height=380,
            xaxis_title="개월 (LT=0)",
            yaxis_title="NIT 값",
            legend=dict(orientation="h"),
        )
        st.plotly_chart(fig, use_container_width=True)
    else:
        st.dataframe(pl)

    st.markdown("#### 재발 timing heatmap (1y / 3y / 5y bin)")
    recur = cohort[cohort["mash_recurrence"] == 1].copy()
    if len(recur):
        bins = pd.cut(
            recur["mash_recur_month"],
            bins=[0, 12, 36, 60, 200],
            labels=["≤1y", "1-3y", "3-5y", ">5y"],
        )
        heat = (
            recur.assign(bin=bins)
            .groupby(["center", "bin"], observed=True)
            .size()
            .reset_index(name="n")
        )
        if _HAS_PLOTLY:
            pv = heat.pivot(index="center", columns="bin", values="n").fillna(0)
            fig2 = px.imshow(
                pv, text_auto=True, aspect="auto", color_continuous_scale="Reds"
            )
            fig2.update_layout(height=350, xaxis_title="재발 시점", yaxis_title="센터")
            st.plotly_chart(fig2, use_container_width=True)
        else:
            st.dataframe(heat)
    else:
        st.info("선택 cohort에서 재발 사례 없음.")

    st.markdown("#### biopsy 시행 환자 fibrosis stage 분포")
    bx = cohort[cohort["biopsy_stage"] >= 0]
    if len(bx) and _HAS_PLOTLY:
        fig3 = px.histogram(
            bx, x="biopsy_stage", color="indication", barmode="stack",
            nbins=5, labels={"biopsy_stage": "Fibrosis stage (F0~F4)"},
        )
        fig3.update_layout(height=330)
        st.plotly_chart(fig3, use_container_width=True)
    else:
        st.write(f"biopsy 시행 n = {len(bx)}")


# ============================================================================
# Tab 2: NODAT/PTDM screening
# ============================================================================
def tab_nodat(cohort: pd.DataFrame):
    st.subheader("2) NODAT/PTDM screening 부합률")
    st.markdown(
        "WHO/ADA 진단 기준 자동 분류, screening 시행률 KPI(1m/3m/6m/1y/매년), "
        "5y cumulative incidence Kaplan-Meier, regimen 별 NODAT 위험 비교."
    )
    eligible = cohort[cohort["pre_lt_dm"] == 0].copy()
    c1, c2, c3, c4 = st.columns(4)
    c1.metric("Pre-LT DM 제외 n", len(eligible))
    c2.metric("NODAT 누적 (%)", f"{eligible['nodat'].mean() * 100:.1f}")
    c3.metric("HbA1c ≥ 6.5 (%)", f"{(eligible['hba1c_pct'] >= 6.5).mean() * 100:.1f}")
    c4.metric("FPG ≥ 126 (%)", f"{(eligible['fpg_mgdl'] >= 126).mean() * 100:.1f}")

    st.markdown("#### screening 시행률 (KPI)")
    rates = {
        "1m": eligible["scr_1m"].mean(),
        "3m": eligible["scr_3m"].mean(),
        "6m": eligible["scr_6m"].mean(),
        "1y": eligible["scr_1y"].mean(),
        "annual": eligible["scr_annual"].mean(),
    }
    sr = pd.DataFrame({"timepoint": list(rates.keys()), "compliance": list(rates.values())})
    sr["compliance_pct"] = (sr["compliance"] * 100).round(1)
    if _HAS_PLOTLY:
        fig = px.bar(sr, x="timepoint", y="compliance_pct", text="compliance_pct")
        fig.update_layout(yaxis_title="시행률 (%)", height=320)
        st.plotly_chart(fig, use_container_width=True)
    else:
        st.dataframe(sr)

    st.markdown("#### 5y cumulative incidence (Kaplan-Meier)")
    if _HAS_LIFELINES and len(eligible) >= 10:
        kmf = KaplanMeierFitter()
        T = np.where(eligible["nodat"] == 1, eligible["nodat_month"], eligible["followup_months"])
        E = eligible["nodat"].values
        kmf.fit(T, event_observed=E, label="NODAT cumulative incidence")
        cif = 1 - kmf.survival_function_["NODAT cumulative incidence"]
        if _HAS_PLOTLY:
            fig2 = go.Figure()
            fig2.add_trace(
                go.Scatter(x=cif.index, y=cif.values * 100, mode="lines", name="CIF")
            )
            fig2.update_layout(
                xaxis_title="개월 (LT=0)", yaxis_title="누적 NODAT (%)", height=340
            )
            st.plotly_chart(fig2, use_container_width=True)
        else:
            st.line_chart((cif * 100).rename("NODAT %"))
    else:
        st.info("lifelines 미설치 또는 n 부족.")

    st.markdown("#### regimen 별 NODAT 위험 비교")
    grp = (
        eligible.groupby("is_regimen")["nodat"]
        .agg(["mean", "count"])
        .rename(columns={"mean": "NODAT_rate", "count": "n"})
        .reset_index()
    )
    grp["NODAT_rate_pct"] = (grp["NODAT_rate"] * 100).round(1)
    st.dataframe(grp, use_container_width=True)


# ============================================================================
# Tab 3: 면역억제제 trough/dose vs metabolic AE
# ============================================================================
def tab_immuno(cohort: pd.DataFrame, longi: pd.DataFrame):
    st.subheader("3) 면역억제제 trough/dose vs metabolic AE")
    st.markdown(
        "tacrolimus / sirolimus / cyclosporine / steroid trough/dose 와 "
        "HbA1c · 체중 · BP · LDL/TG · 요산의 cross-sectional 상관, 환자별 trough timeline."
    )

    drug = st.radio(
        "약물 선택",
        ["tacrolimus", "cyclosporine", "sirolimus", "steroid"],
        horizontal=True,
    )
    drug_col = {
        "tacrolimus": "tac_trough_ng_mL",
        "cyclosporine": "csa_trough_ng_mL",
        "sirolimus": "siro_trough_ng_mL",
        "steroid": "steroid_dose_mg",
    }[drug]

    metric = st.selectbox(
        "metabolic outcome",
        [
            "hba1c_pct",
            "bmi_current",
            "weight_regain_pct",
            "sbp_mmHg",
            "ldl_mgdl",
            "tg_mgdl",
            "uric_acid_mgdl",
        ],
        index=0,
    )

    sub = cohort.dropna(subset=[drug_col]).copy()
    if len(sub) >= 3 and _HAS_PLOTLY:
        fig = px.scatter(
            sub, x=drug_col, y=metric, color="is_regimen",
            trendline="ols" if _HAS_STATSMODELS else None,
            hover_data=["patient_id", "indication", "center"],
        )
        fig.update_layout(height=420)
        st.plotly_chart(fig, use_container_width=True)
    else:
        st.info("선택 cohort에서 해당 약물 노출 환자 부족.")

    if _HAS_STATSMODELS and len(sub) >= 10:
        try:
            X = sm.add_constant(sub[drug_col].astype(float))
            y = sub[metric].astype(float)
            res = sm.OLS(y, X, missing="drop").fit()
            st.write(
                f"OLS 회귀: {metric} ~ {drug_col}  "
                f"β={res.params.iloc[1]:.3f},  p={res.pvalues.iloc[1]:.4f},  "
                f"R²={res.rsquared:.3f},  n={int(res.nobs)}"
            )
        except Exception as e:  # pragma: no cover
            st.write(f"(회귀 오류: {e})")

    st.markdown("#### 환자별 tacrolimus trough timeline (월별)")
    long_tac = longi.dropna(subset=["tac_trough_ng_mL"])
    if len(long_tac) and _HAS_PLOTLY:
        sample_pids = long_tac["patient_id"].drop_duplicates().head(12).tolist()
        plt = long_tac[long_tac["patient_id"].isin(sample_pids)]
        fig2 = px.line(
            plt, x="month_from_lt", y="tac_trough_ng_mL", color="patient_id",
            hover_data=["is_regimen"],
        )
        fig2.update_layout(height=380, xaxis_title="개월 (LT=0)", yaxis_title="TAC trough (ng/mL)")
        # 목표 범위 음영 (3-8 ng/mL)
        fig2.add_hrect(y0=3, y1=8, fillcolor="green", opacity=0.08, line_width=0)
        st.plotly_chart(fig2, use_container_width=True)


# ============================================================================
# Tab 4: GLP-1RA / SGLT2i / metformin RWE
# ============================================================================
def tab_antidm(cohort: pd.DataFrame):
    st.subheader("4) GLP-1RA / SGLT2i / metformin post-LT 사용 RWE")
    st.markdown(
        "사용 vs 미사용군의 metabolic outcome · 간기능 · 면역억제제 trough · 이식 거부 · 감염 비교."
    )

    drug = st.radio("약물", ["use_glp1ra", "use_sglt2i", "use_metformin"], horizontal=True)
    cohort = cohort.copy()
    cohort[drug] = cohort[drug].astype(int)

    metrics = [
        "hba1c_pct",
        "bmi_current",
        "weight_regain_pct",
        "ldl_mgdl",
        "tg_mgdl",
        "alt_uL",
        "ast_uL",
        "ggt_uL",
        "meld_score",
        "tac_trough_ng_mL",
    ]
    rows = []
    for m in metrics:
        a = cohort.loc[cohort[drug] == 1, m].dropna()
        b = cohort.loc[cohort[drug] == 0, m].dropna()
        if len(a) >= 3 and len(b) >= 3:
            rows.append(
                {
                    "metric": m,
                    "user_mean": round(a.mean(), 2),
                    "user_n": int(len(a)),
                    "nonuser_mean": round(b.mean(), 2),
                    "nonuser_n": int(len(b)),
                    "delta": round(a.mean() - b.mean(), 2),
                }
            )
    cmp = pd.DataFrame(rows)
    st.dataframe(cmp, use_container_width=True)

    if _HAS_PLOTLY and len(cmp):
        fig = px.bar(cmp, x="metric", y="delta", title="사용군 - 미사용군 (mean delta)")
        fig.update_layout(height=360)
        st.plotly_chart(fig, use_container_width=True)

    st.markdown("#### 이식 거부 / CMV 감염 비교 (rate %)")
    safety = (
        cohort.groupby(drug)[["rejection_event", "infection_cmv", "death_event"]].mean() * 100
    ).round(1)
    safety.index = ["미사용", "사용"]
    st.dataframe(safety, use_container_width=True)


# ============================================================================
# Tab 5: KLTF/KASL 가이드 호환 리포트 + multicenter 비교
# ============================================================================
def tab_report(cohort: pd.DataFrame, guidelines: dict):
    st.subheader("5) KLTF·KASL post-LT 가이드 호환 리포트 + multicenter 비교")
    st.markdown(
        "선택된 cohort 의 KPI 를 KLTF · KASL · AASLD · EASL 권고와 매핑하고, "
        "6대 이식센터 de-identified 비교 + 국문 docx 리포트를 생성한다."
    )

    # multicenter KPI
    kpi = (
        cohort.groupby("center")
        .agg(
            n=("patient_id", "count"),
            mash_recur_pct=("mash_recurrence", lambda s: round(s.mean() * 100, 1)),
            nodat_pct=("nodat", lambda s: round(s.mean() * 100, 1)),
            scr_1y=("scr_1y", lambda s: round(s.mean() * 100, 1)),
            scr_annual=("scr_annual", lambda s: round(s.mean() * 100, 1)),
            vcte_ge8=("vcte_lsm_kPa", lambda s: round((s >= 8).mean() * 100, 1)),
            pdff_ge5=("mri_pdff_pct", lambda s: round((s >= 5).mean() * 100, 1)),
            tac_mean=("tac_trough_ng_mL", lambda s: round(s.mean(), 2)),
        )
        .reset_index()
    )
    st.dataframe(kpi, use_container_width=True)

    if _HAS_PLOTLY:
        fig = px.bar(
            kpi.melt(id_vars="center", value_vars=["mash_recur_pct", "nodat_pct", "scr_annual"]),
            x="center", y="value", color="variable", barmode="group",
        )
        fig.update_layout(height=380, yaxis_title="%")
        st.plotly_chart(fig, use_container_width=True)

    st.markdown("#### KLTF / KASL / AASLD / EASL 권고 매핑")
    for rec in guidelines.get("recommendations", []):
        with st.expander(f"[{rec['source']}] {rec['domain']} — {rec['id']}"):
            st.write(rec["text"])
            st.caption(f"KPI hook: {rec.get('kpi', 'N/A')}")

    st.markdown("#### 국문 docx 리포트 생성")
    lang = st.radio("리포트 언어", ["국문", "영문"], horizontal=True)
    if st.button("리포트 생성"):
        if not _HAS_DOCX:
            st.error("python-docx 미설치. `pip install python-docx` 필요.")
        else:
            doc = Document()
            if lang == "국문":
                doc.add_heading("PostLTMASHMetabolicKor 리포트", 0)
                doc.add_paragraph(
                    "본 리포트는 합성 데이터 기반이며, 임상 의사결정은 이식센터 다학제팀 판단 하에 이루어져야 합니다."
                )
                doc.add_heading("1. cohort 요약", 1)
                doc.add_paragraph(f"전체 n = {len(cohort)}")
                doc.add_paragraph(
                    f"MASH 재발률: {cohort['mash_recurrence'].mean() * 100:.1f}%, "
                    f"NODAT: {cohort['nodat'].mean() * 100:.1f}%"
                )
                doc.add_heading("2. 가이드 권고", 1)
                for r in guidelines.get("recommendations", []):
                    doc.add_paragraph(f"[{r['source']}] {r['domain']}: {r['text']}")
            else:
                doc.add_heading("PostLTMASHMetabolicKor Report", 0)
                doc.add_paragraph(
                    "Synthetic data dashboard. Clinical decisions must be made by transplant team."
                )
                doc.add_heading("1. Cohort Summary", 1)
                doc.add_paragraph(f"Total n = {len(cohort)}")
                doc.add_paragraph(
                    f"MASH recurrence: {cohort['mash_recurrence'].mean() * 100:.1f}%, "
                    f"NODAT: {cohort['nodat'].mean() * 100:.1f}%"
                )
                doc.add_heading("2. Guideline Recommendations", 1)
                for r in guidelines.get("recommendations", []):
                    doc.add_paragraph(f"[{r['source']}] {r['domain']}: {r['text']}")
            buf = BytesIO()
            doc.save(buf)
            buf.seek(0)
            st.download_button(
                "리포트 다운로드 (.docx)",
                buf,
                file_name=f"post_lt_report_{lang}.docx",
                mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
            )


# ---- main ----------------------------------------------------------------
def main():
    header()
    try:
        cohort = load_cohort()
        longi = load_longi()
        guidelines = load_guidelines()
    except FileNotFoundError as e:
        st.error(
            f"데이터 파일을 찾을 수 없습니다: {e}\n"
            "먼저 `python3 data/generate_synthetic.py` 를 실행하세요."
        )
        return

    filtered = sidebar_filters(cohort)
    if len(filtered) == 0:
        st.warning("선택된 cohort 가 없습니다. 필터를 조정하세요.")
        return

    tabs = st.tabs(
        [
            "1. MASH 재발 NIT",
            "2. NODAT/PTDM",
            "3. 면역억제제",
            "4. 항당뇨제 RWE",
            "5. 리포트/multicenter",
        ]
    )
    long_filtered = longi[longi["patient_id"].isin(filtered["patient_id"])]
    with tabs[0]:
        tab_mash_nit(filtered, long_filtered)
    with tabs[1]:
        tab_nodat(filtered)
    with tabs[2]:
        tab_immuno(filtered, long_filtered)
    with tabs[3]:
        tab_antidm(filtered)
    with tabs[4]:
        tab_report(filtered, guidelines)

    st.markdown("---")
    st.caption(
        "출처: KLTF post-LT care · KASL MASLD CPG · AASLD 2023 MASLD Practice Guidance · "
        "EASL-EASD-EASO 2024 · ADA SoC 2024-2025 · ILTS-AASLD-NASPGHAN 2022 · KONOS 통계."
    )


if __name__ == "__main__":
    main()
