"""
surrogacy.py  --  Trial-level surrogacy meta-regression engine for GlyceSurrogate-Kor.

Pure numpy/scipy/pandas implementation of the Buyse-Molenberghs / Daniels-Hughes
trial-level surrogacy framework. statsmodels is intentionally NOT used (not installed);
weighted least squares, prediction intervals, and the surrogate threshold effect are
implemented from first principles.

Core quantities (computed per drug_class x surrogate x hard_outcome cell):
  - R2_trial : weighted R^2 of the regression of the hard-outcome treatment effect
               (log-HR) on the surrogate treatment effect across trials, with each
               trial inverse-variance weighted by 1/Var(log-HR).
  - STE      : Surrogate Threshold Effect -- the surrogate effect at which the lower
               (or upper) 95% prediction band of the predicted log-HR crosses the null
               (log-HR = 0). Below the STE we predict a real hard-outcome benefit.
  - PTE      : Proportion of Treatment Effect explained, via a simple mediation framing
               PTE = 1 - beta_adjusted / beta_unadjusted.
  - paradox  : surrogate-paradox flag (surrogate improves but hard outcome worsens).
  - grade    : strong / moderate / weak / invalid from configurable R2 thresholds.

RESEARCH / REFERENCE USE ONLY -- NOT FOR CLINICAL DECISION-MAKING.
"""
from __future__ import annotations

import math
from dataclasses import dataclass, field, asdict
from typing import Optional

import numpy as np
import pandas as pd
from scipy import stats

# --------------------------------------------------------------------------------------
# Configurable constants (thresholds / policy knobs)
# --------------------------------------------------------------------------------------
R2_STRONG = 0.70          # R2_trial >= this -> "strong" surrogacy
R2_MODERATE = 0.50        # R2_trial in [R2_MODERATE, R2_STRONG) -> "moderate"
# below R2_MODERATE -> "weak";  any paradox -> "invalid"

MIN_TRIALS_FOR_REGRESSION = 3   # need >=3 points to fit a line + estimate spread
MIN_TRIALS_FOR_VALIDATION = 3   # cells with fewer trials are flagged as a data gap
WIDE_CI_WIDTH = 0.50            # R2 95% CI width above this is "wide / unstable"

# Surrogate direction: which sign of delta_surrogate counts as a *favorable* surrogate
# change. HbA1c/FPG lower is better (-1); TIR higher is better (+1).
SURROGATE_FAVORABLE_SIGN = {
    "HbA1c": -1,
    "FPG": -1,
    "TIR": +1,
}

# Null value on the hard-outcome scale. We work in log-HR, so null = 0 (HR = 1).
NULL_LOGHR = 0.0

DISCLAIMER = (
    "RESEARCH / REFERENCE USE ONLY -- NOT FOR CLINICAL DECISION-MAKING. "
    "Demo numbers are synthetic/illustrative, not official trial readouts."
)

REQUIRED_COLUMNS = [
    "trial", "drug_class", "surrogate",
    "delta_surrogate", "delta_surrogate_se",
    "hard_outcome", "loghr", "loghr_se",
]


# --------------------------------------------------------------------------------------
# Data loading
# --------------------------------------------------------------------------------------
def load_data(path: str) -> pd.DataFrame:
    """Load a trial-level CSV. Lines beginning with '#' are treated as comments."""
    df = pd.read_csv(path, comment="#", skipinitialspace=True)
    df.columns = [c.strip() for c in df.columns]
    missing = [c for c in REQUIRED_COLUMNS if c not in df.columns]
    if missing:
        raise ValueError(
            f"CSV is missing required columns: {missing}. "
            f"Required schema: {REQUIRED_COLUMNS}"
        )
    # Coerce numerics; drop rows with non-finite essentials.
    for c in ["delta_surrogate", "delta_surrogate_se", "loghr", "loghr_se"]:
        df[c] = pd.to_numeric(df[c], errors="coerce")
    df = df.dropna(subset=["delta_surrogate", "loghr", "loghr_se"]).copy()
    # Guard against zero/negative SEs (would blow up inverse-variance weights).
    df["loghr_se"] = df["loghr_se"].clip(lower=1e-6)
    df["delta_surrogate_se"] = df["delta_surrogate_se"].fillna(0.0).clip(lower=0.0)
    for c in ["trial", "drug_class", "surrogate", "hard_outcome"]:
        df[c] = df[c].astype(str).str.strip()
    return df.reset_index(drop=True)


