#!/usr/bin/env python3
"""
PancreaClear3D - ex vivo 3D islet morphometry from tissue-cleared light-sheet volumes.

Domain : DM (diabetes mellitus) | Category: animal-experiment tool (ex vivo 3D image quantitation)

One-liner:
    From a cleared (tissue-clearing) pancreas light-sheet volume, automatically quantify
    islet beta-cell mass and islet vascular/neural innervation as 3D volumetrics.

Core features
    (1) insulin+(beta) / glucagon+(alpha) 3D segmentation -> islet object separation,
        count, volume, beta/alpha ratio
    (2) islet size-distribution histogram + head/body/tail zonal distribution map
    (3) per-islet vascular (CD31) density + neural (TH+/CGRP+) axon proximity (3D)
    (4) design-based stereology comparison mode: 2D section estimate vs 3D ground truth
        bias report (shows 2D sampling biases beta-mass by ~+/-20-40%)
    (5) synthetic demo volume (procedural islet field; head/body/tail gradient +
        ~4-decade size distribution) for a no-data trial run

Dependencies
    Required : numpy
    Optional : scipy (scipy.ndimage) -> better connected-component labeling, distance
               transforms, watershed-style splitting. If scipy is absent, a pure-numpy
               fallback (iterative flood-fill labeling + greedy distance split) is used so
               that `python3 main.py` ALWAYS succeeds.
    Optional : tifffile -> load user TIFF/OME-TIFF stacks via --input. If absent, only the
               synthetic demo is available.
    NOT used : napari / cellpose / StarDist / pyclesperanto (heavy GPU deps). In production
               the segmentation step can be swapped for cellpose-3D / StarDist-3D.

NO network calls. Offline, synthetic-data demonstration.

Research / reference use only - NOT a clinical decision or diagnostic device.
"""

import argparse
import json
import math
import os
import sys

import numpy as np

# ----------------------------------------------------------------------------- optional deps
try:
    import scipy.ndimage as ndi
    HAVE_SCIPY = True
except Exception:
    ndi = None
    HAVE_SCIPY = False

try:
    import tifffile
    HAVE_TIFFFILE = True
except Exception:
    tifffile = None
    HAVE_TIFFFILE = False

DISCLAIMER = ("Research / reference use only - NOT a clinical decision or diagnostic "
              "device. Figures are illustrative quantitation of the supplied/synthetic volume.")

HERE = os.path.dirname(os.path.abspath(__file__))
REF_PATH = os.path.join(HERE, "data", "stereology_reference.json")


def load_reference():
    try:
        with open(REF_PATH, "r") as fh:
            return json.load(fh)
    except Exception:
        # self-contained fallback if data file missing
        return {
            "islet_size_classes_um_diameter": [
                {"name": "very_small", "min": 20.0, "max": 50.0},
                {"name": "small", "min": 50.0, "max": 100.0},
                {"name": "medium", "min": 100.0, "max": 200.0},
                {"name": "large", "min": 200.0, "max": 350.0},
                {"name": "very_large", "min": 350.0, "max": 700.0},
            ],
            "zonal_regions": [
                {"name": "head", "axis_fraction_min": 0.0, "axis_fraction_max": 0.34},
                {"name": "body", "axis_fraction_min": 0.34, "axis_fraction_max": 0.67},
                {"name": "tail", "axis_fraction_min": 0.67, "axis_fraction_max": 1.0},
            ],
            "vascular_neural": {"proximity_threshold_um_default": 15.0},
            "cavalieri": {"section_spacing_um_default": 50.0},
        }


