"""2-reader paired blinded adjudication workflow and drift dashboard.

Provides:
- random paired assignment with blinding
- discordance detection + 3rd reader / panel routing
- Cohen's kappa overall and per-quarter
- adjudicator drift (label distribution by quarter)
- workload balancing and calibration-round scheduler
- turnaround statistics

Inputs are pandas-free where possible to keep this module importable even
when pandas is not installed (the CLI demo will degrade gracefully).
"""

from __future__ import annotations

import csv
import random
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Tuple


# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------


@dataclass
class AdjudicatorJudgement:
    event_id: str
    adjudicator_id: str
    label: str
    quarter: str  # e.g. "2025Q1"
    turnaround_h: float = 24.0

    @classmethod
    def from_row(cls, row: Dict[str, str]) -> "AdjudicatorJudgement":
        return cls(
            event_id=row["event_id"],
            adjudicator_id=row["adjudicator_id"],
            label=row["label"],
            quarter=row.get("quarter", "Q?"),
            turnaround_h=float(row.get("turnaround_h", "24") or 24),
        )


@dataclass
class PairedDecision:
    event_id: str
    reader_a: str
    label_a: str
    reader_b: str
    label_b: str
    concordant: bool
    routed_to: Optional[str] = None  # 3rd_reader / panel / None


# ---------------------------------------------------------------------------
# Loaders
# ---------------------------------------------------------------------------


def load_judgements_csv(path: str) -> List[AdjudicatorJudgement]:
    out: List[AdjudicatorJudgement] = []
    with open(path, newline="", encoding="utf-8") as f:
        for row in csv.DictReader(f):
            out.append(AdjudicatorJudgement.from_row(row))
    return out


# ---------------------------------------------------------------------------
# Assignment
# ---------------------------------------------------------------------------


def assign_paired_readers(
    event_ids: List[str],
    adjudicators: List[str],
    seed: int = 7,
) -> Dict[str, Tuple[str, str]]:
    """Round-robin two distinct adjudicators per event for workload balance.

    Uses a deterministic seeded shuffle so the same input yields the same plan.
    """
    if len(adjudicators) < 2:
        raise ValueError("Need at least 2 adjudicators")
    rng = random.Random(seed)
    rotation = list(adjudicators)
    rng.shuffle(rotation)
    workload: Counter = Counter()
    plan: Dict[str, Tuple[str, str]] = {}
    for eid in event_ids:
        # pick the two least-loaded adjudicators, tie-break by rotation order
        ranked = sorted(rotation, key=lambda a: (workload[a], rotation.index(a)))
        a, b = ranked[0], ranked[1]
        plan[eid] = (a, b)
        workload[a] += 1
        workload[b] += 1
    return plan


def schedule_calibration(
    event_ids: List[str], interval: int = 20
) -> List[str]:
    """Return the subset of events that should be calibration cases."""
    return [eid for i, eid in enumerate(event_ids) if (i + 1) % interval == 0]


# ---------------------------------------------------------------------------
# Pairing / discordance
# ---------------------------------------------------------------------------


def pair_judgements(judgements: Iterable[AdjudicatorJudgement]
                    ) -> List[PairedDecision]:
    """Pair up to 2 judgements per event_id; route discordant ones."""
    by_event: Dict[str, List[AdjudicatorJudgement]] = defaultdict(list)
    for j in judgements:
        by_event[j.event_id].append(j)

    out: List[PairedDecision] = []
    for eid, lst in by_event.items():
        if len(lst) < 2:
            continue
        # Take first two distinct adjudicators
        first = lst[0]
        second = next((j for j in lst[1:] if j.adjudicator_id != first.adjudicator_id),
                      lst[1])
        concordant = first.label == second.label
        routed = None
        if not concordant:
            third = next((j for j in lst[2:]
                          if j.adjudicator_id not in (first.adjudicator_id, second.adjudicator_id)),
                         None)
            routed = "3rd_reader" if third else "panel"
        out.append(PairedDecision(
            event_id=eid,
            reader_a=first.adjudicator_id, label_a=first.label,
            reader_b=second.adjudicator_id, label_b=second.label,
            concordant=concordant, routed_to=routed,
        ))
    return out


