"""ObesityPreprintRadar — Streamlit standalone 연구 알림 도구.

참고용·연구용 / 임상결정 대체 금지.
모든 데이터는 합성(synthetic). 외부 네트워크 호출 없음.
"""

from __future__ import annotations

import argparse
import io
import json
import os
import re
import sqlite3
import sys
from collections import Counter, defaultdict
from datetime import datetime, date
from typing import Any, Dict, Iterable, List, Optional, Tuple

ROOT = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(ROOT, "data")
WATCHLIST_DB = os.path.join(ROOT, "watchlist.sqlite")

DISCLAIMER = (
    "참고용·연구용 도구입니다. 임상 결정 대체 금지. "
    "모든 데이터는 합성(synthetic) 자료이며 실제 preprint·논문이 아닙니다."
)

# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------

def _load_json(name: str) -> Any:
    path = os.path.join(DATA_DIR, name)
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def load_preprints() -> List[Dict[str, Any]]:
    raw = _load_json("preprints.json")
    return list(raw.get("preprints", []))


def load_publications() -> List[Dict[str, Any]]:
    raw = _load_json("publications.json")
    return list(raw.get("publications", []))


def load_topics() -> List[Dict[str, Any]]:
    raw = _load_json("topics.json")
    return list(raw.get("label_rules", []))


def load_kols() -> List[Dict[str, Any]]:
    raw = _load_json("kol_seed.json")
    return list(raw.get("kols", []))


# ---------------------------------------------------------------------------
# Scoring / labeling
# ---------------------------------------------------------------------------

def availability_score(pp: Dict[str, Any]) -> int:
    """0~5점.

    +1 data link 존재
    +1 code link 존재
    +1 protocol link 존재
    +1 data link에 Zenodo/OSF 영구식별자 포함
    +1 v2 이상 (refinement evidence)
    """
    score = 0
    data_links = pp.get("data_links") or []
    code_links = pp.get("code_links") or []
    protocol_links = pp.get("protocol_links") or []
    if data_links:
        score += 1
    if code_links:
        score += 1
    if protocol_links:
        score += 1
    if any(("zenodo" in l.lower()) or ("osf.io" in l.lower()) for l in data_links):
        score += 1
    if int(pp.get("version", 1)) >= 2:
        score += 1
    return min(score, 5)


def label_preprint(pp: Dict[str, Any], rules: List[Dict[str, Any]]) -> List[str]:
    text = " ".join(
        [
            pp.get("title", ""),
            pp.get("abstract", ""),
            " ".join(pp.get("authors", []) or []),
            " ".join(pp.get("affiliations", []) or []),
        ]
    ).lower()
    labels: List[str] = []
    for rule in rules:
        keywords = [k.lower() for k in rule.get("keywords", [])]
        if any(k in text for k in keywords):
            labels.append(rule["label"])
    if not labels:
        labels.append("uncategorized")
    return labels


