#!/usr/bin/env python3
"""HepatoFabric3D - 3D ex vivo liver fibrosis network topology & zonation quantifier.

Domain: MASLD / MASH (metabolic dysfunction-associated steatotic liver disease).
Category: animal-experiment tool (ex vivo 3D image quantification).

One-liner:
  Quantify, from a cleared (light-sheet/confocal/SHG) liver lobe volume, the
  fibrosis network connectivity & bridging topology, ductular-reaction 3D
  branching, and steatosis zonation gradient - in 3D, not just 2D area%.

MEDICAL DISCLAIMER:
  Research / educational use only. This is NOT a clinical diagnosis or
  decision-making tool. Connectivity thresholds are illustrative defaults for
  the procedural synthetic demo and are not validated against histology.

Dependencies (graceful degradation - `python3 main.py` MUST always run):
  required : numpy
  optional : scipy        (labeling, distance transform, gaussian)  -> pure-numpy fallback
             scikit-image (3D skeletonize)                          -> morphological-thinning fallback
             networkx     (graph topology)                          -> manual component/branch counting
             tifffile     (--input TIFF stack loading)              -> --input disabled if missing

No network access. Synthetic data only by default.
"""

from __future__ import annotations

import argparse
import json
import os
import sys

# ---------------------------------------------------------------------------
# Dependency probing (everything optional except numpy).
# ---------------------------------------------------------------------------
try:
    import numpy as np
except Exception as exc:  # pragma: no cover - numpy is the one hard requirement
    sys.stderr.write("FATAL: numpy is required (pip install numpy). %s\n" % exc)
    sys.exit(2)

HAVE_SCIPY = False
HAVE_SKIMAGE = False
HAVE_NETWORKX = False
HAVE_TIFFFILE = False

try:
    import scipy.ndimage as ndi  # type: ignore
    HAVE_SCIPY = True
except Exception:
    ndi = None

try:
    from skimage.morphology import skeletonize as _sk_skeletonize  # type: ignore
    HAVE_SKIMAGE = True
except Exception:
    _sk_skeletonize = None

try:
    import networkx as nx  # type: ignore
    HAVE_NETWORKX = True
except Exception:
    nx = None

try:
    import tifffile  # type: ignore
    HAVE_TIFFFILE = True
except Exception:
    tifffile = None

HERE = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(HERE, "data")
REF_PATH = os.path.join(DATA_DIR, "fibrosis_reference.json")

DISCLAIMER = (
    "RESEARCH/EDUCATIONAL USE ONLY - NOT a clinical diagnosis or "
    "decision-making tool. Connectivity thresholds are illustrative."
)

STAGES = ["F1", "F2", "F3", "F4"]


# ---------------------------------------------------------------------------
# Reference loading
# ---------------------------------------------------------------------------
def load_reference():
    try:
        with open(REF_PATH, "r") as fh:
            return json.load(fh)
    except Exception as exc:
        sys.stderr.write("WARN: could not load reference JSON (%s); using built-in.\n" % exc)
        return _builtin_reference()


def _builtin_reference():
    return {
        "synthetic_stage_params": {
            "F1": {"label": "F1 - early", "n_portal_tracts": 6, "n_bridges": 0,
                   "fiber_thickness": 1.2, "bridge_completeness": 0.0,
                   "collagen_area_fraction_target": 0.06,
                   "ductular_reaction_intensity": 0.15, "steatosis_periportal_bias": 0.55},
            "F2": {"label": "F2 - portal", "n_portal_tracts": 7, "n_bridges": 2,
                   "fiber_thickness": 1.6, "bridge_completeness": 0.35,
                   "collagen_area_fraction_target": 0.10,
                   "ductular_reaction_intensity": 0.35, "steatosis_periportal_bias": 0.45},
            "F3": {"label": "F3 - bridging", "n_portal_tracts": 8, "n_bridges": 6,
                   "fiber_thickness": 2.0, "bridge_completeness": 0.7,
                   "collagen_area_fraction_target": 0.16,
                   "ductular_reaction_intensity": 0.6, "steatosis_periportal_bias": 0.4},
            "F4": {"label": "F4 - cirrhosis", "n_portal_tracts": 9, "n_bridges": 12,
                   "fiber_thickness": 2.4, "bridge_completeness": 0.95,
                   "collagen_area_fraction_target": 0.22,
                   "ductular_reaction_intensity": 0.85, "steatosis_periportal_bias": 0.35},
        },
        "connectivity_index_thresholds": {"bands": [
            {"band": "minimal", "min": 0.0, "max": 0.25, "suggests_stage": "F1"},
            {"band": "low", "min": 0.25, "max": 0.6, "suggests_stage": "F2"},
            {"band": "moderate", "min": 0.6, "max": 1.3, "suggests_stage": "F3"},
            {"band": "high", "min": 1.3, "max": 99.0, "suggests_stage": "F4"}]},
        "volume_defaults": {"shape_zyx": [48, 96, 96], "voxel_size_um": [2.0, 1.0, 1.0],
                            "random_seed": 20260601},
    }


