"""5-design effect grid + concordance score (stdlib only)."""
from __future__ import annotations

import math
from typing import Dict, List, Optional, Tuple

from .ontology import DESIGN_ORDER, unique_pairs


def _direction(effect: Optional[float]) -> Optional[str]:
    """log-scale → +1 / -1 / 0; effect on ratio scale (>1 increase, <1 decrease)."""
    if effect is None or not isinstance(effect, (int, float)):
        return None
    if effect > 1.10:
        return "+"
    if effect < 0.90:
        return "-"
    return "0"


def _ci_excludes_null(ci_low: Optional[float], ci_high: Optional[float]) -> bool:
    if ci_low is None or ci_high is None:
        return False
    return (ci_low > 1.0) or (ci_high < 1.0)


def design_summary(effects: List[Dict[str, object]], stage: str, outcome: str) -> Dict[str, Dict]:
    """For one (stage, outcome) → per-design pooled summary."""
    out: Dict[str, Dict] = {}
    for d in DESIGN_ORDER:
        rows = [e for e in effects
                if e.get("masld_stage") == stage
                and e.get("outcome") == outcome
                and e.get("design") == d]
        if not rows:
            out[d] = {"n": 0, "effect": None, "ci_low": None, "ci_high": None,
                      "direction": None, "significant": False, "rows": []}
            continue
        # inverse-variance-ish naive average on log scale
        log_eff = [r.get("log_effect") for r in rows if r.get("log_effect") is not None]
        if log_eff:
            mean_log = sum(log_eff) / len(log_eff)
            eff = math.exp(mean_log)
        else:
            eff = None
        ci_lows = [r.get("ci_low") for r in rows if isinstance(r.get("ci_low"), (int, float))]
        ci_highs = [r.get("ci_high") for r in rows if isinstance(r.get("ci_high"), (int, float))]
        ci_l = min(ci_lows) if ci_lows else None
        ci_h = max(ci_highs) if ci_highs else None
        # if ALL rows CI exclude null, mark significant
        sigs = [_ci_excludes_null(r.get("ci_low"), r.get("ci_high")) for r in rows]
        out[d] = {
            "n": len(rows),
            "effect": eff,
            "ci_low": ci_l,
            "ci_high": ci_h,
            "direction": _direction(eff),
            "significant": any(sigs),
            "rows": rows,
        }
    return out


def concordance_score(summary: Dict[str, Dict]) -> Dict[str, object]:
    """Cross-design concordance.

    Returns score 0..1 + label + breakdown.
    """
    dirs = [v["direction"] for v in summary.values() if v["direction"] is not None]
    n_designs = len(dirs)
    if n_designs == 0:
        return {"score": None, "label": "NO_DATA", "n_designs": 0, "directions": {}}
    pos = dirs.count("+")
    neg = dirs.count("-")
    null = dirs.count("0")
    # concordance = majority share
    majority = max(pos, neg, null)
    score = majority / n_designs if n_designs else 0.0

    if n_designs >= 3 and score >= 0.75:
        label = "HIGH_CONCORDANCE"
    elif n_designs >= 3 and score >= 0.50:
        label = "MODERATE_CONCORDANCE"
    elif n_designs >= 2 and pos > 0 and neg > 0:
        label = "DISCORDANT"
    else:
        label = "LIMITED_DATA"

    return {
        "score": round(score, 3),
        "label": label,
        "n_designs": n_designs,
        "pos": pos, "neg": neg, "null": null,
        "directions": {d: summary[d]["direction"] for d in summary},
    }


def build_grid(effects: List[Dict[str, object]]) -> List[Dict[str, object]]:
    """Build complete grid across all (stage, outcome) pairs."""
    pairs = unique_pairs(effects)
    grid: List[Dict[str, object]] = []
    for stage, outcome in pairs:
        ds = design_summary(effects, stage, outcome)
        conc = concordance_score(ds)
        grid.append({
            "stage": stage,
            "outcome": outcome,
            "designs": ds,
            "concordance": conc,
        })
    # sort: discordant + high first, then by n_designs desc
    label_rank = {"DISCORDANT": 0, "HIGH_CONCORDANCE": 1, "MODERATE_CONCORDANCE": 2,
                  "LIMITED_DATA": 3, "NO_DATA": 4}
    grid.sort(key=lambda g: (label_rank.get(g["concordance"]["label"], 9),
                             -g["concordance"]["n_designs"]))
    return grid


def discordant_top(grid: List[Dict[str, object]], top_n: int = 5):
    return [g for g in grid if g["concordance"]["label"] == "DISCORDANT"][:top_n]


def grid_as_table(grid: List[Dict[str, object]]) -> List[List[str]]:
    """Flatten grid to printable table."""
    rows = [["stage", "outcome", "obs", "MR", "RCT", "PCLS", "lifestyle",
             "concord", "label"]]
    short = {"observational": "obs", "MR": "MR", "RCT": "RCT",
             "ex_vivo_pcls": "PCLS", "within_subject_lifestyle": "life"}
    for g in grid:
        ds = g["designs"]
        rows.append([
            str(g["stage"]),
            str(g["outcome"])[:24],
            (ds["observational"]["direction"] or "·"),
            (ds["MR"]["direction"] or "·"),
            (ds["RCT"]["direction"] or "·"),
            (ds["ex_vivo_pcls"]["direction"] or "·"),
            (ds["within_subject_lifestyle"]["direction"] or "·"),
            str(g["concordance"]["score"] or "·"),
            g["concordance"]["label"],
        ])
    return rows
