"""
funnel_logic.py — 순수 로직 모듈 (Streamlit/plotly 비의존)
================================================================
ObesityEnrollFunnelOps-Kor 의 분석 로직을 streamlit 컨텍스트 없이
호출/단위테스트 가능하도록 분리한 모듈.

본 도구는 연구·운영 보조용 참고 도구이며 실제 임상시험 규제 의사결정을 대체하지 않는다.
"""
from __future__ import annotations
import numpy as np
import pandas as pd
from scipy import stats

FUNNEL_STAGES = ["referral", "prescreen", "consent", "screen", "randomize", "retain"]
STAGE_LABELS_KOR = {
    "referral": "의뢰", "prescreen": "사전선별", "consent": "동의",
    "screen": "스크리닝", "randomize": "무작위배정", "retain": "잔류",
}


# ---------------------------------------------------------------------------
# 기능 1: funnel 분해 + bottleneck 식별
# ---------------------------------------------------------------------------
def funnel_totals(funnel_df: pd.DataFrame) -> pd.DataFrame:
    """단계별 총 카운트를 FUNNEL_STAGES 순서로 반환."""
    g = funnel_df.groupby("stage")["count"].sum()
    rows = [{"stage": s, "count": int(g.get(s, 0))} for s in FUNNEL_STAGES]
    return pd.DataFrame(rows)


def stage_pass_rates(stage_counts: pd.DataFrame) -> pd.DataFrame:
    """인접 단계 간 통과율과 이탈수를 계산.

    stage_counts: columns [stage, count], FUNNEL_STAGES 순서.
    반환: from_stage, to_stage, pass_rate, dropped
    """
    counts = {r["stage"]: r["count"] for _, r in stage_counts.iterrows()}
    rows = []
    for a, b in zip(FUNNEL_STAGES[:-1], FUNNEL_STAGES[1:]):
        ca, cb = counts.get(a, 0), counts.get(b, 0)
        rate = (cb / ca) if ca > 0 else 0.0
        rows.append({"from_stage": a, "to_stage": b,
                     "pass_rate": rate, "dropped": int(ca - cb)})
    return pd.DataFrame(rows)


def identify_bottleneck(pass_df: pd.DataFrame) -> dict:
    """통과율이 가장 낮은 (최대 병목) 단계 전이를 반환."""
    if pass_df.empty:
        return {}
    idx = pass_df["pass_rate"].idxmin()
    row = pass_df.loc[idx]
    return {"from_stage": row["from_stage"], "to_stage": row["to_stage"],
            "pass_rate": float(row["pass_rate"]), "dropped": int(row["dropped"])}


def site_stage_matrix(funnel_df: pd.DataFrame) -> pd.DataFrame:
    """site x stage 피벗 (총합)."""
    p = funnel_df.pivot_table(index="site_id", columns="stage",
                              values="count", aggfunc="sum", fill_value=0)
    return p.reindex(columns=FUNNEL_STAGES, fill_value=0)


# ---------------------------------------------------------------------------
# 기능 2: screen-fail 근본원인 taxonomy
# ---------------------------------------------------------------------------
def screen_fail_summary(sf_df: pd.DataFrame) -> pd.DataFrame:
    """사유코드별 총 카운트와 비율."""
    g = sf_df.groupby(["reason_code", "reason_label", "avoidable"])["count"].sum().reset_index()
    total = g["count"].sum()
    g["pct"] = g["count"] / total if total > 0 else 0.0
    return g.sort_values("count", ascending=False).reset_index(drop=True)


def avoidable_split(sf_df: pd.DataFrame) -> dict:
    """회피가능 vs 불가피 screen-fail 합계와 회피가능 비율."""
    g = sf_df.groupby("avoidable")["count"].sum()
    avoid = int(g.get("회피가능", 0))
    inev = int(g.get("불가피", 0))
    total = avoid + inev
    return {"avoidable": avoid, "inevitable": inev,
            "avoidable_pct": (avoid / total if total > 0 else 0.0)}


