"""Korean PDF report + figure-ready PNG/PDF + OpenClaw schema export.

Heavy stack avoided: matplotlib figures + PdfPages. No external network.
For Korean text fallback, we attempt to find a Korean-capable system font,
otherwise text falls back to ASCII-stripped + romanized labels.
"""
from __future__ import annotations

import json
import os
from typing import Dict, List, Optional

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import font_manager
import numpy as np
import pandas as pd


# ------------- Korean font handling -------------
def _find_korean_font() -> Optional[str]:
    """Return a font path that supports Korean glyphs, or None."""
    candidates = [
        "/System/Library/Fonts/AppleSDGothicNeo.ttc",
        "/System/Library/Fonts/Supplemental/AppleGothic.ttf",
        "/Library/Fonts/AppleGothic.ttf",
        "/System/Library/Fonts/Supplemental/NanumGothic.ttf",
        "/Library/Fonts/NanumGothic.ttf",
        "/usr/share/fonts/truetype/nanum/NanumGothic.ttf",
        "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
    ]
    for p in candidates:
        if os.path.exists(p):
            return p
    return None


_FONT_PATH = _find_korean_font()
if _FONT_PATH:
    try:
        font_manager.fontManager.addfont(_FONT_PATH)
        prop = font_manager.FontProperties(fname=_FONT_PATH)
        plt.rcParams["font.family"] = prop.get_name()
    except Exception:
        pass
plt.rcParams["axes.unicode_minus"] = False


# ------------- figure helpers -------------
def fig_pct_change_spaghetti(
    locked_df: pd.DataFrame,
    pct_col: str,
    title: str,
    ylabel: str,
    meta: pd.DataFrame,
) -> plt.Figure:
    """Spaghetti per mouse + median per arm."""
    df = pd.merge(locked_df, meta[["mouse_id", "treatment"]], on="mouse_id", how="left")
    fig, ax = plt.subplots(figsize=(8, 5))
    arms = sorted(df["treatment"].dropna().unique().tolist())
    cmap = plt.cm.get_cmap("tab10")
    for i, arm in enumerate(arms):
        sub = df[df["treatment"] == arm]
        color = cmap(i)
        for mid, m in sub.groupby("mouse_id"):
            m = m.sort_values("week")
            ax.plot(m["week"], m[pct_col], color=color, alpha=0.25, linewidth=0.8)
        med = sub.groupby("week")[pct_col].median().reset_index()
        ax.plot(
            med["week"],
            med[pct_col],
            color=color,
            linewidth=2.5,
            label=f"{arm} (n={sub['mouse_id'].nunique()})",
        )
    ax.axhline(0, color="grey", linewidth=0.5, linestyle="--")
    ax.set_xlabel("Week")
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.legend(fontsize=8, loc="best")
    fig.tight_layout()
    return fig


def fig_decompose(decompose: Dict[str, object]) -> plt.Figure:
    fig, ax = plt.subplots(figsize=(7, 5))
    if "error" in decompose:
        ax.text(0.5, 0.5, decompose["error"], ha="center")
        return fig
    pm = decompose["per_mouse"]  # type: ignore
    x = np.arange(len(pm))
    ax.bar(x - 0.18, pm["expected_lean_pct"], width=0.35, label="Expected lean Δ% (mono model)")
    ax.bar(x + 0.18, pm["actual_lean_pct"], width=0.35, label="Actual lean Δ% (combo)")
    ax.set_xticks(x)
    ax.set_xticklabels(pm["mouse_id"].tolist(), rotation=45, fontsize=7)
    ax.axhline(0, color="grey", linewidth=0.5)
    ax.set_ylabel("Lean mass % change")
    sparing = decompose["mean_sparing_pct"]
    msi = decompose["median_msi"]
    ax.set_title(
        f"Muscle-sparing decompose: {decompose['glp_arm']} → {decompose['combo_arm']}\n"
        f"mean sparing = {sparing:+.2f}pp | median MSI = {msi:+.2f}"
    )
    ax.legend(fontsize=8)
    fig.tight_layout()
    return fig


