"""baseline-adjusted Hy's law + eDISH plot 계산.

기준:
  - 일반 Hy's law (FDA 2009): ALT >=3x ULN AND TBL >=2x ULN AND ALP <2x ULN
  - MASH-baseline-adjusted: ALT >=3x baseline AND TBL >=2x ULN

eDISH plot: peak ALT/ULN (x) vs peak TBL/ULN (y), log-log 4-quadrant.
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

from .ingest import Patient, ULN, baseline_adjusted_uln


@dataclass
class HysCase:
    pid: str
    arm: str
    drug_class: str
    peak_alt: float
    peak_tbl: float
    peak_alp: float
    alt_ratio_uln: float          # peak ALT / standard ULN
    alt_ratio_baseline: float     # peak ALT / baseline ALT
    tbl_ratio_uln: float
    alp_ratio_uln: float
    classical_hys: bool
    baseline_adj_hys: bool
    quadrant: str                 # eDISH quadrant label
    r_ratio: Optional[float]      # CIOMS R-ratio
    pattern: str                  # hepatocellular / cholestatic / mixed


def _peak(vals: List[Optional[float]]) -> Optional[float]:
    cleaned = [v for v in vals if v is not None]
    return max(cleaned) if cleaned else None


def edish_quadrant(alt_x_uln: float, tbl_x_uln: float) -> str:
    """eDISH 4-quadrant 라벨 (Watkins 2008)."""
    if alt_x_uln >= 3 and tbl_x_uln >= 2:
        return "Hy_zone"          # 우상단
    if alt_x_uln >= 3 and tbl_x_uln < 2:
        return "Temple_corollary"  # 우하단: ALT 상승, TBL 정상
    if alt_x_uln < 3 and tbl_x_uln >= 2:
        return "Cholestatic_zone"  # 좌상단
    return "Normal_zone"


def cioms_r_ratio(alt_x_uln: float, alp_x_uln: float) -> Optional[float]:
    if alp_x_uln is None or alp_x_uln <= 0:
        return None
    return alt_x_uln / alp_x_uln


def classify_pattern(r: Optional[float]) -> str:
    if r is None:
        return "indeterminate"
    if r >= 5:
        return "hepatocellular"
    if r <= 2:
        return "cholestatic"
    return "mixed"


def evaluate_patients(patients: Dict[str, Patient]) -> List[HysCase]:
    bl_uln = baseline_adjusted_uln(patients, "ALT")
    cases: List[HysCase] = []
    for pid, p in patients.items():
        peak_alt = _peak([tp.get("ALT") for tp in p.timepoints])
        peak_tbl = _peak([tp.get("TBL") for tp in p.timepoints])
        peak_alp = _peak([tp.get("ALP") for tp in p.timepoints])
        if peak_alt is None or peak_tbl is None or peak_alp is None:
            continue
        baseline_alt = p.baseline_value("ALT") or ULN["ALT"]
        alt_x_uln = peak_alt / ULN["ALT"]
        alt_x_base = peak_alt / max(baseline_alt, 1.0)
        tbl_x_uln = peak_tbl / ULN["TBL"]
        alp_x_uln = peak_alp / ULN["ALP"]

        classical = (alt_x_uln >= 3) and (tbl_x_uln >= 2) and (alp_x_uln < 2)
        baseline_adj = (alt_x_base >= 3) and (tbl_x_uln >= 2)

        quad = edish_quadrant(alt_x_uln, tbl_x_uln)
        r = cioms_r_ratio(alt_x_uln, alp_x_uln)
        pat = classify_pattern(r)

        cases.append(HysCase(
            pid=pid, arm=p.arm, drug_class=p.drug_class,
            peak_alt=peak_alt, peak_tbl=peak_tbl, peak_alp=peak_alp,
            alt_ratio_uln=alt_x_uln, alt_ratio_baseline=alt_x_base,
            tbl_ratio_uln=tbl_x_uln, alp_ratio_uln=alp_x_uln,
            classical_hys=classical, baseline_adj_hys=baseline_adj,
            quadrant=quad, r_ratio=r, pattern=pat,
        ))
    return cases


def summarize_hys(cases: List[HysCase]) -> Dict[str, int]:
    return {
        "n_total": len(cases),
        "n_classical_hys": sum(1 for c in cases if c.classical_hys),
        "n_baseline_adj_hys": sum(1 for c in cases if c.baseline_adj_hys),
        "n_hy_zone": sum(1 for c in cases if c.quadrant == "Hy_zone"),
        "n_temple": sum(1 for c in cases if c.quadrant == "Temple_corollary"),
        "n_cholestatic": sum(1 for c in cases if c.quadrant == "Cholestatic_zone"),
        "n_hepatocellular_pattern": sum(1 for c in cases if c.pattern == "hepatocellular"),
        "n_mixed_pattern": sum(1 for c in cases if c.pattern == "mixed"),
        "n_cholestatic_pattern": sum(1 for c in cases if c.pattern == "cholestatic"),
    }


def to_edish_points(cases: List[HysCase]) -> List[Tuple[float, float, str, str]]:
    """(x=ALT/ULN, y=TBL/ULN, arm, drug_class) 포인트 리스트."""
    return [(c.alt_ratio_uln, c.tbl_ratio_uln, c.arm, c.drug_class) for c in cases]


def render_edish_ascii(cases: List[HysCase], width: int = 50, height: int = 16) -> str:
    """matplotlib 없는 환경을 위한 ASCII eDISH (log-log)."""
    import math
    grid = [[" "] * width for _ in range(height)]
    xs = [c.alt_ratio_uln for c in cases if c.alt_ratio_uln > 0]
    ys = [c.tbl_ratio_uln for c in cases if c.tbl_ratio_uln > 0]
    if not xs or not ys:
        return "(no data)"
    xmin, xmax = math.log10(0.3), math.log10(max(max(xs), 30))
    ymin, ymax = math.log10(0.3), math.log10(max(max(ys), 10))
    for c in cases:
        if c.alt_ratio_uln <= 0 or c.tbl_ratio_uln <= 0:
            continue
        lx = math.log10(c.alt_ratio_uln)
        ly = math.log10(c.tbl_ratio_uln)
        xi = int((lx - xmin) / (xmax - xmin) * (width - 1))
        yi = int((ly - ymin) / (ymax - ymin) * (height - 1))
        xi = min(max(xi, 0), width - 1)
        yi = min(max(yi, 0), height - 1)
        marker = "H" if c.classical_hys else ("*" if c.alt_ratio_uln >= 3 else ".")
        grid[height - 1 - yi][xi] = marker
    lines = ["".join(row) for row in grid]
    legend = (
        "eDISH (ALT/ULN x  TBL/ULN y, log-log)  H=Hy's case  *=ALT>=3xULN  .=other\n"
        "Quadrants: Hy zone (ALT>=3, TBL>=2) | Temple (ALT>=3, TBL<2)\n"
    )
    return legend + "\n".join(lines)
