"""
GLP1DiscontRebound-Kor — Streamlit 대시보드 진입점
GLP-1RA·tirzepatide·orforglipron 중단·rebound·재시작 cohort 모니터링.

⚠️ 의학 디스클레이머:
본 대시보드는 연구·post-marketing 보조용 도구이며,
실제 임상 의사결정은 담당 의사 판단 하에 이루어져야 함.
EMR 연동 0, 외부 API 0, 오프라인 동작.
"""
import io
import json
import os
from datetime import datetime

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st

try:
    from lifelines import KaplanMeierFitter
    LIFELINES_OK = True
except Exception:
    LIFELINES_OK = False

try:
    from docx import Document
    DOCX_OK = True
except Exception:
    DOCX_OK = False


BASE = os.path.dirname(os.path.abspath(__file__))
COHORT_CSV = os.path.join(BASE, "data", "synthetic_glp1_cohort.csv")
LONGI_CSV = os.path.join(BASE, "data", "synthetic_longitudinal.csv")
REF_JSON = os.path.join(BASE, "assets", "reference_curves.json")


@st.cache_data
def load_cohort() -> pd.DataFrame:
    return pd.read_csv(COHORT_CSV)


@st.cache_data
def load_longi() -> pd.DataFrame:
    return pd.read_csv(LONGI_CSV)


@st.cache_data
def load_ref() -> dict:
    with open(REF_JSON, "r", encoding="utf-8") as f:
        return json.load(f)


def disclaimer_box():
    st.warning(
        "⚠️ **의학 디스클레이머** — 본 대시보드는 연구·post-marketing 보조용이며, "
        "실제 임상 의사결정은 담당 의사 판단 하에 이루어져야 함. "
        "합성 데이터로 동작하며 실제 환자 정보를 포함하지 않음."
    )


def kpi_row(cohort: pd.DataFrame):
    total = len(cohort)
    disc = int(cohort["discontinued"].sum())
    restart = int(cohort["restarted"].fillna(False).sum())
    holiday = int(cohort["holiday"].fillna(False).sum())
    surg = int((cohort["surgical_conversion"] != "none").sum())
    c1, c2, c3, c4, c5 = st.columns(5)
    c1.metric("총 환자 수", f"{total:,}")
    c2.metric("중단 환자", f"{disc:,}", f"{disc/total*100:.1f}%")
    c3.metric("재시작 환자", f"{restart:,}", f"{restart/max(disc,1)*100:.1f}% of 중단")
    c4.metric("Holiday(<12주)", f"{holiday:,}")
    c5.metric("외과·내시경 전환", f"{surg:,}")


def tab1_ingest_and_reasons(cohort: pd.DataFrame):
    st.subheader("① 중단 이벤트 인제스트 + 사유 분류")
    st.caption("환자별 약물 timeline, 약물별·사유별 cohort 중단율 KPI.")

    horizon = st.selectbox(
        "중단율 산정 horizon (주)",
        options=[12, 26, 52, 104],
        index=2,
        format_func=lambda x: f"{x}주 ({x//4}m)",
    )
    horizon_w = horizon
    drugs = sorted(cohort["drug"].unique())
    rows = []
    for d in drugs:
        sub = cohort[cohort["drug"] == d]
        n = len(sub)
        # 중단=duration < horizon 인 환자
        disc_by_h = int(((sub["discontinued"]) & (sub["duration_weeks"].fillna(99999) <= horizon_w)).sum())
        rows.append({
            "drug": d,
            "N": n,
            "discontinued_by_horizon": disc_by_h,
            "rate_pct": round(disc_by_h / max(n, 1) * 100, 1),
        })
    rate_df = pd.DataFrame(rows)
    st.dataframe(rate_df, use_container_width=True)

    st.markdown("**약물별 × 중단사유 stack chart**")
    sub = cohort[cohort["discontinued"]].copy()
    agg = sub.groupby(["drug", "discontinuation_reason"]).size().reset_index(name="N")
    fig = px.bar(
        agg, x="drug", y="N", color="discontinuation_reason",
        barmode="stack", title="약물별 중단 사유 분포",
    )
    fig.update_layout(height=420)
    st.plotly_chart(fig, use_container_width=True)

    st.markdown("**환자 timeline 샘플**")
    sample = cohort.sample(min(15, len(cohort)), random_state=1).sort_values("patient_id")
    tl_rows = []
    for _, r in sample.iterrows():
        tl_rows.append({
            "patient_id": r["patient_id"],
            "drug": r["drug"],
            "start_w": 0,
            "end_w": r["duration_weeks"],
            "discontinued": r["discontinued"],
            "reason": r["discontinuation_reason"],
            "restarted": r["restarted"],
        })
    tl_df = pd.DataFrame(tl_rows)
    fig2 = px.timeline(
        tl_df.assign(
            start_dt=pd.to_datetime("2023-01-01") + pd.to_timedelta(tl_df["start_w"] * 7, unit="d"),
            end_dt=pd.to_datetime("2023-01-01") + pd.to_timedelta(tl_df["end_w"] * 7, unit="d"),
        ),
        x_start="start_dt", x_end="end_dt", y="patient_id", color="drug",
        hover_data=["reason", "restarted"],
        title="환자별 약물 사용 timeline (샘플 15명)",
    )
    fig2.update_layout(height=480)
    st.plotly_chart(fig2, use_container_width=True)