def fig_fiber_pattern(pattern: Dict[str, object]) -> plt.Figure:
    fig, ax = plt.subplots(figsize=(6, 4))
    if "error" in pattern:
        ax.text(0.5, 0.5, pattern["error"], ha="center")
        return fig
    order = ["IIB", "IIX", "IIA", "I"]
    vals = [pattern["ordered_pattern"][f] for f in order]  # type: ignore
    bars = ax.bar(order, vals)
    ax.axhline(0, color="grey", linewidth=0.5)
    ax.set_ylabel("CSA % change vs baseline")
    flag = "Y" if pattern["fast_twitch_dominant"] else "N"
    ax.set_title(
        f"Fiber type protection — {pattern['arm']}\nfast-twitch dominant: {flag}"
    )
    fig.tight_layout()
    return fig


def fig_composite_pass(ce_summary: pd.DataFrame) -> plt.Figure:
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.bar(ce_summary["treatment"], ce_summary["pass_rate_pct"])
    ax.set_ylabel("Composite endpoint pass rate (%)")
    ax.set_title(
        "Sarcopenic obesity composite endpoint\n"
        "(BW <-5% AND lean Δ >-3% AND grip/lean Δ >+5%)"
    )
    plt.setp(ax.get_xticklabels(), rotation=20, ha="right", fontsize=8)
    fig.tight_layout()
    return fig


# ------------- PDF report -------------
KO_LABELS = dict(
    title="SarcopeniaMuscleTrack 보고서",
    subtitle="사코페니아 비만 동물모델 — Cohort 분석",
    section_overview="1. Cohort 개요",
    section_traj="2. % Change Trajectory",
    section_composite="3. Composite Endpoint",
    section_decompose="4. Muscle-Sparing Decompose",
    section_fiber="5. Fiber Type Protection 패턴",
    section_methods="6. Methods (ARRIVE 2.0 호환 초안)",
    disclaimer=(
        "본 도구는 동물실험 데이터 분석용 — 임상 의사결정 직접 사용 금지. "
        "ARRIVE 2.0 reporting guideline 호환. IACUC 심의 / sample size justification은 "
        "사용자 책임."
    ),
)

EN_LABELS = dict(
    title="SarcopeniaMuscleTrack Report",
    subtitle="Sarcopenic obesity animal model — cohort analysis",
    section_overview="1. Cohort overview",
    section_traj="2. % change trajectory",
    section_composite="3. Composite endpoint",
    section_decompose="4. Muscle-sparing decompose",
    section_fiber="5. Fiber type protection pattern",
    section_methods="6. Methods (ARRIVE 2.0 draft)",
    disclaimer=(
        "Animal-research analysis tool — not for direct clinical decision-making. "
        "ARRIVE 2.0 compatible. IACUC approval / sample size justification is the user's responsibility."
    ),
)


def _labels(language: str) -> Dict[str, str]:
    return KO_LABELS if language == "ko" else EN_LABELS


