"""Report generation: figures + Korean PDF-via-matplotlib."""

from __future__ import annotations

import os
from typing import Iterable

import matplotlib

matplotlib.use("Agg")  # noqa: E402

import matplotlib.pyplot as plt  # noqa: E402
from matplotlib.backends.backend_pdf import PdfPages  # noqa: E402
import numpy as np  # noqa: E402
import pandas as pd  # noqa: E402


# ---------------------------------------------------------------------------
# Localised text — minimal i18n
# ---------------------------------------------------------------------------
TXT = {
    "ko": {
        "title": "MASHLipidDroplet — 지질방울 정량 보고서",
        "summary": "요약",
        "subtype_heatmap": "약물별 LD 서브타입 분포 (%)",
        "dose_response": "용량-반응 곡선",
        "moa": "MOA 핑거프린트 매칭 (top-1)",
        "plin": "PLIN coat 조성 변화 (vehicle 대비 log2 FC)",
        "disclaimer": "본 보고서는 연구·참고용입니다. 임상 의사결정에 직접 사용하지 마십시오.",
        "n_wells": "분석 well 수",
        "n_cells": "총 세포 수",
        "n_lds": "총 LD object 수",
        "ic50": "IC50 / EC50 추정값",
    },
    "en": {
        "title": "MASHLipidDroplet — Lipid Droplet Quantification Report",
        "summary": "Summary",
        "subtype_heatmap": "LD subtype distribution by drug (%)",
        "dose_response": "Dose-response curves",
        "moa": "MOA fingerprint match (top-1)",
        "plin": "PLIN coat composition shift (log2 FC vs vehicle)",
        "disclaimer": "Research / reference use only. Do not use for direct clinical decisions.",
        "n_wells": "Analysed wells",
        "n_cells": "Total cells",
        "n_lds": "Total LD objects",
        "ic50": "IC50 / EC50 estimates",
    },
}


# Try to use a Korean-capable font when language=ko (best-effort).
def _maybe_set_korean_font():
    candidates = [
        "AppleGothic",
        "AppleSDGothicNeo",
        "Apple SD Gothic Neo",
        "NanumGothic",
        "Malgun Gothic",
        "Noto Sans CJK KR",
        "Noto Sans KR",
        "STHeitiTC-Light",
    ]
    available = {f.name for f in matplotlib.font_manager.fontManager.ttflist}
    for name in candidates:
        if name in available:
            matplotlib.rcParams["font.family"] = name
            matplotlib.rcParams["axes.unicode_minus"] = False
            return name
    return None


def write_subtype_heatmap_png(per_well: pd.DataFrame, path: str) -> str:
    """Per-drug subtype % heatmap. Expects per_well with columns drug, macro_pct, medium_pct, micro_pct."""
    if per_well.empty or "drug" not in per_well.columns:
        return ""
    pivot = (
        per_well.groupby("drug")[["macro_pct", "medium_pct", "micro_pct"]]
        .mean()
        .sort_index()
    )
    fig, ax = plt.subplots(figsize=(6, max(2.5, 0.4 * len(pivot) + 1.0)))
    im = ax.imshow(pivot.values * 100.0, aspect="auto", cmap="magma", vmin=0, vmax=100)
    ax.set_yticks(range(len(pivot.index)))
    ax.set_yticklabels(pivot.index)
    ax.set_xticks(range(len(pivot.columns)))
    ax.set_xticklabels(["macro", "medium", "micro"])
    for i in range(pivot.shape[0]):
        for j in range(pivot.shape[1]):
            ax.text(
                j,
                i,
                f"{pivot.values[i, j] * 100:.0f}",
                ha="center",
                va="center",
                color="white" if pivot.values[i, j] < 0.5 else "black",
                fontsize=9,
            )
    fig.colorbar(im, ax=ax, label="%")
    fig.tight_layout()
    fig.savefig(path, dpi=160)
    plt.close(fig)
    return path


def write_dose_response_png(per_well: pd.DataFrame, fits: pd.DataFrame, response_col: str, path: str) -> str:
    if per_well.empty or "drug" not in per_well.columns or "dose_uM" not in per_well.columns:
        return ""
    drugs = sorted(per_well["drug"].dropna().unique().tolist())
    fig, ax = plt.subplots(figsize=(7, 4.5))
    cmap = plt.colormaps.get_cmap("tab10")
    for i, drug in enumerate(drugs):
        sub = per_well[per_well["drug"] == drug]
        agg = sub.groupby("dose_uM")[response_col].mean().reset_index().sort_values("dose_uM")
        ax.plot(agg["dose_uM"], agg[response_col], "o-", color=cmap(i % 10), label=drug, alpha=0.85)
    ax.set_xscale("log")
    ax.set_xlabel("dose (uM)")
    ax.set_ylabel(response_col)
    ax.legend(loc="best", fontsize=8)
    ax.grid(True, alpha=0.3, which="both")
    fig.tight_layout()
    fig.savefig(path, dpi=160)
    plt.close(fig)
    return path


