#!/usr/bin/env python3
"""DKDComboMiner - DKD drug-combo x stage x biomarker x population gap miner.

Demonstrates the full pipeline on synthetic ontology + synthetic evidence data.
NO external network calls. NO real PubMed/CTG/FAERS access. ALL data is mock.

Usage:
    python3 main.py --top 50
    python3 main.py --top 100 --output outputs/run.md
    python3 main.py --help

Disclaimer:
    Research-use only. Not for clinical decision-making. Hypotheses must
    pass IRB review and expert evaluation before any application.
"""
from __future__ import annotations

import argparse
import csv
import hashlib
import itertools
import json
import math
import sys
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Any, Iterable

ROOT = Path(__file__).resolve().parent
DATA_DIR = ROOT / "data"
OUTPUTS_DIR = ROOT / "outputs"

# ----- deterministic synthetic-evidence engine -----------------------------

def _hash_int(*parts: str, mod: int) -> int:
    """Deterministic non-negative int derived from joined string parts."""
    raw = "|".join(parts).encode("utf-8")
    digest = hashlib.sha256(raw).hexdigest()
    return int(digest[:12], 16) % mod


def synth_pubmed_count(combo_key: str, stage: str, biomarker: str, pop: str) -> int:
    return _hash_int(combo_key, stage, biomarker, pop, "pubmed", mod=200)


def synth_pubmed_trend(combo_key: str, stage: str, biomarker: str, pop: str) -> int:
    """5-year publication-trend percent change. Range approx [-20, +60)."""
    return _hash_int(combo_key, stage, biomarker, pop, "trend", mod=80) - 20


def synth_kr_active_trials(combo_key: str, stage: str, pop: str) -> int:
    return _hash_int(combo_key, stage, pop, "ctg-kr", mod=8)


# ----- ontology loading ----------------------------------------------------

@dataclass
class Drug:
    id: str
    name: str
    moa: str
    drug_class: str
    korea_approved: bool
    kdigo_2024_strength: str

@dataclass
class Cell:
    drug_combo: tuple[str, ...]      # e.g. ("DAPA","FINE")
    drug_combo_classes: tuple[str, ...]
    stage: str                        # e.g. "G3aA2"
    biomarker: str                    # e.g. "UACR"
    population: str                   # e.g. "T2DM_NONALB"
    pubmed_n: int = 0
    trend_5y_pct: int = 0
    active_trials_kr: int = 0
    faers_caution: float = 0.0
    kdigo_strength_score: float = 0.0
    mechanism_plausibility: float = 0.0
    kr_iis_feasibility: float = 0.0
    score: float = 0.0
    rationale: list[str] = field(default_factory=list)


def load_ontology() -> dict[str, Any]:
    with open(DATA_DIR / "ontology.json", "r", encoding="utf-8") as f:
        return json.load(f)


def load_faers() -> dict[str, Any]:
    with open(DATA_DIR / "faers_signals.json", "r", encoding="utf-8") as f:
        return json.load(f)


def load_cohorts() -> dict[str, Any]:
    with open(DATA_DIR / "korea_cohorts.json", "r", encoding="utf-8") as f:
        return json.load(f)


def flatten_drugs(onto: dict[str, Any]) -> list[Drug]:
    out: list[Drug] = []
    for klass, items in onto["drugs"].items():
        for d in items:
            out.append(Drug(
                id=d["id"], name=d["name"], moa=d["moa"],
                drug_class=klass,
                korea_approved=d["korea_approved"],
                kdigo_2024_strength=d["kdigo_2024_strength"],
            ))
    return out


# ----- scoring -------------------------------------------------------------

KDIGO_STRENGTH_MAP = {
    "1A": 1.00, "1B": 0.85, "1C": 0.70,
    "2A": 0.65, "2B": 0.55, "2C": 0.40, "2D": 0.25,
}