def methods_arrive2_draft(
    meta: pd.DataFrame, n_per_arm: Dict[str, int], language: str = "ko"
) -> str:
    arms_list = sorted(meta["treatment"].unique().tolist())
    models_list = sorted(meta["model"].unique().tolist())
    if language == "ko":
        return (
            "# Methods (ARRIVE 2.0 호환 초안)\n\n"
            f"- 동물 모델: {', '.join(models_list)}\n"
            f"- 처치 군 ({len(arms_list)}): {', '.join(arms_list)}\n"
            f"- 군별 n: " + ", ".join([f"{k}={v}" for k, v in n_per_arm.items()]) + "\n"
            "- Randomization seed: cohort_meta.csv 기록\n"
            "- Endpoint: BW, EchoMRI/qNMR body composition, grip strength force-meter, "
            "treadmill exhaustion, micro-CT muscle CSA, myofiber HCS imaging summary, "
            "myokine ELISA(myostatin/activin A/irisin/decorin/BAIBA)\n"
            "- 분석: baseline lock(W0), per-mouse % change, median±IQR per arm, "
            "sarcopenic obesity composite endpoint(BW<-5% AND lean>-3% AND grip/lean>+5%), "
            "muscle-sparing decompose(monotherapy slope 기반 expected vs actual), "
            "fiber type protection 패턴(IIB>IIX>IIA>I).\n"
            "- 통계: scipy.stats.linregress, median/IQR; α=0.05 (multiple comparisons는 future scope).\n"
            "- 디스클레이머: 합성 demo 데이터는 약물 효과 시그널을 단순화한 합성치이며 "
            "실제 전임상/임상 반응을 대체하지 않음.\n"
        )
    return (
        "# Methods (ARRIVE 2.0 draft)\n\n"
        f"- Animal models: {', '.join(models_list)}\n"
        f"- Treatment arms ({len(arms_list)}): {', '.join(arms_list)}\n"
        f"- n per arm: " + ", ".join([f"{k}={v}" for k, v in n_per_arm.items()]) + "\n"
        "- Randomization seed: recorded in cohort_meta.csv\n"
        "- Endpoints: BW, EchoMRI/qNMR body composition, grip strength, treadmill, "
        "micro-CT CSA, myofiber HCS summary, myokine ELISA.\n"
        "- Analysis: baseline lock at W0, per-mouse %change, median±IQR per arm, "
        "composite endpoint, muscle-sparing decompose, fiber-type pattern.\n"
        "- Stats: scipy.stats.linregress, median/IQR; α=0.05.\n"
    )


def write_report(
    out_dir: str,
    cohort_id: str,
    meta: pd.DataFrame,
    figures: List[plt.Figure],
    ce_summary: pd.DataFrame,
    decompose: Optional[Dict[str, object]],
    fiber_pattern: Optional[Dict[str, object]],
    language: str = "ko",
) -> str:
    L = _labels(language)
    fig_dir = os.path.join(out_dir, "figures")
    tbl_dir = os.path.join(out_dir, "tables")
    os.makedirs(fig_dir, exist_ok=True)
    os.makedirs(tbl_dir, exist_ok=True)

    # individual PNG/PDF
    for i, fig in enumerate(figures):
        fig.savefig(os.path.join(fig_dir, f"figure_{i:02d}.png"), dpi=150)
        fig.savefig(os.path.join(fig_dir, f"figure_{i:02d}.pdf"))

    # combined PDF report
    pdf_path = os.path.join(out_dir, f"report_{cohort_id}_{language}.pdf")
    with PdfPages(pdf_path) as pdf:
        # cover
        cover = plt.figure(figsize=(8.27, 11.69))  # A4
        cover.text(0.5, 0.85, L["title"], ha="center", fontsize=22, fontweight="bold")
        cover.text(0.5, 0.78, L["subtitle"], ha="center", fontsize=14)
        cover.text(0.5, 0.70, f"cohort_id: {cohort_id}", ha="center", fontsize=10)
        cover.text(0.5, 0.66, f"n total: {meta.shape[0]}", ha="center", fontsize=10)
        cover.text(
            0.5,
            0.62,
            "arms: " + ", ".join(sorted(meta["treatment"].unique())),
            ha="center",
            fontsize=9,
        )
        cover.text(
            0.5,
            0.58,
            "models: " + ", ".join(sorted(meta["model"].unique())),
            ha="center",
            fontsize=9,
        )
        cover.text(0.5, 0.10, L["disclaimer"], ha="center", fontsize=8, wrap=True)
        pdf.savefig(cover)
        plt.close(cover)

        # composite summary page
        ce_fig = plt.figure(figsize=(8.27, 11.69))
        ce_fig.text(0.05, 0.95, L["section_composite"], fontsize=14, fontweight="bold")
        rows = ce_summary.to_dict("records")
        text_lines = []
        for r in rows:
            text_lines.append(
                f"- {r['treatment']:<24s}  n={r['n']:>3d}  "
                f"pass={r['n_pass']:>3d}  rate={r['pass_rate_pct']:>5.1f}%"
            )
        ce_fig.text(
            0.05,
            0.05,
            "\n".join(text_lines),
            fontsize=9,
            family="monospace",
            verticalalignment="bottom",
        )
        pdf.savefig(ce_fig)
        plt.close(ce_fig)

        # all figures
        for fig in figures:
            pdf.savefig(fig)
            plt.close(fig)

        # methods page
        n_per_arm = meta.groupby("treatment").size().to_dict()
        meth = methods_arrive2_draft(meta, n_per_arm, language)
        mfig = plt.figure(figsize=(8.27, 11.69))
        mfig.text(0.05, 0.95, L["section_methods"], fontsize=14, fontweight="bold")
        mfig.text(0.05, 0.05, meth, fontsize=8, verticalalignment="bottom", wrap=True)
        pdf.savefig(mfig)
        plt.close(mfig)

    # also dump methods.md
    with open(os.path.join(out_dir, "methods_arrive2.md"), "w", encoding="utf-8") as f:
        f.write(methods_arrive2_draft(meta, meta.groupby("treatment").size().to_dict(), language))

    return pdf_path


