#!/usr/bin/env python3
"""
AdipoGlowIVIS - Streamlit 앱
rodent IVIS 광학영상 지방 depot ROI 정량 도구 (Obesity 도메인 / 동물실험 in vivo 이미지 정량).

실행:
    pip install -r requirements.txt   # (가상환경에서, 전역설치 금지)
    streamlit run app.py

본 도구는 연구용·참고용이며, 정량 결과는 사용자 검증이 필요합니다.
샘플 데이터는 합성(synthetic) 데이터입니다.
"""

import os
import json
import io

import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats

HERE = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(HERE, "data")
SAMPLE_CSV = os.path.join(DATA_DIR, "sample_roi.csv")
REF_JSON = os.path.join(DATA_DIR, "reporter_reference.json")

REF_EXPOSURE = 30.0
REF_BINNING = 8.0

REQUIRED_COLS = [
    "animal_id", "group", "timepoint", "depot", "total_flux_ps",
    "avg_radiance", "exposure_sec", "binning", "f_stop",
    "substrate_post_min", "is_background",
]

DISCLAIMER = ("⚠️ 본 도구는 연구용·참고용이며, 정량 결과는 사용자 검증이 필요합니다. "
              "기본 샘플 데이터는 합성(synthetic) 데이터입니다.")


# --------------------------------------------------------------------------
# 계산 로직
# --------------------------------------------------------------------------
def normalize(df):
    df = df.copy()
    factor = (REF_EXPOSURE / df["exposure_sec"]) * (REF_BINNING / df["binning"])
    df["norm_total_flux"] = df["total_flux_ps"] * factor
    df["norm_avg_radiance"] = df["avg_radiance"] * factor
    return df


def qc(df):
    flags = []
    SAT = 1.0e6
    sat = df[(df["avg_radiance"] >= SAT)]
    for _, r in sat.iterrows():
        flags.append(("SATURATION", "%s/%s/%s avg_radiance>=%.0e"
                      % (r["animal_id"], r["timepoint"], r["depot"], SAT)))
    miss = df[(df["substrate_post_min"].isna()) & (df["is_background"] == 0)]
    for _, r in miss.iterrows():
        flags.append(("기질시점 미기록", "%s/%s/%s"
                      % (r["animal_id"], r["timepoint"], r["depot"])))
    exps = sorted(df["exposure_sec"].dropna().unique())
    if len(exps) > 1:
        flags.append(("노출 불일치", "노출시간 %s — normalize 적용됨"
                      % ", ".join("%gs" % e for e in exps)))
    # 종단 baseline 누락 체크
    for (a, d), g in df[df["is_background"] == 0].groupby(["animal_id", "depot"]):
        tps = set(g["timepoint"])
        if "baseline" not in tps and len(tps) >= 1:
            flags.append(("baseline 누락", "%s/%s" % (a, d)))
    return flags


def subtract_background(df):
    bg = (df[df["is_background"] == 1]
          .set_index(["animal_id", "timepoint"])[["norm_total_flux", "norm_avg_radiance"]]
          .rename(columns={"norm_total_flux": "bg_flux", "norm_avg_radiance": "bg_rad"}))
    sig = df[df["is_background"] == 0].copy()
    sig = sig.join(bg, on=["animal_id", "timepoint"])
    sig["bg_flux"] = sig["bg_flux"].fillna(0.0)
    sig["bg_rad"] = sig["bg_rad"].fillna(0.0)
    sig["corr_total_flux"] = sig["norm_total_flux"] - sig["bg_flux"]
    sig["corr_avg_radiance"] = sig["norm_avg_radiance"] - sig["bg_rad"]
    return sig


def depot_summary(sig):
    return (sig.groupby("depot")
            .agg(n=("corr_avg_radiance", "size"),
                 mean_total_flux=("corr_total_flux", "mean"),
                 mean_avg_radiance=("corr_avg_radiance", "mean"))
            .reindex(["iBAT", "iWAT", "eWAT"])
            .dropna(how="all")
            .reset_index())


