#!/usr/bin/env python3
"""ObesityWithdrawnRevive-Kor -- CLI

Domain: Obesity
Category: Research idea generation (hypothesis generation, literature-gap
          analysis, ontology-driven derivation).

What it does (standard-library only, fully offline):
  1. Withdrawn / discontinued anti-obesity drug registry
  2. harm-mechanism ontology mapping (5-HT2B->valvulopathy, CB1->CNS,
     sympathetic->CV, non-selective 5-HT->PPH, ...)
  3. de-risking hypothesis generation (peripheral restriction, biased/partial
     agonism, subtype selectivity, tissue-selective delivery, vulnerable-group
     exclusion, dose redesign) with plausibility + precedent
  4. historical-harm inheritance tracking for next-gen candidates
  5. hypothesis ranking + minimal validation-design draft

DISCLAIMER: Reference / research hypothesis-generation tool only.
NOT clinical decision support. NOT medical advice. Curated from public
historical facts; scored fields are SYNTHETIC DEMO content.
"""

import argparse
import json
import os
import sys

DATA_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                         "data", "withdrawn_registry.json")

DISCLAIMER = ("Reference / research hypothesis-generation tool only. "
              "NOT clinical decision support. NOT medical advice.")


# --------------------------------------------------------------------------- #
# Data loading
# --------------------------------------------------------------------------- #
def load_data(path=DATA_PATH):
    if not os.path.exists(path):
        sys.stderr.write("ERROR: data file not found: %s\n" % path)
        sys.exit(2)
    with open(path, "r", encoding="utf-8") as fh:
        data = json.load(fh)
    # index helpers
    data["_drug_by_id"] = {d["id"]: d for d in data["drugs"]}
    data["_harm_by_id"] = {h["id"]: h for h in data["harm_mechanisms"]}
    return data


# --------------------------------------------------------------------------- #
# Core logic
# --------------------------------------------------------------------------- #
def harm_names(data, harm_ids):
    return [data["_harm_by_id"][h]["name"] for h in harm_ids
            if h in data["_harm_by_id"]]


def build_hypotheses(data, drug):
    """For one drug, generate de-risking hypotheses scored by plausibility,
    target-relevance, and evidence strength of the original harm."""
    levers = data["de_risking_levers"]
    out = []
    grade_w = {"established (historical)": 1.0,
               "reported (historical)": 0.85,
               "illustrative (synthetic demo)": 0.6}
    g = grade_w.get(drug.get("evidence_grade", ""), 0.75)
    for lk in drug.get("de_risking_levers", []):
        lv = levers.get(lk)
        if not lv:
            continue
        # relevance: how many of the drug's harm nodes this lever targets
        relevant = [h for h in drug["harm_nodes"] if h in lv.get("best_for", [])]
        relevance = (len(relevant) / max(1, len(drug["harm_nodes"]))) if relevant else 0.25
        score = round(lv["plausibility"] * (0.5 + 0.5 * relevance) * g, 3)
        out.append({
            "drug": drug["name"],
            "drug_id": drug["id"],
            "lever_key": lk,
            "lever": lv["name"],
            "targets_kept": ", ".join(drug["targets"]),
            "harm_addressed": harm_names(data, relevant) or harm_names(data, drug["harm_nodes"]),
            "description": lv["description"],
            "plausibility": lv["plausibility"],
            "precedent": lv["precedent"],
            "validation_design": lv["validation_design"],
            "score": score,
        })
    out.sort(key=lambda x: x["score"], reverse=True)
    return out


def all_hypotheses(data):
    out = []
    for d in data["drugs"]:
        out.extend(build_hypotheses(data, d))
    out.sort(key=lambda x: x["score"], reverse=True)
    return out


