#!/usr/bin/env python3
"""InsuTrialKit-Kor — CLI entry point.

Subcommands:
  ingest      Scan a directory of raw exports → write standardized long-form CSV
  outcomes    Compute ADA 2023 14-metric outcomes + write Markdown report
  samplesize  Compute per-arm N for a two-arm RCT on a continuous outcome
  gendata     Regenerate synthetic raw exports
  demo        Run the full pipeline on synthetic data
"""

from __future__ import annotations
import argparse
import os
import sys
import glob
import json

import pandas as pd

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

from adapters import dispatch, ADAPTERS  # noqa: E402
from outcomes import compute_outcomes, detect_gaps, apply_breakin_cut  # noqa: E402
from power import sample_size_two_arm, load_effect_size_db, by_hba1c_tertile  # noqa: E402
from report import build_markdown_report  # noqa: E402


# --------------------------------------------------------------------------
# ingest
# --------------------------------------------------------------------------
def cmd_ingest(args):
    in_dir = args.input
    out_path = args.output
    if not os.path.isdir(in_dir):
        print(f"[ingest] ERROR: input dir not found: {in_dir}", file=sys.stderr)
        return 2

    # gather candidate files (csv, xlsx, json)
    files = []
    for ext in ("*.csv", "*.CSV", "*.xlsx", "*.json"):
        files.extend(glob.glob(os.path.join(in_dir, ext)))
    files = sorted(set(files))
    if not files:
        print(f"[ingest] no candidate files found in {in_dir}")
        return 1

    frames = []
    for f in files:
        adapter, df = dispatch(f)
        if adapter is None:
            print(f"[ingest] skip (no adapter): {os.path.basename(f)}")
            continue
        print(f"[ingest] {os.path.basename(f)} → {adapter.NAME}: {len(df)} rows")
        # Tag source for traceability
        df = df.copy()
        df["_source_file"] = os.path.basename(f)
        df["_adapter"] = adapter.NAME
        frames.append(df)

    if not frames:
        print("[ingest] nothing parsed")
        return 1
    merged = pd.concat(frames, ignore_index=True)
    # KST is assumed; if tz-aware, convert. Otherwise pass-through.
    try:
        merged["timestamp_KST"] = pd.to_datetime(merged["timestamp_KST"])
    except Exception:
        pass

    # apply break-in cut (24h) per subject
    merged_cut = apply_breakin_cut(merged, hours=24)
    print(f"[ingest] after 24h break-in cut: {len(merged_cut)} rows "
          f"(was {len(merged)})")

    merged_cut.to_csv(out_path, index=False)
    print(f"[ingest] wrote {out_path}: {len(merged_cut)} rows, "
          f"{merged_cut['subject_id'].nunique()} subjects")
    return 0


# --------------------------------------------------------------------------
# outcomes
# --------------------------------------------------------------------------
def cmd_outcomes(args):
    in_path = args.input
    out_path = args.output
    if not os.path.isfile(in_path):
        print(f"[outcomes] ERROR: input not found: {in_path}", file=sys.stderr)
        return 2

    df = pd.read_csv(in_path)
    df["timestamp_KST"] = pd.to_datetime(df["timestamp_KST"], errors="coerce")
    print(f"[outcomes] loaded {len(df)} rows, {df['subject_id'].nunique()} subjects")

    outcomes = compute_outcomes(df)
    gaps = detect_gaps(df, gap_min=30)

    # also compute default sample size for the report
    db = load_effect_size_db()
    default = db.get("default", {"delta_TIR_pct": 5.0, "sd_TIR_pct": 12.0})
    power_res = sample_size_two_arm(default["delta_TIR_pct"], default["sd_TIR_pct"])

    md = build_markdown_report(outcomes, gaps, power_res)
    with open(out_path, "w", encoding="utf-8") as f:
        f.write(md)

    # also dump outcomes csv side-by-side
    outcomes_csv = os.path.splitext(out_path)[0] + "_outcomes.csv"
    outcomes.to_csv(outcomes_csv, index=False)

    print(f"[outcomes] wrote report: {out_path}")
    print(f"[outcomes] wrote outcomes CSV: {outcomes_csv}")
    print(f"[outcomes] gaps detected: {len(gaps)}")
    if not outcomes.empty:
        print("[outcomes] sanity check:")
        print(outcomes[["subject_id", "TIR_70_180_pct", "GMI_pct", "CV_pct"]].to_string(index=False))
    return 0


