"""
app.py — GlucoDynamics-Kor Streamlit 진입점

OGTT/MMTT/clamp 시점별 혈당·인슐린·C-peptide 데이터를 입력받아
인슐린 감수성·분비 지표 전 패널을 단위 자동 보정·AUC 방법 선택과 함께
per-participant·코호트로 산출하고 QC 플래그를 제공하는 오프라인 분석기.

실행:  streamlit run app.py

주의: 본 도구는 연구용·교육용·참고용이며 임상 의사결정에 직접 사용 금지.

설계 메모: 모든 계산 로직은 metrics/auc/deconvolution 모듈에 있고
streamlit UI 함수는 그 위의 얇은 표현 계층이다(스트림릿 미설치 환경에서도
계산 모듈은 독립 테스트 가능).
"""
from __future__ import annotations
import io
import math
import numpy as np
import pandas as pd

import auc as auc_mod
import metrics as M
import deconvolution as dc
import demo_data

try:
    import yaml
    with open("reference.yaml", "r", encoding="utf-8") as _f:
        REF = yaml.safe_load(_f)
except Exception:
    REF = None

DISCLAIMER = ("⚠️ 본 도구는 연구용·교육용·참고용이며 임상 의사결정에 직접 사용을 "
              "금지합니다. 모든 산출값은 입력 데이터·가정·근사식에 의존하며 진단·"
              "치료 목적으로 사용할 수 없습니다.")

PROTOCOLS = {
    "OGTT_75g": {"label": "75g OGTT (0-120)", "times": [0, 30, 60, 90, 120],
                 "windows": [(0, 120)], "kind": "oral"},
    "OGTT_75g_ext": {"label": "확장 OGTT (0-180)", "times": [0, 30, 60, 90, 120, 150, 180],
                     "windows": [(0, 120), (0, 180)], "kind": "oral"},
    "MMTT": {"label": "MMTT (0-180)", "times": [0, 15, 30, 60, 90, 120, 180],
             "windows": [(0, 120), (0, 180)], "kind": "oral"},
    "CLAMP": {"label": "Euglycemic clamp", "times": [0, 10, 20, 30, 60, 90, 120],
              "windows": [], "kind": "clamp", "steady": (90, 120)},
}


# ===========================================================================
# 비-UI 핵심 로직 (단위테스트 가능)
# ===========================================================================
def normalize_units(df, g_unit, i_unit, c_unit):
    """입력 단위를 canonical(mg/dL, uU/mL, ng/mL)로 변환한 새 DataFrame 반환."""
    out = df.copy()
    if "glucose" in out:
        out["glucose"] = M.glucose_to_mgdl(out["glucose"].to_numpy(float), g_unit)
    if "insulin" in out:
        out["insulin"] = M.insulin_to_uU(out["insulin"].to_numpy(float), i_unit)
    if "cpeptide" in out and out["cpeptide"].notna().any():
        out["cpeptide"] = M.cpeptide_to_ngml(out["cpeptide"].to_numpy(float), c_unit)
    return out


def qc_flags(df, kind="oral"):
    """
    QC 플래그 산출. canonical 단위 DataFrame 가정.
    반환: list[str] 경고 메시지.
    """
    flags = []
    t = df["time_min"].to_numpy(float)
    if 0 not in t:
        flags.append("공복(0분) 시점 누락 — 공복 기반 지표(HOMA 등) 계산 불가.")
    if df["time_min"].duplicated().any():
        flags.append("중복 시점 존재.")
    if not np.all(np.diff(np.sort(t)) > 0):
        flags.append("시점이 단조증가하지 않음.")

    g = df["glucose"].to_numpy(float) if "glucose" in df else np.array([])
    ins = df["insulin"].to_numpy(float) if "insulin" in df else np.array([])

    if g.size:
        if np.any(g < 30) or np.any(g > 600):
            flags.append("혈당 이상치(<30 또는 >600 mg/dL) — 단위 오입력 의심.")
        if np.any(np.isnan(g)):
            flags.append("혈당 결측값 존재.")
    if ins.size:
        if np.any(ins < 0):
            flags.append("음의 인슐린값 존재(역생리).")
        if np.any(ins > 1000):
            flags.append("인슐린 이상치(>1000 uU/mL) — 단위 오입력 의심(pmol/L?).")
    if "cpeptide" in df and df["cpeptide"].notna().any():
        c = df["cpeptide"].to_numpy(float)
        if np.any(c < 0):
            flags.append("음의 C-peptide값 존재.")
        if np.nanmax(c) > 30:
            flags.append("C-peptide 이상치(>30 ng/mL) — 단위 오입력 의심(pmol/L?).")

    if kind == "oral" and g.size >= 2 and 0 in t:
        g0 = g[np.where(t == 0)[0][0]]
        if np.nanmax(g) <= g0 + 1:
            flags.append("경구부하 후 혈당 상승 없음 — 부하 실패 또는 입력오류 의심.")
    return flags