# Mechanism plausibility prior by class-pair. Higher = more biologically plausible synergy.
CLASS_PAIR_MECH = {
    frozenset(["SGLT2i", "MRA"]): 0.92,        # CONFIDENCE-CKD-style hemodynamic + anti-fibrotic
    frozenset(["SGLT2i", "GLP1RA"]): 0.88,     # FLOW + EMPA-KIDNEY complementary
    frozenset(["GLP1RA", "MRA"]): 0.80,        # inflammation + mineralocorticoid
    frozenset(["SGLT2i", "RAS"]): 0.72,        # standard backbone
    frozenset(["MRA", "RAS"]): 0.55,           # high hyperK risk
    frozenset(["GLP1RA", "RAS"]): 0.65,
    frozenset(["SGLT2i", "ADJ"]): 0.60,
    frozenset(["GLP1RA", "ADJ"]): 0.55,
    frozenset(["MRA", "ADJ"]): 0.50,
    frozenset(["RAS", "ADJ"]): 0.45,
    frozenset(["SGLT2i"]): 0.70,               # monotherapy baseline
    frozenset(["GLP1RA"]): 0.70,
    frozenset(["MRA"]): 0.65,
    frozenset(["RAS"]): 0.65,
    frozenset(["ADJ"]): 0.40,
}


def kdigo_score_for_combo(drugs: tuple[Drug, ...]) -> float:
    if not drugs:
        return 0.0
    return sum(KDIGO_STRENGTH_MAP.get(d.kdigo_2024_strength, 0.3) for d in drugs) / len(drugs)


def mechanism_plausibility(class_set: frozenset[str]) -> float:
    if class_set in CLASS_PAIR_MECH:
        return CLASS_PAIR_MECH[class_set]
    # 3+ class combos: average pairwise minus crowding penalty
    classes = list(class_set)
    if len(classes) >= 2:
        pairs = list(itertools.combinations(classes, 2))
        avg = sum(CLASS_PAIR_MECH.get(frozenset(p), 0.5) for p in pairs) / len(pairs)
        return max(0.1, avg - 0.10 * (len(classes) - 2))
    return 0.5


def faers_caution_for_combo(combo_classes: tuple[str, ...], faers: dict[str, Any]) -> float:
    """Return a [0,1] caution score; higher = more safety overlap concern."""
    s = 0.0
    cs = set(combo_classes)
    for entry in faers.get("combo_synergy_caution", []):
        if set(entry["combo"]).issubset(cs):
            s += entry["weight"]
    # baseline class-level signal: take max PRR-derived weight
    cls_signals = faers.get("drug_class_signals", {})
    cls_max = 0.0
    for c in combo_classes:
        sig = cls_signals.get(c, {})
        if sig:
            cls_max = max(cls_max, max((v for v in sig.values()), default=0.0) / 8.0)
    return min(1.0, 0.5 * s + 0.5 * cls_max)


def kr_iis_feasibility(stage: str, biomarker: str, population: str,
                       cohorts: dict[str, Any], onto: dict[str, Any]) -> tuple[float, dict]:
    """Estimate KR IIS feasibility = best matching cohort coverage.

    Returns (score, best_cohort_meta).
    """
    pops = {p["id"]: p for p in onto["populations"]}
    pop_rank = pops[population]["korea_prevalence_rank"]
    # higher rank number = rarer; invert so common pop -> 1.0
    pop_factor = max(0.2, (6 - pop_rank) / 5.0)
    # late stages (G4/G5) reduce recruit feasibility
    stage_factor = 1.0
    if stage.startswith("G4"):
        stage_factor = 0.6
    elif stage.startswith("G5"):
        stage_factor = 0.35
    elif stage.startswith("G3b"):
        stage_factor = 0.85
    best_score = 0.0
    best = None
    for c in cohorts["cohorts"]:
        biomarker_ok = biomarker in c["biomarker_coverage"]
        bm_factor = 1.0 if biomarker_ok else 0.4
        sites_factor = min(1.0, (c["sites_kr"] + 1) / 30.0) if c["sites_kr"] else 0.5
        s = c["feasibility_score"] * bm_factor * sites_factor * pop_factor * stage_factor
        if s > best_score:
            best_score = s
            best = c
    return min(1.0, best_score), (best or {})


