"""
DMConfLateBreaker-Kor (디엠콘프레이트브레이커코어)
DM 학회 late-breaker / oral / plenary / symposium / poster abstract
자동 수집·정규화·관심영역 필터·embargo lift 카운트다운·KOL 추적·일정 export.

참고용·연구용 도구. 임상 의사결정 대체 금지.
모든 데이터는 오프라인 mock data (외부 네트워크 호출 0).
"""

from __future__ import annotations

import json
import os
import sqlite3
import sys
from collections import defaultdict
from datetime import datetime, timezone
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple

ROOT = Path(__file__).resolve().parent
DATA_DIR = ROOT / "data"
DB_PATH = ROOT / "data" / "watchlist.sqlite"

DISCLAIMER = (
    "참고용·연구용 도구입니다. 본 도구는 학회 abstract 추적·요약을 보조할 뿐이며, "
    "임상 진료 의사결정·환자 치료의 근거로 사용해서는 안 됩니다. "
    "Embargo·저자·소속 정보는 검증되지 않은 mock data이며, 실제 학회 정보와 다를 수 있습니다."
)
DISCLAIMER_EN = (
    "For reference and research use only. This tool assists conference abstract tracking "
    "and must not be used as the basis for clinical decisions. All data shown is offline "
    "mock data and may differ from actual conference content."
)

KOREAN_AFFILIATION_KEYS = [
    "korea", "seoul", "kda", "yonsei", "snu", "asan", "samsung", "soonchunhyang",
    "sungkyunkwan", "catholic university of korea", "kyungpook", "hallym",
    "korea university", "hanyang", "ku ", "national university hospital",
]


# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def _read_json(path: Path) -> Any:
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)


def load_conferences() -> List[Dict[str, Any]]:
    return _read_json(DATA_DIR / "conferences.json")["conferences"]


def load_abstracts() -> List[Dict[str, Any]]:
    return _read_json(DATA_DIR / "abstracts.json")["abstracts"]


def load_publications() -> List[Dict[str, Any]]:
    return _read_json(DATA_DIR / "publications.json")["publications"]


def load_kols() -> List[Dict[str, Any]]:
    return _read_json(DATA_DIR / "kol_seed.json")["kols"]


