#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
서로게이트회귀워치코어 / SurrogacyMetaReg-Kor
=============================================
MASLD/MASH 대리지표(surrogate) x 임상결과(hard outcome) 쌍별
trial-level surrogacy meta-regression 워치 도구 (OFFLINE MVP).

각 쌍에 대해:
  - 가중최소제곱(weighted least squares) trial-level meta-regression
  - trial-level R^2
  - 기울기(slope) + 표준오차(SE) + 95% CI
  - surrogate threshold effect (STE)
  - prediction interval
  - leave-one-trial-out (LOO) 민감도
신규 trial-level 데이터가 들어올 때 신뢰도(reliability)가
'강화/약화'되는 전이(transition)만 알림으로 emit 한다.

순수 stdlib(math)만으로 구현. numpy 가 있으면 써도 되지만 필수가 아니다.
참고용·연구용 (research/reference only). 대리지표 타당성은 불확실하며
규제/임상 판단에 사용하지 말 것.
"""

import argparse
import json
import math
import os
import sys

# numpy 는 선택적. 없어도 stdlib fallback 으로 동작.
try:
    import numpy as _np  # noqa: F401
    HAVE_NUMPY = True
except Exception:
    HAVE_NUMPY = False

DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
PAIRS_FILE = os.path.join(DATA_DIR, "pairs.json")

DISCLAIMER = (
    "참고용·연구용 (research/reference only). 대리지표(surrogate) 타당성은 본질적으로 불확실. "
    "규제/임상 결정에 사용 금지. (do NOT use for regulatory/clinical decisions)"
)

# 신뢰도 판정에 쓰는 임계값들
R2_THRESHOLD_DEFAULT = 0.70   # trial-level R^2 신뢰 임계
MIN_TRIALS = 4                # 이보다 적으면 '너무 sparse'
SPARSE_TRIALS = 5             # 이하면 과신 경고 (one point matters a lot)

# t 분포 양측 95% 임계값 (df: 자유도). df<=0 은 신뢰구간 산출 불가.
_T95 = {
    1: 12.706, 2: 4.303, 3: 3.182, 4: 2.776, 5: 2.571, 6: 2.447,
    7: 2.365, 8: 2.306, 9: 2.262, 10: 2.228, 11: 2.201, 12: 2.179,
    13: 2.160, 14: 2.145, 15: 2.131, 20: 2.086, 30: 2.042,
}


def t_crit_95(df):
    """양측 95% t 임계값. 표에 없으면 인접/근사값으로 보간."""
    if df <= 0:
        return float("nan")
    if df in _T95:
        return _T95[df]
    if df > 30:
        return 1.96 + (2.042 - 1.96) * (30.0 / df)  # 30 -> inf 근사
    keys = sorted(_T95.keys())
    lo = max(k for k in keys if k <= df)
    hi = min(k for k in keys if k >= df)
    if lo == hi:
        return _T95[lo]
    frac = (df - lo) / (hi - lo)
    return _T95[lo] + (_T95[hi] - _T95[lo]) * frac


# --------------------------------------------------------------------------
# Weighted least squares meta-regression  (y = a + b*x, 가중치 w)
# --------------------------------------------------------------------------
def weighted_meta_regression(points, ste_target_upper=0.0):
    """
    points: list of dict {x, y, w}
    반환 dict 키:
      n, slope, intercept, slope_se, slope_ci, r2,
      pred_interval_half (대표 prediction interval 반폭, x=mean 기준),
      ste (surrogate threshold effect), residual_sigma
    데이터가 부족하면 가능한 필드만 채우고 나머지는 None.
    """
    n = len(points)
    out = {
        "n": n, "slope": None, "intercept": None, "slope_se": None,
        "slope_ci": None, "r2": None, "ste": None,
        "pred_interval_half": None, "residual_sigma": None,
        "xbar": None, "sumw": None,
    }
    if n == 0:
        return out

    xs = [p["x"] for p in points]
    ys = [p["y"] for p in points]
    ws = [max(p.get("w", 1.0), 1e-9) for p in points]

    sumw = sum(ws)
    xbar = sum(w * x for w, x in zip(ws, xs)) / sumw
    ybar = sum(w * y for w, y in zip(ws, ys)) / sumw
    out["xbar"] = xbar
    out["sumw"] = sumw

    sxx = sum(w * (x - xbar) ** 2 for w, x in zip(ws, xs))
    sxy = sum(w * (x - xbar) * (y - ybar) for w, x, y in zip(ws, xs, ys))
    syy = sum(w * (y - ybar) ** 2 for w, y in zip(ws, ys))

    if sxx <= 1e-12:
        # x 변동이 없으면 기울기 추정 불가 (수직선/단일 x)
        out["intercept"] = ybar
        return out

    slope = sxy / sxx
    intercept = ybar - slope * xbar
    out["slope"] = slope
    out["intercept"] = intercept

    # 가중 R^2
    ss_res = sum(w * (y - (intercept + slope * x)) ** 2
                 for w, x, y in zip(ws, xs, ys))
    r2 = 1.0 - (ss_res / syy) if syy > 1e-12 else None
    if r2 is not None:
        r2 = max(0.0, min(1.0, r2))
    out["r2"] = r2

    df = n - 2
    if df >= 1:
        # 가중 잔차분산 -> 기울기 SE
        sigma2 = ss_res / df            # 가중 평균 잔차분산 (단위: w 가중)
        slope_var = sigma2 / sxx
        slope_se = math.sqrt(max(slope_var, 0.0))
        out["slope_se"] = slope_se
        out["residual_sigma"] = math.sqrt(max(sigma2, 0.0))
        tc = t_crit_95(df)
        out["slope_ci"] = (slope - tc * slope_se, slope + tc * slope_se)

        # prediction interval (x = xbar 에서의 반폭): 새 trial 예측 불확실성
        # se_pred = sqrt(sigma2 * (1 + 1/sumw_eff + (x-xbar)^2/sxx)), x=xbar
        se_pred = math.sqrt(max(sigma2 * (1.0 + 1.0 / sumw), 0.0))
        out["pred_interval_half"] = tc * se_pred

        # STE: 예측 상한(benefit 경계 = 0)이 ste_target_upper 를 만족하는 최소 x.
        # 모델 benefit 은 y<0. 예측선 upper bound 가 0 을 가로지르는 x 를 푼다.
        out["ste"] = _solve_ste(slope, intercept, xbar, sxx, sigma2,
                                sumw, tc, target=ste_target_upper)
    return out


def _solve_ste(slope, intercept, xbar, sxx, sigma2, sumw, tc, target=0.0):
    """
    surrogate threshold effect:
      yhat(x) + tc*se_pred(x) = target (=0) 를 만족하는 최소 x.
    benefit 방향(slope<0)에서만 의미. 격자 스캔 후 이분법으로 미세화.
    """
    if slope >= -1e-9:
        return None  # 대리지표 증가가 benefit 으로 연결되지 않음

    def upper(x):
        yhat = intercept + slope * x
        se = math.sqrt(max(sigma2 * (1.0 + 1.0 / sumw + (x - xbar) ** 2 / sxx), 0.0))
        return yhat + tc * se

    lo, hi = -1.0, 3.0
    prev_x = lo
    prev_v = upper(lo)
    step = 0.01
    x = lo + step
    while x <= hi:
        v = upper(x)
        if (prev_v - target) * (v - target) <= 0:  # 부호변화 = 교차
            a, b = prev_x, x
            for _ in range(40):
                m = 0.5 * (a + b)
                if (upper(a) - target) * (upper(m) - target) <= 0:
                    b = m
                else:
                    a = m
            return 0.5 * (a + b)
        prev_x, prev_v = x, v
        x += step
    return None


# --------------------------------------------------------------------------
# Reliability 상태 판정
# --------------------------------------------------------------------------
def assess(reg, r2_threshold, ste_range):
    """회귀 결과 -> 신뢰도 상태 dict."""
    n = reg["n"]
    too_sparse = n < MIN_TRIALS or reg["r2"] is None
    r2_ok = (reg["r2"] is not None) and (reg["r2"] >= r2_threshold)
    slope_ci = reg["slope_ci"]
    slope_excludes_zero = bool(
        slope_ci and (slope_ci[0] < 0 and slope_ci[1] < 0)
        or slope_ci and (slope_ci[0] > 0 and slope_ci[1] > 0)
    )
    ste = reg["ste"]
    ste_in_range = bool(
        ste is not None and ste_range and ste_range[0] <= ste <= ste_range[1]
    )
    sparse_warn = n <= SPARSE_TRIALS
    return {
        "too_sparse": too_sparse,
        "r2_ok": r2_ok,
        "slope_excludes_zero": slope_excludes_zero,
        "ste_in_range": ste_in_range,
        "sparse_warn": sparse_warn,
    }


def loo_sensitivity(points, r2_threshold, ste_range):
    """leave-one-trial-out: 각 trial 을 빼고 R^2/slope/STE 가 어떻게 변하나."""
    rows = []
    if len(points) < 3:
        return rows
    for i in range(len(points)):
        subset = points[:i] + points[i + 1:]
        reg = weighted_meta_regression(subset)
        rows.append({
            "dropped": points[i].get("trial", f"#{i}"),
            "r2": reg["r2"],
            "slope": reg["slope"],
            "ste": reg["ste"],
            "n": reg["n"],
        })
    return rows


# --------------------------------------------------------------------------
# 데이터 로드
# --------------------------------------------------------------------------
def load_pairs(path=PAIRS_FILE):
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data["pairs"]


def find_pair(pairs, pid):
    for p in pairs:
        if p["id"].lower() == pid.lower():
            return p
    # 부분 일치 허용
    for p in pairs:
        if pid.lower() in p["id"].lower():
            return p
    return None


# --------------------------------------------------------------------------
# 포맷 헬퍼
# --------------------------------------------------------------------------
def fmt(v, nd=3):
    if v is None:
        return "  n/a"
    if isinstance(v, float) and (math.isnan(v) or math.isinf(v)):
        return "  n/a"
    return f"{v:.{nd}f}"


def fmt_ci(ci, nd=3):
    if not ci:
        return "n/a"
    return f"[{ci[0]:.{nd}f}, {ci[1]:.{nd}f}]"


def header():
    line = "=" * 74
    print(line)
    print("  서로게이트회귀워치코어 / SurrogacyMetaReg-Kor")
    print("  MASLD/MASH trial-level surrogacy meta-regression watcher (OFFLINE MVP)")
    print("  도메인: MASLD(대사성간질환)   카테고리: 연구 알림(research-alert)")
    print(line)
    print("  [!] " + DISCLAIMER)
    print("  [!] 과신 경고(OVERCONFIDENCE): 데이터가 sparse/시간지연(time-lagged)일 때")
    print("      단일 trial 이 결과를 크게 흔든다. R^2 가 높아도 LOO/sparsity 를 반드시 확인.")
    print(line)
    if HAVE_NUMPY:
        print("  (numpy 감지됨 — 단, 계산은 stdlib 경로로 동일하게 수행)")
    print()


# --------------------------------------------------------------------------
# 한 쌍 상세 출력
# --------------------------------------------------------------------------
def print_pair_detail(pair):
    pid = pair["id"]
    r2_thr = pair.get("r2_threshold", R2_THRESHOLD_DEFAULT)
    ste_range = pair.get("ste_achievable_range")
    pts = pair["trials"]

    reg = weighted_meta_regression(pts)
    st = assess(reg, r2_thr, ste_range)

    print("-" * 74)
    print(f"PAIR: {pid}")
    print(f"  surrogate    : {pair['surrogate']}")
    print(f"  hard outcome : {pair['hard_outcome']}")
    print(f"  trials (n)   : {reg['n']}   |  R^2 threshold = {r2_thr}")
    print("-" * 74)
    print("  현재 meta-regression (가중최소제곱 / weighted LS):")
    print(f"    slope (b)            : {fmt(reg['slope'])}   "
          f"(b<0 = 대리지표 호전 -> 임상결과 benefit)")
    print(f"    slope 95% CI         : {fmt_ci(reg['slope_ci'])}")
    print(f"    intercept (a)        : {fmt(reg['intercept'])}")
    print(f"    trial-level R^2      : {fmt(reg['r2'])}   "
          f"(>= {r2_thr} 면 신뢰 가능 후보)")
    print(f"    STE (surrogate thr.) : {fmt(reg['ste'])}   "
          f"achievable range = {ste_range}")
    print(f"    prediction interval  : +/- {fmt(reg['pred_interval_half'])} "
          f"(x=mean 기준 새 trial 예측 반폭)")

    print()
    print("  신뢰도 판정:")
    print(f"    R^2 임계 통과         : {'YES' if st['r2_ok'] else 'no'}")
    print(f"    slope CI 가 0 제외     : {'YES' if st['slope_excludes_zero'] else 'no'}")
    print(f"    STE 가 임상달성범위 내 : {'YES' if st['ste_in_range'] else 'no'}")
    if st["too_sparse"]:
        print(f"    >>> [TOO SPARSE] n={reg['n']} < {MIN_TRIALS}: "
              f"신뢰 판정 불가 — LOW CONFIDENCE.")
    if st["sparse_warn"]:
        print(f"    >>> [SPARSITY WARNING] n={reg['n']} <= {SPARSE_TRIALS}: "
              f"funnel 평가 불가, 단일 trial 영향 큼. 과신 금지.")

    # LOO 민감도
    print()
    print("  leave-one-trial-out (LOO) 민감도 (데이터가 sparse -> one point matters):")
    loo = loo_sensitivity(pts, r2_thr, ste_range)
    if not loo:
        print("    (trial 수가 너무 적어 LOO 산출 불가)")
    else:
        base_r2 = reg["r2"]
        print(f"    {'dropped trial':<22}{'R^2':>8}{'dR^2':>9}{'slope':>9}{'STE':>9}")
        max_swing = 0.0
        flip = []
        for row in loo:
            dr2 = (row["r2"] - base_r2) if (row["r2"] is not None and base_r2 is not None) else None
            if dr2 is not None:
                max_swing = max(max_swing, abs(dr2))
            # 임계 통과 여부가 뒤집히는지
            if row["r2"] is not None and base_r2 is not None:
                if (base_r2 >= r2_thr) != (row["r2"] >= r2_thr):
                    flip.append(row["dropped"])
            print(f"    {row['dropped']:<22}{fmt(row['r2']):>8}"
                  f"{('  '+fmt(dr2)) if dr2 is not None else '     n/a':>9}"
                  f"{fmt(row['slope']):>9}{fmt(row['ste']):>9}")
        print(f"    최대 R^2 변동(|dR^2|) : {fmt(max_swing)}")
        if flip:
            print(f"    >>> [FRAGILE] 다음 trial 제거 시 R^2 임계 통과 여부가 뒤집힘: "
                  f"{', '.join(flip)}")
        else:
            print("    R^2 임계 통과 여부는 단일 trial 제거에 강건(robust)함.")
    print()


# --------------------------------------------------------------------------
# 스트리밍 데모: trial 을 하나씩 누적하며 reliability 전이 알림만 emit
# --------------------------------------------------------------------------
def stream_pair_alerts(pair):
    pid = pair["id"]
    r2_thr = pair.get("r2_threshold", R2_THRESHOLD_DEFAULT)
    ste_range = pair.get("ste_achievable_range")
    pts = pair["trials"]

    print("-" * 74)
    print(f"STREAM: {pid}  ({pair['surrogate']} -> {pair['hard_outcome']})")

    prev = None
    alerts = 0
    acc = []
    for i, p in enumerate(pts, start=1):
        acc.append(p)
        reg = weighted_meta_regression(acc)
        st = assess(reg, r2_thr, ste_range)
        cur = (st["r2_ok"], st["slope_excludes_zero"], st["ste_in_range"])

        tag = p.get("trial", f"#{i}")
        if prev is None:
            # 초기 상태 보고 (전이는 아님)
            sparse = " [SPARSE-LOWCONF]" if st["too_sparse"] else ""
            print(f"  + {tag:<20} n={reg['n']}  R^2={fmt(reg['r2'])}  "
                  f"slope={fmt(reg['slope'])}{sparse}  (baseline)")
        else:
            events = []
            if cur[0] != prev[0]:
                events.append(
                    f"R^2 {'>=' if cur[0] else '<'} {r2_thr} 로 "
                    f"{'STRENGTHEN(신뢰강화)' if cur[0] else 'WEAKEN(신뢰약화)'} "
                    f"(R^2 {fmt(reg['r2'])})")
            if cur[1] != prev[1]:
                events.append(
                    f"slope 95% CI 가 0 을 "
                    f"{'제외(EXCLUDES 0 -> 유의)' if cur[1] else '포함(INCLUDES 0 -> 불확실)'}")
            if cur[2] != prev[2]:
                events.append(
                    f"STE 가 임상달성범위에 "
                    f"{'진입(ENTERS achievable)' if cur[2] else '이탈(LEAVES achievable)'} "
                    f"(STE {fmt(reg['ste'])})")
            if events:
                alerts += 1
                print(f"  * [ALERT] {tag} 추가 후 reliability 전이:")
                for e in events:
                    print(f"        - {e}")
                if st["too_sparse"] or st["sparse_warn"]:
                    print(f"        (주의: n={reg['n']} sparse — 이 전이는 단일 "
                          f"trial 에 취약, 과신 금지)")
            else:
                # 전이 없음: 조용히 진행 (한 줄 요약만)
                print(f"  + {tag:<20} n={reg['n']}  R^2={fmt(reg['r2'])}  "
                      f"slope={fmt(reg['slope'])}  (no transition)")
        prev = cur

    if alerts == 0:
        print("  (reliability 전이 없음 — 상태 변화 알림 없음)")
    # 최종 sparsity 경고
    final = weighted_meta_regression(acc)
    if final["n"] <= SPARSE_TRIALS or final["r2"] is None:
        print(f"  >>> [SPARSITY WARNING] 최종 n={final['n']}: "
              f"too sparse — LOW CONFIDENCE. funnel/over-confidence 위험.")
    print()


# --------------------------------------------------------------------------
# 레지스트리(매트릭스) 출력
# --------------------------------------------------------------------------
def print_registry(pairs):
    print("REGISTRY (surrogate x hard-outcome 매트릭스):")
    print(f"  {'pair id':<28}{'n':>3}{'R^2':>8}{'thr':>6}  status")
    for pair in pairs:
        r2_thr = pair.get("r2_threshold", R2_THRESHOLD_DEFAULT)
        reg = weighted_meta_regression(pair["trials"])
        st = assess(reg, r2_thr, pair.get("ste_achievable_range"))
        if st["too_sparse"]:
            status = "TOO SPARSE / low-confidence"
        elif st["r2_ok"] and st["slope_excludes_zero"]:
            status = "reliable-candidate"
        elif st["r2_ok"]:
            status = "R^2 ok / slope uncertain"
        else:
            status = "weak surrogate"
        print(f"  {pair['id']:<28}{reg['n']:>3}{fmt(reg['r2']):>8}"
              f"{r2_thr:>6}  {status}")
    print()


# --------------------------------------------------------------------------
# CLI
# --------------------------------------------------------------------------
def build_parser():
    p = argparse.ArgumentParser(
        prog="main.py",
        description=(
            "서로게이트회귀워치코어 / SurrogacyMetaReg-Kor — MASLD/MASH 대리지표 x "
            "임상결과 쌍의 trial-level surrogacy meta-regression 을 유지하고, 신규 "
            "trial 데이터가 대리지표 신뢰도를 강화/약화시키면 알림한다. "
            "참고용·연구용 (research/reference only)."
        ),
        epilog=(
            "예시:\n"
            "  python3 main.py                         # 데모: 전 쌍 스트리밍 + 전이 알림\n"
            "  python3 main.py --demo                  # 위와 동일\n"
            "  python3 main.py --registry              # 쌍 매트릭스 요약\n"
            "  python3 main.py --pair \"PDFF/cirrhosis\"  # 한 쌍 상세 + 알림로그 + LOO\n"
            "  python3 main.py --list                  # 사용 가능한 쌍 id 목록\n"
        ),
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    p.add_argument("--pair", metavar="ID",
                   help="한 쌍의 현재 meta-regression + 알림로그 + LOO 민감도 출력 "
                        "(예: \"PDFF/cirrhosis\")")
    p.add_argument("--demo", action="store_true",
                   help="합성 trial-level 시퀀스를 스트리밍하며 reliability 전이 알림 출력 (기본 모드)")
    p.add_argument("--registry", action="store_true",
                   help="surrogate x hard-outcome 레지스트리(매트릭스) 요약")
    p.add_argument("--list", action="store_true",
                   help="사용 가능한 쌍 id 목록만 출력")
    p.add_argument("--data", metavar="PATH", default=PAIRS_FILE,
                   help=f"pairs.json 경로 (기본: {PAIRS_FILE})")
    return p


def main(argv=None):
    parser = build_parser()
    args = parser.parse_args(argv)

    try:
        pairs = load_pairs(args.data)
    except FileNotFoundError:
        print(f"[ERROR] 데이터 파일을 찾을 수 없음: {args.data}", file=sys.stderr)
        return 2
    except (json.JSONDecodeError, KeyError) as e:
        print(f"[ERROR] 데이터 파싱 실패: {e}", file=sys.stderr)
        return 2

    header()

    if args.list:
        print("사용 가능한 쌍 id:")
        for p in pairs:
            print(f"  - {p['id']}  ({p['surrogate']} -> {p['hard_outcome']})")
        return 0

    if args.registry:
        print_registry(pairs)
        return 0

    if args.pair:
        pair = find_pair(pairs, args.pair)
        if pair is None:
            print(f"[ERROR] '{args.pair}' 와 일치하는 쌍 없음. --list 로 확인.",
                  file=sys.stderr)
            return 2
        print_pair_detail(pair)
        print("[알림 로그] 동일 쌍에 대한 스트리밍 전이 재생:")
        stream_pair_alerts(pair)
        return 0

    # 기본 = demo
    print("DEMO 모드: 전체 쌍에 대해 trial-level 시퀀스를 누적 스트리밍하며")
    print("           reliability 전이(강화/약화)만 알림으로 emit.\n")
    print_registry(pairs)
    for pair in pairs:
        stream_pair_alerts(pair)
    print("=" * 74)
    print("끝. 모든 수치는 합성 데이터 기반. " + DISCLAIMER)
    print("=" * 74)
    return 0


if __name__ == "__main__":
    sys.exit(main())