def tab2_rebound(cohort: pd.DataFrame, longi: pd.DataFrame, ref: dict):
    st.subheader("② Rebound trajectory 시각화")
    st.caption("중단 시점=0 정렬. STEP-1 ext / SURMOUNT-4 reference envelope overlay.")

    metric = st.selectbox(
        "지표",
        ["weight_kg", "hba1c", "sbp", "ldl", "hdl", "tg"],
        index=0,
    )
    drug_filter = st.multiselect("약물 필터", options=sorted(cohort["drug"].unique()),
                                   default=sorted(cohort["drug"].unique()))

    pid_in = cohort[cohort["drug"].isin(drug_filter)]["patient_id"].tolist()
    longi_f = longi[longi["patient_id"].isin(pid_in)]
    if longi_f.empty:
        st.info("선택된 약물에 해당하는 longitudinal 데이터 없음.")
        return

    agg = longi_f.groupby("week_from_disc")[metric].agg(
        median="median",
        p25=lambda x: np.percentile(x, 25),
        p75=lambda x: np.percentile(x, 75),
    ).reset_index()

    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=agg["week_from_disc"], y=agg["p75"],
        fill=None, mode="lines",
        line=dict(width=0), showlegend=False,
    ))
    fig.add_trace(go.Scatter(
        x=agg["week_from_disc"], y=agg["p25"],
        fill="tonexty", mode="lines",
        line=dict(width=0), name="IQR (cohort)",
        fillcolor="rgba(31,119,180,0.18)",
    ))
    fig.add_trace(go.Scatter(
        x=agg["week_from_disc"], y=agg["median"],
        mode="lines+markers", name="Cohort median",
        line=dict(color="#1f77b4", width=2.5),
    ))

    # reference envelope (weight_kg일 때만 overlay)
    if metric == "weight_kg":
        env = ref.get("envelope_summary", {})
        if env:
            base_med = float(cohort.loc[cohort["patient_id"].isin(pid_in), "baseline_weight_kg"].median())
            nad_med = float(cohort.loc[cohort["patient_id"].isin(pid_in), "nadir_weight_kg"].median())
            weeks = env["weeks"]
            for key, color, name in [
                ("regain_lower", "rgba(255,127,14,0.25)", "Ref envelope (low)"),
                ("regain_central", "#ff7f0e", "Ref envelope (central)"),
                ("regain_upper", "rgba(255,127,14,0.25)", "Ref envelope (high)"),
            ]:
                vals = [nad_med + (base_med - nad_med) * f for f in env[key]]
                dash = "solid" if "central" in key else "dot"
                fig.add_trace(go.Scatter(
                    x=weeks, y=vals, mode="lines",
                    name=name,
                    line=dict(color=color, width=1.6, dash=dash),
                ))

    fig.add_vline(x=0, line=dict(color="gray", dash="dash"))
    fig.update_layout(
        title=f"{metric} trajectory (중단 시점=0)",
        xaxis_title="중단 후 주차 (week_from_disc)",
        yaxis_title=metric,
        height=500,
    )
    st.plotly_chart(fig, use_container_width=True)


