"""IsletPerifusionAnalyzer (이슬렛퍼퓨젼애널라이저) — CLI entrypoint.

Pipeline:
    1. Parse vendor-specific perifusion CSV (BioRep / Brandel / in-house)
    2. Apply dead-volume lag, degradation drift, baseline subtraction,
       per-IEQ normalization, KCl-peak batch normalization
    3. Compute kinetic parameters (1st/2nd phase, AUC, fold-change,
       GSIS / KSIS, GLP-1 potentiation, lipotoxicity, proinsulin/insulin)
    4. Multi-analyte co-analysis (insulin / C-peptide / glucagon / proinsulin)
    5. Emit PNG + PDF report, kinetic CSV/XLSX, GraphPad-ready CSV,
       Diabetes/Diabetologia reproducibility checklist, methods MD draft

Usage:
    python3 main.py --demo --report-dir ./output
    python3 main.py --input ./data --analyte all --graphpad --checklist

Disclaimer:
    본 도구는 연구·참고용이며 임상 의사결정에 직접 사용 금지.
"""
from __future__ import annotations

import argparse
import math
import os
import sys
from typing import Any, Dict, List

# stdlib + numpy/scipy/pandas/matplotlib are required.  Other packages are optional.
try:
    sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
    from lib import parsers, corrections, kinetics, multianalyte, report, synth
except Exception as e:  # pragma: no cover
    print(f"[fatal] failed to import lib modules: {e}", file=sys.stderr)
    sys.exit(2)


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def build_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(
        prog="islet-perifusion-analyzer",
        description="Perifusion CSV import + GSIS kinetics + multi-analyte co-analysis. Research use only.",
        formatter_class=argparse.RawTextHelpFormatter,
    )
    p.add_argument("--input", help="Path to a perifusion CSV file or a directory of CSVs")
    p.add_argument("--demo", action="store_true",
                   help="Generate synthetic perifusion data and run end-to-end")
    p.add_argument("--analyte", default="insulin",
                   choices=["insulin", "c-peptide", "glucagon", "proinsulin", "all"],
                   help="Analyte channel to process (default: insulin)")
    p.add_argument("--report-dir", default="./output",
                   help="Output directory for plots / tables / report (default: ./output)")
    p.add_argument("--graphpad", action="store_true",
                   help="Emit GraphPad-Prism-ready CSV (and .pzfx stub)")
    p.add_argument("--checklist", action="store_true",
                   help="Emit Diabetes/Diabetologia reproducibility checklist (MD + PDF)")
    p.add_argument("--ieq", type=float, default=100.0,
                   help="IEQ basis for normalization (default: 100)")
    p.add_argument("--flow", type=float, default=0.1,
                   help="Flow rate mL/min (default: 0.1)")
    p.add_argument("--dead-volume", type=float, default=100.0,
                   help="Dead volume uL (default: 100)")
    p.add_argument("--storage-hours", type=float, default=4.0,
                   help="Sample storage hours for degradation correction (default: 4)")
    return p


