"""
ObesityEnrollFunnelOps-Kor (오베시티인롤펀널옵스코어)
================================================================
비만 RCT *진행 중* 등록(enrollment) funnel 운영 분석기 (standalone Streamlit).
도메인: Obesity | 카테고리: 인체실험 도구(임상시험 운영 물류)

실행: pip install -r requirements.txt && streamlit run app.py
오프라인 전용 — 외부 네트워크/API 호출 없음. 모든 데이터는 합성/사용자입력.
"""
import os
import io
import numpy as np
import pandas as pd
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go

import funnel_logic as fl

HERE = os.path.dirname(os.path.abspath(__file__))
DATA = os.path.join(HERE, "data")

st.set_page_config(page_title="ObesityEnrollFunnelOps-Kor", layout="wide")

DISCLAIMER = (
    "본 도구는 연구·운영 보조용 참고 도구이며 실제 임상시험 규제 의사결정을 "
    "대체하지 않는다. 모든 데모 데이터는 합성(synthetic)이다."
)


# ---------------------------------------------------------------------------
# 데이터 로딩 (CSV 업로드 없이도 합성 데모로 즉시 동작)
# ---------------------------------------------------------------------------
@st.cache_data
def load_demo():
    def rd(name):
        return pd.read_csv(os.path.join(DATA, name))
    return {
        "funnel": rd("enrollment_funnel.csv"),
        "screen_fail": rd("screen_fail_reasons.csv"),
        "site": rd("site_enrollment.csv"),
        "demo": rd("demographics.csv"),
        "retention": rd("retention_visits.csv"),
    }


def maybe_user_csv(label, expected_cols):
    up = st.file_uploader(label, type=["csv"], key=label)
    if up is not None:
        try:
            df = pd.read_csv(up)
            missing = [c for c in expected_cols if c not in df.columns]
            if missing:
                st.warning(f"필수 컬럼 누락: {missing} — 데모 데이터로 대체합니다.")
                return None
            st.success(f"업로드 사용: {df.shape[0]}행")
            return df
        except Exception as e:
            st.error(f"CSV 로드 실패: {e} — 데모 데이터로 대체합니다.")
            return None
    return None


# ---------------------------------------------------------------------------
# 헤더 + 디스클레이머
# ---------------------------------------------------------------------------
st.title("ObesityEnrollFunnelOps-Kor")
st.caption("비만 RCT 진행 중 등록(enrollment) funnel 운영 분석기 · 도메인: Obesity · 카테고리: 인체실험 도구")
st.warning("디스클레이머: " + DISCLAIMER)

data = load_demo()

with st.sidebar:
    st.header("데이터 소스")
    st.write("기본은 합성 데모 데이터. 필요 시 CSV 업로드로 대체 가능.")
    use_upload = st.checkbox("내 CSV 업로드 사용", value=False)
    st.divider()
    st.subheader("Bayesian prior 설정")
    prior_rate = st.number_input("사전 평균 등록률(주당, site당)", 0.1, 10.0, 1.0, 0.1)
    st.caption("Poisson-Gamma 켤레 사전. 데이터가 쌓이면 사후가 사전을 압도.")

tabs = st.tabs([
    "1. Funnel 분해/병목",
    "2. Screen-fail 근본원인",
    "3. 등록곡선 재예측",
    "4. Retention/Dropout",
    "5. 다양성/대표성",
])

