#!/usr/bin/env python3
"""
app.py — optional Streamlit UI for WtLossSurrogate-Kor.

Mirrors the CLI: upload a trials CSV (or use the bundled demo), then view the
surrogacy table, the R²_trial scatter (surrogate vs hard-outcome effect with
WLS fit + prediction band + STE), the dose–response curve, the PTE bar, the
paradox flags, and the validation-hypothesis cards.

Run (optional — not required for verification):
    streamlit run app.py

OFFLINE only. ⚠️  연구용·참고용 (research/reference use only) —
not for clinical decision-making.
"""

from __future__ import annotations

import io
import math

import matplotlib
matplotlib.use("Agg")  # headless backend; safe under Streamlit
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st

import dataio
import surrogacy as sg


st.set_page_config(page_title="WtLossSurrogate-Kor", layout="wide")


# --------------------------------------------------------------------------- #
# Header + disclaimer
# --------------------------------------------------------------------------- #

st.title("WtLossSurrogate-Kor (웨이트로스서로게이트코어)")
st.caption("Obesity (비만대사질환) · 연구 아이디어 생성 (research-hypothesis generation)")
st.error(
    "⚠️ 연구용·참고용 (research/reference use only) — not for clinical "
    "decision-making. Bundled effect sizes are illustrative/synthetic, NOT "
    "official trial readouts."
)


# --------------------------------------------------------------------------- #
# Data input
# --------------------------------------------------------------------------- #

@st.cache_data(show_spinner=False)
def _load_default():
    return dataio.load_trials()


with st.sidebar:
    st.header("Data")
    up = st.file_uploader("Upload trials CSV", type=["csv"])
    st.caption("Schema: trial, drug_class, pct_weight_loss, pct_weight_loss_se, "
               "hard_outcome, loghr, loghr_se")
    top = st.number_input("Limit (top N hypotheses)", min_value=1, max_value=100,
                          value=10, step=1)

if up is not None:
    try:
        raw = up.getvalue().decode("utf-8")
        df = pd.read_csv(io.StringIO(raw), comment="#", skipinitialspace=True)
        # reuse loader validation by writing through dataio's logic manually
        df.columns = [c.strip() for c in df.columns]
        missing = [c for c in dataio.REQUIRED_COLS if c not in df.columns]
        if missing:
            st.warning(f"Missing columns {missing}; falling back to demo data.")
            df = _load_default()
        else:
            for c in dataio.NUMERIC_COLS:
                df[c] = pd.to_numeric(df[c], errors="coerce")
            for c in ["trial", "drug_class", "hard_outcome"]:
                df[c] = df[c].astype(str).str.strip()
            df = df.dropna(subset=dataio.NUMERIC_COLS + ["drug_class", "hard_outcome"]).reset_index(drop=True)
            for c in ["pct_weight_loss_se", "loghr_se"]:
                df.loc[df[c] <= 0, c] = 1e-3
    except Exception as e:
        st.warning(f"Could not parse upload ({e}); using demo data.")
        df = _load_default()
else:
    df = _load_default()


counts = dataio.summary_counts(df)
c1, c2, c3 = st.columns(3)
c1.metric("Trials", counts["n_trials"])
c2.metric("Drug classes", counts["n_classes"])
c3.metric("Hard outcomes", counts["n_outcomes"])


# --------------------------------------------------------------------------- #
# Tabs
# --------------------------------------------------------------------------- #

tab_surr, tab_scatter, tab_dose, tab_pte, tab_paradox, tab_hyp = st.tabs(
    ["Surrogacy table", "R²_trial scatter / STE", "Dose–response",
     "PTE (weight-mediated)", "Paradox", "Hypotheses"]
)


