"""KPI aggregation + KASL/KLTS-style report generation (docx, optional)."""
from __future__ import annotations
from collections import Counter, defaultdict
from typing import Dict, List, Optional


DECOMP_LABELS_KOR = {
    "ascites": "복수 난치성",
    "VB": "정맥류 출혈",
    "HE": "간성뇌증",
    "HRS-AKI": "간신증후군-AKI",
    "SBP": "자발성세균성복막염",
    "ACLF": "급성악화성 간부전",
    "other": "기타",
}

ACLF_ORDER = ["no ACLF", "ACLF-1", "ACLF-2", "ACLF-3"]


def _mean(xs: List[float]) -> float:
    xs = [x for x in xs if x is not None]
    return round(sum(xs) / len(xs), 2) if xs else 0.0


def _rate(num: int, den: int) -> float:
    return round(num / den, 3) if den else 0.0


def episode_kpis(episodes: List[Dict[str, object]]) -> Dict[str, object]:
    """Compute episode-level KPIs.

    Expected episode dict keys:
      etiology, decomp_type, aclf_grade, meld3, child_class,
      los_days, in_hospital_mortality, mortality_30d,
      readmission_90d, lt_listed, lt_transplanted,
      protocol_adherence (0-1), ward.
    """
    n = len(episodes)
    if n == 0:
        return {"n": 0}

    by_type = Counter(e["decomp_type"] for e in episodes)
    by_aclf = Counter(e["aclf_grade"] for e in episodes)
    by_etio = Counter(e["etiology"] for e in episodes)
    by_ward = Counter(e["ward"] for e in episodes)

    type_mortality = defaultdict(list)
    type_los = defaultdict(list)
    type_readmit = defaultdict(list)
    type_adherence = defaultdict(list)

    for e in episodes:
        t = e["decomp_type"]
        type_mortality[t].append(e.get("mortality_30d", 0))
        type_los[t].append(e.get("los_days", 0))
        type_readmit[t].append(e.get("readmission_90d", 0))
        type_adherence[t].append(e.get("protocol_adherence", 0))

    aclf_mortality = defaultdict(list)
    for e in episodes:
        aclf_mortality[e["aclf_grade"]].append(e.get("mortality_30d", 0))

    return {
        "n": n,
        "n_patients": len({e["patient_id"] for e in episodes}),
        "by_decomp_type": dict(by_type),
        "by_aclf_grade": {g: by_aclf.get(g, 0) for g in ACLF_ORDER},
        "by_etiology": dict(by_etio),
        "by_ward": dict(by_ward),
        "type_mortality_30d": {t: _mean(v) for t, v in type_mortality.items()},
        "type_los_mean": {t: _mean(v) for t, v in type_los.items()},
        "type_readmission_90d": {t: _mean(v) for t, v in type_readmit.items()},
        "type_protocol_adherence": {t: _mean(v) for t, v in type_adherence.items()},
        "aclf_mortality_30d": {g: _mean(v) for g, v in aclf_mortality.items()},
        "overall_mortality_30d": _mean([e.get("mortality_30d", 0) for e in episodes]),
        "overall_readmission_90d": _mean([e.get("readmission_90d", 0) for e in episodes]),
        "overall_adherence": _mean([e.get("protocol_adherence", 0) for e in episodes]),
        "meld3_mean": _mean([e.get("meld3", 0) for e in episodes]),
        "lt_listed_rate": _rate(sum(1 for e in episodes if e.get("lt_listed")), n),
        "lt_transplanted_rate": _rate(sum(1 for e in episodes if e.get("lt_transplanted")), n),
    }


