"""Schemas for cohort / treatment / modality data.

Uses pydantic if available, else falls back to dataclasses.
"""
from __future__ import annotations

from dataclasses import dataclass, field
from typing import List, Optional

try:
    from pydantic import BaseModel, Field  # type: ignore

    PYDANTIC = True
except Exception:  # pragma: no cover
    PYDANTIC = False


# Allowed enums (kept as plain str sets for simplicity)
ALLOWED_MODELS = {
    "HFD+aging",
    "HFD+OVX",
    "HFD+dexa",
    "HFD+cachexia",
    "db/db+aging",
    "MSTN-/-_HFD",
}
ALLOWED_TREATMENTS = {
    "vehicle",
    "semaglutide",
    "tirzepatide",
    "retatrutide",
    "bimagrumab",
    "ActRIIB-Fc",
    "trevogrumab",
    "apitegromab",
    # combinations
    "tirzepatide+bimagrumab",
    "semaglutide+bimagrumab",
    "retatrutide+bimagrumab",
    "tirzepatide+ActRIIB-Fc",
}
ALLOWED_FIBER_TYPES = {"I", "IIA", "IIX", "IIB"}
ALLOWED_MUSCLES = {"gastroc", "soleus", "TA", "quad"}
ALLOWED_MYOKINES = {"myostatin", "activin_A", "irisin", "decorin", "BAIBA"}


@dataclass
class CohortMeta:
    mouse_id: str
    model: str
    treatment: str
    dose_mg_kg: float
    randomization_seed: int
    sex: str = "M"
    starting_age_weeks: int = 16
    cohort_id: str = "demo_cohort"

    def validate(self) -> List[str]:
        errs = []
        if self.model not in ALLOWED_MODELS:
            errs.append(f"unknown model: {self.model}")
        if self.treatment not in ALLOWED_TREATMENTS:
            errs.append(f"unknown treatment: {self.treatment}")
        if self.dose_mg_kg < 0:
            errs.append("negative dose")
        return errs


@dataclass
class PreRegistration:
    cohort_id: str
    n_per_arm: int
    arms: List[str]
    primary_endpoint: str
    randomization_seed: int
    iacuc_protocol: str = "IACUC-PENDING"
    arrive_version: str = "ARRIVE 2.0"


@dataclass
class CompositeEndpoint:
    bw_pct_change: float
    lean_pct_change: float
    grip_per_lean_pct_change: float

    def passes(
        self,
        bw_threshold: float = -5.0,
        lean_threshold: float = -3.0,
        grip_threshold: float = 5.0,
    ) -> bool:
        # BW <-5%  AND  lean mass >-3%  AND  grip per lean mass >+5%
        return (
            self.bw_pct_change < bw_threshold
            and self.lean_pct_change > lean_threshold
            and self.grip_per_lean_pct_change > grip_threshold
        )


# Standard column manifests for modality csvs
MODALITY_COLUMNS = {
    "body_weight": ["mouse_id", "week", "bw_g"],
    "body_composition": ["mouse_id", "week", "lean_mass_g", "fat_mass_g", "bw_g"],
    "grip_strength": ["mouse_id", "week", "force_g"],
    "treadmill": ["mouse_id", "week", "distance_m", "duration_s"],
    "running_wheel": ["mouse_id", "week", "rev_per_day"],
    "microct_muscle": ["mouse_id", "week", "muscle", "csa_mm2", "imat_pct"],
    "myofiber_hcs": [
        "mouse_id",
        "week",
        "fiber_type",
        "csa_mean_um2",
        "pct",
        "pax7_density_per_mm2",
        "centronucleated_pct",
    ],
    "exvivo_force": ["mouse_id", "week", "muscle", "tetanic_mN", "specific_force_kPa"],
    "myokine": ["mouse_id", "week", "analyte", "value_pgmL"],
    "cohort_meta": [
        "mouse_id",
        "model",
        "treatment",
        "dose_mg_kg",
        "randomization_seed",
        "sex",
        "cohort_id",
    ],
}
