"""
app.py — ObesityTrialProtocolAmend-Kor Streamlit UI.

5개 registry 항비만 trial protocol amendment audit trail을 시각화한다.
실행: streamlit run app.py
"""
from __future__ import annotations

import io
import json
import os
from datetime import date

import pandas as pd
import streamlit as st

from core import (
    DISCLAIMER,
    REGISTRY_ORDER,
    aggregate_by,
    build_weekly_digest,
    filter_by_registry,
    filter_obesity,
    korean_sponsor_highlight,
    leading_indicator_alerts,
    load_rules,
    load_trials,
    rob2_view,
    score_all_trials,
)

st.set_page_config(
    page_title="ObesityTrialProtocolAmend-Kor",
    page_icon=":pill:",
    layout="wide",
)


# ---------------------------------------------------------------------------
# 데이터 캐시
# ---------------------------------------------------------------------------
@st.cache_data(show_spinner=False)
def _load():
    trials = load_trials()
    rules = load_rules()
    return trials, rules


def _scores_df(scores) -> pd.DataFrame:
    rows = [s.to_dict() for s in scores]
    df = pd.DataFrame(rows)
    return df


# ---------------------------------------------------------------------------
# Sidebar
# ---------------------------------------------------------------------------
def sidebar_filters(trials, rules):
    st.sidebar.title("ObesityTrialProtocolAmend-Kor")
    st.sidebar.caption("5개 registry 항비만 trial amendment 감시 (mock data)")
    st.sidebar.warning(DISCLAIMER)

    st.sidebar.subheader("필터")
    regs = sorted({t["registry"] for t in trials})
    pick_regs = st.sidebar.multiselect(
        "Registry", regs, default=regs, help="ClinicalTrials.gov / EudraCT / jRCT / CRIS-Korea / ANZCTR"
    )
    classes = sorted({t["drug_class"] for t in trials})
    pick_class = st.sidebar.multiselect("Drug class", classes, default=classes)
    countries = sorted({t["country"] for t in trials})
    pick_country = st.sidebar.multiselect("Country", countries, default=countries)
    statuses = sorted({t["status"] for t in trials})
    pick_status = st.sidebar.multiselect("Status", statuses, default=statuses)
    only_korea = st.sidebar.checkbox("한국 sponsor / 한국 site만 표시", value=False)

    min_crisis = st.sidebar.slider("최소 위기점수", 0, 15, 0)

    th = rules.get("crisis_score_thresholds", {})
    st.sidebar.markdown(
        f"**임계치** · ALERT ≥ {th.get('alert',6)}점 / WATCH ≥ {th.get('watch',3)}점"
    )

    return {
        "registries": pick_regs,
        "classes": pick_class,
        "countries": pick_country,
        "statuses": pick_status,
        "only_korea": only_korea,
        "min_crisis": min_crisis,
    }


def apply_filters(trials, scores, f):
    sel_ids = set()
    for s in scores:
        if s.registry not in f["registries"]:
            continue
        if s.drug_class not in f["classes"]:
            continue
        if s.country not in f["countries"]:
            continue
        if s.status not in f["statuses"]:
            continue
        if s.crisis_score < f["min_crisis"]:
            continue
        if f["only_korea"]:
            kr_kw = ["LG", "종근당", "Chong Kun Dang", "Dong-A", "동아", "Ildong", "일동", "Severance", "Yuhan"]
            if not (s.korean_site or any(k.lower() in s.sponsor.lower() for k in kr_kw)):
                continue
        sel_ids.add(s.trial_id)
    trials_f = [t for t in trials if t["trial_id"] in sel_ids]
    scores_f = [s for s in scores if s.trial_id in sel_ids]
    return trials_f, scores_f


# ---------------------------------------------------------------------------
# Tabs
# ---------------------------------------------------------------------------
def tab_ingest(trials, scores):
    st.header("1. Registry ingest + 항비만 필터")
    st.caption(
        "ClinicalTrials.gov v2 + EudraCT + jRCT + CRIS-Korea + ANZCTR 통합 ingest 시뮬레이션. "
        "condition/intervention/drug class 기반으로 항비만 trial 자동 필터링."
    )

    c1, c2, c3, c4 = st.columns(4)
    c1.metric("필터 후 trial 수", len(scores))
    c2.metric("총 amendment 수", sum(len(s.amendment_scores) for s in scores))
    c3.metric(
        "평균 위기점수",
        f"{(sum(s.crisis_score for s in scores)/max(1,len(scores))):.2f}",
    )
    c4.metric("한국 site 포함 trial", sum(1 for s in scores if s.korean_site))

    df = _scores_df(scores)
    if df.empty:
        st.info("필터 결과가 없습니다. 사이드바를 조정해 보세요.")
        return
    st.dataframe(
        df[
            [
                "trial_id",
                "name",
                "sponsor",
                "country",
                "registry",
                "drug_class",
                "status",
                "crisis_score",
                "amendment_count",
                "korean_site",
            ]
        ].sort_values("crisis_score", ascending=False),
        use_container_width=True,
        hide_index=True,
    )

    st.subheader("Registry별 trial / amendment 분포")
    by_reg = aggregate_by(scores, "registry")
    if by_reg:
        reg_df = pd.DataFrame(by_reg, columns=["registry", "trial 수", "crisis 합"]).set_index("registry")
        st.bar_chart(reg_df["trial 수"])