def analyze_participant(df, kind="oral", windows=None, steady=(90, 120),
                        age=45, weight_kg=75, height_cm=170, bmi=None,
                        diabetic=False, obese=False, run_isr=True):
    """
    canonical 단위 DataFrame 1인분을 분석.
    반환 dict: panel(지표), auc(중첩dict), isr(dict or None), flags(list).
    """
    if windows is None:
        windows = [(0, 120)]
    t = df["time_min"].to_numpy(float)
    g = df["glucose"].to_numpy(float)
    ins = df["insulin"].to_numpy(float) if "insulin" in df else np.full_like(t, np.nan)
    has_c = "cpeptide" in df and df["cpeptide"].notna().any()
    cpep = df["cpeptide"].to_numpy(float) if has_c else None

    result = {"flags": qc_flags(df, kind), "auc": {}, "isr": None, "panel": {}}

    if kind == "clamp":
        gir = df["gir"].to_numpy(float) if "gir" in df else np.full_like(t, np.nan)
        mval = M.clamp_m_value(gir, t, steady, weight_kg)
        ss_mask = (t >= steady[0]) & (t <= steady[1])
        ss_ins = float(np.nanmean(ins[ss_mask])) if ss_mask.any() else float("nan")
        result["panel"] = {
            "Clamp M-value (mg/kg/min)": mval,
            "안정상태 인슐린 (uU/mL)": ss_ins,
            "M/I": M.clamp_m_over_i(mval, ss_ins),
            "HOMA-IR": M.homa_ir(g[np.where(t == 0)[0][0]], ins[np.where(t == 0)[0][0]])
                        if 0 in t else float("nan"),
        }
        return result

    # oral (OGTT/MMTT)
    result["panel"] = M.compute_panel(t, g, ins, cpep, weight_kg=weight_kg,
                                       bmi=bmi, auc_module=auc_mod)

    # AUC 패널 (glucose, insulin, cpeptide)
    result["auc"]["glucose"] = auc_mod.all_auc_panel(t, g, windows)
    result["auc"]["insulin"] = auc_mod.all_auc_panel(t, ins, windows)
    if has_c:
        result["auc"]["cpeptide"] = auc_mod.all_auc_panel(t, cpep, windows)

    # ISR deconvolution
    if run_isr and has_c:
        bsa = dc.bsa_dubois(weight_kg, height_cm)
        params = dc.standard_kinetics(age, bsa, diabetic=diabetic, obese=obese)
        cpep_pmol = M.cpeptide_ngml_to_pmol_L(cpep)
        isr = dc.deconvolve_isr(t, cpep_pmol, params, dt=1.0)
        isr["params"] = params
        isr["bsa_m2"] = bsa
        result["isr"] = isr
    return result


def footnotes(kind="oral"):
    """결과에 첨부할 공식·파라미터 footnote 텍스트."""
    base = [
        "HOMA-IR = FPG[mmol/L]·FPI[µU/mL]/22.5 (Matthews 1985).",
        "HOMA-β = 20·FPI/(FPG[mmol/L]−3.5) (Matthews 1985).",
        "HOMA2-IR(근사): 비선형 HOMA2 모델의 로그선형 근사(참고용; 정밀값은 공식 HOMA2 계산기 권장).",
        "QUICKI = 1/(log10 FPI[µU/mL]+log10 FPG[mg/dL]) (Katz 2000).",
        "Matsuda/ISI = 10000/√(FPG·FPI·meanG·meanI), mg/dL·µU/mL (Matsuda-DeFronzo 1999); meanG/meanI는 시간가중 평균.",
        "Gutt ISI(0,120) (Gutt 2000); Stumvoll ISI 단순형(Stumvoll 2000, 참고용).",
        "Insulinogenic index = (I30−I0)/(G30−G0).",
        "CIR30 = 100·I30/(G30·(G30−70)) (Sluiter 1976).",
        "Oral Disposition Index = Insulinogenic index × Matsuda (Utzschneider 2009).",
        "ISR: C-peptide 2-구획 디컨볼루션, van Cauter 1992 표준 kinetics(연령·BSA 보정); BSA=Du Bois.",
        "AUC: 사다리꼴(불균등 간격 지원). incremental=기저차감, positive=양의 증분만.",
    ]
    clamp = [
        "Clamp M-value = 안정상태 GIR 평균(mg/kg/min) (DeFronzo 1979).",
        "M/I = M-value / 안정상태 혈장 인슐린.",
    ]
    return clamp + base if kind == "clamp" else base


