"""Sensitivity analysis and DSMB / CSR-ready reporting.

Supports:
- 3-point (CV death + non-fatal MI + non-fatal stroke)
- 4-point (3p + HHF)
- 5-point (4p + UA hospitalization or revasc — charter-driven)
- Type 2 MI include / exclude variants
- DSMB-ready interim summary
- CSR/manuscript supplementary .docx export
"""

from __future__ import annotations

import os
from collections import Counter, defaultdict
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional

try:
    from docx import Document  # type: ignore
    from docx.shared import Pt  # type: ignore
except ImportError:  # pragma: no cover
    Document = None
    Pt = None


# ---------------------------------------------------------------------------
# MACE composite helpers
# ---------------------------------------------------------------------------


_MI_LABELS_BROAD = {"Type 1", "Type 2", "Type 3", "Type 4a", "Type 4b", "Type 4c", "Type 5"}
_MI_LABELS_NARROW = {"Type 1", "Type 3", "Type 4a", "Type 4b", "Type 4c", "Type 5"}  # excludes Type 2

_STROKE_NONFATAL = {"Ischemic", "Hemorrhagic", "SAH"}
_HF = {"HHF"}
_CV_DEATH_LABELS = {"fatal_MI", "fatal_HF", "fatal_stroke",
                    "sudden_cardiac_death", "arrhythmia", "undetermined", "other_CV"}


def is_mi(label: str, include_type2: bool) -> bool:
    pool = _MI_LABELS_BROAD if include_type2 else _MI_LABELS_NARROW
    if label in pool:
        return True
    # The classifier emits "Type 2 (excluded)" when the charter excludes Type 2.
    # If sensitivity grid asks to re-include Type 2, treat that label as a Type 2.
    if include_type2 and label == "Type 2 (excluded)":
        return True
    return False


def is_cv_death(label: str) -> bool:
    return label in _CV_DEATH_LABELS


def is_non_fatal_stroke(label: str) -> bool:
    return label in _STROKE_NONFATAL


def is_hhf(label: str) -> bool:
    return label in _HF


# ---------------------------------------------------------------------------
# Endpoint counters
# ---------------------------------------------------------------------------


@dataclass
class EndpointSummary:
    name: str
    n_total: int
    by_component: Dict[str, int]
    by_arm: Dict[str, int]

    def to_dict(self) -> Dict[str, Any]:
        return {
            "name": self.name,
            "n_total": self.n_total,
            "by_component": dict(self.by_component),
            "by_arm": dict(self.by_arm),
        }


def count_mace(
    classifications: List[Dict[str, Any]],
    packet_arms: Dict[str, str],
    flavor: str,
    include_type2_mi: bool,
) -> EndpointSummary:
    """Count MACE events. classifications is a list of dicts {event_id, domain, label}.

    flavor ∈ {"3p", "4p", "5p"}.
    """
    comp: Counter = Counter()
    arm: Counter = Counter()
    for c in classifications:
        domain, label = c.get("domain"), c.get("label")
        eid = c.get("event_id")
        included = False
        if domain == "MI" and is_mi(label, include_type2_mi):
            comp["non_fatal_MI"] += 1
            included = True
        elif domain == "Stroke" and is_non_fatal_stroke(label):
            comp["non_fatal_stroke"] += 1
            included = True
        elif domain == "CV_death" and is_cv_death(label):
            comp["CV_death"] += 1
            included = True
        elif flavor in ("4p", "5p") and domain == "HF" and is_hhf(label):
            comp["HHF"] += 1
            included = True
        elif flavor == "5p" and label == "UA_hospitalization":
            comp["UA_hospitalization"] += 1
            included = True
        if included and eid is not None:
            arm[packet_arms.get(eid, "unknown")] += 1
    return EndpointSummary(
        name=f"{flavor}-MACE (Type2 MI {'in' if include_type2_mi else 'out'})",
        n_total=sum(comp.values()), by_component=dict(comp), by_arm=dict(arm),
    )


def run_sensitivity_grid(
    classifications: List[Dict[str, Any]],
    packet_arms: Dict[str, str],
) -> List[EndpointSummary]:
    out: List[EndpointSummary] = []
    for flavor in ("3p", "4p", "5p"):
        for inc in (True, False):
            out.append(count_mace(classifications, packet_arms, flavor, inc))
    return out


# ---------------------------------------------------------------------------
# DSMB interim summary
# ---------------------------------------------------------------------------


