"""AE ingest + MedDRA mini 매핑 + de-identification.

참고용·연구용 — Not for clinical decision. 합성 mini-MedDRA 사전을 사용한다.
"""
from __future__ import annotations

import csv
import hashlib
import os
import re
from dataclasses import dataclass
from typing import Iterable


DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")


@dataclass
class MeddraEntry:
    pt_code: str
    pt_term: str
    llt_term: str
    soc: str
    panel: str


def load_meddra_mini(path: str | None = None) -> list[MeddraEntry]:
    """합성 mini-MedDRA 사전 로드."""
    if path is None:
        path = os.path.join(DATA_DIR, "meddra_mini.csv")
    out: list[MeddraEntry] = []
    with open(path, encoding="utf-8") as f:
        rd = csv.DictReader(f)
        for row in rd:
            out.append(
                MeddraEntry(
                    pt_code=row["pt_code"],
                    pt_term=row["pt_term"],
                    llt_term=row["llt_term"],
                    soc=row["soc"],
                    panel=row["panel"],
                )
            )
    return out


_NORM_RE = re.compile(r"[^a-z0-9 ]+")


def _normalize(text: str) -> str:
    text = text.lower().strip()
    text = _NORM_RE.sub(" ", text)
    text = re.sub(r"\s+", " ", text)
    return text


# Free-text synonym → canonical PT term hints
_SYNONYM_HINTS: dict[str, str] = {
    "n v": "Nausea",
    "feeling sick": "Nausea",
    "queasy": "Nausea",
    "throwing up": "Vomiting",
    "emesis": "Vomiting",
    "loose stools": "Diarrhoea",
    "watery stools": "Diarrhoea",
    "diarrhea": "Diarrhoea",
    "upper abd pain": "Abdominal pain upper",
    "stomach pain upper": "Abdominal pain upper",
    "epigastric pain": "Epigastric discomfort",
    "acute panc": "Acute pancreatitis",
    "acute cholecystitis": "Cholecystitis acute",
    "gallbladder inflammation": "Cholecystitis acute",
    "gallstone": "Cholelithiasis",
    "gallstones": "Cholelithiasis",
    "isr at thigh": "Injection site reaction",
    "site reaction": "Injection site reaction",
    "shot hurts": "Injection site pain",
    "dr worsening": "Diabetic retinopathy",
    "diabetic retinopathy progression": "Diabetic retinopathy",
    "thoughts of self harm": "Suicidal ideation",
    "self harm": "Self injurious behaviour",
    "ha": "Headache",
    "tired": "Fatigue",
    "appetite loss": "Decreased appetite",
    "high calcitonin": "Calcitonin increased",
    "elevated calcitonin": "Calcitonin increased",
    "lap chole": "Cholecystectomy",
    "gallbladder removal": "Cholecystectomy",
}


def map_freetext(text: str, meddra: list[MeddraEntry]) -> tuple[str | None, str | None, float]:
    """Return (pt_code, pt_term, confidence) for a free-text AE description.

    매우 단순한 규칙 기반 매핑 (라이센스 우회용 합성 사전).
    """
    norm = _normalize(text)
    if not norm:
        return (None, None, 0.0)

    # Exact PT term match
    for e in meddra:
        if norm == _normalize(e.pt_term):
            return (e.pt_code, e.pt_term, 1.0)

    # Exact LLT match
    for e in meddra:
        if norm == _normalize(e.llt_term):
            return (e.pt_code, e.pt_term, 0.95)

    # Synonym hint
    for hint, target_pt in _SYNONYM_HINTS.items():
        if hint in norm:
            for e in meddra:
                if e.pt_term == target_pt:
                    return (e.pt_code, e.pt_term, 0.85)

    # Token-overlap heuristic
    tokens = set(norm.split())
    best: tuple[float, MeddraEntry | None] = (0.0, None)
    for e in meddra:
        cand_tokens = set(_normalize(e.pt_term).split()) | set(_normalize(e.llt_term).split())
        if not cand_tokens:
            continue
        overlap = len(tokens & cand_tokens) / max(1, len(cand_tokens))
        if overlap > best[0]:
            best = (overlap, e)
    if best[1] is not None and best[0] >= 0.5:
        return (best[1].pt_code, best[1].pt_term, round(0.5 + best[0] * 0.3, 2))
    return (None, None, 0.0)


