"""MASLDTriangulate-Kor — Streamlit standalone UI.

Run:
    streamlit run app.py
"""
from __future__ import annotations

import io
import json
import os

try:
    import streamlit as st
except ImportError:
    raise SystemExit("Streamlit 가 설치되지 않았습니다. `pip install -r requirements.txt`")

from triangulation import DISCLAIMER
from triangulation.bias import diagnose_discordance, BIAS_TAXONOMY
from triangulation.designs import proposed_designs
from triangulation.genotype import subgroup_treatment_effect
from triangulation.grid import build_grid, design_summary, concordance_score, grid_as_table
from triangulation.lawlor import score_pair, korean_ancestry_layer_score
from triangulation.mediation import masld_mediated_fraction
from triangulation.mvmr import mvmr_decompose
from triangulation.ontology import (
    load_effects, load_outcomes, load_stages, load_instruments,
    load_korean_af, filter_effects, DESIGN_ORDER, STAGE_ORDER,
)
from triangulation.report import build_full_report, report_to_markdown


st.set_page_config(page_title="MASLDTriangulate-Kor", layout="wide")


@st.cache_data
def _load(default=True, custom_bytes=None):
    if custom_bytes:
        import tempfile
        tf = tempfile.NamedTemporaryFile(mode="wb", suffix=".csv", delete=False)
        tf.write(custom_bytes)
        tf.close()
        return load_effects(tf.name)
    return load_effects()


st.title("MASLDTriangulate-Kor")
st.caption("MASLD 단계 × multi-outcome 5-design 인과추론 triangulation")
st.warning(DISCLAIMER)

# ─── Sidebar ──────────────────────────────────────────────────────────────────
with st.sidebar:
    st.header("데이터")
    up = st.file_uploader("effects_sample.csv 업로드 (선택)", type=["csv"])
    effects = _load(custom_bytes=up.getvalue() if up else None)
    st.caption(f"effects rows: {len(effects)}")

    st.header("필터")
    stages_unique = sorted({e.get("masld_stage") for e in effects if e.get("masld_stage")})
    outcomes_unique = sorted({e.get("outcome") for e in effects if e.get("outcome")})
    instruments_unique = sorted({e.get("mr_instrument") for e in effects
                                  if e.get("mr_instrument")})
    sel_stage = st.selectbox("MASLD stage", ["(전체)"] + stages_unique)
    sel_outcome = st.selectbox("Outcome", ["(전체)"] + outcomes_unique)
    sel_design = st.selectbox("Design", ["(전체)"] + DESIGN_ORDER)
    sel_ancestry = st.selectbox("Ancestry", ["(전체)", "Korean", "European", "global", "Asian"])
    sel_instr = st.selectbox("MR instrument", ["(전체)"] + instruments_unique)

filtered = filter_effects(
    effects,
    stage=None if sel_stage == "(전체)" else sel_stage,
    outcome=None if sel_outcome == "(전체)" else sel_outcome,
    design=None if sel_design == "(전체)" else sel_design,
    ancestry=None if sel_ancestry == "(전체)" else sel_ancestry,
    instrument=None if sel_instr == "(전체)" else sel_instr,
)

tab_grid, tab_pair, tab_med, tab_geno, tab_korean, tab_designs, tab_taxonomy, tab_report = st.tabs(
    ["Grid", "Pair", "Mediation", "Genotype", "Korean layer", "Design cards", "Bias taxonomy", "Report"]
)

# ─── Tab: Grid ────────────────────────────────────────────────────────────────
with tab_grid:
    grid = build_grid(filtered if filtered else effects)
    st.subheader(f"5-design grid ({len(grid)} pairs)")
    rows = grid_as_table(grid)
    st.dataframe(
        {col: [r[i] for r in rows[1:]] for i, col in enumerate(rows[0])},
        use_container_width=True,
    )

    n_disc = sum(1 for g in grid if g["concordance"]["label"] == "DISCORDANT")
    n_high = sum(1 for g in grid if g["concordance"]["label"] == "HIGH_CONCORDANCE")
    c1, c2, c3 = st.columns(3)
    c1.metric("HIGH concordance", n_high)
    c2.metric("DISCORDANT", n_disc)
    c3.metric("Pairs analyzed", len(grid))

# ─── Tab: Pair ────────────────────────────────────────────────────────────────
with tab_pair:
    cols = st.columns(2)
    pair_stage = cols[0].selectbox("Stage", stages_unique, key="pair_s")
    pair_outcome = cols[1].selectbox("Outcome", outcomes_unique, key="pair_o")
    summary = design_summary(effects, pair_stage, pair_outcome)
    conc = concordance_score(summary)
    law = score_pair(summary)
    st.markdown(f"### Concordance: **{conc['label']}** (score {conc['score']})")
    st.markdown(f"Lawlor 5-criterion: **{law['total']}/5**")
    st.json(law["criteria"])
    st.markdown("#### Per-design")
    for d, v in summary.items():
        st.write(f"**{d}** — n={v['n']}, effect={v['effect']}, "
                 f"CI=({v['ci_low']}, {v['ci_high']}), dir={v['direction']}")

