"""MASH-DILISurveil-Kor Streamlit app.

참고용·연구용. Not for clinical decision.
"""
from __future__ import annotations

import io
import json
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parent
sys.path.insert(0, str(ROOT))

try:
    import streamlit as st
except Exception:   # pragma: no cover
    st = None

import pandas as pd

from modules import ingest, hys_law, rucam, class_panel, report

DISCLAIMER = (
    "본 도구는 참고용·연구용이며 임상 의사결정·규제 제출에 그대로 사용해서는 안 됩니다. "
    "placebo arm reference는 mock 값입니다."
)


def _patients_from_df(df: pd.DataFrame):
    patients = {}
    for pid, sub in df.groupby("pid"):
        sub = sub.sort_values("week")
        arm = str(sub.iloc[0].get("arm", "unknown"))
        dclass = str(sub.iloc[0].get("drug_class", "unknown"))
        p = ingest.Patient(pid=str(pid), arm=arm, drug_class=dclass)
        for _, row in sub.iterrows():
            p.timepoints.append({
                "week": int(row["week"]),
                "ALT": float(row["ALT"]) if pd.notna(row["ALT"]) else None,
                "AST": float(row["AST"]) if pd.notna(row["AST"]) else None,
                "ALP": float(row["ALP"]) if pd.notna(row["ALP"]) else None,
                "TBL": float(row["TBL"]) if pd.notna(row["TBL"]) else None,
                "INR": float(row["INR"]) if pd.notna(row.get("INR", None)) else None,
                "ALB": float(row["ALB"]) if pd.notna(row.get("ALB", None)) else None,
                "PLT": float(row["PLT"]) if pd.notna(row.get("PLT", None)) else None,
            })
        if p.timepoints:
            p.baseline = {k: v for k, v in p.timepoints[0].items()
                          if k != "week" and v is not None}
        patients[str(pid)] = p
    return patients


