"""
main.py
=======
RodentCGMTel — Streamlit 진입점 + CLI demo 모드.

실행:
    streamlit run main.py            # 웹 UI
    python3 main.py --demo           # CLI 분석 (서버 preview/캡처용 stdout)
    python3 main.py --help

DISCLAIMER: 연구·참고용. 임상의사결정용 아님. IACUC 승인 동물실험 데이터 사후 분석용.
"""

from __future__ import annotations

import argparse
import io
import os
import sys
import textwrap
from typing import Optional

import numpy as np
import pandas as pd

_HERE = os.path.dirname(os.path.abspath(__file__))
if _HERE not in sys.path:
    sys.path.insert(0, _HERE)

import analysis  # noqa: E402
from analysis import (  # noqa: E402
    FORMAT_LOADERS,
    compute_metrics_for_animal,
    compute_metrics_cohort,
    analyze_challenge,
    cohort_challenge_summary,
    compare_groups,
    cohort_diurnal_matrix,
    drug_action_window,
)
from species_reference import (  # noqa: E402
    SPECIES_REFERENCE,
    CHALLENGE_PROTOCOLS,
    DEFAULT_LIGHT_CYCLE,
    list_models,
)


# ------------------------------------------------------------------
# Streamlit lazy import
# ------------------------------------------------------------------
def _streamlit_app() -> None:
    import streamlit as st

    try:
        import matplotlib.pyplot as plt
        HAS_MPL = True
    except ImportError:
        HAS_MPL = False

    st.set_page_config(page_title="RodentCGMTel", page_icon="🐭", layout="wide")
    st.title("RodentCGMTel — 로덴트씨지엠텔")
    st.caption(
        "Continuous glucose telemetry analyzer for rodent metabolic studies. "
        "Ponemah / Empatica / Medtronic Guardian / Eversense / Dexcom G6 rodent raw → "
        "species-adapted metrics + challenge analysis + cohort statistics."
    )
    st.warning(
        "⚠️ 연구·참고용. 임상의사결정용 아님. IACUC 승인 동물실험 데이터 사후 분석용.",
    )

    # ---- Sidebar: 데이터 소스 ----
    with st.sidebar:
        st.header("① Data ingest")
        source_mode = st.radio(
            "Source",
            ["Synthetic demo data", "Upload standard CSV", "Upload raw (multi-format)"],
            index=0,
        )
        st.divider()
        st.header("② Reference range")
        override_model = st.selectbox(
            "Force model reference (override file column)",
            ["(auto from data)"] + list_models(),
            index=0,
        )
        custom_tir = st.checkbox("Custom TIR override")
        if custom_tir:
            tir_low = st.number_input("TIR low (mg/dL)", 30.0, 300.0, 80.0)
            tir_high = st.number_input("TIR high (mg/dL)", 100.0, 500.0, 180.0)
        else:
            tir_low, tir_high = None, None

        st.divider()
        st.header("③ Light cycle")
        lights_on = st.slider("Lights ON (h)", 0, 23, DEFAULT_LIGHT_CYCLE["lights_on_hr"])
        lights_off = st.slider("Lights OFF (h)", 0, 23, DEFAULT_LIGHT_CYCLE["lights_off_hr"])

    # ---- Load data ----
    df = None
    if source_mode == "Synthetic demo data":
        synth_dir = os.path.join(_HERE, "data", "synthetic")
        files = sorted([f for f in os.listdir(synth_dir) if f.endswith(".csv")]) if os.path.isdir(synth_dir) else []
        if not files:
            st.error("No synthetic data. Run `python3 data/synthetic_generator.py` first.")
            return
        picks = st.multiselect("Select cohort CSV(s)", files, default=files)
        if picks:
            df = pd.concat([pd.read_csv(os.path.join(synth_dir, p), parse_dates=["timestamp"]) for p in picks],
                           ignore_index=True)

    elif source_mode == "Upload standard CSV":
        ups = st.file_uploader("Upload standard schema CSV(s)", type=["csv"], accept_multiple_files=True)
        if ups:
            frames = []
            for u in ups:
                frames.append(pd.read_csv(u, parse_dates=["timestamp"]))
            df = pd.concat(frames, ignore_index=True)

    else:  # raw multi-format
        fmt = st.selectbox("Raw format", list(FORMAT_LOADERS.keys()))
        aid = st.text_input("animal_id", "Animal-01")
        mdl = st.selectbox("model", list_models())
        grp = st.text_input("group", "exp")
        up = st.file_uploader(f"Upload {fmt} file", type=["csv", "txt", "pnm"])
        if up:
            tmp_path = os.path.join(_HERE, ".upload.tmp")
            with open(tmp_path, "wb") as f:
                f.write(up.read())
            loader = FORMAT_LOADERS[fmt]
            try:
                if fmt == "standard":
                    df = loader(tmp_path)
                else:
                    df = loader(tmp_path, aid, mdl, grp)
            except Exception as e:  # noqa: BLE001
                st.error(f"Load failed: {e}")
                df = None

    if df is None or df.empty:
        st.info("Load a cohort to begin.")
        return

    st.success(f"Loaded {len(df):,} samples · {df['animal_id'].nunique()} animals · "
               f"{df['model'].nunique()} models · {df['group'].nunique()} groups")

    # Apply model override
    if override_model != "(auto from data)":
        df["model"] = override_model

    # ---- Tabs ----
    tab_overview, tab_metrics, tab_challenge, tab_cohort, tab_export = st.tabs(
        ["Overview", "Per-animal metrics", "Challenge (GTT/ITT/MTT)", "Cohort comparison", "Export"]
    )

    # ----- Overview -----
    with tab_overview:
        st.subheader("Cohort overview")
        c1, c2, c3, c4 = st.columns(4)
        c1.metric("Animals", df["animal_id"].nunique())
        c2.metric("Samples", f"{len(df):,}")
        duration = (df["timestamp"].max() - df["timestamp"].min()).total_seconds() / 3600.0
        c3.metric("Span (h)", f"{duration:.1f}")
        c4.metric("Models", df["model"].nunique())

        if HAS_MPL:
            st.subheader("Trace (first 3 animals)")
            fig, ax = plt.subplots(figsize=(10, 4))
            for aid, sub in df.groupby("animal_id"):
                if list(df["animal_id"].unique()).index(aid) >= 3:
                    continue
                sub = sub.sort_values("timestamp")
                ax.plot(sub["timestamp"], sub["bg_mgdl"], label=str(aid), linewidth=0.7)
            ax.set_xlabel("Time")
            ax.set_ylabel("BG (mg/dL)")
            ax.legend(fontsize=7)
            st.pyplot(fig)

    # ----- Metrics -----
    with tab_metrics:
        st.subheader("Per-animal glycemic metrics")
        tir_range = (tir_low, tir_high) if (tir_low and tir_high) else None
        metrics_df = compute_metrics_cohort(df, tir_range=tir_range,
                                            lights_on=lights_on, lights_off=lights_off)
        st.dataframe(metrics_df, use_container_width=True)

        if HAS_MPL and not metrics_df.empty:
            st.subheader("TIR/TAR/TBR by animal")
            fig, ax = plt.subplots(figsize=(10, 4))
            mids = metrics_df.set_index("animal_id")
            x = np.arange(len(mids))
            ax.bar(x, mids["tbr_pct"], label="TBR", color="#3b82f6")
            ax.bar(x, mids["tir_pct"], bottom=mids["tbr_pct"], label="TIR", color="#10b981")
            ax.bar(x, mids["tar_pct"], bottom=mids["tbr_pct"] + mids["tir_pct"],
                   label="TAR", color="#ef4444")
            ax.set_xticks(x)
            ax.set_xticklabels(mids.index, rotation=60, fontsize=7)
            ax.set_ylabel("% time")
            ax.legend()
            st.pyplot(fig)

    # ----- Challenge -----
    with tab_challenge:
        st.subheader("Challenge analysis (per-animal slice)")
        protocol = st.selectbox("Protocol", list(CHALLENGE_PROTOCOLS.keys()))
        proto = CHALLENGE_PROTOCOLS[protocol]
        st.caption(f"{proto['name']} · {proto.get('notes', '')}")
        ch_animal = st.selectbox("Animal", sorted(df["animal_id"].unique()))
        sub = df[df["animal_id"] == ch_animal].sort_values("timestamp").reset_index(drop=True)
        if not sub.empty:
            dose_default = sub["timestamp"].iloc[len(sub) // 4]
            dose_time = st.text_input("Dose timestamp (ISO)", value=str(dose_default))
            try:
                dose_ts = pd.to_datetime(dose_time)
            except Exception:  # noqa: BLE001
                dose_ts = dose_default
            sampling = proto["sampling_minutes"]
            window_end = dose_ts + pd.Timedelta(minutes=max(sampling))
            window = sub[(sub["timestamp"] >= dose_ts) & (sub["timestamp"] <= window_end)]
            # sample nearest to each prescribed minute
            bgs, t_used = [], []
            for m in sampling:
                target = dose_ts + pd.Timedelta(minutes=m)
                if window.empty:
                    break
                row = window.iloc[(window["timestamp"] - target).abs().argsort().iloc[0]]
                bgs.append(row["bg_mgdl"])
                t_used.append(m)
            if len(bgs) >= 2:
                res = analyze_challenge(bgs, t_used, animal_id=ch_animal,
                                        model=str(sub["model"].iloc[0]),
                                        challenge_type=protocol)
                st.json(res.__dict__)
                if HAS_MPL:
                    fig, ax = plt.subplots(figsize=(7, 3))
                    ax.plot(t_used, bgs, "o-", color="#0ea5e9")
                    ax.axhline(res.baseline_bg, color="gray", linestyle="--", label="baseline")
                    ax.set_xlabel("Time (min)")
                    ax.set_ylabel("BG (mg/dL)")
                    ax.set_title(f"{protocol} — {ch_animal}")
                    ax.legend()
                    st.pyplot(fig)
            else:
                st.warning("Not enough samples in challenge window.")

        st.subheader("Drug action time-window (window=6h)")
        win = drug_action_window(sub, dose_ts if 'dose_ts' in locals() else sub["timestamp"].iloc[0])
        st.json(win)

    # ----- Cohort comparison -----
    with tab_cohort:
        st.subheader("Group comparison")
        metric_choice = st.selectbox(
            "Metric",
            ["mean_bg", "tir_pct", "tar_pct", "tbr_pct", "cv_pct", "mage", "modd",
             "dawn_delta", "nocturnal_mean", "diurnal_mean"],
        )
        tir_range = (tir_low, tir_high) if (tir_low and tir_high) else None
        metrics_df = compute_metrics_cohort(df, tir_range=tir_range,
                                            lights_on=lights_on, lights_off=lights_off)
        if metrics_df["group"].nunique() < 2:
            st.warning("Need ≥2 groups for comparison.")
        else:
            res = compare_groups(metrics_df, metric_choice)
            st.json(res)
        st.subheader("Cohort diurnal heat map (30-min bins)")
        mat = cohort_diurnal_matrix(df, bin_minutes=30)
        if HAS_MPL and not mat.empty:
            fig, ax = plt.subplots(figsize=(10, max(3, 0.25 * len(mat))))
            im = ax.imshow(mat.values, aspect="auto", cmap="RdYlBu_r")
            ax.set_xticks(np.arange(len(mat.columns))[::4])
            ax.set_xticklabels([f"{int(c) // 60:02d}:{int(c) % 60:02d}" for c in mat.columns[::4]],
                               rotation=45, fontsize=7)
            ax.set_yticks(np.arange(len(mat.index)))
            ax.set_yticklabels(mat.index, fontsize=7)
            ax.set_xlabel("hour of day")
            fig.colorbar(im, ax=ax, label="BG (mg/dL)")
            st.pyplot(fig)

    # ----- Export -----
    with tab_export:
        st.subheader("Manuscript-ready export (Diabetes / Cell Metabolism format)")
        tir_range = (tir_low, tir_high) if (tir_low and tir_high) else None
        metrics_df = compute_metrics_cohort(df, tir_range=tir_range,
                                            lights_on=lights_on, lights_off=lights_off)
        # 한국어 요약
        buf = io.StringIO()
        buf.write("# RodentCGMTel 코호트 요약\n\n")
        buf.write(f"- 동물 수: {df['animal_id'].nunique()}\n")
        buf.write(f"- 모델: {', '.join(sorted(df['model'].unique()))}\n")
        buf.write(f"- 관측 기간: {df['timestamp'].min()} ~ {df['timestamp'].max()}\n\n")
        if not metrics_df.empty:
            buf.write("## 그룹별 평균 ± SD\n\n")
            for grp, sub in metrics_df.groupby("group"):
                buf.write(f"### {grp} (n={len(sub)})\n")
                for col in ["mean_bg", "tir_pct", "tar_pct", "tbr_pct", "cv_pct", "mage", "modd"]:
                    if col in sub:
                        buf.write(f"- {col}: {sub[col].mean():.2f} ± {sub[col].std():.2f}\n")
                buf.write("\n")
        report = buf.getvalue()
        st.text_area("Korean manuscript-ready report", report, height=300)
        st.download_button("Download report.md", report, file_name="rodentcgmtel_report.md")
        st.download_button("Download metrics.csv", metrics_df.to_csv(index=False),
                           file_name="rodentcgmtel_metrics.csv")


# ------------------------------------------------------------------
# CLI demo mode
# ------------------------------------------------------------------
def run_demo_cli() -> int:
    """합성 데이터 로드 → 핵심 지표 산출 → stdout 한국어 요약."""
    print("=" * 70)
    print("RodentCGMTel — CLI demo (synthetic cohort analysis)")
    print("DISCLAIMER: 연구·참고용. 임상의사결정용 아님. IACUC 데이터 사후 분석용.")
    print("=" * 70)

    synth_dir = os.path.join(_HERE, "data", "synthetic")
    if not os.path.isdir(synth_dir):
        print(f"[!] No synthetic dir at {synth_dir}")
        print("    Run: python3 data/synthetic_generator.py")
        return 2
    files = sorted([f for f in os.listdir(synth_dir) if f.endswith(".csv")])
    if not files:
        print("[!] No synthetic CSVs. Running generator first.")
        from data import synthetic_generator as sg  # type: ignore
        sg.main(synth_dir)
        files = sorted([f for f in os.listdir(synth_dir) if f.endswith(".csv")])

    frames = []
    for f in files:
        frames.append(pd.read_csv(os.path.join(synth_dir, f), parse_dates=["timestamp"]))
    df = pd.concat(frames, ignore_index=True)
    print(f"[+] Loaded {len(df):,} samples · {df['animal_id'].nunique()} animals · "
          f"{df['model'].nunique()} models")

    metrics_df = compute_metrics_cohort(df)
    print()
    print(f"[+] Per-animal metrics: {len(metrics_df)} rows")
    print()
    print("=== 그룹별 평균 (Korean summary) ===")
    grouped = metrics_df.groupby("group")[
        ["mean_bg", "tir_pct", "tar_pct", "tbr_pct", "cv_pct", "mage", "modd", "dawn_delta"]
    ].mean().round(2)
    print(grouped.to_string())

    # GTT 시뮬레이션: 첫 동물에 가상의 GTT (baseline 0, +100 mg/dL spike at 15min)
    print()
    print("=== 모의 GTT(IP, 2g/kg) — 첫 동물 ===")
    first_aid = sorted(df["animal_id"].unique())[0]
    sub = df[df["animal_id"] == first_aid].sort_values("timestamp").reset_index(drop=True)
    base = float(sub["bg_mgdl"].iloc[100])
    sim = [base, base + 110, base + 160, base + 130, base + 80, base + 40, base + 10]
    res = analyze_challenge(sim, [0, 15, 30, 60, 90, 120, 180],
                            animal_id=first_aid, model=str(sub["model"].iloc[0]),
                            challenge_type="GTT_ip")
    for k, v in res.__dict__.items():
        print(f"  {k}: {v}")

    # 그룹 비교
    print()
    print("=== Group comparison: TIR% ===")
    if metrics_df["group"].nunique() >= 2:
        cmp_res = compare_groups(metrics_df, "tir_pct")
        print(f"  method: {cmp_res['method']}")
        print(f"  ANOVA p-value: {cmp_res['p_value']}")
        print(f"  groups: {cmp_res['groups']}")
        print("  pairwise (top 3):")
        for r in cmp_res["pairwise"][:3]:
            print(f"    {r['group_a']} vs {r['group_b']}: p={r['p_value']:.4f} (bonf={r['p_bonferroni']:.4f})")
    else:
        print("  insufficient groups")

    print()
    print("[OK] demo complete")
    return 0


# ------------------------------------------------------------------
# Entry
# ------------------------------------------------------------------
def _is_streamlit_runtime() -> bool:
    """streamlit run 으로 실행되었는지 감지."""
    try:
        from streamlit.runtime.scriptrunner import get_script_run_ctx
        return get_script_run_ctx() is not None
    except Exception:  # noqa: BLE001
        return False


def main(argv: Optional[list] = None) -> int:
    parser = argparse.ArgumentParser(
        prog="RodentCGMTel",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description=textwrap.dedent(__doc__ or ""),
    )
    parser.add_argument("--demo", action="store_true",
                        help="CLI 모드: 합성 데이터 분석 후 stdout 한국어 요약")
    parser.add_argument("--list-models", action="store_true",
                        help="등록된 species/model 키 목록 출력")
    args, _ = parser.parse_known_args(argv)

    if args.list_models:
        for m in list_models():
            print(m)
        return 0
    if args.demo:
        return run_demo_cli()

    # 기본: streamlit 컨텍스트면 앱 실행, 아니면 사용법 안내
    if _is_streamlit_runtime():
        _streamlit_app()
        return 0
    print("RodentCGMTel — use one of:")
    print("  streamlit run main.py")
    print("  python3 main.py --demo")
    print("  python3 main.py --list-models")
    return 0


# streamlit run 으로 임포트되는 경우에도 앱 동작
if _is_streamlit_runtime():
    _streamlit_app()
elif __name__ == "__main__":
    sys.exit(main())