# ---------------------------------------------------------------------------
# 탭 1: Funnel 분해 + bottleneck
# ---------------------------------------------------------------------------
with tabs[0]:
    st.subheader("6단계 funnel 분해 및 최대 병목 자동 식별")
    funnel_df = data["funnel"]
    if use_upload:
        u = maybe_user_csv("등록 funnel CSV (site_id,site_name,week,stage,count)",
                           ["site_id", "week", "stage", "count"])
        if u is not None:
            funnel_df = u

    totals = fl.funnel_totals(funnel_df)
    totals["stage_kor"] = totals["stage"].map(fl.STAGE_LABELS_KOR)
    pass_df = fl.stage_pass_rates(totals)
    bn = fl.identify_bottleneck(pass_df)

    c1, c2 = st.columns([2, 1])
    with c1:
        fig = go.Figure(go.Funnel(
            y=totals["stage_kor"], x=totals["count"],
            textposition="inside", textinfo="value+percent initial"))
        fig.update_layout(title="전체 등록 funnel", height=420)
        st.plotly_chart(fig, use_container_width=True)
    with c2:
        st.metric("최대 병목 구간",
                  f"{fl.STAGE_LABELS_KOR.get(bn['from_stage'],'')}→{fl.STAGE_LABELS_KOR.get(bn['to_stage'],'')}",
                  f"통과율 {bn['pass_rate']*100:.1f}%")
        st.metric("해당 구간 이탈 수", f"{bn['dropped']:,}명")
        pr_view = pass_df.copy()
        pr_view["전이"] = pr_view.apply(
            lambda r: f"{fl.STAGE_LABELS_KOR[r['from_stage']]}→{fl.STAGE_LABELS_KOR[r['to_stage']]}", axis=1)
        pr_view["통과율(%)"] = (pr_view["pass_rate"] * 100).round(1)
        st.dataframe(pr_view[["전이", "통과율(%)", "dropped"]].rename(columns={"dropped": "이탈수"}),
                     hide_index=True, use_container_width=True)

    st.markdown("**Site × 단계 매트릭스**")
    mat = fl.site_stage_matrix(funnel_df)
    mat_kor = mat.rename(columns=fl.STAGE_LABELS_KOR)
    st.dataframe(mat_kor, use_container_width=True)

    st.markdown("**주차별 단계 추이**")
    wk = funnel_df.groupby(["week", "stage"])["count"].sum().reset_index()
    wk["stage_kor"] = wk["stage"].map(fl.STAGE_LABELS_KOR)
    figw = px.line(wk, x="week", y="count", color="stage_kor", markers=True,
                   labels={"week": "주차", "count": "건수", "stage_kor": "단계"})
    st.plotly_chart(figw, use_container_width=True)

# ---------------------------------------------------------------------------
# 탭 2: screen-fail 근본원인
# ---------------------------------------------------------------------------
with tabs[1]:
    st.subheader("Screen-fail 근본원인 taxonomy")
    sf_df = data["screen_fail"]
    if use_upload:
        u = maybe_user_csv("screen-fail CSV (site_id,reason_code,reason_label,avoidable,count)",
                           ["reason_code", "count"])
        if u is not None:
            sf_df = u

    summ = fl.screen_fail_summary(sf_df)
    split = fl.avoidable_split(sf_df)

    c1, c2 = st.columns(2)
    with c1:
        fig = px.bar(summ, x="count", y="reason_label", color="avoidable",
                     orientation="h", labels={"count": "건수", "reason_label": "사유"},
                     title="사유 코드별 screen-fail 분포")
        fig.update_layout(yaxis={"categoryorder": "total ascending"})
        st.plotly_chart(fig, use_container_width=True)
    with c2:
        st.metric("회피가능 screen-fail 비율", f"{split['avoidable_pct']*100:.1f}%",
                  f"회피가능 {split['avoidable']} / 불가피 {split['inevitable']}")
        figp = px.pie(values=[split["avoidable"], split["inevitable"]],
                      names=["회피가능", "불가피"], title="회피 가능 vs 불가피",
                      color_discrete_sequence=["#EF553B", "#636EFA"])
        st.plotly_chart(figp, use_container_width=True)

    st.markdown("**Site별 사유 분포 편차 (전체 평균 대비 절대비율차 합, 클수록 이상)**")
    dev = fl.site_reason_deviation(sf_df)
    st.dataframe(dev.rename(columns={"deviation": "편차", "n_fail": "screen-fail수"}),
                 hide_index=True, use_container_width=True)
    st.caption("편차가 큰 site 는 특정 사유에 치우쳐 있어 프로토콜 해석/홍보 타겟팅 점검 대상이 될 수 있다.")

