"""Lawlor, Tilling, Davey Smith 2016 (IJE) triangulation 5-criterion scoring.

Each criterion is scored 0 / 1 / 2 with brief rationale strings.
This is a rule-of-thumb operationalization; not a formal QUIPS/ROBINS scale.
"""
from __future__ import annotations

from typing import Any

from .grid import build_grid, direction_of, concordance_score

CRITERIA = [
    "internal_validity",
    "external_validity",
    "bias_direction_known",
    "bias_direction_differs",
    "triangulation_strength",
]

# How strong is each design's internal validity by default?
_INTERNAL_VALIDITY_BY_DESIGN = {
    "RCT": 2,
    "target-MR": 2,
    "within-subject": 1,
    "observational": 1,
    "ex vivo": 1,
}

# Per-design known bias direction template
BIAS_DIRECTION_BY_DESIGN = {
    "RCT": "internal validity strong; external validity may be limited (trial population, short follow-up)",
    "observational": "confounding by indication, healthy-user, detection bias typically inflate benefit",
    "target-MR": "pleiotropy / canalization may bias toward null; lifelong exposure differs from short-term Rx",
    "ex vivo": "dose translation problems; mechanism may not survive whole-organism PK/PD",
    "within-subject": "carryover, period effects; limited generalizability beyond participant",
}


def _score_internal_validity(grid: dict[str, list[dict[str, Any]]]) -> tuple[int, str]:
    best = 0
    for design in grid:
        best = max(best, _INTERNAL_VALIDITY_BY_DESIGN.get(design, 0))
    rationale = (
        f"strongest available design tier = {best} "
        f"(RCT/target-MR=2; observational/within-subject/ex vivo=1)"
    )
    return best, rationale


def _score_external_validity(pair_effects: list[dict[str, Any]]) -> tuple[int, str]:
    # Use population diversity: number of distinct population strings + observational presence
    pops = {(r.get("population") or "").strip() for r in pair_effects if r.get("population")}
    has_observational = any(r.get("design") == "observational" for r in pair_effects)
    score = 0
    if len(pops) >= 2:
        score += 1
    if has_observational:
        score += 1
    score = min(score, 2)
    rationale = f"{len(pops)} distinct populations; observational present={has_observational}"
    return score, rationale


def _score_bias_direction_known(grid: dict[str, list[dict[str, Any]]]) -> tuple[int, str]:
    # If a design is present and we have a known bias direction template, count it.
    known = sum(1 for d in grid if d in BIAS_DIRECTION_BY_DESIGN)
    n = len(grid) or 1
    frac = known / n
    if frac >= 0.8:
        score = 2
    elif frac >= 0.4:
        score = 1
    else:
        score = 0
    rationale = f"{known}/{len(grid)} designs have a known bias direction template"
    return score, rationale


def _score_bias_direction_differs(grid: dict[str, list[dict[str, Any]]]) -> tuple[int, str]:
    # Strength of triangulation is higher when biases of present designs are
    # *expected to act in different directions*. We approximate by counting
    # how many bias families are represented across present designs.
    bias_families: set[str] = set()
    if "RCT" in grid:
        bias_families.add("internal-validity-strong")
    if "observational" in grid:
        bias_families.add("confounding-by-indication")
    if "target-MR" in grid:
        bias_families.add("pleiotropy")
    if "ex vivo" in grid:
        bias_families.add("dose-translation")
    if "within-subject" in grid:
        bias_families.add("carryover")
    n = len(bias_families)
    if n >= 4:
        score = 2
    elif n >= 2:
        score = 1
    else:
        score = 0
    rationale = f"{n} distinct bias families represented across designs"
    return score, rationale


def _score_triangulation_strength(pair_effects: list[dict[str, Any]]) -> tuple[int, str]:
    sc = concordance_score(pair_effects)
    conc = sc["concordance"]
    designs = sc["designs_present"]
    raw = conc * (designs / 5.0)
    if raw >= 0.55:
        score = 2
    elif raw >= 0.25:
        score = 1
    else:
        score = 0
    rationale = (
        f"concordance={conc} x design coverage={designs}/5 → raw={round(raw, 3)}"
    )
    return score, rationale


def score_lawlor_criteria(pair_effects: list[dict[str, Any]]) -> dict[str, Any]:
    grid = build_grid(pair_effects)

    iv, iv_r = _score_internal_validity(grid)
    ev, ev_r = _score_external_validity(pair_effects)
    bdk, bdk_r = _score_bias_direction_known(grid)
    bdd, bdd_r = _score_bias_direction_differs(grid)
    ts, ts_r = _score_triangulation_strength(pair_effects)

    items = [
        ("internal_validity", iv, iv_r),
        ("external_validity", ev, ev_r),
        ("bias_direction_known", bdk, bdk_r),
        ("bias_direction_differs", bdd, bdd_r),
        ("triangulation_strength", ts, ts_r),
    ]
    total = sum(s for _, s, _ in items)
    max_total = 2 * len(items)
    return {
        "criteria": [
            {"name": name, "score": s, "max": 2, "rationale": r}
            for name, s, r in items
        ],
        "total": total,
        "max": max_total,
        "normalized": round(total / max_total, 3),
    }
