"""Multivariable Mendelian Randomization — PNPLA3 + BMI + T2DM 동시 IV.

Simplified naive-MVMR: takes univariate MR estimates and adjusts for joint
exposure structure. Liver-specific (PNPLA3) effect vs metabolic-confounder
(BMI, T2DM) effect are reported separately.

References (curated, not called):
  Burgess S, Thompson SG. Multivariable MR. AJE 2015.
  Sanderson E et al. MVMR sensitivity. IJE 2019.
"""
from __future__ import annotations

import math
from typing import Dict, List, Optional

from .ontology import filter_effects


# Heuristic: weight each instrument's contribution by its "liver specificity"
# (1.0 = fully liver-specific, 0.0 = fully metabolic confounder).
LIVER_SPECIFICITY = {
    "PNPLA3": 0.95,
    "PNPLA3_korean": 0.95,
    "TM6SF2": 0.85,
    "HSD17B13": 0.95,
    "MBOAT7": 0.90,
    "GCKR": 0.30,
    "BMI_polygenic": 0.05,
    "T2DM_polygenic": 0.10,
    "WC_polygenic": 0.10,
}


def _safe_log(x: Optional[float]) -> Optional[float]:
    if x is None or x <= 0:
        return None
    return math.log(x)


def mvmr_decompose(effects: List[Dict[str, object]], stage: str, outcome: str) -> Dict[str, object]:
    """Decompose MASLD→outcome causal effect into liver-specific vs metabolic-confounder.

    Output:
      {
        liver_specific_log_effect: float,    # PNPLA3/HSD17B13/TM6SF2/MBOAT7 weighted
        metabolic_confounder_log_effect: float,
        mediated_fraction_liver: 0..1,
        joint_log_effect: float,
        per_instrument: {iv: {log_effect, weight, ...}},
      }
    """
    mr_rows = [e for e in effects
               if e.get("masld_stage") == stage
               and e.get("outcome") == outcome
               and e.get("design") == "MR"]
    if not mr_rows:
        return {"available": False, "reason": "No MR estimate for pair."}

    per_iv = {}
    liver_num = 0.0
    liver_den = 0.0
    metab_num = 0.0
    metab_den = 0.0
    for r in mr_rows:
        iv = r.get("mr_instrument") or "unknown"
        log_e = _safe_log(r.get("effect_estimate"))
        if log_e is None:
            continue
        spec = LIVER_SPECIFICITY.get(iv, 0.5)
        per_iv[iv] = {
            "log_effect": round(log_e, 4),
            "effect_estimate": r.get("effect_estimate"),
            "liver_specificity_weight": spec,
            "ci_low": r.get("ci_low"),
            "ci_high": r.get("ci_high"),
        }
        liver_num += spec * log_e
        liver_den += spec
        metab_num += (1 - spec) * log_e
        metab_den += (1 - spec)

    liver_log = (liver_num / liver_den) if liver_den > 0 else None
    metab_log = (metab_num / metab_den) if metab_den > 0 else None
    # joint effect = mean (very naive)
    joint_logs = [v["log_effect"] for v in per_iv.values()]
    joint_log = sum(joint_logs) / len(joint_logs) if joint_logs else None

    # mediated fraction (liver) — share of |log effect|
    if liver_log is not None and metab_log is not None:
        denom = abs(liver_log) + abs(metab_log)
        med_liver = abs(liver_log) / denom if denom > 0 else None
    else:
        med_liver = None

    return {
        "available": True,
        "stage": stage,
        "outcome": outcome,
        "liver_specific_log_effect": (round(liver_log, 4) if liver_log is not None else None),
        "liver_specific_effect": (round(math.exp(liver_log), 3) if liver_log is not None else None),
        "metabolic_confounder_log_effect": (round(metab_log, 4) if metab_log is not None else None),
        "metabolic_confounder_effect": (round(math.exp(metab_log), 3) if metab_log is not None else None),
        "joint_log_effect": (round(joint_log, 4) if joint_log is not None else None),
        "joint_effect": (round(math.exp(joint_log), 3) if joint_log is not None else None),
        "mediated_fraction_liver": (round(med_liver, 3) if med_liver is not None else None),
        "per_instrument": per_iv,
        "interpretation": _interpret_mvmr(liver_log, metab_log),
    }


def _interpret_mvmr(liver_log: Optional[float], metab_log: Optional[float]) -> str:
    if liver_log is None or metab_log is None:
        return "Insufficient instruments for decomposition."
    if abs(liver_log) > abs(metab_log) * 2:
        return "강한 liver-specific 인과 신호 (metabolic confounder 보다 우세)."
    if abs(metab_log) > abs(liver_log) * 2:
        return "Metabolic confounder 우세 — MASLD가 인과보다는 marker 가능성 시사."
    if liver_log * metab_log < 0:
        return "Discordant: liver-specific 와 metabolic confounder 방향 반대 (TM6SF2 패턴 의심)."
    return "Liver-specific 와 metabolic confounder 효과 유사 — 부분 mediation 가능."
