#!/usr/bin/env python3
"""
GlyceSurrogate-Kor -- optional Streamlit UI (mirrors the CLI in main.py).

Run with:   streamlit run app.py
This file is OPTIONAL and OFFLINE. It imports cleanly without launching; all heavy
logic lives in surrogacy.py so the CLI and UI share one engine.

  RESEARCH / REFERENCE USE ONLY -- NOT FOR CLINICAL DECISION-MAKING.
"""
from __future__ import annotations

import io
import os
import sys

import numpy as np
import pandas as pd

import matplotlib
matplotlib.use("Agg")          # headless / offline-safe backend
import matplotlib.pyplot as plt  # noqa: E402

import streamlit as st         # noqa: E402

# Make the local engine importable when launched via `streamlit run app.py`.
_HERE = os.path.dirname(os.path.abspath(__file__))
if _HERE not in sys.path:
    sys.path.insert(0, _HERE)

import surrogacy as S          # noqa: E402

DEFAULT_DATA = os.path.join(_HERE, "data", "demo_trials.csv")


# --------------------------------------------------------------------------------------
# Cached engine calls
# --------------------------------------------------------------------------------------
@st.cache_data(show_spinner=False)
def _load(path_or_buffer):
    return S.load_data(path_or_buffer)


@st.cache_data(show_spinner=False)
def _analyze(df):
    results = S.analyze_all(df)
    return S.results_to_frame(results), results


# --------------------------------------------------------------------------------------
# Plot helpers
# --------------------------------------------------------------------------------------
def scatter_fig(df, drug_class, surrogate, hard_outcome):
    cell = df[
        (df["drug_class"] == drug_class)
        & (df["surrogate"] == surrogate)
        & (df["hard_outcome"] == hard_outcome)
    ]
    fig, ax = plt.subplots(figsize=(6, 4.2))
    if cell.empty:
        ax.text(0.5, 0.5, "no trials for this cell", ha="center", va="center")
        ax.axis("off")
        return fig, None

    x = cell["delta_surrogate"].to_numpy(float)
    y = cell["loghr"].to_numpy(float)
    yerr = 1.959964 * cell["loghr_se"].to_numpy(float)
    w = 1.0 / cell["loghr_se"].to_numpy(float) ** 2

    # Bubble size ~ inverse-variance weight.
    sizes = 40 + 400 * (w / w.max())
    ax.errorbar(x, y, yerr=yerr, fmt="none", ecolor="#bbbbbb", capsize=3, zorder=1)
    ax.scatter(x, y, s=sizes, alpha=0.7, edgecolor="k", zorder=2)
    for _, r in cell.iterrows():
        ax.annotate(r["trial"], (r["delta_surrogate"], r["loghr"]),
                    fontsize=7, xytext=(4, 4), textcoords="offset points")

    fit = None
    if len(cell) >= S.MIN_TRIALS_FOR_REGRESSION and not np.allclose(x, x[0]):
        fit = S._wls_fit(x, y, w)
        gx = np.linspace(x.min(), x.max(), 50)
        gy = fit["b0"] + fit["b1"] * gx
        ax.plot(gx, gy, color="C3", lw=2, label=f"WLS fit (R²={fit['r2']:.2f})")
        # Prediction band.
        bands = [S._pred_band(fit, xv) for xv in gx]
        lo = [b[1] for b in bands]
        hi = [b[2] for b in bands]
        ax.fill_between(gx, lo, hi, color="C3", alpha=0.12, label="95% pred. band")
        ax.legend(fontsize=8)

    ax.axhline(0.0, color="k", lw=0.8, ls="--")
    ax.set_xlabel(f"Δ {surrogate} (treatment effect on surrogate)")
    ax.set_ylabel("log-HR (treatment effect on hard outcome)")
    ax.set_title(f"{drug_class} | {surrogate} → {hard_outcome}", fontsize=10)
    fig.tight_layout()
    return fig, fit


def ste_curve_fig(df, drug_class, surrogate, hard_outcome):
    cell = df[
        (df["drug_class"] == drug_class)
        & (df["surrogate"] == surrogate)
        & (df["hard_outcome"] == hard_outcome)
    ]
    fig, ax = plt.subplots(figsize=(6, 3.6))
    x = cell["delta_surrogate"].to_numpy(float)
    y = cell["loghr"].to_numpy(float)
    w = 1.0 / cell["loghr_se"].to_numpy(float) ** 2
    if len(cell) < S.MIN_TRIALS_FOR_REGRESSION or np.allclose(x, x[0]):
        ax.text(0.5, 0.5, "insufficient data for STE curve", ha="center", va="center")
        ax.axis("off")
        return fig
    fit = S._wls_fit(x, y, w)
    span = max(x.max() - x.min(), 0.5)
    gx = np.linspace(x.min() - span, x.max() + span, 200)
    yh, lo, hi = zip(*[S._pred_band(fit, xv) for xv in gx])
    ax.plot(gx, yh, color="C0", label="predicted log-HR")
    ax.fill_between(gx, lo, hi, color="C0", alpha=0.15, label="95% pred. band")
    ax.axhline(0.0, color="k", lw=0.8, ls="--")
    ste = S.compute_ste(fit, surrogate)
    if ste is not None:
        ax.axvline(ste, color="C3", lw=1.5, ls=":", label=f"STE = {ste:.3f}")
    ax.set_xlabel(f"Δ {surrogate}")
    ax.set_ylabel("predicted log-HR")
    ax.set_title("Surrogate Threshold Effect (STE)", fontsize=10)
    ax.legend(fontsize=8)
    fig.tight_layout()
    return fig


