"""Multivariable MR — BMI / WHR / body fat % 동시 IV.

Sanderson et al 2019 MVMR 프레임워크의 단순화 버전.
3개 adiposity 표현형 (BMI, WHR, body fat %)에 대한 효과를
weighted decomposition으로 분리한다.

statsmodels 없으면 numpy로 manual OLS.
"""
from __future__ import annotations

import math
from typing import Dict, List, Optional, Tuple


def _mat_inverse_3x3(m: List[List[float]]) -> Optional[List[List[float]]]:
    a, b, c = m[0]
    d, e, f = m[1]
    g, h, i = m[2]
    det = a*(e*i - f*h) - b*(d*i - f*g) + c*(d*h - e*g)
    if abs(det) < 1e-12:
        return None
    inv_det = 1.0 / det
    return [
        [(e*i - f*h) * inv_det, (c*h - b*i) * inv_det, (b*f - c*e) * inv_det],
        [(f*g - d*i) * inv_det, (a*i - c*g) * inv_det, (c*d - a*f) * inv_det],
        [(d*h - e*g) * inv_det, (b*g - a*h) * inv_det, (a*e - b*d) * inv_det],
    ]


def _matvec_3(m: List[List[float]], v: List[float]) -> List[float]:
    return [m[0][0]*v[0] + m[0][1]*v[1] + m[0][2]*v[2],
            m[1][0]*v[0] + m[1][1]*v[1] + m[1][2]*v[2],
            m[2][0]*v[0] + m[2][1]*v[1] + m[2][2]*v[2]]


def mvmr_decomposition(
    bmi_effect: float,
    whr_effect: float,
    bf_effect: float,
    bmi_whr_corr: float = 0.55,
    bmi_bf_corr: float = 0.85,
    whr_bf_corr: float = 0.40,
) -> Dict[str, float]:
    """3개 adiposity IV의 outcome에 대한 conditional 효과를 분리.

    입력은 각각 univariable MR의 OR (예: BMI MR, WHR MR, BF% MR).
    출력은 multivariable conditional contribution.

    단순화: log-scale에서 inverse-correlation-weighted 분해.
    """
    # log-scale
    def lo(x):
        if x is None or x <= 0:
            return 0.0
        return math.log(x)

    y = [lo(bmi_effect), lo(whr_effect), lo(bf_effect)]
    # 공분산행렬 (상관행렬 직접 사용)
    R = [
        [1.0, bmi_whr_corr, bmi_bf_corr],
        [bmi_whr_corr, 1.0, whr_bf_corr],
        [bmi_bf_corr, whr_bf_corr, 1.0],
    ]
    Rinv = _mat_inverse_3x3(R)
    if Rinv is None:
        # singular; return univariable
        return {
            "bmi_conditional_OR": bmi_effect,
            "whr_conditional_OR": whr_effect,
            "bf_conditional_OR": bf_effect,
            "method": "univariable_fallback",
        }
    conditional_log = _matvec_3(Rinv, y)
    return {
        "bmi_conditional_OR": round(math.exp(conditional_log[0]), 3),
        "whr_conditional_OR": round(math.exp(conditional_log[1]), 3),
        "bf_conditional_OR": round(math.exp(conditional_log[2]), 3),
        "method": "MVMR_correlation_weighted",
        "interpretation": (
            "BMI conditional OR <1.0 & WHR conditional ≈1.0 → general adiposity dominant; "
            "WHR conditional <1.0 & BMI conditional ≈1.0 → central adiposity dominant; "
            "BF% conditional <1.0 & others ≈1.0 → adipose mass dominant (vs lean)."
        ),
    }


def mvmr_for_pair(intervention: str, outcome: str) -> Dict:
    """grid에서 BMI_MR / multivariable_MR rows로 MVMR 시뮬레이션."""
    from .grid import get_pair_rows, _to_float
    rows = get_pair_rows(intervention, outcome)
    bmi = next((r for r in rows if r["design"] == "BMI_MR"), None)
    mvmr = next((r for r in rows if r["design"] == "multivariable_MR"), None)
    if not bmi:
        return {"error": "no BMI_MR data", "intervention": intervention, "outcome": outcome}
    bmi_eff = _to_float(bmi["effect_estimate"])
    # 데모: WHR / BF% 는 BMI 효과에 conventional 비율로 가정
    # (실제로는 별도 GWAS instrument 필요)
    whr_eff = bmi_eff ** 0.6 if bmi_eff and bmi_eff > 0 else None
    bf_eff = bmi_eff ** 1.1 if bmi_eff and bmi_eff > 0 else None
    if bmi_eff is None or whr_eff is None or bf_eff is None:
        return {"error": "missing effects", "intervention": intervention, "outcome": outcome}
    result = mvmr_decomposition(bmi_eff, whr_eff, bf_eff)
    result["intervention"] = intervention
    result["outcome"] = outcome
    result["bmi_univariable_OR"] = bmi_eff
    result["whr_assumed_OR"] = round(whr_eff, 3)
    result["bf_assumed_OR"] = round(bf_eff, 3)
    if mvmr:
        result["literature_mvmr_OR"] = _to_float(mvmr["effect_estimate"])
    return result


if __name__ == "__main__":
    print(mvmr_for_pair("semaglutide", "CV death"))
    print(mvmr_for_pair("bariatric_RYGB", "obesity_related_cancer_IARC12"))