def dsmb_interim_summary(
    sensitivity: List[EndpointSummary],
    adjudication_summary: Dict[str, Any],
) -> Dict[str, Any]:
    primary = next((s for s in sensitivity if s.name.startswith("3p-MACE (Type2 MI out)")),
                   sensitivity[0] if sensitivity else None)
    return {
        "generated_at": datetime.utcnow().isoformat() + "Z",
        "primary_endpoint": primary.to_dict() if primary else None,
        "all_endpoints": [s.to_dict() for s in sensitivity],
        "adjudication": adjudication_summary,
        "notes": [
            "Treatment arms are blinded labels A/B/placebo.",
            "DSMB only sees aggregated counts; no patient-level identifiers.",
            "Generated by CVOT-MACEAdjudicate-Kor research tool (synthetic data).",
        ],
    }


# ---------------------------------------------------------------------------
# DOCX export
# ---------------------------------------------------------------------------


def export_csr_docx(
    output_path: str,
    sensitivity: List[EndpointSummary],
    adjudication_summary: Dict[str, Any],
    charter: Dict[str, Any],
) -> str:
    """Write a CSR/manuscript supplementary docx.

    If python-docx is not installed, write a Markdown fallback at the same path
    with `.md` extension and return that path.
    """
    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)

    if Document is None:
        md_path = os.path.splitext(output_path)[0] + ".md"
        with open(md_path, "w", encoding="utf-8") as f:
            f.write(_render_markdown(sensitivity, adjudication_summary, charter))
        return md_path

    doc = Document()
    doc.add_heading("CVOT-MACEAdjudicate-Kor CEC Supplementary Report", level=0)
    doc.add_paragraph(
        "Research / synthetic-data tool. Not for clinical decision making."
    )
    doc.add_paragraph(f"Generated: {datetime.utcnow().isoformat()}Z")

    doc.add_heading("1. Charter", level=1)
    for k, v in charter.items():
        doc.add_paragraph(f"{k}: {v}")

    doc.add_heading("2. Adjudicator Performance", level=1)
    doc.add_paragraph(f"Overall kappa: {adjudication_summary.get('overall_kappa')}")
    doc.add_paragraph(f"Discordance rate: {adjudication_summary.get('discordance_rate')}")
    doc.add_paragraph(f"Events paired: {adjudication_summary.get('n_events_paired')}")
    doc.add_paragraph(f"Routing on discordance: {adjudication_summary.get('routing')}")
    doc.add_paragraph("Kappa by quarter:")
    for q, k in adjudication_summary.get("kappa_by_quarter", {}).items():
        doc.add_paragraph(f"  {q}: {k}", style="List Bullet")

    doc.add_heading("3. Sensitivity Analysis", level=1)
    table = doc.add_table(rows=1, cols=4)
    hdr = table.rows[0].cells
    hdr[0].text = "Endpoint"
    hdr[1].text = "N total"
    hdr[2].text = "Components"
    hdr[3].text = "By arm"
    for s in sensitivity:
        row = table.add_row().cells
        row[0].text = s.name
        row[1].text = str(s.n_total)
        row[2].text = ", ".join(f"{k}={v}" for k, v in s.by_component.items())
        row[3].text = ", ".join(f"{k}={v}" for k, v in s.by_arm.items())

    doc.add_heading("4. Notes", level=1)
    doc.add_paragraph(
        "Outputs are charter-driven. Type 2 MI inclusion affects MI counts; "
        "the primary endpoint per current charter is shown first in the DSMB summary."
    )
    doc.save(output_path)
    return output_path


def _render_markdown(sensitivity: List[EndpointSummary],
                     adjudication_summary: Dict[str, Any],
                     charter: Dict[str, Any]) -> str:
    lines: List[str] = []
    lines.append("# CVOT-MACEAdjudicate-Kor CEC Supplementary Report")
    lines.append("")
    lines.append("> Research / synthetic-data tool. Not for clinical decision making.")
    lines.append("")
    lines.append(f"Generated: {datetime.utcnow().isoformat()}Z")
    lines.append("")
    lines.append("## 1. Charter")
    for k, v in charter.items():
        lines.append(f"- **{k}**: {v}")
    lines.append("")
    lines.append("## 2. Adjudicator performance")
    lines.append(f"- Overall kappa: {adjudication_summary.get('overall_kappa')}")
    lines.append(f"- Discordance rate: {adjudication_summary.get('discordance_rate')}")
    lines.append(f"- Events paired: {adjudication_summary.get('n_events_paired')}")
    lines.append(f"- Routing: {adjudication_summary.get('routing')}")
    lines.append("")
    lines.append("## 3. Sensitivity analysis")
    lines.append("")
    lines.append("| Endpoint | N | Components | By arm |")
    lines.append("|---|---|---|---|")
    for s in sensitivity:
        comp = ", ".join(f"{k}={v}" for k, v in s.by_component.items())
        arm = ", ".join(f"{k}={v}" for k, v in s.by_arm.items())
        lines.append(f"| {s.name} | {s.n_total} | {comp} | {arm} |")
    return "\n".join(lines) + "\n"
