"""
MASLDRetractionWatch-Kor — core data loading & analysis (offline, rule-based, no LLM).

All inputs are mock/synthetic JSON files under ./data/.
Disclaimer: 본 도구는 연구·참고용이며, retraction 정보는 실제 source(Retraction Watch DB·PubMed) 직접 확인 필수.
"""
from __future__ import annotations

import json
import os
from collections import Counter, defaultdict
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Tuple

DISCLAIMER = (
    "본 도구는 연구·참고용이며, retraction 정보는 실제 source"
    "(Retraction Watch DB·PubMed·PubPeer·Crossref·학술지 공지) 직접 확인 필수. "
    "모든 데이터는 오프라인 데모용 mock data 입니다."
)

MASLD_KEYWORDS = [
    "MASLD", "MASH", "NAFLD", "NASH", "MAFLD",
    "hepatic fibrosis", "fibrosis", "cirrhosis", "HCC",
    "resmetirom", "obeticholic", "pegozafermin", "pegbelfermin",
    "FGF21", "THR-β", "ACC", "FXR", "PPAR", "elafibranor",
    "firsocostat", "tropifexor", "semaglutide", "vitamin E", "pioglitazone",
]

MESH_TERMS = [
    "Non-alcoholic Fatty Liver Disease",
    "Liver Cirrhosis",
    "Hepatitis",
    "Fatty Liver",
]

SOURCES = [
    "Retraction Watch DB",
    "PubMed",
    "PubPeer",
    "Crossref event",
    "Journal notice",
]

YELLOW_FLAG_THRESHOLD = 2
RED_FLAG_THRESHOLD = 5


# ---------- data loading ----------

def _data_dir() -> str:
    here = os.path.dirname(os.path.abspath(__file__))
    return os.path.join(here, "data")


def load_retractions(path: Optional[str] = None) -> List[Dict[str, Any]]:
    """Load mock retraction/correction/EoC records."""
    if path is None:
        path = os.path.join(_data_dir(), "retractions.json")
    with open(path, "r", encoding="utf-8") as f:
        blob = json.load(f)
    return blob.get("records", [])


def load_sample_systematic_review(path: Optional[str] = None) -> Dict[str, Any]:
    if path is None:
        path = os.path.join(_data_dir(), "sample_systematic_review.json")
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def load_guidelines(path: Optional[str] = None) -> List[Dict[str, Any]]:
    if path is None:
        path = os.path.join(_data_dir(), "guidelines_cited.json")
    with open(path, "r", encoding="utf-8") as f:
        blob = json.load(f)
    return blob.get("guidelines", [])


# ---------- Feature 1: MASLD filter + dedup ----------

def matches_masld(record: Dict[str, Any]) -> bool:
    """Rule-based MASLD/MASH filter on keywords + title."""
    text_blob = " ".join([
        record.get("title", "") or "",
        " ".join(record.get("keywords", []) or []),
    ]).lower()
    for kw in MASLD_KEYWORDS:
        if kw.lower() in text_blob:
            return True
    return False


