"""Ingestion + de-identification module.

병원 EDW/REDCap export(CSV) -> 표준화된 in-memory dataframes.
- 환자 attribute, ward, LOS
- POCT BG timestamp + value
- insulin order (basal / bolus / correction / sliding-scale / IV infusion)
- oral 당뇨약 / steroid / TPN / PN / perioperative status
- DKA / HHS dx
- discharge insulin order
- 30-day readmission

de-identification audit trail:
  - 환자 ID → hashed surrogate (SHA-256 truncated)
  - 입원/퇴원일 → study day index (0,1,2,...) (date-shift)
  - free-text 필드는 keep only coded enum
  - audit log: row 수, drop 수, shift offset 기록 (반환 dict)
"""
from __future__ import annotations

import csv
import hashlib
import os
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple

# --------------------------------------------------------------------------- #
# Constants
# --------------------------------------------------------------------------- #

WARDS = ["MICU", "SICU", "CCU", "NICU", "ER", "GW-IM", "GW-SG", "GW-OG"]
DM_TYPES = ["T1DM", "T2DM", "GDM", "Steroid-DM", "Stress-Hyper", "No-DM"]
INSULIN_REGIMENS = [
    "basal-bolus",
    "sliding-scale-only",
    "IV-infusion",
    "basal-only",
    "no-insulin",
]
EPISODE_TYPES = ["DKA", "HHS", "Hypoglycemia", "Hyperglycemia-Persistent", "None"]


# --------------------------------------------------------------------------- #
# Data classes
# --------------------------------------------------------------------------- #

@dataclass
class Patient:
    patient_id: str            # already hashed
    age: int
    sex: str                   # M / F
    dm_type: str
    ward: str
    admit_day: int             # study-day index (date-shifted)
    discharge_day: int
    los_days: int
    perioperative: bool
    steroid_exposure: bool
    tpn_pn: bool
    discharge_insulin: str     # one of INSULIN_REGIMENS
    readmit_30d: bool
    readmit_day: Optional[int] = None
    readmit_reason: Optional[str] = None  # DKA / HHS / Hypo / Hyper / Other


@dataclass
class POCTReading:
    patient_id: str
    study_day: int
    hour: int
    glucose_mg_dl: float


@dataclass
class InsulinOrder:
    patient_id: str
    study_day: int
    regimen: str               # one of INSULIN_REGIMENS
    basal_units: float
    bolus_units: float
    correction_units: float
    iv_rate_uhr: float


@dataclass
class Episode:
    patient_id: str
    episode_type: str          # DKA / HHS / Hypoglycemia / Hyperglycemia-Persistent
    start_day: int
    duration_h: float
    anion_gap_init: float
    bicarb_init: float
    glucose_init: float
    k_supplement: bool
    fluid_l_24h: float
    iv_insulin_rate_uhr: float
    resolved: bool
    resolution_h: Optional[float]
    transition_regimen: str    # post-resolution regimen


@dataclass
class IngestReport:
    n_patients: int
    n_poct: int
    n_orders: int
    n_episodes: int
    date_shift_days: int
    deid_method: str
    dropped_rows: int = 0
    notes: List[str] = field(default_factory=list)


# --------------------------------------------------------------------------- #
# De-identification helpers
# --------------------------------------------------------------------------- #

def hash_patient_id(raw_id: str, salt: str = "INHOSP-GLY-WARD-KOR") -> str:
    """SHA-256 truncated surrogate ID — irreversible without salt."""
    h = hashlib.sha256((salt + "|" + str(raw_id)).encode("utf-8")).hexdigest()
    return "P" + h[:10].upper()


def shift_date_to_day(d: str, anchor: datetime, offset_days: int) -> int:
    """Convert a yyyy-mm-dd date string to a shifted study-day index."""
    try:
        dt = datetime.strptime(d, "%Y-%m-%d")
    except (TypeError, ValueError):
        return 0
    return (dt - anchor).days + offset_days


# --------------------------------------------------------------------------- #
# CSV loaders
# --------------------------------------------------------------------------- #

def _read_csv(path: str) -> List[Dict[str, str]]:
    with open(path, "r", encoding="utf-8") as f:
        return list(csv.DictReader(f))