def within_fold(sig, baseline="baseline"):
    base = (sig[sig["timepoint"] == baseline]
            .set_index(["animal_id", "depot"])["corr_avg_radiance"])
    rows = []
    for _, r in sig[sig["timepoint"] != baseline].iterrows():
        b = base.get((r["animal_id"], r["depot"]))
        fold = (r["corr_avg_radiance"] / b) if (b and b != 0) else np.nan
        rows.append({"animal_id": r["animal_id"], "group": r["group"],
                     "depot": r["depot"], "timepoint": r["timepoint"], "fold": fold})
    return pd.DataFrame(rows)


def cohort_paired(sig, depot, timepoint):
    out = []
    for g, gdf in sig[sig["depot"] == depot].groupby("group"):
        piv = gdf.pivot_table(index="animal_id", columns="timepoint",
                              values="corr_avg_radiance")
        if "baseline" in piv.columns and timepoint in piv.columns:
            pair = piv[["baseline", timepoint]].dropna()
            if len(pair) >= 2:
                t, p = stats.ttest_rel(pair[timepoint], pair["baseline"])
            else:
                t, p = np.nan, np.nan
            out.append({"group": g, "n": len(pair),
                        "mean_baseline": pair["baseline"].mean() if len(pair) else np.nan,
                        "mean_%s" % timepoint: pair[timepoint].mean() if len(pair) else np.nan,
                        "paired_t": t, "p_value": p})
    return pd.DataFrame(out)


def classify(fold_df):
    means = fold_df.groupby("depot")["fold"].mean().to_dict()
    ibat, iwat, ewat = means.get("iBAT", 1), means.get("iWAT", 1), means.get("eWAT", 1)
    if ibat >= 1.5 and ibat > ewat * 1.2:
        label = "BAT 활성형 (iBAT 신호 우세 증가)"
    elif ewat <= 0.7:
        label = "지방염증 감소형 (eWAT 신호 감소)"
    elif ibat >= 1.2 and (iwat >= 1.2 or ewat <= 0.9):
        label = "depot 선택형 (depot별 차등 반응)"
    else:
        label = "변화 미미 / 미분류"
    return means, label