def inheritance_risk(data, query):
    """Match a free-text next-gen candidate descriptor (class/target) against
    historical harm classes and return inherited-risk flags."""
    q = query.lower()
    central_hint = any(t in q for t in
                       ["central", "cns", "brain-penetrant", "centrally"])
    peripheral_hint = any(t in q for t in
                          ["peripheral", "bbb-impermeant", "peripherally",
                           "gut-restricted"])
    flags = []
    for cls in data["inheritance_classes"]:
        if any(term in q for term in cls["match_terms"]):
            # central-dependent harm gets downgraded if peripherally restricted
            if cls["central_required_for_harm"]:
                if peripheral_hint:
                    level = "MITIGATED (peripheral restriction claimed -- verify Kp,uu)"
                elif central_hint:
                    level = "HIGH (central action stated)"
                else:
                    level = "FLAG (central action unknown -- assume risk until proven)"
            else:
                level = "HIGH (mechanism-intrinsic, not central-dependent)"
            flags.append({
                "class": cls["class_key"],
                "risk_level": level,
                "harm_nodes": harm_names(data, cls["harm_nodes"]),
                "ancestor_drugs": [data["_drug_by_id"][a]["name"]
                                   for a in cls["ancestor_drugs"]
                                   if a in data["_drug_by_id"]],
                "monitoring_hypothesis": cls["monitoring_hypothesis"],
            })
    return flags


# --------------------------------------------------------------------------- #
# Rendering
# --------------------------------------------------------------------------- #
def banner(data):
    m = data["_meta"]
    print("=" * 72)
    print("ObesityWithdrawnRevive-Kor  |  Domain: %s" % m["domain"])
    print(m["category"])
    print("-" * 72)
    print("!! %s" % DISCLAIMER)
    print("=" * 72)


def cmd_list(data):
    banner(data)
    print("\nWITHDRAWN / DISCONTINUED ANTI-OBESITY REGISTRY (%d entries)\n"
          % len(data["drugs"]))
    for d in data["drugs"]:
        print("- %-26s [%s]" % (d["name"], d["class"]))
        print("    status   : %s (%s)" % (d["status"], d["event_year"]))
        print("    targets  : %s" % ", ".join(d["targets"]))
        print("    reason   : %s" % d["withdrawal_reason"])
        print("    harm     : %s" % "; ".join(harm_names(data, d["harm_nodes"])))
        print("    grade    : %s" % d["evidence_grade"])
        print("")


def cmd_summary(data):
    banner(data)
    drugs = data["drugs"]
    by_status = {}
    by_harm = {}
    for d in drugs:
        by_status[d["status"]] = by_status.get(d["status"], 0) + 1
        for h in d["harm_nodes"]:
            by_harm[h] = by_harm.get(h, 0) + 1
    print("\nSUMMARY")
    print("  total drugs/candidates : %d" % len(drugs))
    print("  harm mechanisms        : %d" % len(data["harm_mechanisms"]))
    print("  de-risking levers      : %d" % len(data["de_risking_levers"]))
    print("  inheritance classes    : %d" % len(data["inheritance_classes"]))
    print("\n  by status:")
    for k, v in sorted(by_status.items(), key=lambda x: -x[1]):
        print("    %2d  %s" % (v, k))
    print("\n  harm-mechanism frequency:")
    for hid, cnt in sorted(by_harm.items(), key=lambda x: -x[1]):
        print("    %2d  %s" % (cnt, data["_harm_by_id"][hid]["name"]))
    print("\n  open sources: %s" % "; ".join(data["_meta"]["open_sources"]))


def print_hypotheses(hyps, top):
    shown = hyps[:top]
    print("\nRANKED DE-RISKING HYPOTHESES (top %d of %d)\n"
          % (len(shown), len(hyps)))
    for i, h in enumerate(shown, 1):
        print("[%2d] score=%.3f  %s  <-  %s" %
              (i, h["score"], h["lever"], h["drug"]))
        print("     keep target(s): %s" % h["targets_kept"])
        print("     remove harm   : %s" % "; ".join(h["harm_addressed"]))
        print("     hypothesis    : %s" % h["description"])
        print("     precedent     : %s" % h["precedent"])
        print("     validate      : %s" % h["validation_design"])
        print("")