# ---------------------------------------------------------------------------
# Cohen's kappa (no sklearn dependency)
# ---------------------------------------------------------------------------


def cohens_kappa(rater_a: List[str], rater_b: List[str]) -> float:
    """Compute Cohen's kappa for two lists of categorical labels."""
    if len(rater_a) != len(rater_b) or not rater_a:
        return 0.0
    labels = sorted(set(rater_a) | set(rater_b))
    idx = {lab: i for i, lab in enumerate(labels)}
    k = len(labels)
    matrix = [[0] * k for _ in range(k)]
    for a, b in zip(rater_a, rater_b):
        matrix[idx[a]][idx[b]] += 1
    n = len(rater_a)
    po = sum(matrix[i][i] for i in range(k)) / n
    row_tot = [sum(matrix[i]) for i in range(k)]
    col_tot = [sum(matrix[j][i] for j in range(k)) for i in range(k)]
    pe = sum((row_tot[i] / n) * (col_tot[i] / n) for i in range(k))
    if pe == 1:
        return 1.0
    return (po - pe) / (1 - pe)


# ---------------------------------------------------------------------------
# Drift & turnaround
# ---------------------------------------------------------------------------


def label_distribution_by_quarter(
    judgements: Iterable[AdjudicatorJudgement]
) -> Dict[str, Counter]:
    out: Dict[str, Counter] = defaultdict(Counter)
    for j in judgements:
        out[j.quarter][j.label] += 1
    return dict(out)


def kappa_by_quarter(paired: List[PairedDecision],
                     judgements: List[AdjudicatorJudgement]) -> Dict[str, float]:
    quarter_of: Dict[str, str] = {j.event_id: j.quarter for j in judgements}
    by_q: Dict[str, Tuple[List[str], List[str]]] = defaultdict(lambda: ([], []))
    for p in paired:
        q = quarter_of.get(p.event_id, "Q?")
        by_q[q][0].append(p.label_a)
        by_q[q][1].append(p.label_b)
    return {q: cohens_kappa(a, b) for q, (a, b) in by_q.items()}


def turnaround_summary(judgements: Iterable[AdjudicatorJudgement]) -> Dict[str, float]:
    by_adj: Dict[str, List[float]] = defaultdict(list)
    for j in judgements:
        by_adj[j.adjudicator_id].append(j.turnaround_h)
    return {a: round(sum(v) / len(v), 2) for a, v in by_adj.items() if v}


def workload_summary(judgements: Iterable[AdjudicatorJudgement]) -> Dict[str, int]:
    c: Counter = Counter()
    for j in judgements:
        c[j.adjudicator_id] += 1
    return dict(c)


# ---------------------------------------------------------------------------
# Top-level summary
# ---------------------------------------------------------------------------


def summarize(judgements: List[AdjudicatorJudgement]) -> Dict[str, Any]:
    paired = pair_judgements(judgements)
    overall_kappa = cohens_kappa(
        [p.label_a for p in paired],
        [p.label_b for p in paired],
    )
    discordant = [p for p in paired if not p.concordant]
    routed = Counter(p.routed_to for p in discordant if p.routed_to)
    return {
        "n_events_paired": len(paired),
        "n_discordant": len(discordant),
        "discordance_rate": round(len(discordant) / len(paired), 3) if paired else 0,
        "overall_kappa": round(overall_kappa, 3),
        "kappa_by_quarter": {q: round(v, 3)
                              for q, v in kappa_by_quarter(paired, judgements).items()},
        "label_distribution_by_quarter": {
            q: dict(c) for q, c in label_distribution_by_quarter(judgements).items()
        },
        "routing": dict(routed),
        "turnaround_h_by_adjudicator": turnaround_summary(judgements),
        "workload_by_adjudicator": workload_summary(judgements),
    }
