"""
synthetic_generator.py
======================
RodentCGMTel 합성 telemetry 데이터 생성기. seed=42 고정 → 재현 가능.

- 5 model + 1 control = 6 CSV 생성
- 각 4-6마리, 7일 × 5분 interval = 2016 samples/animal
- 모델별 baseline + nocturnal feeder 야간 BG 상승 + dawn phenomenon + noise
- 표준 schema: timestamp, animal_id, model, bg_mgdl, group, source_format

사용:
    python3 synthetic_generator.py [--out_dir synthetic]
"""

from __future__ import annotations

import argparse
import os
import sys
from datetime import datetime, timedelta

import numpy as np
import pandas as pd

# 모듈 직접 실행 대응
_HERE = os.path.dirname(os.path.abspath(__file__))
_PARENT = os.path.dirname(_HERE)
if _PARENT not in sys.path:
    sys.path.insert(0, _PARENT)
from species_reference import SPECIES_REFERENCE  # noqa: E402


SEED = 42
INTERVAL_MIN = 5
DAYS = 7

# 생성할 6개 코호트: (model_key, group_label, n_animals, animal_prefix)
COHORTS = [
    ("C57BL/6J", "control", 6, "WT"),
    ("db/db (BKS.Cg-Dock7m+/+ Leprdb)", "T2DM_db", 5, "DB"),
    ("ob/ob (B6.Cg-Lepob)", "T2DM_ob", 5, "OB"),
    ("STZ T1DM mouse", "T1DM_stz", 5, "STZ"),
    ("DIO HFD (C57BL/6J 60% HFD)", "DIO_HFD", 6, "HFD"),
    ("ZDF (Zucker Diabetic Fatty)", "T2DM_zdf", 4, "ZDF"),
]


def _diurnal_offset(hour_decimal: float, species: str) -> float:
    """야행성 rodent: 19-07h(다크) 활동기 BG 상승, 0~+25 mg/dL. dawn 4-8h 추가 +5~+15."""
    # 12:12 cycle, lights on 07:00
    if hour_decimal >= 19 or hour_decimal < 7:
        dark = 18.0
    else:
        dark = 0.0
    # dawn surge 4-8h
    if 4 <= hour_decimal < 8:
        dawn = 10.0 * np.sin(np.pi * (hour_decimal - 4) / 4)
    else:
        dawn = 0.0
    # rat은 mouse보다 진폭 약함
    scale = 0.7 if species == "rat" else 1.0
    return scale * (dark + dawn)


def _meal_bumps(n_samples: int, interval_min: int, rng: np.random.Generator,
                amplitude_range: tuple = (15, 35)) -> np.ndarray:
    """야간 식이 bout — 다크 phase 중 3-5번 식이당 +amplitude, 30-60min decay."""
    bumps = np.zeros(n_samples)
    samples_per_day = int(24 * 60 / interval_min)
    n_days = max(1, n_samples // samples_per_day)
    for day in range(n_days):
        n_bouts = rng.integers(3, 6)
        # 다크 phase: 19h-31h(=다음날 07h) → sample index
        dark_start = int((19 * 60) / interval_min) + day * samples_per_day
        dark_end = int((31 * 60) / interval_min) + day * samples_per_day
        dark_end = min(dark_end, n_samples - 1)
        if dark_end <= dark_start:
            continue
        bout_idx = rng.integers(dark_start, dark_end, size=n_bouts)
        for idx in bout_idx:
            amp = rng.uniform(*amplitude_range)
            decay_samples = int(rng.integers(6, 13))  # 30-65min
            end_idx = min(idx + decay_samples, n_samples)
            decay = np.exp(-np.arange(end_idx - idx) / (decay_samples / 2))
            bumps[idx:end_idx] += amp * decay
    return bumps


def _generate_animal_series(model_key: str, animal_id: str, group: str,
                             start_time: datetime, days: int, interval_min: int,
                             rng: np.random.Generator) -> pd.DataFrame:
    ref = SPECIES_REFERENCE[model_key]
    species = ref["species"]
    # baseline = random_bg 평균
    rb_low, rb_high = ref["random_bg"]
    baseline = rng.uniform(rb_low, rb_high)
    # CV%
    cv_low, cv_high = ref["expected_cv_pct"]
    target_cv = rng.uniform(cv_low, cv_high) / 100.0
    sd = baseline * target_cv * 0.6  # noise 일부는 식이/diurnal로 추가

    n_samples = int(days * 24 * 60 / interval_min)
    times = [start_time + timedelta(minutes=i * interval_min) for i in range(n_samples)]
    hours = np.array([t.hour + t.minute / 60.0 for t in times])

    # OU-like AR(1) noise (자기상관)
    noise = np.zeros(n_samples)
    phi = 0.85
    eps = rng.normal(0.0, sd * np.sqrt(1 - phi ** 2), n_samples)
    noise[0] = rng.normal(0.0, sd)
    for i in range(1, n_samples):
        noise[i] = phi * noise[i - 1] + eps[i]

    diurnal = np.array([_diurnal_offset(h, species) for h in hours])

    # T1DM/T2DM 강한 모델은 meal bump 진폭 큼
    if ref["strain_type"] in ("T1DM_chemical", "T2DM_obese", "polygenic_T2DM"):
        bump_amp = (25, 60)
    elif ref["strain_type"] == "diet_induced_obese":
        bump_amp = (20, 40)
    elif ref["strain_type"] == "non_obese_T2DM":
        bump_amp = (15, 35)
    else:
        bump_amp = (10, 25)
    bumps = _meal_bumps(n_samples, interval_min, rng, bump_amp)

    bg = baseline + noise + diurnal + bumps
    # physiological floor/ceiling
    bg = np.clip(bg, ref["severe_hypo"] * 0.6, ref["severe_hyper"] * 1.4)

    df = pd.DataFrame({
        "timestamp": times,
        "animal_id": animal_id,
        "model": model_key,
        "bg_mgdl": np.round(bg, 1),
        "group": group,
        "source_format": "synthetic",
    })
    return df


def generate_cohort(model_key: str, group: str, n_animals: int, prefix: str,
                    rng: np.random.Generator,
                    days: int = DAYS, interval_min: int = INTERVAL_MIN) -> pd.DataFrame:
    start = datetime(2026, 5, 1, 12, 0, 0)
    parts = []
    for i in range(n_animals):
        aid = f"{prefix}-{i + 1:02d}"
        sub_rng = np.random.default_rng(rng.integers(0, 2**31 - 1))
        parts.append(_generate_animal_series(model_key, aid, group, start,
                                             days, interval_min, sub_rng))
    return pd.concat(parts, ignore_index=True)


def main(out_dir: str) -> None:
    os.makedirs(out_dir, exist_ok=True)
    master_rng = np.random.default_rng(SEED)
    written = []
    for model_key, group, n, prefix in COHORTS:
        cohort_rng = np.random.default_rng(master_rng.integers(0, 2**31 - 1))
        df = generate_cohort(model_key, group, n, prefix, cohort_rng)
        safe = group.replace("/", "_")
        path = os.path.join(out_dir, f"cohort_{safe}.csv")
        df.to_csv(path, index=False)
        written.append((path, len(df)))
        print(f"[+] {path}  rows={len(df)}  animals={n}")
    print(f"\nTotal cohorts: {len(written)}  total_rows={sum(w[1] for w in written)}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--out_dir", default=os.path.join(_HERE, "synthetic"))
    args = parser.parse_args()
    main(args.out_dir)