def render_text_report(kpis: Dict[str, object], lang: str = "kor") -> str:
    if kpis.get("n", 0) == 0:
        return "(no data)"
    L = lang == "kor"
    lines = []
    title = "시르데컴프유닛코어 QI 리포트" if L else "CirrDecompUnit-Kor QI Report"
    lines.append(f"# {title}")
    lines.append("")
    lines.append(("환자 수: " if L else "Patients: ") + str(kpis["n_patients"]))
    lines.append(("episode 수: " if L else "Episodes: ") + str(kpis["n"]))
    lines.append("")
    lines.append("## " + ("Decompensation type 분포" if L else "Decompensation type distribution"))
    for t, c in kpis["by_decomp_type"].items():
        label = DECOMP_LABELS_KOR.get(t, t) if L else t
        lines.append(f"- {label}: {c}")
    lines.append("")
    lines.append("## " + ("EASL-CLIF ACLF grade 분포" if L else "EASL-CLIF ACLF grade"))
    for g in ACLF_ORDER:
        c = kpis["by_aclf_grade"].get(g, 0)
        lines.append(f"- {g}: {c} ({('30일 사망률' if L else '30d mortality')} "
                     f"{kpis['aclf_mortality_30d'].get(g, 0):.2%})")
    lines.append("")
    lines.append("## " + ("Etiology" if not L else "원인 질환"))
    for e, c in kpis["by_etiology"].items():
        lines.append(f"- {e}: {c}")
    lines.append("")
    lines.append("## " + ("Decompensation type별 QI" if L else "Per-type QI"))
    lines.append(("| type | n | mortality_30d | LOS | readmission_90d | adherence |"
                  if not L else
                  "| 유형 | n | 30일 사망 | 입원일 | 90일 재입원 | protocol 부합 |"))
    lines.append("|---|---|---|---|---|---|")
    for t in kpis["by_decomp_type"]:
        label = DECOMP_LABELS_KOR.get(t, t) if L else t
        lines.append(f"| {label} | {kpis['by_decomp_type'][t]} "
                     f"| {kpis['type_mortality_30d'].get(t,0):.2%} "
                     f"| {kpis['type_los_mean'].get(t,0):.1f} "
                     f"| {kpis['type_readmission_90d'].get(t,0):.2%} "
                     f"| {kpis['type_protocol_adherence'].get(t,0):.2%} |")
    lines.append("")
    lines.append("## " + ("전체 요약" if L else "Overall"))
    lines.append(("- 평균 MELD 3.0: " if L else "- Mean MELD 3.0: ")
                 + f"{kpis['meld3_mean']}")
    lines.append(("- 전체 30일 사망률: " if L else "- Overall 30d mortality: ")
                 + f"{kpis['overall_mortality_30d']:.2%}")
    lines.append(("- 전체 90일 재입원율: " if L else "- Overall 90d readmission: ")
                 + f"{kpis['overall_readmission_90d']:.2%}")
    lines.append(("- 전체 protocol 부합률: " if L else "- Overall protocol adherence: ")
                 + f"{kpis['overall_adherence']:.2%}")
    lines.append(("- LT 등록률: " if L else "- LT listed rate: ")
                 + f"{kpis['lt_listed_rate']:.2%}")
    lines.append(("- LT 시행률: " if L else "- LT transplant rate: ")
                 + f"{kpis['lt_transplanted_rate']:.2%}")
    lines.append("")
    lines.append("---")
    lines.append(("주: 본 리포트는 합성 데이터 기반이며 참고용·연구용입니다. "
                  "임상의사결정에 사용해서는 안 됩니다."
                  if L else
                  "Note: This report uses synthetic data and is for research / QI use only."))
    return "\n".join(lines)


def write_docx(kpis: Dict[str, object], out_path: str, lang: str = "kor") -> Optional[str]:
    """Write a .docx report. Falls back to .md if python-docx unavailable."""
    text = render_text_report(kpis, lang=lang)
    try:
        from docx import Document  # type: ignore
        doc = Document()
        for line in text.split("\n"):
            if line.startswith("# "):
                doc.add_heading(line[2:], level=1)
            elif line.startswith("## "):
                doc.add_heading(line[3:], level=2)
            else:
                doc.add_paragraph(line)
        doc.save(out_path)
        return out_path
    except Exception:
        fallback = out_path.replace(".docx", ".md")
        with open(fallback, "w", encoding="utf-8") as fh:
            fh.write(text)
        return fallback
