"""DMCausalTriangulate-Kor — Streamlit UI.

Run:
    python3 -m streamlit run app.py
or:
    streamlit run app.py

연구·교육용 도구. 임상의사결정에 직접 사용 금지.
"""
from __future__ import annotations

import os
import sys
from typing import Any

# Ensure local package is importable when run via `streamlit run`
HERE = os.path.dirname(os.path.abspath(__file__))
if HERE not in sys.path:
    sys.path.insert(0, HERE)

import streamlit as st  # noqa: E402

from triangulation.ontology import (  # noqa: E402
    load_drugs,
    load_outcomes,
    load_effects,
    list_pairs,
    filter_pair,
)
from triangulation.grid import (  # noqa: E402
    build_grid,
    concordance_score,
    grid_summary,
    DESIGNS,
)
from triangulation.lawlor import score_lawlor_criteria  # noqa: E402
from triangulation.bias import diagnose_discordance  # noqa: E402
from triangulation.designs import recommend_designs  # noqa: E402
from triangulation.report import build_markdown_report  # noqa: E402

DISCLAIMER = (
    "**⚠️ 본 도구는 연구 가설 생성·문헌 갭 분석 목적의 연구용·교육용 도구입니다. "
    "임상의사결정·환자 진료에 직접 사용해서는 안 되며, 모든 분석 결과는 추가 검증이 필요합니다.**"
)


def _color_for_concordance(c: float) -> str:
    if c >= 0.70:
        return "#1b7d3b"  # green
    if c >= 0.50:
        return "#caa033"  # amber
    return "#b3331c"  # red


def _try_plotly_heatmap(rows: list[dict[str, Any]]):
    try:
        import plotly.graph_objects as go
        import pandas as pd
    except Exception:
        return None
    df = pd.DataFrame(rows)
    if df.empty:
        return None
    pivot = df.pivot_table(
        index="drug_class", columns="outcome", values="concordance", aggfunc="mean"
    )
    fig = go.Figure(
        data=go.Heatmap(
            z=pivot.values,
            x=list(pivot.columns),
            y=list(pivot.index),
            colorscale="RdYlGn",
            zmin=0,
            zmax=1,
            colorbar=dict(title="concordance"),
        )
    )
    fig.update_layout(
        title="Concordance heatmap — drug class × outcome",
        xaxis_title="outcome",
        yaxis_title="drug class",
        height=420,
    )
    return fig


def _try_lawlor_radar(lawlor: dict[str, Any]):
    try:
        import plotly.graph_objects as go
    except Exception:
        return None
    names = [c["name"] for c in lawlor["criteria"]]
    scores = [c["score"] for c in lawlor["criteria"]]
    fig = go.Figure()
    fig.add_trace(
        go.Scatterpolar(r=scores + [scores[0]], theta=names + [names[0]], fill="toself")
    )
    fig.update_layout(
        polar=dict(radialaxis=dict(range=[0, 2])),
        showlegend=False,
        title="Lawlor 5-criterion (0–2 each)",
        height=420,
    )
    return fig