# ----- pipeline ------------------------------------------------------------

def enumerate_cells(onto: dict[str, Any], max_combo_size: int = 2,
                    biomarkers_subset: list[str] | None = None,
                    populations_subset: list[str] | None = None,
                    stages_subset: list[str] | None = None) -> Iterable[tuple]:
    drugs = flatten_drugs(onto)
    bms = biomarkers_subset or [b["id"] for b in onto["biomarkers"]]
    pops = populations_subset or [p["id"] for p in onto["populations"]]
    stages = stages_subset or onto["kdigo_stages"]
    drug_ids = [d.id for d in drugs]
    drug_by_id = {d.id: d for d in drugs}
    # mono + pair combos (sorted to dedupe)
    combos: list[tuple[Drug, ...]] = []
    for d in drugs:
        combos.append((d,))
    for a, b in itertools.combinations(drug_ids, 2):
        da, db = drug_by_id[a], drug_by_id[b]
        if da.drug_class == db.drug_class:
            continue  # skip intra-class duplicates
        combos.append((da, db))
    if max_combo_size >= 3:
        for a, b, c in itertools.combinations(drug_ids, 3):
            classes = {drug_by_id[a].drug_class, drug_by_id[b].drug_class, drug_by_id[c].drug_class}
            if len(classes) < 3:
                continue
            combos.append((drug_by_id[a], drug_by_id[b], drug_by_id[c]))
    for combo in combos:
        for stage in stages:
            for bm in bms:
                for pop in pops:
                    yield combo, stage, bm, pop


def score_cell(combo: tuple[Drug, ...], stage: str, biomarker: str, population: str,
               faers: dict[str, Any], cohorts: dict[str, Any], onto: dict[str, Any],
               weights: dict[str, float]) -> Cell:
    combo_ids = tuple(sorted(d.id for d in combo))
    combo_classes = tuple(sorted(set(d.drug_class for d in combo)))
    combo_key = "+".join(combo_ids)

    pubmed_n = synth_pubmed_count(combo_key, stage, biomarker, population)
    trend = synth_pubmed_trend(combo_key, stage, biomarker, population)
    kr_trials = synth_kr_active_trials(combo_key, stage, population)
    faers_c = faers_caution_for_combo(combo_classes, faers)
    kdigo_s = kdigo_score_for_combo(combo)
    mech = mechanism_plausibility(frozenset(combo_classes))
    feas, best_cohort = kr_iis_feasibility(stage, biomarker, population, cohorts, onto)

    # Gap = inverse pubmed count (low evidence = high gap), capped
    gap = max(0.0, 1.0 - pubmed_n / 200.0)
    trend_norm = (trend + 20) / 80.0  # 0..1 from -20..60
    trial_pressure = max(0.0, 1.0 - kr_trials / 8.0)  # few KR trials -> high opportunity

    score = (
        weights["gap"] * gap +
        weights["trend"] * trend_norm +
        weights["trial_gap"] * trial_pressure +
        weights["kdigo"] * kdigo_s +
        weights["mech"] * mech +
        weights["feasibility"] * feas -
        weights["faers_penalty"] * faers_c
    )

    rationale = [
        f"Synthetic PubMed n={pubmed_n} (gap={gap:.2f}), 5y trend {trend:+d}%",
        f"Active KR trials (synthetic) = {kr_trials}; KDIGO 2024 strength avg = {kdigo_s:.2f}",
        f"Mechanism plausibility prior = {mech:.2f}; FAERS caution = {faers_c:.2f}",
        f"Best matching KR cohort: {best_cohort.get('id','-')} (feasibility {feas:.2f})",
    ]

    return Cell(
        drug_combo=combo_ids,
        drug_combo_classes=combo_classes,
        stage=stage,
        biomarker=biomarker,
        population=population,
        pubmed_n=pubmed_n,
        trend_5y_pct=trend,
        active_trials_kr=kr_trials,
        faers_caution=round(faers_c, 3),
        kdigo_strength_score=round(kdigo_s, 3),
        mechanism_plausibility=round(mech, 3),
        kr_iis_feasibility=round(feas, 3),
        score=round(score, 4),
        rationale=rationale,
    )