def index_conferences(conferences: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
    return {c["code"]: c for c in conferences}


# ---------------------------------------------------------------------------
# Domain helpers
# ---------------------------------------------------------------------------
def has_korean_affiliation(affiliations: Iterable[str]) -> bool:
    blob = " ".join(a.lower() for a in affiliations)
    return any(k in blob for k in KOREAN_AFFILIATION_KEYS)


def parse_iso(dt: str) -> datetime:
    return datetime.fromisoformat(dt)


def now_utc() -> datetime:
    return datetime.now(tz=timezone.utc)


def time_to_embargo(embargo_iso: str, ref: Optional[datetime] = None) -> Dict[str, Any]:
    ref = ref or now_utc()
    target = parse_iso(embargo_iso)
    delta = target - ref
    total_seconds = int(delta.total_seconds())
    lifted = total_seconds <= 0
    abs_seconds = abs(total_seconds)
    days, rem = divmod(abs_seconds, 86400)
    hours, rem = divmod(rem, 3600)
    minutes, _ = divmod(rem, 60)
    label = f"{days}d {hours}h {minutes}m"
    return {
        "lifted": lifted,
        "label": ("LIFTED -" + label) if lifted else ("T-" + label),
        "target": target,
        "seconds": total_seconds,
    }


def next_embargo_lift(abstracts: List[Dict[str, Any]], ref: Optional[datetime] = None) -> Optional[Dict[str, Any]]:
    ref = ref or now_utc()
    upcoming = [a for a in abstracts if parse_iso(a["embargo_time"]) > ref]
    if not upcoming:
        return None
    return min(upcoming, key=lambda a: parse_iso(a["embargo_time"]))


def filter_abstracts(
    abstracts: List[Dict[str, Any]],
    *,
    conferences: Optional[List[str]] = None,
    session_types: Optional[List[str]] = None,
    topics: Optional[List[str]] = None,
    korea_only: bool = False,
    date_from: Optional[datetime] = None,
    date_to: Optional[datetime] = None,
) -> List[Dict[str, Any]]:
    out = []
    for a in abstracts:
        if conferences and a["conference"] not in conferences:
            continue
        if session_types and a["session_type"] not in session_types:
            continue
        if topics and not (set(topics) & set(a.get("topic_tags", []))):
            continue
        if korea_only and not a.get("korea_author"):
            continue
        slot = parse_iso(a["session_slot"])
        if date_from and slot < date_from:
            continue
        if date_to and slot > date_to:
            continue
        out.append(a)
    return out


def dedup_abstracts(abstracts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Detect near-duplicates by normalized title token overlap."""
    def normalize(s: str) -> set:
        return {w.strip(".,:;()[]").lower() for w in s.split() if len(w) > 3}

    groups: List[List[Dict[str, Any]]] = []
    for a in abstracts:
        tokens = normalize(a["title"])
        matched = False
        for g in groups:
            base_tokens = normalize(g[0]["title"])
            if not base_tokens:
                continue
            overlap = len(tokens & base_tokens) / max(1, len(tokens | base_tokens))
            if overlap >= 0.55:
                g.append(a)
                matched = True
                break
        if not matched:
            groups.append([a])
    return [g for g in groups if len(g) > 1]


def join_publication_lag(
    abstracts: List[Dict[str, Any]],
    publications: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
    pub_by_abs = {p["abstract_id"]: p for p in publications}
    rows = []
    for a in abstracts:
        pub = pub_by_abs.get(a["id"])
        if not pub:
            continue
        rows.append({
            "abstract_id": a["id"],
            "conference": a["conference"],
            "title": a["title"],
            "first_author": pub["first_author"],
            "journal": pub["journal"],
            "pubmed_id": pub["pubmed_id"],
            "session_slot": a["session_slot"],
            "publication_date": pub["publication_date"],
            "lag_days": pub["lag_days"],
            "korea_author": a.get("korea_author", False),
        })
    rows.sort(key=lambda r: r["lag_days"])
    return rows


# ---------------------------------------------------------------------------
# Watchlist (sqlite)
# ---------------------------------------------------------------------------
def _db() -> sqlite3.Connection:
    DATA_DIR.mkdir(parents=True, exist_ok=True)
    conn = sqlite3.connect(DB_PATH)
    conn.execute(
        "CREATE TABLE IF NOT EXISTS watchlist ("
        "kol_id TEXT PRIMARY KEY,"
        "name TEXT NOT NULL,"
        "affiliation TEXT,"
        "added_at TEXT NOT NULL)"
    )
    return conn


def watchlist_add(kol_id: str, name: str, affiliation: str) -> None:
    conn = _db()
    try:
        conn.execute(
            "INSERT OR REPLACE INTO watchlist VALUES (?, ?, ?, ?)",
            (kol_id, name, affiliation, datetime.now(tz=timezone.utc).isoformat()),
        )
        conn.commit()
    finally:
        conn.close()


def watchlist_remove(kol_id: str) -> None:
    conn = _db()
    try:
        conn.execute("DELETE FROM watchlist WHERE kol_id = ?", (kol_id,))
        conn.commit()
    finally:
        conn.close()


def watchlist_list() -> List[Dict[str, str]]:
    conn = _db()
    try:
        rows = conn.execute(
            "SELECT kol_id, name, affiliation, added_at FROM watchlist ORDER BY added_at DESC"
        ).fetchall()
    finally:
        conn.close()
    return [
        {"kol_id": r[0], "name": r[1], "affiliation": r[2], "added_at": r[3]}
        for r in rows
    ]


# ---------------------------------------------------------------------------
# ICS export
# ---------------------------------------------------------------------------
def build_ics(abstracts: List[Dict[str, Any]], conferences: Dict[str, Dict[str, Any]]) -> str:
    """Build a minimal RFC-5545 ICS calendar without external network."""
    def fmt(dt: datetime) -> str:
        return dt.astimezone(timezone.utc).strftime("%Y%m%dT%H%M%SZ")

    lines = [
        "BEGIN:VCALENDAR",
        "VERSION:2.0",
        "PRODID:-//DMConfLateBreaker-Kor//EN",
        "CALSCALE:GREGORIAN",
    ]
    for a in abstracts:
        start = parse_iso(a["session_slot"])
        end = start.replace(minute=(start.minute + 30) % 60, hour=start.hour + (1 if start.minute + 30 >= 60 else 0))
        conf = conferences.get(a["conference"], {})
        summary = f"[{a['conference']}/{a['session_type']}] {a['title']}".replace("\n", " ")
        desc = (
            f"Authors: {', '.join(a['authors'])}\\n"
            f"Affiliations: {', '.join(a['affiliations'])}\\n"
            f"Embargo: {a['embargo_time']}\\n"
            f"Topics: {', '.join(a.get('topic_tags', []))}\\n"
            "참고용·연구용. 임상 의사결정 금지."
        )
        location = f"{conf.get('city', '')}, {conf.get('country', '')}".strip(", ")
        lines.extend([
            "BEGIN:VEVENT",
            f"UID:{a['id']}@dmconflatebreakerkor",
            f"DTSTAMP:{fmt(now_utc())}",
            f"DTSTART:{fmt(start)}",
            f"DTEND:{fmt(end)}",
            f"SUMMARY:{summary}",
            f"DESCRIPTION:{desc}",
            f"LOCATION:{location}",
            "END:VEVENT",
        ])
    lines.append("END:VCALENDAR")
    return "\r\n".join(lines) + "\r\n"


# ---------------------------------------------------------------------------
# Digest
# ---------------------------------------------------------------------------
def build_digest_text(abstracts: List[Dict[str, Any]], lang: str = "ko") -> str:
    by_conf: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
    for a in abstracts:
        by_conf[a["conference"]].append(a)
    header = (
        "DMConfLateBreaker-Kor 일일 다이제스트\n참고용·연구용. 임상 의사결정 금지.\n"
        if lang == "ko"
        else "DMConfLateBreaker-Kor Daily Digest\nFor reference/research only. Not for clinical decisions.\n"
    )
    sections = [header]
    for conf, items in sorted(by_conf.items()):
        sections.append(f"\n[{conf}] ({len(items)} abstracts)")
        for a in items[:8]:
            korea_mark = " [KR]" if a.get("korea_author") else ""
            line = f" - {a['session_type'].upper()}{korea_mark} {a['title']}"
            sections.append(line)
            sections.append(f"    Authors: {', '.join(a['authors'])}")
            sections.append(f"    Embargo: {a['embargo_time']}  Topics: {', '.join(a.get('topic_tags', []))}")
    return "\n".join(sections)


def build_digest_docx(abstracts: List[Dict[str, Any]], lang: str = "ko") -> bytes:
    """Generate a Word doc with python-docx. Falls back to plain text bytes if import fails."""
    try:
        from docx import Document
    except Exception:
        return build_digest_text(abstracts, lang).encode("utf-8")
    doc = Document()
    doc.add_heading("DMConfLateBreaker-Kor Digest", level=1)
    doc.add_paragraph(DISCLAIMER if lang == "ko" else DISCLAIMER_EN).italic = True
    by_conf: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
    for a in abstracts:
        by_conf[a["conference"]].append(a)
    for conf, items in sorted(by_conf.items()):
        doc.add_heading(f"{conf} ({len(items)} abstracts)", level=2)
        for a in items:
            korea_mark = " [KR]" if a.get("korea_author") else ""
            doc.add_heading(f"{a['session_type'].upper()}{korea_mark} - {a['title']}", level=3)
            doc.add_paragraph(f"Authors: {', '.join(a['authors'])}")
            doc.add_paragraph(f"Affiliations: {', '.join(a['affiliations'])}")
            doc.add_paragraph(f"Embargo: {a['embargo_time']}")
            doc.add_paragraph(f"Topics: {', '.join(a.get('topic_tags', []))}")
            doc.add_paragraph(a.get("abstract_text", ""))
    buf = BytesIO()
    doc.save(buf)
    return buf.getvalue()


# ---------------------------------------------------------------------------
# CLI summary (for preview / smoke test)
# ---------------------------------------------------------------------------
def cli_summary() -> int:
    abstracts = load_abstracts()
    conferences = load_conferences()
    kols = load_kols()
    korea_abs = [a for a in abstracts if a.get("korea_author")]
    nxt = next_embargo_lift(abstracts)
    print("DMConfLateBreaker-Kor CLI summary")
    print("=" * 50)
    print(f"Conferences loaded     : {len(conferences)}")
    print(f"Total abstracts        : {len(abstracts)}")
    print(f"Korean-author abstracts: {len(korea_abs)}")
    print(f"KOL seed entries       : {len(kols)}")
    by_type: Dict[str, int] = defaultdict(int)
    for a in abstracts:
        by_type[a["session_type"]] += 1
    print("Session type breakdown :", dict(by_type))
    by_conf: Dict[str, int] = defaultdict(int)
    for a in abstracts:
        by_conf[a["conference"]] += 1
    print("Per-conference counts  :", dict(by_conf))
    if nxt:
        tte = time_to_embargo(nxt["embargo_time"])
        print(f"Next embargo lift      : {nxt['embargo_time']} ({tte['label']})")
        print(f"  -> {nxt['title'][:80]}")
    else:
        print("Next embargo lift      : (none upcoming relative to now)")
    print("Reminder               : 참고용·연구용. 임상 의사결정 금지.")
    return 0


# ---------------------------------------------------------------------------
# Streamlit UI
# ---------------------------------------------------------------------------
def run_streamlit() -> None:
    import streamlit as st  # imported lazily so CLI works without streamlit installed.

    st.set_page_config(page_title="DMConfLateBreaker-Kor", layout="wide")
    st.title("DMConfLateBreaker-Kor")
    st.caption("DM 학회 late-breaker abstract 추적기 (오프라인 mock data)")
    st.warning(DISCLAIMER)

    abstracts = load_abstracts()
    conferences_list = load_conferences()
    conferences = index_conferences(conferences_list)
    publications = load_publications()
    kols = load_kols()

    tabs = st.tabs([
        "오늘의 late-breaker",
        "관심영역 필터",
        "KOL watchlist",
        "embargo + ICS export",
        "dedup + PubMed lag",
        "digest 미리보기",
    ])

    # --- Tab 1: 오늘의 late-breaker -------------------------------------
    with tabs[0]:
        st.subheader("Late-breaker / plenary 카드")
        lb_only = st.checkbox("Late-breaker / plenary 만 표시", value=True)
        items = [a for a in abstracts if (not lb_only) or a["session_type"] in {"late-breaker", "plenary"}]
        items.sort(key=lambda a: parse_iso(a["embargo_time"]))
        for a in items:
            tte = time_to_embargo(a["embargo_time"])
            conf = conferences.get(a["conference"], {})
            with st.container(border=True):
                cols = st.columns([3, 1])
                with cols[0]:
                    st.markdown(f"**[{a['conference']} · {a['session_type']}]** {a['title']}")
                    st.caption(f"{', '.join(a['authors'])} | {', '.join(a['affiliations'])}")
                    st.caption(f"Topics: {', '.join(a.get('topic_tags', []))}"
                               f"{'  ·  [KOREA]' if a.get('korea_author') else ''}")
                with cols[1]:
                    st.metric("Embargo", tte["label"])
                    st.caption(conf.get("city", ""))
        st.caption(DISCLAIMER)

    # --- Tab 2: 관심영역 필터 -------------------------------------------
    with tabs[1]:
        st.subheader("필터")
        col1, col2, col3 = st.columns(3)
        all_topics = sorted({t for a in abstracts for t in a.get("topic_tags", [])})
        all_session_types = sorted({a["session_type"] for a in abstracts})
        all_confs = sorted({a["conference"] for a in abstracts})
        with col1:
            sel_confs = st.multiselect("학회", all_confs, default=all_confs)
        with col2:
            sel_types = st.multiselect("세션 타입", all_session_types, default=all_session_types)
        with col3:
            sel_topics = st.multiselect("토픽", all_topics)
        korea_only = st.toggle("한국 KOL 발표만", value=False)
        filtered = filter_abstracts(
            abstracts,
            conferences=sel_confs,
            session_types=sel_types,
            topics=sel_topics or None,
            korea_only=korea_only,
        )
        st.caption(f"결과: {len(filtered)} abstracts")
        try:
            import pandas as pd
            df = pd.DataFrame([
                {
                    "Conference": a["conference"],
                    "Type": a["session_type"],
                    "Title": a["title"],
                    "Authors": ", ".join(a["authors"]),
                    "Topics": ", ".join(a.get("topic_tags", [])),
                    "Embargo": a["embargo_time"],
                    "Korea": a.get("korea_author", False),
                }
                for a in filtered
            ])
            st.dataframe(df, use_container_width=True, hide_index=True)
        except Exception:
            for a in filtered:
                st.write(f"- [{a['conference']}/{a['session_type']}] {a['title']}")
        st.caption(DISCLAIMER)

    # --- Tab 3: KOL watchlist -----------------------------------------
    with tabs[2]:
        st.subheader("KOL watchlist")
        st.caption("Seed KOL 목록에서 추가/제거. SQLite에 저장됩니다.")
        wl = {w["kol_id"] for w in watchlist_list()}
        seed_cols = st.columns(2)
        for i, k in enumerate(kols):
            with seed_cols[i % 2]:
                in_wl = k["id"] in wl
                label = ("★ " if in_wl else "  ") + f"{k['name']} – {k['affiliation']}"
                if st.button(label, key=f"kol-{k['id']}"):
                    if in_wl:
                        watchlist_remove(k["id"])
                    else:
                        watchlist_add(k["id"], k["name"], k["affiliation"])
                    st.rerun()
        st.markdown("### 현재 watchlist")
        for w in watchlist_list():
            st.write(f"- {w['name']} ({w['affiliation']}) — added {w['added_at']}")
        st.caption(DISCLAIMER)

    # --- Tab 4: embargo countdown + ICS --------------------------------
    with tabs[3]:
        st.subheader("Embargo 카운트다운 + ICS export")
        rows = []
        for a in abstracts:
            tte = time_to_embargo(a["embargo_time"])
            rows.append({"id": a["id"], "title": a["title"], "embargo_time": a["embargo_time"], "status": tte["label"]})
        try:
            import pandas as pd
            st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
        except Exception:
            for r in rows:
                st.write(f"- {r['status']}: {r['title']} ({r['embargo_time']})")
        sel_ids = st.multiselect("ICS export 대상", [a["id"] for a in abstracts])
        if sel_ids:
            chosen = [a for a in abstracts if a["id"] in sel_ids]
            ics_text = build_ics(chosen, conferences)
            st.download_button(
                "ICS 다운로드",
                data=ics_text.encode("utf-8"),
                file_name="dmconflatebreaker.ics",
                mime="text/calendar",
            )
        st.caption(DISCLAIMER)

    # --- Tab 5: dedup + PubMed lag -----------------------------------
    with tabs[4]:
        st.subheader("Dedup 후보")
        groups = dedup_abstracts(abstracts)
        if not groups:
            st.info("중복 의심 그룹 없음.")
        for g in groups:
            with st.container(border=True):
                st.write("Duplicate-candidate group:")
                for a in g:
                    st.write(f"  - [{a['conference']}/{a['session_type']}] {a['title']}")
        st.subheader("Abstract → PubMed lag")
        lag_rows = join_publication_lag(abstracts, publications)
        try:
            import pandas as pd
            st.dataframe(pd.DataFrame(lag_rows), use_container_width=True, hide_index=True)
            st.caption(
                f"median lag = "
                f"{int(pd.Series([r['lag_days'] for r in lag_rows]).median()) if lag_rows else 'NA'} days"
            )
        except Exception:
            for r in lag_rows:
                st.write(f"- {r['lag_days']}d  {r['title']}  -> {r['journal']}")
        st.caption(DISCLAIMER)

    # --- Tab 6: digest preview ---------------------------------------
    with tabs[5]:
        st.subheader("Digest 미리보기")
        lang = st.radio("언어", ["ko", "en"], horizontal=True)
        text = build_digest_text(abstracts, lang=lang)
        st.text_area("Digest", value=text, height=400)
        docx_bytes = build_digest_docx(abstracts, lang=lang)
        st.download_button(
            "DOCX 다운로드",
            data=docx_bytes,
            file_name="dmconflatebreaker_digest.docx",
            mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
        )
        st.caption(DISCLAIMER)


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main() -> int:
    if "--cli-summary" in sys.argv:
        return cli_summary()
    # Default: run Streamlit. If streamlit isn't installed, fall back to CLI summary.
    try:
        run_streamlit()
        return 0
    except ModuleNotFoundError:
        print("[warn] streamlit not installed; falling back to CLI summary.\n")
        return cli_summary()


if __name__ == "__main__":
    raise SystemExit(main())