def tab3_restart(cohort: pd.DataFrame):
    st.subheader("③ 재시작 dose titration 추적")
    st.caption("재시작 환자의 dose escalation, GI AE 재발률, 추가 감량.")

    restarters = cohort[cohort["restarted"] == True].copy()
    naive = cohort[cohort["restarted"] != True].copy()
    if restarters.empty:
        st.info("재시작 환자 없음.")
        return

    c1, c2, c3 = st.columns(3)
    c1.metric("재시작 환자 수", len(restarters))
    c2.metric("재시작 GI AE 재발률", f"{restarters['restart_gi_ae'].mean()*100:.1f}%")
    add_med = restarters["restart_additional_loss_pct"].median()
    c3.metric("재시작 후 추가 감량 중앙값(%)", f"{add_med:.1f}")

    st.markdown("**약물별 재시작 vs 신규(non-restart) 최대 감량 비교**")
    cmp = (
        pd.concat([
            restarters.assign(group="restart"),
            naive.assign(group="non-restart"),
        ])
        .groupby(["drug", "group"])["peak_weight_loss_pct"].median()
        .reset_index()
    )
    fig = px.bar(cmp, x="drug", y="peak_weight_loss_pct", color="group",
                 barmode="group", title="약물별 최대 감량(%) — 재시작 vs 비재시작")
    fig.update_layout(height=420)
    st.plotly_chart(fig, use_container_width=True)

    st.markdown("**재시작 환자 dose escalation 분포 (시작 → 최대)**")
    dose_df = restarters.groupby("drug").agg(
        start_dose_median=("start_dose_mg", "median"),
        max_dose_median=("max_dose_mg", "median"),
        N=("patient_id", "count"),
    ).reset_index()
    st.dataframe(dose_df, use_container_width=True)


def tab4_holiday(cohort: pd.DataFrame):
    st.subheader("④ 약물 holiday vs 영구 중단 outcome")
    st.caption("의도적 holiday(<12주 중단 후 재시작) vs 영구 중단(≥12주 미재시작). KM survival.")

    disc = cohort[cohort["discontinued"]].copy()
    if disc.empty:
        st.info("중단 환자 없음.")
        return
    disc["group"] = np.where(
        disc["holiday"].fillna(False),
        "holiday",
        np.where(disc["restarted"].fillna(False), "delayed_restart", "permanent_discont"),
    )

    st.markdown("**그룹별 52주 시점 outcome 비교**")
    grp = disc.groupby("group").agg(
        N=("patient_id", "count"),
        median_rebound_frac_52w=("rebound_frac_52w", "median"),
        median_wt_52w_post=("weight_at_52w_post_disc_kg", "median"),
        median_hba1c_52w_post=("hba1c_at_52w_post_disc", "median"),
        cv_event_rate=("cv_event", "mean"),
        surgical_rate=("surgical_conversion", lambda s: (s != "none").mean()),
    ).reset_index()
    grp["cv_event_rate"] = (grp["cv_event_rate"] * 100).round(2)
    grp["surgical_rate"] = (grp["surgical_rate"] * 100).round(2)
    st.dataframe(grp, use_container_width=True)

    st.markdown("**KM survival — '60% rebound 도달까지의 시간' 모델 (proxy)**")
    if not LIFELINES_OK:
        st.info("lifelines 미설치 — `pip install lifelines` 후 표시됨.")
        return
    # proxy: 60% rebound 도달 = duration_to_event; rebound_frac_52w 기준 추정
    # event=1 if rebound_frac_52w >= 0.6 else 0; time=52 if event else 52
    # 더 의미있게: time = 52 * (0.6/rebound_frac_52w) clipped, event=1 if rebound>=0.6 else 0
    disc["event"] = (disc["rebound_frac_52w"] >= 0.6).astype(int)
    disc["time_w"] = np.where(
        disc["event"] == 1,
        np.clip(52 * (0.6 / np.maximum(disc["rebound_frac_52w"], 0.01)), 1, 52),
        52,
    )
    fig = go.Figure()
    for g in disc["group"].unique():
        sub = disc[disc["group"] == g]
        kmf = KaplanMeierFitter()
        kmf.fit(sub["time_w"], event_observed=sub["event"], label=g)
        sf = kmf.survival_function_.reset_index()
        sf.columns = ["t", "S"]
        fig.add_trace(go.Scatter(x=sf["t"], y=sf["S"], mode="lines", name=f"{g} (N={len(sub)})"))
    fig.update_layout(
        title="60% rebound 미도달 생존곡선 (proxy)",
        xaxis_title="중단 후 주차",
        yaxis_title="S(t) = P(rebound<60% 유지)",
        height=480,
    )
    st.plotly_chart(fig, use_container_width=True)


