"""Mediation analysis: MASLD-mediated vs metabolic-confounder-mediated."""
from __future__ import annotations

import math
from typing import Dict, List, Optional

from .grid import design_summary
from .mvmr import mvmr_decompose, LIVER_SPECIFICITY


def _safe_log(x: Optional[float]) -> Optional[float]:
    if x is None:
        return None
    try:
        if x <= 0:
            return None
        return math.log(x)
    except (TypeError, ValueError):
        return None


def masld_mediated_fraction(
    effects: List[Dict[str, object]],
    stage: str,
    outcome: str,
) -> Dict[str, object]:
    """Compute MASLD-mediated vs metabolic-confounder-mediated fraction.

    Total effect = observational (proxies population-level association).
    Liver-specific = liver-restricted MR (PNPLA3/HSD17B13/MBOAT7).
    Metabolic-confounder = BMI/T2DM polygenic MR (where present).
    """
    summary = design_summary(effects, stage, outcome)
    mvmr = mvmr_decompose(effects, stage, outcome)

    total = summary["observational"]["effect"]
    log_total = _safe_log(total)
    log_liver = mvmr.get("liver_specific_log_effect") if mvmr.get("available") else None
    log_metab = mvmr.get("metabolic_confounder_log_effect") if mvmr.get("available") else None

    out: Dict[str, object] = {
        "stage": stage,
        "outcome": outcome,
        "total_effect_observational": total,
        "liver_specific_effect_mr": mvmr.get("liver_specific_effect"),
        "metabolic_confounder_effect_mr": mvmr.get("metabolic_confounder_effect"),
    }

    if log_total is None or log_liver is None:
        out["mediation_available"] = False
        out["reason"] = "관찰 효과 또는 liver-specific MR 부재."
        return out

    # share of total log-effect attributable to liver vs metabolic
    log_metab_safe = log_metab if log_metab is not None else 0.0
    log_sum_abs = abs(log_liver) + abs(log_metab_safe)
    if log_sum_abs == 0:
        out["mediation_available"] = False
        out["reason"] = "MR 효과 0 — mediation 미산정."
        return out

    frac_liver = abs(log_liver) / log_sum_abs
    frac_metab = abs(log_metab_safe) / log_sum_abs
    # residual (observational - sum_mr) → unmeasured confounding
    residual = log_total - (log_liver + log_metab_safe)

    out["mediation_available"] = True
    out["frac_masld_mediated"] = round(frac_liver, 3)
    out["frac_metabolic_mediated"] = round(frac_metab, 3)
    out["residual_log_effect_unmeasured"] = round(residual, 3)
    out["residual_effect_unmeasured"] = round(math.exp(residual), 3)

    # interpretation
    if frac_liver > 0.66:
        flag = "MASLD-driven causal mechanism (liver-restricted MR 우세)."
    elif frac_metab > 0.66:
        flag = "Metabolic-confounder-mediated (BMI/T2DM 공통 원인 의심)."
    elif log_liver * log_metab_safe < 0:
        flag = "Discordant — TM6SF2 패턴 의심 (MASLD 증가·metabolic 감소)."
    else:
        flag = "Mixed mediation."
    out["interpretation"] = flag
    return out