def tab_amendments(trials, scores):
    st.header("2. Amendment delta 자동 추출")
    st.caption(
        "각 trial의 history version diff를 timestamp별 자동 parse — primary/secondary endpoint, "
        "sample size, exclusion, comparator, blinding, analysis plan, target completion date, status, sponsor, funding."
    )

    if not scores:
        st.info("표시할 trial이 없습니다.")
        return

    options = [f"{s.name} ({s.trial_id})" for s in scores]
    idx = st.selectbox("Trial 선택", range(len(options)), format_func=lambda i: options[i])
    s = scores[idx]
    trial = next(t for t in trials if t["trial_id"] == s.trial_id)

    c1, c2, c3 = st.columns(3)
    c1.metric("위기점수", s.crisis_score)
    c2.metric("Amendment 수", len(s.amendment_scores))
    c3.metric(
        "Readout D-day",
        "n/a" if s.readout_dday is None else (f"D-{s.readout_dday}" if s.readout_dday >= 0 else f"D+{abs(s.readout_dday)}"),
    )

    st.markdown(
        f"**Sponsor:** {trial['sponsor']}  |  **Country:** {trial['country']}  |  "
        f"**Drug class:** {trial['drug_class']}  |  **Status:** {trial['status']}  |  "
        f"**N:** {trial['current_n']}  |  **Primary completion:** {trial['primary_completion']}"
    )
    if s.flags:
        st.warning("플래그: " + ", ".join(s.flags))
    if s.termination_reason:
        st.error(f"종료/중단 사유 분류: {s.termination_reason}")

    rows = []
    for a in s.amendment_scores:
        am = a["amendment"]
        rows.append(
            {
                "date": am.get("date"),
                "type": am.get("type"),
                "before": am.get("before"),
                "after": am.get("after"),
                "reason": am.get("reason"),
                "note": am.get("note"),
                "score": a["score"],
                "matched_rules": ", ".join(r.get("label", r.get("id")) for r in a["matched_rules"]),
            }
        )
    if rows:
        df = pd.DataFrame(rows).sort_values("date")
        st.dataframe(df, use_container_width=True, hide_index=True)
    else:
        st.info("해당 trial에 amendment history가 없습니다.")


def tab_scoring(trials, scores, rules):
    st.header("3. Amendment 중요도 채점 + leading indicator alert")
    st.caption(
        "rule-based 채점. primary endpoint 변경(★★★), sample size 감소(★★★, futility), "
        "termination(★★★) 등 composite 'trial 위기 점수'와 alert."
    )

    th = rules.get("crisis_score_thresholds", {})
    alerts = leading_indicator_alerts(scores, rules)

    c1, c2 = st.columns(2)
    c1.metric("ALERT", sum(1 for a in alerts if a["level"] == "ALERT"))
    c2.metric("WATCH", sum(1 for a in alerts if a["level"] == "WATCH"))

    st.subheader("Alert 목록")
    if not alerts:
        st.info("현재 alert 없음.")
    else:
        adf = pd.DataFrame(alerts)
        adf["flags"] = adf["flags"].apply(lambda fs: ", ".join(fs) if fs else "")
        st.dataframe(adf, use_container_width=True, hide_index=True)

    st.subheader("채점 rule 카드")
    rcols = st.columns(2)
    for i, r in enumerate(rules.get("rules", [])):
        with rcols[i % 2]:
            st.markdown(
                f"**[{r['score']}점] {r['label']}** — `{r['id']}`"
            )
            st.caption(r.get("description", ""))

    st.subheader("종료/중단 사유 분류 (참고)")
    cats = rules.get("termination_reason_categories", [])
    cdf = pd.DataFrame(cats)
    if not cdf.empty:
        st.dataframe(cdf, use_container_width=True, hide_index=True)