# ─── Tab: Mediation ──────────────────────────────────────────────────────────
with tab_med:
    cols = st.columns(2)
    ms = cols[0].selectbox("Stage", stages_unique, key="med_s")
    mo = cols[1].selectbox("Outcome", outcomes_unique, key="med_o")
    med = masld_mediated_fraction(effects, ms, mo)
    mv = mvmr_decompose(effects, ms, mo)
    if med.get("mediation_available"):
        c1, c2, c3 = st.columns(3)
        c1.metric("MASLD-mediated", f"{med['frac_masld_mediated']*100:.0f}%")
        c2.metric("Metabolic-mediated", f"{med['frac_metabolic_mediated']*100:.0f}%")
        c3.metric("Residual (unmeasured)", med["residual_effect_unmeasured"])
        st.info(med.get("interpretation", ""))
    else:
        st.warning(med.get("reason", "Mediation 미산정"))
    st.markdown("### MVMR decomposition")
    st.json(mv)

# ─── Tab: Genotype ──────────────────────────────────────────────────────────
with tab_geno:
    g = st.selectbox("Gene", ["PNPLA3", "HSD17B13", "TM6SF2", "MBOAT7", "GCKR"])
    d = st.selectbox("Drug", ["resmetirom", "semaglutide_2.4mg", "tirzepatide_15mg"])
    o = st.selectbox("Outcome (hypothesis target)", outcomes_unique, key="geno_o")
    res = subgroup_treatment_effect(d, g, o)
    if res.get("available"):
        st.write(f"**Rationale:** {res['rationale']}")
        st.json(res["subgroup_modifiers"])
        st.write(f"한국 allele frequency: **{res['korean_allele_frequency']}**")
        st.write("한국 genotype 빈도 (Hardy-Weinberg):")
        st.json(res["korean_genotype_frequencies"])
        if res.get("discordance_flag"):
            st.warning(res["discordance_flag"])
        st.caption(res["study_design_implication"])
    else:
        st.warning(res.get("reason"))

# ─── Tab: Korean layer ──────────────────────────────────────────────────────
with tab_korean:
    layer = korean_ancestry_layer_score(effects)
    c1, c2, c3 = st.columns(3)
    c1.metric("Korean rows", layer["n_korean_rows"])
    c2.metric("Asian rows (포함)", layer["n_asian_rows"])
    c3.metric("External validity", layer["external_validity_for_korean"])
    st.write("Korean designs covered:", layer["korean_designs_covered"])

    korean_rows = [e for e in effects if (e.get("ancestry") or "").lower() == "korean"]
    if korean_rows:
        st.markdown("### Korean ancestry rows")
        for r in korean_rows:
            st.write(f"- **{r.get('masld_stage')}** × **{r.get('outcome')}** "
                     f"[{r.get('design')}] effect={r.get('effect_estimate')} "
                     f"— {r.get('source_citation')}")

    st.markdown("### 한국 ancestry allele frequency")
    af = load_korean_af()
    st.dataframe(
        {col: [r.get(col) for r in af]
         for col in ["gene", "rsid", "risk_allele", "af_global",
                     "af_east_asian", "af_korean", "note"]},
        use_container_width=True,
    )

# ─── Tab: Designs ──────────────────────────────────────────────────────────
with tab_designs:
    pdres = proposed_designs(effects)
    st.subheader(f"{pdres['n_cards']} triangulation-targeted 후속 design")
    for c in pdres["cards"]:
        with st.expander(f"[{c['tier']}] {c['type']}"):
            st.write(f"**Rationale:** {c['rationale']}")
            st.write(f"**Primary endpoint:** {c['primary_endpoint']}")
            st.write(f"**Sample size:** {c['sample_size']}  /  Follow-up: {c['follow_up']}")
            st.write(f"**Hypothesis:** {c['key_hypothesis']}")
            if c.get("mr_instruments_required"):
                st.write(f"**IVs:** {c['mr_instruments_required']}")
            if c.get("genotype_stratification"):
                st.info("genotype-stratified")

# ─── Tab: Bias taxonomy ─────────────────────────────────────────────────────
with tab_taxonomy:
    st.subheader("Bias taxonomy (MASLD-specific)")
    for b in BIAS_TAXONOMY:
        with st.expander(f"{b['code']} — {b['name']}"):
            st.write(b["description"])
            st.caption(f"Remedy: {b['remedy']}")

# ─── Tab: Report ───────────────────────────────────────────────────────────
with tab_report:
    rep = build_full_report(effects)
    md = report_to_markdown(rep)
    st.download_button("Download Markdown report", md, file_name="masld_triangulate_report.md")
    st.download_button(
        "Download JSON report",
        json.dumps(rep, ensure_ascii=False, indent=2, default=str),
        file_name="masld_triangulate_report.json",
    )
    st.markdown(md)