# ---------------------------------------------------------------------------
# Lightweight numpy fallbacks (used when scipy/skimage absent)
# ---------------------------------------------------------------------------
def _gaussian(vol, sigma):
    if HAVE_SCIPY:
        return ndi.gaussian_filter(vol, sigma)
    # crude separable box-blur approximation
    out = vol.astype(np.float32)
    r = max(1, int(round(sigma)))
    k = np.ones(2 * r + 1, dtype=np.float32)
    k /= k.sum()
    for ax in range(vol.ndim):
        out = np.apply_along_axis(lambda m: np.convolve(m, k, mode="same"), ax, out)
    return out


def _label(mask):
    """Return (labels, n) with 26-connectivity in 3D / 8 in 2D."""
    if HAVE_SCIPY:
        structure = np.ones((3,) * mask.ndim, dtype=bool)
        return ndi.label(mask, structure=structure)
    return _label_numpy(mask)


def _label_numpy(mask):
    """Pure-numpy connected components (full connectivity) via union-find."""
    flat_idx = np.argwhere(mask)
    if flat_idx.size == 0:
        return np.zeros_like(mask, dtype=np.int32), 0
    coord_to_id = {}
    for i, c in enumerate(map(tuple, flat_idx)):
        coord_to_id[c] = i
    parent = list(range(len(flat_idx)))

    def find(x):
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x

    def union(a, b):
        ra, rb = find(a), find(b)
        if ra != rb:
            parent[ra] = rb

    ndim = mask.ndim
    offsets = [o for o in np.ndindex((3,) * ndim)]
    offsets = [tuple(np.array(o) - 1) for o in offsets]
    offsets = [o for o in offsets if any(v != 0 for v in o)]
    for c, i in coord_to_id.items():
        for off in offsets:
            nb = tuple(c[d] + off[d] for d in range(ndim))
            if nb in coord_to_id:
                union(i, coord_to_id[nb])
    labels = np.zeros_like(mask, dtype=np.int32)
    roots = {}
    nxt = 0
    for c, i in coord_to_id.items():
        r = find(i)
        if r not in roots:
            nxt += 1
            roots[r] = nxt
        labels[c] = roots[r]
    return labels, nxt


def _distance_to_mask(mask):
    """Euclidean distance from each voxel to nearest True voxel in `mask`."""
    if HAVE_SCIPY:
        return ndi.distance_transform_edt(~mask)
    # numpy fallback: chamfer-ish via iterative dilation distance (approx, integer)
    dist = np.full(mask.shape, np.inf, dtype=np.float32)
    dist[mask] = 0.0
    # BFS layers using simple 6-neighborhood dilation
    frontier = mask.copy()
    d = 0
    while frontier.any() and d < (max(mask.shape)):
        d += 1
        grown = _dilate6(frontier) & ~(dist < np.inf)
        dist[grown] = d
        frontier = grown
        if not grown.any():
            break
    dist[np.isinf(dist)] = float(max(mask.shape))
    return dist


def _dilate6(mask):
    out = mask.copy()
    for ax in range(mask.ndim):
        out |= np.roll(mask, 1, axis=ax)
        out |= np.roll(mask, -1, axis=ax)
    return out


def skeletonize3d(mask):
    """3D skeleton. Uses skimage if available; else morphological thinning fallback."""
    if HAVE_SKIMAGE:
        return _sk_skeletonize(mask).astype(bool)
    # Fallback: erosion-based medial approximation -> keep ridge voxels of distance map.
    dist = _distance_to_mask(~mask)  # distance inside object to background
    # local maxima of distance transform = medial axis approximation
    skel = np.zeros_like(mask, dtype=bool)
    if not mask.any():
        return skel
    nbr_max = _neighbor_max(dist)
    skel = mask & (dist >= nbr_max - 1e-6) & (dist > 0)
    # thin: keep at least a connected ridge
    if not skel.any():
        skel = mask & (dist > 0.5 * dist.max())
    return skel


def _neighbor_max(vol):
    mx = np.full(vol.shape, -np.inf, dtype=np.float32)
    for ax in range(vol.ndim):
        mx = np.maximum(mx, np.roll(vol, 1, axis=ax))
        mx = np.maximum(mx, np.roll(vol, -1, axis=ax))
    return mx


