"""DSC quarterly 시그널 리포트 + RMP/PSUR/DSUR docx 초안.

참고용·연구용 — Not for clinical decision. 본 리포트는 합성 데이터 기반 초안이다.
"""
from __future__ import annotations

import json
import math
import os
from dataclasses import asdict, is_dataclass
from datetime import date
from typing import Any

try:
    from docx import Document  # type: ignore
    from docx.shared import Pt as DocxPt  # type: ignore
    _DOCX_OK = True
except Exception:  # pragma: no cover - optional dep
    _DOCX_OK = False

from .panels import PanelResult, PANEL_LABEL_KO, compute_all_panels, panel_incidence
from .subgroup import IRR, all_panel_irrs, onset_distribution, dose_response, subgroup_incidence
from .disproportion import Disproportion, disproportionality_all, filter_glp1ra, load_faers


REPORTS_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "reports")

SIGNAL_PANEL_LIST = [
    "pancreatitis", "thyroid", "gallbladder", "gi",
    "injection_site", "retinopathy", "suicidality",
]


def _classify_strength(irr: IRR | None, faers_signal: bool, faers_prr: float | None) -> str:
    """signal 강도 분류 (charter 기반)."""
    irr_lci = irr.lci if irr and not math.isnan(irr.lci) else 0.0
    prr = faers_prr or 0.0
    if irr_lci >= 2.0 and prr >= 4.0:
        return "strong"
    if irr_lci >= 1.0 and prr >= 2.0:
        return "emerging"
    return "no_signal"


def build_quarterly_summary(
    ae_rows: list[dict],
    lab_rows: list[dict],
    faers_path: str | None = None,
    drug_arms: tuple[str, ...] = ("semaglutide_2.4", "tirzepatide_15"),
    ref_arm: str = "placebo",
    quarter_label: str | None = None,
) -> dict[str, Any]:
    if quarter_label is None:
        today = date.today()
        q = (today.month - 1) // 3 + 1
        quarter_label = f"{today.year}Q{q}"

    panels = compute_all_panels(ae_rows, lab_rows)
    irrs = all_panel_irrs(ae_rows, SIGNAL_PANEL_LIST, list(drug_arms), ref_arm)

    faers = load_faers(faers_path)
    disps = disproportionality_all(faers)
    glp1_disps = filter_glp1ra(disps)

    # By (drug_arm_canonical, panel) take best PRR among matched FAERS pairs
    arm_to_faers_drug = {
        "semaglutide_2.4": "semaglutide",
        "tirzepatide_15": "tirzepatide",
    }
    best_faers: dict[tuple[str, str], Disproportion] = {}
    for d in glp1_disps:
        key = (d.drug, d.panel)
        cur = best_faers.get(key)
        if cur is None or (d.prr or 0) > (cur.prr or 0):
            best_faers[key] = d

    panel_blocks: list[dict[str, Any]] = []
    for panel in SIGNAL_PANEL_LIST:
        block: dict[str, Any] = {
            "panel": panel,
            "label_ko": PANEL_LABEL_KO[panel],
            "panel_result": _serialize(panels[panel]),
            "incidence": panel_incidence(ae_rows, panel),
            "onset_dist": onset_distribution(ae_rows, panel),
            "irrs": [],
            "dose_response": {},
            "subgroup": {},
            "faers": [],
            "signal_strength": "no_signal",
        }
        for arm in drug_arms:
            irr = next((i for i in irrs if i.panel == panel and i.drug_arm == arm), None)
            faers_drug = arm_to_faers_drug.get(arm, arm)
            faers_d = best_faers.get((faers_drug, panel))
            strength = _classify_strength(irr, faers_d.signal if faers_d else False,
                                          faers_d.prr if faers_d else None)
            block["irrs"].append({"arm": arm, "irr": irr.to_dict() if irr else None,
                                  "strength": strength})
            block["dose_response"][arm] = dose_response(ae_rows, panel, arm)
            if faers_d:
                block["faers"].append(faers_d.to_dict())
            # upgrade panel strength if any arm shows higher
            order = ["no_signal", "emerging", "strong"]
            if order.index(strength) > order.index(block["signal_strength"]):
                block["signal_strength"] = strength
        # subgroup: focus on baseline_bmi + prior_pancreatitis (panc) + t2dm (retinopathy)
        for key in ("baseline_bmi", "age", "sex", "t2dm", "prior_pancreatitis"):
            block["subgroup"][key] = subgroup_incidence(ae_rows, panel, key)
        # SUSAR candidate
        susar_7day = sum(1 for r in ae_rows
                         if r["panel"] == panel and int(r.get("serious", 0) or 0) == 1
                         and (panel in ("pancreatitis", "suicidality")))
        susar_15day = sum(1 for r in ae_rows
                          if r["panel"] == panel and int(r.get("serious", 0) or 0) == 1)
        block["susar_7day_candidate"] = susar_7day
        block["susar_15day_candidate"] = susar_15day
        panel_blocks.append(block)

    summary = {
        "quarter": quarter_label,
        "generated": str(date.today()),
        "n_events": len(ae_rows),
        "panels": panel_blocks,
    }
    return summary


