"""
surrogacy.py — core trial-level surrogacy engine for WtLossSurrogate-Kor.

Pure numpy/scipy implementation (statsmodels is NOT available) of a
Buyse–Molenberghs / Daniels–Hughes style *trial-level* surrogacy analysis:

  * R²_trial  : inverse-variance weighted least squares (WLS) regression of the
                hard-outcome treatment effect (y, e.g. log-HR) on the surrogate
                treatment effect (x, % weight loss) across trials within a class.
  * STE       : surrogate threshold effect — the smallest |%weight-loss| at which
                the upper 95% prediction band of the (log-HR) crosses the null
                (i.e. the predicted benefit becomes statistically credible).
  * dose–response surrogacy : linear vs quadratic WLS fit + plateau detection.
  * PTE       : weight-mediated fraction via a simple mediation framing
                PTE = 1 - beta_adjusted / beta_unadjusted.
  * surrogate-paradox flag  : surrogate improves but hard outcome worsens.
  * surrogacy grade         : strong / moderate / weak / invalid.

⚠️  연구용·참고용 (research/reference use only) — not for clinical decision-making.
"""

from __future__ import annotations

import math
from dataclasses import dataclass, field
from typing import Optional

import numpy as np
from scipy import stats

# --------------------------------------------------------------------------- #
# Configuration
# --------------------------------------------------------------------------- #

# Surrogacy strength thresholds on R²_trial (configurable).
R2_STRONG = 0.70
R2_MODERATE = 0.50

# Minimum trials to even attempt a class×outcome regression.
MIN_TRIALS_REGRESSION = 3

# Minimum trials below which a surrogate–outcome–class cell is considered
# "under-validated" regardless of R².
MIN_TRIALS_VALIDATED = 4

# z for 95%.
Z95 = stats.norm.ppf(0.975)  # ~1.95996

DISCLAIMER = (
    "⚠️  연구용·참고용 (research/reference use only) — not for clinical decision-making. "
    "Demo effect sizes are illustrative/synthetic, NOT official trial readouts."
)


# --------------------------------------------------------------------------- #
# Result containers
# --------------------------------------------------------------------------- #

@dataclass
class SurrogacyResult:
    drug_class: str
    hard_outcome: str
    n_trials: int
    r2_trial: float
    r2_ci_low: float
    r2_ci_high: float
    slope: float
    slope_se: float
    intercept: float
    ste: Optional[float]            # % weight loss at which benefit becomes credible
    pte: Optional[float]            # weight-mediated fraction (0..1, may be flagged)
    pte_flag: str
    paradox: bool
    grade: str
    notes: str = ""
    # raw arrays kept for plotting / dose-response
    x: np.ndarray = field(default=None, repr=False)
    y: np.ndarray = field(default=None, repr=False)
    w: np.ndarray = field(default=None, repr=False)


@dataclass
class DoseResponseResult:
    drug_class: str
    hard_outcome: str
    bins: list          # list of (label, mean_wl, mean_loghr, n)
    linear_slope: float
    quad_coef: float    # coefficient on x^2 in quadratic WLS (curvature)
    nonlinearity_p: float
    plateau: bool
    verdict: str


@dataclass
class Hypothesis:
    drug_class: str
    surrogate: str
    hard_outcome: str
    reason: str
    n_trials: int
    r2_trial: Optional[float]
    suggested_n_per_arm: Optional[int]
    suggested_trials: int
    priority: float
    statement: str


# --------------------------------------------------------------------------- #
# Weighted regression primitives
# --------------------------------------------------------------------------- #

