"""합성 paired WSI rigid/affine 정렬.

실제 paired serial section은 작은 translation/rotation/scale 차이가 있음.
합성에서는 같은 좌표계지만, 임의 offset을 주고 rigid 정렬을 시연.
"""
from __future__ import annotations

import numpy as np
from typing import Tuple


def apply_translation(img: np.ndarray, dx: int, dy: int) -> np.ndarray:
    """단순 정수 translation. out-of-bound는 0."""
    out = np.zeros_like(img)
    h, w = img.shape[:2]
    src_x0 = max(0, -dx)
    src_x1 = min(w, w - dx)
    src_y0 = max(0, -dy)
    src_y1 = min(h, h - dy)
    dst_x0 = max(0, dx)
    dst_x1 = min(w, w + dx)
    dst_y0 = max(0, dy)
    dst_y1 = min(h, h + dy)
    if src_x1 <= src_x0 or src_y1 <= src_y0:
        return out
    out[dst_y0:dst_y1, dst_x0:dst_x1] = img[src_y0:src_y1, src_x0:src_x1]
    return out


def estimate_rigid_translation(ref: np.ndarray, mov: np.ndarray, max_shift: int = 20) -> Tuple[int, int]:
    """간단한 phase-correlation 대용: SAD로 best (dx, dy) 탐색.

    ref, mov: 2D float32. max_shift 내에서 grid search (간격 2 px).
    """
    if ref.ndim == 3:
        ref = ref.mean(axis=2)
    if mov.ndim == 3:
        mov = mov.mean(axis=2)
    best = (0, 0)
    best_score = float("inf")
    for dy in range(-max_shift, max_shift + 1, 2):
        for dx in range(-max_shift, max_shift + 1, 2):
            shifted = apply_translation(mov, dx, dy)
            # 중앙부 crop만 비교
            h, w = ref.shape
            cy0, cy1 = h // 4, 3 * h // 4
            cx0, cx1 = w // 4, 3 * w // 4
            score = float(np.mean(np.abs(ref[cy0:cy1, cx0:cx1] - shifted[cy0:cy1, cx0:cx1])))
            if score < best_score:
                best_score = score
                best = (dx, dy)
    return best


def register_pair(ref: np.ndarray, mov: np.ndarray, max_shift: int = 20) -> Tuple[np.ndarray, Tuple[int, int]]:
    """ref에 mov를 맞춰서 정렬한 결과 반환."""
    dx, dy = estimate_rigid_translation(ref, mov, max_shift=max_shift)
    if mov.ndim == 3:
        aligned = np.zeros_like(mov)
        for c in range(mov.shape[2]):
            aligned[..., c] = apply_translation(mov[..., c], dx, dy)
    else:
        aligned = apply_translation(mov, dx, dy)
    return aligned, (dx, dy)
