"""5-design triangulation grid + concordance metric.

Pure stdlib so the CLI does not depend on numpy/scipy.
"""
from __future__ import annotations

from typing import Any

DESIGNS = ["RCT", "observational", "target-MR", "ex vivo", "within-subject"]


def direction_of(effect: dict[str, Any]) -> int:
    """Return +1 / -1 / 0 for the direction of an effect.

    Convention:
      - point estimate > 1 and CI excludes 1 → +1 (risk increase)
      - point estimate < 1 and CI excludes 1 → -1 (risk decrease / benefit)
      - CI includes 1 → 0 (null)

    For continuous effect estimates (e.g. mean differences expressed on a
    ratio-like scale via STEP/SURMOUNT %), the same null-bracketing logic
    applies (compare against 1 since the CSV normalizes ratios; for mean-diff
    style outcomes we use the same convention treating the supplied effect
    as if 1.0 = no effect — curators must encode accordingly).
    """
    e = effect.get("effect_estimate")
    lo = effect.get("ci_low")
    hi = effect.get("ci_high")
    if e is None:
        return 0
    if lo is None or hi is None:
        return 1 if e > 1 else (-1 if e < 1 else 0)
    if lo > 1.0:
        return 1
    if hi < 1.0:
        return -1
    return 0


def _ci_overlap(a: dict[str, Any], b: dict[str, Any]) -> bool:
    """Do CIs of two effects overlap?"""
    a_lo, a_hi = a.get("ci_low"), a.get("ci_high")
    b_lo, b_hi = b.get("ci_low"), b.get("ci_high")
    if None in (a_lo, a_hi, b_lo, b_hi):
        return False
    return not (a_hi < b_lo or b_hi < a_lo)


def build_grid(pair_effects: list[dict[str, Any]]) -> dict[str, list[dict[str, Any]]]:
    """Group pair effects by design type. Designs without effects are absent."""
    grid: dict[str, list[dict[str, Any]]] = {}
    for r in pair_effects:
        d = r.get("design", "unknown")
        grid.setdefault(d, []).append(r)
    return grid


def _mean(xs: list[float]) -> float:
    return sum(xs) / len(xs) if xs else 0.0


def concordance_score(pair_effects: list[dict[str, Any]]) -> dict[str, float]:
    """Compute a concordance dict for a single (drug_class, outcome) pair.

    Returns:
      {
        'designs_present': int,
        'direction_agreement': float in [0,1],   # majority-direction fraction
        'ci_overlap_fraction': float in [0,1],   # pairwise CI overlap rate
        'concordance': float in [0,1],           # weighted composite
        'majority_direction': int (-1, 0, +1),
      }
    """
    grid = build_grid(pair_effects)
    # Per-design "consensus" effect: take the one with the largest sample_size
    # (or first if no sample_size) as the representative.
    reps: list[dict[str, Any]] = []
    for d in DESIGNS:
        eff_list = grid.get(d, [])
        if not eff_list:
            continue
        rep = max(
            eff_list,
            key=lambda r: r.get("sample_size") or 0,
        )
        reps.append(rep)

    n = len(reps)
    if n == 0:
        return {
            "designs_present": 0,
            "direction_agreement": 0.0,
            "ci_overlap_fraction": 0.0,
            "concordance": 0.0,
            "majority_direction": 0,
        }

    dirs = [direction_of(r) for r in reps]
    # Majority direction
    counts = {-1: 0, 0: 0, 1: 0}
    for d in dirs:
        counts[d] += 1
    majority = max(counts, key=lambda k: counts[k])
    direction_agreement = counts[majority] / n if n else 0.0

    # CI overlap fraction (pairwise)
    pairs = 0
    overlapping = 0
    for i in range(n):
        for j in range(i + 1, n):
            pairs += 1
            if _ci_overlap(reps[i], reps[j]):
                overlapping += 1
    overlap_frac = (overlapping / pairs) if pairs else (1.0 if n == 1 else 0.0)

    # Coverage bonus (more designs present → stronger triangulation)
    coverage = n / 5.0

    # Weighted concordance composite: heavy on direction agreement, then CI
    # overlap, with a coverage bonus. Capped at 1.0.
    score = 0.55 * direction_agreement + 0.30 * overlap_frac + 0.15 * coverage

    return {
        "designs_present": n,
        "direction_agreement": round(direction_agreement, 3),
        "ci_overlap_fraction": round(overlap_frac, 3),
        "concordance": round(min(score, 1.0), 3),
        "majority_direction": majority,
    }


def grid_summary(effects: list[dict[str, Any]]) -> list[dict[str, Any]]:
    """One-row-per-pair summary across the whole effects table."""
    from .ontology import list_pairs, filter_pair

    out: list[dict[str, Any]] = []
    for drug_class, outcome in list_pairs(effects):
        pair = filter_pair(effects, drug_class, outcome)
        sc = concordance_score(pair)
        out.append(
            {
                "drug_class": drug_class,
                "outcome": outcome,
                "designs_present": sc["designs_present"],
                "concordance": sc["concordance"],
                "majority_direction": sc["majority_direction"],
                "direction_agreement": sc["direction_agreement"],
                "n_effects": len(pair),
            }
        )
    return out
