"""Lawlor 2016 5-criterion triangulation scoring + Korean ancestry layer."""
from __future__ import annotations

from typing import Dict, List, Optional

from .ontology import DESIGN_ORDER

# Lawlor DA et al, IJE 2016 — triangulation criteria

CRITERIA = [
    ("multiple_designs", "≥2 design type 사용"),
    ("orthogonal_bias", "design 간 bias 구조가 직교"),
    ("convergent_direction", "효과 방향 일관"),
    ("convergent_magnitude", "효과 크기 일관 (order of magnitude)"),
    ("explicit_assumptions", "각 design 의 핵심 가정 명시"),
]


def _direction(effect: Optional[float]) -> Optional[int]:
    if effect is None:
        return None
    if effect > 1.10:
        return 1
    if effect < 0.90:
        return -1
    return 0


def score_pair(summary: Dict[str, Dict]) -> Dict[str, object]:
    """Score one (stage, outcome) summary by Lawlor 5 criteria."""
    designs_with_data = [d for d in DESIGN_ORDER if summary[d]["n"] > 0]
    n_designs = len(designs_with_data)

    # 1. multiple_designs
    c1 = 1 if n_designs >= 2 else 0
    # 2. orthogonal_bias (heuristic: observational+MR / RCT / ex_vivo span ≠ confounding structures)
    bias_families = set()
    for d in designs_with_data:
        if d == "observational":
            bias_families.add("residual_confounding")
        elif d == "MR":
            bias_families.add("pleiotropy")
        elif d == "RCT":
            bias_families.add("internal_validity_high")
        elif d == "ex_vivo_pcls":
            bias_families.add("external_validity")
        elif d == "within_subject_lifestyle":
            bias_families.add("compliance_adherence")
    c2 = 1 if len(bias_families) >= 2 else 0
    # 3. convergent_direction
    dirs = [_direction(summary[d]["effect"]) for d in designs_with_data]
    dirs = [d for d in dirs if d is not None]
    if not dirs:
        c3 = 0
    else:
        c3 = 1 if (max(dirs) - min(dirs) <= 1) and (len(set(dirs)) <= 2) else 0
    # 4. convergent_magnitude (log-scale span < ~2 → 일관)
    eff_vals = [summary[d]["effect"] for d in designs_with_data
                if isinstance(summary[d]["effect"], (int, float)) and summary[d]["effect"] > 0]
    if len(eff_vals) >= 2:
        import math
        logs = [math.log(v) for v in eff_vals]
        span = max(logs) - min(logs)
        c4 = 1 if span < 2.0 else 0  # < ~7x ratio
    else:
        c4 = 0
    # 5. explicit_assumptions — by construction always 1 (each row carries source+design)
    c5 = 1 if n_designs >= 1 else 0

    total = c1 + c2 + c3 + c4 + c5
    return {
        "criteria": {
            "multiple_designs": c1,
            "orthogonal_bias": c2,
            "convergent_direction": c3,
            "convergent_magnitude": c4,
            "explicit_assumptions": c5,
        },
        "total": total,
        "n_designs": n_designs,
        "max": 5,
    }


def korean_ancestry_layer_score(rows: List[Dict[str, object]]) -> Dict[str, object]:
    """Korean ancestry layer score for triangulation external validity."""
    korean_rows = [r for r in rows if (r.get("ancestry") or "").lower() == "korean"]
    asian_rows = [r for r in rows if (r.get("ancestry") or "").lower() in ("korean", "asian", "east_asian")]
    designs_korean = set(r.get("design") for r in korean_rows)
    return {
        "n_korean_rows": len(korean_rows),
        "n_asian_rows": len(asian_rows),
        "korean_designs_covered": sorted(d for d in designs_korean if d),
        "external_validity_for_korean": (
            "STRONG" if len(designs_korean) >= 3
            else "MODERATE" if len(designs_korean) >= 2
            else "WEAK" if len(designs_korean) >= 1
            else "ABSENT"
        ),
    }
