"""
app.py — GlycemicNIMargin-Kor Streamlit UI

당뇨 RCT 비열등성 마진·MMRM 표본수 설계 계산기.
실행:  streamlit run app.py

핵심 계산은 core.py(순수 함수)에 분리. 이 파일은 UI 래퍼.
"""

import json
import os

import streamlit as st

import core

try:
    import numpy as np
except Exception:
    np = None

try:
    import pandas as pd
    HAS_PANDAS = True
except Exception:
    HAS_PANDAS = False


st.set_page_config(page_title="GlycemicNIMargin-Kor",
                   page_icon="🩸", layout="wide")


@st.cache_data
def get_presets():
    try:
        return core.load_presets()
    except Exception:
        return None


PRESETS = get_presets()

st.title("GlycemicNIMargin-Kor — 당뇨 RCT 비열등성 설계 계산기")
st.caption("도메인: 당뇨병(DM) · 카테고리: 인체실험 도구(RCT 설계·계산기)")

st.warning("⚠️ " + core.DISCLAIMER)

tab1, tab2, tab3, tab4, tab5 = st.tabs([
    "① NI 마진 정당화",
    "② MMRM 표본수",
    "③ 공동 1차 다중성",
    "④ 탈락 시뮬레이터",
    "⑤ 규제 문서 생성",
])


# ----- 사이드바: 프리셋 선택 -----
st.sidebar.header("Endpoint 프리셋")
ep_key = None
ep_preset = None
if PRESETS:
    eps = PRESETS["endpoints"]
    labels = {k: v["label"] for k, v in eps.items()}
    ep_key = st.sidebar.selectbox(
        "평가변수", list(labels.keys()),
        format_func=lambda k: labels[k])
    ep_preset = eps[ep_key]
    st.sidebar.markdown(f"**단위**: {ep_preset['unit']}")
    st.sidebar.markdown(f"**전형적 SD**: {ep_preset['typical_sd']}")
    st.sidebar.markdown(f"**관례 마진**: {ep_preset['conventional_margin']}")
    st.sidebar.markdown(f"**마진 범위**: {ep_preset['margin_range']}")
    with st.sidebar.expander("출처 문헌"):
        for s in ep_preset["sources"]:
            st.sidebar.write("- " + s)
else:
    st.sidebar.error("margin_presets.json 로드 실패 — 수동 입력 사용")

# 기본값 헬퍼
def _d(key, fallback):
    return ep_preset[key] if ep_preset else fallback

unit = ep_preset["unit"] if ep_preset else "unit"
default_sd = float(_d("typical_sd", 1.0))
default_margin = float(_d("conventional_margin", 0.4))
default_pf = float(_d("preserved_fraction_default", 0.5)) if ep_preset else 0.5


# ========================================================================
# 탭 1: NI 마진 정당화
# ========================================================================
with tab1:
    st.header("① 비열등성(NI) 마진 정당화 엔진")
    st.markdown("과거 위약대조 시험의 활성약-위약 효과로 NI 마진 δ를 산출하고 "
                "ICH E10 용어로 정당화 문단을 자동 작성합니다.")
    method = st.radio("방법", ["fixed-margin (M2)", "synthesis (95-95)"],
                      horizontal=True)
    c1, c2, c3 = st.columns(3)
    with c1:
        effect = st.number_input(
            f"활성약-위약 효과 M1 ({unit})", min_value=0.0,
            value=0.8 if ep_key == "hba1c" else default_margin * 2,
            step=0.05, format="%.3f")
    with c2:
        pf = st.slider("보존비율 (활성약 효과 중 지킬 비율)", 0.0, 0.9,
                       default_pf, 0.05)
    with c3:
        se_hist = st.number_input("과거시험 효과 SE (synthesis용)",
                                  min_value=0.0, value=0.12, step=0.01,
                                  format="%.3f")

    lang = st.radio("문서 언어", ["국문", "영문"], horizontal=True)
    lang_code = "ko" if lang == "국문" else "en"

    if method.startswith("fixed"):
        mres = core.ni_margin_fixed(effect, pf)
        st.metric("산출 NI 마진 δ", f"{mres['NI_margin_M2']:g} {unit}")
    else:
        mres = core.ni_margin_synthesis(effect, se_hist, pf)
        st.metric("산출 NI 마진 δ", f"{mres['NI_margin']:g} {unit}")
        st.caption(f"과거시험 {mres['conf']*100:.0f}% 효과 신뢰하한: "
                   f"{mres['historical_lower_CI']:g} {unit}")

    st.json(mres)
    ep_label = ep_preset["label"] if ep_preset else "Endpoint"
    margin_val = mres.get("NI_margin_M2") or mres.get("NI_margin")
    st.subheader("자동 생성 정당화 문단")
    st.write(core.generate_margin_justification_text(
        ep_label, unit, margin_val, mres, lang_code))

    st.session_state["margin_result"] = mres
    st.session_state["margin_val"] = margin_val
    st.session_state["ep_label"] = ep_label
    st.session_state["unit"] = unit


