"""
test_logic.py — funnel_logic 순수 로직 단위 테스트 (streamlit/plotly 비의존)
실행: python3 test_logic.py
"""
import os
import math
import numpy as np
import pandas as pd

import funnel_logic as fl

HERE = os.path.dirname(os.path.abspath(__file__))
DATA = os.path.join(HERE, "data")
PASSED = 0
FAILED = 0


def check(name, cond):
    global PASSED, FAILED
    if cond:
        PASSED += 1
        print(f"  ok   - {name}")
    else:
        FAILED += 1
        print(f"  FAIL - {name}")


def test_funnel():
    df = pd.read_csv(os.path.join(DATA, "enrollment_funnel.csv"))
    totals = fl.funnel_totals(df)
    counts = {r["stage"]: r["count"] for _, r in totals.iterrows()}
    # funnel 단조 감소 (각 단계 <= 이전 단계)
    seq = [counts[s] for s in fl.FUNNEL_STAGES]
    check("funnel monotonically non-increasing", all(seq[i] >= seq[i + 1] for i in range(len(seq) - 1)))
    pr = fl.stage_pass_rates(totals)
    check("pass rates in [0,1]", ((pr["pass_rate"] >= 0) & (pr["pass_rate"] <= 1)).all())
    bn = fl.identify_bottleneck(pr)
    check("bottleneck has lowest pass_rate", math.isclose(bn["pass_rate"], pr["pass_rate"].min(), rel_tol=1e-9))
    mat = fl.site_stage_matrix(df)
    check("site matrix has all stages", list(mat.columns) == fl.FUNNEL_STAGES)


def test_screen_fail():
    df = pd.read_csv(os.path.join(DATA, "screen_fail_reasons.csv"))
    summ = fl.screen_fail_summary(df)
    check("screen-fail pct sums to ~1", math.isclose(summ["pct"].sum(), 1.0, abs_tol=1e-6))
    check("screen-fail sorted desc", (summ["count"].values[:-1] >= summ["count"].values[1:]).all())
    sp = fl.avoidable_split(df)
    check("avoidable + inevitable = total", sp["avoidable"] + sp["inevitable"] == int(df["count"].sum()))
    check("avoidable_pct in [0,1]", 0 <= sp["avoidable_pct"] <= 1)
    dev = fl.site_reason_deviation(df)
    check("deviation non-negative", (dev["deviation"] >= 0).all())


def test_bayesian():
    # 결정적 케이스: prior=1/wk, 10건/10주 관측 -> 사후 평균 약 1.0/wk
    a, b, rate, lo, hi = fl.bayesian_rate_posterior(10, 10, prior_rate=1.0, prior_weeks=1.0)
    check("posterior mean between lo and hi", lo <= rate <= hi)
    check("posterior mean ~ (1+10)/(1+10)=1.0", math.isclose(rate, 11 / 11, rel_tol=1e-9))
    # 등록 많을수록 rate 증가 (monotonic)
    _, _, r_low, _, _ = fl.bayesian_rate_posterior(5, 10)
    _, _, r_high, _, _ = fl.bayesian_rate_posterior(20, 10)
    check("more enrolled -> higher rate", r_high > r_low)

    pred = fl.predict_completion(target_n=50, cum_randomized=20, weeks_elapsed=10)
    check("remaining computed", pred["remaining"] == 30)
    check("optimistic <= mean <= pessimistic weeks",
          pred["weeks_optimistic"] <= pred["weeks_mean"] <= pred["weeks_pessimistic"])
    # 이미 목표 달성 시 remaining=0, weeks=0
    pred2 = fl.predict_completion(target_n=20, cum_randomized=25, weeks_elapsed=10)
    check("already met -> remaining 0", pred2["remaining"] == 0)
    check("already met -> weeks_mean 0", pred2["weeks_mean"] == 0)

    n_sites = fl.sites_needed(remaining_target=60, mean_site_rate=2.0, weeks_left=10)
    check("sites_needed = 60/(2*10)=3", math.isclose(n_sites, 3.0, rel_tol=1e-9))


def test_retention():
    df = pd.read_csv(os.path.join(DATA, "retention_visits.csv"))
    rc = fl.retention_curve(df)
    check("retention starts at 1.0", math.isclose(rc.iloc[0]["retention_rate"], 1.0, abs_tol=1e-6))
    check("retention non-increasing", (rc["retention_rate"].values[:-1] + 1e-9 >= rc["retention_rate"].values[1:]).all())
    eb = fl.early_responder_bias(df)
    check("bias baseline shift = 0", math.isclose(eb.iloc[0]["bmi_shift_vs_baseline"], 0.0, abs_tol=1e-6))


def test_representativeness():
    df = pd.read_csv(os.path.join(DATA, "demographics.csv"))
    rep = fl.representativeness(df)
    for dim in ["sex", "age_band", "race"]:
        t = rep[dim]["table"]
        check(f"{dim} enrolled_pct sums ~1", math.isclose(t["enrolled_pct"].sum(), 1.0, abs_tol=1e-6))
        check(f"{dim} reference_pct sums ~1", math.isclose(t["reference_pct"].sum(), 1.0, abs_tol=1e-6))
    check("age_band classifier", fl.age_band(25) == "19-39" and fl.age_band(50) == "40-59" and fl.age_band(70) == "60+")


if __name__ == "__main__":
    print("== ObesityEnrollFunnelOps-Kor 단위 테스트 ==")
    test_funnel()
    test_screen_fail()
    test_bayesian()
    test_retention()
    test_representativeness()
    print(f"\n결과: {PASSED} passed, {FAILED} failed")
    raise SystemExit(1 if FAILED else 0)
