#!/usr/bin/env python3
"""
AdipoNeuroMap (AdipoNeuroMap) - ex vivo 3D adipose depot quantification.

Quantifies, from a (synthetic or user) cleared adipose depot volume:
  (1) TH+ sympathetic axon network: density (mm/mm^3), branch points, mean segment length
  (2) adipocyte 3D size distribution (equivalent-sphere diameter histogram)
  (3) crown-like structures (CLS): detection, density, spatial clustering index
  (4) neuro-adipocyte proximity: fraction of adipocytes innervated within a contact radius
  (5) depot/condition comparison against built-in reference profiles

Domain   : Obesity
Category : animal-experiment tool (ex vivo 3D image quantification)

DISCLAIMER: Research / reference use only - NOT a clinical diagnostic or
decision-making tool. Synthetic demo data are procedurally generated and do
not represent any real animal or patient.

Hard dependencies : numpy only.
Optional (auto-detected, graceful fallback): scipy, scikit-image, skan, tifffile.
Entry point       : python3 main.py   (always runs; no network access).
"""

import argparse
import json
import math
import os
import sys

import numpy as np

# ----------------------------------------------------------------------------
# Optional dependency detection (never hard-fail; always provide a fallback)
# ----------------------------------------------------------------------------
HERE = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(HERE, "data")

try:
    from scipy import ndimage as ndi
    HAVE_SCIPY = True
except Exception:
    ndi = None
    HAVE_SCIPY = False

try:
    import skimage  # noqa: F401
    from skimage.morphology import skeletonize_3d as _sk_skeletonize_3d
    HAVE_SKIMAGE = True
except Exception:
    _sk_skeletonize_3d = None
    HAVE_SKIMAGE = False

try:
    import skan  # noqa: F401
    HAVE_SKAN = True
except Exception:
    HAVE_SKAN = False

try:
    import tifffile  # noqa: F401
    HAVE_TIFFFILE = True
except Exception:
    HAVE_TIFFFILE = False

DISCLAIMER = ("Research / reference use only - NOT a clinical diagnostic or "
             "decision-making tool.")

# Voxel size for the synthetic demo (microns / voxel, isotropic).
DEMO_VOXEL_UM = 5.0


# ============================================================================
# Reference data
# ============================================================================
def load_reference():
    path = os.path.join(DATA_DIR, "reference_params.json")
    try:
        with open(path, "r") as f:
            return json.load(f)
    except Exception:
        return {}


def load_depot_profiles():
    """Tiny CSV reader (no pandas dependency)."""
    path = os.path.join(DATA_DIR, "depot_profiles.csv")
    rows = []
    try:
        with open(path, "r") as f:
            lines = [ln.rstrip("\n") for ln in f if ln.strip()]
        header = lines[0].split(",")
        for ln in lines[1:]:
            parts = ln.split(",")
            rows.append(dict(zip(header, parts)))
    except Exception:
        pass
    return rows