def _serialize(obj: Any) -> Any:
    if is_dataclass(obj):
        return asdict(obj)
    if isinstance(obj, dict):
        return {k: _serialize(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_serialize(v) for v in obj]
    return obj


def write_dsc_quarterly_json(summary: dict, out_dir: str | None = None) -> str:
    out_dir = out_dir or REPORTS_DIR
    os.makedirs(out_dir, exist_ok=True)
    path = os.path.join(out_dir, f"DSC_{summary['quarter']}.json")
    with open(path, "w", encoding="utf-8") as f:
        json.dump(_serialize(summary), f, ensure_ascii=False, indent=2, default=str)
    return path


def _signal_priority_table(summary: dict) -> list[dict]:
    rows = []
    for p in summary["panels"]:
        rows.append({
            "panel": p["panel"],
            "label_ko": p["label_ko"],
            "strength": p["signal_strength"],
            "susar_7d": p["susar_7day_candidate"],
            "susar_15d": p["susar_15day_candidate"],
        })
    order = {"strong": 0, "emerging": 1, "no_signal": 2}
    rows.sort(key=lambda r: order.get(r["strength"], 9))
    return rows


def write_module_docx(summary: dict, module: str = "RMP",
                      out_dir: str | None = None) -> str | None:
    """RMP / PSUR / DSUR 모듈 docx 초안 작성. python-docx 미설치면 .md fallback."""
    out_dir = out_dir or REPORTS_DIR
    os.makedirs(out_dir, exist_ok=True)
    title_map = {
        "RMP": "Risk Management Plan (RMP) — GLP-1RA Class Safety Module (DRAFT)",
        "PSUR": "Periodic Safety Update Report (PSUR) — GLP-1RA (DRAFT)",
        "DSUR": "Development Safety Update Report (DSUR) — GLP-1RA (DRAFT)",
    }
    title = title_map.get(module, f"{module} (DRAFT)")
    priority_rows = _signal_priority_table(summary)

    if _DOCX_OK:
        doc = Document()
        doc.add_heading(title, level=1)
        doc.add_paragraph("참고용·연구용 — Not for clinical decision. 합성 데이터 기반 초안.").italic = True
        doc.add_paragraph(f"Quarter: {summary['quarter']}    Generated: {summary['generated']}    "
                          f"N AE events: {summary['n_events']}")

        doc.add_heading("1. Signal Priority Summary", level=2)
        t = doc.add_table(rows=1, cols=5)
        t.style = "Light Grid"
        hdr = t.rows[0].cells
        for i, h in enumerate(["Panel", "Label", "Strength", "SUSAR 7d", "SUSAR 15d"]):
            hdr[i].text = h
        for row in priority_rows:
            r = t.add_row().cells
            r[0].text = row["panel"]
            r[1].text = row["label_ko"]
            r[2].text = row["strength"]
            r[3].text = str(row["susar_7d"])
            r[4].text = str(row["susar_15d"])

        doc.add_heading("2. Panel Detail", level=2)
        for p in summary["panels"]:
            doc.add_heading(f"{p['label_ko']} ({p['panel']})", level=3)
            doc.add_paragraph(f"Signal strength: {p['signal_strength']}")
            inc = p["incidence"]
            doc.add_paragraph(
                "Incidence by arm: " +
                ", ".join(f"{a}: {v['n_cases']}/{v['n_subjects']} ({v['incidence']:.1%})"
                          for a, v in inc.items())
            )
            for irr_block in p["irrs"]:
                if irr_block["irr"]:
                    i = irr_block["irr"]
                    doc.add_paragraph(
                        f"IRR ({irr_block['arm']} vs placebo): "
                        f"{i['IRR']} (95% CI {i['IRR_LCI']}-{i['IRR_UCI']}); "
                        f"strength={irr_block['strength']}"
                    )
            if p["faers"]:
                doc.add_paragraph("FAERS disproportionality (best PRR per drug):")
                for f in p["faers"]:
                    doc.add_paragraph(
                        f"  - {f['drug']} × {f['pt_term']}: "
                        f"PRR={f['PRR']} χ²={f['chi2']} ROR={f['ROR']} "
                        f"(95% CI {f['ROR_LCI']}-{f['ROR_UCI']}) "
                        f"EBGM={f['EBGM']} signal={f['signal']}"
                    )

        doc.add_heading("3. SUSAR Reporting Candidates", level=2)
        total_7d = sum(p["susar_7day_candidate"] for p in summary["panels"])
        total_15d = sum(p["susar_15day_candidate"] for p in summary["panels"])
        doc.add_paragraph(f"7-day expedited candidates: {total_7d}")
        doc.add_paragraph(f"15-day expedited candidates: {total_15d}")

        doc.add_heading("4. Disclaimer", level=2)
        doc.add_paragraph(
            "본 문서는 합성 데이터 기반 자동생성 초안이다. 실제 RMP/PSUR/DSUR 제출 전에는 "
            "DSC, sponsor, regulatory affairs, 통계학자, 임상의의 검증이 반드시 필요하다."
        )

        out_path = os.path.join(out_dir, f"{module}_{summary['quarter']}_DRAFT.docx")
        doc.save(out_path)
        return out_path

    # Fallback: markdown
    lines = [
        f"# {title}",
        "",
        "*참고용·연구용 — Not for clinical decision. 합성 데이터 기반 초안.*",
        "",
        f"Quarter: **{summary['quarter']}** | Generated: {summary['generated']} | N AE events: {summary['n_events']}",
        "",
        "## 1. Signal Priority Summary",
        "",
        "| Panel | Label | Strength | SUSAR 7d | SUSAR 15d |",
        "|---|---|---|---|---|",
    ]
    for row in priority_rows:
        lines.append(f"| {row['panel']} | {row['label_ko']} | {row['strength']} | "
                     f"{row['susar_7d']} | {row['susar_15d']} |")
    lines += ["", "## 2. Panel Detail", ""]
    for p in summary["panels"]:
        lines.append(f"### {p['label_ko']} ({p['panel']})")
        lines.append(f"- Signal strength: {p['signal_strength']}")
        inc = p["incidence"]
        lines.append("- Incidence: " + "; ".join(
            f"{a}: {v['n_cases']}/{v['n_subjects']} ({v['incidence']:.1%})"
            for a, v in inc.items()))
        for irr_block in p["irrs"]:
            if irr_block["irr"]:
                i = irr_block["irr"]
                lines.append(f"- IRR ({irr_block['arm']} vs placebo): "
                             f"{i['IRR']} (95% CI {i['IRR_LCI']}-{i['IRR_UCI']}); "
                             f"strength={irr_block['strength']}")
        if p["faers"]:
            lines.append("- FAERS disproportionality:")
            for f in p["faers"]:
                lines.append(
                    f"  - {f['drug']} × {f['pt_term']}: PRR={f['PRR']} χ²={f['chi2']} "
                    f"ROR={f['ROR']} (95% CI {f['ROR_LCI']}-{f['ROR_UCI']}) "
                    f"EBGM={f['EBGM']} signal={f['signal']}"
                )
        lines.append("")
    lines += [
        "## 3. SUSAR Reporting Candidates",
        f"- 7-day expedited candidates: {sum(p['susar_7day_candidate'] for p in summary['panels'])}",
        f"- 15-day expedited candidates: {sum(p['susar_15day_candidate'] for p in summary['panels'])}",
        "",
        "## 4. Disclaimer",
        "본 문서는 합성 데이터 기반 자동생성 초안이다. 실제 제출 전 DSC/regulatory 검증 필수.",
    ]
    out_path = os.path.join(out_dir, f"{module}_{summary['quarter']}_DRAFT.md")
    with open(out_path, "w", encoding="utf-8") as f:
        f.write("\n".join(lines))
    return out_path


def render_text_summary(summary: dict, max_lines: int = 28) -> str:
    """CLI demo용 간단 텍스트 요약 (head -30용)."""
    out = [
        f"=== GLP1AESignal-Kor DSC {summary['quarter']} ===",
        f"Generated: {summary['generated']} | N events: {summary['n_events']}",
        "",
        "Signal priority (panel | strength | n_strong_arm | SUSAR15d):",
    ]
    rows = _signal_priority_table(summary)
    for r in rows:
        out.append(f"  - {r['panel']:14s} | {r['strength']:10s} | susar15={r['susar_15d']}")
    out.append("")
    out.append("Per-panel IRR (drug vs placebo) — first arm only:")
    for p in summary["panels"]:
        first = next((i for i in p["irrs"] if i["irr"]), None)
        if first:
            i = first["irr"]
            out.append(f"  - {p['panel']:14s} {first['arm']:18s} IRR={i['IRR']} "
                       f"(95% CI {i['IRR_LCI']}-{i['IRR_UCI']})")
    return "\n".join(out[:max_lines])