def make_postmarket_docx(cohort: pd.DataFrame) -> bytes:
    doc = Document()
    doc.add_heading("GLP-1RA Post-Marketing Surveillance Report (KOR draft)", level=1)
    doc.add_paragraph(
        "작성일: " + datetime.now().strftime("%Y-%m-%d") +
        " | 본 문서는 연구·post-marketing 보조 초안이며, 실제 보고는 담당자 검수 필수."
    )
    doc.add_heading("1. Cohort summary", level=2)
    doc.add_paragraph(
        f"총 N = {len(cohort)} | 중단 = {int(cohort['discontinued'].sum())} "
        f"| 재시작 = {int(cohort['restarted'].fillna(False).sum())} "
        f"| 외과·내시경 전환 = {int((cohort['surgical_conversion']!='none').sum())}"
    )
    doc.add_heading("2. AE 빈도 (post-marketing 항목)", level=2)
    items = [
        ("췌장염(pancreatitis)", float(cohort["pancreatitis"].mean()) * 100),
        ("담석(gallstone)", float(cohort["gallstone"].mean()) * 100),
        ("Sarcopenia flag", float(cohort["sarcopenia_flag"].mean()) * 100),
        ("갑상샘 C-cell signal", float(cohort["thyroid_c_cell"].mean()) * 100),
        ("CV 사건", float(cohort["cv_event"].mean()) * 100),
    ]
    t = doc.add_table(rows=1, cols=2)
    t.style = "Light Grid"
    t.rows[0].cells[0].text = "항목"
    t.rows[0].cells[1].text = "발생률(%)"
    for name, rate in items:
        row = t.add_row().cells
        row[0].text = name
        row[1].text = f"{rate:.2f}"
    doc.add_heading("3. 외과·내시경 전환 분포", level=2)
    sc = cohort["surgical_conversion"].value_counts()
    for k, v in sc.items():
        doc.add_paragraph(f"- {k}: {int(v)} ({v/len(cohort)*100:.1f}%)")
    doc.add_paragraph("\n[디스클레이머] 본 보고서 초안은 합성 데이터 기반이며, 실제 식약처/제약사 제출 전 임상연구자 검수 필수.")

    buf = io.BytesIO()
    doc.save(buf)
    return buf.getvalue()