# ---------------------------------------------------------------------------
# pipeline helpers
# ---------------------------------------------------------------------------
def _process_file(parsed: Dict[str, Any], args: argparse.Namespace) -> Dict[str, Any]:
    time_min = list(parsed["time_min"])
    channels: Dict[str, List[float]] = {ch: list(v) for ch, v in parsed["channels"].items()}

    # 1. dead-volume lag correction (shift time axis once for the whole file)
    lag = corrections.transit_time_min(args.dead_volume, args.flow)
    time_corr = corrections.shift_time_axis(time_min, lag)

    # 2. degradation drift correction
    if args.storage_hours > 0:
        for ch in channels:
            channels[ch] = corrections.degradation_correct(
                channels[ch], time_corr, args.storage_hours
            )

    # 3. baseline subtraction (before -> after)
    raw_channels = {ch: list(v) for ch, v in channels.items()}
    for ch in channels:
        baseline, _sd = corrections.baseline_window(time_corr, channels[ch], 10.0)
        channels[ch] = corrections.subtract_baseline(channels[ch], 0.0)  # keep raw scale; baseline used in kinetics

    # 4. inter-channel normalization (per-IEQ; identical IEQ for synthetic demo)
    for ch in channels:
        channels[ch] = corrections.per_ieq(channels[ch], args.ieq)

    # 5. multi-batch KCl normalization
    scaled, scale_factors, cv_pct = corrections.batch_normalize_to_kcl(
        channels, time_corr, kcl_t_start=40.0, kcl_t_end=50.0, target_peak=40.0
    )

    # kinetic parameters per channel (use scaled traces)
    kin_rows = []
    for ch, vals in scaled.items():
        params = kinetics.kinetic_params(time_corr, vals)
        params["channel"] = ch
        params["source_file"] = parsed.get("source_file", "")
        params["vendor"] = parsed.get("vendor", "")
        params["scale_factor"] = scale_factors.get(ch, 1.0)
        kin_rows.append(params)

    # cross-condition derived metrics (vehicle vs treatments)
    ch_keys = list(scaled.keys())
    vehicle_key = next((c for c in ch_keys if c.lower().startswith("vehicle")), ch_keys[0] if ch_keys else None)
    derived = {}
    if vehicle_key and ch_keys:
        veh = next(r for r in kin_rows if r["channel"] == vehicle_key)
        for r in kin_rows:
            if r["channel"] == vehicle_key:
                continue
            if "GLP-1" in r["channel"] or "Exendin" in r["channel"]:
                derived[f"{r['channel']}_GLP1_potentiation"] = kinetics.glp1_potentiation_index(veh, r)
            if "Palmitate" in r["channel"]:
                derived[f"{r['channel']}_lipotoxicity"] = kinetics.lipotoxicity_delta(veh, r)

    return {
        "time_min": time_corr,
        "raw_channels": raw_channels,
        "scaled_channels": scaled,
        "kinetic_rows": kin_rows,
        "kcl_cv_pct": cv_pct,
        "kcl_pass": corrections.kcl_pass(cv_pct),
        "derived": derived,
        "vendor": parsed.get("vendor", ""),
        "source_file": parsed.get("source_file", ""),
        "meta_hint": parsed.get("meta_hint", {}),
        "lag_min": lag,
    }