# ---- 1. Surrogacy table -------------------------------------------------- #
with tab_surr:
    st.subheader("Trial-level surrogacy by class × outcome")
    rows = []
    for r in sg.all_surrogacy(df):
        if r.n_trials == 0:
            continue
        rows.append({
            "class": r.drug_class, "outcome": r.hard_outcome, "k": r.n_trials,
            "R2_trial": None if math.isnan(r.r2_trial) else round(r.r2_trial, 3),
            "R2_CI_low": None if math.isnan(r.r2_ci_low) else round(r.r2_ci_low, 3),
            "R2_CI_high": None if math.isnan(r.r2_ci_high) else round(r.r2_ci_high, 3),
            "STE_%": None if r.ste is None else round(r.ste, 1),
            "PTE": None if r.pte is None else round(r.pte, 3),
            "PTE_flag": r.pte_flag,
            "grade": r.grade.upper(),
            "paradox": r.paradox,
        })
    st.dataframe(pd.DataFrame(rows), width="stretch")


# ---- 2. Scatter + WLS fit + band + STE ----------------------------------- #
with tab_scatter:
    st.subheader("Surrogate (% weight loss) vs hard-outcome effect (log-HR)")
    classes = sorted(df["drug_class"].unique())
    outcomes = sorted(df["hard_outcome"].unique())
    sc1, sc2 = st.columns(2)
    cls = sc1.selectbox("Drug class", classes,
                        index=classes.index("GLP1RA") if "GLP1RA" in classes else 0)
    out = sc2.selectbox("Hard outcome", outcomes,
                        index=outcomes.index("MACE") if "MACE" in outcomes else 0)

    r = sg.surrogacy_for(df, cls, out)
    if r is None or r.n_trials < sg.MIN_TRIALS_REGRESSION:
        st.info("Need ≥3 trials in this cell to fit a regression.")
    else:
        fit = sg.wls(r.x, r.y, r.w)
        xs = np.linspace(min(r.x.min(), -1) * 1.15, 0.0, 120)
        yhat = fit["intercept"] + fit["slope"] * xs
        bands = [sg.predict_with_band(fit, xv) for xv in xs]
        half = np.array([b[1] for b in bands])

        fig, ax = plt.subplots(figsize=(6.5, 4.5))
        sizes = 1200 * (r.w / r.w.max())
        ax.scatter(r.x, r.y, s=sizes, alpha=0.55, label="trials (size ∝ weight)")
        ax.plot(xs, yhat, color="#c0392b", label="WLS fit")
        ax.fill_between(xs, yhat - half, yhat + half, color="#c0392b", alpha=0.12,
                        label="95% prediction band")
        ax.axhline(0, color="gray", lw=1, ls="--")
        if r.ste is not None:
            ax.axvline(r.ste, color="#2980b9", ls=":", lw=1.5,
                       label=f"STE = {r.ste:.1f}%")
        ax.set_xlabel("% body-weight change (surrogate; negative = loss)")
        ax.set_ylabel("hard-outcome effect (log-HR; negative = benefit)")
        ax.set_title(f"{cls} → {out}   R²_trial={r.r2_trial:.2f}  ({r.grade.upper()})")
        ax.legend(fontsize=8)
        st.pyplot(fig)
        st.write(f"**R²_trial** = {r.r2_trial:.3f} "
                 f"(95% CI {r.r2_ci_low:.2f}–{r.r2_ci_high:.2f}) · "
                 f"**STE** = {('%.1f%%' % r.ste) if r.ste is not None else '—'} · "
                 f"**grade** = {r.grade.upper()}")