def tab5_postmarket(cohort: pd.DataFrame):
    st.subheader("⑤ Post-marketing AE · 외과적 전환 리포트")
    st.caption("췌장염·담석·sarcopenia·갑상샘 C-cell·CV 사건, 외과/내시경 전환률.")

    items = pd.DataFrame({
        "항목": ["췌장염", "담석", "Sarcopenia flag", "갑상샘 C-cell", "CV 사건"],
        "발생률(%)": [
            cohort["pancreatitis"].mean() * 100,
            cohort["gallstone"].mean() * 100,
            cohort["sarcopenia_flag"].mean() * 100,
            cohort["thyroid_c_cell"].mean() * 100,
            cohort["cv_event"].mean() * 100,
        ],
    })
    items["발생률(%)"] = items["발생률(%)"].round(2)
    st.dataframe(items, use_container_width=True)

    fig = px.bar(items, x="항목", y="발생률(%)", title="Post-marketing AE 발생률 (%)")
    fig.update_layout(height=380)
    st.plotly_chart(fig, use_container_width=True)

    st.markdown("**약물별 외과/내시경 전환률**")
    sc = (
        cohort.assign(converted=lambda d: d["surgical_conversion"] != "none")
        .groupby("drug")["converted"].mean().reset_index()
    )
    sc["converted"] = (sc["converted"] * 100).round(2)
    fig2 = px.bar(sc, x="drug", y="converted", title="약물별 외과/내시경 전환률 (%)")
    fig2.update_layout(height=380)
    st.plotly_chart(fig2, use_container_width=True)

    st.markdown("**리포트 export**")
    if DOCX_OK:
        if st.button("Post-marketing 리포트 (docx) 생성"):
            data = make_postmarket_docx(cohort)
            st.download_button(
                "다운로드: postmarketing_report.docx",
                data=data,
                file_name=f"postmarketing_report_{datetime.now().strftime('%Y%m%d')}.docx",
                mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
            )
    else:
        st.info("python-docx 미설치 — `pip install python-docx`")


def main():
    st.set_page_config(
        page_title="GLP1DiscontRebound-Kor",
        layout="wide",
        initial_sidebar_state="expanded",
    )
    st.title("GLP1DiscontRebound-Kor")
    st.markdown("**지엘피원디스컨티뉴리바운드코어** — GLP-1RA·tirzepatide·orforglipron 중단·rebound·재시작 cohort dashboard")
    disclaimer_box()

    if not (os.path.exists(COHORT_CSV) and os.path.exists(LONGI_CSV)):
        st.error("합성 데이터가 없습니다. `python data/generate_synthetic.py` 실행 후 새로고침하세요.")
        st.stop()

    cohort = load_cohort()
    longi = load_longi()
    ref = load_ref()

    with st.sidebar:
        st.header("필터")
        drug_sel = st.multiselect(
            "약물", sorted(cohort["drug"].unique()),
            default=sorted(cohort["drug"].unique()),
        )
        age_min, age_max = st.slider("연령", 18, 90, (18, 90))
        sex_sel = st.multiselect("성별", ["F", "M"], default=["F", "M"])

    mask = (
        cohort["drug"].isin(drug_sel)
        & (cohort["age"].between(age_min, age_max))
        & (cohort["sex"].isin(sex_sel))
    )
    cohort_f = cohort[mask].copy()
    pids = cohort_f["patient_id"].tolist()
    longi_f = longi[longi["patient_id"].isin(pids)].copy()

    kpi_row(cohort_f)
    st.divider()

    tabs = st.tabs([
        "① 중단 인제스트·사유",
        "② Rebound trajectory",
        "③ 재시작 titration",
        "④ Holiday vs 영구 중단",
        "⑤ Post-marketing AE",
    ])
    with tabs[0]:
        tab1_ingest_and_reasons(cohort_f)
    with tabs[1]:
        tab2_rebound(cohort_f, longi_f, ref)
    with tabs[2]:
        tab3_restart(cohort_f)
    with tabs[3]:
        tab4_holiday(cohort_f)
    with tabs[4]:
        tab5_postmarket(cohort_f)

    st.divider()
    st.caption(
        "Reference: STEP-1 extension (Wilding 2022) · SURMOUNT-4 (Aronne 2024) · SELECT (Lincoff 2023) · "
        "SURMOUNT-MMO (NCT05556512) · ATTAIN-1 (orforglipron Phase 3). "
        "본 대시보드는 합성 데이터 기반 연구·post-marketing 보조용. 임상 의사결정은 담당 의사 판단."
    )


if __name__ == "__main__":
    main()