def dedup_records(records: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Dedup by (pmid or doi)."""
    seen = set()
    out: List[Dict[str, Any]] = []
    for r in records:
        key = r.get("pmid") or r.get("doi") or r.get("id")
        if key in seen:
            continue
        seen.add(key)
        out.append(r)
    return out


def filter_and_dedup(records: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]:
    return dedup_records(r for r in records if matches_masld(r))


# ---------- Feature 2: systematic review cross-reference ----------

def cross_reference_review(
    included: List[Dict[str, Any]],
    records: List[Dict[str, Any]],
) -> Dict[str, Any]:
    """For each included paper, flag if it appears in retraction DB."""
    by_pmid = {r.get("pmid"): r for r in records if r.get("pmid")}
    by_doi = {r.get("doi"): r for r in records if r.get("doi")}

    affected: List[Dict[str, Any]] = []
    clean: List[Dict[str, Any]] = []

    for inc in included:
        hit = by_pmid.get(inc.get("pmid")) or by_doi.get(inc.get("doi"))
        if hit:
            affected.append({"included": inc, "record": hit})
        else:
            clean.append(inc)

    n = len(included)
    n_aff = len(affected)
    severity_break = Counter(a["record"].get("type", "Unknown") for a in affected)

    suggestions = []
    if n_aff:
        suggestions.append(
            f"Sensitivity analysis: {n_aff}/{n} included paper가 retraction/correction/EoC 상태. "
            "이들 제외 시 pooled effect 재계산 권고."
        )
        if severity_break.get("Retraction", 0):
            suggestions.append(
                f"Cochrane RoB2: Retraction {severity_break['Retraction']}건 — "
                "'Overall: High risk' 재평가, narrative synthesis 전환 검토."
            )
        if severity_break.get("Expression of Concern", 0):
            suggestions.append(
                f"EoC {severity_break['Expression of Concern']}건 — "
                "PRISMA flow에서 'Reports excluded after EoC review' 단계 추가 제안."
            )
        if severity_break.get("Correction", 0):
            suggestions.append(
                f"Correction {severity_break['Correction']}건 — "
                "원논문 vs 수정본 effect estimate 차이 확인 필요."
            )
    else:
        suggestions.append("Included paper 중 retraction DB hit 없음. (mock data 기준)")

    return {
        "n_included": n,
        "n_affected": n_aff,
        "affected": affected,
        "clean": clean,
        "severity_breakdown": dict(severity_break),
        "suggestions": suggestions,
    }


# ---------- Feature 3: author/lab cumulative flags ----------

def author_cumulative(records: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    counts: Counter = Counter()
    by_author_records: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
    orcids: Dict[str, set] = defaultdict(set)
    affs: Dict[str, set] = defaultdict(set)
    types: Dict[str, Counter] = defaultdict(Counter)
    years: Dict[str, List[int]] = defaultdict(list)

    for r in records:
        for a in r.get("authors", []):
            counts[a] += 1
            by_author_records[a].append(r)
            for o in r.get("orcid", []) or []:
                orcids[a].add(o)
            for af in r.get("affiliations", []) or []:
                affs[a].add(af)
            types[a][r.get("type", "Unknown")] += 1
            if r.get("year"):
                years[a].append(int(r["year"]))

    rows = []
    for a, c in counts.most_common():
        flag = "green"
        if c >= RED_FLAG_THRESHOLD:
            flag = "red"
        elif c >= YELLOW_FLAG_THRESHOLD:
            flag = "yellow"
        yrs = sorted(years[a])
        time_cluster = ""
        if len(yrs) >= 2:
            span = yrs[-1] - yrs[0]
            time_cluster = f"{yrs[0]}–{yrs[-1]} (span {span}y)"
        rows.append({
            "author": a,
            "count": c,
            "flag": flag,
            "orcid": sorted(orcids[a]),
            "affiliations": sorted(affs[a]),
            "type_breakdown": dict(types[a]),
            "time_cluster": time_cluster,
            "record_ids": [r.get("id") for r in by_author_records[a]],
        })
    return rows


def affiliation_cumulative(records: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    counts: Counter = Counter()
    by_aff_records: Dict[str, List[str]] = defaultdict(list)
    for r in records:
        for af in r.get("affiliations", []) or []:
            counts[af] += 1
            by_aff_records[af].append(r.get("id"))
    rows = []
    for af, c in counts.most_common():
        flag = "green"
        if c >= RED_FLAG_THRESHOLD:
            flag = "red"
        elif c >= YELLOW_FLAG_THRESHOLD:
            flag = "yellow"
        rows.append({
            "affiliation": af,
            "count": c,
            "flag": flag,
            "record_ids": by_aff_records[af],
        })
    return rows


# A small mock map of guideline-committee members.
GUIDELINE_COMMITTEE_MEMBERS = {
    "KASL MASLD CPG (mock)": ["Kim J-H", "Park S-J", "Choi Y-W", "Lee M-K"],
    "AASLD MASH Guidance (mock)": ["Smith J", "Brown A", "Dupont L"],
    "EASL MASLD CPG (mock)": ["Dupont L", "Martin C", "Garcia M", "Rossi G"],
}


def guideline_member_overlap(author_rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    out = []
    for guideline, members in GUIDELINE_COMMITTEE_MEMBERS.items():
        hits = [a for a in author_rows if a["author"] in members and a["count"] >= YELLOW_FLAG_THRESHOLD]
        out.append({
            "guideline": guideline,
            "flagged_members": hits,
        })
    return out


# ---------- Feature 4: journal / publisher / reason distributions + timeseries ----------

def journal_distribution(records: List[Dict[str, Any]]) -> Counter:
    return Counter(r.get("journal", "Unknown") for r in records)


def publisher_distribution(records: List[Dict[str, Any]]) -> Counter:
    return Counter(r.get("publisher", "Unknown") for r in records)


def reason_distribution(records: List[Dict[str, Any]]) -> Counter:
    return Counter(r.get("reason", "Unknown") for r in records)


def type_distribution(records: List[Dict[str, Any]]) -> Counter:
    return Counter(r.get("type", "Unknown") for r in records)


def source_distribution(records: List[Dict[str, Any]]) -> Counter:
    return Counter(r.get("source", "Unknown") for r in records)


def yearly_timeseries(records: List[Dict[str, Any]]) -> List[Tuple[int, int]]:
    yrs: Counter = Counter()
    for r in records:
        if r.get("retraction_date"):
            try:
                y = int(r["retraction_date"][:4])
                yrs[y] += 1
            except Exception:
                pass
    return sorted(yrs.items())


def lag_distribution(records: List[Dict[str, Any]]) -> Dict[str, float]:
    lags = [r.get("lag_days") for r in records if isinstance(r.get("lag_days"), (int, float))]
    if not lags:
        return {"n": 0, "mean": 0.0, "median": 0.0, "min": 0.0, "max": 0.0}
    lags_sorted = sorted(lags)
    n = len(lags_sorted)
    mean = sum(lags_sorted) / n
    median = lags_sorted[n // 2] if n % 2 else (lags_sorted[n // 2 - 1] + lags_sorted[n // 2]) / 2
    return {
        "n": n,
        "mean": round(mean, 1),
        "median": round(median, 1),
        "min": float(lags_sorted[0]),
        "max": float(lags_sorted[-1]),
    }


def journal_yearly_rate(records: List[Dict[str, Any]]) -> Dict[str, List[Tuple[int, int]]]:
    """Per-journal yearly retraction counts."""
    bucket: Dict[str, Counter] = defaultdict(Counter)
    for r in records:
        j = r.get("journal", "Unknown")
        if r.get("retraction_date"):
            try:
                y = int(r["retraction_date"][:4])
                bucket[j][y] += 1
            except Exception:
                pass
    out: Dict[str, List[Tuple[int, int]]] = {}
    for j, ctr in bucket.items():
        out[j] = sorted(ctr.items())
    return out


# ---------- Feature 5: weekly digest + guideline sanity ----------

def recent_records(records: List[Dict[str, Any]], top_n: int = 15) -> List[Dict[str, Any]]:
    def keyfn(r):
        d = r.get("retraction_date") or "0000-00-00"
        return d
    return sorted(records, key=keyfn, reverse=True)[:top_n]


def weekly_digest_markdown(records: List[Dict[str, Any]], top_n: int = 15) -> str:
    recent = recent_records(records, top_n=top_n)
    type_ct = type_distribution(recent)
    reason_ct = reason_distribution(recent)
    author_rows = author_cumulative(records)
    red = [a for a in author_rows if a["flag"] == "red"][:5]
    yellow = [a for a in author_rows if a["flag"] == "yellow"][:5]

    today = datetime.now().strftime("%Y-%m-%d")
    lines: List[str] = []
    lines.append(f"# MASLDRetractionWatch-Kor 주간 다이제스트 ({today})")
    lines.append("")
    lines.append(f"> {DISCLAIMER}")
    lines.append("")
    lines.append(f"## 요약 — 최근 {len(recent)}건 (총 DB {len(records)}건)")
    lines.append("")
    lines.append("| 유형 | 건수 |")
    lines.append("|---|---:|")
    for k, v in type_ct.most_common():
        lines.append(f"| {k} | {v} |")
    lines.append("")
    lines.append("## 사유 분포 (최근)")
    lines.append("")
    for k, v in reason_ct.most_common():
        lines.append(f"- {k}: {v}")
    lines.append("")
    lines.append("## Red-flag author (누적 ≥ 5건)")
    lines.append("")
    if not red:
        lines.append("- 없음")
    for a in red:
        lines.append(f"- **{a['author']}** — {a['count']}건, 소속: {', '.join(a['affiliations'])}, {a['time_cluster']}")
    lines.append("")
    lines.append("## Yellow-flag author (누적 2-4건)")
    lines.append("")
    if not yellow:
        lines.append("- 없음")
    for a in yellow:
        lines.append(f"- {a['author']} — {a['count']}건")
    lines.append("")
    lines.append("## 신규 / 최근 record")
    lines.append("")
    lines.append("| 날짜 | 유형 | 저널 | 제목 | 사유 |")
    lines.append("|---|---|---|---|---|")
    for r in recent:
        title = (r.get("title", "") or "")[:80]
        lines.append(
            f"| {r.get('retraction_date','')} | {r.get('type','')} | "
            f"{r.get('journal','')} | {title} | {r.get('reason','')} |"
        )
    lines.append("")
    lines.append("## 권고")
    lines.append("")
    lines.append("- Systematic review/meta-analysis 진행 중인 팀은 최근 retraction 반영 sensitivity analysis 권고.")
    lines.append("- Red-flag author의 다른 publication 추가 audit 검토.")
    lines.append("- KASL/AASLD/EASL 가이드라인 referencing 시, retraction status 사전 확인.")
    lines.append("")
    return "\n".join(lines)


def guideline_sanity_report(
    guidelines: List[Dict[str, Any]],
    records: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
    """For each guideline, list cited PMIDs that are now retracted/EoC/corrected."""
    pmid_index = {r.get("pmid"): r for r in records if r.get("pmid")}
    out = []
    for g in guidelines:
        affected = []
        for pmid in g.get("cited_pmids", []):
            rec = pmid_index.get(str(pmid))
            if rec:
                affected.append({
                    "pmid": pmid,
                    "type": rec.get("type"),
                    "title": rec.get("title"),
                    "journal": rec.get("journal"),
                    "reason": rec.get("reason"),
                    "retraction_date": rec.get("retraction_date"),
                })
        out.append({
            "guideline": g.get("name"),
            "society": g.get("society"),
            "year": g.get("year"),
            "n_cited": len(g.get("cited_pmids", [])),
            "n_affected": len(affected),
            "affected": affected,
        })
    return out


def export_docx_report(
    digest_md: str,
    guideline_report: List[Dict[str, Any]],
    out_path: str,
) -> str:
    """Export weekly digest + guideline sanity to a .docx file."""
    try:
        from docx import Document
    except ImportError as e:
        raise RuntimeError("python-docx not installed; add to requirements.txt") from e

    doc = Document()
    doc.add_heading("MASLDRetractionWatch-Kor — 주간 리포트", level=0)
    doc.add_paragraph(DISCLAIMER)

    doc.add_heading("1. 주간 다이제스트", level=1)
    for line in digest_md.splitlines():
        if line.startswith("# "):
            doc.add_heading(line[2:], level=1)
        elif line.startswith("## "):
            doc.add_heading(line[3:], level=2)
        elif line.startswith("- "):
            doc.add_paragraph(line[2:], style="List Bullet")
        elif line.strip().startswith("|"):
            doc.add_paragraph(line)
        else:
            doc.add_paragraph(line)

    doc.add_heading("2. KASL/AASLD/EASL 가이드라인 sanity check", level=1)
    for g in guideline_report:
        doc.add_heading(f"{g['society']} — {g['guideline']} ({g['year']})", level=2)
        doc.add_paragraph(
            f"인용 {g['n_cited']}건 중 retraction/correction/EoC 영향 {g['n_affected']}건."
        )
        for a in g["affected"]:
            doc.add_paragraph(
                f"PMID {a['pmid']} | {a['type']} | {a['journal']} | "
                f"{a['retraction_date']} | 사유: {a['reason']} — {a['title']}",
                style="List Bullet",
            )

    doc.save(out_path)
    return out_path


# ---------- summary ----------

def overall_summary(records: List[Dict[str, Any]]) -> Dict[str, Any]:
    return {
        "n_total": len(records),
        "by_type": dict(type_distribution(records)),
        "by_source": dict(source_distribution(records)),
        "by_reason": dict(reason_distribution(records)),
        "top_journals": dict(journal_distribution(records).most_common(5)),
        "top_publishers": dict(publisher_distribution(records).most_common(5)),
        "lag": lag_distribution(records),
        "timeseries": yearly_timeseries(records),
    }