# --------------------------------------------------------------------------------------
# Weighted least squares helpers
# --------------------------------------------------------------------------------------
def _wls_fit(x: np.ndarray, y: np.ndarray, w: np.ndarray):
    """
    Weighted least squares for y = b0 + b1*x.

    Returns dict with intercept b0, slope b1, the (X'WX)^-1 covariance scaffold,
    residual variance s2, weighted R^2, and degrees of freedom.
    """
    x = np.asarray(x, float)
    y = np.asarray(y, float)
    w = np.asarray(w, float)
    n = x.size

    X = np.column_stack([np.ones(n), x])           # design matrix [1, x]
    W = np.diag(w)
    XtW = X.T @ W
    XtWX = XtW @ X
    XtWX_inv = np.linalg.inv(XtWX)
    beta = XtWX_inv @ (XtW @ y)                     # [b0, b1]
    b0, b1 = beta

    yhat = X @ beta
    resid = y - yhat
    dof = n - 2                                     # 2 params
    # Weighted residual variance (scale factor for the parameter covariance).
    s2 = float((w * resid**2).sum() / dof) if dof > 0 else float("nan")

    # Weighted R^2: 1 - SS_res / SS_tot, using weights and the weighted mean.
    ybar_w = float((w * y).sum() / w.sum())
    ss_res = float((w * resid**2).sum())
    ss_tot = float((w * (y - ybar_w) ** 2).sum())
    r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else float("nan")
    r2 = float(min(max(r2, 0.0), 1.0)) if np.isfinite(r2) else float("nan")

    return {
        "b0": float(b0), "b1": float(b1),
        "XtWX_inv": XtWX_inv, "s2": s2, "dof": dof, "n": n,
        "x": x, "y": y, "w": w, "w_med": float(np.median(w)),
        "yhat": yhat, "resid": resid,
        "r2": r2,
    }


def _r2_ci_fisher(r2: float, n: int) -> tuple:
    """
    95% CI for R^2 via the Fisher-z transform of r = sqrt(R^2).

    Sign of r is taken to match a positive coefficient of determination (>=0).
    Returns (lo, hi) on the R^2 scale, clamped to [0, 1].
    """
    if not np.isfinite(r2) or n < 4:
        return (float("nan"), float("nan"))
    r = math.sqrt(max(min(r2, 1.0), 0.0))
    r = min(r, 0.999999)
    z = math.atanh(r)
    se = 1.0 / math.sqrt(n - 3)
    zlo, zhi = z - 1.959964 * se, z + 1.959964 * se
    rlo, rhi = math.tanh(zlo), math.tanh(zhi)
    rlo = max(rlo, 0.0)
    return (float(rlo**2), float(rhi**2))


def _pred_band(fit: dict, x0: float, alpha: float = 0.05) -> tuple:
    """
    Two-sided (1-alpha) prediction interval for a *new* trial's hard-outcome effect
    at surrogate value x0, from a WLS fit.

    In WLS the residual sampling variance of observation i is s2 / w_i. A future trial
    is assumed to carry a representative precision = the median observed weight, so its
    own variance is s2 / w_med. The prediction variance is therefore
        s2 * ( x0' (X'WX)^-1 x0  +  1/w_med ).
    Using 1/w_med (rather than a unit weight) keeps the band on the same scale as the
    inverse-variance-weighted fit.

    Returns (yhat, lo, hi).
    """
    XtWX_inv = fit["XtWX_inv"]
    s2 = fit["s2"]
    dof = fit["dof"]
    w_med = fit.get("w_med", 1.0)
    x0v = np.array([1.0, x0])
    yhat = float(fit["b0"] + fit["b1"] * x0)
    if not np.isfinite(s2) or dof <= 0 or w_med <= 0:
        return (yhat, float("nan"), float("nan"))
    var_mean = float(x0v @ XtWX_inv @ x0v) * s2          # variance of the fitted mean
    var_new = s2 / w_med                                 # new-trial sampling variance
    var_pred = var_mean + var_new
    se_pred = math.sqrt(max(var_pred, 0.0))
    tcrit = stats.t.ppf(1 - alpha / 2, dof)
    return (yhat, yhat - tcrit * se_pred, yhat + tcrit * se_pred)


