#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CumulMetaDelta-Kor  (큐뮬메타델타워치코어)
============================================================
Domain    : DM (당뇨)
Category  : 연구 알림 (research-alert / cumulative-meta curation)
Entry     : Python 3 CLI, FULLY OFFLINE on synthetic data.

One-line MVP
------------
For each (drug-class x pre-registered outcome) CELL, maintain a CUMULATIVE
meta-analysis of diabetes drug-class RCTs. As each new trial is ingested in
chronological order, recompute the pooled effect and emit an alert ONLY when
the new trial causes a *threshold transition* in the pooled state:
  - significance cross (95% CI of pooled effect crosses the null = 1.0)
  - I^2 heterogeneity threshold crossed (low<->moderate<->high)
  - 95% prediction interval starts/stops crossing 1.0
  - GRADE certainty grade up/down (inconsistency, imprecision, publication bias)

Meta-analysis math is implemented here in PURE STDLIB (math module only):
  - Inverse-variance FIXED effect
  - DerSimonian-Laird RANDOM effects (tau^2, Q, I^2)
  - Hartung-Knapp-Sidik-Jonkman (HKSJ) variance adjustment for the RE pooled CI
  - 95% prediction interval (random effects)
  - Egger's regression test (publication-bias signal, needs >=3 trials)

NO network. NO paid API. NO statsmodels/scipy required (numpy NOT used).