# ---------------------------------------------------------------------------
# 탭 3: 등록곡선 Bayesian 재예측
# ---------------------------------------------------------------------------
with tabs[2]:
    st.subheader("Site별 등록곡선 Bayesian 재예측")
    site_df = data["site"]
    funnel_df = data["funnel"]
    weeks_elapsed = int(funnel_df["week"].max())
    st.caption(f"경과 주차(데이터 기준): {weeks_elapsed}주 · prior 등록률 {prior_rate:.1f}/주(site)")

    rows = []
    for _, r in site_df.iterrows():
        pred = fl.predict_completion(int(r["target_n"]), int(r["cum_randomized"]),
                                     weeks_elapsed, prior_rate=prior_rate)
        rows.append({
            "site": f"{r['site_id']} {r.get('site_name','')}",
            "목표": int(r["target_n"]),
            "현재등록": int(r["cum_randomized"]),
            "달성률(%)": round(100 * r["cum_randomized"] / r["target_n"], 1) if r["target_n"] else 0,
            "등록률/주(평균)": round(pred["rate_mean"], 2),
            "잔여": pred["remaining"],
            "완료까지(주,평균)": round(pred["weeks_mean"], 1) if np.isfinite(pred["weeks_mean"]) else "∞",
            "낙관(주)": round(pred["weeks_optimistic"], 1) if np.isfinite(pred["weeks_optimistic"]) else "∞",
            "비관(주)": round(pred["weeks_pessimistic"], 1) if np.isfinite(pred["weeks_pessimistic"]) else "∞",
        })
    pred_df = pd.DataFrame(rows)
    st.dataframe(pred_df, hide_index=True, use_container_width=True)

    # 전체 시험 수준 예측
    tot_target = int(site_df["target_n"].sum())
    tot_cum = int(site_df["cum_randomized"].sum())
    tot_pred = fl.predict_completion(tot_target, tot_cum, weeks_elapsed, prior_rate=prior_rate * len(site_df))
    mean_site_rate = tot_pred["rate_mean"] / max(1, len(site_df))
    c1, c2, c3 = st.columns(3)
    c1.metric("시험 전체 목표", f"{tot_target}", f"현재 {tot_cum} ({100*tot_cum/tot_target:.0f}%)")
    c2.metric("완료까지(주, 평균)",
              f"{tot_pred['weeks_mean']:.1f}" if np.isfinite(tot_pred["weeks_mean"]) else "∞",
              f"낙관 {tot_pred['weeks_optimistic']:.1f} / 비관 {tot_pred['weeks_pessimistic']:.1f}")

    st.markdown("**목표 기한 내 달성에 필요한 (평균속도) site 수**")
    weeks_left = st.slider("남은 허용 주차", 4, 52, 12)
    n_needed = fl.sites_needed(tot_pred["remaining"], mean_site_rate, weeks_left)
    c3.metric("필요 site 수(평균속도)",
              f"{n_needed:.1f}" if np.isfinite(n_needed) else "∞",
              f"현재 {len(site_df)}개")

    # 예측 곡선 시각화
    cum_now = tot_cum
    future_weeks = np.arange(0, int(min(60, max(8, (tot_pred["weeks_pessimistic"] if np.isfinite(tot_pred["weeks_pessimistic"]) else 52)))) + 1)
    fig = go.Figure()
    for label, rate in [("평균", tot_pred["rate_mean"]),
                        ("낙관", tot_pred["rate_hi"]),
                        ("비관", tot_pred["rate_lo"])]:
        proj = np.minimum(cum_now + rate * future_weeks, tot_target)
        fig.add_trace(go.Scatter(x=future_weeks + weeks_elapsed, y=proj, mode="lines", name=label))
    fig.add_hline(y=tot_target, line_dash="dash", annotation_text="목표")
    fig.update_layout(title="시험 전체 누적 등록 예측", xaxis_title="주차", yaxis_title="누적 무작위배정", height=420)
    st.plotly_chart(fig, use_container_width=True)