# --------------------------------------------------------------------------------------
# Surrogate Threshold Effect (STE)
# --------------------------------------------------------------------------------------
def compute_ste(fit: dict, surrogate: str, alpha: float = 0.05,
                grid_pad: float = 1.5, npts: int = 2001) -> Optional[float]:
    """
    Surrogate Threshold Effect: the smallest *favorable* surrogate effect at which the
    upper 95% prediction band of the predicted log-HR crosses the null (log-HR=0),
    i.e. the surrogate effect beyond which a hard-outcome benefit is predicted with 95%
    confidence for a future trial.

    Direction guard: an STE only exists when the fitted slope points toward benefit
    (a *more favorable* surrogate predicts a *lower* log-HR). For HbA1c/FPG (lower is
    better) that means a positive b1; for TIR (higher is better) that means a negative
    b1. When the slope points the wrong way (e.g. the surrogate-paradox region, or no
    real surrogate-outcome relationship), there is no surrogate effect that confidently
    buys benefit, so we return None.

    Returns the surrogate value, or None if no confident-benefit threshold exists.
    """
    x = fit["x"]
    if x.size == 0:
        return None
    sign = SURROGATE_FAVORABLE_SIGN.get(surrogate, -1)
    # Beneficial slope: d(logHR)/d(surrogate) must be opposite to the favorable sign,
    # i.e. sign * b1 < 0.  (favorable=-1 -> need b1>0;  favorable=+1 -> need b1<0)
    if not np.isfinite(fit["b1"]) or sign * fit["b1"] >= 0:
        return None

    span = float(x.max() - x.min())
    pad = grid_pad * (span if span > 0 else (abs(float(x.mean())) + 1.0))
    lo = float(x.min()) - pad
    hi = float(x.max()) + pad
    grid = np.linspace(lo, hi, npts)

    # Sweep from most-favorable to least-favorable surrogate value; the threshold is
    # the least-favorable surrogate effect that still keeps the whole prediction band
    # below the null.
    order = np.argsort(sign * grid)[::-1]   # descending favorability
    grid_sorted = grid[order]

    ste = None
    for xv in grid_sorted:
        _, plo, phi = _pred_band(fit, float(xv), alpha=alpha)
        if not np.isfinite(phi):
            continue
        # Confident benefit <=> the entire prediction band sits below the null.
        if phi < NULL_LOGHR:
            ste = float(xv)
        elif ste is not None:
            # We were in a confident-benefit run and just exited it; report the
            # boundary (least-favorable surrogate value that still gave benefit).
            break
    return ste


# --------------------------------------------------------------------------------------
# Proportion of Treatment Effect explained (PTE)
# --------------------------------------------------------------------------------------
def compute_pte(df_cell: pd.DataFrame) -> dict:
    """
    PTE via a simple mediation framing:
        beta_unadjusted : mean hard-outcome treatment effect (intercept-only WLS).
        beta_adjusted   : residual hard-outcome effect after regressing out the
                          surrogate (the intercept of the WLS of loghr on surrogate,
                          i.e. predicted effect at zero surrogate change).
        PTE = 1 - beta_adjusted / beta_unadjusted.

    Values are clamped to [0, 1] for reporting but the raw value and a plausibility
    flag are also returned.
    """
    y = df_cell["loghr"].to_numpy(float)
    x = df_cell["delta_surrogate"].to_numpy(float)
    w = 1.0 / df_cell["loghr_se"].to_numpy(float) ** 2

    if y.size == 0:
        return {"pte": float("nan"), "pte_raw": float("nan"), "implausible": True}

    # Unadjusted: weighted mean of the hard-outcome effect.
    beta_unadj = float((w * y).sum() / w.sum())

    if df_cell.shape[0] < MIN_TRIALS_FOR_REGRESSION or np.allclose(x, x[0]):
        # Cannot identify the mediated portion without surrogate variation.
        return {
            "pte": float("nan"), "pte_raw": float("nan"),
            "beta_unadj": beta_unadj, "beta_adj": float("nan"),
            "implausible": True,
        }

    fit = _wls_fit(x, y, w)
    beta_adj = fit["b0"]            # predicted log-HR at zero surrogate change
    if abs(beta_unadj) < 1e-9:
        pte_raw = float("nan")
    else:
        pte_raw = 1.0 - beta_adj / beta_unadj

    implausible = (not np.isfinite(pte_raw)) or pte_raw < 0.0 or pte_raw > 1.0
    pte_clamped = float(min(max(pte_raw, 0.0), 1.0)) if np.isfinite(pte_raw) else float("nan")
    return {
        "pte": pte_clamped, "pte_raw": float(pte_raw) if np.isfinite(pte_raw) else float("nan"),
        "beta_unadj": beta_unadj, "beta_adj": float(beta_adj),
        "implausible": bool(implausible),
    }