# ===========================================================================
# Streamlit UI
# ===========================================================================
def _fmt(x):
    if x is None or (isinstance(x, float) and (math.isnan(x) or math.isinf(x))):
        return "—"
    return f"{x:,.3f}"


def run_streamlit():
    import streamlit as st
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    st.set_page_config(page_title="GlucoDynamics-Kor", layout="wide")
    st.title("GlucoDynamics-Kor · 대사 동적검사 분석기")
    st.caption("OGTT / MMTT / Euglycemic clamp — 인슐린 감수성·분비 지표 전 패널 (완전 오프라인)")
    st.warning(DISCLAIMER)

    with st.sidebar:
        st.header("1. 검사 모드")
        proto_key = st.selectbox("프로토콜", list(PROTOCOLS.keys()),
                                 format_func=lambda k: PROTOCOLS[k]["label"])
        proto = PROTOCOLS[proto_key]
        kind = proto["kind"]

        st.header("2. 단위")
        g_unit = st.radio("혈당", ["mg/dL", "mmol/L"], horizontal=True)
        i_unit = st.radio("인슐린", ["uU/mL", "pmol/L"], horizontal=True)
        c_unit = st.radio("C-peptide", ["ng/mL", "nmol/L"], horizontal=True)

        st.header("3. 피험자 정보")
        age = st.number_input("연령(세)", 1, 110, 45)
        weight_kg = st.number_input("체중(kg)", 20.0, 250.0, 75.0)
        height_cm = st.number_input("신장(cm)", 100.0, 220.0, 170.0)
        bmi = round(weight_kg / (height_cm / 100) ** 2, 1)
        st.caption(f"BMI = {bmi}")
        diabetic = st.checkbox("당뇨군(ISR kinetics)", value=False)
        obese = st.checkbox("비만군(ISR kinetics)", value=bmi >= 30)

        st.header("4. AUC 옵션")
        auc_mode = st.selectbox("표시 AUC 방법",
                                ["total", "incremental", "positive_incremental"])

    st.subheader("데이터 입력")
    src = st.radio("데이터 소스", ["합성 데모", "CSV 업로드"], horizontal=True)

    cohort = {}  # name -> canonical df
    if src == "합성 데모":
        demos = demo_data.all_demos()
        if kind == "clamp":
            chosen = st.multiselect("데모 선택", ["CLAMP"], default=["CLAMP"])
        else:
            opts = ["OGTT_normal", "OGTT_IGT", "OGTT_T2DM"]
            chosen = st.multiselect("데모 선택(다중=코호트)", opts, default=opts)
        for name in chosen:
            # 데모는 이미 canonical 단위
            cohort[name] = demos[name]
        st.info("합성 데모는 절차적 생성값으로 실제 환자 데이터가 아닙니다.")
    else:
        st.caption("CSV 스키마: time_min, glucose, insulin, [cpeptide]  (clamp는 gir 열 추가). "
                   "여러 명을 올리려면 participant 열을 추가하세요.")
        up = st.file_uploader("CSV 업로드 (다중 가능)", type=["csv"],
                              accept_multiple_files=True)
        if up:
            for f in up:
                raw = pd.read_csv(f)
                if "participant" in raw.columns:
                    for pid, sub in raw.groupby("participant"):
                        cohort[f"{f.name}:{pid}"] = normalize_units(
                            sub.drop(columns=["participant"]), g_unit, i_unit, c_unit)
                else:
                    cohort[f.name] = normalize_units(raw, g_unit, i_unit, c_unit)

    if not cohort:
        st.stop()

    st.divider()
    st.subheader("결과")

    panels = {}
    for name, df in cohort.items():
        res = analyze_participant(
            df, kind=kind, windows=proto.get("windows", [(0, 120)]),
            steady=proto.get("steady", (90, 120)),
            age=age, weight_kg=weight_kg, height_cm=height_cm, bmi=bmi,
            diabetic=diabetic, obese=obese, run_isr=True)
        panels[name] = res

        with st.expander(f"▸ {name}", expanded=(len(cohort) == 1)):
            c1, c2 = st.columns([1, 1])
            with c1:
                st.markdown("**입력 데이터 (canonical 단위)**")
                st.dataframe(df, use_container_width=True)
                if res["flags"]:
                    st.error("QC 플래그:\n" + "\n".join(f"- {x}" for x in res["flags"]))
                else:
                    st.success("QC 플래그 없음")
            with c2:
                st.markdown("**지표 패널**")
                pretty = {k: _fmt(v) for k, v in res["panel"].items()
                          if not k.startswith("_")}
                st.table(pd.DataFrame.from_dict(pretty, orient="index",
                                                columns=["값"]))

            # AUC
            if res["auc"]:
                st.markdown(f"**AUC ({auc_mode})**")
                rows = []
                for analyte, wdict in res["auc"].items():
                    for (t0, t1), modes in wdict.items():
                        rows.append({"측정물질": analyte, "구간(min)": f"{t0}-{t1}",
                                     "AUC": _fmt(modes[auc_mode])})
                st.table(pd.DataFrame(rows))

            # 곡선 플롯
            if kind != "clamp":
                fig, ax = plt.subplots(1, 2, figsize=(9, 3))
                ax[0].plot(df["time_min"], df["glucose"], "o-", color="#c0392b")
                ax[0].set_title("Glucose (mg/dL)"); ax[0].set_xlabel("min")
                ax[1].plot(df["time_min"], df["insulin"], "s-", color="#2980b9")
                ax[1].set_title("Insulin (µU/mL)"); ax[1].set_xlabel("min")
                fig.tight_layout()
                st.pyplot(fig); plt.close(fig)
            else:
                fig, ax = plt.subplots(figsize=(6, 3))
                ax.plot(df["time_min"], df["gir"], "o-", color="#27ae60")
                ax.axvspan(*proto["steady"], alpha=0.15, color="green")
                ax.set_title("GIR (mg/kg/min) — 음영=안정상태")
                ax.set_xlabel("min")
                st.pyplot(fig); plt.close(fig)

            # ISR
            if res["isr"] and res["isr"].get("isr_pmol_min") is not None \
                    and len(res["isr"]["isr_pmol_min"]) > 0:
                isr = res["isr"]
                st.markdown(f"**ISR (디컨볼루션, {isr['method']})**  "
                            f"기저={_fmt(isr['isr_basal'])} pmol/min, "
                            f"총분비≈{_fmt(isr['isr_total_auc'])} pmol")
                fig2, ax2 = plt.subplots(figsize=(6, 2.6))
                ax2.plot(isr["grid_min"], isr["isr_pmol_min"], color="#8e44ad")
                ax2.set_title("Insulin Secretion Rate"); ax2.set_xlabel("min")
                ax2.set_ylabel("pmol/min")
                st.pyplot(fig2); plt.close(fig2)

    # 코호트 비교
    if len(cohort) > 1 and kind != "clamp":
        st.divider()
        st.subheader("코호트 비교")
        keys = ["HOMA-IR", "Matsuda/ISI", "QUICKI", "HOMA-β",
                "Insulinogenic index", "Oral Disposition Index"]
        comp = {name: {k: panels[name]["panel"].get(k, float("nan")) for k in keys}
                for name in cohort}
        comp_df = pd.DataFrame(comp).T
        st.dataframe(comp_df.style.format("{:.3f}"), use_container_width=True)

        # 박스플롯 (군이 여럿일 때 분포)
        fig3, axes = plt.subplots(1, 2, figsize=(10, 3.2))
        for ax, metric in zip(axes, ["HOMA-IR", "Matsuda/ISI"]):
            vals = [comp[name][metric] for name in cohort
                    if math.isfinite(comp[name][metric])]
            labels = [name for name in cohort if math.isfinite(comp[name][metric])]
            ax.bar(range(len(vals)), vals, color="#34495e")
            ax.set_xticks(range(len(labels)))
            ax.set_xticklabels(labels, rotation=30, ha="right", fontsize=7)
            ax.set_title(metric)
        fig3.tight_layout()
        st.pyplot(fig3); plt.close(fig3)

    # footnotes
    st.divider()
    st.markdown("**산출 공식·파라미터 (footnote)**")
    for fn in footnotes(kind):
        st.caption("• " + fn)
    st.caption(DISCLAIMER)


if __name__ == "__main__":
    run_streamlit()
