"""Keyword + light regex 3D tagger: drug × outcome × phenotype.

No external ML model. Pure stdlib. Operates on a record dict with at least
{title, abstract} string fields.
"""

import re
from . import vocab


_DRUG_MAP = vocab.all_drug_keywords()
_OUTCOME_MAP = vocab.all_outcome_keywords()
_PHENOTYPE_MAP = vocab.all_phenotype_keywords()


def _normalize(text):
    if not text:
        return ""
    # Collapse whitespace, lowercase, strip non-essential punctuation that breaks token edges.
    t = text.lower()
    t = re.sub(r"\s+", " ", t)
    return t


def _scan(text, keyword_map):
    """Return set of labels whose any keyword appears as a substring/word in text."""
    found = set()
    for kw, label in keyword_map.items():
        # Bound short keywords (<=4 chars) to word boundaries to reduce noise.
        if len(kw) <= 4:
            pattern = r"(?<![a-z0-9])" + re.escape(kw) + r"(?![a-z0-9])"
            if re.search(pattern, text):
                found.add(label)
        else:
            if kw in text:
                found.add(label)
    return found


def classify(record):
    """Return tag dict {drug_classes, outcomes, phenotypes} for a record.

    record: dict with title (str) and abstract (str) keys; other keys ignored.
    """
    text = _normalize((record.get("title") or "") + " \n " + (record.get("abstract") or ""))
    drugs = _scan(text, _DRUG_MAP)
    outcomes = _scan(text, _OUTCOME_MAP)
    phenotypes = _scan(text, _PHENOTYPE_MAP)
    return {
        "drug_classes": sorted(drugs),
        "outcomes": sorted(outcomes),
        "phenotypes": sorted(phenotypes),
    }


def annotate(records):
    """Return new list of records with a 'tags' key added."""
    out = []
    for r in records:
        tags = classify(r)
        new = dict(r)
        new["tags"] = tags
        out.append(new)
    return out