# ============================================================================= synthetic data
def make_synthetic_volume(shape=(60, 160, 160), voxel_um=8.0, n_islets=70, seed=7):
    """Procedurally generate a cleared-pancreas-like volume.

    Returns dict with float channels (beta, alpha, vessel, nerve) and ground-truth records.
    The long axis is Y (head=low Y -> tail=high Y) to emulate head/body/tail anatomy.
    Islet sizes span ~4 decades of volume via a log-uniform radius draw.
    Density follows a head<body<tail gradient.
    """
    rng = np.random.default_rng(seed)
    nz, ny, nx = shape
    beta = np.zeros(shape, dtype=np.float32)
    alpha = np.zeros(shape, dtype=np.float32)
    vessel = np.zeros(shape, dtype=np.float32)
    nerve = np.zeros(shape, dtype=np.float32)

    # head/body/tail density gradient -> bias islet Y-positions toward the tail
    grad = np.array([0.7, 1.0, 1.6], dtype=np.float64)  # head, body, tail
    grad = grad / grad.sum()
    # radius range chosen so volume (~r^3) spans ~4 decades. With voxel=8um:
    #   r=1.3vox -> d~21um (very_small) ... r=22vox -> d~350um (very_large boundary).
    # (22/1.3)^3 ~ 4800 -> well over 3.5 decades of volume.
    r_min, r_max = 1.3, 25.0  # voxels

    truth = []
    zz, yy, xx = np.meshgrid(np.arange(nz), np.arange(ny), np.arange(nx), indexing="ij")

    placed = 0
    attempts = 0
    centers = []
    while placed < n_islets and attempts < n_islets * 40:
        attempts += 1
        # choose zone by gradient, then Y within that zone
        zone = rng.choice(3, p=grad)
        y0 = int(ny * zone / 3.0)
        y1 = int(ny * (zone + 1) / 3.0)
        cy = rng.integers(y0 + 4, max(y0 + 5, y1 - 4))
        cx = rng.integers(8, nx - 8)
        cz = rng.integers(6, nz - 6)
        # log-uniform radius over the full range -> ~4 decades of volume, naturally
        # populating every size class (very_small ... very_large).
        r = math.exp(rng.uniform(math.log(r_min), math.log(r_max)))
        # avoid heavy overlap with already-placed islets (keep objects separable)
        ok = True
        for (pz, py, px, pr) in centers:
            d = math.sqrt((cz - pz) ** 2 + (cy - py) ** 2 + (cx - px) ** 2)
            if d < (r + pr) * 0.85:
                ok = False
                break
        if not ok:
            continue
        centers.append((cz, cy, cx, r))

        dist2 = (zz - cz) ** 2 + (yy - cy) ** 2 + (xx - cx) ** 2
        mask = dist2 <= r * r
        nvox = int(mask.sum())
        if nvox == 0:
            continue

        # beta:alpha composition ~ 3-4:1 with per-islet jitter. Alpha cells form a mantle
        # (mouse-like) - here an outer shell whose thickness is set so the alpha VOLUME
        # fraction lands near alpha_frac (shell volume fraction ~ 1 - (rin/r)^3).
        alpha_frac = float(np.clip(rng.normal(0.22, 0.05), 0.08, 0.40))
        rin_ratio = (1.0 - alpha_frac) ** (1.0 / 3.0)  # inner radius of the mantle
        rim = (dist2 <= r * r) & (dist2 > (r * rin_ratio) ** 2)
        core = mask & ~rim
        # intensities (clearing/light-sheet style: smooth blobs). Both channels are placed
        # well above the noise floor (sigma=0.03) so Otsu recovers them faithfully.
        falloff = np.exp(-dist2 / (2 * (r * 0.9) ** 2 + 1e-6))
        beta[core] += 0.9 + 0.5 * falloff[core]
        alpha[rim] += 0.9 + 0.5 * falloff[rim]
        # faint cross-presence (realistic intermixing), kept below threshold
        beta[rim] += 0.10 * falloff[rim]
        alpha[core] += 0.10 * falloff[core]

        truth.append({
            "z": cz, "y": cy, "x": cx, "r_vox": r,
            "vol_vox_truth": nvox,
            "alpha_frac": alpha_frac,
            "zone_idx": int(zone),
        })
        placed += 1

    # vasculature: dense near islets (islets are highly vascularized) + sparse background
    # build a tube-ish field: random capillary seeds biased toward islet centers
    n_vessel = nz * ny * nx // 400
    vy = rng.integers(0, ny, n_vessel)
    vx = rng.integers(0, nx, n_vessel)
    vz = rng.integers(0, nz, n_vessel)
    vessel[vz, vy, vx] = 1.0
    for (cz, cy, cx, r) in centers:
        k = max(3, int(r))
        for _ in range(k * 3):
            jz = int(np.clip(cz + rng.integers(-int(r) - 2, int(r) + 3), 0, nz - 1))
            jy = int(np.clip(cy + rng.integers(-int(r) - 2, int(r) + 3), 0, ny - 1))
            jx = int(np.clip(cx + rng.integers(-int(r) - 2, int(r) + 3), 0, nx - 1))
            vessel[jz, jy, jx] = 1.0

    # nerve: sparse axons; only a subset of islets are well innervated (proximity will vary)
    n_nerve = nz * ny * nx // 1500
    ny_ = rng.integers(0, ny, n_nerve)
    nx_ = rng.integers(0, nx, n_nerve)
    nz_ = rng.integers(0, nz, n_nerve)
    nerve[nz_, ny_, nx_] = 1.0
    innervated = rng.random(len(centers)) < 0.6
    for innerv, (cz, cy, cx, r) in zip(innervated, centers):
        if not innerv:
            continue
        for _ in range(max(2, int(r))):
            jz = int(np.clip(cz + rng.integers(-int(r) - 1, int(r) + 2), 0, nz - 1))
            jy = int(np.clip(cy + rng.integers(-int(r) - 1, int(r) + 2), 0, ny - 1))
            jx = int(np.clip(cx + rng.integers(-int(r) - 1, int(r) + 2), 0, nx - 1))
            nerve[jz, jy, jx] = 1.0

    # add light noise so thresholding is non-trivial
    beta += rng.normal(0, 0.03, shape).astype(np.float32)
    alpha += rng.normal(0, 0.03, shape).astype(np.float32)
    np.clip(beta, 0, None, out=beta)
    np.clip(alpha, 0, None, out=alpha)

    return {
        "shape": shape, "voxel_um": float(voxel_um),
        "beta": beta, "alpha": alpha, "vessel": vessel, "nerve": nerve,
        "truth": truth, "long_axis": 1,  # Y
    }


