#!/usr/bin/env python3
"""
app.py — PETGlucoFlux Streamlit 앱.

rodent micro-PET 18F-FDG 정량 도구. 정량 계산 로직은 petgluco_core (표준
라이브러리) 를 그대로 재사용하고, 본 파일은 업로드 / 시각화 / 인터랙션만 담당한다.

실행:  streamlit run app.py
(streamlit/numpy/pandas/matplotlib 미설치 환경에서는 오프라인 CLI 인 main.py 사용)

핵심 기능 5개:
  1) TAC·input function·meta ingest + 18F decay 보정 + QC
  2) SUV 정량 (BW/lean/glucose-corrected)
  3) Patlak Ki + LC 보정 MRGlu
  4) 종단 within-animal paired 변화 + 모델 참고 범위/누락 경고
  5) cohort 통계(paired t-test) + mechanism 분류 + 국·영문 리포트

본 도구는 연구용·참고용이며, 정량 결과는 사용자 검증이 필요하다.
원시 DICOM 재구성은 범위 밖(정량 CSV ingest 전제). 기본 샘플은 합성 데이터.
"""

import io
import os
import sys

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import streamlit as st
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import petgluco_core as core

HERE = os.path.dirname(os.path.abspath(__file__))
DATA = os.path.join(HERE, "data")

DISCLAIMER = (
    "본 도구는 **연구용·참고용**이며, 정량 결과는 사용자 검증이 필요합니다. "
    "원시 DICOM 재구성은 범위 밖(정량 CSV ingest 전제)이며, 기본 샘플은 "
    "**합성(synthetic) 데이터**입니다."
)

st.set_page_config(page_title="PETGlucoFlux", layout="wide")


# ---------------------------------------------------------------------------
# ingest helpers (streamlit upload -> core dict 구조)
# ---------------------------------------------------------------------------
def _df_to_tac(df):
    out = []
    for _, r in df.iterrows():
        out.append({
            "animal_id": str(r["animal_id"]), "group": str(r["group"]),
            "timepoint": str(r["timepoint"]), "tissue": str(r["tissue"]),
            "time_min": float(r["time_min"]), "activity_bqml": float(r["activity_bqml"]),
        })
    return out


def _df_to_meta(df):
    out = {}
    for _, r in df.iterrows():
        out[(str(r["animal_id"]), str(r["timepoint"]))] = {
            "animal_id": str(r["animal_id"]), "group": str(r["group"]),
            "timepoint": str(r["timepoint"]), "dose_mbq": float(r["dose_mbq"]),
            "body_weight_g": float(r["body_weight_g"]),
            "lean_mass_g": float(r["lean_mass_g"]),
            "glucose_mgdl": float(r["glucose_mgdl"]),
            "scan_temp_c": (None if pd.isna(r.get("scan_temp_c")) else float(r.get("scan_temp_c"))),
            "anesthesia": ("" if pd.isna(r.get("anesthesia")) else str(r.get("anesthesia"))),
        }
    return out


def _df_to_input(df):
    out = {}
    for _, r in df.iterrows():
        key = (str(r["animal_id"]), str(r["timepoint"]))
        out.setdefault(key, []).append({
            "time_min": float(r["time_min"]),
            "plasma_activity_bqml": float(r["plasma_activity_bqml"]),
        })
    for k in out:
        out[k].sort(key=lambda d: d["time_min"])
    return out


@st.cache_data
def load_samples():
    tac = pd.read_csv(os.path.join(DATA, "sample_tac.csv"))
    meta = pd.read_csv(os.path.join(DATA, "sample_meta.csv"))
    inp = pd.read_csv(os.path.join(DATA, "sample_input_function.csv"))
    return tac, meta, inp


# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
st.title("PETGlucoFlux — rodent micro-PET ¹⁸F-FDG 정량")
st.caption("DM 도메인 · 동물실험 도구 (in vivo 이미지 정량) · standalone 정량 엔진")
st.warning(DISCLAIMER)

with st.sidebar:
    st.header("데이터 입력")
    use_sample = st.button("샘플 데이터 로드 (합성)")
    st.markdown("---")
    up_tac = st.file_uploader("TAC CSV", type="csv", key="tac")
    up_meta = st.file_uploader("meta CSV", type="csv", key="meta")
    up_inp = st.file_uploader("input function CSV", type="csv", key="inp")
    st.markdown("---")
    model = st.selectbox("참고 모델", list(core.MODEL_REFERENCE.keys()))
    st.markdown("---")
    st.subheader("Lumped constant (조직별)")
    lc = {}
    for t, v in core.DEFAULT_LUMPED_CONSTANT.items():
        lc[t] = st.number_input("LC: %s" % t, value=float(v), step=0.05, format="%.2f")