# --------------------------------------------------------------------------
# UI
# --------------------------------------------------------------------------
def main():
    st.set_page_config(page_title="AdipoGlowIVIS", layout="wide")
    st.title("AdipoGlowIVIS — rodent IVIS 지방 depot 광학영상 정량")
    st.caption("Obesity 도메인 · 동물실험 도구 (in vivo 이미지 정량)")
    st.warning(DISCLAIMER)

    st.sidebar.header("데이터 입력")
    up = st.sidebar.file_uploader("ROI CSV 업로드", type=["csv"])
    use_sample = st.sidebar.button("샘플 데이터 로드")

    df = None
    if up is not None:
        df = pd.read_csv(up)
        st.sidebar.success("업로드 CSV 로드됨")
    elif use_sample or "loaded" not in st.session_state:
        if os.path.exists(SAMPLE_CSV):
            df = pd.read_csv(SAMPLE_CSV)
            st.session_state["loaded"] = True
            if use_sample:
                st.sidebar.success("샘플(합성) 데이터 로드됨")

    if df is None:
        st.info("좌측에서 CSV 를 업로드하거나 '샘플 데이터 로드' 를 누르세요.")
        st.stop()

    missing = [c for c in REQUIRED_COLS if c not in df.columns]
    if missing:
        st.error("필수 컬럼 누락: %s" % ", ".join(missing))
        st.stop()

    df = normalize(df)
    sig = subtract_background(df)

    tab1, tab2, tab3, tab4, tab5 = st.tabs(
        ["1. Ingest+Normalize", "2. 배경차감·depot", "3. 종단 fold",
         "4. Reference·QC", "5. Cohort·리포트"])

    # ---- Tab 1 ----
    with tab1:
        st.subheader("ROI ingest + 노출/binning normalize")
        st.write("기준 노출 %gs · binning %g 로 normalize (이미지 간 비교 가능 단위)."
                 % (REF_EXPOSURE, REF_BINNING))
        st.dataframe(df[["animal_id", "group", "timepoint", "depot",
                         "exposure_sec", "binning", "avg_radiance",
                         "norm_avg_radiance", "total_flux_ps", "norm_total_flux",
                         "substrate_post_min", "is_background"]],
                     use_container_width=True)

    # ---- Tab 2 ----
    with tab2:
        st.subheader("배경/자가형광 차감 + depot 비교")
        ds = depot_summary(sig)
        st.dataframe(ds, use_container_width=True)
        st.caption("total flux=전체 신호 / avg radiance=밀도 → ROI 크기 의존성 분리")
        if not ds.empty:
            fig, ax = plt.subplots(1, 2, figsize=(9, 3))
            ax[0].bar(ds["depot"], ds["mean_total_flux"], color="#c44")
            ax[0].set_title("mean total flux"); ax[0].set_ylabel("photons/s")
            ax[1].bar(ds["depot"], ds["mean_avg_radiance"], color="#48c")
            ax[1].set_title("mean avg radiance"); ax[1].set_ylabel("p/s/cm2/sr")
            for a in ax:
                a.tick_params(axis="x", rotation=0)
            st.pyplot(fig)

    # ---- Tab 3 ----
    with tab3:
        st.subheader("종단 within-animal fold-change (paired)")
        fold_df = within_fold(sig)
        st.dataframe(fold_df, use_container_width=True)
        if not fold_df.empty:
            fsum = fold_df.groupby(["group", "depot"])["fold"].mean().reset_index()
            st.write("군×depot 평균 fold")
            st.dataframe(fsum, use_container_width=True)
        sub = df[(df["is_background"] == 0)]["substrate_post_min"].dropna()
        if len(sub):
            st.info("기질(luciferin) 주입 후 경과시간: %g~%g분 (평균 %.1f분). "
                    "peak(통상 10-15분) 기준 시점 정렬 권장."
                    % (sub.min(), sub.max(), sub.mean()))

    # ---- Tab 4 ----
    with tab4:
        st.subheader("reporter / model reference + QC flag")
        if os.path.exists(REF_JSON):
            with open(REF_JSON) as f:
                ref = json.load(f)
            for name, info in ref["reporters"].items():
                with st.expander("%s — %s" % (name, info["full_name"])):
                    st.write("primary depot: %s · substrate: %s · peak %s분"
                             % (info["primary_depot"], info["substrate"],
                                info["typical_peak_min"]))
                    st.json(info["models"])
        st.markdown("**QC flags**")
        flags = qc(df)
        if flags:
            st.table(pd.DataFrame(flags, columns=["유형", "내용"]))
        else:
            st.success("플래그 없음")

    # ---- Tab 5 ----
    with tab5:
        st.subheader("cohort 통계 + 약력 분류 + 리포트")
        depots = [d for d in ["iBAT", "iWAT", "eWAT"] if d in set(sig["depot"])]
        tps = [t for t in sig["timepoint"].unique() if t != "baseline"]
        c1, c2 = st.columns(2)
        depot = c1.selectbox("depot", depots, index=0 if depots else None)
        tp = c2.selectbox("비교 시점", tps, index=0 if tps else None)
        if depot and tp:
            ct = cohort_paired(sig, depot, tp)
            st.dataframe(ct, use_container_width=True)
            st.caption("paired t-test (scipy.stats.ttest_rel): 군 내 baseline vs 선택 시점")
        fold_df = within_fold(sig)
        means, label = classify(fold_df)
        st.success("약력 분류: %s" % label)
        st.write("depot별 평균 fold: " +
                 ", ".join("%s=%.2f" % (k, v) for k, v in means.items()))

        ko = ("[방법] depot ROI total flux·average radiance 를 노출(%gs)·binning(%g) "
              "normalize 후 배경 ROI 차감, depot 비교·within-animal fold·군별 paired "
              "t-test 산출. [결과] depot 평균 fold %s. 약력 분류: %s."
              % (REF_EXPOSURE, REF_BINNING,
                 ", ".join("%s=%.2f" % (k, v) for k, v in means.items()), label))
        en = ("[Methods] Depot ROI total flux/average radiance normalized to %gs "
              "exposure & binning %g, background-subtracted; depot comparison, "
              "within-animal fold, per-group paired t-tests computed. "
              "[Results] Mean fold %s. Class: %s."
              % (REF_EXPOSURE, REF_BINNING,
                 ", ".join("%s=%.2f" % (k, v) for k, v in means.items()), label))
        st.markdown("**리포트 (국문)**"); st.code(ko)
        st.markdown("**Report (English)**"); st.code(en)
        st.download_button("리포트 텍스트 다운로드",
                           data=(ko + "\n\n" + en).encode("utf-8"),
                           file_name="adipoglowivis_report.txt")


if __name__ == "__main__":
    main()
