#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LivingSR-Staleness-Kor / 리빙에스알스테일니스코어
-------------------------------------------------
비만(Obesity) 항비만 체계적 문헌고찰(SR)/PROSPERO 프로토콜의
'최신성 노후화(staleness)'를 평가하는 오프라인 연구알림 MVP.

기능:
  1) SR/PROSPERO 레코드의 PICO 구조화 (룰 기반)
  2) 검색일(search date) 이후 색인된 신규 적격 RCT 매칭
  3) 노후화 지표 (경과일 / 신규 적격 트라이얼 수 / 신규 참가자 N) + 신호등 등급
  4) 미출판 SR 추적 + 중복 PROSPERO 등록(같은 PICO, 다른 팀) 탐지 → 연구낭비/협업 알림
  5) PICO 갭 맵 (어떤 SR도 다루지 않은 항비만 PICO 조합)

표준 라이브러리(json/csv/argparse/re/datetime/math)만 사용. 네트워크/외부패키지 불필요.

⚠️ 참고용·연구용 (research/reference only, not for clinical decisions).
   합성(synthetic) 데모 데이터 기반이며 실제 PROSPERO/PubMed/CT.gov 레코드가 아님.
"""

import argparse
import json
import math
import os
import re
import sys
from datetime import datetime, date

ROOT = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(ROOT, "data")

DISCLAIMER = (
    "⚠️  참고용·연구용 (research/reference only, NOT for clinical decisions). "
    "합성 데모 데이터 기반 — 실제 PROSPERO/PubMed/CT.gov 레코드 아님."
)

# 신호등 등급 임계값 (신규 적격 트라이얼 수 기준)
GRADE_FRESH = 0      # 0건 → GREEN (최신)
GRADE_WATCH = 1      # 1~2건 → YELLOW (주의)
# 3건 이상 → RED (노후, 업데이트 권고)

# PICO 매칭 점수 가중치
WEIGHTS = {"population": 0.25, "intervention": 0.35, "comparator": 0.15, "outcome": 0.25}
MATCH_THRESHOLD = 0.45  # 적격으로 간주할 종합 점수 임계값


# ----------------------------------------------------------------------
# 데이터 로딩
# ----------------------------------------------------------------------
def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def load_data():
    srs = load_json(os.path.join(DATA_DIR, "srs.json"))["srs"]
    trials = load_json(os.path.join(DATA_DIR, "trials.json"))["trials"]
    return srs, trials


def parse_date(s):
    return datetime.strptime(s, "%Y-%m-%d").date()


# ----------------------------------------------------------------------
# 1) PICO 구조화 (룰 기반)
#    이미 구조화된 레코드는 정규화(소문자/토큰화)만 수행.
#    자유 텍스트 abstract가 들어오면 키워드 룰로 필드를 추출.
# ----------------------------------------------------------------------
PICO_LEXICON = {
    "population": ["obesity", "overweight", "bmi", "type 2 diabetes", "t2dm",
                   "sarcopenic obesity", "adult", "comorbidity", "masld", "nafld"],
    "intervention": ["glp-1", "semaglutide", "liraglutide", "tirzepatide", "drug",
                     "bariatric surgery", "gastric bypass", "sleeve gastrectomy", "surgery",
                     "behavioral", "diet", "exercise", "lifestyle"],
    "comparator": ["placebo", "usual care", "no intervention", "medical therapy",
                   "pharmacotherapy"],
    "outcome": ["body weight", "weight loss", "bmi", "waist circumference",
                "diabetes remission", "hba1c", "glycemic control", "muscle mass",
                "lean mass"],
}


def normalize_tokens(values):
    """리스트 값들을 소문자 토큰 집합으로 정규화."""
    out = set()
    for v in values:
        v = v.strip().lower()
        if v:
            out.add(v)
            # 다중 단어는 개별 단어 토큰도 추가해 부분 겹침 허용
            for w in re.split(r"[\s/]+", v):
                if len(w) > 1:
                    out.add(w)
    return out


def extract_pico_from_text(text):
    """자유 텍스트 abstract에서 PICO 키워드를 룰 기반으로 추출."""
    t = text.lower()
    found = {k: [] for k in PICO_LEXICON}
    for field, terms in PICO_LEXICON.items():
        for term in terms:
            if term in t:
                found[field].append(term)
    return found


def structured_pico(record):
    """SR/trial 레코드의 pico 필드를 정규화된 토큰 집합 dict로 반환."""
    pico = record.get("pico", {})
    return {field: normalize_tokens(pico.get(field, [])) for field in WEIGHTS}


# ----------------------------------------------------------------------
# 2) PICO 매칭 점수 (투명한 토큰 겹침 기반)
# ----------------------------------------------------------------------
def field_overlap(sr_tokens, trial_tokens):
    """한 PICO 필드의 Jaccard-유사 겹침 비율 (0~1)."""
    if not sr_tokens:
        return 0.0
    inter = sr_tokens & trial_tokens
    return len(inter) / len(sr_tokens)


def pico_match_score(sr_pico, trial_pico):
    """필드별 겹침 × 가중치 합산 → 종합 점수 + 필드별 내역."""
    breakdown = {}
    total = 0.0
    for field, w in WEIGHTS.items():
        ov = field_overlap(sr_pico[field], trial_pico[field])
        breakdown[field] = round(ov, 3)
        total += ov * w
    return round(total, 3), breakdown


def is_eligible_design(sr, trial):
    """SR inclusion 기준 대비 트라이얼 설계 적격성(설계/기간) 확인."""
    inc = [c.lower() for c in sr.get("inclusion", [])]
    # RCT 요구 여부
    if any("randomized" in c for c in inc):
        if "randomized" not in trial.get("study_type", "").lower():
            return False, "not an RCT"
    # 기간 요구 (duration>=N weeks)
    for c in inc:
        m = re.search(r"duration>=(\d+)\s*weeks", c)
        if m:
            need = int(m.group(1))
            if trial.get("duration_weeks", 0) < need:
                return False, "duration<{}wk".format(need)
        m2 = re.search(r"follow-up>=(\d+)\s*months", c)
        if m2:
            need_wk = int(m2.group(1)) * 4
            if trial.get("duration_weeks", 0) < need_wk:
                return False, "followup<{}mo".format(m2.group(1))
    return True, "ok"


def new_eligible_trials(sr, trials):
    """SR 검색일 이후 색인 + PICO 적격 + 설계 적격인 트라이얼 목록 반환."""
    search_dt = parse_date(sr["search_date"])
    sr_pico = structured_pico(sr)
    results = []
    for tr in trials:
        idx_dt = parse_date(tr["index_date"])
        if idx_dt <= search_dt:
            continue  # 검색일 이전 → 이미 포함 가능, 노후화에 기여 안 함
        design_ok, design_reason = is_eligible_design(sr, tr)
        if not design_ok:
            continue
        score, breakdown = pico_match_score(sr_pico, structured_pico(tr))
        if score >= MATCH_THRESHOLD:
            results.append({
                "trial_id": tr["id"],
                "title": tr["title_en"],
                "index_date": tr["index_date"],
                "n": tr.get("n_participants", 0),
                "score": score,
                "breakdown": breakdown,
            })
    results.sort(key=lambda r: (-r["score"], r["index_date"]))
    return results


# ----------------------------------------------------------------------
# 3) 노후화 등급 (신호등)
# ----------------------------------------------------------------------
def staleness_grade(n_new):
    if n_new <= GRADE_FRESH:
        return "GREEN", "🟢 최신 (up-to-date)"
    if n_new <= 2:
        return "YELLOW", "🟡 주의 (monitor — 1~2 new eligible trials)"
    return "RED", "🔴 노후 (STALE — update recommended)"


def days_since(d):
    return (date.today() - parse_date(d)).days


def evaluate_sr(sr, trials):
    news = new_eligible_trials(sr, trials)
    n_new = len(news)
    total_new_n = sum(t["n"] for t in news)
    grade, label = staleness_grade(n_new)
    return {
        "sr": sr,
        "new_trials": news,
        "n_new": n_new,
        "total_new_n": total_new_n,
        "days_since_search": days_since(sr["search_date"]),
        "grade": grade,
        "grade_label": label,
    }


# ----------------------------------------------------------------------
# 4) 미출판 추적 + 중복 등록 탐지
# ----------------------------------------------------------------------
def pico_signature(sr):
    """PICO 4필드 토큰을 정렬·결합한 시그니처 (중복 비교용)."""
    sp = structured_pico(sr)
    parts = []
    for field in ("population", "intervention", "comparator", "outcome"):
        parts.append(",".join(sorted(sp[field])))
    return " | ".join(parts)


def pico_similarity(a, b):
    """두 SR의 PICO 토큰 Jaccard 유사도 (중복 판정용)."""
    sa, sb = structured_pico(a), structured_pico(b)
    inter = total = 0
    for field in WEIGHTS:
        ta, tb = sa[field], sb[field]
        inter += len(ta & tb)
        total += len(ta | tb)
    return inter / total if total else 0.0


def find_duplicates(srs, sim_threshold=0.7):
    """같은/유사 PICO를 가진 SR 쌍(서로 다른 팀) 탐지."""
    dups = []
    for i in range(len(srs)):
        for j in range(i + 1, len(srs)):
            sim = pico_similarity(srs[i], srs[j])
            if sim >= sim_threshold:
                dups.append((srs[i], srs[j], round(sim, 3)))
    return dups


def unpublished_report(srs):
    out = []
    for sr in srs:
        if not sr.get("published", False):
            out.append({
                "id": sr["id"],
                "registry_id": sr.get("registry_id"),
                "team": sr.get("team"),
                "days_registered": days_since(sr["search_date"]),
                "title": sr.get("title_en"),
            })
    return out


# ----------------------------------------------------------------------
# 5) PICO 갭 맵
#    관심 PICO 조합 격자를 정의하고, 어떤 SR도 다루지 않는 조합을 표시.
# ----------------------------------------------------------------------
GAP_GRID = [
    # (label, required intervention tokens, required population tokens, required outcome tokens)
    ("GLP-1 × 비만 × 체중감량", {"glp-1"}, {"obesity"}, {"weight loss"}),
    ("비만대사수술 × T2DM × 당뇨관해", {"surgery"}, {"t2dm"}, {"diabetes remission"}),
    ("행동중재 × 과체중 × 체중감량", {"behavioral"}, {"overweight"}, {"weight loss"}),
    ("GLP-1 × 근감소성비만 × 근육량",
     {"glp-1"}, {"sarcopenic obesity"}, {"muscle mass"}),
    ("비만대사수술 × MASLD × 간섬유화",
     {"surgery"}, {"masld"}, {"fibrosis"}),
]


def combo_covered(srs, intv, pop, outc):
    """어떤 SR이 이 조합(교집합)을 모두 포함하면 covered."""
    for sr in srs:
        sp = structured_pico(sr)
        if intv <= sp["intervention"] and pop <= sp["population"] and outc <= sp["outcome"]:
            return sr["id"]
    return None


def pico_gaps(srs):
    rows = []
    for label, intv, pop, outc in GAP_GRID:
        covered_by = combo_covered(srs, intv, pop, outc)
        rows.append({"label": label, "covered_by": covered_by})
    return rows


# ----------------------------------------------------------------------
# 출력 헬퍼
# ----------------------------------------------------------------------
def header():
    print("=" * 72)
    print("LivingSR-Staleness-Kor / 리빙에스알스테일니스코어  (Obesity · 연구알림)")
    print(DISCLAIMER)
    print("=" * 72)


def print_sr_eval(ev, verbose=True):
    sr = ev["sr"]
    pub = "출판됨 (published)" if sr.get("published") else "미출판 (PROSPERO only)"
    print("\n[{}] {}".format(sr["id"], sr.get("title_kor", sr.get("title_en"))))
    print("  registry={}  team={}  status={}".format(
        sr.get("registry_id"), sr.get("team"), pub))
    print("  검색일(search date)={}  ({}일 경과)".format(
        sr["search_date"], ev["days_since_search"]))
    print("  staleness: {}".format(ev["grade_label"]))
    print("  신규 적격 트라이얼 = {}건 | 신규 참가자 N = {}".format(
        ev["n_new"], ev["total_new_n"]))
    if ev["n_new"] > 0:
        print("  → ALERT: 검색일 이후 {}건의 신규 적격 연구로 인해 이 SR은 노후화되었습니다."
              .format(ev["n_new"]))
        if verbose:
            for t in ev["new_trials"]:
                print("     • {} (idx {}, N={}, score={}) {}".format(
                    t["trial_id"], t["index_date"], t["n"], t["score"], t["title"]))
                print("        breakdown P/I/C/O = {}/{}/{}/{}".format(
                    t["breakdown"]["population"], t["breakdown"]["intervention"],
                    t["breakdown"]["comparator"], t["breakdown"]["outcome"]))


# ----------------------------------------------------------------------
# 모드
# ----------------------------------------------------------------------
def mode_single(sr_id, srs, trials):
    header()
    match = next((s for s in srs if s["id"] == sr_id), None)
    if not match:
        print("ERROR: SR id '{}' 를 찾을 수 없습니다. 사용 가능: {}".format(
            sr_id, ", ".join(s["id"] for s in srs)))
        return 2
    ev = evaluate_sr(match, trials)
    print_sr_eval(ev, verbose=True)
    print("")
    return 0


def mode_gaps(srs):
    header()
    print("\n## PICO 갭 맵 (uncovered anti-obesity PICO combinations)")
    any_gap = False
    for row in pico_gaps(srs):
        if row["covered_by"]:
            print("  [COVERED]  {:38s} ← {}".format(row["label"], row["covered_by"]))
        else:
            any_gap = True
            print("  [GAP ⚠️ ]  {:38s} ← 등록/출판된 SR 없음 (research opportunity)"
                  .format(row["label"]))
    if not any_gap:
        print("  (모든 정의된 조합이 커버됨)")
    print("")
    return 0


def mode_demo(srs, trials):
    header()
    print("\n## 데모: 합성 SR 세트 전체 노후화 스캔")
    stale_count = 0
    for sr in srs:
        ev = evaluate_sr(sr, trials)
        print_sr_eval(ev, verbose=False)
        if ev["n_new"] > 0:
            stale_count += 1

    print("\n## 미출판 SR 추적 (unpublished tracking)")
    unp = unpublished_report(srs)
    if unp:
        for u in unp:
            print("  • {} [{}] {} — 검색일 이후 {}일 미출판 (team {})".format(
                u["id"], u["registry_id"], u["title"], u["days_registered"], u["team"]))
    else:
        print("  (미출판 SR 없음)")

    print("\n## 중복 PROSPERO 등록 탐지 (duplicate registrations)")
    dups = find_duplicates(srs)
    if dups:
        for a, b, sim in dups:
            print("  ⚠️  중복 의심: {} ({}) ↔ {} ({})  PICO 유사도={}".format(
                a["id"], a.get("team"), b["id"], b.get("team"), sim))
            print("      → 연구낭비 위험 / 협업 권고 (research-waste / collaboration alert)")
    else:
        print("  (중복 등록 없음)")

    print("\n## PICO 갭 맵 요약")
    for row in pico_gaps(srs):
        tag = "COVERED" if row["covered_by"] else "GAP ⚠️"
        print("  [{:8s}] {}".format(tag, row["label"]))

    print("\n" + "-" * 72)
    print("요약: SR {}개 중 {}개가 노후(staleness alert) 상태입니다.".format(
        len(srs), stale_count))
    print("-" * 72 + "\n")
    return 0


# ----------------------------------------------------------------------
# CLI
# ----------------------------------------------------------------------
def build_parser():
    p = argparse.ArgumentParser(
        prog="main.py",
        description="LivingSR-Staleness-Kor — 항비만 SR 노후화 연구알림 (offline MVP). "
                    "참고용·연구용 only.",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=(
            "예시:\n"
            "  python3 main.py                 # 데모/기본: 전체 SR 노후화 스캔\n"
            "  python3 main.py --demo          # 동일\n"
            "  python3 main.py --sr SR-001     # 단일 SR 노후화 평가\n"
            "  python3 main.py --gaps          # PICO 갭 맵 출력\n"
            "  python3 main.py --list          # 사용 가능한 SR id 목록\n"
        ),
    )
    p.add_argument("--sr", metavar="ID",
                   help="단일 SR의 노후화 평가 (신규 적격 트라이얼 + 신호등 등급)")
    p.add_argument("--gaps", action="store_true",
                   help="커버되지 않은 항비만 PICO 조합(갭) 출력")
    p.add_argument("--demo", action="store_true",
                   help="데모 모드: 합성 SR 세트 전체 노후화 알림 (기본값)")
    p.add_argument("--list", action="store_true",
                   help="데이터에 포함된 SR id 목록 출력")
    return p


def main(argv=None):
    parser = build_parser()
    args = parser.parse_args(argv)

    try:
        srs, trials = load_data()
    except Exception as e:
        print("ERROR: 데이터 로딩 실패: {}".format(e), file=sys.stderr)
        return 1

    if args.list:
        header()
        print("\n사용 가능한 SR:")
        for s in srs:
            print("  {} — {}".format(s["id"], s.get("title_kor", s.get("title_en"))))
        print("")
        return 0

    if args.sr:
        return mode_single(args.sr, srs, trials)
    if args.gaps:
        return mode_gaps(srs)
    # 기본 = 데모
    return mode_demo(srs, trials)


if __name__ == "__main__":
    sys.exit(main())