def tab_dashboard(scores):
    st.header("4. Trial-level dashboard + 경쟁 landscape")
    st.caption("약물 class · sponsor · 국가별 amendment 빈도 / trial 위기 점수 비교.")

    if not scores:
        st.info("표시할 trial이 없습니다.")
        return

    c1, c2 = st.columns(2)
    with c1:
        st.subheader("Drug class별 위기점수 합")
        by_class = aggregate_by(scores, "drug_class")
        if by_class:
            dfc = pd.DataFrame(by_class, columns=["drug_class", "trial 수", "crisis 합"]).set_index("drug_class")
            st.bar_chart(dfc["crisis 합"])
    with c2:
        st.subheader("Sponsor별 위기점수 합 (상위 10)")
        by_spon = aggregate_by(scores, "sponsor")[:10]
        if by_spon:
            dfs = pd.DataFrame(by_spon, columns=["sponsor", "trial 수", "crisis 합"]).set_index("sponsor")
            st.bar_chart(dfs["crisis 합"])

    st.subheader("국가별 위기점수 합")
    by_country = aggregate_by(scores, "country")
    if by_country:
        dfk = pd.DataFrame(by_country, columns=["country", "trial 수", "crisis 합"]).set_index("country")
        st.bar_chart(dfk["crisis 합"])

    st.subheader("Readout D-day 카운트다운 (예정 readout 기준)")
    upcoming = [s for s in scores if s.readout_dday is not None]
    upcoming.sort(key=lambda x: x.readout_dday)
    if upcoming:
        rows = [
            {
                "D-day": (f"D-{s.readout_dday}" if s.readout_dday >= 0 else f"D+{abs(s.readout_dday)}"),
                "name": s.name,
                "trial_id": s.trial_id,
                "sponsor": s.sponsor,
                "status": s.status,
                "crisis_score": s.crisis_score,
            }
            for s in upcoming
        ]
        st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
    else:
        st.info("readout 예정일이 등록된 trial이 없습니다.")

    st.subheader("한국 sponsor / 한국 site 하이라이트")
    krs = korean_sponsor_highlight(scores)
    if krs:
        st.dataframe(_scores_df(krs), use_container_width=True, hide_index=True)
    else:
        st.info("필터 결과에 해당 trial이 없습니다.")


def tab_digest(scores, rules):
    st.header("5. Weekly digest + meta-analyst RoB2 보조")
    st.caption(
        "중요도 ★★ 이상 amendment, 위기점수 상위 trial, readout D-day, RoB2 보조 view, "
        "competitor intelligence Markdown export."
    )

    md = build_weekly_digest(scores, rules)
    st.download_button(
        "weekly_digest.md 다운로드",
        data=md.encode("utf-8"),
        file_name=f"obesity_trial_digest_{date.today().isoformat()}.md",
        mime="text/markdown",
    )
    with st.expander("Markdown 미리보기"):
        st.markdown(md)

    st.subheader("RoB2 보조 view")
    rob = rob2_view(scores)
    if rob:
        rows = []
        for r in rob:
            row = {"trial_id": r["trial_id"], "name": r["name"], "sponsor": r["sponsor"], "crisis_score": r["crisis_score"]}
            row.update({f"RoB:{k}": v for k, v in r["rob_signals"].items()})
            rows.append(row)
        st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
    else:
        st.info("표시할 trial이 없습니다.")

    st.subheader("Competitor intelligence (간이 표)")
    rows = []
    for s in sorted(scores, key=lambda x: -x.crisis_score):
        rows.append(
            {
                "name": s.name,
                "sponsor": s.sponsor,
                "drug_class": s.drug_class,
                "country": s.country,
                "status": s.status,
                "crisis_score": s.crisis_score,
                "amendment_count": len(s.amendment_scores),
                "korean_site": s.korean_site,
                "readout_dday": s.readout_dday,
            }
        )
    if rows:
        ci_df = pd.DataFrame(rows)
        st.dataframe(ci_df, use_container_width=True, hide_index=True)
        buf = io.StringIO()
        ci_df.to_csv(buf, index=False)
        st.download_button(
            "competitor_intelligence.csv 다운로드",
            data=buf.getvalue().encode("utf-8"),
            file_name="competitor_intelligence.csv",
            mime="text/csv",
        )


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
    trials_all, rules = _load()
    obesity_trials = filter_obesity(trials_all)
    scores_all = score_all_trials(obesity_trials, rules)

    filters = sidebar_filters(obesity_trials, rules)
    trials, scores = apply_filters(obesity_trials, scores_all, filters)

    st.title("ObesityTrialProtocolAmend-Kor")
    st.caption(
        "ClinicalTrials.gov · EudraCT · jRCT · CRIS-Korea · ANZCTR 통합 항비만 trial "
        "protocol amendment audit trail · 일일 surveillance · 한국어 weekly digest."
    )
    st.info(DISCLAIMER)

    tabs = st.tabs(
        [
            "1. Registry ingest",
            "2. Amendment delta",
            "3. 채점 · Alert",
            "4. Dashboard · Landscape",
            "5. Digest · RoB2",
        ]
    )
    with tabs[0]:
        tab_ingest(trials, scores)
    with tabs[1]:
        tab_amendments(trials, scores)
    with tabs[2]:
        tab_scoring(trials, scores, rules)
    with tabs[3]:
        tab_dashboard(scores)
    with tabs[4]:
        tab_digest(scores, rules)


if __name__ == "__main__":
    main()