# --------------------------------------------------------------------------------------
# App
# --------------------------------------------------------------------------------------
def main():
    st.set_page_config(page_title="GlyceSurrogate-Kor", layout="wide")
    st.title("GlyceSurrogate-Kor 글라이스서로게이트코어")
    st.caption("Domain: DM (당뇨병)  |  Category: 연구 아이디어 생성  |  "
               "Trial-level glycemic-surrogate validity meta-regression")
    st.error("⚠️ " + S.DISCLAIMER)

    # ---- Data source ----
    st.sidebar.header("Data")
    up = st.sidebar.file_uploader("Upload trial-level CSV", type=["csv"])
    if up is not None:
        df = _load(io.StringIO(up.getvalue().decode("utf-8")))
        st.sidebar.success(f"Loaded {len(df)} rows from upload.")
    else:
        df = _load(DEFAULT_DATA)
        st.sidebar.info("Using bundled demo dataset (synthetic).")

    table, results = _analyze(df)

    tab_sur, tab_scatter, tab_par, tab_hyp, tab_raw = st.tabs(
        ["Surrogacy table", "R²/STE plots", "Paradox", "Hypotheses", "Raw data"]
    )

    # ---- Surrogacy table ----
    with tab_sur:
        st.subheader("Trial-level surrogacy by class × surrogate × outcome")
        show = table[[
            "drug_class", "surrogate", "hard_outcome", "n_trials",
            "r2_trial", "r2_ci_lo", "r2_ci_hi", "ste", "pte", "grade",
            "has_paradox",
        ]].copy()
        st.dataframe(show, use_container_width=True)
        c1, c2, c3 = st.columns(3)
        c1.metric("Surrogacy cells", len(results))
        c2.metric("Paradox flags", int(table["has_paradox"].sum()))
        c3.metric("Strong cells", int((table["grade"] == "strong").sum()))

    # ---- Scatter / STE ----
    with tab_scatter:
        st.subheader("R²_trial scatter + WLS fit, and STE curve")
        classes = sorted(df["drug_class"].unique())
        col = st.columns(3)
        dc = col[0].selectbox("Drug class", classes)
        surr = col[1].selectbox("Surrogate", sorted(df[df["drug_class"] == dc]["surrogate"].unique()))
        outs = sorted(df[(df["drug_class"] == dc) & (df["surrogate"] == surr)]["hard_outcome"].unique())
        ho = col[2].selectbox("Hard outcome", outs)
        left, right = st.columns(2)
        with left:
            fig, fit = scatter_fig(df, dc, surr, ho)
            st.pyplot(fig)
        with right:
            st.pyplot(ste_curve_fig(df, dc, surr, ho))

    # ---- Paradox ----
    with tab_par:
        st.subheader("Surrogate-paradox flags (ACCORD-style)")
        st.write("Surrogate improved but hard outcome worsened (HR > 1).")
        flags = S.detect_paradox(df)
        if flags:
            st.dataframe(pd.DataFrame(flags), use_container_width=True)
        else:
            st.success("No paradox flags detected.")

    # ---- Hypotheses ----
    with tab_hyp:
        st.subheader("Mined unvalidated pairs → validation hypotheses")
        hyps = S.mine_gaps(df, results)
        topn = st.slider("Show top N", 1, max(len(hyps), 1), min(10, len(hyps)))
        for i, h in enumerate(hyps[:topn], 1):
            ss = h["suggestion"]
            with st.expander(
                f"H{i}. [{h['kind']}] {h['drug_class']} | {h['surrogate']} → "
                f"{h['hard_outcome']}  (grade={h['grade']}, n={h['n_trials']})"
            ):
                st.markdown(f"**Hypothesis:** {h['hypothesis']}")
                st.markdown(f"**Why flagged:** {', '.join(h['reasons'])}")
                st.markdown(
                    f"**Suggested study:** +{ss['n_additional_trials']} trial(s); "
                    f"per-arm n ≈ {ss['approx_per_arm_n']:,} "
                    f"(target |logHR| = {ss['target_loghr']}, assumed event rate = "
                    f"{ss['assumed_event_rate']}, ~{ss['approx_total_events']} events)."
                )

    # ---- Raw data ----
    with tab_raw:
        st.subheader("Loaded trial-level data")
        st.dataframe(df, use_container_width=True)
        st.download_button(
            "Download surrogacy table (CSV)",
            data=table.to_csv(index=False).encode("utf-8"),
            file_name="glyce_surrogacy_table.csv",
            mime="text/csv",
        )


if __name__ == "__main__":
    main()
