"""Novelty scoring for DKDPulse.

score = base * (1 - seen_mask) * (1 + guideline_boost) * citation_velocity_proxy

Components are all heuristic (offline) — no external ML or live citation API.
"""

from . import vocab


def _guideline_boost(record):
    """+0.3 per guideline hit, capped at +0.6."""
    text = ((record.get("title") or "") + " " + (record.get("abstract") or "")).lower()
    hits = sum(1 for kw in vocab.GUIDELINE_KEYWORDS if kw in text)
    return min(0.6, 0.3 * hits)


def _citation_velocity_proxy(record):
    """Use mock 'jif' (journal impact factor surrogate) and 'days_since_pub' to derive velocity.

    velocity = jif / max(days_since_pub, 1) * 100
    Falls back to 0.5 if jif missing.
    """
    jif = record.get("jif")
    if jif is None:
        return 0.5
    try:
        jif = float(jif)
    except (TypeError, ValueError):
        return 0.5
    days = record.get("days_since_pub", 30)
    try:
        days = max(1, int(days))
    except (TypeError, ValueError):
        days = 30
    return round(jif / days * 100, 3)


def _coverage_bonus(record):
    """Reward records with rich tagging (more dimensions covered = more interesting)."""
    tags = record.get("tags") or {}
    dims = 0
    if tags.get("drug_classes"):
        dims += 1
    if tags.get("outcomes"):
        dims += 1
    if tags.get("phenotypes"):
        dims += 1
    return 1.0 + 0.15 * dims  # 1.0 .. 1.45


def score_record(record, seen_pmids):
    """Return float score. Records whose pmid is in seen_pmids are zeroed (already seen)."""
    pmid = str(record.get("pmid") or record.get("nct_id") or record.get("doi") or "")
    if pmid and pmid in seen_pmids:
        return 0.0

    base = 1.0
    guideline = _guideline_boost(record)
    velocity = _citation_velocity_proxy(record)
    coverage = _coverage_bonus(record)

    # source-type weight — RCT > preprint > registry diff
    source = (record.get("source") or "").lower()
    src_w = {"pubmed": 1.0, "medrxiv": 0.8, "ctg": 0.6}.get(source, 0.7)

    score = base * (1 + guideline) * velocity * coverage * src_w
    return round(score, 3)


def rank(records, seen_pmids, top_n=10):
    """Annotate each record with 'novelty' and return top_n list sorted desc."""
    scored = []
    for r in records:
        s = score_record(r, seen_pmids)
        nr = dict(r)
        nr["novelty"] = s
        scored.append(nr)
    scored.sort(key=lambda x: x["novelty"], reverse=True)
    if top_n is None or top_n <= 0:
        return scored
    return scored[:top_n]
