#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
NITSurrogate-Kor — Streamlit UI (optional)
==========================================
Mirrors the main.py CLI: upload data, per-stage surrogacy table, R^2_trial
scatter per stage, STE curve, PTE chain diagram, NIT ranking, hypothesis cards.

Run (optional — not required for grading):
    streamlit run app.py

Imports cleanly (ast.parse) and reuses the analysis engine in main.py so the
two interfaces never diverge.

⚠️  연구용·참고용 (research/reference use only) — NOT for clinical
    decision-making. Demo data are illustrative / synthetic.
"""

import io
import os

import matplotlib
matplotlib.use("Agg")  # headless backend; safe even without a display
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st

import main as engine  # reuse the exact same analysis functions


DISCLAIMER_MD = (
    "> ⚠️ **연구용·참고용 (research / reference use only) — NOT for clinical "
    "decision-making.** Demo effect sizes are ILLUSTRATIVE / SYNTHETIC, not "
    "official trial readouts."
)


def _load(uploaded, alpha):
    if uploaded is not None:
        raw = uploaded.getvalue().decode("utf-8")
        df = engine.load_data(io.StringIO(raw))
        src = "uploaded file"
    else:
        df = engine.load_data(engine.DEFAULT_DATA)
        src = "bundled demo (%s)" % os.path.basename(engine.DEFAULT_DATA)
    return df, src


def _stage_scatter(df, stage_name, up, down, alpha):
    r = engine.analyze_stage(df, stage_name, up, down, alpha=alpha)
    fig, ax = plt.subplots(figsize=(5, 4))
    if r["fit"] is not None:
        fit = r["fit"]
        x, y = fit["x"], fit["y"]
        ax.scatter(x, y, s=40, alpha=0.8, label="trials")
        xs = np.linspace(float(np.min(x)), float(np.max(x)), 100)
        mean, lo, hi = engine.predict_with_band(fit, xs, alpha=alpha)
        ax.plot(xs, mean, color="C1", label="WLS fit")
        ax.fill_between(xs, lo, hi, color="C1", alpha=0.15,
                        label="95% pred. band")
        ax.axhline(0, color="grey", lw=0.8, ls="--")
        ax.set_title("%s  (R²=%.3f, %s)" % (stage_name, r["r2"], r["grade"]))
    else:
        ax.text(0.5, 0.5, "insufficient/sparse data\n(n=%d)" % r["n"],
                ha="center", va="center")
        ax.set_title(stage_name)
    ax.set_xlabel("upstream treatment benefit")
    ax.set_ylabel("downstream treatment benefit")
    ax.legend(fontsize=7, loc="best")
    fig.tight_layout()
    return fig, r


def _ste_curve(df, stage_name, up, down, alpha):
    r = engine.analyze_stage(df, stage_name, up, down, alpha=alpha)
    fig, ax = plt.subplots(figsize=(5, 4))
    if r["fit"] is not None:
        fit = r["fit"]
        x = fit["x"]
        span = max(float(np.max(x)) - float(np.min(x)), 1e-6)
        grid = np.linspace(float(np.min(x)) - 0.25 * span,
                           float(np.max(x)) + 1.5 * span, 300)
        mean, lo, hi = engine.predict_with_band(fit, grid, alpha=alpha)
        ax.plot(grid, mean, color="C0", label="predicted downstream")
        ax.fill_between(grid, lo, hi, color="C0", alpha=0.15,
                        label="95% pred. band")
        ax.axhline(0, color="grey", lw=0.8, ls="--")
        if r["ste_achievable"]:
            ax.axvline(r["ste"], color="C3", ls=":",
                       label="STE = %.3f" % r["ste"])
        ax.set_title("STE curve — %s" % stage_name)
        ax.set_xlabel("upstream treatment benefit")
        ax.set_ylabel("downstream treatment benefit")
        ax.legend(fontsize=7, loc="best")
    else:
        ax.text(0.5, 0.5, "STE not estimable\n(n=%d)" % r["n"],
                ha="center", va="center")
    fig.tight_layout()
    return fig


def _pte_diagram(pte):
    fig, ax = plt.subplots(figsize=(6, 2.4))
    ax.axis("off")
    boxes = {"NIT": (0.08, 0.5), "Histology": (0.5, 0.5), "Hard outcome": (0.9, 0.5)}
    for label, (xx, yy) in boxes.items():
        ax.text(xx, yy, label, ha="center", va="center",
                bbox=dict(boxstyle="round,pad=0.4", fc="#e8f0fe", ec="C0"))
    a = pte.get("a_path", float("nan"))
    pte_c = pte.get("pte_clamped", float("nan"))
    ax.annotate("", xy=(0.40, 0.5), xytext=(0.16, 0.5),
                arrowprops=dict(arrowstyle="->", color="C0"))
    ax.annotate("", xy=(0.80, 0.5), xytext=(0.60, 0.5),
                arrowprops=dict(arrowstyle="->", color="C0"))
    ax.annotate("", xy=(0.80, 0.30), xytext=(0.16, 0.30),
                arrowprops=dict(arrowstyle="->", color="C3", ls="--"))
    ax.text(0.28, 0.58, "a-path=%.2f" % a, ha="center", fontsize=8)
    ax.text(0.5, 0.22, "direct (NIT→hard) — PTE mediated ≈ %.2f"
            % (pte_c if not np.isnan(pte_c) else float("nan")),
            ha="center", fontsize=8, color="C3")
    fig.tight_layout()
    return fig


def render():
    st.set_page_config(page_title="NITSurrogate-Kor", layout="wide")
    st.title("NITSurrogate-Kor — MASLD/MASH trial-level surrogacy")
    st.caption("Domain: MASLD (대사성간질환)  |  Category: 연구 아이디어 생성 "
               "(research-hypothesis generation)")
    st.markdown(DISCLAIMER_MD)

    with st.sidebar:
        st.header("Data & settings")
        uploaded = st.file_uploader("Upload surrogacy CSV (optional)",
                                    type=["csv"])
        alpha = st.slider("alpha (CIs / bands)", 0.01, 0.20, 0.05, 0.01)
        engine.GRADE_STRONG = st.slider("R² 'strong' threshold", 0.5, 0.95,
                                        0.70, 0.05)
        engine.GRADE_MODERATE = st.slider("R² 'moderate' threshold", 0.2, 0.7,
                                          0.50, 0.05)
        top = st.number_input("Top-N for hypotheses", 1, 50, 10)

    df, src = _load(uploaded, alpha)
    st.success("Loaded: %s — %d analyzable rows, %d trials, %d drugs."
               % (src, len(df), df["trial"].nunique(), df["drug"].nunique()))

    tab1, tab2, tab3, tab4, tab5 = st.tabs(
        ["Chain (per-stage)", "STE curves", "PTE mediation",
         "NIT ranking", "Gaps & hypotheses"])

    # --- Tab 1: per-stage table + scatter ----------------------------------
    with tab1:
        rows = []
        for stage_name, up, down in engine.STAGES:
            r = engine.analyze_stage(df, stage_name, up, down, alpha=alpha)
            rows.append({
                "stage": r["stage"], "n": r["n"],
                "R²_trial": round(r["r2"], 3) if not np.isnan(r["r2"]) else None,
                "R² CI low": round(r["r2_ci"][0], 3)
                if not np.isnan(r["r2_ci"][0]) else None,
                "R² CI high": round(r["r2_ci"][1], 3)
                if not np.isnan(r["r2_ci"][1]) else None,
                "slope": round(r["slope"], 3) if not np.isnan(r["slope"]) else None,
                "grade": r["grade"],
                "paradox rows": len(r["paradox"]),
            })
        st.subheader("Stage-by-stage surrogacy")
        st.dataframe(pd.DataFrame(rows), use_container_width=True)
        cols = st.columns(3)
        for c, (stage_name, up, down) in zip(cols, engine.STAGES):
            with c:
                fig, _ = _stage_scatter(df, stage_name, up, down, alpha)
                st.pyplot(fig)

    # --- Tab 2: STE curves -------------------------------------------------
    with tab2:
        st.subheader("Surrogate threshold effect (STE) curves")
        cols = st.columns(3)
        for c, (stage_name, up, down) in zip(cols, engine.STAGES):
            with c:
                st.pyplot(_ste_curve(df, stage_name, up, down, alpha))

    # --- Tab 3: PTE mediation ---------------------------------------------
    with tab3:
        st.subheader("PTE mediation: does histology mediate NIT→hard?")
        pte = engine.proportion_treatment_effect(df)
        st.pyplot(_pte_diagram(pte))
        st.json({
            "trials_full_chain": pte["n_full_trials"],
            "PTE_raw": None if np.isnan(pte["pte_raw"]) else round(pte["pte_raw"], 3),
            "PTE_clamped": None if np.isnan(pte["pte_clamped"])
            else round(pte["pte_clamped"], 3),
            "b_total": None if np.isnan(pte["b_total"]) else round(pte["b_total"], 3),
            "b_direct": None if np.isnan(pte["b_direct"]) else round(pte["b_direct"], 3),
            "plausible": pte["plausible"],
        })
        if pte["note"]:
            st.warning(pte["note"])

    # --- Tab 4: NIT ranking ------------------------------------------------
    with tab4:
        st.subheader("Per-NIT surrogacy ranking (NIT→histology)")
        rrows = []
        for nit in engine.NIT_METRICS:
            r = engine.analyze_stage(df, "NIT->histology", "nit", "histo",
                                     nit_filter=nit, alpha=alpha)
            rrows.append({
                "NIT": nit, "n": r["n"],
                "R²_trial": round(r["r2"], 3) if not np.isnan(r["r2"]) else None,
                "grade": r["grade"],
            })
        rdf = pd.DataFrame(rrows).sort_values(
            "R²_trial", ascending=False, na_position="last")
        st.dataframe(rdf, use_container_width=True)

    # --- Tab 5: gaps + hypothesis cards -----------------------------------
    with tab5:
        st.subheader("Mined unvalidated stages")
        gaps = engine.mine_gaps(df, alpha=alpha)
        st.dataframe(pd.DataFrame([
            {"stage": g["scope"], "scope": g["by"], "n": g["n"],
             "reasons": "; ".join(g["reasons"])}
            for g in gaps]), use_container_width=True)

        st.subheader("Auto-generated validation hypotheses")
        # reuse engine logic by capturing its structured output
        ev = engine.required_events(engine.DEFAULT_TARGET_HR)
        n_total = engine.required_sample_size(engine.DEFAULT_TARGET_HR)
        shown = 0
        for g in gaps:
            if g["stage"] == "histology->hard" and shown < top:
                with st.container(border=True):
                    st.markdown("**%s**" % g["by"])
                    st.write("Hypothesis: Is histology a valid trial-level "
                             "surrogate for hard hepatic outcomes here?")
                    st.write("Required ≈ **%d events**, **≈%d participants**, "
                             "**~%.1f yr** follow-up (HR=%.2f, Schoenfeld)."
                             % (ev, n_total, engine.DEFAULT_FOLLOWUP_YEARS,
                                engine.DEFAULT_TARGET_HR))
                    st.caption("rationale: " + "; ".join(g["reasons"]))
                shown += 1

    st.markdown("---")
    st.markdown(DISCLAIMER_MD)


if __name__ == "__main__":
    render()