def main():
    if st is None:
        print("Streamlit not available. Run `pip install -r requirements.txt`.")
        return
    st.set_page_config(page_title="MASH-DILISurveil-Kor", layout="wide")
    st.title("MASH-DILISurveil-Kor")
    st.warning(DISCLAIMER)

    data_dir = ROOT / "data"
    report_dir = ROOT / "reports"

    st.sidebar.header("Data source")
    use_demo = st.sidebar.checkbox("Use bundled demo data", value=True)
    uploaded = None
    if not use_demo:
        uploaded = st.sidebar.file_uploader("Upload LFT CSV", type="csv")

    if use_demo and (data_dir / "synthetic_lft.csv").exists():
        df = pd.read_csv(data_dir / "synthetic_lft.csv")
    elif uploaded is not None:
        df = pd.read_csv(uploaded)
    else:
        st.info("Demo data missing. Run `python3 main.py --demo` first or upload CSV.")
        return

    patients = _patients_from_df(df)
    # panels
    for cls in ("THRb", "FGF21", "ACC", "FXR"):
        ingest.load_panel_csv(data_dir / f"synthetic_{cls.lower()}_panel.csv",
                              patients, cls)

    cases = hys_law.evaluate_patients(patients)
    rucam_inputs = rucam.derive_inputs_from_cases(cases)
    rucam_results = rucam.evaluate_batch(rucam_inputs, cases)
    signals = class_panel.evaluate_all(patients)
    placebo_ref = report.load_placebo_reference(data_dir / "placebo_reference.csv")
    dsc = report.build_dsc_report(cases, rucam_results, signals, placebo_ref,
                                  quarter="Q1", drug_label="STUDY-DRUG-001")

    tabs = st.tabs([
        "Overview", "eDISH / Hy's law", "RUCAM cases",
        "Class panel", "Placebo signal", "Export",
    ])

    with tabs[0]:
        st.subheader("Baseline summary")
        bsum = ingest.summarize_baseline(patients)
        st.dataframe(pd.DataFrame(bsum).T)
        st.subheader("Hy's law summary")
        st.json(dsc["hys_summary"])
        st.subheader("RUCAM categories")
        st.json(dsc["rucam_category_counts"])

    with tabs[1]:
        st.subheader("eDISH points")
        edish_df = pd.DataFrame([{
            "pid": c.pid, "arm": c.arm, "drug_class": c.drug_class,
            "ALT/ULN": round(c.alt_ratio_uln, 2),
            "TBL/ULN": round(c.tbl_ratio_uln, 2),
            "ALT/baseline": round(c.alt_ratio_baseline, 2),
            "quadrant": c.quadrant,
            "classical_hys": c.classical_hys,
            "baseline_adj_hys": c.baseline_adj_hys,
        } for c in cases])
        st.dataframe(edish_df, use_container_width=True)
        try:
            import matplotlib.pyplot as plt
            fig, ax = plt.subplots(figsize=(6, 5))
            colors = {"placebo": "tab:gray"}
            for arm in edish_df["arm"].unique():
                sub = edish_df[edish_df["arm"] == arm]
                ax.scatter(sub["ALT/ULN"], sub["TBL/ULN"], label=arm, s=18, alpha=0.6)
            ax.set_xscale("log")
            ax.set_yscale("log")
            ax.axhline(2, ls="--", c="red", lw=0.8)
            ax.axvline(3, ls="--", c="red", lw=0.8)
            ax.set_xlabel("peak ALT / ULN")
            ax.set_ylabel("peak TBL / ULN")
            ax.set_title("eDISH plot (log-log)")
            ax.legend(fontsize=8)
            st.pyplot(fig)
        except Exception as e:
            st.text(hys_law.render_edish_ascii(cases))
            st.caption(f"(matplotlib unavailable: {e})")

    with tabs[2]:
        st.subheader("RUCAM scores (auto 1차 패스)")
        rdf = pd.DataFrame([{
            "pid": r.pid, "total": r.total, "category": r.category,
            "R_ratio": round(r.cioms_r_ratio, 2) if r.cioms_r_ratio else None,
            "pattern": r.pattern, "Naranjo": r.naranjo,
            "Naranjo_cat": r.naranjo_category,
            "Maria-Victorino": r.maria_victorino,
        } for r in rucam_results])
        st.dataframe(rdf, use_container_width=True)

    with tabs[3]:
        st.subheader("Class panel flags")
        for cls, sigs in signals.items():
            st.markdown(f"#### {cls}")
            sdf = pd.DataFrame([{
                "pid": s.pid, "marker": s.marker, "role": s.role,
                "baseline": s.baseline, "peak": s.peak,
                "%change": round(s.pct_change, 1) if s.pct_change is not None else None,
                "flag": s.flag, "rule": s.rule,
            } for s in sigs])
            if not sdf.empty:
                st.dataframe(sdf, use_container_width=True)
            else:
                st.caption("(no panel data)")

    with tabs[4]:
        st.subheader("Placebo arm attribution")
        st.json(dsc["placebo_attribution"])
        st.caption(
            "MAESTRO-NASH·ENLIGHTEN·SYMMETRY·CONTROL·ESSENCE placebo arm 공개값 mock."
        )

    with tabs[5]:
        st.subheader("Export")
        st.download_button(
            "Download DSC report (JSON)",
            data=json.dumps(dsc, ensure_ascii=False, indent=2),
            file_name="DSC_report.json",
            mime="application/json",
        )
        cases_csv = io.StringIO()
        import csv
        writer = csv.writer(cases_csv)
        writer.writerow(["pid", "arm", "drug_class",
                         "ALT/ULN", "TBL/ULN", "classical_hys", "quadrant"])
        for c in cases:
            writer.writerow([c.pid, c.arm, c.drug_class,
                             round(c.alt_ratio_uln, 2), round(c.tbl_ratio_uln, 2),
                             c.classical_hys, c.quadrant])
        st.download_button(
            "Download cases CSV",
            data=cases_csv.getvalue(),
            file_name="cases.csv", mime="text/csv",
        )
        if st.button("Generate docx (RMP/DSUR/PSUR/DSC)"):
            report_dir.mkdir(parents=True, exist_ok=True)
            for kind in ("DSC", "RMP", "DSUR", "PSUR"):
                out = report.export_docx(dsc, report_dir / f"{kind}_Q1.docx",
                                         doc_type=kind)
                st.write(f"  - {kind}: {out}")
            report.build_manuscript_supplement(
                cases, rucam_results, signals,
                report_dir / "supplementary.md")
            st.success("Reports written to reports/.")


if __name__ == "__main__":
    main()