def wls(x: np.ndarray, y: np.ndarray, w: np.ndarray):
    """Weighted least squares for y = a + b*x.

    Returns dict with slope, intercept, slope_se, intercept_se, r2 (weighted),
    residual variance, design matrices needed for prediction bands.
    """
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    w = np.asarray(w, dtype=float)
    n = len(x)

    X = np.column_stack([np.ones(n), x])      # (n,2)
    W = np.diag(w)
    XtW = X.T @ W
    XtWX = XtW @ X
    XtWX_inv = np.linalg.inv(XtWX)
    beta = XtWX_inv @ (XtW @ y)               # [intercept, slope]
    intercept, slope = beta[0], beta[1]

    resid = y - X @ beta
    dof = max(n - 2, 1)
    # weighted residual variance (dispersion); >1 => over-dispersion beyond
    # the reported SEs, which inflates prediction bands appropriately.
    wrss = float(resid.T @ W @ resid)
    sigma2 = wrss / dof
    cov_beta = sigma2 * XtWX_inv
    intercept_se = math.sqrt(max(cov_beta[0, 0], 0.0))
    slope_se = math.sqrt(max(cov_beta[1, 1], 0.0))

    # weighted R²
    ybar_w = np.sum(w * y) / np.sum(w)
    ss_tot = float(np.sum(w * (y - ybar_w) ** 2))
    ss_res = wrss
    r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0
    r2 = float(np.clip(r2, 0.0, 1.0))

    # dispersion factor phi (≈1 if the reported SEs fully explain the residuals;
    # >1 signals between-trial heterogeneity beyond sampling error).
    phi = wrss / dof
    # representative new-trial sampling variance: a reasonably precise future
    # validation trial. Use the best (smallest) observed sampling variance so the
    # STE reflects what an adequately-powered new trial could establish.
    samp_var_new = float(1.0 / np.max(w))

    return {
        "slope": slope,
        "intercept": intercept,
        "slope_se": slope_se,
        "intercept_se": intercept_se,
        "r2": r2,
        "sigma2": sigma2,
        "phi": phi,
        "samp_var_new": samp_var_new,
        "dof": dof,
        "XtWX_inv": XtWX_inv,
        "n": n,
        "x": x,
        "resid": resid,
    }


def r2_ci_fisher(r2: float, n: int):
    """Approximate 95% CI for R² via Fisher z-transform on r = sqrt(R²).

    Sign of r follows the (assumed negative) surrogate→benefit relationship but
    for R² we only need magnitude; we transform r=sqrt(R2).
    """
    if n <= 3:
        return (max(0.0, r2 - 0.4), min(1.0, r2 + 0.2))
    r = math.sqrt(max(min(r2, 0.999999), 0.0))
    # guard r near 1
    r = min(r, 0.999)
    z = 0.5 * math.log((1 + r) / (1 - r))
    se = 1.0 / math.sqrt(n - 3)
    lo_z, hi_z = z - Z95 * se, z + Z95 * se
    lo_r = math.tanh(lo_z)
    hi_r = math.tanh(hi_z)
    return (float(np.clip(lo_r ** 2, 0.0, 1.0)), float(np.clip(hi_r ** 2, 0.0, 1.0)))


def predict_with_band(fit: dict, x_new: float):
    """Predicted y and 95% prediction-band half-width at x_new (new trial).

    Standard IV-weighted meta-regression prediction interval. The variance of a
    predicted NEW trial's observed effect at x_new is:

        var_pred = phi * x0' (X'WX)^-1 x0   +   var_sampling_new

    where the first term is the (dispersion-scaled) uncertainty in the regression
    line at x_new, and the second is the sampling variance the new trial would
    itself carry (taken as the most precise observed trial — i.e. an adequately
    powered validation study). This gives a band the STE can realistically cross.
    """
    x0 = np.array([1.0, x_new])
    yhat = float(fit["intercept"] + fit["slope"] * x_new)
    var_line = float(x0 @ fit["XtWX_inv"] @ x0) * fit["phi"]
    var_pred = var_line + fit["samp_var_new"]
    t = stats.t.ppf(0.975, fit["dof"])
    half = t * math.sqrt(max(var_pred, 0.0))
    return yhat, half


def compute_ste(fit: dict, x_grid: np.ndarray):
    """Surrogate Threshold Effect.

    The hard outcome is on the log-HR scale where NEGATIVE = benefit and
    null = 0. STE = smallest |%weight loss| (most modest weight loss) at which
    the UPPER 95% prediction band of log-HR is still < 0, i.e. the predicted
    benefit is credible even at the pessimistic edge of the band.

    % weight loss is encoded as a negative number (loss). We scan from 0 toward
    more negative and return the first x where upper band < 0.
    """
    # ensure grid runs from least loss (near 0) to most loss (very negative)
    grid = np.sort(x_grid)[::-1]   # descending: 0, -1, -2, ...
    ste = None
    for xv in grid:
        yhat, half = predict_with_band(fit, xv)
        upper = yhat + half
        if upper < 0:               # entire band below null => credible benefit
            ste = float(xv)
            break
    return ste


# --------------------------------------------------------------------------- #
# Surrogacy per class × outcome
# --------------------------------------------------------------------------- #