# ------------- OpenClaw export -------------
OPENCLAW_SCHEMA_VERSION = "openclaw.bio.v1"


def openclaw_export(
    out_dir: str,
    meta: pd.DataFrame,
    ce_df: pd.DataFrame,
    decompose: Optional[Dict[str, object]],
    fiber_pattern: Optional[Dict[str, object]],
) -> str:
    """Emit standard schema for OpenClaw 비만/근감소 약물 재조합 DB."""
    rows = []
    for _, m in meta.iterrows():
        ce_row = ce_df[ce_df["mouse_id"] == m["mouse_id"]]
        if ce_row.empty:
            continue
        ce_row = ce_row.iloc[0]
        rows.append(
            dict(
                schema=OPENCLAW_SCHEMA_VERSION,
                cohort_id=m.get("cohort_id", "demo_cohort"),
                mouse_id=m["mouse_id"],
                model=m["model"],
                treatment=m["treatment"],
                dose_mg_kg=float(m.get("dose_mg_kg", 0.0)),
                bw_pct_change=float(ce_row["bw_pct_change"]),
                lean_pct_change=float(ce_row["lean_pct_change"]),
                grip_per_lean_pct_change=float(ce_row["grip_per_lean_pct_change"]),
                composite_pass=int(ce_row["composite_pass"]),
            )
        )
    df = pd.DataFrame(rows)

    parquet_path = os.path.join(out_dir, "openclaw_export.parquet")
    json_path = os.path.join(out_dir, "openclaw_export.json")
    used_path = json_path
    try:
        df.to_parquet(parquet_path, index=False)
        used_path = parquet_path
    except Exception:
        pass

    # always also emit JSON for portability
    bundle = dict(
        schema=OPENCLAW_SCHEMA_VERSION,
        per_mouse=df.to_dict("records"),
        decompose_summary=(
            None
            if decompose is None or "error" in decompose
            else dict(
                glp_arm=decompose["glp_arm"],
                combo_arm=decompose["combo_arm"],
                n_mono=decompose["n_mono"],
                n_combo=decompose["n_combo"],
                mean_expected_lean_pct=decompose["mean_expected_lean_pct"],
                mean_actual_lean_pct=decompose["mean_actual_lean_pct"],
                mean_sparing_pct=decompose["mean_sparing_pct"],
                median_msi=decompose["median_msi"],
                mono_model=decompose["mono_model"],
            )
        ),
        fiber_pattern=(
            None
            if fiber_pattern is None or "error" in fiber_pattern
            else dict(
                arm=fiber_pattern["arm"],
                ordered_pattern=fiber_pattern["ordered_pattern"],
                fast_twitch_dominant=fiber_pattern["fast_twitch_dominant"],
            )
        ),
    )
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(bundle, f, indent=2, default=float, ensure_ascii=False)

    return used_path