# --------------------------------------------------------------------------
# samplesize
# --------------------------------------------------------------------------
def cmd_samplesize(args):
    res = sample_size_two_arm(args.delta, args.sd, args.alpha, args.power)
    print("[samplesize] two-arm parallel RCT, continuous outcome (e.g., ΔTIR)")
    for k, v in res.items():
        print(f"  {k}: {v}")

    if args.tertiles:
        db = load_effect_size_db()
        rows = by_hba1c_tertile(db, args.alpha, args.power)
        print("\n[samplesize] HbA1c tertile-stratified (Korean reference DB):")
        for r in rows:
            print(f"  - {r.get('tertile','?')} ({r.get('hba1c_range','?')}): "
                  f"Δ={r.get('delta_TIR_pct')}%, SD={r.get('sd_TIR_pct')}%, "
                  f"N/arm={r.get('n_per_arm','?')}")
    return 0


# --------------------------------------------------------------------------
# gendata
# --------------------------------------------------------------------------
def cmd_gendata(args):
    from data.gen_synth import generate_all
    paths = generate_all()
    for p in paths:
        print(f"[gendata] wrote {p}")
    return 0


# --------------------------------------------------------------------------
# demo (end-to-end)
# --------------------------------------------------------------------------
def cmd_demo(args):
    print("=== InsuTrialKit-Kor demo: end-to-end ===\n")

    # 1) generate synthetic data
    print("[1/4] generating synthetic raw exports...")
    rc = cmd_gendata(args)
    if rc != 0:
        return rc

    # 2) ingest
    print("\n[2/4] ingesting raw exports → out.csv")
    ns = argparse.Namespace(input=os.path.join(HERE, "data"),
                            output=os.path.join(HERE, "out.csv"))
    rc = cmd_ingest(ns)
    if rc != 0:
        return rc

    # 3) outcomes + report
    print("\n[3/4] computing outcomes + writing report.md")
    ns = argparse.Namespace(input=os.path.join(HERE, "out.csv"),
                            output=os.path.join(HERE, "report.md"))
    rc = cmd_outcomes(ns)
    if rc != 0:
        return rc

    # 4) sample size
    print("\n[4/4] sample size calculation (default: ΔTIR=5%, SD=12%, α=0.05, power=0.8)")
    ns = argparse.Namespace(delta=5.0, sd=12.0, alpha=0.05, power=0.80, tertiles=True)
    rc = cmd_samplesize(ns)
    if rc != 0:
        return rc

    print("\n=== demo complete ===")
    print(f"Outputs:")
    print(f"  - {os.path.join(HERE, 'out.csv')}")
    print(f"  - {os.path.join(HERE, 'report.md')}")
    print(f"  - {os.path.join(HERE, 'report_outcomes.csv')}")
    return 0


# --------------------------------------------------------------------------
def build_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(
        prog="insutrialkit-kor",
        description="InsuTrialKit-Kor: multi-device CGM/pump standardizer + ADA 2023 "
                    "outcomes + Korean RCT report generator. RESEARCH USE ONLY.",
    )
    sp = p.add_subparsers(dest="cmd", required=True)

    p_ing = sp.add_parser("ingest", help="standardize raw exports → long-form CSV")
    p_ing.add_argument("--input", required=True, help="directory of raw exports")
    p_ing.add_argument("--output", required=True, help="output CSV path")
    p_ing.set_defaults(func=cmd_ingest)

    p_out = sp.add_parser("outcomes", help="compute ADA 2023 14-metric + report.md")
    p_out.add_argument("--input", required=True, help="long-form CSV")
    p_out.add_argument("--output", required=True, help="report.md output path")
    p_out.set_defaults(func=cmd_outcomes)

    p_ss = sp.add_parser("samplesize", help="two-arm parallel RCT N/arm")
    p_ss.add_argument("--delta", type=float, required=True, help="effect (e.g., ΔTIR %)")
    p_ss.add_argument("--sd", type=float, required=True, help="SD (e.g., 12 for TIR)")
    p_ss.add_argument("--alpha", type=float, default=0.05)
    p_ss.add_argument("--power", type=float, default=0.80)
    p_ss.add_argument("--tertiles", action="store_true",
                      help="also show HbA1c tertile-stratified N from KR DB")
    p_ss.set_defaults(func=cmd_samplesize)

    p_gd = sp.add_parser("gendata", help="(re)generate synthetic raw exports")
    p_gd.set_defaults(func=cmd_gendata)

    p_dm = sp.add_parser("demo", help="end-to-end demo on synthetic data")
    p_dm.set_defaults(func=cmd_demo)

    return p


def main(argv=None):
    parser = build_parser()
    args = parser.parse_args(argv)
    return args.func(args)


if __name__ == "__main__":
    sys.exit(main())