# state
if "loaded" not in st.session_state:
    st.session_state.loaded = False

if use_sample:
    st.session_state.df_tac, st.session_state.df_meta, st.session_state.df_inp = load_samples()
    st.session_state.loaded = True
elif up_tac and up_meta and up_inp:
    st.session_state.df_tac = pd.read_csv(up_tac)
    st.session_state.df_meta = pd.read_csv(up_meta)
    st.session_state.df_inp = pd.read_csv(up_inp)
    st.session_state.loaded = True

if not st.session_state.loaded:
    st.info("좌측에서 '샘플 데이터 로드' 를 누르거나 TAC / meta / input function "
            "CSV 3개를 업로드하세요.")
    st.stop()

tac = _df_to_tac(st.session_state.df_tac)
meta = _df_to_meta(st.session_state.df_meta)
inp = _df_to_input(st.session_state.df_inp)
core.decay_correct_tac(tac)

tab1, tab2, tab3, tab4, tab5 = st.tabs(
    ["1. Ingest+Decay+QC", "2. SUV", "3. Patlak Ki+MRGlu",
     "4. 종단+참고범위", "5. Cohort+리포트"])

# ---- Tab 1 ----
with tab1:
    st.subheader("Ingest + ¹⁸F decay 보정")
    st.write("¹⁸F 반감기 %.2f min, lambda=%.5f /min. TAC 행수=%d, meta=%d, "
             "input 곡선=%d." % (core.F18_HALF_LIFE_MIN, core.F18_DECAY_LAMBDA,
                                 len(tac), len(meta), len(inp)))
    dft = pd.DataFrame(tac)
    st.dataframe(dft.head(50), use_container_width=True)

    st.markdown("**TAC 곡선 (animal/timepoint 선택)**")
    aids = sorted({r["animal_id"] for r in tac})
    c1, c2 = st.columns(2)
    aid = c1.selectbox("animal", aids)
    tps = sorted({r["timepoint"] for r in tac if r["animal_id"] == aid})
    tp = c2.selectbox("timepoint", tps)
    fig, ax = plt.subplots(figsize=(7, 4))
    for tissue in sorted({r["tissue"] for r in tac}):
        sel = sorted([r for r in tac if r["animal_id"] == aid and r["timepoint"] == tp
                      and r["tissue"] == tissue], key=lambda d: d["time_min"])
        if sel:
            ax.plot([r["time_min"] for r in sel],
                    [r["decay_corrected_bqml"] for r in sel], marker="o", label=tissue)
    ip = inp.get((aid, tp))
    if ip:
        ax.plot([d["time_min"] for d in ip], [d["plasma_activity_bqml"] for d in ip],
                "k--", label="input function")
    ax.set_xlabel("time (min)"); ax.set_ylabel("decay-corrected activity (Bq/mL)")
    ax.legend(fontsize=7); ax.set_title("%s / %s" % (aid, tp))
    st.pyplot(fig)

    st.markdown("**QC 플래그**")
    qc = core.qc_flags(tac, inp, meta)
    if qc:
        for f in qc:
            st.warning(f)
    else:
        st.success("QC 플래그 없음")

# ---- Tab 2 ----
suv_rows = core.suv_table(tac, meta)
with tab2:
    st.subheader("SUV 정량 (BW / lean / glucose-corrected)")
    st.latex(r"SUV = \frac{\text{tissue activity}}{\text{dose}} \times \text{mass}")
    dfs = pd.DataFrame(suv_rows)
    st.dataframe(dfs, use_container_width=True)
    norm = st.radio("정규화 기준", ["SUV_bw", "SUV_lean", "SUV_glu"], horizontal=True)
    fig, ax = plt.subplots(figsize=(8, 4))
    tissues = sorted(dfs["tissue"].unique())
    for grp in sorted(dfs["group"].unique()):
        for tpn in sorted(dfs["timepoint"].unique()):
            vals = [dfs[(dfs.tissue == t) & (dfs.group == grp) &
                        (dfs.timepoint == tpn)][norm].mean() for t in tissues]
            ax.plot(tissues, vals, marker="s", label="%s/%s" % (grp, tpn))
    ax.set_ylabel(norm); ax.legend(fontsize=7); plt.xticks(rotation=30, ha="right")
    st.pyplot(fig)

