"""ObesityTriangulate-Kor — Streamlit UI.

사이드바: intervention/outcome 카테고리/도메인 필터.
메인:
- 6-design grid heatmap
- concordance + Lawlor 5-criterion
- weight-loss-mediated 비율 차트
- MVMR forest
- discordance bias ranked
- 8 design 카드
- CSV upload
"""
from __future__ import annotations

import io
import json
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent))

import pandas as pd

try:
    import streamlit as st
except ImportError:
    print("streamlit not installed. install: pip install streamlit")
    sys.exit(1)

from triangulation import ontology, grid, lawlor, mvmr, mediation, bias, designs, report

DISCLAIMER = (
    "⚠️ 본 도구는 연구 가설 생성·문헌 갭 분석 목적의 연구용·교육용 도구입니다. "
    "임상의사결정·환자 진료에 직접 사용해서는 안 되며, 모든 분석 결과는 추가 검증이 필요합니다. "
    "비만 약물·수술 선택은 반드시 자격 있는 임상의와 상담 후 결정해야 합니다."
)


def _load_effects_from_upload(upload):
    if upload is None:
        return None
    return pd.read_csv(upload)


def main():
    st.set_page_config(page_title="ObesityTriangulate-Kor", layout="wide")
    st.title("ObesityTriangulate-Kor")
    st.caption("항비만 intervention 인과추론 triangulation 도구 (오베시티트라이앵귤레이트코어)")
    st.warning(DISCLAIMER)

    st.sidebar.header("필터")
    upload = st.sidebar.file_uploader("CSV 업로드 (effects)", type=["csv"])
    custom_df = _load_effects_from_upload(upload)
    if custom_df is not None:
        st.sidebar.success(f"업로드: {len(custom_df)} rows")
        # 임시 파일 저장하여 사용
        tmp = Path(__file__).resolve().parent / "data" / "_uploaded.csv"
        custom_df.to_csv(tmp, index=False)
        # ontology.load_effects는 effects_sample.csv를 읽으므로
        # 사용자가 명시적으로 활성화하도록 안내
        st.sidebar.caption("업로드 사용 시 effects_sample.csv를 백업 후 교체하세요.")

    effects_df = pd.DataFrame(ontology.load_effects())
    interventions = sorted(effects_df["intervention"].unique().tolist())
    outcomes = sorted(effects_df["outcome"].unique().tolist())
    categories = sorted(effects_df["outcome_category"].unique().tolist())

    cat_filter = st.sidebar.multiselect("Outcome 카테고리", categories, default=categories)
    iv_filter = st.sidebar.multiselect("Intervention", interventions, default=interventions)

    tab_grid, tab_pair, tab_med, tab_mvmr, tab_disc, tab_design, tab_report = st.tabs([
        "6-Design Grid", "Pair Detail", "Mediation",
        "MVMR", "Discordance", "Designs", "Report"
    ])

    # ------- Grid -------
    with tab_grid:
        st.subheader("6-Design Grid — Concordance Heatmap")
        g = grid.build_grid()
        gdf = pd.DataFrame(g)
        gdf = gdf[gdf["intervention"].isin(iv_filter)
                  & gdf["outcome_category"].isin(cat_filter)]
        if gdf.empty:
            st.info("필터 조건에 맞는 데이터 없음.")
        else:
            pivot = gdf.pivot_table(index="intervention", columns="outcome",
                                    values="score", aggfunc="first")
            st.dataframe(pivot.style.background_gradient(cmap="RdYlGn", vmin=0, vmax=1),
                         use_container_width=True)
            st.dataframe(gdf, use_container_width=True)

    # ------- Pair -------
    with tab_pair:
        col1, col2 = st.columns(2)
        with col1:
            sel_iv = st.selectbox("Intervention", interventions, key="pair_iv")
        with col2:
            sel_oc = st.selectbox("Outcome", outcomes, key="pair_oc")
        rows = grid.get_pair_rows(sel_iv, sel_oc)
        if not rows:
            st.info(f"({sel_iv}, {sel_oc}) 데이터 없음.")
        else:
            rdf = pd.DataFrame(rows)
            st.dataframe(rdf[["design", "effect_estimate", "ci_low", "ci_high",
                              "sample_size", "follow_up_years", "population",
                              "source_citation"]], use_container_width=True)
            conc = grid.concordance_score(rows)
            law = lawlor.lawlor_score(sel_iv, sel_oc)
            colA, colB = st.columns(2)
            with colA:
                st.metric("Concordance score", f"{conc['score']:.3f}",
                          help="0~1; direction agreement, CI overlap, design diversity 가중합")
                st.metric("Direction agreement", f"{conc['direction_agreement']:.2f}")
                st.metric("Majority direction", conc["majority_direction"])
            with colB:
                st.metric("Lawlor 5-criterion total", f"{law['total']}/{law['max']}")
                st.metric("# designs", conc["n_designs"])
                st.metric("Design diversity", f"{conc['design_diversity']:.2f}")
            st.json(law["criteria"])

    # ------- Mediation -------
    with tab_med:
        st.subheader("Weight-loss-mediated vs direct effect")
        col1, col2 = st.columns(2)
        with col1:
            iv = st.selectbox("Intervention", interventions, key="med_iv")
        with col2:
            oc = st.selectbox("Outcome", outcomes, key="med_oc")
        med = mediation.mediation_for_pair(iv, oc)
        st.json(med)
        # 차트
        pct = med.get("curated_mediated_pct")
        if pct is not None:
            cdf = pd.DataFrame({
                "component": ["weight-loss mediated", "direct (weight-independent)"],
                "fraction": [pct / 100, 1 - pct / 100],
            })
            st.bar_chart(cdf.set_index("component"))
        st.caption(med.get("interpretation", ""))

    # ------- MVMR -------
    with tab_mvmr:
        st.subheader("Multivariable MR (BMI / WHR / body fat %)")
        col1, col2 = st.columns(2)
        with col1:
            iv = st.selectbox("Intervention", interventions, key="mvmr_iv")
        with col2:
            oc = st.selectbox("Outcome", outcomes, key="mvmr_oc")
        m = mvmr.mvmr_for_pair(iv, oc)
        st.json(m)
        if "bmi_conditional_OR" in m:
            fdf = pd.DataFrame({
                "phenotype": ["BMI conditional", "WHR conditional", "BF% conditional"],
                "OR": [m["bmi_conditional_OR"], m["whr_conditional_OR"], m["bf_conditional_OR"]],
            })
            st.bar_chart(fdf.set_index("phenotype"))

    # ------- Discordance -------
    with tab_disc:
        st.subheader("Discordant pairs (direction agreement < 0.7)")
        top_n = st.slider("Top N", 1, 30, 10)
        d = grid.find_discordant(top_n=top_n)
        if not d:
            st.info("discordant pair 없음 (모든 grid 합의)")
        else:
            st.dataframe(pd.DataFrame(d), use_container_width=True)
            for item in d:
                with st.expander(f"{item['intervention']} × {item['outcome']}"):
                    diag = bias.diagnose_discordance(item["intervention"], item["outcome"])
                    st.json(diag)

    # ------- Designs -------
    with tab_design:
        st.subheader("8 Triangulation-targeted design cards")
        col1, col2 = st.columns(2)
        with col1:
            iv = st.selectbox("Intervention", interventions, key="des_iv")
        with col2:
            oc = st.selectbox("Outcome", outcomes, key="des_oc")
        cards = designs.recommend_designs(iv, oc)
        for i, c in enumerate(cards, 1):
            with st.expander(f"{i}. {c['type']}"):
                st.write(f"**Rationale**: {c['rationale']}")
                st.write(f"**Primary endpoint**: {c.get('primary_endpoint')}")
                st.write(f"**Sample size**: {c.get('sample_size_estimate')}")
                st.write(f"**Follow-up**: {c.get('follow_up_years')}년")
                st.write(f"**Hypothesis**: {c.get('key_hypothesis')}")
                st.write(f"**Key threats**: {c.get('key_threats')}")
                if c.get("multivariable_MR_instruments") and c["multivariable_MR_instruments"] != "N/A":
                    st.write(f"**MR instruments**: {c['multivariable_MR_instruments']}")

    # ------- Report -------
    with tab_report:
        st.subheader("통합 리포트 export (HTA · 규제 · KASMBS · OpenClaw)")
        col1, col2 = st.columns(2)
        with col1:
            iv = st.selectbox("Intervention", interventions, key="rep_iv")
        with col2:
            oc = st.selectbox("Outcome", outcomes, key="rep_oc")
        fmt = st.radio("Format", ["json", "markdown", "docx (python-docx 필요)"], horizontal=True)
        if st.button("리포트 생성"):
            r = report.build_report_dict(iv, oc)
            if fmt == "json":
                payload = json.dumps(r, ensure_ascii=False, indent=2)
                st.download_button("Download JSON", payload,
                                   file_name=f"report_{iv}_{oc}.json",
                                   mime="application/json")
                st.json(r)
            elif fmt.startswith("markdown"):
                md = report.to_markdown(r)
                st.download_button("Download MD", md,
                                   file_name=f"report_{iv}_{oc}.md",
                                   mime="text/markdown")
                st.markdown(md)
            else:
                tmp = Path("/tmp") / f"report_{iv}_{oc}.docx"
                saved = report.to_docx(r, str(tmp))
                if saved and saved.endswith(".docx"):
                    with open(saved, "rb") as f:
                        st.download_button("Download DOCX", f.read(),
                                           file_name=tmp.name)
                else:
                    st.warning("python-docx 미설치, markdown으로 fallback")
                    with open(saved, "r", encoding="utf-8") as f:
                        st.download_button("Download MD (fallback)", f.read(),
                                           file_name=Path(saved).name)


if __name__ == "__main__":
    main()
