"""6개 class별 safety panel 로직.

각 class에 대해 efficacy/toxicity marker별 baseline 대비 변화율과
class-effect 시그널(예: THR-β TSH 억제, FGF21 IGF-1 감소,
ACC TG 상승 >=30%, FXR LDL 상승+pruritus) 자동 flag.
"""
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Optional

from .ingest import Patient, CLASS_PANELS


# class별 시그널 threshold (참고용·문헌기반 보수치)
CLASS_SIGNAL_RULES = {
    "THRb": {
        "TSH_suppression_pct": -30,         # TSH baseline대비 30% 감소: thyroid axis 시그널
        "HR_increase_bpm": 8,
        "SHBG_rise_pct": 30,                # central thyromimetic on-target
    },
    "FGF21": {
        "IGF1_decline_pct": -25,            # IGF-1 감소
        "P1NP_decline_pct": -20,            # bone formation suppression
        "CTX_rise_pct": 25,                 # bone resorption
        "uric_acid_rise_pct": 15,
    },
    "ACC": {
        "TG_rise_pct": 30,                  # class-effect hypertriglyceridemia
        "HDL_decline_pct": -10,
    },
    "FXR": {
        "LDL_rise_pct": 15,
        "pruritus_VAS_threshold": 4,        # 0-10
        "ALP_rise_pct": 20,
    },
    "GLP1RA": {
        "lipase_rise_xULN": 3,
        "amylase_rise_xULN": 3,
        "calcitonin_threshold_pgmL": 35,
    },
    "GIPglucagon": {
        "fasting_glu_rise_pct": 10,
        "urea_rise_pct": 20,
    },
}


@dataclass
class PanelSignal:
    pid: str
    arm: str
    drug_class: str
    marker: str
    role: str           # efficacy / toxicity
    baseline: Optional[float]
    peak: Optional[float]
    pct_change: Optional[float]
    flag: bool
    rule: str


def _pct_change(baseline: Optional[float], peak: Optional[float]) -> Optional[float]:
    if baseline is None or peak is None or baseline == 0:
        return None
    return (peak - baseline) / baseline * 100.0


def _baseline_value(series: List[tuple]) -> Optional[float]:
    if not series:
        return None
    series_sorted = sorted(series, key=lambda x: x[0])
    return series_sorted[0][1]


def _peak_value(series: List[tuple]) -> Optional[float]:
    if not series:
        return None
    return max(v for _, v in series)


def _trough_value(series: List[tuple]) -> Optional[float]:
    if not series:
        return None
    return min(v for _, v in series)


def evaluate_class(patients: Dict[str, Patient], drug_class: str) -> List[PanelSignal]:
    panel_def = CLASS_PANELS.get(drug_class)
    if panel_def is None:
        return []
    rules = CLASS_SIGNAL_RULES.get(drug_class, {})
    signals: List[PanelSignal] = []
    members = [p for p in patients.values() if p.drug_class == drug_class]
    for p in members:
        for role, markers in panel_def.items():
            for m in markers:
                series = p.panel.get(m, [])
                # ALT/AST 등 LFT는 timepoints에서도 직접 가져옴
                if not series and m in {"ALT", "AST", "ALP", "TBL"}:
                    series = [(tp["week"], tp[m]) for tp in p.timepoints if tp.get(m) is not None]
                baseline = _baseline_value(series)
                # 시그널은 marker별 방향성을 고려 (감소시 trough, 증가시 peak)
                if m in {"TSH", "IGF1", "HDL", "P1NP"}:
                    extreme = _trough_value(series)
                else:
                    extreme = _peak_value(series)
                pct = _pct_change(baseline, extreme)
                flag = False
                rule_applied = ""
                if pct is not None:
                    if m == "TSH" and drug_class == "THRb":
                        thr = rules.get("TSH_suppression_pct", -30)
                        flag = pct <= thr
                        rule_applied = f"TSH dec >= {abs(thr)}%"
                    elif m == "IGF1" and drug_class == "FGF21":
                        thr = rules.get("IGF1_decline_pct", -25)
                        flag = pct <= thr
                        rule_applied = f"IGF-1 dec >= {abs(thr)}%"
                    elif m == "TG" and drug_class == "ACC":
                        thr = rules.get("TG_rise_pct", 30)
                        flag = pct >= thr
                        rule_applied = f"TG rise >= {thr}%"
                    elif m == "LDL" and drug_class == "FXR":
                        thr = rules.get("LDL_rise_pct", 15)
                        flag = pct >= thr
                        rule_applied = f"LDL rise >= {thr}%"
                    elif m == "HR" and drug_class == "THRb":
                        if baseline is not None and extreme is not None:
                            flag = (extreme - baseline) >= rules.get("HR_increase_bpm", 8)
                            rule_applied = f"HR rise >= {rules['HR_increase_bpm']} bpm"
                    elif m == "P1NP" and drug_class == "FGF21":
                        flag = pct <= rules.get("P1NP_decline_pct", -20)
                        rule_applied = "P1NP dec >= 20%"
                    elif m == "CTX" and drug_class == "FGF21":
                        flag = pct >= rules.get("CTX_rise_pct", 25)
                        rule_applied = "CTX rise >= 25%"
                signals.append(PanelSignal(
                    pid=p.pid, arm=p.arm, drug_class=drug_class,
                    marker=m, role=role, baseline=baseline,
                    peak=extreme, pct_change=pct,
                    flag=flag, rule=rule_applied,
                ))
    return signals


def summarize_class_signals(signals: List[PanelSignal]) -> Dict[str, Dict[str, int]]:
    out: Dict[str, Dict[str, int]] = {}
    for s in signals:
        if not s.flag:
            continue
        key = f"{s.drug_class}/{s.marker}"
        rec = out.setdefault(key, {"total": 0, "drug": 0, "placebo": 0})
        rec["total"] += 1
        if s.arm.lower() == "placebo":
            rec["placebo"] += 1
        else:
            rec["drug"] += 1
    return out


def evaluate_all(patients: Dict[str, Patient]) -> Dict[str, List[PanelSignal]]:
    out: Dict[str, List[PanelSignal]] = {}
    classes = sorted({p.drug_class for p in patients.values()})
    for c in classes:
        if c in CLASS_PANELS:
            out[c] = evaluate_class(patients, c)
    return out