# ---- Tab 3 ----
kin_rows = core.kinetic_table(tac, inp, meta, lumped_constants=lc)
with tab3:
    st.subheader("Patlak graphical analysis + MRGlu")
    st.latex(r"MRGlu = \frac{K_i \times \text{glucose}}{LC}")
    st.dataframe(pd.DataFrame(kin_rows), use_container_width=True)
    st.markdown("**Patlak plot (선택)**")
    c1, c2, c3 = st.columns(3)
    a2 = c1.selectbox("animal ", sorted({r["animal_id"] for r in tac}), key="pa")
    t2 = c2.selectbox("timepoint ", sorted({r["timepoint"] for r in tac if r["animal_id"] == a2}), key="pt")
    ti2 = c3.selectbox("tissue ", sorted({r["tissue"] for r in tac}), key="pti")
    tac_sel = [(r["time_min"], r.get("decay_corrected_bqml")) for r in tac
               if r["animal_id"] == a2 and r["timepoint"] == t2 and r["tissue"] == ti2]
    ip2 = inp.get((a2, t2))
    if ip2:
        res = core.patlak(tac_sel, [(d["time_min"], d["plasma_activity_bqml"]) for d in ip2])
        if res:
            fig, ax = plt.subplots(figsize=(6, 4))
            xs = [p[0] for p in res["points"]]; ys = [p[1] for p in res["points"]]
            ts = [p[2] for p in res["points"]]
            cols = ["red" if tt >= res["tstar"] else "gray" for tt in ts]
            ax.scatter(xs, ys, c=cols)
            xr = [min(xs), max(xs)]
            ax.plot(xr, [res["Ki"] * x + res["intercept"] for x in xr], "b-")
            ax.set_xlabel(r"$\int_0^t C_p\,dt' / C_p(t)$ (min)")
            ax.set_ylabel(r"$C_{tissue}(t)/C_p(t)$")
            ax.set_title("Ki=%.4f  r2=%.3f  t*=%.1f min" % (res["Ki"], res["r2"], res["tstar"]))
            st.pyplot(fig)

# ---- Tab 4 ----
with tab4:
    st.subheader("종단 within-animal 변화 (paired)")
    metric = st.radio("metric", ["Ki", "SUV_bw"], horizontal=True)
    src = kin_rows if metric == "Ki" else suv_rows
    longc = core.longitudinal_changes(src, metric)
    st.dataframe(pd.DataFrame(longc), use_container_width=True)
    st.markdown("**모델 참고 범위 이탈 경고 (model=%s)**" % model)
    rw = core.reference_check(suv_rows, model=model)
    for w in (rw or ["이탈 없음"]):
        (st.warning if rw else st.success)(w)
    st.markdown("**메타데이터 누락 경고**")
    mw = core.metadata_warnings(meta, inp, lc, tac)
    for w in (mw or ["누락 없음"]):
        (st.warning if mw else st.success)(w)

# ---- Tab 5 ----
with tab5:
    st.subheader("Cohort 통계 (Ki paired t-test)")
    cs = core.cohort_summary(kin_rows, "Ki")
    st.markdown("descriptive (group x tissue x timepoint)")
    st.dataframe(pd.DataFrame(cs["descriptive"]), use_container_width=True)
    st.markdown("paired t-test (baseline vs post)")
    st.dataframe(pd.DataFrame(cs["paired_tests"]), use_container_width=True)

    longki = core.longitudinal_changes(kin_rows, "Ki")
    mech = core.classify_mechanism(longki)
    st.success("Mechanism 분류: **%s**" % mech["label"])
    st.json(mech["scores"])

    st.markdown("**요약 리포트**")
    n_animals = len({r["animal_id"] for r in suv_rows})
    n_tissues = len({r["tissue"] for r in suv_rows})
    tx_m = next((t for t in cs["paired_tests"]
                 if t["group"] == "treatment" and t["tissue"] == "skeletal_muscle"), None)
    pstr = ("%.4f" % tx_m["p"]) if (tx_m and tx_m["p"] is not None) else "n/a"
    dstr = ("%.4f" % tx_m["mean_delta"]) if (tx_m and tx_m["mean_delta"] is not None) else "n/a"
    st.text("[국문] %d마리·%d조직 micro-PET 18F-FDG TAC 를 18F decay 보정 후 SUV"
            "(BW/lean/glucose)·Patlak Ki·LC 보정 MRGlu 로 정량. treatment 군 골격근 "
            "Ki 종단 변화 평균 %s (paired p=%s). 주된 기전: '%s' (참고모델 %s)."
            % (n_animals, n_tissues, dstr, pstr, mech["label"], model))
    st.text("[EN] Micro-PET 18F-FDG TACs (%d animals, %d tissues) were decay-corrected "
            "and quantified as SUV (BW/lean/glucose), Patlak Ki and LC-corrected MRGlu. "
            "Treatment-group skeletal-muscle Ki changed by %s on average (paired p=%s). "
            "Dominant mechanism: '%s' (reference model %s)."
            % (n_animals, n_tissues, dstr, pstr, mech["label"], model))

st.caption(DISCLAIMER)