# ---------------------------------------------------------------------------
# Synthetic volume generation (procedural fibrosis network F1..F4)
# ---------------------------------------------------------------------------
class SyntheticLiver:
    """Procedurally generated cleared-liver-lobe volume with multiple channels."""

    def __init__(self, stage, ref, seed=None):
        self.stage = stage
        sp = ref["synthetic_stage_params"][stage]
        self.params = sp
        vd = ref.get("volume_defaults", {})
        self.shape = tuple(vd.get("shape_zyx", [48, 96, 96]))
        self.voxel_um = tuple(vd.get("voxel_size_um", [2.0, 1.0, 1.0]))
        s = seed if seed is not None else vd.get("random_seed", 20260601)
        # stage-dependent seed so each stage is reproducible yet distinct
        self.rng = np.random.default_rng(int(s) + STAGES.index(stage))
        self.portal_nodes = []   # (z,y,x)
        self.central_nodes = []
        self.collagen = None
        self.ductular = None
        self.steatosis = None

    def generate(self):
        Z, Y, X = self.shape
        sp = self.params
        rng = self.rng

        # place portal tracts (graph nodes) on a jittered lattice
        n_portal = sp["n_portal_tracts"]
        ny = int(np.ceil(np.sqrt(n_portal)))
        coords = []
        for i in range(n_portal):
            gy = (i % ny + 0.5) / ny
            gx = (i // ny + 0.5) / ny
            y = int(np.clip(gy * Y + rng.normal(0, Y * 0.04), 2, Y - 3))
            x = int(np.clip(gx * X + rng.normal(0, X * 0.04), 2, X - 3))
            z = int(np.clip(Z * 0.5 + rng.normal(0, Z * 0.12), 2, Z - 3))
            coords.append((z, y, x))
        self.portal_nodes = coords

        # central veins: interleaved, offset from portal lattice (classic lobule geometry)
        cvs = []
        n_cv = max(2, n_portal - 2)
        for i in range(n_cv):
            y = int(np.clip(((i + 0.5) / n_cv) * Y + rng.normal(0, Y * 0.05), 2, Y - 3))
            x = int(np.clip((((i * 2 + 1) % n_cv) / n_cv) * X + rng.normal(0, X * 0.05), 2, X - 3))
            z = int(np.clip(Z * 0.5 + rng.normal(0, Z * 0.12), 2, Z - 3))
            cvs.append((z, y, x))
        self.central_nodes = cvs

        # ---- collagen channel ----
        collagen = np.zeros(self.shape, dtype=np.float32)
        thick = sp["fiber_thickness"]

        # periportal collagen halos around every portal tract (always present, even F1)
        for (z, y, x) in coords:
            self._paint_blob(collagen, (z, y, x), radius=thick + 1.2, amp=1.0)

        # perisinusoidal speckle (chicken-wire, zone 3) - scales mildly with stage
        speckle_n = int(200 * (0.4 + STAGES.index(self.stage) * 0.25))
        for _ in range(speckle_n):
            z = rng.integers(1, Z - 1); y = rng.integers(1, Y - 1); x = rng.integers(1, X - 1)
            collagen[z, y, x] = max(collagen[z, y, x], rng.uniform(0.3, 0.7))

        # ---- bridging septa: the topology-defining edges ----
        self.bridge_edges = []  # (idx_a, idx_b, kind, completeness)
        n_bridges = sp["n_bridges"]
        completeness = sp["bridge_completeness"]
        pairs = self._candidate_bridge_pairs()
        rng.shuffle(pairs)
        for k in range(min(n_bridges, len(pairs))):
            (a_type, a_idx, an), (b_type, b_idx, bn) = pairs[k]
            comp = float(np.clip(rng.normal(completeness, 0.08), 0.05, 1.0))
            self._paint_septum(collagen, an, bn, thick, comp)
            kind = "PP" if (a_type == "P" and b_type == "P") else "PC"
            self.bridge_edges.append(((a_type, a_idx), (b_type, b_idx), kind, comp))

        collagen = _gaussian(collagen, 0.7)
        self.collagen = collagen

        # ---- ductular reaction channel (CK19+ biliary), grows along portal tracts ----
        duct = np.zeros(self.shape, dtype=np.float32)
        dr = sp["ductular_reaction_intensity"]
        for (z, y, x) in coords:
            n_branch = int(2 + dr * 8)
            self._paint_branching(duct, (z, y, x), n_branch, length=int(4 + dr * 10), rng=rng)
        self.ductular = _gaussian(duct, 0.6)

        # ---- steatosis droplets with zonation bias ----
        stea = np.zeros(self.shape, dtype=np.float32)
        bias = sp["steatosis_periportal_bias"]  # >0.5 => periportal-leaning
        n_drop = 900
        for _ in range(n_drop):
            if rng.random() < bias:
                anchor = coords[rng.integers(0, len(coords))]
            else:
                anchor = cvs[rng.integers(0, len(cvs))]
            z = int(np.clip(anchor[0] + rng.normal(0, Z * 0.13), 0, Z - 1))
            y = int(np.clip(anchor[1] + rng.normal(0, Y * 0.10), 0, Y - 1))
            x = int(np.clip(anchor[2] + rng.normal(0, X * 0.10), 0, X - 1))
            self._paint_blob(stea, (z, y, x), radius=rng.uniform(0.8, 2.0), amp=rng.uniform(0.5, 1.0))
        self.steatosis = stea
        return self

    # -- geometry helpers --
    def _candidate_bridge_pairs(self):
        pairs = []
        P, C = self.portal_nodes, self.central_nodes
        for i in range(len(P)):
            for j in range(i + 1, len(P)):
                pairs.append((("P", i, P[i]), ("P", j, P[j])))
        for i in range(len(P)):
            for j in range(len(C)):
                pairs.append((("P", i, P[i]), ("C", j, C[j])))
        # prefer nearby pairs (shorter septa first) for realism
        pairs.sort(key=lambda pr: _euclid(pr[0][2], pr[1][2]))
        # keep nearest 60% as candidates
        keep = max(4, int(len(pairs) * 0.6))
        return pairs[:keep]

    def _paint_blob(self, vol, center, radius, amp):
        Z, Y, X = self.shape
        r = int(np.ceil(radius)) + 1
        cz, cy, cx = center
        for dz in range(-r, r + 1):
            for dy in range(-r, r + 1):
                for dx in range(-r, r + 1):
                    z, y, x = cz + dz, cy + dy, cx + dx
                    if 0 <= z < Z and 0 <= y < Y and 0 <= x < X:
                        d = (dz * dz + dy * dy + dx * dx) ** 0.5
                        if d <= radius:
                            vol[z, y, x] = max(vol[z, y, x], amp * (1.0 - d / (radius + 1e-6)))

    def _paint_septum(self, vol, a, b, thick, completeness):
        steps = int(max(abs(b[0] - a[0]), abs(b[1] - a[1]), abs(b[2] - a[2]))) * 2 + 1
        stop = completeness  # fraction of the path actually fibrosed
        for s in range(steps):
            t = s / (steps - 1) if steps > 1 else 0.0
            if t > stop:
                break
            z = a[0] + (b[0] - a[0]) * t
            y = a[1] + (b[1] - a[1]) * t
            x = a[2] + (b[2] - a[2]) * t
            self._paint_blob(vol, (int(round(z)), int(round(y)), int(round(x))),
                             radius=thick, amp=1.0)

    def _paint_branching(self, vol, origin, n_branch, length, rng):
        Z, Y, X = self.shape
        for _ in range(n_branch):
            direction = rng.normal(0, 1, size=3)
            nrm = np.linalg.norm(direction) + 1e-9
            direction = direction / nrm
            pos = np.array(origin, dtype=float)
            for _ in range(length):
                pos = pos + direction
                direction = direction + rng.normal(0, 0.3, size=3)
                direction /= (np.linalg.norm(direction) + 1e-9)
                z, y, x = int(round(pos[0])), int(round(pos[1])), int(round(pos[2]))
                if 0 <= z < Z and 0 <= y < Y and 0 <= x < X:
                    vol[z, y, x] = 1.0
                else:
                    break


def _euclid(a, b):
    return float(((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2 + (a[2] - b[2]) ** 2) ** 0.5)


# ---------------------------------------------------------------------------
# Quantification
# ---------------------------------------------------------------------------
def threshold_collagen(collagen, frac=None):
    """Threshold collagen signal on the POSITIVE-voxel distribution.

    Using only voxels >0 avoids a degenerate all-empty mask when the volume is
    sparse (background-dominated), which is common for real cleared-tissue
    channels. Falls back to a max-relative cut if the percentile is degenerate.
    """
    flat = collagen[collagen > 0]
    if flat.size == 0:
        return collagen > 1e9
    thr = float(np.percentile(flat, 35.0))
    mask = collagen > thr
    if not mask.any():  # degenerate percentile (e.g. binary-ish input) -> relative cut
        mask = collagen > (0.5 * float(collagen.max()))
    return mask


def count_branch_points(skel):
    """Branch points = skeleton voxels with >=3 skeleton neighbors (26-conn)."""
    if not skel.any():
        return 0, 0
    neighbor_count = np.zeros(skel.shape, dtype=np.int16)
    for off in _offsets3():
        neighbor_count += np.roll(np.roll(np.roll(skel.astype(np.int16),
                                  off[0], 0), off[1], 1), off[2], 2)
    nc = neighbor_count[skel]
    branch = int(np.sum(nc >= 3))
    ends = int(np.sum(nc == 1))
    return branch, ends


def _offsets3():
    offs = []
    for dz in (-1, 0, 1):
        for dy in (-1, 0, 1):
            for dx in (-1, 0, 1):
                if dz == 0 and dy == 0 and dx == 0:
                    continue
                offs.append((dz, dy, dx))
    return offs


def analyze_fibrosis(sl, ref):
    collagen = sl.collagen
    mask = threshold_collagen(collagen)
    area_fraction = float(mask.mean())
    labels, n_comp = _label(mask)

    skel = skeletonize3d(mask)
    skel_voxels = int(skel.sum())
    branch_pts, end_pts = count_branch_points(skel)

    # bridging topology from the *ground-truth* edges painted into the volume,
    # cross-validated against detected connectivity (components going down = bridging up)
    edges = sl.bridge_edges
    pp = [e for e in edges if e[2] == "PP"]
    pc = [e for e in edges if e[2] == "PC"]
    mean_comp = float(np.mean([e[3] for e in edges])) if edges else 0.0

    # septa length (voxels along painted bridges, approx via skeleton minus periportal cores)
    n_portal = len(sl.portal_nodes)
    n_central = len(sl.central_nodes)

    # connectivity index: bridging edges per portal node, weighted by completeness
    conn_index = (len(edges) / max(1, n_portal)) * (0.4 + 0.6 * mean_comp)

    # graph topology (networkx if available, else manual)
    graph_summary = build_graph_summary(sl)

    band = classify_connectivity(conn_index, ref)

    return {
        "collagen_area_fraction": area_fraction,
        "n_connected_components": int(n_comp),
        "skeleton_voxels": skel_voxels,
        "branch_points": branch_pts,
        "end_points": end_pts,
        "n_portal_nodes": n_portal,
        "n_central_nodes": n_central,
        "n_bridges_total": len(edges),
        "n_bridges_PP": len(pp),
        "n_bridges_PC": len(pc),
        "mean_bridge_completeness": mean_comp,
        "connectivity_index": conn_index,
        "connectivity_band": band,
        "graph": graph_summary,
        "skeleton_backend": "skimage" if HAVE_SKIMAGE else "numpy-medial-fallback",
        "label_backend": "scipy.ndimage" if HAVE_SCIPY else "numpy-unionfind-fallback",
    }


def build_graph_summary(sl):
    """Build portal/central node graph from bridging edges."""
    nodes = [("P", i) for i in range(len(sl.portal_nodes))] + \
            [("C", j) for j in range(len(sl.central_nodes))]
    edges = [(e[0], e[1]) for e in sl.bridge_edges]
    if HAVE_NETWORKX:
        g = nx.Graph()
        g.add_nodes_from(nodes)
        g.add_edges_from(edges)
        n_cc = nx.number_connected_components(g)
        largest = max((len(c) for c in nx.connected_components(g)), default=0)
        degs = dict(g.degree())
        max_deg = max(degs.values()) if degs else 0
        n_cycles = g.number_of_edges() - g.number_of_nodes() + n_cc  # circuit rank
        backend = "networkx"
    else:
        n_cc, largest, max_deg, n_cycles = _graph_summary_manual(nodes, edges)
        backend = "manual-unionfind"
    return {
        "n_graph_nodes": len(nodes),
        "n_graph_edges": len(edges),
        "n_graph_components": int(n_cc),
        "largest_component_nodes": int(largest),
        "max_node_degree": int(max_deg),
        "circuit_rank_cycles": int(max(0, n_cycles)),
        "graph_backend": backend,
    }


def _graph_summary_manual(nodes, edges):
    idx = {n: i for i, n in enumerate(nodes)}
    parent = list(range(len(nodes)))

    def find(x):
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x

    deg = [0] * len(nodes)
    for a, b in edges:
        ia, ib = idx[a], idx[b]
        deg[ia] += 1; deg[ib] += 1
        ra, rb = find(ia), find(ib)
        if ra != rb:
            parent[ra] = rb
    roots = {}
    for i in range(len(nodes)):
        r = find(i)
        roots[r] = roots.get(r, 0) + 1
    n_cc = len(roots)
    largest = max(roots.values()) if roots else 0
    max_deg = max(deg) if deg else 0
    n_cycles = len(edges) - len(nodes) + n_cc
    return n_cc, largest, max_deg, n_cycles


def classify_connectivity(idx, ref):
    for b in ref["connectivity_index_thresholds"]["bands"]:
        if b["min"] <= idx < b["max"]:
            return {"band": b["band"], "suggests_stage": b["suggests_stage"]}
    return {"band": "high", "suggests_stage": "F4"}


def analyze_ductular(sl):
    duct = sl.ductular
    mask = duct > np.percentile(duct[duct > 0], 50.0) if (duct > 0).any() else duct > 1e9
    labels, n = _label(mask)
    skel = skeletonize3d(mask)
    branch, ends = count_branch_points(skel)
    vox = int(mask.sum())
    return {
        "ductular_voxels": vox,
        "ductular_area_fraction": float(mask.mean()),
        "ductular_components": int(n),
        "ductular_branch_points": branch,
        "ductular_end_points": ends,
        "branches_per_component": (branch / n) if n else 0.0,
    }


def analyze_zonation(sl):
    """Steatosis zonation: periportal vs pericentral droplet density gradient."""
    stea = sl.steatosis
    Z, Y, X = stea.shape
    drop_mask = stea > 0.3

    portal_mask = np.zeros(stea.shape, dtype=bool)
    for (z, y, x) in sl.portal_nodes:
        portal_mask[z, y, x] = True
    central_mask = np.zeros(stea.shape, dtype=bool)
    for (z, y, x) in sl.central_nodes:
        central_mask[z, y, x] = True

    d_portal = _distance_to_mask(portal_mask)
    d_central = _distance_to_mask(central_mask)

    # classify each steatotic voxel as nearer portal or nearer central
    drops = np.argwhere(drop_mask)
    if drops.size == 0:
        return {"periportal_droplet_frac": 0.0, "pericentral_droplet_frac": 0.0,
                "zonation_gradient": 0.0, "zonation_direction": "none"}
    near_portal = 0
    near_central = 0
    for (z, y, x) in drops:
        if d_portal[z, y, x] <= d_central[z, y, x]:
            near_portal += 1
        else:
            near_central += 1
    total = near_portal + near_central
    pp_frac = near_portal / total
    pc_frac = near_central / total
    grad = pp_frac - pc_frac
    direction = ("periportal-predominant (zone 1)" if grad > 0.08 else
                 "pericentral-predominant (zone 3)" if grad < -0.08 else
                 "panlobular / no strong gradient")
    return {
        "periportal_droplet_frac": pp_frac,
        "pericentral_droplet_frac": pc_frac,
        "zonation_gradient": grad,
        "zonation_direction": direction,
        "n_droplet_voxels": int(total),
    }


def analyze_stage(stage, ref, seed=None):
    sl = SyntheticLiver(stage, ref, seed=seed).generate()
    fib = analyze_fibrosis(sl, ref)
    duct = analyze_ductular(sl)
    zon = analyze_zonation(sl)
    return {"stage": stage, "label": sl.params["label"],
            "fibrosis": fib, "ductular": duct, "zonation": zon}


# ---------------------------------------------------------------------------
# TIFF input (optional)
# ---------------------------------------------------------------------------
def analyze_input_volume(path, ref):
    if not HAVE_TIFFFILE:
        print("ERROR: --input requires the 'tifffile' package, which is not installed.")
        print("       Install with: pip install tifffile  (or run the built-in --demo)")
        return None
    if not os.path.exists(path):
        print("ERROR: input file not found: %s" % path)
        return None
    vol = tifffile.imread(path).astype(np.float32)
    if vol.ndim == 2:
        vol = vol[None, ...]
    if vol.ndim == 4:  # assume (Z,Y,X,C) or (C,Z,Y,X) -> take channel 0
        vol = vol[..., 0] if vol.shape[-1] <= 4 else vol[0]
    vol = vol / (vol.max() + 1e-9)
    # treat the single channel as collagen; build a minimal SyntheticLiver-like holder
    holder = _ExternalVolume(vol, ref)
    fib = analyze_fibrosis(holder, ref)
    duct = analyze_ductular(holder)
    zon = analyze_zonation(holder)
    return {"stage": "USER", "label": "user-supplied volume: %s" % os.path.basename(path),
            "fibrosis": fib, "ductular": duct, "zonation": zon}


class _ExternalVolume:
    """Wrap a user TIFF so the same analyzers work. Nodes inferred crudely."""
    def __init__(self, vol, ref):
        self.collagen = vol
        self.ductular = vol  # without separate channels, reuse (clearly labeled in output)
        self.steatosis = vol
        self.shape = vol.shape
        # infer pseudo portal nodes as local maxima of blurred collagen
        blurred = _gaussian(vol, 2.0)
        flat = blurred.ravel()
        k = min(8, max(2, flat.size // 50000))
        thr = np.partition(flat, -k)[-k]
        coords = np.argwhere(blurred >= thr)[:k]
        self.portal_nodes = [tuple(int(v) for v in c) for c in coords]
        Z = vol.shape[0]
        self.central_nodes = [(min(Z - 1, p[0] + Z // 4), p[1], p[2]) for p in self.portal_nodes[:max(2, len(self.portal_nodes) - 2)]]
        self.bridge_edges = []  # unknown ground truth for user data
        self.params = {"label": "user volume"}


# ---------------------------------------------------------------------------
# Reporting
# ---------------------------------------------------------------------------
def fmt_pct(x):
    return "%.2f%%" % (100.0 * x)


def print_header():
    print("=" * 74)
    print("  HepatoFabric3D  -  3D liver fibrosis network topology & zonation")
    print("  Domain: MASLD / MASH   |   Category: ex vivo 3D image quantification")
    print("=" * 74)
    print("  " + DISCLAIMER)
    print("  Backends: skel=%s | label=%s | graph=%s | tiff=%s"
          % ("skimage" if HAVE_SKIMAGE else "numpy-fallback",
             "scipy" if HAVE_SCIPY else "numpy-fallback",
             "networkx" if HAVE_NETWORKX else "manual-fallback",
             "tifffile" if HAVE_TIFFFILE else "absent"))
    print("=" * 74)


def print_stage_report(res, top=None):
    f = res["fibrosis"]; d = res["ductular"]; z = res["zonation"]
    g = f["graph"]
    print()
    print("STAGE %s  -  %s" % (res["stage"], res["label"]))
    print("-" * 74)
    print(" [1] Collagen / fibrosis network")
    print("     2D-equivalent area fraction ...... %s" % fmt_pct(f["collagen_area_fraction"]))
    print("     3D connected components ........... %d" % f["n_connected_components"])
    print("     skeleton voxels ................... %d   (backend: %s)"
          % (f["skeleton_voxels"], f["skeleton_backend"]))
    print("     skeleton branch / end points ...... %d / %d" % (f["branch_points"], f["end_points"]))
    print("     >> CONNECTIVITY INDEX ............. %.3f  -> band '%s' (suggests %s)"
          % (f["connectivity_index"], f["connectivity_band"]["band"],
             f["connectivity_band"]["suggests_stage"]))
    print(" [2] Bridging septa topology (portal-central axis)")
    print("     portal nodes / central nodes ...... %d / %d" % (f["n_portal_nodes"], f["n_central_nodes"]))
    print("     bridging septa total .............. %d  (PP=%d  portal-portal | PC=%d portal-central)"
          % (f["n_bridges_total"], f["n_bridges_PP"], f["n_bridges_PC"]))
    print("     mean bridge completeness .......... %s" % fmt_pct(f["mean_bridge_completeness"]))
    print("     graph: components=%d  largest=%d nodes  maxDeg=%d  cycles=%d  (%s)"
          % (g["n_graph_components"], g["largest_component_nodes"],
             g["max_node_degree"], g["circuit_rank_cycles"], g["graph_backend"]))
    print(" [3] Ductular reaction (CK19+ biliary 3D branching)")
    print("     ductular area fraction ............ %s" % fmt_pct(d["ductular_area_fraction"]))
    print("     components / branch points ........ %d / %d" % (d["ductular_components"], d["ductular_branch_points"]))
    print("     branches per component ............ %.2f" % d["branches_per_component"])
    print(" [4] Steatosis zonation gradient")
    print("     periportal droplet fraction ....... %s" % fmt_pct(z["periportal_droplet_frac"]))
    print("     pericentral droplet fraction ...... %s" % fmt_pct(z["pericentral_droplet_frac"]))
    print("     zonation gradient (PP - PC) ....... %+.3f  -> %s"
          % (z["zonation_gradient"], z["zonation_direction"]))


def print_comparison(results):
    print()
    print("=" * 74)
    print("  COMPARATIVE REPORT  -  '2D area% can match while 3D connectivity differs'")
    print("=" * 74)
    hdr = "  %-5s | %-9s | %-7s | %-9s | %-7s | %-8s | %-6s" % (
        "stage", "area%", "comps", "connIdx", "bridges", "PP/PC", "band")
    print(hdr)
    print("  " + "-" * 70)
    for r in results:
        f = r["fibrosis"]
        print("  %-5s | %-9s | %-7d | %-9.3f | %-7d | %-8s | %-6s" % (
            r["stage"], fmt_pct(f["collagen_area_fraction"]),
            f["n_connected_components"], f["connectivity_index"],
            f["n_bridges_total"], "%d/%d" % (f["n_bridges_PP"], f["n_bridges_PC"]),
            f["connectivity_band"]["band"]))
    print("  " + "-" * 70)

    # find the most instructive pair: as-close-as-possible 2D area%, but the
    # largest divergence in 3D connectivity. Heavily reward area similarity so
    # the chosen pair genuinely illustrates "same area%, different topology".
    best = None
    if len(results) > 1:
        # normalize area gap by the observed area spread so the penalty is scale-free
        areas = [r["fibrosis"]["collagen_area_fraction"] for r in results]
        area_span = (max(areas) - min(areas)) or 1.0
        for i in range(len(results)):
            for j in range(i + 1, len(results)):
                fi, fj = results[i]["fibrosis"], results[j]["fibrosis"]
                area_gap = abs(fi["collagen_area_fraction"] - fj["collagen_area_fraction"])
                conn_gap = abs(fi["connectivity_index"] - fj["connectivity_index"])
                # smaller normalized area gap + bigger connectivity gap = better
                score = conn_gap * (1.0 - area_gap / area_span)
                if best is None or score > best[0]:
                    best = (score, results[i], results[j], area_gap, conn_gap)
    if best:
        _, ra, rb, area_gap, conn_gap = best
        bridge_gap = abs(ra["fibrosis"]["n_bridges_total"] - rb["fibrosis"]["n_bridges_total"])
        print()
        print("  KEY INSIGHT (smallest 2D-area gap with largest 3D-topology divergence):")
        print("    %s vs %s differ by only %s in 2D collagen area fraction," % (
            ra["stage"], rb["stage"], fmt_pct(area_gap)))
        print("    yet their 3D connectivity index differs by %.3f and bridging" % conn_gap)
        print("    septa count by %d. The same 2D area%% maps to very different" % bridge_gap)
        print("    network architecture (isolated periportal collagen vs. spanning septa).")
        print("    This is exactly what 2D histomorphometry (collagen area%) misses, and")
        print("    why 3D bridging topology is informative in MASH staging research.")
    print()
    print("  " + DISCLAIMER)


def print_summary(results):
    print()
    print("SUMMARY (one line per stage):")
    for r in results:
        f = r["fibrosis"]; z = r["zonation"]
        print("  %-3s area=%-7s comps=%-3d connIdx=%-6.3f bridges=%-3d band=%-9s zonation=%+.2f"
              % (r["stage"], fmt_pct(f["collagen_area_fraction"]),
                 f["n_connected_components"], f["connectivity_index"],
                 f["n_bridges_total"], f["connectivity_band"]["band"],
                 z["zonation_gradient"]))
    print("  " + DISCLAIMER)


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def build_parser():
    p = argparse.ArgumentParser(
        prog="main.py",
        description="HepatoFabric3D: 3D ex vivo liver fibrosis network topology, "
                    "bridging classification, ductular-reaction branching, and "
                    "steatosis zonation quantifier (MASLD/MASH research tool).",
        epilog="RESEARCH/EDUCATIONAL USE ONLY - not a clinical diagnosis tool. "
               "Default runs a synthetic F1-F4 demo offline (no network).",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    p.add_argument("--demo", action="store_true",
                   help="Run synthetic fibrosis demo (DEFAULT if no input given).")
    p.add_argument("--input", metavar="PATH",
                   help="Analyze a user TIFF stack (needs optional 'tifffile').")
    p.add_argument("--stage", choices=STAGES,
                   help="Restrict demo to a single fibrosis stage (F1-F4). "
                        "Default: analyze & compare all four.")
    p.add_argument("--top", type=int, metavar="N", default=None,
                   help="Show only the top-N stages by connectivity index.")
    p.add_argument("--summary", action="store_true",
                   help="Print compact one-line-per-stage summary only.")
    p.add_argument("--seed", type=int, default=None,
                   help="Override random seed for synthetic generation.")
    p.add_argument("--json", action="store_true",
                   help="Emit machine-readable JSON instead of text report.")
    return p


def main(argv=None):
    args = build_parser().parse_args(argv)
    ref = load_reference()

    # ---- user input path ----
    if args.input:
        print_header()
        res = analyze_input_volume(args.input, ref)
        if res is None:
            return 1
        if args.json:
            print(json.dumps(res, indent=2, default=float))
        else:
            print_stage_report(res)
            print("\n  NOTE: user volume analyzed as single-channel collagen proxy; "
                  "ductular/zonation use the same channel and are approximate.")
            print("  " + DISCLAIMER)
        return 0

    # ---- demo path (default) ----
    stages = [args.stage] if args.stage else STAGES
    results = [analyze_stage(s, ref, seed=args.seed) for s in stages]

    # ordering / top-N
    results_sorted = sorted(results, key=lambda r: r["fibrosis"]["connectivity_index"],
                            reverse=True)
    if args.top is not None:
        shown = results_sorted[:max(1, args.top)]
    else:
        shown = results  # natural F1..F4 order

    if args.json:
        print(json.dumps({"results": results}, indent=2, default=float))
        return 0

    print_header()
    if args.summary:
        print_summary(results)
        return 0

    for r in shown:
        print_stage_report(r)
    if len(results) > 1 and args.top is None:
        print_comparison(results)
    elif args.top is not None:
        print_comparison(shown)
    return 0


if __name__ == "__main__":
    sys.exit(main())