def _grade(r2: float, paradox: bool, n_trials: int) -> str:
    if paradox:
        return "invalid"
    if n_trials < MIN_TRIALS_REGRESSION:
        return "insufficient"
    if r2 >= R2_STRONG:
        return "strong"
    if r2 >= R2_MODERATE:
        return "moderate"
    return "weak"


def compute_pte(x: np.ndarray, y: np.ndarray, w: np.ndarray):
    """Weight-mediated fraction via simple mediation framing.

    Unadjusted hard-outcome effect = inverse-variance weighted mean of y
    (the average treatment effect on the hard outcome across trials).
    Adjusted effect = WLS intercept (predicted hard-outcome effect at the
    *reference* point of zero surrogate change) — i.e. the part of the hard
    benefit NOT explained by movement in the surrogate.

        PTE = 1 - beta_adjusted / beta_unadjusted

    Returns (pte_value, flag). Flag warns on implausible (<0 or >1) values,
    which signal weak/over-mediation and should not be over-interpreted.
    """
    if len(x) < MIN_TRIALS_REGRESSION:
        return None, "insufficient_trials"
    beta_unadj = float(np.sum(w * y) / np.sum(w))   # avg hard-outcome effect
    fit = wls(x, y, w)
    beta_adj = fit["intercept"]                     # residual direct effect at x=0
    if abs(beta_unadj) < 1e-9:
        return None, "null_total_effect"
    pte = 1.0 - beta_adj / beta_unadj
    flag = "ok"
    if pte < 0:
        flag = "implausible_negative(direct>total)"
    elif pte > 1:
        flag = "implausible_over1(clamped)"
    pte_clamped = float(np.clip(pte, 0.0, 1.0))
    return pte_clamped, flag


def surrogacy_for(df, drug_class: str, hard_outcome: str) -> Optional[SurrogacyResult]:
    """Compute surrogacy metrics for one (class, outcome) cell. df is a pandas DataFrame."""
    sub = df[(df["drug_class"] == drug_class) & (df["hard_outcome"] == hard_outcome)]
    n = len(sub)
    if n == 0:
        return None

    x = sub["pct_weight_loss"].to_numpy(dtype=float)
    y = sub["loghr"].to_numpy(dtype=float)
    yse = sub["loghr_se"].to_numpy(dtype=float)
    w = 1.0 / np.clip(yse, 1e-6, None) ** 2          # inverse-variance weights

    # paradox: surrogate improves (weight loss, x<0) yet hard outcome worsens
    # (loghr>0) for the bulk of the data — assess via weighted mean of y when
    # weight loss is meaningful.
    meaningful = x < -2.0
    paradox = False
    if meaningful.sum() >= 1:
        wm = np.sum(w[meaningful] * y[meaningful]) / np.sum(w[meaningful])
        paradox = wm > 0.02  # weighted-mean log-HR clearly above null despite weight loss

    if n < MIN_TRIALS_REGRESSION:
        # cannot regress; report a stub so the gap miner can pick it up
        return SurrogacyResult(
            drug_class=drug_class, hard_outcome=hard_outcome, n_trials=n,
            r2_trial=float("nan"), r2_ci_low=float("nan"), r2_ci_high=float("nan"),
            slope=float("nan"), slope_se=float("nan"), intercept=float("nan"),
            ste=None, pte=None, pte_flag="insufficient_trials",
            paradox=paradox, grade=_grade(0.0, paradox, n),
            notes="too few trials to regress", x=x, y=y, w=w,
        )

    fit = wls(x, y, w)
    r2 = fit["r2"]
    r2_lo, r2_hi = r2_ci_fisher(r2, n)

    # STE over a grid from 0 to the most extreme loss observed (a bit beyond)
    x_min = min(x.min(), -1.0)
    grid = np.linspace(0.0, x_min * 1.1, 240)
    ste = compute_ste(fit, grid)

    pte, pte_flag = compute_pte(x, y, w)

    grade = _grade(r2, paradox, n)

    return SurrogacyResult(
        drug_class=drug_class, hard_outcome=hard_outcome, n_trials=n,
        r2_trial=r2, r2_ci_low=r2_lo, r2_ci_high=r2_hi,
        slope=fit["slope"], slope_se=fit["slope_se"], intercept=fit["intercept"],
        ste=ste, pte=pte, pte_flag=pte_flag, paradox=paradox, grade=grade,
        x=x, y=y, w=w,
    )