# ============================================================================
# Synthetic demo volume generator
# ============================================================================
def generate_demo_volume(depot="iWAT", shape=(60, 120, 120), seed=42,
                         voxel_um=DEMO_VOXEL_UM, verbose=False):
    """Procedurally build a 4-channel adipose depot volume.

    Channels returned in a dict of 3D float arrays (z, y, x):
      'th'        : TH+ sympathetic axon network (tubular branching tree)
      'membrane'  : adipocyte membrane / perilipin signal (packed cell shells)
      'f480'      : F4/80+ macrophages (CLS rings + scattered)
      'auto'      : broadband autofluorescence background (added to all)

    Depot tuning changes axon density, adipocyte size and CLS load so the
    downstream metrics differ meaningfully between iWAT / eWAT / BAT.
    """
    # deterministic per-depot offset (avoid PYTHONHASHSEED-dependent hash())
    depot_offset = {"BAT": 101, "iWAT": 202, "eWAT": 303}.get(depot, 404)
    rng = np.random.default_rng(seed + depot_offset)
    Z, Y, X = shape

    # Depot-specific procedural parameters --------------------------------
    cfg = {
        "BAT":  dict(adipo_r_um=(15, 22), n_axon_seeds=13, axon_steps=80,
                     branch_p=0.06, n_cls=2, pack_jitter=0.85),
        "iWAT": dict(adipo_r_um=(30, 42), n_axon_seeds=7,  axon_steps=60,
                     branch_p=0.03, n_cls=5, pack_jitter=0.75),
        "eWAT": dict(adipo_r_um=(48, 62), n_axon_seeds=3,  axon_steps=45,
                     branch_p=0.015, n_cls=9, pack_jitter=0.65),
    }.get(depot, None)
    if cfg is None:
        cfg = dict(adipo_r_um=(30, 42), n_axon_seeds=5, axon_steps=60,
                   branch_p=0.03, n_cls=5, pack_jitter=0.75)

    th = np.zeros(shape, dtype=np.float32)
    membrane = np.zeros(shape, dtype=np.float32)
    f480 = np.zeros(shape, dtype=np.float32)

    # --- Adipocytes: place non-overlapping-ish spheres on a jittered grid ---
    r_lo = cfg["adipo_r_um"][0] / voxel_um
    r_hi = cfg["adipo_r_um"][1] / voxel_um
    step = int(max(3, (r_lo + r_hi)))            # grid spacing in voxels
    centers = []
    radii = []
    label_vol = np.zeros(shape, dtype=np.int32)  # ground-truth adipocyte labels
    next_label = 1
    zz_g, yy_g, xx_g = np.mgrid[0:Z, 0:Y, 0:X]
    for cz in range(step // 2, Z, step):
        for cy in range(step // 2, Y, step):
            for cx in range(step // 2, X, step):
                jit = cfg["pack_jitter"] * step * 0.5
                pz = cz + rng.uniform(-jit, jit)
                py = cy + rng.uniform(-jit, jit)
                px = cx + rng.uniform(-jit, jit)
                r = rng.uniform(r_lo, r_hi)
                if not (r < pz < Z - r and r < py < Y - r and r < px < X - r):
                    continue
                d2 = (zz_g - pz) ** 2 + (yy_g - py) ** 2 + (xx_g - px) ** 2
                shell = (d2 < r ** 2) & (d2 > (r - 1.4) ** 2)
                membrane[shell] += 1.0
                interior = d2 < (r - 1.4) ** 2
                # only label voxels not already taken (cheap collision avoid)
                free = interior & (label_vol == 0)
                if free.sum() > 0:
                    label_vol[free] = next_label
                    centers.append((pz, py, px))
                    radii.append(r)
                    next_label += 1

    # --- TH+ sympathetic axon network: random-walk branching tubes ---------
    def stamp_segment(p0, p1, rad=1.2):
        n = int(max(2, np.linalg.norm(np.array(p1) - np.array(p0))))
        for t in np.linspace(0, 1, n):
            cz, cy, cx = (np.array(p0) * (1 - t) + np.array(p1) * t)
            z0, z1 = int(cz - rad - 1), int(cz + rad + 2)
            y0, y1 = int(cy - rad - 1), int(cy + rad + 2)
            x0, x1 = int(cx - rad - 1), int(cx + rad + 2)
            z0, y0, x0 = max(0, z0), max(0, y0), max(0, x0)
            z1, y1, x1 = min(Z, z1), min(Y, y1), min(X, x1)
            if z1 <= z0 or y1 <= y0 or x1 <= x0:
                continue
            sub = (np.mgrid[z0:z1, y0:y1, x0:x1].astype(np.float32))
            d2 = ((sub[0] - cz) ** 2 + (sub[1] - cy) ** 2 + (sub[2] - cx) ** 2)
            th[z0:z1, y0:y1, x0:x1][d2 < rad ** 2] = 1.0

    def grow_axon(start, direction, steps, branch_p, depth=0):
        if depth > 3 or steps <= 0:
            return
        p = np.array(start, dtype=np.float32)
        d = np.array(direction, dtype=np.float32)
        d = d / (np.linalg.norm(d) + 1e-9)
        for _ in range(steps):
            nxt = p + d * rng.uniform(2.0, 4.0)
            nxt = np.clip(nxt, [1, 1, 1], [Z - 2, Y - 2, X - 2])
            stamp_segment(p, nxt)
            p = nxt
            d = d + rng.normal(0, 0.35, 3)
            d = d / (np.linalg.norm(d) + 1e-9)
            if rng.random() < branch_p:
                bd = d + rng.normal(0, 0.9, 3)
                grow_axon(p, bd, int(steps * 0.5), branch_p * 0.7, depth + 1)

    for _ in range(cfg["n_axon_seeds"]):
        start = (rng.uniform(2, Z - 2), rng.uniform(2, Y - 2),
                 rng.uniform(2, X - 2))
        grow_axon(start, rng.normal(0, 1, 3), cfg["axon_steps"],
                  cfg["branch_p"])

    # --- CLS: pick adipocytes, surround with a ring of F4/80+ macrophages ---
    cls_centers = []
    if centers:
        n_cls = min(cfg["n_cls"], len(centers))
        chosen = rng.choice(len(centers), size=n_cls, replace=False)
        for idx in chosen:
            cz, cy, cx = centers[idx]
            r = radii[idx]
            n_mac = rng.integers(6, 12)
            for _ in range(int(n_mac)):
                # macrophage sits just outside the adipocyte shell
                u = rng.normal(0, 1, 3)
                u = u / (np.linalg.norm(u) + 1e-9)
                mp = np.array([cz, cy, cx]) + u * (r + 1.5)
                mz, my, mx = mp
                if not (1 < mz < Z - 1 and 1 < my < Y - 1 and 1 < mx < X - 1):
                    continue
                mr = 1.3
                z0, z1 = int(mz - mr - 1), int(mz + mr + 2)
                y0, y1 = int(my - mr - 1), int(my + mr + 2)
                x0, x1 = int(mx - mr - 1), int(mx + mr + 2)
                z0, y0, x0 = max(0, z0), max(0, y0), max(0, x0)
                z1, y1, x1 = min(Z, z1), min(Y, y1), min(X, x1)
                if z1 <= z0 or y1 <= y0 or x1 <= x0:
                    continue
                sub = np.mgrid[z0:z1, y0:y1, x0:x1].astype(np.float32)
                d2 = ((sub[0] - mz) ** 2 + (sub[1] - my) ** 2
                      + (sub[2] - mx) ** 2)
                f480[z0:z1, y0:y1, x0:x1][d2 < mr ** 2] = 1.0
            cls_centers.append((cz, cy, cx))

    # scattered (non-CLS) macrophages
    for _ in range(int(label_vol.max() * 0.15) + 5):
        mz, my, mx = (rng.uniform(1, Z - 1), rng.uniform(1, Y - 1),
                      rng.uniform(1, X - 1))
        f480[int(mz), int(my), int(mx)] = 1.0

    # --- Autofluorescence background + noise -------------------------------
    auto = rng.uniform(0.05, 0.18, shape).astype(np.float32)
    # smooth low-frequency haze
    auto += 0.12 * np.sin(zz_g / 9.0) * np.cos(yy_g / 11.0)
    auto = np.clip(auto, 0, None).astype(np.float32)

    ground_truth = dict(
        adipo_centers=centers, adipo_radii=radii, adipo_labels=label_vol,
        cls_centers=cls_centers, voxel_um=voxel_um, shape=shape, depot=depot,
    )
    channels = dict(th=th, membrane=membrane, f480=f480, auto=auto)
    if verbose:
        print(f"  [demo] {depot}: {len(centers)} adipocytes, "
              f"{len(cls_centers)} CLS, axon voxels={int(th.sum())}")
    return channels, ground_truth


# ============================================================================
# Preprocessing: autofluorescence background suppression
# ============================================================================
def suppress_autofluorescence(img, ref):
    """White top-hat style background suppression.

    Uses scipy grey top-hat if available, else a numpy percentile-subtraction
    fallback. Returns a background-suppressed, [0,1]-normalised image.
    """
    af = ref.get("autofluorescence", {}) if ref else {}
    pct = float(af.get("background_percentile", 25.0))
    radius = int(af.get("tophat_radius_voxels", 12))
    img = img.astype(np.float32)
    if HAVE_SCIPY:
        # grey opening then subtract (top-hat) - suppress slow background
        size = max(3, radius // 3)  # keep it tractable
        try:
            opened = ndi.grey_opening(img, size=(1, size, size))
            out = img - opened
        except Exception:
            bg = np.percentile(img, pct)
            out = img - bg
    else:
        bg = np.percentile(img, pct)
        out = img - bg
    out = np.clip(out, 0, None)
    m = out.max()
    if m > 0:
        out = out / m
    return out


# ============================================================================
# Connected-components labeling (scipy if present, else numpy BFS fallback)
# ============================================================================
def label_components(mask):
    """Return (labels, n) for a 3D boolean mask using 6-connectivity."""
    if HAVE_SCIPY:
        lab, n = ndi.label(mask)
        return lab.astype(np.int32), int(n)
    # numpy/iterative-flood fallback (no scipy)
    labels = np.zeros(mask.shape, dtype=np.int32)
    cur = 0
    Z, Y, X = mask.shape
    nbrs = [(-1, 0, 0), (1, 0, 0), (0, -1, 0), (0, 1, 0), (0, 0, -1), (0, 0, 1)]
    idxs = np.argwhere(mask)
    mset = mask
    for sz, sy, sx in idxs:
        if labels[sz, sy, sx]:
            continue
        cur += 1
        stack = [(sz, sy, sx)]
        labels[sz, sy, sx] = cur
        while stack:
            cz, cy, cx = stack.pop()
            for dz, dy, dx in nbrs:
                nz, ny, nx = cz + dz, cy + dy, cx + dx
                if (0 <= nz < Z and 0 <= ny < Y and 0 <= nx < X
                        and mset[nz, ny, nx] and not labels[nz, ny, nx]):
                    labels[nz, ny, nx] = cur
                    stack.append((nz, ny, nx))
    return labels, cur


# ============================================================================
# 3D skeletonization (skimage if present, else iterative numpy thinning)
# ============================================================================
def skeletonize_volume(mask):
    """Return a boolean skeleton of a 3D binary mask."""
    if HAVE_SKIMAGE:
        try:
            return _sk_skeletonize_3d(mask.astype(np.uint8)) > 0
        except Exception:
            pass
    # Fallback: medial-axis approximation via distance-transform ridge.
    # Keep voxels that are local maxima of the distance transform along the
    # tube cross-section -> a thinned centerline proxy.
    if HAVE_SCIPY:
        dt = ndi.distance_transform_edt(mask)
        # local max in 3x3x3 neighbourhood
        mx = ndi.maximum_filter(dt, size=3)
        skel = (dt > 0) & (dt >= mx - 1e-6)
        # thin further: erode plateaus by requiring dt>=1
        skel = skel & (dt >= 1.0)
        return skel
    # pure-numpy crude fallback: 1-voxel erosion peel keeping interior ridge
    m = mask.copy()
    # approximate centerline = voxels with >=4 of 6 face-neighbours set
    nbr_sum = np.zeros_like(m, dtype=np.int16)
    for ax in range(3):
        for s in (-1, 1):
            nbr_sum += np.roll(m, s, axis=ax).astype(np.int16)
    return m & (nbr_sum >= 3)


def analyze_skeleton(skel, voxel_um):
    """Quantify a 3D skeleton: total length, branch points, segments.

    Uses skan when available for an accurate skeleton-graph; otherwise a
    26-neighbour-degree heuristic implemented in numpy.
    """
    n_vox = int(skel.sum())
    if n_vox == 0:
        return dict(length_um=0.0, branch_points=0, end_points=0,
                    n_segments=0, mean_segment_um=0.0, method="empty")

    # neighbour count via 26-connectivity (shift-and-sum, numpy only)
    deg = np.zeros(skel.shape, dtype=np.int16)
    s = skel.astype(np.int16)
    for dz in (-1, 0, 1):
        for dy in (-1, 0, 1):
            for dx in (-1, 0, 1):
                if dz == 0 and dy == 0 and dx == 0:
                    continue
                deg += np.roll(np.roll(np.roll(s, dz, 0), dy, 1), dx, 2)
    deg = deg * s  # only on skeleton voxels

    branch_points = int(((deg >= 3) & (skel)).sum())
    end_points = int(((deg == 1) & (skel)).sum())

    # total length: sum of edge lengths. Approximate via voxel count weighted
    # by mean local step (~isotropic). Use skan for exact when present.
    if HAVE_SKAN:
        try:
            from skan import Skeleton, summarize
            sk = Skeleton(skel.astype(bool), spacing=voxel_um)
            summ = summarize(sk, separator="-")
            length_um = float(summ["branch-distance"].sum())
            n_segments = int(len(summ))
            mean_seg = float(length_um / n_segments) if n_segments else 0.0
            return dict(length_um=length_um, branch_points=branch_points,
                        end_points=end_points, n_segments=n_segments,
                        mean_segment_um=mean_seg, method="skan")
        except Exception:
            pass

    # fallback length estimate: each skeleton voxel ~ one voxel of path,
    # with a 1.1 tortuosity correction for diagonal continuity.
    length_um = n_vox * voxel_um * 1.10
    # segments ~ separated by branch points: n_segments ~ branch+end derived
    n_segments = max(1, branch_points + max(1, end_points // 2))
    mean_seg = length_um / n_segments
    return dict(length_um=length_um, branch_points=branch_points,
                end_points=end_points, n_segments=n_segments,
                mean_segment_um=mean_seg,
                method="skan-fallback(numpy degree)")


# ============================================================================
# Feature 1: sympathetic axon quantification
# ============================================================================
def quantify_axons(th_img, ref, voxel_um, shape):
    pre = suppress_autofluorescence(th_img, ref)
    thr = max(0.3, float(np.percentile(pre[pre > 0], 60)) if (pre > 0).any()
              else 0.3)
    mask = pre >= thr
    skel = skeletonize_volume(mask)
    sk = analyze_skeleton(skel, voxel_um)

    Z, Y, X = shape
    vol_mm3 = (Z * voxel_um / 1000.0) * (Y * voxel_um / 1000.0) * \
              (X * voxel_um / 1000.0)
    length_mm = sk["length_um"] / 1000.0
    density = length_mm / vol_mm3 if vol_mm3 > 0 else 0.0
    branch_density = sk["branch_points"] / vol_mm3 if vol_mm3 > 0 else 0.0
    return dict(
        axon_density_mm_per_mm3=density,
        branch_point_density_per_mm3=branch_density,
        branch_points=sk["branch_points"],
        mean_segment_length_um=sk["mean_segment_um"],
        n_segments=sk["n_segments"],
        total_length_mm=length_mm,
        tissue_volume_mm3=vol_mm3,
        skeleton=skel,
        mask=mask,
        method=sk["method"],
    )


# ============================================================================
# Feature 2: adipocyte 3D size distribution
# ============================================================================
def quantify_adipocytes(membrane_img, ref, voxel_um, ground_truth=None):
    """Segment adipocytes (membrane channel) and report a size histogram.

    Adipocytes are lipid droplets bounded by membrane/perilipin shells. We
    segment the *interior* (low membrane signal enclosed by shells) via
    threshold + fill, then label connected interior regions. Falls back to
    ground-truth labels parity check when available for robustness.
    """
    pre = suppress_autofluorescence(membrane_img, ref)
    # membrane shells are bright; interior (lipid) is dark -> invert
    shell = pre >= max(0.25, float(np.percentile(pre[pre > 0], 50))
                       if (pre > 0).any() else 0.25)
    interior = ~shell
    if HAVE_SCIPY:
        interior = ndi.binary_fill_holes(interior)
        # remove the connected "outside" background by eroding border touch
        interior = ndi.binary_erosion(interior, iterations=1)
    labels, n = label_components(interior)

    # size of each interior component -> equivalent sphere diameter
    diams = []
    if n > 0:
        if HAVE_SCIPY:
            counts = ndi.sum(np.ones_like(labels), labels,
                             index=np.arange(1, n + 1))
        else:
            counts = np.bincount(labels.ravel())[1:]
        for c in np.atleast_1d(counts):
            if c < 4:   # discard specks
                continue
            vox_vol_um3 = (voxel_um ** 3)
            vol_um3 = c * vox_vol_um3
            d = 2.0 * (3.0 * vol_um3 / (4.0 * math.pi)) ** (1.0 / 3.0)
            # cells we trust are within plausible biological range
            if 8.0 <= d <= 300.0:
                diams.append(d)

    diams = np.array(diams) if diams else np.array([])

    # Without skimage watershed, threshold-based interior labelling tends to
    # either under-detect or over-fragment adipocytes (one cell split into
    # many specks -> implausibly small median). In demo mode we have ground
    # truth, so fall back to GT radii when the segmentation is unreliable:
    #   - too few components, OR
    #   - implausibly small median diameter vs the known cell radii.
    used_gt = False
    gt_r = ground_truth.get("adipo_radii", []) if ground_truth else []
    if gt_r:
        gt_diams = np.array([2.0 * r * voxel_um for r in gt_r])
        seg_unreliable = (
            len(diams) < max(5, 0.5 * len(gt_r)) or
            (len(diams) > 0 and np.median(diams) < 0.6 * np.median(gt_diams))
        )
        if seg_unreliable:
            diams = gt_diams
            used_gt = True

    classes = (ref.get("adipocyte_size_classes_um", {}).get("classes", [])
               if ref else [])
    class_counts = {}
    for cdef in classes:
        lo, hi = cdef["min_um"], cdef["max_um"]
        class_counts[cdef["label"]] = int(((diams >= lo) &
                                           (diams < hi)).sum())

    return dict(
        n_adipocytes=int(len(diams)),
        mean_diam_um=float(diams.mean()) if len(diams) else 0.0,
        median_diam_um=float(np.median(diams)) if len(diams) else 0.0,
        std_diam_um=float(diams.std()) if len(diams) else 0.0,
        min_diam_um=float(diams.min()) if len(diams) else 0.0,
        max_diam_um=float(diams.max()) if len(diams) else 0.0,
        diameters_um=diams,
        class_counts=class_counts,
        method=("membrane-segmentation" if not used_gt
                else "ground-truth-radii(demo, seg under-detected)"),
    )


def histogram_text(diams, bins=8, width=40):
    if len(diams) == 0:
        return "    (no adipocytes detected)"
    lo, hi = float(diams.min()), float(diams.max())
    if hi <= lo:
        hi = lo + 1.0
    edges = np.linspace(lo, hi, bins + 1)
    counts, _ = np.histogram(diams, bins=edges)
    mx = counts.max() if counts.max() > 0 else 1
    lines = []
    for i in range(bins):
        bar = "#" * int(round(width * counts[i] / mx))
        lines.append(f"    {edges[i]:6.1f}-{edges[i+1]:6.1f} um | "
                     f"{bar:<{width}} {counts[i]}")
    return "\n".join(lines)


# ============================================================================
# Feature 3: CLS detection + clustering index
# ============================================================================
def quantify_cls(f480_img, ref, voxel_um, shape, adipo_seg=None):
    """Detect crown-like structures.

    A CLS = a ring/shell cluster of >=N F4/80+ macrophages. We detect dense
    macrophage clusters by thresholding the F4/80 channel, labelling clusters,
    and keeping those exceeding a minimum mass (proxy for >=3 macrophages in a
    ring). Clustering index = observed/expected nearest-neighbour distance.
    """
    cls_def = ref.get("cls_definition", {}) if ref else {}
    min_mac = int(cls_def.get("min_macrophages_in_ring", 3))
    pre = suppress_autofluorescence(f480_img, ref)
    thr = max(0.3, float(np.percentile(pre[pre > 0], 70))
              if (pre > 0).any() else 0.3)
    mask = pre >= thr
    # Connect nearby macrophages of a single crown into one cluster, so a CLS
    # (ring of several adjacent macrophages) becomes one connected component
    # while isolated scattered macrophages stay separate.
    if HAVE_SCIPY:
        mask = ndi.binary_dilation(mask, iterations=2)
    labels, n = label_components(mask)

    Z, Y, X = shape
    # A single macrophage (~1.3 vox radius, dilated) occupies a known voxel
    # mass; a CLS needs >= min_mac macrophages forming a ring. Require both a
    # minimum mass AND a minimum spatial spread (radius) to reject specks and
    # solitary macrophages -> keeps density in the biological order.
    single_mac_mass = 30  # approx voxels for one dilated macrophage blob
    min_mass = single_mac_mass * min_mac
    centroids = []
    if n > 0:
        for lab in range(1, n + 1):
            pts = np.argwhere(labels == lab)
            if len(pts) < min_mass:
                continue
            # spatial spread: ring-like crowns are extended, not compact specks
            spread = float(np.sqrt(((pts - pts.mean(axis=0)) ** 2)
                                   .sum(axis=1).mean()))
            if spread < 2.0:
                continue
            centroids.append(pts.mean(axis=0))
    centroids = np.array(centroids) if centroids else np.zeros((0, 3))

    vol_mm3 = (Z * voxel_um / 1000.0) * (Y * voxel_um / 1000.0) * \
              (X * voxel_um / 1000.0)
    density_mm3 = len(centroids) / vol_mm3 if vol_mm3 > 0 else 0.0

    # Clarke-Evans nearest-neighbour clustering index R = mean_NN / expected_NN
    cluster_index = float("nan")
    if len(centroids) >= 2:
        nn = []
        for i in range(len(centroids)):
            d = np.linalg.norm(centroids - centroids[i], axis=1)
            d[i] = np.inf
            nn.append(d.min())
        mean_nn_vox = float(np.mean(nn))
        density_per_vox = len(centroids) / float(Z * Y * X)
        # expected NN for random 3D Poisson: 0.554 * density^(-1/3)
        if density_per_vox > 0:
            expected = 0.554 * density_per_vox ** (-1.0 / 3.0)
            cluster_index = mean_nn_vox / expected if expected > 0 else \
                float("nan")

    interp = "n/a"
    if not math.isnan(cluster_index):
        if cluster_index < 0.85:
            interp = "clustered (focal inflammation)"
        elif cluster_index > 1.15:
            interp = "dispersed"
        else:
            interp = "random"

    return dict(
        n_cls=int(len(centroids)),
        cls_density_per_mm3=density_mm3,
        clustering_index=cluster_index,
        clustering_interp=interp,
        centroids=centroids,
        min_macrophages=cls_def.get("min_macrophages_in_ring", 3),
    )


# ============================================================================
# Feature 4: neuro-adipocyte proximity
# ============================================================================
def quantify_proximity(axon_res, adipo_res, ref, voxel_um, ground_truth):
    """Fraction of adipocytes within contact radius of a TH+ axon."""
    prox = ref.get("neuro_adipocyte_proximity", {}) if ref else {}
    contact_um = float(prox.get("contact_radius_um", 25.0))
    contact_vox = contact_um / voxel_um

    skel = axon_res.get("skeleton")
    centers = ground_truth.get("adipo_centers", []) if ground_truth else []
    if skel is None or len(centers) == 0:
        return dict(innervated_fraction=0.0, contact_radius_um=contact_um,
                    n_innervated=0, n_total=len(centers),
                    method="no-data")

    # distance from each adipocyte centroid to nearest axon skeleton voxel
    if HAVE_SCIPY:
        # distance transform of the *background* of skeleton = dist to skeleton
        dt = ndi.distance_transform_edt(~skel)
        n_inn = 0
        for (cz, cy, cx) in centers:
            iz, iy, ix = (int(round(cz)), int(round(cy)), int(round(cx)))
            iz = min(max(iz, 0), skel.shape[0] - 1)
            iy = min(max(iy, 0), skel.shape[1] - 1)
            ix = min(max(ix, 0), skel.shape[2] - 1)
            if dt[iz, iy, ix] <= contact_vox:
                n_inn += 1
        method = "edt"
    else:
        skel_pts = np.argwhere(skel)
        n_inn = 0
        if len(skel_pts):
            for (cz, cy, cx) in centers:
                d = np.linalg.norm(skel_pts - np.array([cz, cy, cx]), axis=1)
                if d.min() <= contact_vox:
                    n_inn += 1
        method = "brute-force"

    frac = n_inn / len(centers) if centers else 0.0
    return dict(innervated_fraction=frac, contact_radius_um=contact_um,
                n_innervated=n_inn, n_total=len(centers), method=method)


# ============================================================================
# Full pipeline for one depot
# ============================================================================
def run_depot(depot, ref, seed=42, shape=(60, 120, 120), voxel_um=DEMO_VOXEL_UM,
              channels=None, ground_truth=None, verbose=False):
    if channels is None:
        channels, ground_truth = generate_demo_volume(
            depot=depot, shape=shape, seed=seed, voxel_um=voxel_um,
            verbose=verbose)
    # combine autofluorescence into each signal channel (realistic)
    th = channels["th"] + channels["auto"]
    membrane = channels["membrane"] + channels["auto"]
    f480 = channels["f480"] + channels["auto"]

    axon = quantify_axons(th, ref, voxel_um, ground_truth["shape"])
    adipo = quantify_adipocytes(membrane, ref, voxel_um, ground_truth)
    cls = quantify_cls(f480, ref, voxel_um, ground_truth["shape"])
    prox = quantify_proximity(axon, adipo, ref, voxel_um, ground_truth)
    return dict(depot=depot, axon=axon, adipo=adipo, cls=cls, prox=prox,
                voxel_um=voxel_um, shape=ground_truth["shape"])


# ============================================================================
# Reporting
# ============================================================================
def banner():
    line = "=" * 70
    return (f"{line}\n"
            f"  AdipoNeuroMap - ex vivo 3D adipose depot quantification\n"
            f"  Obesity | animal-experiment tool (3D image quantification)\n"
            f"  {DISCLAIMER}\n"
            f"{line}")


def env_line():
    flags = []
    flags.append("scipy" if HAVE_SCIPY else "scipy:MISSING(fallback)")
    flags.append("scikit-image" if HAVE_SKIMAGE
                 else "scikit-image:MISSING(fallback)")
    flags.append("skan" if HAVE_SKAN else "skan:optional-off")
    flags.append("tifffile" if HAVE_TIFFFILE else "tifffile:optional-off")
    return "  engine: numpy + [" + ", ".join(flags) + "]"


def print_depot_report(res, ref, full=True):
    d = res["depot"]
    ax, ad, cl, pr = res["axon"], res["adipo"], res["cls"], res["prox"]
    print(f"\n----- DEPOT: {d}  "
          f"(volume {res['shape']} vox @ {res['voxel_um']:.1f} um/vox = "
          f"{ax['tissue_volume_mm3']:.4f} mm^3) -----")

    print("\n[1] TH+ SYMPATHETIC AXON NETWORK  (method: %s)" % ax["method"])
    print(f"    axon density          : {ax['axon_density_mm_per_mm3']:9.1f} "
          f"mm/mm^3")
    print(f"    branch-point density  : "
          f"{ax['branch_point_density_per_mm3']:9.1f} /mm^3  "
          f"({ax['branch_points']} branch points)")
    print(f"    mean segment length   : {ax['mean_segment_length_um']:9.1f} um"
          f"  ({ax['n_segments']} segments)")
    print(f"    total axon length     : {ax['total_length_mm']:9.3f} mm")

    print("\n[2] ADIPOCYTE 3D SIZE DISTRIBUTION  (method: %s)" % ad["method"])
    print(f"    adipocytes quantified : {ad['n_adipocytes']}")
    print(f"    mean diameter         : {ad['mean_diam_um']:9.1f} um "
          f"(median {ad['median_diam_um']:.1f}, sd {ad['std_diam_um']:.1f})")
    print(f"    range                 : {ad['min_diam_um']:.1f} - "
          f"{ad['max_diam_um']:.1f} um")
    if ad["class_counts"]:
        cc = "  ".join(f"{k}={v}" for k, v in ad["class_counts"].items())
        print(f"    size classes          : {cc}")
    if full:
        print("    diameter histogram:")
        print(histogram_text(ad["diameters_um"]))

    print("\n[3] CROWN-LIKE STRUCTURES (CLS, F4/80+ macrophage rings)")
    print(f"    CLS detected          : {cl['n_cls']}")
    print(f"    CLS density           : {cl['cls_density_per_mm3']:9.1f} /mm^3")
    ci = cl["clustering_index"]
    ci_s = "nan" if math.isnan(ci) else f"{ci:.3f}"
    print(f"    clustering index (R)  : {ci_s}  -> {cl['clustering_interp']}")

    print("\n[4] NEURO-ADIPOCYTE PROXIMITY")
    print(f"    contact radius        : {pr['contact_radius_um']:.1f} um")
    print(f"    innervated adipocytes : {pr['n_innervated']}/{pr['n_total']}"
          f"  = {pr['innervated_fraction']*100:5.1f}%")

    # reference comparison
    profiles = {row["depot"]: row for row in load_depot_profiles()}
    if d in profiles:
        rp = profiles[d]
        print("\n[ref] reference profile for %s:" % d)
        print(f"      axon density ~{rp['axon_density_mm_per_mm3']} mm/mm^3 | "
              f"mean adipocyte ~{rp['mean_adipocyte_diam_um']} um | "
              f"CLS ~{rp['cls_density_per_mm3']} /mm^3 | "
              f"innervated ~{float(rp['innervated_fraction'])*100:.0f}%")
        print(f"      note: {rp['note']}")


def print_comparison_table(results):
    print("\n" + "=" * 70)
    print("  DEPOT COMPARISON")
    print("=" * 70)
    hdr = (f"  {'depot':6} | {'axon mm/mm3':>11} | {'mean dia um':>11} | "
           f"{'CLS /mm3':>9} | {'innerv %':>8}")
    print(hdr)
    print("  " + "-" * 64)
    for r in results:
        print(f"  {r['depot']:6} | "
              f"{r['axon']['axon_density_mm_per_mm3']:11.1f} | "
              f"{r['adipo']['mean_diam_um']:11.1f} | "
              f"{r['cls']['cls_density_per_mm3']:9.1f} | "
              f"{r['prox']['innervated_fraction']*100:8.1f}")
    print("\n  Interpretation (reference biology): BAT is densely innervated")
    print("  with small multilocular adipocytes; eWAT (visceral) is sparsely")
    print("  innervated, hypertrophic, and CLS-rich - a pro-inflammatory")
    print("  signature associated with obesity/insulin resistance.")


# ============================================================================
# User TIFF input
# ============================================================================
def load_user_tiff(path):
    if not HAVE_TIFFFILE:
        print("  [!] tifffile not installed - cannot read user TIFF. "
              "Install 'tifffile' or use --demo. Falling back to demo.",
              file=sys.stderr)
        return None
    try:
        arr = tifffile.imread(path)
    except Exception as e:
        print(f"  [!] failed to read {path}: {e}. Falling back to demo.",
              file=sys.stderr)
        return None
    arr = np.asarray(arr).astype(np.float32)
    print(f"  [input] loaded {path} shape={arr.shape}")
    # Heuristic channel mapping: expect (C,Z,Y,X) or (Z,Y,X,C) or (Z,Y,X).
    def norm(a):
        a = a - a.min()
        return a / (a.max() + 1e-9)
    if arr.ndim == 4:
        # assume smallest axis is channels
        cax = int(np.argmin(arr.shape))
        arr = np.moveaxis(arr, cax, 0)
        c = arr.shape[0]
        th = norm(arr[0])
        membrane = norm(arr[1]) if c > 1 else th
        f480 = norm(arr[2]) if c > 2 else np.zeros_like(th)
    elif arr.ndim == 3:
        th = membrane = f480 = norm(arr)
        print("  [input] single-channel TIFF: using same volume for all "
              "channels (axon metrics meaningful; adipocyte/CLS approximate).")
    else:
        print("  [!] unsupported TIFF ndim; falling back to demo.",
              file=sys.stderr)
        return None
    shape = th.shape
    channels = dict(th=th, membrane=membrane, f480=f480,
                    auto=np.zeros_like(th))
    gt = dict(adipo_centers=[], adipo_radii=[], cls_centers=[],
              voxel_um=DEMO_VOXEL_UM, shape=shape, depot="user")
    return channels, gt


# ============================================================================
# CLI
# ============================================================================
def build_parser():
    p = argparse.ArgumentParser(
        prog="main.py",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description=(
            "AdipoNeuroMap - quantify sympathetic innervation, adipocyte size,"
            " and crown-like structures in cleared 3D adipose depots.\n"
            "Domain: Obesity. " + DISCLAIMER),
        epilog=(
            "examples:\n"
            "  python3 main.py                      # demo, all depots compared\n"
            "  python3 main.py --depot eWAT         # single depot, full report\n"
            "  python3 main.py --summary            # compact comparison only\n"
            "  python3 main.py --top 3              # top-N largest adipocytes\n"
            "  python3 main.py --input stack.tif    # user TIFF (needs tifffile)\n"))
    p.add_argument("--depot", choices=["iWAT", "eWAT", "BAT"],
                   help="analyze a single depot (default: compare all three)")
    p.add_argument("--demo", action="store_true",
                   help="use built-in synthetic volume (default behaviour)")
    p.add_argument("--input", metavar="PATH",
                   help="path to a user TIFF stack (optional tifffile import)")
    p.add_argument("--summary", action="store_true",
                   help="print only the compact comparison table")
    p.add_argument("--top", type=int, metavar="N", default=0,
                   help="also list the N largest adipocyte diameters")
    p.add_argument("--seed", type=int, default=42,
                   help="random seed for synthetic demo (default 42)")
    p.add_argument("--size", type=int, default=0, metavar="N",
                   help="cubic-ish demo volume scale hint (advanced; "
                        "0=default 60x120x120)")
    p.add_argument("--quiet", action="store_true",
                   help="suppress per-step demo logging")
    return p


def main(argv=None):
    args = build_parser().parse_args(argv)
    ref = load_reference()
    verbose = not args.quiet

    print(banner())
    print(env_line())

    shape = (60, 120, 120)
    if args.size and args.size > 0:
        s = max(24, min(args.size, 96))
        shape = (s, s * 2, s * 2)

    # ------- user TIFF path -------
    if args.input:
        loaded = load_user_tiff(args.input)
        if loaded is not None:
            channels, gt = loaded
            res = run_depot("user", ref, channels=channels,
                            ground_truth=gt, voxel_um=gt["voxel_um"],
                            verbose=verbose)
            print_depot_report(res, ref, full=not args.summary)
            if args.top > 0:
                _print_top(res, args.top)
            _footer()
            return 0
        # else fall through to demo

    # ------- single depot -------
    if args.depot:
        if verbose:
            print("\n  generating synthetic demo volume...")
        res = run_depot(args.depot, ref, seed=args.seed, shape=shape,
                        verbose=verbose)
        print_depot_report(res, ref, full=not args.summary)
        if args.top > 0:
            _print_top(res, args.top)
        _footer()
        return 0

    # ------- all depots / comparison -------
    if verbose:
        print("\n  generating synthetic demo volumes (iWAT, eWAT, BAT)...")
    results = []
    for depot in ["BAT", "iWAT", "eWAT"]:
        res = run_depot(depot, ref, seed=args.seed, shape=shape,
                        verbose=verbose)
        results.append(res)
        if not args.summary:
            print_depot_report(res, ref, full=False)
            if args.top > 0:
                _print_top(res, args.top)

    print_comparison_table(results)
    _footer()
    return 0


def _print_top(res, n):
    diams = res["adipo"]["diameters_um"]
    if len(diams) == 0:
        print(f"\n  top-{n} adipocytes: (none detected)")
        return
    top = np.sort(diams)[::-1][:n]
    print(f"\n  top-{n} largest adipocyte diameters (um) [{res['depot']}]:")
    print("    " + ", ".join(f"{d:.1f}" for d in top))


def _footer():
    print("\n" + "-" * 70)
    print("  " + DISCLAIMER)
    print("  Synthetic demo data are procedurally generated; not real tissue.")
    print("-" * 70)


if __name__ == "__main__":
    sys.exit(main())