def main() -> None:
    st.set_page_config(
        page_title="DMCausalTriangulate-Kor",
        layout="wide",
        page_icon="🔺",
    )
    st.title("DMCausalTriangulate-Kor — 5-design causal triangulation (T2DM)")
    st.warning(DISCLAIMER)
    st.caption(
        "T2DM 약물(SGLT2i · GLP-1RA · tirzepatide · DPP-4i · metformin 등) × outcome 격자에 대해 "
        "RCT · observational · target-MR · ex vivo · within-subject 5-design effect를 통합하고 "
        "concordance/discordance를 정량화하며, 후속 연구 가설·설계를 생성합니다."
    )

    # ----------------------- Data load -----------------------
    st.sidebar.header("데이터")
    uploaded = st.sidebar.file_uploader(
        "사용자 effects CSV (선택)", type=["csv"], help="동일 컬럼 스키마"
    )
    data_dir = os.path.join(HERE, "data")
    if uploaded is not None:
        import csv
        import io

        rows = list(csv.DictReader(io.StringIO(uploaded.getvalue().decode("utf-8"))))
        for r in rows:
            for k in ("effect_estimate", "ci_low", "ci_high", "follow_up_years"):
                try:
                    r[k] = float(r.get(k, ""))
                except (TypeError, ValueError):
                    r[k] = None
            try:
                r["sample_size"] = int(float(r.get("sample_size", "")))
            except (TypeError, ValueError):
                r["sample_size"] = None
        effects = rows
        st.sidebar.success(f"업로드 CSV {len(rows)}행 사용 중")
    else:
        effects = load_effects(data_dir)

    drugs = load_drugs(data_dir)
    outcomes_meta = load_outcomes(data_dir)

    # ----------------------- Sidebar filters -----------------------
    st.sidebar.header("필터")
    drug_classes = sorted({r["drug_class"] for r in effects})
    outcome_categories = sorted({m["outcome_category"] for m in outcomes_meta})

    sel_drugs = st.sidebar.multiselect("Drug class", drug_classes, default=drug_classes)
    sel_categories = st.sidebar.multiselect(
        "Outcome category", outcome_categories, default=outcome_categories
    )

    # Apply filters
    outcomes_in_cat = {
        m["outcome"] for m in outcomes_meta if m["outcome_category"] in sel_categories
    }
    filtered_effects = [
        r
        for r in effects
        if r["drug_class"] in sel_drugs and r["outcome"] in outcomes_in_cat
    ]

    # ----------------------- Tabs -----------------------
    tab_overview, tab_pair, tab_discord, tab_report = st.tabs(
        ["격자 개요", "쌍 분석", "Discordance 랭킹", "리포트"]
    )

    # --- Overview ---
    with tab_overview:
        st.subheader("Concordance heatmap")
        summary = grid_summary(filtered_effects)
        if not summary:
            st.info("필터 조건에 일치하는 쌍이 없습니다.")
        else:
            fig = _try_plotly_heatmap(summary)
            if fig is not None:
                st.plotly_chart(fig, use_container_width=True)
            else:
                st.info("plotly 미설치 — 표만 표시합니다.")
            st.dataframe(summary, use_container_width=True)

    # --- Pair analysis ---
    with tab_pair:
        st.subheader("단일 약물-outcome 쌍 triangulation")
        pairs = list_pairs(filtered_effects)
        if not pairs:
            st.info("필터 조건에 쌍이 없습니다.")
            return
        labels = [f"{d} × {o}" for d, o in pairs]
        idx = st.selectbox("쌍 선택", range(len(labels)), format_func=lambda i: labels[i])
        drug_class, outcome = pairs[idx]
        pair_effects = filter_pair(filtered_effects, drug_class, outcome)
        conc = concordance_score(pair_effects)
        lawlor = score_lawlor_criteria(pair_effects)
        biases = diagnose_discordance(pair_effects)
        target = next(
            (d["target"] for d in drugs if d["drug_class"] == drug_class),
            "the canonical drug target",
        )
        designs = recommend_designs(biases, drug_class, outcome, target=target)

        c1, c2, c3 = st.columns(3)
        c1.metric("Concordance", conc["concordance"])
        c2.metric("Designs present", f"{conc['designs_present']}/5")
        c3.metric(
            "Lawlor total",
            f"{lawlor['total']}/{lawlor['max']}",
            help=f"normalized {lawlor['normalized']}",
        )

        color = _color_for_concordance(conc["concordance"])
        st.markdown(
            f"<div style='padding:6px;border-left:6px solid {color};background:#f5f5f5'>"
            f"방향 일치 {conc['direction_agreement']} · CI 중첩 {conc['ci_overlap_fraction']} · "
            f"다수 방향 {conc['majority_direction']}</div>",
            unsafe_allow_html=True,
        )

        st.markdown("#### 5-design effect grid")
        grid = build_grid(pair_effects)
        # Render in canonical design order, with missing designs shown as gaps.
        for d in DESIGNS:
            with st.expander(f"{d} — {len(grid.get(d, []))} record(s)", expanded=False):
                if not grid.get(d):
                    st.write("_(no effects available for this design)_")
                else:
                    st.table(
                        [
                            {
                                "drug": r["drug"],
                                "effect": r["effect_estimate"],
                                "ci_low": r["ci_low"],
                                "ci_high": r["ci_high"],
                                "n": r["sample_size"],
                                "f/u (y)": r["follow_up_years"],
                                "population": r["population"],
                                "source": r["source_citation"],
                                "url": r["source_url"],
                            }
                            for r in grid[d]
                        ]
                    )

        st.markdown("#### Lawlor 2016 5-criterion")
        radar = _try_lawlor_radar(lawlor)
        if radar is not None:
            st.plotly_chart(radar, use_container_width=True)
        for c in lawlor["criteria"]:
            st.write(f"- **{c['name']}**: {c['score']}/{c['max']} — {c['rationale']}")

        st.markdown("#### Discordance bias direction (ranked)")
        if not biases:
            st.success("Discordance pattern triggered: 없음 (designs concordant).")
        else:
            for b in biases:
                st.write(f"- **{b['bias_type']}** (score {b['score']}): {b['rationale']}")
                st.caption(b.get("definition", ""))

        st.markdown("#### Triangulation-targeted follow-up design 카드")
        for i, c in enumerate(designs, 1):
            with st.expander(
                f"{i}. {c['name']}  (match score {c['match_score']})",
                expanded=(i <= 2),
            ):
                st.write(f"**primary endpoint**: {c['primary_endpoint']}")
                st.write(f"**estimated n**: {c['estimated_sample_size']}")
                st.write(f"**follow-up**: {c['follow_up_years']} years")
                st.write(f"**addresses biases**: {', '.join(c['addresses_biases'])}")
                st.write(f"**key hypothesis**: {c['key_hypothesis']}")
                st.caption(c["rationale"])

    # --- Discordance ranking ---
    with tab_discord:
        st.subheader("가장 discordant한 쌍")
        rows = grid_summary(filtered_effects)
        enriched = []
        for r in rows:
            pair = filter_pair(filtered_effects, r["drug_class"], r["outcome"])
            biases = diagnose_discordance(pair)
            r2 = dict(r)
            r2["top_bias"] = biases[0]["bias_type"] if biases else None
            r2["top_bias_score"] = biases[0]["score"] if biases else 0
            r2["n_bias_hits"] = len(biases)
            enriched.append(r2)
        enriched.sort(
            key=lambda r: (r["concordance"], -r["top_bias_score"], -r["n_bias_hits"])
        )
        st.dataframe(enriched, use_container_width=True)

    # --- Report ---
    with tab_report:
        st.subheader("Markdown 리포트 생성")
        pairs = list_pairs(filtered_effects)
        if not pairs:
            st.info("쌍 없음.")
            return
        labels = [f"{d} × {o}" for d, o in pairs]
        idx = st.selectbox(
            "리포트 대상 쌍",
            range(len(labels)),
            format_func=lambda i: labels[i],
            key="report_pair",
        )
        drug_class, outcome = pairs[idx]
        pair_effects = filter_pair(filtered_effects, drug_class, outcome)
        conc = concordance_score(pair_effects)
        lawlor = score_lawlor_criteria(pair_effects)
        biases = diagnose_discordance(pair_effects)
        target = next(
            (d["target"] for d in drugs if d["drug_class"] == drug_class),
            "the canonical drug target",
        )
        designs = recommend_designs(biases, drug_class, outcome, target=target)
        md = build_markdown_report(
            (drug_class, outcome), pair_effects, conc, lawlor, biases, designs
        )
        st.download_button(
            "리포트 .md 다운로드",
            data=md.encode("utf-8"),
            file_name=f"triangulation_{drug_class}_{outcome}.md".replace(" ", "_"),
            mime="text/markdown",
        )
        st.code(md[:4000] + ("\n...\n" if len(md) > 4000 else ""), language="markdown")


if __name__ == "__main__":
    main()