def _to_int(v: Any, default: int = 0) -> int:
    try:
        return int(float(v))
    except (TypeError, ValueError):
        return default


def _to_float(v: Any, default: float = 0.0) -> float:
    try:
        return float(v)
    except (TypeError, ValueError):
        return default


def _to_bool(v: Any) -> bool:
    if isinstance(v, bool):
        return v
    return str(v).strip().lower() in ("1", "true", "y", "yes", "t")


def load_patients(path: str) -> List[Patient]:
    rows = _read_csv(path)
    out: List[Patient] = []
    for r in rows:
        out.append(Patient(
            patient_id=r["patient_id"],
            age=_to_int(r.get("age")),
            sex=r.get("sex", "M"),
            dm_type=r.get("dm_type", "T2DM"),
            ward=r.get("ward", "GW-IM"),
            admit_day=_to_int(r.get("admit_day")),
            discharge_day=_to_int(r.get("discharge_day")),
            los_days=_to_int(r.get("los_days")),
            perioperative=_to_bool(r.get("perioperative")),
            steroid_exposure=_to_bool(r.get("steroid_exposure")),
            tpn_pn=_to_bool(r.get("tpn_pn")),
            discharge_insulin=r.get("discharge_insulin", "no-insulin"),
            readmit_30d=_to_bool(r.get("readmit_30d")),
            readmit_day=(_to_int(r.get("readmit_day")) if r.get("readmit_day") else None),
            readmit_reason=(r.get("readmit_reason") or None),
        ))
    return out


def load_poct(path: str) -> List[POCTReading]:
    rows = _read_csv(path)
    out = []
    for r in rows:
        out.append(POCTReading(
            patient_id=r["patient_id"],
            study_day=_to_int(r.get("study_day")),
            hour=_to_int(r.get("hour")),
            glucose_mg_dl=_to_float(r.get("glucose_mg_dl")),
        ))
    return out


def load_orders(path: str) -> List[InsulinOrder]:
    rows = _read_csv(path)
    out = []
    for r in rows:
        out.append(InsulinOrder(
            patient_id=r["patient_id"],
            study_day=_to_int(r.get("study_day")),
            regimen=r.get("regimen", "no-insulin"),
            basal_units=_to_float(r.get("basal_units")),
            bolus_units=_to_float(r.get("bolus_units")),
            correction_units=_to_float(r.get("correction_units")),
            iv_rate_uhr=_to_float(r.get("iv_rate_uhr")),
        ))
    return out


def load_episodes(path: str) -> List[Episode]:
    rows = _read_csv(path)
    out = []
    for r in rows:
        res_h_raw = r.get("resolution_h")
        out.append(Episode(
            patient_id=r["patient_id"],
            episode_type=r.get("episode_type", "None"),
            start_day=_to_int(r.get("start_day")),
            duration_h=_to_float(r.get("duration_h")),
            anion_gap_init=_to_float(r.get("anion_gap_init")),
            bicarb_init=_to_float(r.get("bicarb_init")),
            glucose_init=_to_float(r.get("glucose_init")),
            k_supplement=_to_bool(r.get("k_supplement")),
            fluid_l_24h=_to_float(r.get("fluid_l_24h")),
            iv_insulin_rate_uhr=_to_float(r.get("iv_insulin_rate_uhr")),
            resolved=_to_bool(r.get("resolved")),
            resolution_h=(_to_float(res_h_raw) if res_h_raw not in (None, "") else None),
            transition_regimen=r.get("transition_regimen", "no-insulin"),
        ))
    return out


def load_all(data_dir: str) -> Tuple[List[Patient], List[POCTReading],
                                      List[InsulinOrder], List[Episode],
                                      IngestReport]:
    patients = load_patients(os.path.join(data_dir, "patients.csv"))
    poct = load_poct(os.path.join(data_dir, "poct_bg.csv"))
    orders = load_orders(os.path.join(data_dir, "insulin_orders.csv"))
    episodes = load_episodes(os.path.join(data_dir, "episodes.csv"))

    report = IngestReport(
        n_patients=len(patients),
        n_poct=len(poct),
        n_orders=len(orders),
        n_episodes=len(episodes),
        date_shift_days=0,  # already shifted in synthetic generator
        deid_method="SHA-256 surrogate ID + date-shift to study-day index",
        notes=[
            "환자 ID는 SHA-256 으로 hash 된 surrogate (P + 10 hex) 형식",
            "모든 날짜는 study-day index (0=admit anchor)",
            "free-text 필드 제거, 코드화된 enum 만 보존",
        ],
    )
    return patients, poct, orders, episodes, report


