"""Sigmoidal 4-parameter logistic dose-response fit and IC50/EC50 extraction."""

from __future__ import annotations

import math
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit


def fourpl(x, top, bottom, hill, ec50):
    """4-parameter logistic.

    y = bottom + (top - bottom) / (1 + (x / ec50) ** -hill)
    """
    x = np.asarray(x, dtype=float)
    safe_x = np.where(x <= 0, 1e-12, x)
    return bottom + (top - bottom) / (1.0 + (safe_x / ec50) ** (-hill))


def _safe_log_doses(doses):
    arr = np.asarray(doses, dtype=float)
    arr = np.where(arr <= 0, 1e-6, arr)
    return arr


def fit_dose_response(doses, responses) -> dict:
    """Return dict with top, bottom, hill, ec50, r2, n.

    On failure returns all-NaN dict with `ok=False`.
    """
    doses = _safe_log_doses(doses)
    y = np.asarray(responses, dtype=float)
    mask = np.isfinite(doses) & np.isfinite(y)
    doses = doses[mask]
    y = y[mask]
    n = len(y)
    blank = {
        "top": float("nan"),
        "bottom": float("nan"),
        "hill": float("nan"),
        "ec50": float("nan"),
        "r2": float("nan"),
        "n": int(n),
        "ok": False,
    }
    if n < 4:
        return blank
    try:
        top0 = float(np.nanmax(y))
        bot0 = float(np.nanmin(y))
        ec0 = float(np.nanmedian(doses))
        if not np.isfinite(ec0) or ec0 <= 0:
            ec0 = 1.0
        hill0 = 1.0
        p0 = [top0, bot0, hill0, ec0]
        # Bound hill in a reasonable range so fit stays sigmoidal
        bounds = (
            [-np.inf, -np.inf, -10.0, 1e-9],
            [np.inf, np.inf, 10.0, 1e9],
        )
        popt, _ = curve_fit(fourpl, doses, y, p0=p0, bounds=bounds, maxfev=20000)
        top, bottom, hill, ec50 = (float(v) for v in popt)
        yhat = fourpl(doses, top, bottom, hill, ec50)
        ss_res = float(np.sum((y - yhat) ** 2))
        ss_tot = float(np.sum((y - np.mean(y)) ** 2))
        r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else float("nan")
        return {
            "top": top,
            "bottom": bottom,
            "hill": hill,
            "ec50": ec50,
            "r2": r2,
            "n": int(n),
            "ok": True,
        }
    except Exception as exc:  # pragma: no cover
        blank["error"] = str(exc)
        return blank


def fit_per_drug(per_well: pd.DataFrame, response_col: str) -> pd.DataFrame:
    """Fit 4PL across doses for each drug in a per-well summary table.

    Requires columns: drug, dose_uM, <response_col>.
    """
    rows = []
    if "drug" not in per_well.columns or "dose_uM" not in per_well.columns:
        return pd.DataFrame()
    for drug, sub in per_well.groupby("drug"):
        agg = sub.groupby("dose_uM")[response_col].mean().reset_index()
        res = fit_dose_response(agg["dose_uM"].values, agg[response_col].values)
        res["drug"] = drug
        res["response"] = response_col
        rows.append(res)
    return pd.DataFrame(rows)
