"""BodyCompMouseDXA — Streamlit MVP entry point.

Run with:
    streamlit run main.py

Or for a quick offline CLI sanity check:
    python3 main.py --demo

The CLI path requires only numpy / pandas / scipy. Streamlit and plotly
are optional and only needed for the interactive UI.
"""

from __future__ import annotations

import argparse
import os
import sys

import pandas as pd

# Make sibling modules importable both as script and via -m.
HERE = os.path.dirname(os.path.abspath(__file__))
if HERE not in sys.path:
    sys.path.insert(0, HERE)

import analysis  # noqa: E402
from analysis import (  # noqa: E402
    ancova_endpoint,
    compute_indices,
    derive_roi_metrics,
    drug_on_off_interaction,
    flag_sarcopenic_obesity,
    korean_report,
    load_any,
    load_demo_data,
    mixed_effects_repeated_measures,
    regenerate_synthetic_if_missing,
    trajectory_for_animal,
    trajectory_table,
)
from dio_reference import (  # noqa: E402
    AVAILABLE_MODELS,
    DIO_REFERENCE,
    classify_band,
    get_reference,
    reference_table,
)


def _prepare(df: pd.DataFrame) -> pd.DataFrame:
    df = derive_roi_metrics(df)
    df = compute_indices(df)
    df = flag_sarcopenic_obesity(df)
    return df


# ---------------------------------------------------------------------------
#  CLI demo mode (no Streamlit needed)
# ---------------------------------------------------------------------------

def run_demo_cli() -> int:
    print("=" * 72)
    print(" BodyCompMouseDXA — CLI demo mode")
    print(" (이 도구는 IACUC 승인 동물실험 데이터의 사후 분석 보조용입니다)")
    print("=" * 72)
    regen = regenerate_synthetic_if_missing()
    if regen:
        print(f"[info] regenerated {len(regen)} synthetic cohort CSV(s):")
        for p in regen:
            print(f"        {p}")
    df = load_demo_data()
    if df.empty:
        print("[error] no synthetic data found. Run data/synthetic_generator.py first.")
        return 1
    df = _prepare(df)
    print(f"\n[info] loaded {len(df)} rows across {df['animal_id'].nunique()} animals "
          f"and {df['group'].nunique()} groups.")
    print(f"[info] groups: {sorted(df['group'].unique())}")
    print(f"[info] time points: {sorted(df['time_point_wk'].dropna().unique())}")

    # Single-animal trajectory
    aid = df["animal_id"].iloc[0]
    traj = trajectory_for_animal(df, aid)

    # Cohort ANCOVA on fat_g, final wk = max
    dio_vs_ctrl = df[df["group"].isin(["DIO_HFD60", "control_chow"])]
    ancova_res = ancova_endpoint(dio_vs_ctrl, endpoint="fat_g")

    # Drug on/off
    drug = df[df["group"].isin(["GLP1RA_treated", "GLP1RA_placebo"])]
    drug_res = drug_on_off_interaction(drug, endpoint="fat_g")

    print()
    print(korean_report(traj, ancova_res, drug_res))

    # Mixed-effects summary if available
    mer = mixed_effects_repeated_measures(dio_vs_ctrl, endpoint="fat_g")
    print()
    print(f"[Mixed-effects RM, backend={mer.get('backend')}]")
    if "per_time_results" in mer:
        for t, v in sorted(mer["per_time_results"].items()):
            print(f"  wk{t:g}: F={v['f']:.3f}, p={v['p']:.4g}, n_per_group={v['n_per_group']}")
    elif "pvalues" in mer:
        for k, v in mer["pvalues"].items():
            print(f"  {k}: p={v:.4g}")

    # Sarcopenic obesity flags summary
    so_flags = df.groupby("group")["sarcopenic_obesity_flag"].mean()
    print("\n[근감소성 비만 비율 (group별)]")
    for g, frac in so_flags.items():
        print(f"  · {g}: {frac * 100:.1f}%")

    # Reference comparison: pick wk12 of DIO HFD vs reference band
    print("\n[참조 밴드 vs 합성 데이터 (wk12, DIO_HFD60, fat_pct)]")
    band = get_reference("C57BL_6J_HFD60", 12)["fat_pct"]
    sample = dio_vs_ctrl[(dio_vs_ctrl["group"] == "DIO_HFD60")
                         & (dio_vs_ctrl["time_point_wk"] == 12)]
    for _, row in sample.head(3).iterrows():
        val = float(row["fat_pct"])
        print(f"  · {row['animal_id']}: fat_pct={val:.2f} — {classify_band(val, band)}")
    return 0