def all_surrogacy(df) -> list:
    """Run surrogacy across every (class, outcome) cell present in df."""
    results = []
    for cls in sorted(df["drug_class"].unique()):
        for out in sorted(df["hard_outcome"].unique()):
            r = surrogacy_for(df, cls, out)
            if r is not None:
                results.append(r)
    return results


# --------------------------------------------------------------------------- #
# Dose–response surrogacy
# --------------------------------------------------------------------------- #

def dose_response(df, drug_class: str, hard_outcome: str,
                  edges=(0.0, -7.0, -14.0, -21.0, -100.0)) -> Optional[DoseResponseResult]:
    """Assess (non-)linearity of hard benefit across weight-loss magnitude bins."""
    sub = df[(df["drug_class"] == drug_class) & (df["hard_outcome"] == hard_outcome)]
    if len(sub) < MIN_TRIALS_REGRESSION:
        return None

    x = sub["pct_weight_loss"].to_numpy(dtype=float)
    y = sub["loghr"].to_numpy(dtype=float)
    yse = sub["loghr_se"].to_numpy(dtype=float)
    w = 1.0 / np.clip(yse, 1e-6, None) ** 2

    # bins by magnitude of loss
    edges = list(edges)
    bins = []
    for i in range(len(edges) - 1):
        hi, lo = edges[i], edges[i + 1]   # hi closer to 0, lo more negative
        mask = (x <= hi) & (x > lo)
        if mask.sum() == 0:
            continue
        mw = float(np.sum(w[mask] * x[mask]) / np.sum(w[mask]))
        my = float(np.sum(w[mask] * y[mask]) / np.sum(w[mask]))
        label = f"{lo:.0f}..{hi:.0f}%"
        bins.append((label, mw, my, int(mask.sum())))

    # linear fit
    lin = wls(x, y, w)
    linear_slope = lin["slope"]

    # quadratic WLS: y = a + b x + c x^2 ; curvature = c
    n = len(x)
    Xq = np.column_stack([np.ones(n), x, x ** 2])
    Wd = np.diag(w)
    quad_coef = float("nan")
    nonlin_p = float("nan")
    if n >= 4:  # need >3 for a sensible quadratic
        try:
            XtW = Xq.T @ Wd
            cov = np.linalg.inv(XtW @ Xq)
            beta = cov @ (XtW @ y)
            resid = y - Xq @ beta
            dof = max(n - 3, 1)
            sigma2 = float(resid.T @ Wd @ resid) / dof
            se_c = math.sqrt(max(sigma2 * cov[2, 2], 0.0))
            quad_coef = float(beta[2])
            if se_c > 0:
                tstat = quad_coef / se_c
                nonlin_p = float(2 * stats.t.sf(abs(tstat), dof))
        except np.linalg.LinAlgError:
            pass

    # plateau detection: does incremental benefit shrink at the most-loss bin?
    plateau = False
    if len(bins) >= 3:
        # compare slope between consecutive bin means; plateau if last increment
        # toward more loss yields <40% of the average earlier increment in benefit
        incs = []
        for i in range(1, len(bins)):
            d_wl = bins[i][1] - bins[i - 1][1]      # more negative => loss increases
            d_y = bins[i][2] - bins[i - 1][2]       # more negative => more benefit
            if abs(d_wl) > 1e-6:
                incs.append(d_y / d_wl)             # benefit per unit extra loss
        if len(incs) >= 2:
            early = np.mean(incs[:-1])
            last = incs[-1]
            if abs(early) > 1e-9 and (last / early) < 0.4:
                plateau = True

    if not math.isnan(nonlin_p) and nonlin_p < 0.10 and abs(quad_coef) > 1e-4:
        verdict = "non-linear (curvature detected)"
    elif plateau:
        verdict = "possible plateau at high weight-loss"
    else:
        verdict = "approximately linear (dose-response consistent)"

    return DoseResponseResult(
        drug_class=drug_class, hard_outcome=hard_outcome, bins=bins,
        linear_slope=linear_slope, quad_coef=quad_coef,
        nonlinearity_p=nonlin_p, plateau=plateau, verdict=verdict,
    )


# --------------------------------------------------------------------------- #
# Sample-size helper for validation hypotheses
# --------------------------------------------------------------------------- #

