"""
notifications.py — 시점·알림 일정 + push timing 휴리스틱.

기능:
- generate_schedule(protocol, duration_weeks, start_date) → 방문/입력 시점 list
- suggest_push_time(patient_chronotype, last_inputs) → 시간대 제안
- score_compliance_window(entries, weeks_back) → 직전 N주 입력률
"""
from __future__ import annotations
from datetime import date, datetime, timedelta


PROTOCOLS = {
    "weekly": {
        "name": "매주 ePRO",
        "interval_days": 7,
    },
    "milestone": {
        "name": "Milestone 방문 (baseline, W4, W12, W24, W52)",
        "weeks": [0, 4, 12, 24, 52],
    },
    "weekly+milestone": {
        "name": "매주 ePRO + milestone 강화",
        "interval_days": 7,
        "milestone_weeks": [0, 4, 12, 24, 52],
    },
}


def generate_schedule(protocol: str = "weekly",
                      duration_weeks: int = 52,
                      start_date: date | None = None) -> list:
    if start_date is None:
        start_date = date.today()
    out = []
    if protocol == "weekly":
        for w in range(duration_weeks + 1):
            d = start_date + timedelta(weeks=w)
            out.append({"week": w, "date": d.isoformat(), "type": "weekly_epro"})
    elif protocol == "milestone":
        for w in PROTOCOLS["milestone"]["weeks"]:
            if w > duration_weeks:
                break
            d = start_date + timedelta(weeks=w)
            out.append({"week": w, "date": d.isoformat(), "type": "milestone_visit"})
    elif protocol == "weekly+milestone":
        for w in range(duration_weeks + 1):
            d = start_date + timedelta(weeks=w)
            entry = {"week": w, "date": d.isoformat(), "type": "weekly_epro"}
            if w in PROTOCOLS["weekly+milestone"]["milestone_weeks"]:
                entry["type"] = "weekly_epro+milestone"
            out.append(entry)
    else:
        raise ValueError(f"unknown protocol: {protocol}")
    return out


CHRONOTYPE_DEFAULT_HOURS = {
    "lark":   ("07:30", "12:00", "18:00"),
    "neutral": ("09:00", "13:00", "20:00"),
    "owl":    ("11:00", "15:00", "21:30"),
}


def suggest_push_time(chronotype: str = "neutral",
                      last_input_hours: list[int] | None = None) -> dict:
    """간단 휴리스틱:
    - chronotype 기본 시간 후보 3개 중,
    - 환자가 최근 입력한 시간(hour list)과 가장 가까운 후보를 1순위로 추천.
    """
    candidates = CHRONOTYPE_DEFAULT_HOURS.get(chronotype,
                                              CHRONOTYPE_DEFAULT_HOURS["neutral"])
    cand_hours = [int(c.split(":")[0]) for c in candidates]
    if last_input_hours:
        avg = sum(last_input_hours) / len(last_input_hours)
        ranked = sorted(zip(candidates, cand_hours),
                        key=lambda x: abs(x[1] - avg))
        primary = ranked[0][0]
        backup = [r[0] for r in ranked[1:]]
    else:
        primary = candidates[1]  # 점심 무렵
        backup = [c for c in candidates if c != primary]
    return {
        "chronotype": chronotype,
        "primary_push": primary,
        "backup_pushes": backup,
        "based_on_history": bool(last_input_hours),
    }


def score_compliance_window(entries: list, weeks_back: int = 4,
                            today: date | None = None) -> float:
    """직전 weeks_back 주 동안 weekly entry 입력률 (0–1)."""
    if today is None:
        today = date.today()
    cutoff = today - timedelta(weeks=weeks_back)
    expected = weeks_back
    actual = 0
    seen_weeks = set()
    for e in entries:
        try:
            d = datetime.fromisoformat(e["date"]).date()
        except Exception:
            continue
        if cutoff <= d <= today:
            wk = (d - cutoff).days // 7
            if wk not in seen_weeks:
                seen_weeks.add(wk)
                actual += 1
    return min(1.0, actual / max(expected, 1))


def dropout_risk(compliance_4w: float, phq9_baseline: int | None = None) -> dict:
    """단순 risk score (0–1).
    - 컴플라이언스 낮을수록 risk↑
    - PHQ-9 baseline 높을수록 risk↑ (≥10 = depression)
    """
    base = 1.0 - compliance_4w
    phq_w = 0.0
    if phq9_baseline is not None:
        if phq9_baseline >= 15:
            phq_w = 0.3
        elif phq9_baseline >= 10:
            phq_w = 0.2
        elif phq9_baseline >= 5:
            phq_w = 0.1
    risk = min(1.0, 0.7 * base + phq_w + (0.0 if compliance_4w >= 0.75 else 0.05))
    if risk >= 0.6:
        level = "HIGH"
    elif risk >= 0.35:
        level = "MEDIUM"
    else:
        level = "LOW"
    return {"score": round(risk, 3), "level": level,
            "compliance_4w": round(compliance_4w, 3),
            "phq9_baseline": phq9_baseline}


if __name__ == "__main__":
    import json
    sch = generate_schedule("weekly+milestone", 12)
    print(json.dumps(sch[:5], ensure_ascii=False, indent=2))
    print(suggest_push_time("owl", last_input_hours=[10, 11, 14, 12]))
    print(dropout_risk(0.5, phq9_baseline=12))
