"""
auc.py — 면적(AUC) 엔진
- trapezoidal 적분 (불균등 시간 간격 지원)
- total / incremental(기저 차감) / positive-incremental 선택
- 임의 시간창(window) 분리 (예: 0-120, 0-180)

모든 함수는 numpy만 사용 (scipy 불필요).
단위: 시간은 분(min), 농도는 임의 단위. 결과 단위 = 농도단위·min.
"""
from __future__ import annotations
import numpy as np


def _clip_window(times, values, t0, t1):
    """[t0, t1] 구간으로 자르되, 경계가 측정점 사이면 선형보간으로 경계점을 추가."""
    times = np.asarray(times, dtype=float)
    values = np.asarray(values, dtype=float)
    order = np.argsort(times)
    times, values = times[order], values[order]

    def interp_at(t):
        return float(np.interp(t, times, values))

    keep = (times >= t0) & (times <= t1)
    wt = list(times[keep])
    wv = list(values[keep])

    # 좌측 경계 보간 추가
    if t0 < times.min() or t0 not in wt:
        if times.min() <= t0 <= times.max():
            if t0 not in wt:
                wt.insert(0, t0)
                wv.insert(0, interp_at(t0))
    # 우측 경계 보간 추가
    if times.min() <= t1 <= times.max():
        if t1 not in wt:
            wt.append(t1)
            wv.append(interp_at(t1))

    pair = sorted(zip(wt, wv), key=lambda p: p[0])
    wt = np.array([p[0] for p in pair], dtype=float)
    wv = np.array([p[1] for p in pair], dtype=float)
    return wt, wv


def trapz_auc(times, values, t0=None, t1=None, mode="total", baseline=None):
    """
    사다리꼴 AUC.

    times    : 시점 리스트(분), 불균등 간격 허용
    values   : 동일 길이 농도 리스트
    t0, t1   : 적분 구간(미지정시 전체)
    mode     : 'total' | 'incremental' | 'positive_incremental'
               - total: 곡선 아래 전체 면적
               - incremental: (값 - baseline) 적분 (음/양 모두 합산, iAUC)
               - positive_incremental: max(값 - baseline, 0) 적분 (양의 증분만, pAUC)
    baseline : 기준선(미지정시 첫 시점값=공복값 사용)

    반환: float AUC. 유효 데이터 부족시 np.nan.
    """
    times = np.asarray(times, dtype=float)
    values = np.asarray(values, dtype=float)
    # NaN 제거
    ok = ~(np.isnan(times) | np.isnan(values))
    times, values = times[ok], values[ok]
    if times.size < 2:
        return float("nan")

    order = np.argsort(times)
    times, values = times[order], values[order]

    if t0 is None:
        t0 = float(times.min())
    if t1 is None:
        t1 = float(times.max())

    wt, wv = _clip_window(times, values, t0, t1)
    if wt.size < 2:
        return float("nan")

    if mode == "total":
        return float(np.trapezoid(wv, wt))

    if baseline is None:
        baseline = float(wv[0])

    inc = wv - baseline
    if mode == "incremental":
        return float(np.trapezoid(inc, wt))

    if mode == "positive_incremental":
        # 양의 증분만: 음수 구간을 0으로 클립하기 전에 영점 교차점을 보간 추가하여 정확도 향상
        xt, xv = _insert_zero_crossings(wt, inc)
        xv_pos = np.clip(xv, 0.0, None)
        return float(np.trapezoid(xv_pos, xt))

    raise ValueError(f"알 수 없는 mode: {mode}")


def _insert_zero_crossings(t, y):
    """y가 0을 교차하는 지점에 보간점을 삽입(positive iAUC 정확도용)."""
    t = np.asarray(t, dtype=float)
    y = np.asarray(y, dtype=float)
    nt, ny = [t[0]], [y[0]]
    for i in range(1, len(t)):
        y0, y1 = y[i - 1], y[i]
        if y0 == 0 or y1 == 0 or (y0 > 0) == (y1 > 0):
            nt.append(t[i]); ny.append(y[i])
        else:
            # 부호가 바뀜 -> 교차점 t* 보간
            frac = y0 / (y0 - y1)
            tc = t[i - 1] + frac * (t[i] - t[i - 1])
            nt.append(tc); ny.append(0.0)
            nt.append(t[i]); ny.append(y[i])
    return np.array(nt), np.array(ny)


def mean_over_curve(times, values, t0=None, t1=None):
    """구간 시간가중 평균 = total_AUC / (t1 - t0). Matsuda의 meanG/meanI 등에 사용."""
    times = np.asarray(times, dtype=float)
    values = np.asarray(values, dtype=float)
    ok = ~(np.isnan(times) | np.isnan(values))
    times, values = times[ok], values[ok]
    if times.size < 2:
        # 단순 산술평균 fallback
        return float(np.nanmean(values)) if values.size else float("nan")
    if t0 is None:
        t0 = float(times.min())
    if t1 is None:
        t1 = float(times.max())
    a = trapz_auc(times, values, t0, t1, mode="total")
    span = t1 - t0
    if span <= 0:
        return float("nan")
    return a / span


def all_auc_panel(times, values, windows):
    """
    지정된 window 목록에 대해 total/incremental/positive_incremental 3종을 모두 산출.
    반환: {(t0,t1): {'total':.., 'incremental':.., 'positive_incremental':..}}
    """
    out = {}
    for (t0, t1) in windows:
        out[(t0, t1)] = {
            "total": trapz_auc(times, values, t0, t1, "total"),
            "incremental": trapz_auc(times, values, t0, t1, "incremental"),
            "positive_incremental": trapz_auc(times, values, t0, t1, "positive_incremental"),
        }
    return out