# --------------------------------------------------------------------------------------
# Paradox detection
# --------------------------------------------------------------------------------------
def detect_paradox(df_cell: pd.DataFrame) -> list:
    """
    Surrogate-paradox: surrogate moves in the *favorable* direction yet the hard
    outcome worsens (log-HR > 0, i.e. HR > 1). Returns a list of offending trial rows.
    """
    flags = []
    for _, r in df_cell.iterrows():
        sign = SURROGATE_FAVORABLE_SIGN.get(r["surrogate"], -1)
        favorable_surrogate = (sign * r["delta_surrogate"]) > 0
        worse_outcome = r["loghr"] > 0
        if favorable_surrogate and worse_outcome:
            flags.append({
                "trial": r["trial"],
                "drug_class": r["drug_class"],
                "surrogate": r["surrogate"],
                "delta_surrogate": float(r["delta_surrogate"]),
                "hard_outcome": r["hard_outcome"],
                "loghr": float(r["loghr"]),
                "hr": float(math.exp(r["loghr"])),
            })
    return flags


# --------------------------------------------------------------------------------------
# Grading
# --------------------------------------------------------------------------------------
def grade_surrogacy(r2: float, has_paradox: bool) -> str:
    if has_paradox:
        return "invalid"
    if not np.isfinite(r2):
        return "indeterminate"
    if r2 >= R2_STRONG:
        return "strong"
    if r2 >= R2_MODERATE:
        return "moderate"
    return "weak"


# --------------------------------------------------------------------------------------
# Per-cell surrogacy result container
# --------------------------------------------------------------------------------------
@dataclass
class CellResult:
    drug_class: str
    surrogate: str
    hard_outcome: str
    n_trials: int
    r2_trial: float
    r2_ci_lo: float
    r2_ci_hi: float
    slope: float
    ste: Optional[float]
    pte: float
    pte_raw: float
    pte_implausible: bool
    grade: str
    has_paradox: bool
    paradox_trials: list = field(default_factory=list)
    note: str = ""

    def as_row(self) -> dict:
        d = asdict(self)
        d["paradox_trials"] = ",".join(t["trial"] for t in self.paradox_trials)
        return d


def analyze_cell(df_cell: pd.DataFrame) -> CellResult:
    """Run the full surrogacy pipeline on one (class x surrogate x outcome) cell."""
    dc = df_cell.iloc[0]["drug_class"]
    sg = df_cell.iloc[0]["surrogate"]
    ho = df_cell.iloc[0]["hard_outcome"]
    n = int(df_cell.shape[0])

    paradox = detect_paradox(df_cell)
    has_paradox = len(paradox) > 0

    r2 = float("nan")
    r2lo = r2hi = float("nan")
    slope = float("nan")
    ste = None
    note = ""

    x = df_cell["delta_surrogate"].to_numpy(float)
    y = df_cell["loghr"].to_numpy(float)
    w = 1.0 / df_cell["loghr_se"].to_numpy(float) ** 2

    if n >= MIN_TRIALS_FOR_REGRESSION and not np.allclose(x, x[0]):
        fit = _wls_fit(x, y, w)
        r2 = fit["r2"]
        slope = fit["b1"]
        r2lo, r2hi = _r2_ci_fisher(r2, n)
        ste = compute_ste(fit, sg)
    else:
        note = f"insufficient trials/variation (n={n}) for regression"

    pte_d = compute_pte(df_cell)
    grade = grade_surrogacy(r2, has_paradox)

    return CellResult(
        drug_class=dc, surrogate=sg, hard_outcome=ho, n_trials=n,
        r2_trial=r2, r2_ci_lo=r2lo, r2_ci_hi=r2hi, slope=slope,
        ste=ste, pte=pte_d["pte"], pte_raw=pte_d["pte_raw"],
        pte_implausible=bool(pte_d["implausible"]),
        grade=grade, has_paradox=has_paradox, paradox_trials=paradox,
        note=note,
    )