def default_weights() -> dict[str, float]:
    return {
        "gap": 0.20,
        "trend": 0.10,
        "trial_gap": 0.10,
        "kdigo": 0.15,
        "mech": 0.20,
        "feasibility": 0.20,
        "faers_penalty": 0.15,
    }


# ----- output renderers ----------------------------------------------------

def render_hypothesis_card(cell: Cell, drug_lookup: dict[str, Drug]) -> str:
    drugs_named = ", ".join(f"{drug_lookup[i].name} ({i})" for i in cell.drug_combo)
    classes = " + ".join(cell.drug_combo_classes)
    return f"""### Hypothesis: {drugs_named} | KDIGO {cell.stage} | {cell.biomarker} | {cell.population}

- Drug combo: {drugs_named}
- Class signature: {classes}
- KDIGO stage: {cell.stage}
- Primary biomarker: {cell.biomarker}
- Target population: {cell.population}
- Composite score: **{cell.score}**
- Synthetic literature gap (pubmed_n / trend_5y%): {cell.pubmed_n} / {cell.trend_5y_pct:+d}%
- Synthetic KR active trials: {cell.active_trials_kr}
- KDIGO 2024 strength avg: {cell.kdigo_strength_score} | Mechanism: {cell.mechanism_plausibility} | FAERS caution: {cell.faers_caution} | KR IIS feasibility: {cell.kr_iis_feasibility}
- Rationale: {' / '.join(cell.rationale)}
"""


def render_grant_aims(cell: Cell, drug_lookup: dict[str, Drug]) -> str:
    drugs_named = " + ".join(drug_lookup[i].name for i in cell.drug_combo)
    return f"""# Grant-Ready Specific Aims: {drugs_named} in DKD {cell.stage}

> Research-use only. IRB approval and expert review required prior to any clinical application. Synthetic data only.

## Specific Aims
**Aim 1.** Determine the effect of {drugs_named} on {cell.biomarker} trajectory in {cell.population} with KDIGO {cell.stage} DKD over 52 weeks (synthetic preliminary data: gap-score {cell.score}; current Korea active trials = {cell.active_trials_kr}).

**Aim 2.** Identify mechanistic mediators (mechanism plausibility prior {cell.mechanism_plausibility}) using paired plasma + urine multi-omics in a Korean IIS cohort with feasibility {cell.kr_iis_feasibility}.

**Aim 3.** Build a Korean-cohort-validated risk-stratification tool combining {cell.biomarker} with KDIGO 2024 (combo strength {cell.kdigo_strength_score}) for treatment selection.

## Innovation
- Class signature {' + '.join(cell.drug_combo_classes)} addresses an unexplored cell of the DKD therapeutic ontology.
- Korean lean / non-albuminuric DKD phenotype is underrepresented in pivotal trials (FLOW, EMPA-KIDNEY, FIDELIO/FIGARO).

## Approach
- Design: prospective IIS, parallel-arm, target N derived from KR cohort coverage.
- Outcomes: primary {cell.biomarker} change at 52w; secondary eGFR slope, hard renal endpoints.
- Safety monitoring informed by synthetic FAERS caution score = {cell.faers_caution}.

## Preliminary Data (synthetic)
{chr(10).join('- ' + r for r in cell.rationale)}

## Deliverables
- KDA / KSAD abstracts, peer-reviewed manuscript, KR cohort-derived RWE evidence package.
"""