def site_reason_deviation(sf_df: pd.DataFrame) -> pd.DataFrame:
    """site별 사유 분포 vs 전체 평균 분포 편차(절대 비율차 합)."""
    overall = sf_df.groupby("reason_code")["count"].sum()
    overall_p = overall / overall.sum() if overall.sum() > 0 else overall
    rows = []
    for sid in sf_df["site_id"].unique():
        s = sf_df[sf_df.site_id == sid].groupby("reason_code")["count"].sum()
        sp = s / s.sum() if s.sum() > 0 else s
        sp = sp.reindex(overall_p.index, fill_value=0.0)
        dev = float(np.abs(sp - overall_p).sum())
        rows.append({"site_id": sid, "deviation": dev, "n_fail": int(s.sum())})
    return pd.DataFrame(rows).sort_values("deviation", ascending=False).reset_index(drop=True)


# ---------------------------------------------------------------------------
# 기능 3: site별 등록곡선 Bayesian 재예측
# ---------------------------------------------------------------------------
def bayesian_rate_posterior(n_enrolled: int, weeks_elapsed: float,
                            prior_rate: float = 1.0, prior_weeks: float = 1.0):
    """주당 등록률(Poisson rate)의 Gamma 사후분포를 반환.

    Poisson-Gamma 켤레: prior Gamma(alpha0, beta0).
    alpha0 = prior_rate * prior_weeks, beta0 = prior_weeks.
    관측: n_enrolled 건이 weeks_elapsed 주 동안 발생.
    사후: Gamma(alpha0 + n_enrolled, beta0 + weeks_elapsed).
    반환: (alpha_post, beta_post, mean_rate, lo95, hi95)
    """
    alpha0 = max(1e-6, prior_rate * prior_weeks)
    beta0 = max(1e-6, prior_weeks)
    alpha_post = alpha0 + max(0, n_enrolled)
    beta_post = beta0 + max(1e-6, weeks_elapsed)
    mean_rate = alpha_post / beta_post
    lo = stats.gamma.ppf(0.025, a=alpha_post, scale=1.0 / beta_post)
    hi = stats.gamma.ppf(0.975, a=alpha_post, scale=1.0 / beta_post)
    return alpha_post, beta_post, float(mean_rate), float(lo), float(hi)


def predict_completion(target_n: int, cum_randomized: int, weeks_elapsed: float,
                       prior_rate: float = 1.0):
    """목표 등록수 달성까지 남은 주차를 Bayesian 갱신으로 예측.

    반환 dict: rate_mean/lo/hi (주당), remaining, weeks_to_complete (mean/lo/hi).
    weeks_to_complete 은 점추정(mean rate)과 신뢰구간(rate lo/hi) 기반.
    """
    a, b, rate, lo, hi = bayesian_rate_posterior(cum_randomized, weeks_elapsed,
                                                 prior_rate=prior_rate, prior_weeks=1.0)
    remaining = max(0, target_n - cum_randomized)
    def wk(r):
        return remaining / r if r > 1e-9 else float("inf")
    # 낮은 rate -> 더 오래 걸림 => weeks lo/hi 가 역전
    return {
        "rate_mean": rate, "rate_lo": lo, "rate_hi": hi,
        "remaining": remaining,
        "weeks_mean": wk(rate),
        "weeks_optimistic": wk(hi),   # 빠른 등록률 가정 -> 짧은 기간
        "weeks_pessimistic": wk(lo),  # 느린 등록률 가정 -> 긴 기간
    }


def sites_needed(remaining_target: int, mean_site_rate: float,
                 weeks_left: float) -> float:
    """남은 목표를 weeks_left 안에 달성하려면 필요한 (평균속도) site 수."""
    capacity_per_site = mean_site_rate * max(1e-9, weeks_left)
    if capacity_per_site <= 0:
        return float("inf")
    return remaining_target / capacity_per_site