# ---------------------------------------------------------------------------
#  Streamlit UI
# ---------------------------------------------------------------------------

def _try_import_streamlit():
    try:
        import streamlit as st  # type: ignore
        return st
    except ImportError:
        return None


def _try_import_plotly():
    try:
        import plotly.express as px  # type: ignore
        import plotly.graph_objects as go  # type: ignore
        return px, go
    except ImportError:
        return None, None


def run_streamlit_ui() -> None:
    st = _try_import_streamlit()
    if st is None:
        print("Streamlit is not installed. Use:  python3 main.py --demo")
        return

    px, go = _try_import_plotly()

    st.set_page_config(page_title="BodyCompMouseDXA",
                       page_icon="🐭",
                       layout="wide")
    st.title("BodyCompMouseDXA — Mouse Body Composition Pipeline")
    st.caption("PIXImus DXA / EchoMRI / qNMR / SkyScan microCT — multi-timepoint, DIO model-specific reference, sarcopenic obesity, drug-on/off, ANCOVA.")
    st.warning("연구·참고용 — IACUC 승인 동물실험 데이터의 사후 분석 보조용입니다. 임상 의사결정에 사용하지 마십시오.")

    # ---------------- sidebar ----------------
    st.sidebar.header("1) 데이터 로드")
    src = st.sidebar.radio("데이터 소스", ["내장 합성 데이터", "파일 업로드"], index=0)

    df = pd.DataFrame()
    if src == "내장 합성 데이터":
        regen = regenerate_synthetic_if_missing()
        if regen:
            st.sidebar.info(f"합성 데이터 {len(regen)}개 파일 자동 생성됨")
        df = load_demo_data()
    else:
        uploads = st.sidebar.file_uploader(
            "PIXImus DICOM / EchoMRI / qNMR / SkyScan / canonical CSV",
            type=["csv", "txt", "dcm"],
            accept_multiple_files=True,
        )
        frames = []
        if uploads:
            for up in uploads:
                name = up.name.lower()
                try:
                    if "echomri" in name:
                        frames.append(analysis.load_echomri_csv(up))
                    elif "qnmr" in name or "minispec" in name or name.endswith(".txt"):
                        frames.append(analysis.load_qnmr_txt(up))
                    elif "skyscan" in name or "microct" in name:
                        frames.append(analysis.load_skyscan_csv(up))
                    else:
                        frames.append(analysis.load_dxa_csv_export(up))
                except Exception as e:
                    st.sidebar.error(f"{name}: {e}")
            if frames:
                df = pd.concat(frames, ignore_index=True)

    if df.empty:
        st.info("좌측에서 데이터를 로드하세요.")
        return
    df = _prepare(df)
    st.success(f"{len(df)} 행 · {df['animal_id'].nunique()} 개체 · {df['group'].nunique()} 그룹")

    # ---------------- tabs ----------------
    tab_data, tab_traj, tab_so, tab_ref, tab_ancova, tab_drug, tab_export = st.tabs([
        "데이터", "트라젝토리", "근감소성 비만", "DIO 참조", "ANCOVA", "Drug on/off", "Export"
    ])

    with tab_data:
        st.subheader("Raw schema (canonical 18 columns)")
        st.dataframe(df.head(200))
        st.subheader("Source format 분포")
        st.dataframe(df.groupby("source_format").size().rename("rows").reset_index())

    # ---------------- trajectory ----------------
    with tab_traj:
        st.subheader("개체별 멀티 타임포인트 트라젝토리")
        animal_ids = sorted(df["animal_id"].dropna().unique().tolist())
        aid = st.selectbox("개체 선택", animal_ids)
        traj = trajectory_for_animal(df, aid)
        c1, c2, c3, c4 = st.columns(4)
        c1.metric("기저 체중 (g)", f"{traj.baseline_bw:.2f}")
        c2.metric("최종 체중 (g)", f"{traj.final_bw:.2f}",
                  delta=f"{traj.pct_change_bw:+.2f}%")
        c3.metric("Nadir 체중 (g)", f"{traj.nadir_bw:.2f}",
                  delta=f"wk{traj.nadir_bw_wk:g}")
        c4.metric("Nadir 이후 재증가 (g)", f"{traj.regain_after_nadir_g:+.3f}")

        sub = df[df["animal_id"] == aid].sort_values("time_point_wk")
        if px is not None:
            fig = px.line(sub, x="time_point_wk",
                          y=["body_weight_g", "fat_g", "lean_g"],
                          markers=True,
                          title=f"{aid} — 체중·지방·제지방 (g)")
            st.plotly_chart(fig, use_container_width=True)
        else:
            st.line_chart(sub.set_index("time_point_wk")[["body_weight_g", "fat_g", "lean_g"]])

        st.subheader("코호트 전체 트라젝토리 요약 표")
        tbl = trajectory_table(df)
        st.dataframe(tbl)

    # ---------------- sarcopenic obesity ----------------
    with tab_so:
        st.subheader("근감소성 비만 지수")
        st.caption("Sarcopenic Obesity Index = lean / (lean + fat); ALM/BW 동시 cutoff 적용")
        c1, c2 = st.columns(2)
        soi_cut = c1.slider("SOI cutoff", 0.30, 0.70, 0.45, 0.01)
        alm_cut = c2.slider("ALM/BW cutoff", 0.05, 0.25, 0.12, 0.01)
        df_so = flag_sarcopenic_obesity(df, soi_cutoff=soi_cut, alm_cutoff=alm_cut)
        flagged = df_so[df_so["sarcopenic_obesity_flag"]]
        st.write(f"양성 개체-타임 포인트: {len(flagged)} / {len(df_so)}")
        st.dataframe(flagged[["animal_id", "model", "group", "time_point_wk",
                              "sarcopenic_obesity_index", "ALM_over_BW",
                              "so_severity"]])
        if px is not None and not df_so.empty:
            fig = px.violin(df_so, x="group", y="sarcopenic_obesity_index",
                            box=True, points="all", color="group",
                            title="Group별 sarcopenic obesity index 분포")
            st.plotly_chart(fig, use_container_width=True)

    # ---------------- reference ----------------
    with tab_ref:
        st.subheader("DIO 모델 참조 밴드")
        model = st.selectbox("모델", AVAILABLE_MODELS, index=AVAILABLE_MODELS.index("C57BL_6J_HFD60"))
        ref = DIO_REFERENCE[model]
        ref_rows = []
        for t, b in ref.items():
            ref_rows.append({
                "time_point_wk": t,
                "body_weight_g": f"{b['body_weight_g'][0]:.1f}–{b['body_weight_g'][2]:.1f}",
                "fat_pct": f"{b['fat_pct'][0]:.1f}–{b['fat_pct'][2]:.1f}",
                "lean_pct": f"{b['lean_pct'][0]:.1f}–{b['lean_pct'][2]:.1f}",
                "BMD": f"{b['BMD'][0]:.3f}–{b['BMD'][2]:.3f}",
                "visceral_fat_proxy_g": f"{b['visceral_fat_proxy_g'][0]:.2f}–{b['visceral_fat_proxy_g'][2]:.2f}",
            })
        st.dataframe(pd.DataFrame(ref_rows))
        st.caption("(low, high)는 분포 25–75 percentile에 해당하는 근사 범위입니다.")

    # ---------------- ANCOVA ----------------
    with tab_ancova:
        st.subheader("Baseline-corrected ANCOVA")
        endpoint = st.selectbox("Endpoint", ["fat_g", "lean_g", "body_weight_g",
                                             "fat_pct", "BMD", "visceral_fat_proxy_g",
                                             "appendicular_lean_g"])
        groups = sorted(df["group"].dropna().unique())
        sel_groups = st.multiselect("그룹 (≥2)", groups, default=groups[:2])
        if len(sel_groups) < 2:
            st.info("최소 2개 그룹을 선택하세요.")
        else:
            sub = df[df["group"].isin(sel_groups)]
            res = ancova_endpoint(sub, endpoint=endpoint)
            if "error" in res:
                st.error(res["error"])
            else:
                st.write(f"**Backend:** {res['backend']}")
                st.write(f"**F = {res['f_statistic']:.3f}, p = {res['p_value']:.4g}**  "
                         f"(final_wk = {res['final_wk']})")
                cols = pd.DataFrame({
                    "group": list(res["means_adjusted"].keys()),
                    "n": [res["n_per_group"][g] for g in res["means_adjusted"]],
                    "baseline_mean": [res["means_baseline"][g] for g in res["means_adjusted"]],
                    "final_mean": [res["means_final"][g] for g in res["means_adjusted"]],
                    "adjusted_mean": list(res["means_adjusted"].values()),
                })
                st.dataframe(cols)
                if px is not None:
                    fig = px.bar(cols, x="group", y="adjusted_mean",
                                 title=f"{endpoint} — baseline-corrected adjusted mean")
                    st.plotly_chart(fig, use_container_width=True)

        st.divider()
        st.subheader("Mixed-effects Repeated Measures")
        if st.button("Run mixed-effects RM"):
            if len(sel_groups) >= 2:
                mer = mixed_effects_repeated_measures(
                    df[df["group"].isin(sel_groups)], endpoint=endpoint)
                st.write(f"Backend: {mer.get('backend')}")
                if "per_time_results" in mer:
                    rows = [{"time_point_wk": t, **v}
                            for t, v in sorted(mer["per_time_results"].items())]
                    st.dataframe(pd.DataFrame(rows))
                elif "summary_text" in mer:
                    st.code(mer["summary_text"])
            else:
                st.info("최소 2개 그룹을 선택하세요.")

    # ---------------- drug on/off ----------------
    with tab_drug:
        st.subheader("Drug-on / Drug-off interaction")
        groups_drug = [g for g in df["group"].dropna().unique() if "GLP" in g or "drug" in g.lower()]
        if not groups_drug:
            st.info("현재 데이터셋에 drug_phase 라벨이 표시된 그룹이 없습니다.")
        else:
            sub = df[df["group"].isin(groups_drug)]
            ep = st.selectbox("Endpoint ", ["fat_g", "body_weight_g", "lean_g", "visceral_fat_proxy_g"],
                              key="drug_ep")
            res = drug_on_off_interaction(sub, endpoint=ep)
            st.json(res)
            if px is not None and "per_animal" in res:
                pa = pd.DataFrame(res["per_animal"])
                fig = px.scatter(pa, x="delta_on", y="delta_off", color="group",
                                 hover_data=["animal_id"],
                                 title=f"{ep}: Δon vs Δoff (개체별)")
                st.plotly_chart(fig, use_container_width=True)

    # ---------------- export ----------------
    with tab_export:
        st.subheader("Manuscript-ready export")
        csv_bytes = df.to_csv(index=False).encode("utf-8-sig")
        st.download_button("코호트 CSV 다운로드 (BOM 포함)", csv_bytes,
                           file_name="bodycomp_cohort.csv", mime="text/csv")
        traj_tbl = trajectory_table(df)
        if not traj_tbl.empty:
            st.download_button("트라젝토리 요약 CSV", traj_tbl.to_csv(index=False).encode("utf-8-sig"),
                               file_name="bodycomp_trajectory_summary.csv", mime="text/csv")
        ref_tbl = pd.DataFrame(reference_table())
        st.download_button("DIO 참조 밴드 CSV", ref_tbl.to_csv(index=False).encode("utf-8-sig"),
                           file_name="dio_reference_bands.csv", mime="text/csv")
        st.markdown("---")
        st.markdown("**한국어 요약 리포트 예시**")
        aid = sorted(df["animal_id"].unique())[0]
        report = korean_report(trajectory_for_animal(df, aid),
                               ancova_endpoint(df, endpoint="fat_g"))
        st.code(report, language="text")


# ---------------------------------------------------------------------------
#  Entry point
# ---------------------------------------------------------------------------

def main(argv=None) -> int:
    parser = argparse.ArgumentParser(description="BodyCompMouseDXA MVP")
    parser.add_argument("--demo", action="store_true",
                        help="Run a CLI demo (synthetic data + ANCOVA + trajectory).")
    args, _unknown = parser.parse_known_args(argv)
    if args.demo:
        return run_demo_cli()
    run_streamlit_ui()
    return 0


if __name__ == "__main__":
    sys.exit(main())