def analyze_all(df: pd.DataFrame) -> list:
    """Analyze every (drug_class, surrogate, hard_outcome) cell present in the data."""
    results = []
    keys = df[["drug_class", "surrogate", "hard_outcome"]].drop_duplicates()
    for _, k in keys.iterrows():
        cell = df[
            (df["drug_class"] == k["drug_class"])
            & (df["surrogate"] == k["surrogate"])
            & (df["hard_outcome"] == k["hard_outcome"])
        ]
        results.append(analyze_cell(cell))
    # Stable, readable ordering.
    results.sort(key=lambda r: (r.drug_class, r.surrogate, r.hard_outcome))
    return results


def results_to_frame(results: list) -> pd.DataFrame:
    return pd.DataFrame([r.as_row() for r in results])


# --------------------------------------------------------------------------------------
# Sample-size suggestion for a validation study
# --------------------------------------------------------------------------------------
def suggest_validation(cell: CellResult, df_cell: Optional[pd.DataFrame] = None) -> dict:
    """
    Suggest a validation study for an unvalidated/weak cell.

    Heuristics (transparent, illustrative -- not a formal power calc):
      - n_trials_needed: how many additional trials to reach a stable trial-level
        regression (target = 2x MIN, scaled up when CI is wide / surrogacy weak).
      - per_trial_n: a back-of-envelope per-arm sample size to detect a hard-outcome
        log-HR of magnitude |STE-implied effect| (fallback 0.15) at 80% power, 2-sided
        alpha 0.05, given the observed event-rate proxy from loghr_se.
    """
    target_trials = 2 * MIN_TRIALS_FOR_VALIDATION
    deficit = max(target_trials - cell.n_trials, 1)

    # Inflate when the surrogacy signal is poor or uncertain.
    inflate = 1.0
    ci_width = (cell.r2_ci_hi - cell.r2_ci_lo) if np.isfinite(cell.r2_ci_hi) else float("nan")
    if np.isfinite(ci_width) and ci_width > WIDE_CI_WIDTH:
        inflate += 0.5
    if cell.grade in ("weak", "invalid", "indeterminate"):
        inflate += 0.5
    n_trials_needed = int(math.ceil(deficit * inflate))

    # Target effect size: prefer an STE-implied log-HR magnitude, else a modest 0.15.
    target_loghr = 0.15
    # Per-arm sample size for a log-HR via the standard survival approximation:
    #   total_events ~ 4 * (z_a/2 + z_b)^2 / (loghr)^2  (1:1 allocation)
    za = 1.959964
    zb = 0.841621   # 80% power
    total_events = 4.0 * (za + zb) ** 2 / (target_loghr ** 2)
    # Convert events -> participants via a crude event-rate proxy. Smaller observed
    # loghr_se implies more events/larger trials; map it into an assumed event rate.
    se_med = float(np.nanmedian(df_cell["loghr_se"])) if df_cell is not None and len(df_cell) else 0.08
    assumed_event_rate = float(min(max(0.5 * se_med, 0.03), 0.20))  # 3%-20%
    total_n = int(math.ceil(total_events / assumed_event_rate))
    per_arm_n = int(math.ceil(total_n / 2))

    return {
        "n_additional_trials": n_trials_needed,
        "target_loghr": target_loghr,
        "assumed_event_rate": round(assumed_event_rate, 3),
        "approx_total_events": int(math.ceil(total_events)),
        "approx_total_n": total_n,
        "approx_per_arm_n": per_arm_n,
    }