def suggest_sample_size(loghr_target: float, event_rate: float = 0.06,
                        power: float = 0.80, alpha: float = 0.05):
    """Schoenfeld-style sample size for a time-to-event surrogate-validation trial.

    Required number of events:  d = ((z_a/2 + z_b)^2) / (0.25 * (loghr)^2)
    (1:1 allocation). Then n_per_arm ≈ (d / event_rate) / 2.

    Returns (events_required, n_per_arm). If the target effect is ~null,
    returns (None, None) (un-powerable).
    """
    if abs(loghr_target) < 1e-3:
        return None, None
    z_a = stats.norm.ppf(1 - alpha / 2)
    z_b = stats.norm.ppf(power)
    d = ((z_a + z_b) ** 2) / (0.25 * loghr_target ** 2)
    events = int(math.ceil(d))
    n_total = events / max(event_rate, 1e-3)
    n_per_arm = int(math.ceil(n_total / 2.0))
    return events, n_per_arm


# --------------------------------------------------------------------------- #
# Gap mining → validation hypotheses
# --------------------------------------------------------------------------- #

def mine_gaps(df, surrogate_name: str = "%weight-loss") -> list:
    """Scan class × outcome grid; flag under-validated / weak cells and emit
    validation-study hypotheses with suggested sample sizes."""
    hyps = []
    classes = sorted(df["drug_class"].unique())
    outcomes = sorted(df["hard_outcome"].unique())

    # observed grand-mean benefit magnitude per outcome (for sizing absent cells)
    out_loghr = {}
    for out in outcomes:
        s = df[df["hard_outcome"] == out]
        if len(s):
            w = 1.0 / np.clip(s["loghr_se"].to_numpy(float), 1e-6, None) ** 2
            out_loghr[out] = float(np.sum(w * s["loghr"].to_numpy(float)) / np.sum(w))

    for cls in classes:
        for out in outcomes:
            res = surrogacy_for(df, cls, out)
            n = 0 if res is None else res.n_trials
            r2 = None if (res is None or math.isnan(res.r2_trial)) else res.r2_trial

            reason = None
            priority = 0.0

            if res is None or n == 0:
                reason = "no trials for this surrogate–outcome–class pair"
                priority = 0.9
            elif res.paradox:
                reason = "SURROGATE PARADOX: weight improves but hard outcome trends worse"
                priority = 1.0
            elif n < MIN_TRIALS_VALIDATED:
                reason = f"under-validated: only {n} trial(s) (<{MIN_TRIALS_VALIDATED})"
                priority = 0.8
            elif r2 is not None and r2 < R2_MODERATE:
                reason = f"weak trial-level surrogacy (R²_trial={r2:.2f} < {R2_MODERATE})"
                priority = 0.7
            elif r2 is not None and r2 < R2_STRONG:
                reason = f"moderate surrogacy not yet validated (R²_trial={r2:.2f})"
                priority = 0.5

            if reason is None:
                continue

            # target effect to power against: predicted hard-outcome effect at a
            # clinically meaningful surrogate level, else the outcome grand mean.
            target = out_loghr.get(out)
            if res is not None and not math.isnan(res.slope):
                pred = res.intercept + res.slope * (-15.0)  # ~15% loss reference
                if abs(pred) > 1e-3:
                    target = pred
            events, n_per_arm = (None, None)
            if target is not None:
                events, n_per_arm = suggest_sample_size(target)

            statement = (
                f"Is {surrogate_name} a valid trial-level surrogate for "
                f"{out} in {cls}? ({reason})"
            )
            # widen priority by how big the unmet hard-outcome stakes are
            if out in ("MACE", "ALL_CAUSE_DEATH", "HF_HOSP"):
                priority += 0.1

            hyps.append(Hypothesis(
                drug_class=cls, surrogate=surrogate_name, hard_outcome=out,
                reason=reason, n_trials=n, r2_trial=r2,
                suggested_n_per_arm=n_per_arm,
                suggested_trials=max(MIN_TRIALS_VALIDATED - n, 2),
                priority=round(priority, 3), statement=statement,
            ))

    hyps.sort(key=lambda h: h.priority, reverse=True)
    return hyps


# --------------------------------------------------------------------------- #
# Paradox scan
# --------------------------------------------------------------------------- #

def paradox_scan(df) -> list:
    out = []
    for r in all_surrogacy(df):
        if r.paradox:
            out.append(r)
    return out