# ---- 3. Dose–response ---------------------------------------------------- #
with tab_dose:
    st.subheader("Dose–response surrogacy (more weight loss → more hard benefit?)")
    classes = sorted(df["drug_class"].unique())
    outcomes = sorted(df["hard_outcome"].unique())
    d1, d2 = st.columns(2)
    cls2 = d1.selectbox("Drug class ", classes,
                        index=classes.index("GLP1RA") if "GLP1RA" in classes else 0,
                        key="dose_cls")
    out2 = d2.selectbox("Hard outcome ", outcomes,
                        index=outcomes.index("INCIDENT_T2D") if "INCIDENT_T2D" in outcomes else 0,
                        key="dose_out")
    dr = sg.dose_response(df, cls2, out2)
    if dr is None:
        st.info("Need ≥3 trials in this cell.")
    else:
        st.write(f"**Verdict:** {dr.verdict}")
        bdf = pd.DataFrame(dr.bins, columns=["wl_bin", "mean_wl", "mean_logHR", "k"])
        st.dataframe(bdf, width="stretch")
        fig2, ax2 = plt.subplots(figsize=(6.5, 4))
        ax2.plot(bdf["mean_wl"], bdf["mean_logHR"], "o-", color="#27ae60")
        ax2.axhline(0, color="gray", ls="--", lw=1)
        ax2.set_xlabel("mean % weight loss (per bin)")
        ax2.set_ylabel("mean log-HR (negative = benefit)")
        ax2.set_title(f"Dose–response: {cls2} → {out2}")
        st.pyplot(fig2)


# ---- 4. PTE bar ---------------------------------------------------------- #
with tab_pte:
    st.subheader("PTE — weight-mediated fraction of the hard-outcome benefit")
    pte_rows = []
    for r in sg.all_surrogacy(df):
        if r.pte is not None:
            pte_rows.append((f"{r.drug_class}\n{r.hard_outcome}", r.pte, r.pte_flag))
    if not pte_rows:
        st.info("No cell had enough trials to estimate PTE.")
    else:
        labels = [p[0] for p in pte_rows]
        vals = [p[1] for p in pte_rows]
        fig3, ax3 = plt.subplots(figsize=(7, 4))
        colors = ["#e67e22" if v < 0.6 else "#16a085" for v in vals]
        ax3.bar(range(len(vals)), vals, color=colors)
        ax3.set_xticks(range(len(vals)))
        ax3.set_xticklabels(labels, fontsize=8)
        ax3.set_ylim(0, 1)
        ax3.axhline(0.6, color="gray", ls="--", lw=1, label="0.6 (large direct effect below)")
        ax3.set_ylabel("PTE (fraction mediated by weight)")
        ax3.set_title("Weight-mediated vs weight-independent effect")
        ax3.legend(fontsize=8)
        st.pyplot(fig3)
        st.caption("Bars below 0.6 (orange) imply a substantial weight-INDEPENDENT "
                   "(direct) component — the SELECT debate.")


# ---- 5. Paradox ---------------------------------------------------------- #
with tab_paradox:
    st.subheader("Surrogate-paradox flags")
    par = sg.paradox_scan(df)
    if not par:
        st.success("No surrogate paradox detected.")
    else:
        for r in par:
            st.warning(
                f"**{r.drug_class} → {r.hard_outcome}** (k={r.n_trials}): "
                f"weight improves but hard outcome trends adverse → surrogacy "
                f"INVALID here. High-priority validation target."
            )


# ---- 6. Hypotheses ------------------------------------------------------- #
with tab_hyp:
    st.subheader("Validation-study hypotheses (mined gaps)")
    hyps = sg.mine_gaps(df)[: int(top)]
    for i, h in enumerate(hyps, 1):
        n_arm = (f"{h.suggested_n_per_arm:,}/arm" if h.suggested_n_per_arm else "n/a")
        with st.container(border=True):
            st.markdown(f"**H{i} · priority {h.priority:.2f}** — {h.statement}")
            st.caption(
                f"class={h.drug_class} · outcome={h.hard_outcome} · "
                f"k(existing)={h.n_trials} · "
                f"R²_trial={'—' if h.r2_trial is None else round(h.r2_trial,2)} · "
                f"suggest ~{h.suggested_trials} more trial(s), per-arm ≈ {n_arm} "
                f"(Schoenfeld, 6% events, 80% power)"
            )


st.divider()
st.caption("OFFLINE tool · pure numpy/scipy meta-regression (statsmodels not "
           "required). " + sg.DISCLAIMER)