# --------------------------------------------------------------------------- #
# Synthetic data generator
# --------------------------------------------------------------------------- #

def _rand_state(seed: int):
    """Tiny deterministic LCG so we don't depend on numpy here."""
    state = [seed & 0xFFFFFFFF]
    def rand() -> float:
        state[0] = (1103515245 * state[0] + 12345) & 0x7FFFFFFF
        return state[0] / 0x7FFFFFFF
    return rand


def generate_synthetic(n_patients: int = 320,
                       out_dir: str = "data",
                       seed: int = 42) -> IngestReport:
    """Create patients.csv / poct_bg.csv / insulin_orders.csv / episodes.csv."""
    os.makedirs(out_dir, exist_ok=True)
    rand = _rand_state(seed)

    def pick(xs, weights=None):
        if weights is None:
            return xs[int(rand() * len(xs)) % len(xs)]
        total = sum(weights)
        r = rand() * total
        acc = 0.0
        for x, w in zip(xs, weights):
            acc += w
            if r <= acc:
                return x
        return xs[-1]

    ward_w = [0.10, 0.07, 0.05, 0.02, 0.18, 0.30, 0.18, 0.10]
    dm_w   = [0.05, 0.55, 0.04, 0.10, 0.16, 0.10]

    patients: List[Patient] = []
    poct: List[POCTReading] = []
    orders: List[InsulinOrder] = []
    episodes: List[Episode] = []

    for i in range(n_patients):
        pid_raw = f"H{2026000 + i}"
        pid = hash_patient_id(pid_raw)
        age = 25 + int(rand() * 60)
        sex = "M" if rand() < 0.55 else "F"
        ward = pick(WARDS, ward_w)
        dm_type = pick(DM_TYPES, dm_w)

        # LOS: ICU longer, GW shorter
        base_los = 5 if ward.startswith("GW") else (10 if ward in ("MICU","SICU","CCU","NICU") else 3)
        los = max(1, int(base_los + rand() * 8))
        admit = 0
        discharge = admit + los

        perioperative = ward in ("SICU", "GW-SG", "GW-OG") and rand() < 0.7
        steroid = rand() < 0.18
        tpn = rand() < 0.10 and ward in ("MICU","SICU","GW-IM","GW-SG")

        # discharge insulin assignment depends on dm_type + ICU status
        if dm_type in ("T1DM",):
            discharge_ins = "basal-bolus"
        elif dm_type == "T2DM":
            discharge_ins = pick(
                ["basal-bolus", "basal-only", "sliding-scale-only", "no-insulin"],
                [0.45, 0.25, 0.10, 0.20],
            )
        elif dm_type == "Steroid-DM":
            discharge_ins = pick(["basal-bolus", "basal-only", "no-insulin"], [0.4, 0.4, 0.2])
        else:
            discharge_ins = pick(["no-insulin", "sliding-scale-only"], [0.85, 0.15])

        # 30-day readmission probability
        readmit_p = 0.05
        if dm_type in ("T1DM", "T2DM"):
            readmit_p += 0.05
        if discharge_ins == "sliding-scale-only":
            readmit_p += 0.08
        readmit = rand() < readmit_p
        readmit_day = None
        readmit_reason = None
        if readmit:
            readmit_day = discharge + 1 + int(rand() * 29)
            readmit_reason = pick(
                ["DKA", "HHS", "Hypoglycemia", "Hyperglycemia", "Other"],
                [0.18, 0.12, 0.20, 0.35, 0.15],
            )

        p = Patient(
            patient_id=pid, age=age, sex=sex, dm_type=dm_type, ward=ward,
            admit_day=admit, discharge_day=discharge, los_days=los,
            perioperative=perioperative, steroid_exposure=steroid, tpn_pn=tpn,
            discharge_insulin=discharge_ins,
            readmit_30d=readmit, readmit_day=readmit_day, readmit_reason=readmit_reason,
        )
        patients.append(p)

        # ----- POCT BG: AC/PC/HS schedule (~4-6/day in GW, q1h IV infusion in ICU) -----
        if dm_type == "No-DM":
            per_day = 1
        elif ward in ("MICU", "SICU", "CCU", "NICU"):
            per_day = 8 + int(rand() * 4)
        else:
            per_day = 4 + int(rand() * 3)

        # baseline glycemic profile
        if dm_type == "T1DM":
            mu = 175
        elif dm_type == "T2DM":
            mu = 165
        elif dm_type == "Steroid-DM":
            mu = 200
        elif dm_type == "Stress-Hyper":
            mu = 190
        elif dm_type == "GDM":
            mu = 140
        else:
            mu = 115

        if steroid:
            mu += 20
        if tpn:
            mu += 15
        if ward in ("MICU","SICU","CCU"):
            mu += 10  # ICU-acuity hyperglycemia

        for d in range(los):
            for k in range(per_day):
                hour = int((24 / per_day) * k + rand() * 2)
                # noise + occasional hypo / hyper
                jitter = (rand() - 0.5) * 60
                g = mu + jitter
                # tail events
                u = rand()
                if u < 0.025:
                    g = 50 + rand() * 18  # hypo
                elif u > 0.97:
                    g = 280 + rand() * 120  # severe hyper
                g = max(28, min(550, g))
                poct.append(POCTReading(pid, d, hour, round(g, 0)))

        # ----- insulin orders (one row per day) -----
        if dm_type == "T1DM":
            regimen = "basal-bolus"
        elif dm_type == "T2DM":
            regimen = pick(INSULIN_REGIMENS, [0.40, 0.25, 0.10, 0.15, 0.10])
        elif dm_type in ("Steroid-DM", "Stress-Hyper", "GDM"):
            regimen = pick(["basal-bolus", "sliding-scale-only", "basal-only", "no-insulin"],
                           [0.35, 0.25, 0.20, 0.20])
        else:
            regimen = "no-insulin"

        # ICU stress: IV infusion more common (especially MICU/SICU/CCU)
        if ward in ("MICU","SICU","CCU") and dm_type in ("T1DM","T2DM","Steroid-DM","Stress-Hyper") and rand() < 0.35:
            regimen = "IV-infusion"

        for d in range(los):
            basal = 0.0
            bolus = 0.0
            corr = 0.0
            ivrate = 0.0
            if regimen == "basal-bolus":
                basal = round(8 + rand() * 24, 1)
                bolus = round(4 + rand() * 16, 1)
                corr = round(rand() * 6, 1)
            elif regimen == "basal-only":
                basal = round(8 + rand() * 20, 1)
            elif regimen == "sliding-scale-only":
                corr = round(2 + rand() * 8, 1)
            elif regimen == "IV-infusion":
                ivrate = round(1 + rand() * 5, 1)
                corr = round(rand() * 2, 1)
            orders.append(InsulinOrder(pid, d, regimen, basal, bolus, corr, ivrate))

        # ----- episodes -----
        # DKA more in T1DM; HHS more in T2DM elderly; hypo across the board
        if dm_type == "T1DM" and rand() < 0.15:
            etype = "DKA"
        elif dm_type == "T2DM" and age > 60 and rand() < 0.05:
            etype = "HHS"
        elif rand() < 0.06:
            etype = "Hypoglycemia"
        elif rand() < 0.10:
            etype = "Hyperglycemia-Persistent"
        else:
            etype = "None"

        if etype != "None":
            res_h = None
            resolved = True
            if etype == "DKA":
                anion = 18 + rand() * 12
                bicarb = 8 + rand() * 8
                glu_init = 350 + rand() * 250
                iv = round(0.1 + rand() * 0.1, 2)
                k_sup = rand() < 0.85
                fluid = round(4 + rand() * 3, 1)
                res_h = round(10 + rand() * 20, 1)
                trans = "basal-bolus"
            elif etype == "HHS":
                anion = 10 + rand() * 8
                bicarb = 18 + rand() * 6
                glu_init = 600 + rand() * 250
                iv = round(0.05 + rand() * 0.08, 2)
                k_sup = rand() < 0.8
                fluid = round(6 + rand() * 4, 1)
                res_h = round(18 + rand() * 24, 1)
                trans = "basal-bolus"
            elif etype == "Hypoglycemia":
                anion = 8 + rand() * 4
                bicarb = 22 + rand() * 4
                glu_init = 40 + rand() * 25
                iv = 0.0
                k_sup = False
                fluid = round(0.5 + rand() * 0.5, 2)
                res_h = round(0.25 + rand() * 0.5, 2)
                trans = pick(INSULIN_REGIMENS, [0.3,0.25,0.05,0.2,0.2])
            else:  # Hyperglycemia-Persistent
                anion = 9 + rand() * 4
                bicarb = 22 + rand() * 4
                glu_init = 220 + rand() * 100
                iv = round(rand() * 1.5, 2)
                k_sup = False
                fluid = round(1 + rand() * 2, 1)
                res_h = round(6 + rand() * 18, 1)
                trans = pick(["basal-bolus","sliding-scale-only","basal-only"], [0.5,0.3,0.2])

            # ~8% unresolved
            if rand() < 0.08:
                resolved = False
                res_h = None

            episodes.append(Episode(
                patient_id=pid,
                episode_type=etype,
                start_day=int(rand() * max(1, los - 1)),
                duration_h=round((res_h or (12 + rand()*36)), 1),
                anion_gap_init=round(anion, 1),
                bicarb_init=round(bicarb, 1),
                glucose_init=round(glu_init, 0),
                k_supplement=k_sup,
                fluid_l_24h=fluid,
                iv_insulin_rate_uhr=iv,
                resolved=resolved,
                resolution_h=res_h,
                transition_regimen=trans,
            ))

    # ----- write CSVs -----
    def _write(path: str, header: List[str], rows: List[List[Any]]):
        with open(path, "w", encoding="utf-8", newline="") as f:
            w = csv.writer(f)
            w.writerow(header)
            w.writerows(rows)

    _write(
        os.path.join(out_dir, "patients.csv"),
        ["patient_id","age","sex","dm_type","ward","admit_day","discharge_day","los_days",
         "perioperative","steroid_exposure","tpn_pn","discharge_insulin",
         "readmit_30d","readmit_day","readmit_reason"],
        [[p.patient_id,p.age,p.sex,p.dm_type,p.ward,p.admit_day,p.discharge_day,p.los_days,
          int(p.perioperative),int(p.steroid_exposure),int(p.tpn_pn),p.discharge_insulin,
          int(p.readmit_30d),(p.readmit_day if p.readmit_day is not None else ""),
          (p.readmit_reason or "")]
         for p in patients],
    )
    _write(
        os.path.join(out_dir, "poct_bg.csv"),
        ["patient_id","study_day","hour","glucose_mg_dl"],
        [[r.patient_id,r.study_day,r.hour,r.glucose_mg_dl] for r in poct],
    )
    _write(
        os.path.join(out_dir, "insulin_orders.csv"),
        ["patient_id","study_day","regimen","basal_units","bolus_units","correction_units","iv_rate_uhr"],
        [[o.patient_id,o.study_day,o.regimen,o.basal_units,o.bolus_units,o.correction_units,o.iv_rate_uhr]
         for o in orders],
    )
    _write(
        os.path.join(out_dir, "episodes.csv"),
        ["patient_id","episode_type","start_day","duration_h",
         "anion_gap_init","bicarb_init","glucose_init","k_supplement","fluid_l_24h",
         "iv_insulin_rate_uhr","resolved","resolution_h","transition_regimen"],
        [[e.patient_id,e.episode_type,e.start_day,e.duration_h,
          e.anion_gap_init,e.bicarb_init,e.glucose_init,int(e.k_supplement),e.fluid_l_24h,
          e.iv_insulin_rate_uhr,int(e.resolved),
          (e.resolution_h if e.resolution_h is not None else ""), e.transition_regimen]
         for e in episodes],
    )

    return IngestReport(
        n_patients=len(patients),
        n_poct=len(poct),
        n_orders=len(orders),
        n_episodes=len(episodes),
        date_shift_days=0,
        deid_method="synthetic — surrogate IDs only (no real PHI ever present)",
        notes=[
            f"seed={seed}",
            "patients/POCT/orders/episodes CSV 생성 완료",
            "한국 종합병원 ward 분포 + DM type 분포 가정",
        ],
    )