참고용 / 연구용 (research / reference only — NOT for clinical decisions).
============================================================
"""

import argparse
import json
import math
import os
import sys

# ----------------------------------------------------------------------
# Constants / thresholds (documented, tunable)
# ----------------------------------------------------------------------
NULL = 1.0                       # ratio-measure null value (HR/RR = 1)
I2_LOW = 0.25                    # I^2 < 25% -> low heterogeneity
I2_MODERATE = 0.50               # 25-50% moderate, 50-75% substantial, >75% considerable
I2_SUBSTANTIAL = 0.75
EGGER_P_FLAG = 0.10              # Egger's p < 0.10 -> small-study/pub-bias signal
DISCLAIMER = ("참고용/연구용 (research/reference only — NOT for clinical decisions). "
              "Data are SYNTHETIC demo trials.")

# t / z critical values without scipy (95% two-sided).
# z = 1.959964. t-table for small df (HKSJ uses df = k-1).
_Z_975 = 1.959964
_T_975 = {1: 12.706, 2: 4.303, 3: 3.182, 4: 2.776, 5: 2.571, 6: 2.447,
          7: 2.365, 8: 2.306, 9: 2.262, 10: 2.228, 11: 2.201, 12: 2.179,
          15: 2.131, 20: 2.086, 30: 2.042, 40: 2.021, 60: 2.000}


def t_crit_975(df):
    """Two-sided 95% t critical value (approx via table + interpolation)."""
    if df <= 0:
        return _Z_975
    if df in _T_975:
        return _T_975[df]
    keys = sorted(_T_975)
    if df >= keys[-1]:
        return _Z_975  # large df -> normal
    # linear interpolation between bracketing table entries
    lo = max(k for k in keys if k <= df)
    hi = min(k for k in keys if k >= df)
    if lo == hi:
        return _T_975[lo]
    frac = (df - lo) / (hi - lo)
    return _T_975[lo] + frac * (_T_975[hi] - _T_975[lo])


def norm_sf(z):
    """One-sided survival function P(Z > z) for standard normal, via erfc."""
    return 0.5 * math.erfc(z / math.sqrt(2.0))


# ----------------------------------------------------------------------
# Effect-size conversion: ratio measure + 95% CI  ->  log scale + SE
# ----------------------------------------------------------------------
def to_log_se(effect, ci_low, ci_high):
    """
    Convert a ratio measure (HR/RR) with 95% CI to (log_effect, se_log).
    se_log = (ln(ci_high) - ln(ci_low)) / (2 * 1.959964).
    """
    log_eff = math.log(effect)
    se = (math.log(ci_high) - math.log(ci_low)) / (2.0 * _Z_975)
    if se <= 0:
        se = 1e-6
    return log_eff, se


# ----------------------------------------------------------------------
# Core meta-analysis (pure stdlib)
# ----------------------------------------------------------------------
def meta_analyze(trials):
    """
    trials: list of dicts each with 'effect','ci_low','ci_high'.
    Returns a dict describing pooled fixed + DerSimonian-Laird random effects,
    tau^2, Q, I^2, HKSJ-adjusted random-effects 95% CI, 95% prediction interval,
    and Egger's test. All ratio results are back-transformed from log scale.
    """
    k = len(trials)
    ys, ws, ses = [], [], []
    for t in trials:
        y, se = to_log_se(t["effect"], t["ci_low"], t["ci_high"])
        ys.append(y)
        ses.append(se)
        ws.append(1.0 / (se * se))   # inverse-variance weights (fixed)

    sum_w = sum(ws)
    # ---- Fixed effect (inverse variance) ----
    mu_fixed = sum(w * y for w, y in zip(ws, ys)) / sum_w
    var_fixed = 1.0 / sum_w
    se_fixed = math.sqrt(var_fixed)

    # ---- Heterogeneity: Cochran's Q, DerSimonian-Laird tau^2, I^2 ----
    Q = sum(w * (y - mu_fixed) ** 2 for w, y in zip(ws, ys))
    df = k - 1
    if df > 0:
        sum_w2 = sum(w * w for w in ws)
        C = sum_w - (sum_w2 / sum_w)
        tau2 = max(0.0, (Q - df) / C) if C > 0 else 0.0
        I2 = max(0.0, (Q - df) / Q) if Q > 0 else 0.0
    else:
        tau2 = 0.0
        I2 = 0.0

    # ---- Random effects (DerSimonian-Laird) ----
    ws_re = [1.0 / (se * se + tau2) for se in ses]
    sum_wre = sum(ws_re)
    mu_re = sum(w * y for w, y in zip(ws_re, ys)) / sum_wre
    var_re = 1.0 / sum_wre
    se_re = math.sqrt(var_re)

    # ---- HKSJ adjustment for RE pooled CI (more honest with few trials) ----
    if df > 0:
        q_hksj = sum(w * (y - mu_re) ** 2 for w, y in zip(ws_re, ys)) / df
        se_hksj = math.sqrt(q_hksj / sum_wre)
        # guard: HKSJ can be anticonservative when q<1; use max with model SE
        se_hksj = max(se_hksj, se_re)
        tcrit = t_crit_975(df)
    else:
        se_hksj = se_re
        tcrit = _Z_975

    ci_re_low = math.exp(mu_re - tcrit * se_hksj)
    ci_re_high = math.exp(mu_re + tcrit * se_hksj)
    ci_fx_low = math.exp(mu_fixed - _Z_975 * se_fixed)
    ci_fx_high = math.exp(mu_fixed + _Z_975 * se_fixed)

    # ---- 95% prediction interval (random effects) ----
    if df >= 1:
        pi_se = math.sqrt(tau2 + se_re * se_re)
        pi_t = t_crit_975(df) if df > 1 else _T_975[1]
        pi_low = math.exp(mu_re - pi_t * pi_se)
        pi_high = math.exp(mu_re + pi_t * pi_se)
        pi_available = True
    else:
        pi_low = pi_high = None
        pi_available = False

    # ---- Egger's regression test for small-study effects ----
    egger_p = egger_test(ys, ses)

    return {
        "k": k,
        "fixed_effect": math.exp(mu_fixed),
        "fixed_ci": (ci_fx_low, ci_fx_high),
        "random_effect": math.exp(mu_re),
        "random_ci": (ci_re_low, ci_re_high),
        "se_re_hksj": se_hksj,
        "tau2": tau2,
        "Q": Q,
        "I2": I2,
        "pred_interval": (pi_low, pi_high) if pi_available else None,
        "egger_p": egger_p,
    }


def egger_test(ys, ses):
    """
    Egger's regression: regress (y/se) on (1/se); intercept != 0 signals
    small-study effects. Returns two-sided p for the intercept, or None if k<3.
    Pure stdlib OLS with normal-approx p-value.
    """
    k = len(ys)
    # Egger's intercept test needs df = k-2 >= 2 to be even minimally
    # interpretable; with k=3 (df=1) the normal-approx p is unreliable and
    # degenerate, so we require k>=4 before emitting a publication-bias signal.
    if k < 4:
        return None
    x = [1.0 / se for se in ses]          # precision
    z = [y / se for y, se in zip(ys, ses)]  # standardized effect
    n = float(k)
    sx = sum(x); sz = sum(z)
    sxx = sum(xi * xi for xi in x)
    sxz = sum(xi * zi for xi, zi in zip(x, z))
    denom = (n * sxx - sx * sx)
    if denom == 0:
        return None
    slope = (n * sxz - sx * sz) / denom
    intercept = (sz - slope * sx) / n
    # residual variance
    resid = [zi - (intercept + slope * xi) for xi, zi in zip(x, z)]
    dfres = k - 2
    if dfres <= 0:
        return None
    s2 = sum(r * r for r in resid) / dfres
    se_intercept = math.sqrt(s2 * (sxx / denom))
    if se_intercept == 0:
        return None
    tval = abs(intercept / se_intercept)
    # two-sided p via normal approx (conservative-enough for a flag)
    p = 2.0 * norm_sf(tval)
    return min(1.0, p)


# ----------------------------------------------------------------------
# Threshold-state derivation + GRADE
# ----------------------------------------------------------------------
def i2_band(i2):
    if i2 < I2_LOW:
        return "low"
    if i2 < I2_MODERATE:
        return "moderate"
    if i2 < I2_SUBSTANTIAL:
        return "substantial"
    return "considerable"


def ci_significant(ci):
    """True if the 95% CI for a ratio measure excludes the null (1.0)."""
    lo, hi = ci
    return (hi < NULL) or (lo > NULL)


def pi_crosses_null(pi):
    if pi is None:
        return None  # not estimable
    lo, hi = pi
    return (lo < NULL < hi)


def grade_certainty(res):
    """
    Simplified GRADE for a body of RCT evidence. Start at HIGH, downgrade for:
      - inconsistency (I^2 substantial/considerable)
      - imprecision (pooled RE 95% CI crosses null)
      - publication bias (Egger p < EGGER_P_FLAG)
    Returns (grade_label, level_int 4..1, reasons list).
    """
    level = 4  # 4=High,3=Moderate,2=Low,1=Very low
    reasons = []
    if res["k"] < 2:
        return ("insufficient", 0, ["single trial — cumulative pooling not yet meaningful"])
    if res["I2"] >= I2_SUBSTANTIAL:
        level -= 1
        reasons.append("inconsistency (I^2 %.0f%%)" % (res["I2"] * 100))
    if not ci_significant(res["random_ci"]):
        level -= 1
        reasons.append("imprecision (pooled 95%% CI crosses null)")
    ep = res["egger_p"]
    if ep is not None and ep < EGGER_P_FLAG:
        level -= 1
        reasons.append("publication bias signal (Egger p=%.3f)" % ep)
    level = max(1, level)
    label = {4: "High", 3: "Moderate", 2: "Low", 1: "Very low"}[level]
    return (label, level, reasons)


def derive_state(res):
    """Collapse a meta result into the discrete threshold-state we watch."""
    grade_label, grade_level, grade_reasons = grade_certainty(res)
    return {
        "significant": ci_significant(res["random_ci"]),
        "direction": ("benefit" if res["random_effect"] < NULL else "harm"),
        "i2_band": i2_band(res["I2"]),
        "pi_crosses_null": pi_crosses_null(res["pred_interval"]),
        "grade_label": grade_label,
        "grade_level": grade_level,
        "grade_reasons": grade_reasons,
    }


# ----------------------------------------------------------------------
# Transition detection (the "delta alert" engine)
# ----------------------------------------------------------------------
def detect_transitions(prev_state, prev_res, new_state, new_res):
    """
    Compare BEFORE vs AFTER discrete states. Return list of human-readable
    transition events. Empty list = no threshold crossed (no alert).
    """
    events = []
    if prev_state is None:
        return events  # first trial in cell establishes baseline, no delta

    # 1) Significance cross
    if prev_state["significant"] != new_state["significant"]:
        if new_state["significant"]:
            events.append(
                "SIGNIFICANCE CROSS → pooled 95%% CI now EXCLUDES null "
                "(RR/HR %.2f, CI %.2f-%.2f, %s)" % (
                    new_res["random_effect"], new_res["random_ci"][0],
                    new_res["random_ci"][1], new_state["direction"]))
        else:
            events.append(
                "SIGNIFICANCE LOST → pooled 95%% CI now INCLUDES null "
                "(RR/HR %.2f, CI %.2f-%.2f)" % (
                    new_res["random_effect"], new_res["random_ci"][0],
                    new_res["random_ci"][1]))

    # 2) I^2 heterogeneity band change
    if prev_state["i2_band"] != new_state["i2_band"]:
        events.append(
            "HETEROGENEITY SHIFT → I^2 band %s → %s (I^2 %.0f%%, tau^2 %.4f)" % (
                prev_state["i2_band"], new_state["i2_band"],
                new_res["I2"] * 100, new_res["tau2"]))

    # 3) Prediction interval crossing-null change
    if prev_state["pi_crosses_null"] != new_state["pi_crosses_null"] \
            and new_state["pi_crosses_null"] is not None \
            and prev_state["pi_crosses_null"] is not None:
        pi = new_res["pred_interval"]
        if new_state["pi_crosses_null"]:
            events.append(
                "PREDICTION INTERVAL → now CROSSES null "
                "(95%% PI %.2f-%.2f): future-trial effect uncertain" % (pi[0], pi[1]))
        else:
            events.append(
                "PREDICTION INTERVAL → now EXCLUDES null "
                "(95%% PI %.2f-%.2f): effect consistent across settings" % (pi[0], pi[1]))

    # 4) GRADE certainty up/down
    if prev_state["grade_level"] != new_state["grade_level"] \
            and prev_state["grade_level"] > 0 and new_state["grade_level"] > 0:
        arrow = "UP" if new_state["grade_level"] > prev_state["grade_level"] else "DOWN"
        events.append(
            "GRADE CERTAINTY %s → %s → %s (%s)" % (
                arrow, prev_state["grade_label"], new_state["grade_label"],
                "; ".join(new_state["grade_reasons"]) or "fewer limitations"))

    return events


# ----------------------------------------------------------------------
# Data loading + cell organization
# ----------------------------------------------------------------------
def load_trials(path):
    with open(path, "r", encoding="utf-8") as f:
        blob = json.load(f)
    return blob["trials"]


def group_into_cells(trials):
    """Return {(drug_class, outcome): [trials sorted by year, trial_id]}."""
    cells = {}
    for t in trials:
        key = (t["drug_class"], t["outcome"])
        cells.setdefault(key, []).append(t)
    for key in cells:
        cells[key].sort(key=lambda x: (x["year"], x["trial_id"]))
    return cells


# ----------------------------------------------------------------------
# Output helpers
# ----------------------------------------------------------------------
def header():
    print("=" * 68)
    print(" CumulMetaDelta-Kor / 큐뮬메타델타워치코어")
    print(" 누적 메타분석 기반 당뇨 약물군 RCT 임계전이 알림")
    print(" Domain: DM (당뇨)  |  Category: 연구 알림 (research-alert)")
    print(" " + DISCLAIMER)
    print("=" * 68)


def fmt_ci(ci):
    if ci is None:
        return "n/a"
    return "%.2f–%.2f" % (ci[0], ci[1])


def print_cell_forest(key, cell_trials):
    """Render the cumulative forest + per-step alert log for one cell."""
    dclass, outcome = key
    print("\n[CELL] %s × %s   (k=%d trials)" % (dclass, outcome, len(cell_trials)))
    print("-" * 68)
    print(" CUMULATIVE FOREST (each row = pooled estimate after adding that trial)")
    print(" %-3s %-6s %-30s %-16s %-16s" % ("yr", "id", "trial (source-traced)", "this-trial HR", "cumul RE [HKSJ CI]"))

    prev_state, prev_res = None, None
    alert_log = []
    for i in range(len(cell_trials)):
        sub = cell_trials[:i + 1]
        res = meta_analyze(sub)
        state = derive_state(res)
        t = cell_trials[i]
        this_hr = "%.2f (%.2f-%.2f)" % (t["effect"], t["ci_low"], t["ci_high"])
        cumul = "%.2f [%s]" % (res["random_effect"], fmt_ci(res["random_ci"]))
        sig_mark = "*" if state["significant"] else " "
        print(" %-3d %-6s %-30s %-16s %s%-15s" % (
            t["year"], t["trial_id"], t["trial_name"][:30], this_hr, sig_mark, cumul))

        events = detect_transitions(prev_state, prev_res, state, res)
        for ev in events:
            alert_log.append((t, ev))
        prev_state, prev_res = state, res

    # final summary stats
    final = meta_analyze(cell_trials)
    fstate = derive_state(final)
    print("-" * 68)
    print(" FINAL POOLED (random-effects, DL + HKSJ):")
    print("   RR/HR = %.3f   95%% CI %s   (%s)" % (
        final["random_effect"], fmt_ci(final["random_ci"]),
        "significant" if fstate["significant"] else "NS"))
    print("   I^2 = %.0f%% (%s)   tau^2 = %.4f   Q = %.2f" % (
        final["I2"] * 100, fstate["i2_band"], final["tau2"], final["Q"]))
    pi = final["pred_interval"]
    print("   95%% prediction interval = %s%s" % (
        fmt_ci(pi),
        "" if pi is None else ("  (crosses null)" if fstate["pi_crosses_null"] else "  (excludes null)")))
    ep = final["egger_p"]
    print("   Egger's p = %s   |   GRADE certainty = %s" % (
        "n/a (<3 trials)" if ep is None else "%.3f" % ep, fstate["grade_label"]))
    if fstate["grade_reasons"]:
        print("     ↳ GRADE notes: %s" % "; ".join(fstate["grade_reasons"]))

    # alert log
    print("-" * 68)
    if alert_log:
        print(" 🔔 TRANSITION ALERT LOG (delta events only):")
        for t, ev in alert_log:
            print("   • after %s (%s, %d): %s" % (t["trial_id"], t["trial_name"], t["year"], ev))
            print("       ↳ source-trace: %s | n=%s | effect=%.2f (%.2f-%.2f)" % (
                t["source"], t.get("n", "?"), t["effect"], t["ci_low"], t["ci_high"]))
    else:
        print(" (no threshold transition fired in this cell)")
    return alert_log


def run_demo(cells):
    """Stream all cells chronologically and print only transition alerts."""
    header()
    print("\n[DEMO] Streaming synthetic trial sequence across all cells…")
    print("Emitting ONLY cumulative-meta threshold-transition alerts.\n")
    total_alerts = 0
    for key in sorted(cells):
        dclass, outcome = key
        cell_trials = cells[key]
        prev_state, prev_res = None, None
        fired_here = []
        for i in range(len(cell_trials)):
            res = meta_analyze(cell_trials[:i + 1])
            state = derive_state(res)
            events = detect_transitions(prev_state, prev_res, state, res)
            for ev in events:
                fired_here.append((cell_trials[i], ev))
            prev_state, prev_res = state, res
        if fired_here:
            print("■ %s × %s" % (dclass, outcome))
            for t, ev in fired_here:
                total_alerts += 1
                print("   🔔 %d %s: %s" % (t["year"], t["trial_id"], ev))
                print("        ↳ trace: %s | %s | n=%s" % (
                    t["trial_name"], t["source"], t.get("n", "?")))
            print()
    print("-" * 68)
    print("Total transition alerts fired: %d" % total_alerts)
    print("Use:  python3 main.py --class SGLT2i --outcome HHF   for one cell's full forest.")
    print(DISCLAIMER)


def list_cells(cells):
    header()
    print("\nAvailable cells (drug_class × outcome):")
    print("-" * 68)
    print(" %-22s %-18s %s" % ("drug_class", "outcome", "k (trials)"))
    for key in sorted(cells):
        print(" %-22s %-18s %d" % (key[0], key[1], len(cells[key])))
    print("\nQuery one cell:  python3 main.py --class <class> --outcome <outcome>")
    print(DISCLAIMER)


# ----------------------------------------------------------------------
# CLI
# ----------------------------------------------------------------------
def build_parser():
    p = argparse.ArgumentParser(
        prog="main.py",
        description=("CumulMetaDelta-Kor (큐뮬메타델타워치코어): cumulative meta-analysis "
                     "of diabetes drug-class RCTs with threshold-transition alerts. "
                     "FULLY OFFLINE on synthetic data. " + DISCLAIMER),
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=(
            "Examples:\n"
            "  python3 main.py                         # default: stream demo + transition alerts\n"
            "  python3 main.py --demo                  # same as default\n"
            "  python3 main.py --list                  # list all (class x outcome) cells\n"
            "  python3 main.py --class SGLT2i --outcome HHF   # one cell: full cumulative forest + alerts\n"
        ),
    )
    p.add_argument("--data", default=None,
                   help="path to trials JSON (default: data/trials.json next to main.py)")
    p.add_argument("--demo", action="store_true",
                   help="stream the synthetic trial sequence and print transition alerts (default mode)")
    p.add_argument("--list", action="store_true",
                   help="list all available (drug_class x outcome) cells")
    p.add_argument("--class", dest="drug_class", default=None,
                   help="drug class for single-cell query (e.g. SGLT2i, GLP-1RA, finerenone)")
    p.add_argument("--outcome", default=None,
                   help="outcome for single-cell query (e.g. HHF, MACE, renal_composite, mortality)")
    return p


def main(argv=None):
    parser = build_parser()
    args = parser.parse_args(argv)

    data_path = args.data or os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                          "data", "trials.json")
    if not os.path.exists(data_path):
        print("ERROR: data file not found: %s" % data_path, file=sys.stderr)
        return 2

    trials = load_trials(data_path)
    cells = group_into_cells(trials)

    if args.list:
        list_cells(cells)
        return 0

    if args.drug_class or args.outcome:
        if not (args.drug_class and args.outcome):
            print("ERROR: --class and --outcome must be given together.", file=sys.stderr)
            return 2
        key = (args.drug_class, args.outcome)
        if key not in cells:
            print("ERROR: no cell for %s × %s. Try --list." % key, file=sys.stderr)
            return 2
        header()
        print_cell_forest(key, cells[key])
        print("\n" + DISCLAIMER)
        return 0

    # default == demo
    run_demo(cells)
    return 0


if __name__ == "__main__":
    sys.exit(main())