def cmd_top(data, top):
    banner(data)
    print_hypotheses(all_hypotheses(data), top)


def cmd_drug(data, name, top):
    banner(data)
    key = name.strip().lower()
    match = None
    for d in data["drugs"]:
        if key == d["id"] or key in d["name"].lower():
            match = d
            break
    if not match:
        print("\nNo drug matching '%s'. Use --list to see entries." % name)
        return
    print("\nDRUG: %s [%s]" % (match["name"], match["class"]))
    print("  status : %s (%s)" % (match["status"], match["event_year"]))
    print("  targets: %s" % ", ".join(match["targets"]))
    print("  reason : %s" % match["withdrawal_reason"])
    print("  harm   : %s" % "; ".join(harm_names(data, match["harm_nodes"])))
    print("  central-acting: %s" % match.get("central_acting"))
    print_hypotheses(build_hypotheses(data, match), top)


def cmd_inherit(data, query):
    banner(data)
    print("\nHISTORICAL-HARM INHERITANCE QUERY: \"%s\"\n" % query)
    flags = inheritance_risk(data, query)
    if not flags:
        print("  No historical harm class matched the descriptor.")
        print("  Tip: include class/target terms like 'CB1', '5-HT2B',")
        print("       'sympathomimetic', 'MetAP2', 'uncoupler', and whether")
        print("       the candidate is 'central' or 'peripherally restricted'.")
        return
    for f in flags:
        print("  CLASS %s  ->  inherited risk: %s" % (f["class"], f["risk_level"]))
        print("    harm nodes : %s" % "; ".join(f["harm_nodes"]))
        print("    ancestors  : %s" % ", ".join(f["ancestor_drugs"]))
        print("    monitoring : %s" % f["monitoring_hypothesis"])
        print("")


# --------------------------------------------------------------------------- #
# CLI
# --------------------------------------------------------------------------- #
def build_parser():
    p = argparse.ArgumentParser(
        prog="main.py",
        description=("ObesityWithdrawnRevive-Kor: harm-mechanism ontology + "
                     "de-risking hypothesis generator for withdrawn / "
                     "discontinued anti-obesity drugs. Offline, stdlib-only."),
        epilog=DISCLAIMER,
        formatter_class=argparse.RawDescriptionHelpFormatter)
    p.add_argument("--list", action="store_true",
                   help="list the withdrawn/discontinued drug registry")
    p.add_argument("--summary", action="store_true",
                   help="show registry + ontology summary statistics")
    p.add_argument("--top", type=int, metavar="N",
                   help="show top N ranked de-risking hypotheses (all drugs)")
    p.add_argument("--drug", type=str, metavar="NAME",
                   help="show one drug's harm map + ranked de-risking hypotheses")
    p.add_argument("--inherit-risk", type=str, metavar="CLASS/TARGET",
                   help=("free-text next-gen candidate descriptor "
                         "(e.g. 'central CB1 inverse agonist' or "
                         "'peripherally restricted CB1 agent'); returns "
                         "inherited historical-harm flags"))
    return p


def main(argv=None):
    parser = build_parser()
    args = parser.parse_args(argv)
    data = load_data()

    did = False
    if args.list:
        cmd_list(data); did = True
    if args.summary:
        cmd_summary(data); did = True
    if args.top is not None:
        cmd_top(data, max(1, args.top)); did = True
    if args.drug:
        cmd_drug(data, args.drug, top=6); did = True
    if args.inherit_risk:
        cmd_inherit(data, args.inherit_risk); did = True

    if not did:
        # default: banner + summary + top 5 so `python3 main.py` is useful
        cmd_summary(data)
        print_hypotheses(all_hypotheses(data), 5)
        print("\n(Use --help for options: --list --summary --top N --drug NAME "
              "--inherit-risk \"...\")")
    return 0


if __name__ == "__main__":
    sys.exit(main())