# ============================================================================= segmentation
def otsu_threshold(arr):
    """Otsu threshold on a flattened array (pure numpy)."""
    vals = arr[arr > 0]
    if vals.size == 0:
        return 0.5
    hist, edges = np.histogram(vals, bins=128)
    centers = (edges[:-1] + edges[1:]) / 2.0
    total = hist.sum()
    if total == 0:
        return float(centers[len(centers) // 2])
    w = np.cumsum(hist).astype(np.float64)
    wb = w
    wf = total - w
    mu = np.cumsum(hist * centers)
    mu_t = mu[-1]
    with np.errstate(divide="ignore", invalid="ignore"):
        mb = mu / np.where(wb == 0, np.nan, wb)
        mf = (mu_t - mu) / np.where(wf == 0, np.nan, wf)
        between = wb * wf * (mb - mf) ** 2
    between = np.nan_to_num(between)
    idx = int(np.argmax(between))
    return float(centers[idx])


# --- pure-numpy connected components fallback (6-connectivity, iterative) -----
def _label_numpy(mask):
    """6-connectivity 3D labeling without scipy. Iterative BFS via stack."""
    labels = np.zeros(mask.shape, dtype=np.int32)
    nz, ny, nx = mask.shape
    cur = 0
    neigh = [(-1, 0, 0), (1, 0, 0), (0, -1, 0), (0, 1, 0), (0, 0, -1), (0, 0, 1)]
    coords = np.argwhere(mask)
    visited = np.zeros(mask.shape, dtype=bool)
    coord_set = set(map(tuple, coords))
    for c in coords:
        cz, cy, cx = int(c[0]), int(c[1]), int(c[2])
        if visited[cz, cy, cx]:
            continue
        cur += 1
        stack = [(cz, cy, cx)]
        visited[cz, cy, cx] = True
        while stack:
            z, y, x = stack.pop()
            labels[z, y, x] = cur
            for dz, dy, dx in neigh:
                nz2, ny2, nx2 = z + dz, y + dy, x + dx
                if 0 <= nz2 < nz and 0 <= ny2 < ny and 0 <= nx2 < nx:
                    if mask[nz2, ny2, nx2] and not visited[nz2, ny2, nx2]:
                        visited[nz2, ny2, nx2] = True
                        stack.append((nz2, ny2, nx2))
    return labels, cur


def segment_islets(beta, alpha, voxel_um, min_vol_vox=8):
    """Segment islet objects from beta+alpha (endocrine) signal.

    Returns label array and a flag whether watershed-style splitting was applied.
    Uses scipy distance-transform watershed when available; otherwise scipy/numpy labeling.
    """
    endocrine = beta + alpha
    thr = otsu_threshold(endocrine)
    mask = endocrine > thr

    used_watershed = False
    if HAVE_SCIPY:
        # light morphological cleanup
        mask = ndi.binary_opening(mask, iterations=1)
        mask = ndi.binary_fill_holes(mask)
        # distance transform + local maxima seeds -> watershed split of touching islets
        dist = ndi.distance_transform_edt(mask)
        try:
            # seeds: high-distance peaks separated by a footprint
            # Smooth the distance map first so a single islet yields ONE peak (avoid
            # over-splitting large islets into fragments). Seed only at substantial maxima.
            dsm = ndi.gaussian_filter(dist, sigma=1.2)
            fp = np.ones((5, 5, 5))
            mx = (ndi.maximum_filter(dsm, footprint=fp) == dsm) & (dsm > 1.5)
            markers, _ = ndi.label(mx)
            if markers.max() >= 1:
                labels = _watershed_scipy(-dist, markers, mask)
                used_watershed = True
            else:
                labels, _ = ndi.label(mask)
        except Exception:
            labels, _ = ndi.label(mask)
    else:
        labels, _ = _label_numpy(mask)

    # drop tiny objects (noise)
    labels = _filter_small(labels, min_vol_vox)
    return labels, used_watershed, thr


def _watershed_scipy(landscape, markers, mask):
    """Marker-controlled watershed using scipy only (priority-flood).

    A compact priority-queue flood from markers over `landscape`, restricted to mask.
    """
    import heapq
    out = np.zeros(landscape.shape, dtype=np.int32)
    out[markers > 0] = markers[markers > 0]
    nz, ny, nx = landscape.shape
    heap = []
    seeds = np.argwhere(markers > 0)
    for s in seeds:
        z, y, x = int(s[0]), int(s[1]), int(s[2])
        heapq.heappush(heap, (float(landscape[z, y, x]), z, y, x))
    neigh = [(-1, 0, 0), (1, 0, 0), (0, -1, 0), (0, 1, 0), (0, 0, -1), (0, 0, 1)]
    while heap:
        val, z, y, x = heapq.heappop(heap)
        lab = out[z, y, x]
        for dz, dy, dx in neigh:
            z2, y2, x2 = z + dz, y + dy, x + dx
            if 0 <= z2 < nz and 0 <= y2 < ny and 0 <= x2 < nx:
                if mask[z2, y2, x2] and out[z2, y2, x2] == 0:
                    out[z2, y2, x2] = lab
                    heapq.heappush(heap, (float(landscape[z2, y2, x2]), z2, y2, x2))
    out[~mask] = 0
    return out


def _filter_small(labels, min_vol_vox):
    ids, counts = np.unique(labels, return_counts=True)
    out = labels.copy()
    for i, c in zip(ids, counts):
        if i == 0:
            continue
        if c < min_vol_vox:
            out[labels == i] = 0
    # relabel compactly
    remaining = [i for i in np.unique(out) if i != 0]
    remap = {old: new for new, old in enumerate(remaining, start=1)}
    res = np.zeros_like(out)
    for old, new in remap.items():
        res[out == old] = new
    return res


# ============================================================================= analysis
def equiv_diameter_um(vol_vox, voxel_um):
    vol_um3 = vol_vox * (voxel_um ** 3)
    r = (3.0 * vol_um3 / (4.0 * math.pi)) ** (1.0 / 3.0)
    return 2.0 * r


def analyze(volume, ref, prox_thr_um=None):
    beta = volume["beta"]
    alpha = volume["alpha"]
    vessel = volume["vessel"]
    nerve = volume["nerve"]
    voxel_um = volume["voxel_um"]
    shape = volume["shape"]
    nz, ny, nx = shape
    long_axis_len = ny  # Y

    if prox_thr_um is None:
        prox_thr_um = ref.get("vascular_neural", {}).get("proximity_threshold_um_default", 15.0)
    prox_vox = max(1, int(round(prox_thr_um / voxel_um)))

    labels, used_ws, thr = segment_islets(beta, alpha, voxel_um)
    ids = [i for i in np.unique(labels) if i != 0]

    # precompute beta/alpha membership for ratio
    bmask = beta > otsu_threshold(beta)
    amask = alpha > otsu_threshold(alpha)

    # distance to vessel / nerve for proximity (scipy EDT if present, else dilation count)
    if HAVE_SCIPY:
        vdist = ndi.distance_transform_edt(vessel == 0)
        ndist = ndi.distance_transform_edt(nerve == 0)
    else:
        vdist = _approx_distance(vessel > 0, prox_vox + 2)
        ndist = _approx_distance(nerve > 0, prox_vox + 2)

    size_classes = ref["islet_size_classes_um_diameter"]
    zones = ref["zonal_regions"]
    zone_counts = {z["name"]: 0 for z in zones}
    zone_vol = {z["name"]: 0.0 for z in zones}
    class_counts = {c["name"]: 0 for c in size_classes}

    islets = []
    total_beta_vox = 0
    total_alpha_vox = 0
    for i in ids:
        m = labels == i
        vol_vox = int(m.sum())
        d_um = equiv_diameter_um(vol_vox, voxel_um)
        # beta/alpha composition inside this islet
        b_in = int((m & bmask).sum())
        a_in = int((m & amask).sum())
        total_beta_vox += b_in
        total_alpha_vox += a_in
        # centroid -> zone
        coords = np.argwhere(m)
        cy = float(coords[:, 1].mean())
        frac = cy / max(1, long_axis_len - 1)
        zone_name = "body"
        for z in zones:
            if z["axis_fraction_min"] <= frac < z["axis_fraction_max"]:
                zone_name = z["name"]
                break
        else:
            zone_name = zones[-1]["name"]
        zone_counts[zone_name] += 1
        zone_vol[zone_name] += vol_vox * (voxel_um ** 3)
        # size class
        cls = size_classes[-1]["name"]
        for c in size_classes:
            if c["min"] <= d_um < c["max"]:
                cls = c["name"]
                break
        class_counts[cls] += 1
        # vascular density: fraction of islet voxels within prox of a vessel
        vox_near_vessel = int((vdist[m] <= prox_vox).sum())
        vasc_density = vox_near_vessel / max(1, vol_vox)
        # neural proximity: min distance (um) from islet to nearest nerve voxel
        nmin = float(ndist[m].min()) * voxel_um
        nerve_near = int((ndist[m] <= prox_vox).sum()) > 0
        islets.append({
            "id": int(i), "vol_vox": vol_vox,
            "vol_um3": vol_vox * (voxel_um ** 3),
            "diam_um": d_um, "zone": zone_name, "size_class": cls,
            "beta_vox": b_in, "alpha_vox": a_in,
            "vasc_density": vasc_density,
            "nerve_min_dist_um": nmin, "innervated": bool(nerve_near),
        })

    islets.sort(key=lambda r: r["vol_um3"], reverse=True)

    voxel_vol_um3 = voxel_um ** 3
    beta_mass_3d = total_beta_vox * voxel_vol_um3  # um^3 (proxy for beta-cell mass)
    alpha_mass_3d = total_alpha_vox * voxel_vol_um3
    ba_ratio = (total_beta_vox / total_alpha_vox) if total_alpha_vox else float("inf")

    # ---- 2D vs 3D stereology bias ----
    bias = stereology_bias(beta, voxel_um, ref)

    return {
        "labels": labels, "islets": islets, "n_islets": len(ids),
        "used_watershed": used_ws, "endocrine_threshold": thr,
        "beta_mass_3d_um3": beta_mass_3d, "alpha_mass_3d_um3": alpha_mass_3d,
        "beta_alpha_ratio": ba_ratio,
        "total_beta_vox": total_beta_vox, "total_alpha_vox": total_alpha_vox,
        "zone_counts": zone_counts, "zone_vol_um3": zone_vol,
        "class_counts": class_counts, "prox_thr_um": prox_thr_um,
        "voxel_um": voxel_um, "shape": shape, "bias": bias,
    }


def _approx_distance(binary, max_d):
    """Pure-numpy approximate EDT via iterative 6-conn dilation (chamfer-ish, capped)."""
    dist = np.full(binary.shape, max_d + 1.0, dtype=np.float32)
    dist[binary] = 0.0
    cur = binary.copy()
    for d in range(1, max_d + 1):
        nxt = cur.copy()
        nxt[1:, :, :] |= cur[:-1, :, :]
        nxt[:-1, :, :] |= cur[1:, :, :]
        nxt[:, 1:, :] |= cur[:, :-1, :]
        nxt[:, :-1, :] |= cur[:, 1:, :]
        nxt[:, :, 1:] |= cur[:, :, :-1]
        nxt[:, :, :-1] |= cur[:, :, 1:]
        new = nxt & ~cur
        dist[new] = d
        cur = nxt
    return dist


def stereology_bias(beta, voxel_um, ref, seed=7):
    """Compare 3D ground-truth beta volume vs sparse-2D-section estimates.

    Real-world morphometry papers often estimate islet/beta-cell mass from a SMALL number
    of arbitrarily positioned 2D sections, then extrapolate to the whole organ:

        est_vol = (mean positive area over the few sampled sections) x total thickness

    With few sections the sampled positive area is a noisy, biased estimate of the true mean
    cross-sectional area (islets are sparse and unevenly sized), so the extrapolated
    beta-mass swings by tens of percent. We sample several realistic section counts, each
    repeated over independent random section placements, and report the worst-case bias to
    demonstrate the +/-20-40% effect; the full-3D (every plane) estimate is shown as the
    unbiased reference.
    """
    bmask = beta > otsu_threshold(beta)
    voxel_vol = voxel_um ** 3
    true_vol = float(bmask.sum()) * voxel_vol
    nz = beta.shape[0]
    total_thickness = nz * voxel_um
    plane_area_um2 = voxel_um ** 2
    rng = np.random.default_rng(seed + 101)

    # full 3D reference (every plane) -> recovers true volume exactly
    results = [{
        "scheme": "full 3D (all planes)",
        "n_sections": nz,
        "est_vol_um3": true_vol,
        "bias_pct": 0.0,
    }]

    # sparse 2D snapshot stereology: few random sections, extrapolated.
    # We report the MEAN absolute bias over many random placements (the typical error a 2D
    # study incurs) together with its SD; mean signed bias and worst case are kept in JSON.
    for n_sec in [3, 5, 8, 12]:
        biases = []
        ests = []
        for _ in range(60):  # many random placements -> characterise the error
            planes = rng.choice(nz, size=min(n_sec, nz), replace=False)
            mean_area_vox = np.mean([bmask[p].sum() for p in planes])
            est_vol = mean_area_vox * plane_area_um2 * total_thickness
            ests.append(est_vol)
            biases.append(100.0 * (est_vol - true_vol) / true_vol if true_vol else 0.0)
        biases = np.asarray(biases)
        results.append({
            "scheme": f"{n_sec} random 2D sections",
            "n_sections": n_sec,
            "est_vol_um3": float(np.mean(ests)),
            "bias_pct": float(np.mean(np.abs(biases))),   # typical (mean-absolute) bias
            "bias_signed_pct": float(np.mean(biases)),
            "bias_std_pct": float(np.std(biases)),
            "bias_worst_pct": float(biases[np.argmax(np.abs(biases))]),
        })
    return {"true_vol_um3": true_vol, "estimates": results}


# ============================================================================= TIFF input
def load_input(path):
    """Load a user multi-channel TIFF/OME-TIFF as a synthetic-compatible volume dict.

    Expectation: a stack where channels are beta, alpha, vessel, nerve. We try a few common
    layouts. Requires tifffile. If channels are missing, the absent ones are zero-filled.
    """
    if not HAVE_TIFFFILE:
        raise RuntimeError(
            "tifffile is not installed; --input is unavailable. Run the synthetic demo "
            "(default / --demo), or install tifffile to load your own stacks.")
    arr = tifffile.imread(path)
    arr = np.asarray(arr, dtype=np.float32)
    # try to find a channel axis of length 2..4 (beta[,alpha[,vessel[,nerve]]])
    ch_axis = None
    for ax, n in enumerate(arr.shape):
        if 2 <= n <= 4 and arr.ndim >= 4:
            ch_axis = ax
            break
    if arr.ndim == 3:
        # single channel -> treat as beta only
        beta = arr
        alpha = np.zeros_like(beta)
        vessel = np.zeros_like(beta)
        nerve = np.zeros_like(beta)
    else:
        if ch_axis is None:
            ch_axis = int(np.argmin(arr.shape))
        arr = np.moveaxis(arr, ch_axis, 0)
        chans = [arr[i] for i in range(arr.shape[0])]
        while len(chans) < 4:
            chans.append(np.zeros_like(chans[0]))
        beta, alpha, vessel, nerve = chans[0], chans[1], chans[2], chans[3]
        # binarize structural channels softly
        vessel = (vessel > otsu_threshold(vessel)).astype(np.float32)
        nerve = (nerve > otsu_threshold(nerve)).astype(np.float32)
    shape = beta.shape
    return {
        "shape": shape, "voxel_um": 8.0,
        "beta": beta, "alpha": alpha, "vessel": vessel, "nerve": nerve,
        "truth": [], "long_axis": 1,
    }


# ============================================================================= reporting
def bar(frac, width=24):
    n = int(round(frac * width))
    return "#" * n + "." * (width - n)


def print_report(res, summary=False, top=10):
    print("=" * 74)
    print(" PancreaClear3D - ex vivo 3D islet morphometry (DM / animal-experiment tool)")
    print("=" * 74)
    print(" " + DISCLAIMER)
    print("-" * 74)
    eng = "scipy.ndimage (watershed split)" if res["used_watershed"] else (
        "scipy.ndimage labeling" if HAVE_SCIPY else "pure-numpy fallback labeling")
    print(f" segmentation engine : {eng}")
    print(f" volume shape (z,y,x): {res['shape']}   voxel size: {res['voxel_um']} um")
    print(f" endocrine threshold : {res['endocrine_threshold']:.3f} (Otsu on insulin+glucagon)")
    print("-" * 74)

    # (1) counts / mass / ratio
    print(" [1] Islet objects, beta-cell mass, beta/alpha composition")
    print(f"     islet count            : {res['n_islets']}")
    print(f"     total beta-cell volume : {res['beta_mass_3d_um3']:,.0f} um^3  (3D beta-mass proxy)")
    print(f"     total alpha-cell volume: {res['alpha_mass_3d_um3']:,.0f} um^3")
    bar_ratio = res["beta_alpha_ratio"]
    rtxt = f"{bar_ratio:.2f} : 1" if math.isfinite(bar_ratio) else "inf (no alpha)"
    print(f"     beta : alpha ratio     : {rtxt}")

    # (2) size distribution + zonal map
    print(" [2] Islet size-class distribution & head/body/tail zonal map")
    total = max(1, res["n_islets"])
    for cls, c in res["class_counts"].items():
        print(f"     size {cls:<11}: {c:3d}  {bar(c / total)}")
    print("     zonal distribution (count | volume um^3):")
    for zone in ["head", "body", "tail"]:
        zc = res["zone_counts"].get(zone, 0)
        zv = res["zone_vol_um3"].get(zone, 0.0)
        print(f"       {zone:<5}: {zc:3d} islets  {bar(zc / total)}  vol={zv:,.0f}")

    # (3) vascular / neural
    print(" [3] Per-islet vascular (CD31) density & neural (TH+/CGRP+) proximity")
    print(f"     proximity threshold    : {res['prox_thr_um']} um")
    if res["islets"]:
        vds = [i["vasc_density"] for i in res["islets"]]
        innerv = sum(1 for i in res["islets"] if i["innervated"])
        print(f"     mean vascular density  : {np.mean(vds):.3f} "
              f"(islet-voxel fraction within threshold of a CD31+ voxel)")
        print(f"     innervated islets      : {innerv}/{res['n_islets']} "
              f"({100.0 * innerv / total:.0f}%)")

    # (4) stereology bias
    print(" [4] Design-based stereology check: 2D section estimate vs 3D ground truth")
    b = res["bias"]
    print(f"     3D ground-truth beta volume: {b['true_vol_um3']:,.0f} um^3")
    print("     sampling scheme         sec  mean est(um^3)  typ.|bias|  (+/-SD)  worst")
    for e in b["estimates"]:
        sd = e.get("bias_std_pct")
        if sd is None:
            print(f"     {e['scheme']:<23} {e['n_sections']:>3d}  {e['est_vol_um3']:>13,.0f}  "
                  f"   exact")
        else:
            print(f"     {e['scheme']:<23} {e['n_sections']:>3d}  {e['est_vol_um3']:>13,.0f}  "
                  f"  {e['bias_pct']:5.0f}%   +/-{sd:3.0f}%  {e['bias_worst_pct']:+5.0f}%")
    # headline: typical bias of the sparsest realistic scheme (3 sections)
    sparse = [e for e in b["estimates"] if e.get("bias_std_pct") is not None]
    if sparse:
        lo = min(e["bias_pct"] for e in sparse)
        hi = max(e["bias_pct"] for e in sparse)
        print(f"     => sparse 2D sampling typically biases beta-mass by ~{lo:.0f}-{hi:.0f}% "
              f"vs 3D ground truth.")

    if not summary:
        # top islets table
        print(f" [5] Top {top} islets by volume")
        print("     id   vol_um3        diam_um  zone  class        vasc   nerve_dist_um")
        for i in res["islets"][:top]:
            print(f"     {i['id']:>3}  {i['vol_um3']:>12,.0f}  {i['diam_um']:>7.1f}  "
                  f"{i['zone']:<5} {i['size_class']:<11} {i['vasc_density']:>4.2f}  "
                  f"{i['nerve_min_dist_um']:>8.1f}")

    print("-" * 74)
    print(" " + DISCLAIMER)
    print("=" * 74)


# ============================================================================= CLI
def build_parser():
    p = argparse.ArgumentParser(
        prog="main.py",
        description="PancreaClear3D - 3D islet beta-cell mass & innervation morphometry "
                    "from tissue-cleared light-sheet pancreas volumes (DM, research use only).",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="Research / reference use only - NOT a clinical diagnostic device.\n"
               "Examples:\n"
               "  python3 main.py                 # synthetic demo, full report\n"
               "  python3 main.py --summary       # condensed report\n"
               "  python3 main.py --top 5         # show top 5 islets\n"
               "  python3 main.py --input vol.tif # load your own TIFF stack (needs tifffile)\n")
    g = p.add_mutually_exclusive_group()
    g.add_argument("--demo", action="store_true",
                   help="use the built-in synthetic cleared-pancreas volume (default)")
    g.add_argument("--input", metavar="PATH", default=None,
                   help="load a user multi-channel TIFF/OME-TIFF stack "
                        "(beta,alpha,vessel,nerve channels); requires tifffile")
    p.add_argument("--summary", action="store_true",
                   help="print condensed summary (omit per-islet table)")
    p.add_argument("--top", type=int, default=10, metavar="N",
                   help="number of largest islets to list (default 10)")
    p.add_argument("--n-islets", type=int, default=70, metavar="N",
                   help="synthetic islet count (demo only, default 70)")
    p.add_argument("--seed", type=int, default=7, help="synthetic RNG seed (default 7)")
    p.add_argument("--proximity-um", type=float, default=None, metavar="UM",
                   help="vascular/neural proximity threshold in um (default from reference)")
    p.add_argument("--json", metavar="PATH", default=None,
                   help="also write the numeric results to a JSON file")
    return p


def results_to_json(res):
    out = dict(res)
    out.pop("labels", None)
    out["have_scipy"] = HAVE_SCIPY
    out["have_tifffile"] = HAVE_TIFFFILE
    out["disclaimer"] = DISCLAIMER
    return out


def main(argv=None):
    parser = build_parser()
    args = parser.parse_args(argv)
    ref = load_reference()

    print(f"[env] numpy={np.__version__}  scipy={'yes' if HAVE_SCIPY else 'NO (numpy fallback)'}"
          f"  tifffile={'yes' if HAVE_TIFFFILE else 'no'}", file=sys.stderr)

    if args.input:
        try:
            volume = load_input(args.input)
            print(f"[input] loaded {args.input} shape={volume['shape']}", file=sys.stderr)
        except Exception as e:
            print(f"[error] could not load --input: {e}", file=sys.stderr)
            print("[info] falling back to synthetic demo volume.", file=sys.stderr)
            volume = make_synthetic_volume(n_islets=args.n_islets, seed=args.seed)
    else:
        volume = make_synthetic_volume(n_islets=args.n_islets, seed=args.seed)

    res = analyze(volume, ref, prox_thr_um=args.proximity_um)
    print_report(res, summary=args.summary, top=args.top)

    if args.json:
        try:
            with open(args.json, "w") as fh:
                json.dump(results_to_json(res), fh, indent=2, default=float)
            print(f"[json] wrote {args.json}", file=sys.stderr)
        except Exception as e:
            print(f"[warn] could not write json: {e}", file=sys.stderr)
    return 0


if __name__ == "__main__":
    sys.exit(main())
