"""LFT + class별 safety panel ingest + baseline 보정.

MASH 모집단은 baseline ALT/AST가 ULN을 상회하는 경우가 흔하므로
일반 ULN 기준과 MASH-baseline 두 기준을 병기한다.
"""
from __future__ import annotations

import csv
import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple

# 일반 ULN (성인 남/여 통합 보수치, FDA 2009 DILI Guidance 표준)
ULN = {
    "ALT": 40.0,    # U/L
    "AST": 40.0,
    "ALP": 130.0,
    "TBL": 1.2,     # mg/dL
    "INR": 1.2,
    "ALB": 3.5,     # lower limit (g/dL), low가 이상
    "PLT": 150.0,   # x10^3/uL, low가 이상
}

# class별 safety panel 정의 (efficacy / toxicity 마커 구분)
CLASS_PANELS: Dict[str, Dict[str, List[str]]] = {
    "THRb": {
        "efficacy": ["LDL", "TG", "HDL"],
        "toxicity": ["TSH", "T3", "T4", "rT3", "SHBG", "BMD", "HR", "prolactin"],
    },
    "FGF21": {
        "efficacy": ["ALT", "PRO_C3", "ELF"],
        "toxicity": ["IGF1", "GH", "P1NP", "CTX", "sclerostin", "uric_acid", "prolactin"],
    },
    "ACC": {
        "efficacy": ["LFC", "ALT"],
        "toxicity": ["TG", "HDL"],   # ACC inhibitor의 TG 상승은 class-effect
    },
    "FXR": {
        "efficacy": ["ALP", "GGT", "fibroscan"],
        "toxicity": ["LDL", "pruritus_VAS", "ALP"],
    },
    "GLP1RA": {
        "efficacy": ["weight", "HbA1c"],
        "toxicity": ["amylase", "lipase", "calcitonin"],
    },
    "GIPglucagon": {
        "efficacy": ["weight", "ALT"],
        "toxicity": ["fasting_glu", "urea", "BUN"],
    },
}


@dataclass
class Patient:
    """단일 환자 LFT + panel 시계열."""
    pid: str
    arm: str
    drug_class: str
    baseline: Dict[str, float] = field(default_factory=dict)
    timepoints: List[Dict[str, float]] = field(default_factory=list)
    panel: Dict[str, List[Tuple[int, float]]] = field(default_factory=dict)

    def latest(self, marker: str) -> Optional[float]:
        for tp in reversed(self.timepoints):
            v = tp.get(marker)
            if v is not None and not (isinstance(v, float) and math.isnan(v)):
                return v
        return None

    def peak(self, marker: str) -> Optional[float]:
        vals = [tp.get(marker) for tp in self.timepoints if tp.get(marker) is not None]
        return max(vals) if vals else None

    def baseline_value(self, marker: str) -> Optional[float]:
        if marker in self.baseline:
            return self.baseline[marker]
        if self.timepoints:
            return self.timepoints[0].get(marker)
        return None


def _safe_float(v: str) -> Optional[float]:
    if v is None or v == "" or v.lower() == "na":
        return None
    try:
        return float(v)
    except (TypeError, ValueError):
        return None


def load_lft_csv(path: str | Path) -> Dict[str, Patient]:
    """LFT 시계열 CSV 로드. long-format(pid,arm,drug_class,week,ALT,AST,ALP,TBL,INR,ALB,PLT)."""
    patients: Dict[str, Patient] = {}
    path = Path(path)
    with path.open(newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            pid = row["pid"]
            if pid not in patients:
                patients[pid] = Patient(
                    pid=pid,
                    arm=row.get("arm", "unknown"),
                    drug_class=row.get("drug_class", "unknown"),
                )
            tp = {
                "week": int(float(row["week"])),
                "ALT": _safe_float(row.get("ALT", "")),
                "AST": _safe_float(row.get("AST", "")),
                "ALP": _safe_float(row.get("ALP", "")),
                "TBL": _safe_float(row.get("TBL", "")),
                "INR": _safe_float(row.get("INR", "")),
                "ALB": _safe_float(row.get("ALB", "")),
                "PLT": _safe_float(row.get("PLT", "")),
            }
            patients[pid].timepoints.append(tp)
    for p in patients.values():
        p.timepoints.sort(key=lambda x: x["week"])
        if p.timepoints:
            first = p.timepoints[0]
            p.baseline = {k: v for k, v in first.items() if k != "week" and v is not None}
    return patients


def load_panel_csv(path: str | Path, patients: Dict[str, Patient], drug_class: str) -> None:
    """class별 panel CSV 로드. long-format (pid, week, marker, value)."""
    path = Path(path)
    if not path.exists():
        return
    with path.open(newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            pid = row["pid"]
            if pid not in patients:
                continue
            marker = row["marker"]
            wk = int(float(row["week"]))
            val = _safe_float(row.get("value", ""))
            if val is None:
                continue
            patients[pid].panel.setdefault(marker, []).append((wk, val))
    for p in patients.values():
        for m in p.panel:
            p.panel[m].sort(key=lambda x: x[0])


def baseline_adjusted_uln(patients: Dict[str, Patient], marker: str = "ALT") -> Dict[str, float]:
    """각 환자의 baseline ALT를 '개인 ULN'으로 채택 (FDA 2009 MASH 보정 권고).

    baseline이 일반 ULN 미만이면 일반 ULN을, 이상이면 baseline을 채택.
    """
    out: Dict[str, float] = {}
    standard = ULN.get(marker, 40.0)
    for p in patients.values():
        b = p.baseline_value(marker)
        if b is None:
            out[p.pid] = standard
        else:
            out[p.pid] = max(b, standard * 0.5)  # 너무 낮은 baseline 방어
    return out


def summarize_baseline(patients: Dict[str, Patient]) -> Dict[str, Dict[str, float]]:
    """모집단 baseline 요약 (mean/median/over-ULN ratio)."""
    out: Dict[str, Dict[str, float]] = {}
    for marker in ("ALT", "AST", "ALP", "TBL"):
        vals = [p.baseline_value(marker) for p in patients.values()
                if p.baseline_value(marker) is not None]
        if not vals:
            continue
        vals.sort()
        n = len(vals)
        mean = sum(vals) / n
        median = vals[n // 2] if n % 2 else (vals[n // 2 - 1] + vals[n // 2]) / 2
        over = sum(1 for v in vals if v > ULN.get(marker, 1.0)) / n
        out[marker] = {"n": n, "mean": mean, "median": median, "over_uln_ratio": over}
    return out