def enrich(preprints: List[Dict[str, Any]], rules: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    out = []
    for pp in preprints:
        e = dict(pp)
        e["availability"] = availability_score(pp)
        e["labels"] = label_preprint(pp, rules)
        out.append(e)
    return out


# ---------------------------------------------------------------------------
# Dedup
# ---------------------------------------------------------------------------

def _norm_title(t: str) -> str:
    t = t.lower()
    t = re.sub(r"[^a-z0-9 ]+", " ", t)
    return re.sub(r"\s+", " ", t).strip()


def _token_set(t: str) -> set:
    return set(_norm_title(t).split())


def detect_duplicates(preprints: List[Dict[str, Any]], threshold: float = 0.72) -> List[Tuple[str, str, float]]:
    """간단한 Jaccard 기반 cross-server 중복 후보."""
    dup = []
    n = len(preprints)
    for i in range(n):
        for j in range(i + 1, n):
            a, b = preprints[i], preprints[j]
            if a["server"] == b["server"]:
                continue
            sa, sb = _token_set(a["title"]), _token_set(b["title"])
            if not sa or not sb:
                continue
            jac = len(sa & sb) / len(sa | sb)
            if jac >= threshold:
                dup.append((a["id"], b["id"], round(jac, 3)))
    return dup


# ---------------------------------------------------------------------------
# Publication lag
# ---------------------------------------------------------------------------

def _parse_date(s: str) -> Optional[date]:
    try:
        return datetime.strptime(s, "%Y-%m-%d").date()
    except Exception:
        return None


def compute_publication_lag(
    preprints: List[Dict[str, Any]], publications: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
    pub_by_id = {p["preprint_id"]: p for p in publications}
    rows = []
    for pp in preprints:
        pub = pub_by_id.get(pp["id"])
        posted = _parse_date(pp["posted_date"])
        if pub:
            published = _parse_date(pub["published_date"])
            lag = (published - posted).days if posted and published else None
            rows.append(
                {
                    "preprint_id": pp["id"],
                    "title": pp["title"],
                    "server": pp["server"],
                    "posted_date": pp["posted_date"],
                    "published_date": pub["published_date"],
                    "journal": pub["journal"],
                    "pmid": pub.get("pmid"),
                    "lag_days": lag,
                    "status": "published",
                }
            )
        else:
            rows.append(
                {
                    "preprint_id": pp["id"],
                    "title": pp["title"],
                    "server": pp["server"],
                    "posted_date": pp["posted_date"],
                    "published_date": None,
                    "journal": None,
                    "pmid": None,
                    "lag_days": None,
                    "status": "unpublished",
                }
            )
    return rows


# ---------------------------------------------------------------------------
# Watchlist (sqlite)
# ---------------------------------------------------------------------------

def _init_watchlist(conn: sqlite3.Connection) -> None:
    conn.execute(
        "CREATE TABLE IF NOT EXISTS watchlist ("
        "id INTEGER PRIMARY KEY AUTOINCREMENT, "
        "kind TEXT NOT NULL, "
        "value TEXT NOT NULL UNIQUE, "
        "created_at TEXT NOT NULL)"
    )
    conn.commit()


def watchlist_conn(path: str = WATCHLIST_DB) -> sqlite3.Connection:
    conn = sqlite3.connect(path)
    _init_watchlist(conn)
    return conn


def add_watch(conn: sqlite3.Connection, kind: str, value: str) -> bool:
    try:
        conn.execute(
            "INSERT INTO watchlist (kind, value, created_at) VALUES (?, ?, ?)",
            (kind, value.strip(), datetime.utcnow().isoformat()),
        )
        conn.commit()
        return True
    except sqlite3.IntegrityError:
        return False


def remove_watch(conn: sqlite3.Connection, watch_id: int) -> None:
    conn.execute("DELETE FROM watchlist WHERE id = ?", (watch_id,))
    conn.commit()


def list_watches(conn: sqlite3.Connection) -> List[Tuple[int, str, str]]:
    cur = conn.execute("SELECT id, kind, value FROM watchlist ORDER BY id")
    return [(r[0], r[1], r[2]) for r in cur.fetchall()]


def watch_match_score(pp: Dict[str, Any], watches: Iterable[Tuple[int, str, str]]) -> int:
    text_kw = " ".join(
        [pp.get("title", ""), pp.get("abstract", ""), " ".join(pp.get("labels", []) or [])]
    ).lower()
    authors = " ".join(pp.get("authors", []) or []).lower()
    affils = " ".join(pp.get("affiliations", []) or []).lower()
    s = 0
    for _id, kind, value in watches:
        v = value.strip().lower()
        if not v:
            continue
        if kind == "keyword" and v in text_kw:
            s += 2
        elif kind == "author" and v in authors:
            s += 3
        elif kind == "affiliation" and v in affils:
            s += 3
    return s


# ---------------------------------------------------------------------------
# Trend
# ---------------------------------------------------------------------------

def monthly_trend(
    preprints: List[Dict[str, Any]], publications: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
    pub_ids = {p["preprint_id"] for p in publications}
    bucket: Dict[Tuple[str, str], Dict[str, Any]] = defaultdict(
        lambda: {"count": 0, "avail_sum": 0, "published": 0}
    )
    for pp in preprints:
        ym = pp["posted_date"][:7]
        for lab in pp.get("labels", ["uncategorized"]):
            key = (ym, lab)
            bucket[key]["count"] += 1
            bucket[key]["avail_sum"] += pp.get("availability", 0)
            if pp["id"] in pub_ids:
                bucket[key]["published"] += 1
    rows = []
    for (ym, lab), v in sorted(bucket.items()):
        cnt = v["count"]
        rows.append(
            {
                "month": ym,
                "label": lab,
                "preprint_count": cnt,
                "avg_availability": round(v["avail_sum"] / cnt, 2) if cnt else 0.0,
                "published_count": v["published"],
                "published_ratio": round(v["published"] / cnt, 2) if cnt else 0.0,
            }
        )
    return rows


# ---------------------------------------------------------------------------
# Digest
# ---------------------------------------------------------------------------

def build_digest(
    preprints: List[Dict[str, Any]], top_n: int = 10, lang: str = "ko"
) -> str:
    ranked = sorted(
        preprints,
        key=lambda p: (p.get("availability", 0), p.get("posted_date", "")),
        reverse=True,
    )[:top_n]
    if lang == "ko":
        header = f"# ObesityPreprintRadar 일일 디제스트 (상위 {top_n}건)\n\n"
        header += f"_{DISCLAIMER}_\n\n"
        body_lines = []
        for i, p in enumerate(ranked, 1):
            labels = ", ".join(p.get("labels", []))
            body_lines.append(
                f"{i}. [{p['server']}] {p['title']}\n"
                f"   - 저자: {', '.join(p['authors'])}\n"
                f"   - 게시일: {p['posted_date']} (v{p.get('version', 1)})\n"
                f"   - 토픽: {labels}\n"
                f"   - Availability score: {p['availability']}/5\n"
                f"   - DOI: {p['doi']}\n"
            )
        return header + "\n".join(body_lines)
    else:
        header = f"# ObesityPreprintRadar Daily Digest (Top {top_n})\n\n"
        header += "_For research reference only. Not a substitute for clinical decisions. All data are synthetic._\n\n"
        body_lines = []
        for i, p in enumerate(ranked, 1):
            labels = ", ".join(p.get("labels", []))
            body_lines.append(
                f"{i}. [{p['server']}] {p['title']}\n"
                f"   - Authors: {', '.join(p['authors'])}\n"
                f"   - Posted: {p['posted_date']} (v{p.get('version', 1)})\n"
                f"   - Topics: {labels}\n"
                f"   - Availability: {p['availability']}/5\n"
                f"   - DOI: {p['doi']}\n"
            )
        return header + "\n".join(body_lines)


def digest_to_docx(text: str) -> bytes:
    from docx import Document

    doc = Document()
    for line in text.split("\n"):
        if line.startswith("# "):
            doc.add_heading(line[2:].strip(), level=1)
        elif line.startswith("## "):
            doc.add_heading(line[3:].strip(), level=2)
        else:
            doc.add_paragraph(line)
    buf = io.BytesIO()
    doc.save(buf)
    return buf.getvalue()


def trend_to_docx(rows: List[Dict[str, Any]]) -> bytes:
    from docx import Document

    doc = Document()
    doc.add_heading("ObesityPreprintRadar 월간 Trend 리포트", level=1)
    doc.add_paragraph(DISCLAIMER)
    if not rows:
        doc.add_paragraph("데이터 없음.")
    else:
        table = doc.add_table(rows=1, cols=5)
        hdr = table.rows[0].cells
        hdr[0].text = "Month"
        hdr[1].text = "Label"
        hdr[2].text = "Preprint #"
        hdr[3].text = "Avg Avail"
        hdr[4].text = "Published Ratio"
        for r in rows:
            row = table.add_row().cells
            row[0].text = str(r["month"])
            row[1].text = str(r["label"])
            row[2].text = str(r["preprint_count"])
            row[3].text = str(r["avg_availability"])
            row[4].text = str(r["published_ratio"])
    buf = io.BytesIO()
    doc.save(buf)
    return buf.getvalue()


# ---------------------------------------------------------------------------
# CLI summary
# ---------------------------------------------------------------------------

def cli_summary() -> str:
    preprints = load_preprints()
    rules = load_topics()
    pubs = load_publications()
    enriched = enrich(preprints, rules)

    total = len(enriched)
    server_counts = Counter(p["server"] for p in enriched)
    avg_avail = round(sum(p["availability"] for p in enriched) / total, 2) if total else 0
    label_counter: Counter = Counter()
    for p in enriched:
        for lab in p["labels"]:
            label_counter[lab] += 1
    top5 = label_counter.most_common(5)

    lag_rows = compute_publication_lag(enriched, pubs)
    published_lags = [r["lag_days"] for r in lag_rows if r["lag_days"] is not None]
    avg_lag = round(sum(published_lags) / len(published_lags), 1) if published_lags else None
    published_n = sum(1 for r in lag_rows if r["status"] == "published")

    duplicates = detect_duplicates(enriched)

    lines = [
        "=== ObesityPreprintRadar CLI Summary ===",
        DISCLAIMER,
        "",
        f"Total preprints: {total}",
        "Server distribution:",
    ]
    for s, c in sorted(server_counts.items()):
        lines.append(f"  - {s}: {c}")
    lines.append(f"Average availability score: {avg_avail}/5")
    lines.append("Top 5 topics:")
    for lab, c in top5:
        lines.append(f"  - {lab}: {c}")
    lines.append(f"Matched PubMed publications: {published_n}/{total}")
    if avg_lag is not None:
        lines.append(f"Average preprint->publication lag (days): {avg_lag}")
    else:
        lines.append("Average preprint->publication lag (days): N/A")
    lines.append(f"Cross-server dedup candidates: {len(duplicates)}")
    for a, b, jac in duplicates[:5]:
        lines.append(f"  - {a} <=> {b} (jaccard={jac})")
    return "\n".join(lines)


# ---------------------------------------------------------------------------
# Streamlit UI
# ---------------------------------------------------------------------------

def _render_disclaimer(st_module) -> None:
    st_module.caption(f":warning: {DISCLAIMER}")


def run_streamlit() -> None:  # pragma: no cover - UI entry
    import pandas as pd
    import streamlit as st

    st.set_page_config(page_title="ObesityPreprintRadar", layout="wide")
    st.title("ObesityPreprintRadar")
    st.caption(
        "bioRxiv·medRxiv·Research Square·ChemRxiv 비만/대사 preprint 일일 큐레이션 · 재현성 평가 · 출판 추적"
    )
    _render_disclaimer(st)

    preprints = load_preprints()
    rules = load_topics()
    pubs = load_publications()
    kols = load_kols()
    enriched = enrich(preprints, rules)

    conn = watchlist_conn()
    watches = list_watches(conn)

    tabs = st.tabs(
        [
            "최신 preprint 피드",
            "availability scoring 필터",
            "동일 저자 PubMed 추적",
            "watchlist",
            "월간 trend 리포트",
            "digest 미리보기",
        ]
    )

    # ----- 최신 preprint 피드 -----
    with tabs[0]:
        st.subheader("최신 preprint 피드 (일별 카드)")
        # watch-aware ranking
        for p in enriched:
            p["_watch_score"] = watch_match_score(p, watches)
        feed = sorted(
            enriched, key=lambda p: (p["_watch_score"], p["posted_date"]), reverse=True
        )
        servers = sorted({p["server"] for p in feed})
        chosen_server = st.multiselect("server 필터", servers, default=servers)
        feed = [p for p in feed if p["server"] in chosen_server]
        st.write(f"총 {len(feed)}건")
        for p in feed[:40]:
            with st.container(border=True):
                cols = st.columns([3, 1])
                with cols[0]:
                    st.markdown(f"**[{p['server']}] {p['title']}**")
                    st.caption(
                        f"{', '.join(p['authors'])} · {p['posted_date']} (v{p.get('version', 1)})"
                    )
                    st.write(p["abstract"])
                    st.caption("토픽: " + ", ".join(p["labels"]))
                with cols[1]:
                    st.metric("Availability", f"{p['availability']}/5")
                    if p["_watch_score"] > 0:
                        st.success(f"watch score {p['_watch_score']}")
                    st.code(p["doi"])
        _render_disclaimer(st)

    # ----- availability filter -----
    with tabs[1]:
        st.subheader("availability scoring 필터")
        min_score = st.slider("최소 availability score", 0, 5, 3)
        keyword = st.text_input("키워드/서브토픽 (선택)")
        all_labels = sorted({lab for p in enriched for lab in p["labels"]})
        chosen_labels = st.multiselect("토픽 필터", all_labels)
        filtered = [p for p in enriched if p["availability"] >= min_score]
        if keyword:
            kw = keyword.lower()
            filtered = [
                p
                for p in filtered
                if kw in p["title"].lower() or kw in p["abstract"].lower()
            ]
        if chosen_labels:
            filtered = [p for p in filtered if any(l in chosen_labels for l in p["labels"])]
        df = pd.DataFrame(
            [
                {
                    "id": p["id"],
                    "server": p["server"],
                    "title": p["title"],
                    "posted_date": p["posted_date"],
                    "version": p.get("version", 1),
                    "availability": p["availability"],
                    "labels": ", ".join(p["labels"]),
                    "doi": p["doi"],
                }
                for p in filtered
            ]
        )
        st.write(f"필터 결과 {len(df)}건")
        st.dataframe(df, use_container_width=True)
        _render_disclaimer(st)

    # ----- pub tracking -----
    with tabs[2]:
        st.subheader("동일 저자 PubMed 출판 추적")
        rows = compute_publication_lag(enriched, pubs)
        df = pd.DataFrame(rows)
        published = df[df["status"] == "published"].copy()
        unpublished = df[df["status"] == "unpublished"].copy()

        c1, c2, c3 = st.columns(3)
        c1.metric("총 preprint", len(df))
        c2.metric("매칭 publication", len(published))
        c3.metric(
            "평균 lag (일)",
            f"{published['lag_days'].mean():.1f}" if not published.empty else "N/A",
        )

        st.markdown("**Lag 분포 (히스토그램)**")
        if not published.empty:
            hist = (
                pd.cut(
                    published["lag_days"],
                    bins=[-1, 7, 14, 21, 30, 45, 60, 120],
                    labels=["0-7", "8-14", "15-21", "22-30", "31-45", "46-60", "61-120"],
                )
                .value_counts()
                .sort_index()
            )
            st.bar_chart(hist)
        else:
            st.info("매칭된 publication 없음")

        st.markdown("**Published preprints**")
        st.dataframe(published, use_container_width=True)
        st.markdown("**Unpublished preprints (출판 안 됨 flag)**")
        st.dataframe(unpublished, use_container_width=True)
        _render_disclaimer(st)

    # ----- watchlist -----
    with tabs[3]:
        st.subheader("watchlist (sqlite 저장)")
        with st.form("add_watch", clear_on_submit=True):
            kind = st.selectbox("종류", ["keyword", "author", "affiliation"])
            value = st.text_input("값")
            submitted = st.form_submit_button("추가")
            if submitted and value.strip():
                ok = add_watch(conn, kind, value.strip())
                if ok:
                    st.success(f"{kind}: {value} 추가됨")
                else:
                    st.warning("이미 존재합니다")

        st.markdown("**KOL seed (참고)**")
        st.dataframe(pd.DataFrame(kols), use_container_width=True)

        st.markdown("**현재 watchlist**")
        cur_watches = list_watches(conn)
        if not cur_watches:
            st.info("등록된 항목 없음")
        else:
            for wid, k, v in cur_watches:
                cols = st.columns([1, 2, 4, 1])
                cols[0].write(wid)
                cols[1].write(k)
                cols[2].write(v)
                if cols[3].button("삭제", key=f"del-{wid}"):
                    remove_watch(conn, wid)
                    st.rerun()
        _render_disclaimer(st)

    # ----- trend -----
    with tabs[4]:
        st.subheader("월간 trend 리포트")
        rows = monthly_trend(enriched, pubs)
        df = pd.DataFrame(rows)
        st.dataframe(df, use_container_width=True)
        if not df.empty:
            pivot_cnt = df.pivot_table(
                index="month", columns="label", values="preprint_count", fill_value=0
            )
            st.markdown("**월별 토픽 preprint 수**")
            st.bar_chart(pivot_cnt)
            pivot_avail = df.pivot_table(
                index="month", columns="label", values="avg_availability", fill_value=0
            )
            st.markdown("**월별 토픽 평균 availability**")
            st.line_chart(pivot_avail)
        docx_bytes = trend_to_docx(rows)
        st.download_button(
            "trend 리포트 docx 다운로드",
            docx_bytes,
            file_name="obesity_preprint_trend.docx",
        )
        _render_disclaimer(st)

    # ----- digest -----
    with tabs[5]:
        st.subheader("digest 미리보기")
        top_n = st.slider("상위 N건", 5, 30, 10)
        lang = st.radio("언어", ["ko", "en"], horizontal=True)
        text = build_digest(enriched, top_n=top_n, lang=lang)
        st.markdown(text)
        docx_bytes = digest_to_docx(text)
        st.download_button(
            "digest docx 다운로드",
            docx_bytes,
            file_name=f"obesity_preprint_digest_{lang}.docx",
        )
        _render_disclaimer(st)


# ---------------------------------------------------------------------------
# Entry
# ---------------------------------------------------------------------------

def main(argv: Optional[List[str]] = None) -> int:
    parser = argparse.ArgumentParser(description="ObesityPreprintRadar")
    parser.add_argument(
        "--summary",
        action="store_true",
        help="Print CLI summary (preprint count, server distribution, avg availability, top5 topics, avg pub lag)",
    )
    args, _ = parser.parse_known_args(argv)
    if args.summary:
        print(cli_summary())
        return 0
    # If invoked via `streamlit run app.py`, streamlit imports this module and
    # then executes the script top-level. Detect streamlit runtime.
    try:
        from streamlit.runtime.scriptrunner import get_script_run_ctx  # type: ignore

        if get_script_run_ctx() is not None:
            run_streamlit()
            return 0
    except Exception:
        pass
    # Plain python: print short help-ish summary.
    print(cli_summary())
    return 0


if __name__ == "__main__":
    sys.exit(main())