def render_kda_abstract(cell: Cell, drug_lookup: dict[str, Drug]) -> str:
    drugs_named = " + ".join(drug_lookup[i].name for i in cell.drug_combo)
    drugs_kr = " + ".join(drug_lookup[i].name for i in cell.drug_combo)
    return f"""# KDA / KSAD Abstract Draft (synthetic preliminary)

> Research-use only. Synthetic mock data; not an actual study result.

## English (target ~250 words)
**Title.** {drugs_named} for KDIGO {cell.stage} diabetic kidney disease in {cell.population}: a Korean investigator-initiated study proposal.

**Background.** Despite SGLT2i, GLP-1RA, and finerenone-class advances, the (drug-combo x stage x biomarker x population) cell defined by {drugs_named} / {cell.stage} / {cell.biomarker} / {cell.population} remains underexplored (synthetic literature n={cell.pubmed_n}, KR active trials={cell.active_trials_kr}).

**Methods.** Using a DKD therapeutic ontology of 23 drugs, 18 KDIGO stages, 7 biomarkers, and 8 populations, we systematically screened ~70,000 cells. We integrated synthetic literature counts, KR cohort feasibility, KDIGO 2024 recommendation strength, mechanism plausibility, and FAERS-style safety priors into a composite score.

**Results.** This cell scored {cell.score} (top-tier). KDIGO 2024 strength {cell.kdigo_strength_score}, mechanism plausibility {cell.mechanism_plausibility}, KR IIS feasibility {cell.kr_iis_feasibility}, FAERS caution {cell.faers_caution}.

**Conclusion.** A KR IIS evaluating {drugs_named} on {cell.biomarker} in {cell.population} with {cell.stage} DKD is high-priority; preliminary feasibility supports KDA / KSAD pursuit.

## 한국어 (약 250 단어)
**제목.** {cell.population} 환자에서 KDIGO {cell.stage} 당뇨병성 신증에 대한 {drugs_kr} 병용 효과: 한국형 연구자 주도 임상시험 제안.

**배경.** SGLT2 억제제·GLP-1 수용체 작용제·non-steroidal MRA의 발전에도 불구하고, {drugs_kr} / {cell.stage} / {cell.biomarker} / {cell.population} 조합은 미탐색 영역으로 남아있다 (합성 문헌 n={cell.pubmed_n}, 한국 활성 시험 {cell.active_trials_kr}).

**방법.** 23개 약물·18개 KDIGO 단계·7개 바이오마커·8개 인구학을 포함하는 DKD 치료 온톨로지를 구성하고 약 70,000개 cell을 체계적으로 스크리닝하였다. 합성 문헌 수, 한국 코호트 가용성, KDIGO 2024 권고 강도, 기전 타당성, FAERS 유사 안전성 prior를 가중합하여 복합 점수를 산출하였다.

**결과.** 본 cell의 점수 {cell.score} (상위권). KDIGO 강도 {cell.kdigo_strength_score}, 기전 타당성 {cell.mechanism_plausibility}, 한국 IIS 가능성 {cell.kr_iis_feasibility}, FAERS 주의 {cell.faers_caution}.

**결론.** {cell.population}의 {cell.stage} 단계에서 {drugs_kr} 병용이 {cell.biomarker}에 미치는 영향을 평가하는 한국형 IIS는 우선순위가 높으며, 예비 가용성 평가가 KDA·KSAD 추진을 뒷받침한다.
"""


# ----- main ---------------------------------------------------------------