def write_moa_png(moa_long: pd.DataFrame, path: str) -> str:
    if moa_long is None or moa_long.empty:
        return ""
    top = moa_long[moa_long["top1"]]
    if top.empty:
        return ""
    counts = top.groupby(["drug_observed", "ref_drug"]).size().unstack(fill_value=0)
    fig, ax = plt.subplots(figsize=(7, max(3.0, 0.4 * len(counts.index) + 1)))
    im = ax.imshow(counts.values, aspect="auto", cmap="viridis")
    ax.set_xticks(range(counts.shape[1]))
    ax.set_xticklabels(counts.columns, rotation=45, ha="right")
    ax.set_yticks(range(counts.shape[0]))
    ax.set_yticklabels(counts.index)
    for i in range(counts.shape[0]):
        for j in range(counts.shape[1]):
            v = counts.values[i, j]
            if v:
                ax.text(j, i, str(int(v)), ha="center", va="center", color="white", fontsize=8)
    fig.colorbar(im, ax=ax, label="# wells matched")
    fig.tight_layout()
    fig.savefig(path, dpi=160)
    plt.close(fig)
    return path


def write_plin_shift_png(per_well_plin_fc: pd.DataFrame, per_well: pd.DataFrame, path: str) -> str:
    if per_well_plin_fc is None or per_well_plin_fc.empty:
        return ""
    cols = ["plin1_frac_log2fc", "plin2_frac_log2fc", "plin3_frac_log2fc", "plin5_frac_log2fc"]
    have = [c for c in cols if c in per_well_plin_fc.columns]
    if not have or "drug" not in per_well.columns:
        return ""
    merged = per_well_plin_fc.merge(per_well[["well", "drug"]], on="well", how="left")
    pivot = merged.groupby("drug")[have].mean()
    fig, ax = plt.subplots(figsize=(7, max(2.5, 0.4 * len(pivot) + 1)))
    vmax = float(np.nanmax(np.abs(pivot.values))) if pivot.size else 1.0
    if not np.isfinite(vmax) or vmax == 0:
        vmax = 1.0
    im = ax.imshow(pivot.values, aspect="auto", cmap="bwr", vmin=-vmax, vmax=vmax)
    ax.set_yticks(range(len(pivot.index)))
    ax.set_yticklabels(pivot.index)
    ax.set_xticks(range(len(pivot.columns)))
    ax.set_xticklabels([c.replace("_frac_log2fc", "") for c in pivot.columns])
    for i in range(pivot.shape[0]):
        for j in range(pivot.shape[1]):
            ax.text(
                j,
                i,
                f"{pivot.values[i, j]:+.2f}",
                ha="center",
                va="center",
                color="black",
                fontsize=8,
            )
    fig.colorbar(im, ax=ax, label="log2 FC vs vehicle")
    fig.tight_layout()
    fig.savefig(path, dpi=160)
    plt.close(fig)
    return path


def write_pdf_report(
    out_pdf: str,
    summary_stats: dict,
    fit_table: pd.DataFrame,
    moa_top: pd.DataFrame,
    figure_paths: Iterable[str],
    language: str = "ko",
) -> str:
    txt = TXT.get(language, TXT["en"])
    if language == "ko":
        _maybe_set_korean_font()

    with PdfPages(out_pdf) as pdf:
        # Cover / summary page
        fig = plt.figure(figsize=(8.27, 11.69))  # A4 portrait
        fig.text(0.08, 0.93, txt["title"], fontsize=16, weight="bold")
        y = 0.87
        fig.text(0.08, y, txt["summary"], fontsize=13, weight="bold")
        y -= 0.04
        for k, v in summary_stats.items():
            fig.text(0.08, y, f"- {k}: {v}", fontsize=11)
            y -= 0.025
        y -= 0.02
        fig.text(0.08, y, txt["ic50"], fontsize=13, weight="bold")
        y -= 0.04
        if fit_table is not None and not fit_table.empty:
            for _, row in fit_table.iterrows():
                line = (
                    f"- {row.get('drug', '')} ({row.get('response', '')}): "
                    f"EC50={row.get('ec50', float('nan')):.3g} uM, "
                    f"hill={row.get('hill', float('nan')):.2f}, "
                    f"R^2={row.get('r2', float('nan')):.2f}, n={int(row.get('n', 0))}, "
                    f"ok={bool(row.get('ok', False))}"
                )
                fig.text(0.08, y, line, fontsize=9)
                y -= 0.022
                if y < 0.18:
                    break
        fig.text(0.08, 0.07, txt["disclaimer"], fontsize=9, style="italic", color="firebrick")
        pdf.savefig(fig)
        plt.close(fig)

        # Figure pages
        for path in figure_paths:
            if not path or not os.path.exists(path):
                continue
            fig = plt.figure(figsize=(8.27, 11.69))
            ax = fig.add_axes([0.06, 0.1, 0.88, 0.82])
            ax.axis("off")
            img = plt.imread(path)
            ax.imshow(img)
            ax.set_title(os.path.basename(path), fontsize=11)
            fig.text(0.08, 0.05, txt["disclaimer"], fontsize=8, style="italic", color="firebrick")
            pdf.savefig(fig)
            plt.close(fig)

        # MOA top-1 page
        if moa_top is not None and not moa_top.empty:
            fig = plt.figure(figsize=(8.27, 11.69))
            fig.text(0.08, 0.93, txt["moa"], fontsize=14, weight="bold")
            y = 0.88
            preview = moa_top.head(40)
            for _, row in preview.iterrows():
                line = (
                    f"- {row.get('well', '')}: obs={row.get('drug_observed', '')} "
                    f"-> ref={row.get('ref_drug', '')} "
                    f"(sim={row.get('similarity', float('nan')):.3f})"
                )
                fig.text(0.08, y, line, fontsize=9)
                y -= 0.02
                if y < 0.08:
                    break
            fig.text(0.08, 0.05, txt["disclaimer"], fontsize=8, style="italic", color="firebrick")
            pdf.savefig(fig)
            plt.close(fig)

    return out_pdf