# ========================================================================
# 탭 2: MMRM 표본수
# ========================================================================
with tab2:
    st.header("② MMRM 기반 탈락보정 표본수")
    st.markdown("종단(baseline·12·24·52주) 측정, AR(1)/unstructured 공분산, "
                "MAR 탈락보정. 단순 ANCOVA 대비 효율을 비교합니다.")
    c1, c2, c3 = st.columns(3)
    with c1:
        margin = st.number_input(f"NI 마진 δ ({unit})", min_value=0.001,
                                 value=float(st.session_state.get(
                                     "margin_val", default_margin)),
                                 step=0.05, format="%.3f")
        sd = st.number_input(f"공통 표준편차 SD ({unit})", min_value=0.001,
                             value=default_sd, step=0.1, format="%.3f")
        true_diff = st.number_input(f"가정한 실제 군간차이 ({unit})",
                                    value=0.0, step=0.05, format="%.3f")
    with c2:
        alpha = st.selectbox("단측 유의수준 α", [0.025, 0.05], index=0)
        power = st.slider("검정력 (1-β)", 0.70, 0.95, 0.80, 0.05)
        ratio = st.selectbox("배정비 (시험:대조)", [1.0, 1.5, 2.0], index=0)
    with c3:
        cov = st.selectbox("공분산구조", ["ar1", "unstructured"])
        rho = st.slider("자기상관 ρ (AR1)", 0.0, 0.9, 0.6, 0.05)
        dropout = st.slider("탈락률", 0.0, 0.35, 0.15, 0.05)
        mech = st.selectbox("결측 메커니즘", ["MAR", "MCAR"])

    cov_arg = "ar1" if cov == "ar1" else "unstructured"
    try:
        ss = core.sample_size_mmrm_ni(
            margin=margin, sd=sd, true_diff=true_diff, alpha=alpha,
            power=power, allocation_ratio=ratio,
            visit_weeks=(0, 12, 24, 52), rho=rho, cov_structure=cov_arg,
            dropout_rate=dropout, missing_mechanism=mech)
        st.session_state["ss_result"] = ss

        m1, m2, m3 = st.columns(3)
        m1.metric("MMRM 총 표본수", f"{ss['n_total']}명")
        m2.metric("단순 ANCOVA 대비 절감",
                  f"{ss['efficiency_gain_subjects']}명",
                  f"{ss['efficiency_gain_pct']:.1f}%")
        m3.metric("총 설계효과 (DE)",
                  f"{ss['mmrm_design_effect']['total_design_effect']:.3f}")
        st.info(ss["interpretation"])

        de = ss["mmrm_design_effect"]
        st.write(f"- 종단 설계효과(longitudinal DE): "
                 f"**{de['longitudinal_DE']:.3f}** "
                 f"(< 1 이면 MMRM 효율 이득)")
        st.write(f"- 결측 inflation: **{de['missing_inflation']:.3f}** "
                 f"({mech} 가정)")

        # 검정력 곡선
        st.subheader("검정력 곡선 (arm별 표본수 vs 검정력)")
        n_ctrl = ss["n_control"]
        ns = list(range(max(5, n_ctrl // 3), n_ctrl * 2 + 1,
                         max(1, n_ctrl // 20)))
        pws = [core.power_curve(margin, sd, n, true_diff, alpha, ratio)
               for n in ns]
        if HAS_PANDAS:
            df = pd.DataFrame({"arm별 표본수": ns, "검정력": pws})
            st.line_chart(df, x="arm별 표본수", y="검정력")
        else:
            st.write({n: round(p, 3) for n, p in zip(ns, pws)})
        st.caption(f"목표 검정력 {power:g}, 계획 arm별 n={n_ctrl}")
    except ValueError as e:
        st.error(f"입력 오류: {e}")


# ========================================================================
# 탭 3: 공동 1차 다중성
# ========================================================================
with tab3:
    st.header("③ 공동 1차 endpoint 다중성 처리")
    st.markdown("HbA1c + CGM-TIR 공동 1차 시 상관 r을 입력하면 "
                "절차별 표본수와 FWER 통제를 계산합니다.")
    c1, c2 = st.columns(2)
    with c1:
        st.subheader("Endpoint A")
        a_label = st.text_input("A 이름", "HbA1c")
        a_margin = st.number_input("A 마진", value=0.4, step=0.05,
                                   format="%.3f")
        a_sd = st.number_input("A SD", value=1.0, step=0.1, format="%.3f")
    with c2:
        st.subheader("Endpoint B")
        b_label = st.text_input("B 이름", "CGM TIR")
        b_margin = st.number_input("B 마진", value=5.0, step=0.5,
                                   format="%.3f")
        b_sd = st.number_input("B SD", value=18.0, step=1.0, format="%.3f")

    default_r = 0.7
    if PRESETS and "joint_correlations" in PRESETS:
        default_r = abs(PRESETS["joint_correlations"]["hba1c_cgm_tir"]
                        ["abs_r_for_power"])
    rcorr = st.slider("Endpoint 간 상관 |r|", 0.0, 0.95, float(default_r),
                      0.05)
    proc = st.selectbox("다중성 절차",
                        ["unadjusted", "hochberg", "gatekeeping",
                         "bonferroni"])
    j_alpha = st.selectbox("전체 단측 α ", [0.025, 0.05], index=0)
    j_power = st.slider("목표 공동 검정력", 0.70, 0.95, 0.80, 0.05)

    jp = core.joint_primary_sample_size(
        endpoints=[
            {"label": a_label, "margin": a_margin, "sd": a_sd,
             "true_diff": 0.0},
            {"label": b_label, "margin": b_margin, "sd": b_sd,
             "true_diff": 0.0},
        ],
        r_correlation=rcorr, alpha=j_alpha, power=j_power,
        procedure=proc, allocation_ratio=1.0)
    st.session_state["joint_result"] = jp

    m1, m2 = st.columns(2)
    m1.metric("공동 1차 총 표본수", f"{jp['n_total']}명")
    m2.metric("달성 공동 검정력",
              f"{jp['achieved_joint_power']*100:.1f}%")
    st.info(jp["interpretation"])
    st.caption(jp["procedure_note"])
    with st.expander("Endpoint별 단독 표본수"):
        for pe in jp["per_endpoint"]:
            st.write(f"- **{pe['label']}**: 총 {pe['n_total']}명 "
                     f"(대조 {pe['n_control']} / 시험 {pe['n_treatment']})")


# ========================================================================
# 탭 4: 탈락 시뮬레이터
# ========================================================================
with tab4:
    st.header("④ 탈락 시나리오 시뮬레이터")
    st.markdown("탈락률 5~25%, 탈락 패턴별 검정력 민감도 히트맵.")
    c1, c2, c3 = st.columns(3)
    with c1:
        d_margin = st.number_input("마진 δ", value=float(
            st.session_state.get("margin_val", 0.4)), step=0.05,
            format="%.3f", key="d_margin")
    with c2:
        d_sd = st.number_input("SD", value=default_sd, step=0.1,
                               format="%.3f", key="d_sd")
    with c3:
        # 탈락 민감도는 단일시점 계획 표본수 기준으로 평가하는 것이 적절
        _ssr = st.session_state.get("ss_result")
        _default_n = 200
        if isinstance(_ssr, dict):
            _default_n = _ssr.get("base_single_timepoint", {}).get(
                "n_control", 200)
        d_n = st.number_input("계획 arm별 표본수 (단일시점 기준 권장)",
                              min_value=5, value=int(_default_n), step=10)
    d_alpha = st.selectbox("단측 α", [0.025, 0.05], index=0, key="d_alpha")

    drop = core.dropout_sensitivity(
        margin=d_margin, sd=d_sd, n_per_arm=d_n, true_diff=0.0,
        alpha=d_alpha)
    st.session_state["dropout_result"] = drop

    st.subheader("검정력 히트맵 (탈락률 × 패턴)")
    pats = drop["patterns"]
    if HAS_PANDAS:
        rows = []
        idx = []
        for r in drop["heatmap"]:
            idx.append(f"{r['dropout_rate']*100:.0f}%")
            rows.append([r["cells"][p]["power"] for p in pats])
        df = pd.DataFrame(rows, columns=pats, index=idx)
        st.dataframe(df.style.format("{:.3f}").background_gradient(
            cmap="RdYlGn", vmin=0.5, vmax=1.0))
    else:
        for r in drop["heatmap"]:
            st.write(f"탈락 {r['dropout_rate']*100:.0f}%: "
                     + ", ".join(f"{p}={r['cells'][p]['power']:.3f}"
                                 for p in pats))
    st.caption(drop["note"])
    st.write("**유효손실 계수**:", drop["effective_loss_factors"])


# ========================================================================
# 탭 5: 규제 문서 생성
# ========================================================================
with tab5:
    st.header("⑤ 규제 제출용 문서 생성")
    st.markdown("표본수 근거·마진 정당화·가정·참고문헌을 IND/프로토콜 통계 "
                "섹션 형식으로 생성합니다.")

    have_margin = "margin_result" in st.session_state
    have_ss = "ss_result" in st.session_state
    if not (have_margin and have_ss):
        st.warning("먼저 탭 ①(마진)과 탭 ②(표본수)를 실행하세요.")
    else:
        doc_lang = st.radio("언어", ["국문", "영문"], horizontal=True,
                            key="doc_lang")
        lc = "ko" if doc_lang == "국문" else "en"
        include_joint = st.checkbox("공동 1차(탭③) 포함",
                                    value="joint_result" in st.session_state)
        include_drop = st.checkbox("탈락 민감도(탭④) 포함",
                                   value="dropout_result" in st.session_state)
        txt = core.generate_regulatory_text(
            st.session_state.get("ep_label", "Endpoint"),
            st.session_state.get("unit", "unit"),
            st.session_state["margin_result"],
            st.session_state["ss_result"],
            lang=lc,
            joint_result=(st.session_state.get("joint_result")
                          if include_joint else None),
            dropout_result=(st.session_state.get("dropout_result")
                            if include_drop else None))
        st.code(txt, language="text")
        st.download_button("문서 다운로드 (.txt)", txt,
                           file_name="ni_margin_statistical_section.txt")

st.divider()
st.caption("GlycemicNIMargin-Kor v1.0 · 오프라인 동작 · 외부 API 호출 없음 · "
           "출처: ICH E9/E10, Chow-Shao-Wang(2008), Lu-Luo-Chen(2008), "
           "Battelino et al.(2019)")
