#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
NITSurrogate-Kor (엔아이티서로게이트코어)
=========================================================================
Domain   : MASLD / MASH (대사성간질환)
Category : 연구 아이디어 생성 (research-hypothesis generation)

A standalone, OFFLINE tool that ingests trial-level effect-size pairs along
the chain

        NIT (surrogate)  ->  Histology (intermediate)  ->  Hard hepatic outcome

from MASH RCTs, computes stage-by-stage TRIAL-LEVEL SURROGACY
(R^2_trial, surrogate threshold effect [STE], proportion-of-treatment-effect
[PTE] mediation) via inverse-variance weighted meta-regression, flags which
chain stage is unvalidated (especially the sparse histology->hard stage), and
emits validation-study hypotheses with required outcome-trial size / follow-up.

Methodology: Buyse-Molenberghs / Daniels-Hughes trial-level surrogacy with
weighted least squares -- implemented here in pure numpy/scipy because
statsmodels is intentionally NOT a dependency.

*** 연구용·참고용 (research / reference use only) -- NOT for clinical
    decision-making. Demo data are ILLUSTRATIVE / SYNTHETIC, not trial readouts. ***
=========================================================================
"""

import argparse
import os
import sys
import textwrap

import numpy as np
import pandas as pd
from scipy import stats

# --------------------------------------------------------------------------- #
# Constants / configuration
# --------------------------------------------------------------------------- #

DISCLAIMER = (
    "⚠️  연구용·참고용 (research/reference use only) — NOT for clinical "
    "decision-making.\n"
    "    Demo effect sizes are ILLUSTRATIVE / SYNTHETIC, not official trial "
    "readouts."
)

DEFAULT_DATA = os.path.join(
    os.path.dirname(os.path.abspath(__file__)), "data", "masld_surrogacy_demo.csv"
)

# Surrogacy strength grade thresholds on R^2_trial (configurable via CLI).
GRADE_STRONG = 0.70
GRADE_MODERATE = 0.50

# NITs (surrogates) the tool knows about.
NIT_METRICS = ["FIB4", "LSM_VCTE", "MRI_PDFF", "ELF", "MRE"]

# The three surrogacy stages of the chain.
STAGES = [
    ("NIT->histology", "nit", "histo"),
    ("histology->hard", "histo", "hard"),
    ("NIT->hard", "nit", "hard"),
]

# Minimum number of trials below which a stage is flagged "data-sparse".
SPARSE_MIN_TRIALS = 4

# Default assumptions used to size a proposed validation outcome trial.
DEFAULT_TARGET_HR = 0.75          # hazard ratio we'd want to detect
DEFAULT_BASELINE_EVENT_RATE = 0.20  # cumulative hard-event rate over follow-up
DEFAULT_ALPHA = 0.05
DEFAULT_POWER = 0.80
DEFAULT_FOLLOWUP_YEARS = 4.0


# --------------------------------------------------------------------------- #
# Small helpers
# --------------------------------------------------------------------------- #

def _hr(width=74):
    return "-" * width


def banner():
    print("=" * 74)
    print("NITSurrogate-Kor  |  MASLD/MASH 3-stage trial-level surrogacy")
    print("Domain: MASLD (대사성간질환)   Category: 연구 아이디어 생성")
    print(_hr())
    print(DISCLAIMER)
    print("=" * 74)


def grade_from_r2(r2, paradox=False):
    """Map an R^2_trial value (and a paradox flag) to a surrogacy grade."""
    if paradox:
        return "INVALID(paradox)"
    if r2 is None or np.isnan(r2):
        return "n/a"
    if r2 >= GRADE_STRONG:
        return "strong"
    if r2 >= GRADE_MODERATE:
        return "moderate"
    return "weak"


# --------------------------------------------------------------------------- #
# Data loading + normalization
# --------------------------------------------------------------------------- #

def load_data(path):
    """Load the tidy surrogacy CSV (comment lines start with '#')."""
    df = pd.read_csv(path, comment="#", skip_blank_lines=True)
    expected = {
        "trial", "drug", "drug_class", "nit_metric",
        "delta_nit", "delta_nit_se", "histo_metric", "histo_effect",
        "histo_se", "hard_outcome", "loghr", "loghr_se",
    }
    missing = expected - set(df.columns)
    if missing:
        raise ValueError(
            "CSV is missing required columns: %s" % ", ".join(sorted(missing))
        )

    # Coerce numerics; empty cells become NaN.
    for col in ["delta_nit", "delta_nit_se", "histo_effect", "histo_se",
                "loghr", "loghr_se"]:
        df[col] = pd.to_numeric(df[col], errors="coerce")

    # ----- sign normalization -----------------------------------------------
    # We want, for every layer, "more positive = more benefit".
    #   delta_nit  : negative = improvement  -> flip sign
    #   histo_effect: positive log-OR = improvement (good already); dNAS would
    #                 be negative=good, but demo uses binary endpoints. We treat
    #                 histo_metric == 'dNAS' specially.
    #   loghr      : negative = benefit       -> flip sign
    df["nit_benefit"] = -df["delta_nit"]
    df["nit_benefit_se"] = df["delta_nit_se"]

    histo = df["histo_effect"].copy()
    is_dnas = df["histo_metric"].astype(str).str.upper().eq("DNAS")
    histo[is_dnas] = -histo[is_dnas]  # dNAS negative = good -> flip
    df["histo_benefit"] = histo
    df["histo_benefit_se"] = df["histo_se"]

    df["hard_benefit"] = -df["loghr"]
    df["hard_benefit_se"] = df["loghr_se"]

    return df


def stage_columns(stage_key):
    """Return (effect_col, se_col) base names for a stage endpoint key."""
    mapping = {
        "nit": ("nit_benefit", "nit_benefit_se"),
        "histo": ("histo_benefit", "histo_benefit_se"),
        "hard": ("hard_benefit", "hard_benefit_se"),
    }
    return mapping[stage_key]


def stage_pairs(df, upstream, downstream, nit_filter=None):
    """
    Build the (x, x_se, y, y_se, weight) arrays for a surrogacy stage.

    For stages that involve the NIT layer, the unit of analysis is a
    (trial x NIT-metric) row. For histology->hard we collapse to one row per
    trial (the histology + hard effects are trial-level, not NIT-specific).

    When the upstream layer is the NIT layer and we are POOLING multiple NIT
    metrics (nit_filter is None), the different NITs live on incommensurable
    scales (MRI-PDFF relative % vs FIB-4 points vs LSM kPa). We therefore
    z-standardize delta_nit WITHIN each nit_metric before pooling, so the
    surrogacy regression measures the cross-trial relationship rather than an
    artifact of scale. Single-NIT analyses are left on their native scale so
    the STE is interpretable in that NIT's units.
    """
    ux, uxse = stage_columns(upstream)
    dy, dyse = stage_columns(downstream)

    work = df.copy()
    if nit_filter is not None:
        work = work[work["nit_metric"] == nit_filter]

    # histology->hard does not depend on the NIT metric: dedupe per trial.
    if upstream == "histo" and downstream == "hard":
        work = work.drop_duplicates(subset=["trial", "histo_metric"])

    sub = work[[ux, uxse, dy, dyse, "trial", "drug", "nit_metric"]].dropna(
        subset=[ux, uxse, dy, dyse]
    ).reset_index(drop=True)
    if len(sub) == 0:
        return None

    # Within-NIT standardization for the pooled NIT-upstream case.
    if upstream == "nit" and nit_filter is None and len(sub) > 0:
        sub = sub.copy()
        for nit, idx in sub.groupby("nit_metric").groups.items():
            vals = sub.loc[idx, ux].to_numpy(float)
            sd = vals.std(ddof=0)
            if sd > 1e-9:
                sub.loc[idx, ux] = (vals - vals.mean()) / sd
                sub.loc[idx, uxse] = sub.loc[idx, uxse].to_numpy(float) / sd
            else:
                sub.loc[idx, ux] = vals - vals.mean()

    x = sub[ux].to_numpy(dtype=float)
    xse = sub[uxse].to_numpy(dtype=float)
    y = sub[dy].to_numpy(dtype=float)
    yse = sub[dyse].to_numpy(dtype=float)

    # Inverse-variance weights based on the *downstream* (response) uncertainty
    # -- standard for trial-level surrogacy WLS (Buyse-Molenberghs).
    w = 1.0 / np.clip(yse ** 2, 1e-9, None)
    meta = sub[["trial", "drug", "nit_metric"]].reset_index(drop=True)
    return x, xse, y, yse, w, meta


# --------------------------------------------------------------------------- #
# Core statistics: weighted regression, R^2_trial, STE, PTE
# --------------------------------------------------------------------------- #

def weighted_linfit(x, y, w):
    """
    Weighted least-squares y = b0 + b1*x.

    Returns dict with coefficients, their covariance, weighted R^2, n, dof,
    residual scale, and the design pieces needed for prediction bands.
    """
    n = len(x)
    X = np.column_stack([np.ones(n), x])
    W = np.diag(w)
    XtW = X.T @ W
    XtWX = XtW @ X
    # rcond guards against a (near-)singular design (e.g. all-equal x).
    XtWX_inv = np.linalg.pinv(XtWX, rcond=1e-10)
    beta = XtWX_inv @ (XtW @ y)
    beta = np.nan_to_num(beta, nan=0.0, posinf=0.0, neginf=0.0)

    yhat = X @ beta
    resid = y - yhat
    dof = max(n - 2, 1)
    # weighted residual variance (mean squared weighted residual)
    sigma2 = float((w * resid ** 2).sum() / dof)

    # weighted R^2
    ybar_w = float((w * y).sum() / w.sum())
    ss_tot = float((w * (y - ybar_w) ** 2).sum())
    ss_res = float((w * resid ** 2).sum())
    r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else float("nan")
    r2 = float(np.clip(r2, 0.0, 1.0)) if not np.isnan(r2) else r2

    cov_beta = sigma2 * XtWX_inv
    cov_beta = np.nan_to_num(cov_beta, nan=0.0, posinf=0.0, neginf=0.0)

    return {
        "beta": beta,
        "cov_beta": cov_beta,
        "r2": r2,
        "n": n,
        "dof": dof,
        "sigma2": sigma2,
        "x": x,
        "y": y,
        "w": w,
        "ybar_w": ybar_w,
    }


def r2_confidence_interval(r2, n, alpha=0.05):
    """
    Approximate 95% CI for R^2 via Fisher z-transform on r = sqrt(R^2).
    Sign of correlation is assumed positive (validated surrogacy direction).
    """
    if r2 is None or np.isnan(r2) or n < 4:
        return (float("nan"), float("nan"))
    r = np.sqrt(np.clip(r2, 0, 1))
    if r >= 0.9999:
        r = 0.9999
    z = np.arctanh(r)
    se = 1.0 / np.sqrt(n - 3)
    zc = stats.norm.ppf(1 - alpha / 2)
    lo_r = np.tanh(z - zc * se)
    hi_r = np.tanh(z + zc * se)
    lo = float(np.clip(lo_r, -1, 1) ** 2 * np.sign(lo_r) if lo_r < 0 else lo_r ** 2)
    hi = float(np.clip(hi_r, 0, 1) ** 2)
    return (max(0.0, lo), min(1.0, hi))


def predict_with_band(fit, x_grid, alpha=0.05):
    """Mean prediction + lower/upper prediction band over x_grid."""
    beta = fit["beta"]
    cov = fit["cov_beta"]
    sigma2 = fit["sigma2"]
    dof = fit["dof"]
    tcrit = stats.t.ppf(1 - alpha / 2, dof)

    Xg = np.column_stack([np.ones(len(x_grid)), x_grid])
    with np.errstate(over="ignore", invalid="ignore", divide="ignore"):
        mean = Xg @ beta
        # variance of a *new* prediction = model variance + residual scatter
        var_mean = np.einsum("ij,jk,ik->i", Xg, cov, Xg)
        var_pred = var_mean + sigma2
        half = tcrit * np.sqrt(np.clip(var_pred, 0, None))
    mean = np.nan_to_num(mean, nan=0.0, posinf=0.0, neginf=0.0)
    half = np.nan_to_num(half, nan=0.0, posinf=0.0, neginf=0.0)
    return mean, mean - half, mean + half


def surrogate_threshold_effect(fit, alpha=0.05):
    """
    STE = the minimum *upstream* benefit at which the LOWER prediction band of
    the downstream benefit crosses zero (i.e. predicts a real downstream
    benefit with confidence). Smaller STE = better/more useful surrogate.

    Returns (ste, achievable_flag). achievable_flag is False if the band never
    crosses zero within the observed-plus range (surrogate not useful).
    """
    x = fit["x"]
    if len(x) < 3:
        return (float("nan"), False)
    x_min, x_max = float(np.min(x)), float(np.max(x))
    span = max(x_max - x_min, 1e-6)
    grid = np.linspace(x_min - 0.25 * span, x_max + 1.5 * span, 2000)
    _, lower, _ = predict_with_band(fit, grid, alpha=alpha)

    # find smallest grid x where lower band >= 0
    pos = np.where(lower >= 0.0)[0]
    if len(pos) == 0:
        return (float("nan"), False)
    return (float(grid[pos[0]]), True)


def proportion_treatment_effect(df):
    """
    PTE mediation across the chain: of the total NIT->hard relationship, how
    much is mediated through histology?

    We use a product-of-coefficients / difference-of-coefficients style
    decomposition on the trial-level weighted regressions:

        total  : hard_benefit ~ nit_benefit          -> b_total
        a-path : histo_benefit ~ nit_benefit         -> b_a
        b-path : hard_benefit ~ histo_benefit (adj)  -> approximated via
                 hard_benefit ~ nit_benefit + histo_benefit -> b_direct (on nit)

        PTE = (b_total - b_direct) / b_total  = indirect / total

    Requires trials that have all three layers (the sparse set). Values are
    clamped to [0,1] for reporting but the raw value + a plausibility flag are
    returned so implausible (>1 or <0) mediation is visible.
    """
    # rows with all three layers present (per trial, NIT-collapsed for hard)
    full = df.dropna(subset=["nit_benefit", "nit_benefit_se",
                             "histo_benefit", "histo_benefit_se",
                             "hard_benefit", "hard_benefit_se"]).copy()
    # collapse to one row per (trial) using FIB4 if present else first NIT row,
    # because hard + histo are trial-level.
    full = full.sort_values("nit_metric").drop_duplicates(subset=["trial"])
    n = len(full)
    result = {
        "n_full_trials": n,
        "pte_raw": float("nan"),
        "pte_clamped": float("nan"),
        "b_total": float("nan"),
        "b_direct": float("nan"),
        "a_path": float("nan"),
        "plausible": False,
        "note": "",
    }
    if n < 3:
        result["note"] = (
            "Too few trials with NIT+histology+hard simultaneously (%d). "
            "Mediation not estimable -- this is the structural data gap in "
            "MASLD surrogacy." % n
        )
        return result

    x = full["nit_benefit"].to_numpy(float)
    m = full["histo_benefit"].to_numpy(float)
    yv = full["hard_benefit"].to_numpy(float)
    w = 1.0 / np.clip(full["hard_benefit_se"].to_numpy(float) ** 2, 1e-9, None)

    # total: y ~ x
    ftot = weighted_linfit(x, yv, w)
    b_total = float(ftot["beta"][1])

    # a-path: m ~ x  (weight by histo se)
    wa = 1.0 / np.clip(full["histo_benefit_se"].to_numpy(float) ** 2, 1e-9, None)
    fa = weighted_linfit(x, m, wa)
    a_path = float(fa["beta"][1])

    # direct: y ~ x + m  (weighted)
    Xd = np.column_stack([np.ones(n), x, m])
    W = np.diag(w)
    XtW = Xd.T @ W
    beta_d = np.linalg.pinv(XtW @ Xd) @ (XtW @ yv)
    b_direct = float(beta_d[1])

    pte_raw = (b_total - b_direct) / b_total if abs(b_total) > 1e-9 else float("nan")
    plausible = (not np.isnan(pte_raw)) and (0.0 <= pte_raw <= 1.0)
    pte_clamped = float(np.clip(pte_raw, 0.0, 1.0)) if not np.isnan(pte_raw) else float("nan")

    result.update({
        "pte_raw": pte_raw,
        "pte_clamped": pte_clamped,
        "b_total": b_total,
        "b_direct": b_direct,
        "a_path": a_path,
        "plausible": plausible,
        "note": ("" if plausible else
                 "PTE outside [0,1] -> mediation estimate unstable "
                 "(expected with sparse hard-outcome data); interpret with "
                 "caution."),
    })
    return result


def detect_paradox(meta_x, meta_y):
    """
    Surrogate-paradox detection at the trial level: any trial where the
    upstream layer shows benefit (>0 after normalization) but the downstream
    layer shows harm (<0), beyond noise.
    """
    flags = []
    for xi, yi, trial, drug, nit in meta_x:
        if xi > 0 and yi < 0:
            flags.append((trial, drug, nit, xi, yi))
    return flags


# --------------------------------------------------------------------------- #
# Stage analysis orchestration
# --------------------------------------------------------------------------- #

def analyze_stage(df, stage_name, up, down, nit_filter=None, alpha=0.05):
    packed = stage_pairs(df, up, down, nit_filter=nit_filter)
    if packed is None:
        return {
            "stage": stage_name, "nit_filter": nit_filter, "n": 0,
            "r2": float("nan"), "r2_ci": (float("nan"), float("nan")),
            "slope": float("nan"), "ste": float("nan"),
            "ste_achievable": False, "grade": "n/a", "paradox": [],
            "sparse": True, "fit": None, "meta": None,
        }
    x, xse, y, yse, w, meta = packed
    n = len(x)
    paradox = []
    for i in range(n):
        if x[i] > 0 and y[i] < 0:
            paradox.append((meta.loc[i, "trial"], meta.loc[i, "drug"],
                            meta.loc[i, "nit_metric"], float(x[i]), float(y[i])))

    if n < 3:
        return {
            "stage": stage_name, "nit_filter": nit_filter, "n": n,
            "r2": float("nan"), "r2_ci": (float("nan"), float("nan")),
            "slope": float("nan"), "ste": float("nan"),
            "ste_achievable": False,
            "grade": "n/a(insufficient n)", "paradox": paradox,
            "sparse": True, "fit": None, "meta": meta,
        }

    fit = weighted_linfit(x, y, w)
    r2 = fit["r2"]
    r2_ci = r2_confidence_interval(r2, n, alpha=alpha)
    ste, ste_ok = surrogate_threshold_effect(fit, alpha=alpha)
    grade = grade_from_r2(r2, paradox=bool(paradox))

    return {
        "stage": stage_name, "nit_filter": nit_filter, "n": n,
        "r2": r2, "r2_ci": r2_ci, "slope": float(fit["beta"][1]),
        "ste": ste, "ste_achievable": ste_ok, "grade": grade,
        "paradox": paradox, "sparse": n < SPARSE_MIN_TRIALS,
        "fit": fit, "meta": meta,
    }


# --------------------------------------------------------------------------- #
# Sample-size / trial-design math for validation hypotheses
# --------------------------------------------------------------------------- #

def required_events(target_hr, alpha=DEFAULT_ALPHA, power=DEFAULT_POWER):
    """Schoenfeld formula: number of events for a log-rank test, 1:1 alloc."""
    za = stats.norm.ppf(1 - alpha / 2)
    zb = stats.norm.ppf(power)
    loghr = np.log(target_hr)
    if abs(loghr) < 1e-9:
        return float("inf")
    return int(np.ceil(4 * (za + zb) ** 2 / (loghr ** 2)))


def required_sample_size(target_hr, baseline_event_rate=DEFAULT_BASELINE_EVENT_RATE,
                         alpha=DEFAULT_ALPHA, power=DEFAULT_POWER):
    """Total N given a control-arm cumulative event rate over follow-up."""
    ev = required_events(target_hr, alpha, power)
    if ev == float("inf"):
        return float("inf")
    # pooled expected event proportion across both arms
    p_ctrl = baseline_event_rate
    p_trt = 1 - (1 - baseline_event_rate) ** target_hr  # approx via HR on cum-incidence
    p_bar = (p_ctrl + p_trt) / 2.0
    if p_bar <= 0:
        return float("inf")
    return int(np.ceil(ev / p_bar))


# --------------------------------------------------------------------------- #
# Subcommand renderers
# --------------------------------------------------------------------------- #

def cmd_chain(df, alpha=0.05):
    print("\n[ STAGE-BY-STAGE TRIAL-LEVEL SURROGACY ]  (pooled, all NITs)")
    print(_hr())
    header = "{:<16} {:>4} {:>7} {:>16} {:>9} {:>12}".format(
        "stage", "n", "R2", "R2 95% CI", "slope", "grade")
    print(header)
    print(_hr())
    results = []
    for stage_name, up, down in STAGES:
        r = analyze_stage(df, stage_name, up, down, alpha=alpha)
        results.append(r)
        ci = r["r2_ci"]
        ci_str = ("[%.2f, %.2f]" % ci) if not np.isnan(ci[0]) else "[ n/a ]"
        r2_str = "%.3f" % r["r2"] if not np.isnan(r["r2"]) else "n/a"
        slope = "%.3f" % r["slope"] if not np.isnan(r["slope"]) else "n/a"
        print("{:<16} {:>4} {:>7} {:>16} {:>9} {:>12}".format(
            r["stage"], r["n"], r2_str, ci_str, slope, r["grade"]))

    print(_hr())
    print("STE (surrogate threshold effect = minimum upstream benefit whose")
    print("lower 95% prediction band still predicts a downstream benefit):")
    for r in results:
        if r["fit"] is None:
            print("  - %-16s : n/a (insufficient/sparse data, n=%d)"
                  % (r["stage"], r["n"]))
        elif not r["ste_achievable"]:
            print("  - %-16s : NOT achievable (band never excludes null) -> "
                  "surrogate weak/unusable" % r["stage"])
        else:
            print("  - %-16s : STE = %.3f (upstream benefit units)"
                  % (r["stage"], r["ste"]))

    print(_hr())
    pte = proportion_treatment_effect(df)
    print("PTE mediation (does histology mediate the NIT->hard effect?):")
    print("  trials with NIT+histology+hard simultaneously : %d" % pte["n_full_trials"])
    if not np.isnan(pte["pte_raw"]):
        print("  PTE (raw)     : %.3f" % pte["pte_raw"])
        print("  PTE (clamped) : %.3f" % pte["pte_clamped"])
        print("  total NIT->hard slope  b_total = %.3f" % pte["b_total"])
        print("  direct (NIT|histo)     b_direct= %.3f" % pte["b_direct"])
        print("  plausible (0<=PTE<=1)? : %s" % pte["plausible"])
    if pte["note"]:
        print("  note: " + textwrap.fill(pte["note"], 66,
              subsequent_indent="        "))
    return results


def cmd_compare_nit(df, alpha=0.05, top=None):
    print("\n[ PER-NIT SURROGACY RANKING ]  (NIT->histology stage)")
    print(_hr())
    rows = []
    for nit in NIT_METRICS:
        r = analyze_stage(df, "NIT->histology", "nit", "histo",
                          nit_filter=nit, alpha=alpha)
        rows.append(r)
    # also rank on NIT->hard where data exist
    rows_hard = []
    for nit in NIT_METRICS:
        r = analyze_stage(df, "NIT->hard", "nit", "hard",
                          nit_filter=nit, alpha=alpha)
        rows_hard.append(r)

    def sort_key(r):
        return (-(r["r2"] if not np.isnan(r["r2"]) else -1), r["nit_filter"])

    rows_sorted = sorted(rows, key=sort_key)
    if top:
        rows_sorted = rows_sorted[:top]

    print("Ranked by NIT->histology R^2_trial:")
    print("{:<10} {:>4} {:>7} {:>16} {:>12}".format(
        "NIT", "n", "R2", "R2 95% CI", "grade"))
    print(_hr())
    for r in rows_sorted:
        ci = r["r2_ci"]
        ci_str = ("[%.2f, %.2f]" % ci) if not np.isnan(ci[0]) else "[ n/a ]"
        r2_str = "%.3f" % r["r2"] if not np.isnan(r["r2"]) else "n/a"
        print("{:<10} {:>4} {:>7} {:>16} {:>12}".format(
            r["nit_filter"], r["n"], r2_str, ci_str, r["grade"]))

    print(_hr())
    print("NIT->hard stage (sparse: only cirrhosis-stage outcome trials):")
    for r in sorted(rows_hard, key=sort_key):
        r2_str = "%.3f" % r["r2"] if not np.isnan(r["r2"]) else "n/a"
        note = "" if r["n"] >= 3 else "  <- too few trials to estimate"
        print("  %-10s n=%-2d R2=%s%s" % (r["nit_filter"], r["n"], r2_str, note))
    return rows_sorted


def cmd_paradox(df):
    print("\n[ SURROGATE-PARADOX SCAN ]")
    print(_hr())
    any_found = False
    for stage_name, up, down in STAGES:
        r = analyze_stage(df, stage_name, up, down)
        if r["paradox"]:
            any_found = True
            print("Stage %s -- paradox rows (upstream benefit, downstream harm):"
                  % stage_name)
            for trial, drug, nit, xi, yi in r["paradox"]:
                print("  - %-22s %-12s NIT=%-9s  up=+%.3f  down=%.3f"
                      % (trial, drug, nit, xi, yi))
    if not any_found:
        print("No surrogate-paradox rows detected in current dataset.")
    print("\nInterpretation: a paradox (surrogate improves, hard outcome worsens)")
    print("automatically downgrades that stage's surrogacy grade to INVALID.")


def mine_gaps(df, alpha=0.05):
    """Identify unvalidated stages across the NIT x histology x hard x drug grid."""
    gaps = []

    # 1) Per-stage pooled weakness / sparsity
    for stage_name, up, down in STAGES:
        r = analyze_stage(df, stage_name, up, down, alpha=alpha)
        reason = []
        if r["n"] < SPARSE_MIN_TRIALS:
            reason.append("data-sparse (n=%d < %d)" % (r["n"], SPARSE_MIN_TRIALS))
        if not np.isnan(r["r2"]) and r["r2"] < GRADE_MODERATE:
            reason.append("weak R2_trial=%.2f" % r["r2"])
        ci = r["r2_ci"]
        if not np.isnan(ci[0]) and (ci[1] - ci[0]) > 0.5:
            reason.append("wide R2 CI [%.2f,%.2f]" % ci)
        if r["paradox"]:
            reason.append("surrogate paradox present")
        if reason:
            gaps.append({"scope": stage_name, "by": "(pooled)",
                         "n": r["n"], "reasons": reason, "stage": stage_name})

    # 2) Per (NIT x drug_class) on the histology->hard structural gap
    classes = sorted(df["drug_class"].dropna().unique())
    for cls in classes:
        sub = df[df["drug_class"] == cls]
        n_hard = sub.dropna(subset=["hard_benefit"])["trial"].nunique()
        if n_hard < 2:
            gaps.append({
                "scope": "histology->hard",
                "by": "drug_class=%s" % cls,
                "n": n_hard,
                "reasons": ["no/too-few hard-outcome trials for this class "
                            "(n=%d) -> NIT/histology surrogacy UNVALIDATED for "
                            "hard hepatic outcomes" % n_hard],
                "stage": "histology->hard",
            })

    return gaps


def cmd_gaps(df, alpha=0.05, top=None):
    print("\n[ UNVALIDATED-STAGE MINING ]")
    print(_hr())
    gaps = mine_gaps(df, alpha=alpha)
    if top:
        gaps = gaps[:top]
    if not gaps:
        print("No unvalidated stages flagged.")
        return gaps
    for i, g in enumerate(gaps, 1):
        print("%2d. stage=%-16s scope=%-22s n=%d" %
              (i, g["scope"], g["by"], g["n"]))
        for rs in g["reasons"]:
            print("      - " + rs)
    return gaps


def cmd_hypotheses(df, alpha=0.05, top=None):
    print("\n[ VALIDATION-STUDY HYPOTHESES ]  (auto-generated from mined gaps)")
    print(_hr())
    gaps = mine_gaps(df, alpha=alpha)
    # focus on the highest-value gaps: structural histology->hard + weak stages
    # de-dup similar drug-class hard gaps but keep them informative.
    hyps = []
    ev = required_events(DEFAULT_TARGET_HR)
    n_total = required_sample_size(DEFAULT_TARGET_HR)

    for g in gaps:
        if g["stage"] == "histology->hard":
            if g["by"].startswith("drug_class="):
                cls = g["by"].split("=", 1)[1]
                q = ("Is a histologic response (MASH resolution / fibrosis "
                     "improvement) a VALID trial-level surrogate for hard "
                     "hepatic outcomes (decompensation / liver-related death) "
                     "in %s agents?" % cls)
            else:
                q = ("Is the histology->hard surrogacy relationship validated "
                     "across the MASH drug landscape, or is it driven only by "
                     "the few cirrhosis-stage outcome trials?")
            hyps.append({
                "question": q,
                "design": ("Event-driven RCT or pooled IPD meta-analysis with "
                           "baseline biopsy/NIT + adjudicated hard endpoints."),
                "events": ev, "n": n_total,
                "followup": DEFAULT_FOLLOWUP_YEARS,
                "rationale": "; ".join(g["reasons"]),
            })

    # NIT-specific hypotheses where NIT->histology is only moderate/weak
    for nit in NIT_METRICS:
        r = analyze_stage(df, "NIT->histology", "nit", "histo",
                          nit_filter=nit, alpha=alpha)
        if r["n"] >= 3 and not np.isnan(r["r2"]) and r["r2"] < GRADE_STRONG:
            hyps.append({
                "question": ("Can %s change validly replace biopsy as the "
                             "intermediate surrogate (NIT->histology) — its "
                             "current trial-level R2 is only %.2f?"
                             % (nit, r["r2"])),
                "design": ("Co-primary biopsy + %s sub-study nested in MASH "
                           "RCTs; LITMUS/NIMBLE-style qualification." % nit),
                "events": "n/a (histology endpoint)",
                "n": "≈250–400 paired biopsies (qualification-grade)",
                "followup": 1.0,
                "rationale": "NIT->histology R2_trial=%.2f (sub-strong)" % r["r2"],
            })
        elif r["n"] < 3:
            hyps.append({
                "question": ("Is %s even estimable as a surrogate? Too few "
                             "trials report paired %s + histology." % (nit, nit)),
                "design": "Prospective paired NIT+biopsy registry across programs.",
                "events": "n/a", "n": "≥10 trials w/ paired data",
                "followup": 1.0,
                "rationale": "only n=%d paired rows" % r["n"],
            })

    if top:
        hyps = hyps[:top]

    for i, h in enumerate(hyps, 1):
        print("H%-2d %s" % (i, textwrap.fill(h["question"], 68,
                                             subsequent_indent="    ")))
        print("    design   : " + textwrap.fill(h["design"], 60,
                                                 subsequent_indent="               "))
        ev_str = h["events"] if isinstance(h["events"], str) else (
            "%d adjudicated events" % h["events"])
        n_str = h["n"] if isinstance(h["n"], str) else ("≈%d participants" % h["n"])
        print("    required : %s | %s | follow-up ≈ %.1f yr"
              % (ev_str, n_str, h["followup"]))
        print("    rationale: " + h["rationale"])
        print()

    print(_hr())
    print("Trial-size assumptions: target HR=%.2f, baseline %d%% cumulative "
          "event rate," % (DEFAULT_TARGET_HR, int(DEFAULT_BASELINE_EVENT_RATE * 100)))
    print("alpha=%.2f two-sided, power=%d%% (Schoenfeld). Adjust in code/CLI."
          % (DEFAULT_ALPHA, int(DEFAULT_POWER * 100)))
    return hyps


def cmd_summary(df, alpha=0.05):
    """Default bare-invocation summary."""
    print("\n[ SUMMARY ]  3-stage MASLD/MASH surrogacy snapshot")
    print(_hr())
    res = []
    for stage_name, up, down in STAGES:
        r = analyze_stage(df, stage_name, up, down, alpha=alpha)
        res.append(r)
        r2_str = "%.3f" % r["r2"] if not np.isnan(r["r2"]) else "n/a"
        print("  %-16s n=%-2d  R2=%-6s  grade=%s"
              % (r["stage"], r["n"], r2_str, r["grade"]))
    # weakest validated stage
    print(_hr())
    weakest = None
    for r in res:
        if not np.isnan(r["r2"]):
            if weakest is None or r["r2"] < weakest["r2"]:
                weakest = r
    if weakest:
        print("  Weakest / most-uncertain stage : %s (R2=%.3f, grade=%s)"
              % (weakest["stage"], weakest["r2"], weakest["grade"]))
    # structural gap
    n_full = proportion_treatment_effect(df)["n_full_trials"]
    print("  Trials with full NIT+histology+hard chain : %d  "
          "(histology->hard is the structural data gap)" % n_full)
    n_drugs = df["drug"].nunique()
    n_trials = df["trial"].nunique()
    print("  Dataset: %d trials, %d drugs, %d NIT metrics."
          % (n_trials, n_drugs, df["nit_metric"].nunique()))
    print(_hr())
    print("Run with --chain | --compare-nit | --paradox | --gaps | "
          "--hypotheses for detail.")
    print("Use --help for all options.")
    return res


# --------------------------------------------------------------------------- #
# CLI
# --------------------------------------------------------------------------- #

def build_parser():
    p = argparse.ArgumentParser(
        prog="main.py",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description=textwrap.dedent("""\
            NITSurrogate-Kor — MASLD/MASH 3-stage trial-level surrogacy tool.

            Chain analyzed:  NIT (surrogate) -> Histology (intermediate)
                             -> Hard hepatic outcome.

            Computes per-stage R2_trial (+95% CI), STE, PTE mediation, surrogate
            grade, paradox flags, NIT ranking, and mines unvalidated stages to
            generate validation-study hypotheses with required trial size /
            follow-up.

            ⚠️  연구용·참고용 (research/reference use only) — NOT for clinical
                decision-making. Demo data are illustrative / synthetic.
            """),
        epilog="Examples:\n"
               "  python3 main.py\n"
               "  python3 main.py --chain\n"
               "  python3 main.py --compare-nit\n"
               "  python3 main.py --gaps --hypotheses\n"
               "  python3 main.py --data mytrials.csv --chain --top 10\n",
    )
    p.add_argument("--data", default=DEFAULT_DATA,
                   help="path to surrogacy CSV (default: bundled demo)")
    p.add_argument("--chain", action="store_true",
                   help="stage-by-stage R2_trial / STE / PTE / grade table")
    p.add_argument("--compare-nit", action="store_true",
                   help="rank NITs (FIB-4 / VCTE / MRI-PDFF / ELF / MRE)")
    p.add_argument("--paradox", action="store_true",
                   help="list surrogate-paradox rows")
    p.add_argument("--gaps", action="store_true",
                   help="mine unvalidated / data-sparse stages")
    p.add_argument("--hypotheses", action="store_true",
                   help="emit validation-study hypotheses + required N/follow-up")
    p.add_argument("--top", type=int, default=None,
                   help="limit number of rows/items in lists")
    p.add_argument("--alpha", type=float, default=0.05,
                   help="two-sided alpha for CIs / bands (default 0.05)")
    p.add_argument("--grade-strong", type=float, default=GRADE_STRONG,
                   help="R2_trial threshold for 'strong' (default 0.70)")
    p.add_argument("--grade-moderate", type=float, default=GRADE_MODERATE,
                   help="R2_trial threshold for 'moderate' (default 0.50)")
    p.add_argument("--no-banner", action="store_true",
                   help="suppress the header banner")
    return p


def main(argv=None):
    global GRADE_STRONG, GRADE_MODERATE
    args = build_parser().parse_args(argv)
    GRADE_STRONG = args.grade_strong
    GRADE_MODERATE = args.grade_moderate

    if not args.no_banner:
        banner()

    try:
        df = load_data(args.data)
    except FileNotFoundError:
        print("ERROR: data file not found: %s" % args.data, file=sys.stderr)
        return 2
    except Exception as e:  # pragma: no cover
        print("ERROR loading data: %s" % e, file=sys.stderr)
        return 2

    print("\nData: %s  (%d analyzable rows)" % (args.data, len(df)))

    ran_any = False
    if args.chain:
        cmd_chain(df, alpha=args.alpha); ran_any = True
    if args.compare_nit:
        cmd_compare_nit(df, alpha=args.alpha, top=args.top); ran_any = True
    if args.paradox:
        cmd_paradox(df); ran_any = True
    if args.gaps:
        cmd_gaps(df, alpha=args.alpha, top=args.top); ran_any = True
    if args.hypotheses:
        cmd_hypotheses(df, alpha=args.alpha, top=args.top); ran_any = True

    if not ran_any:
        cmd_summary(df, alpha=args.alpha)

    print("\n" + DISCLAIMER)
    return 0


if __name__ == "__main__":
    sys.exit(main())