# ---------------------------------------------------------------------------
# 탭 4: retention / dropout
# ---------------------------------------------------------------------------
with tabs[3]:
    st.subheader("Retention / Dropout 추적 및 early-responder 잔류 편향")
    ret_df = data["retention"]
    rc = fl.retention_curve(ret_df)
    eb = fl.early_responder_bias(ret_df)

    c1, c2 = st.columns(2)
    with c1:
        fig = px.line(rc, x="visit_week", y="retention_rate", markers=True,
                      labels={"visit_week": "방문 주차", "retention_rate": "잔류율"},
                      title="방문별 잔류율(전체)")
        fig.update_yaxes(tickformat=".0%", range=[0, 1.05])
        st.plotly_chart(fig, use_container_width=True)
        st.dataframe(rc.rename(columns={"visit_week": "방문주차", "n_present": "잔류수",
                                        "n_enrolled": "등록수", "retention_rate": "잔류율"}),
                     hide_index=True, use_container_width=True)
    with c2:
        fig2 = px.line(eb, x="visit_week", y="bmi_shift_vs_baseline", markers=True,
                       labels={"visit_week": "방문 주차", "bmi_shift_vs_baseline": "잔류군 평균 BMI 변화"},
                       title="잔류군 평균 BMI 이동 (생존편향 모니터)")
        fig2.add_hline(y=0, line_dash="dash")
        st.plotly_chart(fig2, use_container_width=True)
        st.caption("baseline 대비 잔류군 평균 BMI 가 양/음으로 이동하면 특정 BMI 군의 "
                   "선택적 잔류(체중감량 시험 특유의 early-responder 편향)를 시사한다.")

# ---------------------------------------------------------------------------
# 탭 5: 다양성 / representativeness
# ---------------------------------------------------------------------------
with tabs[4]:
    st.subheader("등록 인구 다양성 / 대표성 (질환 역학 reference 비교)")
    demo_df = data["demo"]
    rep = fl.representativeness(demo_df)
    st.caption("reference 분포는 합성·문헌 기반 가정치다. 카이제곱 적합도 검정 p<0.05 면 "
               "등록 인구가 역학 분포와 통계적으로 다름을 시사(주의: 표본/가정 의존).")

    for dim, kor in [("sex", "성별"), ("age_band", "연령대"), ("race", "인종")]:
        r = rep[dim]
        t = r["table"].copy()
        t["등록(%)"] = (t["enrolled_pct"] * 100).round(1)
        t["역학기준(%)"] = (t["reference_pct"] * 100).round(1)
        t["절대차(%p)"] = (t["abs_diff"] * 100).round(1)
        st.markdown(f"**{kor}** — 카이제곱 p = "
                    + (f"{r['chisq_p']:.3f}" if r["chisq_p"] == r["chisq_p"] else "N/A")
                    + f" (n={r['n']})")
        figc = go.Figure()
        figc.add_trace(go.Bar(x=t["category"], y=t["enrolled_pct"], name="등록"))
        figc.add_trace(go.Bar(x=t["category"], y=t["reference_pct"], name="역학기준"))
        figc.update_layout(barmode="group", yaxis_tickformat=".0%", height=300)
        cc1, cc2 = st.columns([2, 1])
        cc1.plotly_chart(figc, use_container_width=True)
        cc2.dataframe(t[["category", "등록(%)", "역학기준(%)", "절대차(%p)"]].rename(
            columns={"category": "구분"}), hide_index=True, use_container_width=True)

    st.markdown("**등록 BMI 분포**")
    figb = px.histogram(demo_df, x="bmi", nbins=20, labels={"bmi": "BMI"}, title="등록 피험자 BMI 분포")
    st.plotly_chart(figb, use_container_width=True)

st.divider()
st.caption("출처/근거: CONSORT 2010 flow diagram(enrollment→allocation→follow-up→analysis), "
           "비만 RCT screen-fail 사유 분류는 일반 문헌 기반 합성 taxonomy. 본 도구는 참고용이다.")