# --------------------------------------------------------------------------------------
# Gap mining + hypothesis generation
# --------------------------------------------------------------------------------------
def mine_gaps(df: pd.DataFrame, results: Optional[list] = None) -> list:
    """
    Mine the surrogate x outcome x class grid for under-validated / weak cells and emit
    validation hypotheses. Two sources of gaps:
      (1) Observed cells that are weak/invalid/indeterminate or have too few trials /
          wide CIs.
      (2) *Missing* cells: combinations of (class, surrogate, outcome) that never appear
          in the data at all -- these are pure evidence gaps.
    """
    if results is None:
        results = analyze_all(df)
    res_by_key = {(r.drug_class, r.surrogate, r.hard_outcome): r for r in results}

    classes = sorted(df["drug_class"].unique())
    surrogates = sorted(df["surrogate"].unique())
    outcomes = sorted(df["hard_outcome"].unique())

    hypotheses = []

    # (1) Observed-but-weak / under-powered cells.
    for r in results:
        reasons = []
        if r.n_trials < MIN_TRIALS_FOR_VALIDATION:
            reasons.append(f"only {r.n_trials} trial(s) (<{MIN_TRIALS_FOR_VALIDATION})")
        ci_width = (r.r2_ci_hi - r.r2_ci_lo) if np.isfinite(r.r2_ci_hi) else float("nan")
        if np.isfinite(ci_width) and ci_width > WIDE_CI_WIDTH:
            reasons.append(f"wide R2 CI (width={ci_width:.2f})")
        if r.grade in ("weak", "invalid", "indeterminate"):
            reasons.append(f"grade={r.grade}")
        if r.pte_implausible:
            reasons.append("implausible/undetermined PTE")
        if not reasons:
            continue
        cell = df[
            (df["drug_class"] == r.drug_class)
            & (df["surrogate"] == r.surrogate)
            & (df["hard_outcome"] == r.hard_outcome)
        ]
        ss = suggest_validation(r, cell)
        hypotheses.append({
            "kind": "weak_cell",
            "drug_class": r.drug_class,
            "surrogate": r.surrogate,
            "hard_outcome": r.hard_outcome,
            "n_trials": r.n_trials,
            "grade": r.grade,
            "r2_trial": r.r2_trial,
            "reasons": reasons,
            "hypothesis": (
                f"Is {r.surrogate} a valid trial-level surrogate for "
                f"{r.hard_outcome} in {r.drug_class}? "
                f"(current evidence: {', '.join(reasons)})"
            ),
            "suggestion": ss,
        })

    # (2) Completely missing cells (pure evidence gaps).
    for dc in classes:
        for sg in surrogates:
            for ho in outcomes:
                key = (dc, sg, ho)
                if key in res_by_key:
                    continue
                # A made-up CellResult shell for the sample-size heuristic.
                shell = CellResult(
                    drug_class=dc, surrogate=sg, hard_outcome=ho, n_trials=0,
                    r2_trial=float("nan"), r2_ci_lo=float("nan"), r2_ci_hi=float("nan"),
                    slope=float("nan"), ste=None, pte=float("nan"), pte_raw=float("nan"),
                    pte_implausible=True, grade="indeterminate", has_paradox=False,
                )
                ss = suggest_validation(shell, None)
                hypotheses.append({
                    "kind": "missing_cell",
                    "drug_class": dc,
                    "surrogate": sg,
                    "hard_outcome": ho,
                    "n_trials": 0,
                    "grade": "no data",
                    "r2_trial": float("nan"),
                    "reasons": ["no trials in dataset for this combination"],
                    "hypothesis": (
                        f"No trial-level data link {sg} to {ho} in {dc}. "
                        f"Hypothesis: a validation programme could establish whether "
                        f"{sg} predicts {ho} treatment benefit in {dc}."
                    ),
                    "suggestion": ss,
                })

    # Prioritise: weak observed cells first (more actionable), then missing cells.
    order = {"weak_cell": 0, "missing_cell": 1}
    hypotheses.sort(key=lambda h: (order[h["kind"]], h["drug_class"], h["surrogate"], h["hard_outcome"]))
    return hypotheses
