"""Drug and outcome ontology loaders, and (drug_class, outcome) pair normalization.

Uses stdlib csv only so it works without pandas installed.
"""
from __future__ import annotations

import csv
import os
from typing import Any

DATA_DIR_DEFAULT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")


def _read_csv(path: str) -> list[dict[str, str]]:
    with open(path, newline="", encoding="utf-8") as f:
        return list(csv.DictReader(f))


def load_drugs(data_dir: str = DATA_DIR_DEFAULT) -> list[dict[str, str]]:
    return _read_csv(os.path.join(data_dir, "drug_ontology.csv"))


def load_outcomes(data_dir: str = DATA_DIR_DEFAULT) -> list[dict[str, str]]:
    return _read_csv(os.path.join(data_dir, "outcome_ontology.csv"))


def load_effects(data_dir: str = DATA_DIR_DEFAULT) -> list[dict[str, Any]]:
    rows = _read_csv(os.path.join(data_dir, "effects_sample.csv"))
    # Coerce numeric fields where possible
    for r in rows:
        for k in ("effect_estimate", "ci_low", "ci_high", "follow_up_years"):
            v = r.get(k, "")
            try:
                r[k] = float(v)
            except (TypeError, ValueError):
                r[k] = None
        v = r.get("sample_size", "")
        try:
            r["sample_size"] = int(float(v))
        except (TypeError, ValueError):
            r["sample_size"] = None
    return rows


def normalize_pair_key(drug_class: str, outcome: str) -> tuple[str, str]:
    """Canonical (drug_class, outcome) key for grouping effects across rows."""
    return (drug_class.strip(), outcome.strip())


def list_pairs(effects: list[dict[str, Any]]) -> list[tuple[str, str]]:
    """All unique (drug_class, outcome) pairs present in the effects table."""
    seen: set[tuple[str, str]] = set()
    out: list[tuple[str, str]] = []
    for r in effects:
        key = normalize_pair_key(r["drug_class"], r["outcome"])
        if key not in seen:
            seen.add(key)
            out.append(key)
    return out


def filter_pair(effects: list[dict[str, Any]], drug_class: str, outcome: str) -> list[dict[str, Any]]:
    key = normalize_pair_key(drug_class, outcome)
    return [r for r in effects if normalize_pair_key(r["drug_class"], r["outcome"]) == key]