def run_pipeline(input_files: List[str], args: argparse.Namespace) -> int:
    os.makedirs(args.report_dir, exist_ok=True)
    parsed_list = [parsers.parse_perifusion_csv(p) for p in input_files]
    if not parsed_list:
        print("[error] no input files parsed", file=sys.stderr)
        return 1

    all_kin_rows: List[Dict[str, Any]] = []
    all_kcl_cv: List[float] = []
    n_channels_total = 0

    multi_analyte_traces: Dict[str, Dict[str, List[float]]] = {}  # analyte -> channel-> trace

    for parsed in parsed_list:
        result = _process_file(parsed, args)
        all_kin_rows.extend(result["kinetic_rows"])
        if not math.isnan(result["kcl_cv_pct"]):
            all_kcl_cv.append(result["kcl_cv_pct"])
        n_channels_total += len(result["scaled_channels"])

        # spaghetti plot (per file, scaled traces, with median +- IQR)
        ylabel = {"insulin": "Insulin (ng/mL)",
                  "c-peptide": "C-peptide (ng/mL)",
                  "glucagon": "Glucagon (pg/mL)",
                  "proinsulin": "Proinsulin (ng/mL)",
                  "all": "Analyte"}.get(args.analyte, "Analyte")
        title = f"{result['source_file']}  ({result['vendor']})"
        png_path = os.path.join(
            args.report_dir,
            f"trace_{os.path.splitext(result['source_file'])[0]}.png",
        )
        report.plot_spaghetti(result["time_min"], result["scaled_channels"], title, png_path, ylabel=ylabel)

        # graphpad export
        if args.graphpad:
            gp_csv = os.path.join(
                args.report_dir,
                f"graphpad_{os.path.splitext(result['source_file'])[0]}.csv",
            )
            report.graphpad_export(result["time_min"], result["scaled_channels"], gp_csv)

        # collect for multi-analyte co-analysis
        analyte_label = result["meta_hint"].get("analyte", args.analyte)
        if analyte_label not in multi_analyte_traces:
            multi_analyte_traces[analyte_label] = {}
        for ch, vals in result["scaled_channels"].items():
            multi_analyte_traces[analyte_label][ch] = vals

    # -------- aggregate kinetic table --------
    kin_csv = os.path.join(args.report_dir, "kinetic_parameters.csv")
    report.write_kinetic_csv(all_kin_rows, kin_csv)
    xlsx_ok = report.write_kinetic_xlsx(
        all_kin_rows, os.path.join(args.report_dir, "kinetic_parameters.xlsx")
    )

    # -------- multi-analyte plot for first vehicle channel --------
    if len(multi_analyte_traces) >= 2:
        # pick a common channel
        common = None
        ch_sets = [set(v.keys()) for v in multi_analyte_traces.values()]
        if ch_sets:
            inter = set.intersection(*ch_sets)
            if inter:
                common = sorted(inter)[0]
        if common:
            ma_png = os.path.join(args.report_dir, f"multi_analyte_{common}.png")
            # use unified time axis from first parsed file
            t0 = parsed_list[0]["time_min"]
            report.plot_multi_analyte(t0, multi_analyte_traces, common, ma_png)

    # -------- iPSC-SC-beta maturation index (if such a sample present) --------
    maturation_lines: List[str] = []
    for r in all_kin_rows:
        if "iPSC" in (r.get("source_file", "") or ""):
            mi = multianalyte.maturation_index(r)
            day = multianalyte.closest_maturation_day(r)
            maturation_lines.append(
                f"- {r['source_file']} / {r['channel']}: "
                f"composite={mi['composite']:.2f} (closest={day})"
            )

    # -------- glucagon suppression report --------
    gcg_lines: List[str] = []
    for parsed in parsed_list:
        if parsed.get("meta_hint", {}).get("analyte") == "glucagon":
            for ch, vals in parsed["channels"].items():
                s = multianalyte.glucagon_suppression(parsed["time_min"], vals)
                gcg_lines.append(
                    f"- {parsed['source_file']} / {ch}: "
                    f"basal={s['basal_mean']:.1f}, high={s['high_mean']:.1f}, "
                    f"suppression={s['suppression_pct']:.1f}%"
                )

    # -------- C-peptide:insulin ratio (if both analytes seen) --------
    cpi_lines: List[str] = []
    if "insulin" in multi_analyte_traces and "c-peptide" in multi_analyte_traces:
        for ch in multi_analyte_traces["insulin"]:
            if ch in multi_analyte_traces["c-peptide"]:
                ratios = multianalyte.c_peptide_insulin_ratio(
                    multi_analyte_traces["c-peptide"][ch],
                    multi_analyte_traces["insulin"][ch],
                )
                finite = [x for x in ratios if not math.isnan(x)]
                if finite:
                    cpi_lines.append(f"- {ch}: median C-pep:Ins ratio = {sum(finite)/len(finite):.2f}")

    # -------- methods + checklist --------
    overall_cv = sum(all_kcl_cv) / len(all_kcl_cv) if all_kcl_cv else float("nan")
    ctx = {
        "n_files": len(parsed_list),
        "n_channels": n_channels_total,
        "kcl_cv_pct": f"{overall_cv:.2f}" if not math.isnan(overall_cv) else "n/a",
        "kcl_pass": (overall_cv <= 15.0) if not math.isnan(overall_cv) else False,
        "vendor": parsed_list[0].get("vendor", "in-house"),
        "flow_rate_ml_min": args.flow,
        "dead_volume_ul": args.dead_volume,
        "fill": {
            "Sample source": "; ".join(
                [p.get("meta_hint", {}).get("sample_kind", "n/a") for p in parsed_list]
            ),
            "Flow rate": f"{args.flow} mL/min",
            "Dead volume + transit time": f"{args.dead_volume} uL → {corrections.transit_time_min(args.dead_volume, args.flow):.2f} min",
            "Reference KCl peak": f"30 mM, CV {overall_cv:.2f}%" if not math.isnan(overall_cv) else "30 mM, CV n/a",
            "Stimulation protocol": "2.8/16.7 mM glucose 10/30 min + 30 mM KCl 10 min",
            "Perifusion system": parsed_list[0].get("vendor", "in-house"),
            "IEQ / cell number / protein": f"IEQ basis = {args.ieq}",
        },
    }

    methods_md = os.path.join(args.report_dir, "methods_draft.md")
    report.methods_draft_md(ctx, methods_md)

    if args.checklist:
        chk_md = os.path.join(args.report_dir, "reproducibility_checklist.md")
        report.write_checklist_md(chk_md, ctx)
        chk_pdf = os.path.join(args.report_dir, "reproducibility_checklist.pdf")
        report.checklist_to_pdf(chk_md, chk_pdf)

    if args.graphpad:
        report.graphpad_pzfx_stub(os.path.join(args.report_dir, "kinetics.pzfx"))

    # -------- summary text --------
    summary_path = os.path.join(args.report_dir, "summary.md")
    with open(summary_path, "w", encoding="utf-8") as f:
        f.write(f"# IsletPerifusionAnalyzer summary\n\n")
        f.write(f"- files: {len(parsed_list)}\n")
        f.write(f"- channels: {n_channels_total}\n")
        f.write(f"- KCl intra-batch CV: {ctx['kcl_cv_pct']}\n")
        f.write(f"- KCl gate (<=15%): {'PASS' if ctx['kcl_pass'] else 'FAIL'}\n")
        f.write(f"- transit lag: {corrections.transit_time_min(args.dead_volume, args.flow):.2f} min\n")
        f.write(f"- xlsx written: {xlsx_ok}\n")
        if maturation_lines:
            f.write("\n## iPSC-SC-beta maturation\n" + "\n".join(maturation_lines) + "\n")
        if gcg_lines:
            f.write("\n## Glucagon suppression\n" + "\n".join(gcg_lines) + "\n")
        if cpi_lines:
            f.write("\n## C-peptide : insulin ratio\n" + "\n".join(cpi_lines) + "\n")
        f.write("\n## Disclaimer\n본 도구는 연구·참고용이며 임상 의사결정에 직접 사용 금지.\n")

    print(f"[ok] wrote report to {args.report_dir}")
    print(f"     files={len(parsed_list)}, channels={n_channels_total}, "
          f"kcl_cv={ctx['kcl_cv_pct']}, gate={'PASS' if ctx['kcl_pass'] else 'FAIL'}")
    return 0


def main(argv=None) -> int:
    args = build_parser().parse_args(argv)
    if not args.demo and not args.input:
        print("[error] one of --demo or --input is required", file=sys.stderr)
        return 2

    if args.demo:
        demo_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
        os.makedirs(demo_dir, exist_ok=True)
        files = synth.make_demo_set(demo_dir)
        # demo always emits checklist + graphpad
        args.checklist = True
        args.graphpad = True
        return run_pipeline(files, args)

    # --input
    if os.path.isdir(args.input):
        files = [
            os.path.join(args.input, n)
            for n in sorted(os.listdir(args.input))
            if n.lower().endswith(".csv")
        ]
    else:
        files = [args.input]
    return run_pipeline(files, args)


if __name__ == "__main__":
    raise SystemExit(main())