def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    p = argparse.ArgumentParser(
        prog="dkd-combo-miner",
        description=(
            "DKDComboMiner - rank under-explored DKD (drug-combo x stage x biomarker x "
            "population) cells using synthetic ontology + synthetic evidence. "
            "Research-use only."
        ),
    )
    p.add_argument("--top", type=int, default=50, help="Number of top cells to emit (default: 50)")
    p.add_argument("--max-combo", type=int, default=2, choices=[1, 2, 3], help="Max drugs per combo (default: 2)")
    p.add_argument("--output", type=str, default=None, help="Markdown output file (default: outputs/run_top{N}.md)")
    p.add_argument("--csv", type=str, default=None, help="Optional CSV dump of all ranked cells")
    p.add_argument("--biomarkers", type=str, default=None, help="Comma-separated biomarker IDs to filter")
    p.add_argument("--populations", type=str, default=None, help="Comma-separated population IDs to filter")
    p.add_argument("--stages", type=str, default=None, help="Comma-separated KDIGO stage IDs to filter")
    p.add_argument("--no-grant", action="store_true", help="Skip grant-aims and abstract for top-1 cell")
    p.add_argument("--seed-info", action="store_true", help="Print synthetic-data seed info and exit")
    return p.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    if args.seed_info:
        print("Synthetic seed = 20260501; pubmed/trend/trials are deterministic SHA-256 derivatives.")
        print("All evidence is mock. No external network calls. NOT for clinical use.")
        return 0

    onto = load_ontology()
    faers = load_faers()
    cohorts = load_cohorts()
    drug_lookup = {d.id: d for d in flatten_drugs(onto)}
    weights = default_weights()

    bm_filter = args.biomarkers.split(",") if args.biomarkers else None
    pop_filter = args.populations.split(",") if args.populations else None
    stage_filter = args.stages.split(",") if args.stages else None

    print(f"[info] Enumerating cells (max_combo={args.max_combo})...", file=sys.stderr)
    cells: list[Cell] = []
    for combo, stage, bm, pop in enumerate_cells(
        onto, max_combo_size=args.max_combo,
        biomarkers_subset=bm_filter,
        populations_subset=pop_filter,
        stages_subset=stage_filter,
    ):
        cells.append(score_cell(combo, stage, bm, pop, faers, cohorts, onto, weights))

    cells.sort(key=lambda c: c.score, reverse=True)
    top = cells[: args.top]
    print(f"[info] Scored {len(cells)} cells; emitting top {len(top)}", file=sys.stderr)

    OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)
    out_path = Path(args.output) if args.output else OUTPUTS_DIR / f"run_top{args.top}.md"

    lines: list[str] = []
    lines.append(f"# DKDComboMiner — top {len(top)} unexplored cells")
    lines.append("")
    lines.append("> Research-use only. Synthetic mock data. Not for clinical decision-making. IRB review required.")
    lines.append("")
    lines.append(f"- Total cells scored: {len(cells)}")
    lines.append(f"- Combo size up to: {args.max_combo}")
    lines.append(f"- Weights: {weights}")
    lines.append("")
    lines.append("## Ranked hypothesis cards")
    lines.append("")
    for i, c in enumerate(top, 1):
        lines.append(f"## #{i} (score {c.score})")
        lines.append("")
        lines.append(render_hypothesis_card(c, drug_lookup))
        lines.append("")
    if not args.no_grant and top:
        lines.append("---")
        lines.append("")
        lines.append("## Grant-ready specific aims (top-1)")
        lines.append("")
        lines.append(render_grant_aims(top[0], drug_lookup))
        lines.append("")
        lines.append("## KDA / KSAD abstract draft (top-1)")
        lines.append("")
        lines.append(render_kda_abstract(top[0], drug_lookup))

    out_path.write_text("\n".join(lines), encoding="utf-8")
    print(f"[ok] wrote {out_path}", file=sys.stderr)

    if args.csv:
        csv_path = Path(args.csv)
        with csv_path.open("w", newline="", encoding="utf-8") as f:
            w = csv.writer(f)
            w.writerow(["rank", "score", "drug_combo", "classes", "stage", "biomarker", "population",
                        "pubmed_n", "trend_5y_pct", "active_trials_kr",
                        "kdigo_strength_score", "mechanism_plausibility",
                        "kr_iis_feasibility", "faers_caution"])
            for i, c in enumerate(cells, 1):
                w.writerow([i, c.score, "+".join(c.drug_combo), "+".join(c.drug_combo_classes),
                            c.stage, c.biomarker, c.population,
                            c.pubmed_n, c.trend_5y_pct, c.active_trials_kr,
                            c.kdigo_strength_score, c.mechanism_plausibility,
                            c.kr_iis_feasibility, c.faers_caution])
        print(f"[ok] wrote {csv_path}", file=sys.stderr)

    # quick stdout summary
    print(f"\nTop 5 unexplored DKD cells (synthetic, research-use only):\n")
    for i, c in enumerate(top[:5], 1):
        names = "+".join(drug_lookup[i].name for i in c.drug_combo)
        print(f"  #{i} score={c.score} | {names} | {c.stage} | {c.biomarker} | {c.population}")
    print(f"\nFull report: {out_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