def hash_subject_id(raw_id: str, salt: str = "GLP1KOR") -> str:
    return hashlib.sha1(f"{salt}-{raw_id}".encode()).hexdigest()[:10]


def mask_site(site: str) -> str:
    return hashlib.sha1(site.encode()).hexdigest()[:6]


def load_ae(path: str | None = None) -> list[dict]:
    if path is None:
        path = os.path.join(DATA_DIR, "synthetic_ae.csv")
    with open(path, encoding="utf-8") as f:
        rows = list(csv.DictReader(f))
    # Coerce numeric fields
    for r in rows:
        for k in ("dose_mg", "baseline_bmi"):
            if r.get(k):
                try:
                    r[k] = float(r[k])
                except ValueError:
                    pass
        for k in ("age", "t2dm", "prior_pancreatitis", "ckd", "followup_weeks",
                  "onset_week", "severity_ctcae", "serious", "related"):
            if r.get(k):
                try:
                    r[k] = int(r[k])
                except ValueError:
                    pass
    return rows


def load_labs(path: str | None = None) -> list[dict]:
    if path is None:
        path = os.path.join(DATA_DIR, "synthetic_labs.csv")
    with open(path, encoding="utf-8") as f:
        rows = list(csv.DictReader(f))
    for r in rows:
        for k in ("visit_week",):
            if r.get(k) not in (None, ""):
                r[k] = int(r[k])
        for k in ("amylase_U_L", "amylase_uln", "lipase_U_L", "lipase_uln",
                  "calcitonin_pg_mL", "calcitonin_uln"):
            if r.get(k) not in (None, ""):
                try:
                    r[k] = float(r[k])
                except ValueError:
                    pass
    return rows


def deidentify_rows(rows: Iterable[dict], drop_keys: tuple[str, ...] = ("subject_id", "site")) -> list[dict]:
    """필요시 외부 raw row에 hash/mask 적용."""
    out = []
    for r in rows:
        new = dict(r)
        if "subject_id" in new:
            new["subject_id_hash"] = hash_subject_id(str(new["subject_id"]))
            del new["subject_id"]
        if "site" in new:
            new["site_mask"] = mask_site(str(new["site"]))
            del new["site"]
        for k in drop_keys:
            new.pop(k, None)
        out.append(new)
    return out


def auto_map_ae_freetext(ae_rows: list[dict], meddra: list[MeddraEntry] | None = None) -> list[dict]:
    """`freetext` 컬럼을 mini-MedDRA에 매핑하여 (suggested_pt_*) 컬럼 추가."""
    if meddra is None:
        meddra = load_meddra_mini()
    out = []
    for r in ae_rows:
        new = dict(r)
        ft = new.get("freetext", "")
        pt_code, pt_term, conf = map_freetext(ft, meddra)
        new["suggested_pt_code"] = pt_code or ""
        new["suggested_pt_term"] = pt_term or ""
        new["suggested_pt_conf"] = conf
        out.append(new)
    return out


def ingest_summary(ae_rows: list[dict]) -> dict:
    by_arm: dict[str, int] = {}
    by_panel: dict[str, int] = {}
    serious = 0
    for r in ae_rows:
        by_arm[r["arm"]] = by_arm.get(r["arm"], 0) + 1
        by_panel[r["panel"]] = by_panel.get(r["panel"], 0) + 1
        if int(r.get("serious", 0) or 0) == 1:
            serious += 1
    return {
        "n_events": len(ae_rows),
        "by_arm": by_arm,
        "by_panel": by_panel,
        "serious": serious,
    }
