"""Co-localization & coat-protein analysis utilities.

Operates on the per-LD object table. Manders M1/M2 and Pearson are already
provided per-LD upstream (in the MVP we trust the precomputed columns) — this
module aggregates them and characterises perilipin coat composition shifts.
"""

from __future__ import annotations

import numpy as np
import pandas as pd


PLIN_COLS = ["plin1_intensity", "plin2_intensity", "plin3_intensity", "plin5_intensity"]


def per_well_coloc(df_ld: pd.DataFrame) -> pd.DataFrame:
    grp = df_ld.groupby("well")
    out = grp.agg(
        manders_m1_mean=("manders_m1", "mean"),
        manders_m1_sd=("manders_m1", "std"),
        manders_m2_mean=("manders_m2", "mean"),
        pearson_mean=("pearson", "mean"),
        ld_mito_dist_mean_um=("distance_to_mito_um", "mean"),
        ld_mito_dist_sd_um=("distance_to_mito_um", "std"),
    ).reset_index()
    return out


def plin_coat_composition(df_ld: pd.DataFrame) -> pd.DataFrame:
    """Per-LD coat composition fractions then per-well mean."""
    work = df_ld.copy()
    work[PLIN_COLS] = work[PLIN_COLS].clip(lower=0)
    total = work[PLIN_COLS].sum(axis=1).replace(0, np.nan)
    for col in PLIN_COLS:
        work[col + "_frac"] = work[col] / total
    fracs = [c + "_frac" for c in PLIN_COLS]
    grp = work.groupby("well")[fracs].mean().reset_index()
    grp = grp.rename(
        columns={
            "plin1_intensity_frac": "plin1_frac",
            "plin2_intensity_frac": "plin2_frac",
            "plin3_intensity_frac": "plin3_frac",
            "plin5_intensity_frac": "plin5_frac",
        }
    )
    return grp


def detect_coat_shift(df_well_plin: pd.DataFrame, vehicle_wells: list[str]) -> pd.DataFrame:
    """For each well, compute log2 fold-change of each PLIN fraction vs vehicle mean."""
    if df_well_plin.empty:
        return df_well_plin
    veh = df_well_plin[df_well_plin["well"].isin(vehicle_wells)]
    if veh.empty:
        veh_means = df_well_plin[["plin1_frac", "plin2_frac", "plin3_frac", "plin5_frac"]].mean()
    else:
        veh_means = veh[["plin1_frac", "plin2_frac", "plin3_frac", "plin5_frac"]].mean()

    eps = 1e-6
    out = df_well_plin.copy()
    for col in ["plin1_frac", "plin2_frac", "plin3_frac", "plin5_frac"]:
        out[col + "_log2fc"] = np.log2((out[col] + eps) / (veh_means[col] + eps))
    return out