# ---------------------------------------------------------------------------
# 기능 4: retention / dropout 추적
# ---------------------------------------------------------------------------
def retention_curve(ret_df: pd.DataFrame) -> pd.DataFrame:
    """방문주차별 잔류율(전체)."""
    g = ret_df.groupby("visit_week").agg(n_present=("n_present", "sum"),
                                         n_enrolled=("n_enrolled", "sum")).reset_index()
    base = g.loc[g["visit_week"] == g["visit_week"].min(), "n_present"]
    base_n = int(base.iloc[0]) if len(base) else g["n_enrolled"].max()
    g["retention_rate"] = g["n_present"] / base_n if base_n > 0 else 0.0
    return g


def early_responder_bias(ret_df: pd.DataFrame) -> pd.DataFrame:
    """방문별 잔류군 평균 BMI 추이 — early-responder/잔류편향 모니터.

    잔류군의 평균 BMI 가 baseline 대비 어떻게 변하는지(생존편향)를 본다.
    """
    def _w(d):
        return np.average(d["mean_bmi_present"].fillna(0),
                          weights=d["n_present"].clip(lower=0) + 1e-9)
    try:
        g = ret_df.groupby("visit_week").apply(_w, include_groups=False).reset_index(
            name="weighted_mean_bmi_present")
    except TypeError:  # older pandas without include_groups
        g = ret_df.groupby("visit_week").apply(_w).reset_index(
            name="weighted_mean_bmi_present")
    base = g.loc[g["visit_week"] == g["visit_week"].min(), "weighted_mean_bmi_present"]
    base_bmi = float(base.iloc[0]) if len(base) else float("nan")
    g["bmi_shift_vs_baseline"] = g["weighted_mean_bmi_present"] - base_bmi
    return g


# ---------------------------------------------------------------------------
# 기능 5: 다양성 / representativeness
# ---------------------------------------------------------------------------
# 비만 인구 역학 reference 분포(합성·문헌 기반 가정, 한국 성인 비만 대략치)
EPI_REFERENCE = {
    "sex": {"여성": 0.52, "남성": 0.48},
    "age_band": {"19-39": 0.30, "40-59": 0.45, "60+": 0.25},
    "race": {"동아시아": 0.93, "동남아시아": 0.05, "기타": 0.02},
}


def age_band(age: int) -> str:
    if age < 40:
        return "19-39"
    if age < 60:
        return "40-59"
    return "60+"


def representativeness(demo_df: pd.DataFrame) -> dict:
    """등록 인구 분포 vs 역학 reference 비교.

    각 차원별로 enrolled 비율, reference 비율, 절대차를 표로 반환하고
    카이제곱 적합도 검정 p-value 를 첨부.
    """
    out = {}
    d = demo_df.copy()
    d["age_band"] = d["age"].apply(age_band)
    for dim, ref in EPI_REFERENCE.items():
        col = dim if dim != "age_band" else "age_band"
        obs = d[col].value_counts()
        cats = list(ref.keys())
        obs_counts = np.array([int(obs.get(c, 0)) for c in cats], dtype=float)
        n = obs_counts.sum()
        ref_p = np.array([ref[c] for c in cats], dtype=float)
        ref_p = ref_p / ref_p.sum()
        exp_counts = ref_p * n
        table = pd.DataFrame({
            "category": cats,
            "enrolled_pct": (obs_counts / n if n > 0 else obs_counts),
            "reference_pct": ref_p,
        })
        table["abs_diff"] = (table["enrolled_pct"] - table["reference_pct"]).abs()
        pval = float("nan")
        if n > 0 and (exp_counts > 0).all():
            try:
                _, pval = stats.chisquare(f_obs=obs_counts, f_exp=exp_counts)
            except Exception:
                pval = float("nan")
        out[dim] = {"table": table, "chisq_p": pval, "n": int(n)}
    return out
