Week 6 - Imagery Alignment: Introduction and Simple Examples#

Week 6 materials can be accessed here.


1. Introduction and Context#

In remote sensing, aligning images from different sensors or different times is a critical step in many applications such as sea ice monitoring, land cover change, or multi-sensor data fusion. For instance, aligning a Sentinel-2 image to a Sentinel-3 image involves co-locating spatial points so that each pixel in one dataset corresponds to the same ground location (or sea/ice location) in the other dataset.

However, imagery alignment is often complicated by:

  • Sensor differences (different resolutions, spectral bands, angles).

  • Scene dynamics (moving sea ice, vegetation changes, etc.).

  • Acquisition timing differences (the Earth or ice might move between acquisitions).

1.1 Four Methods for Image Alignment#

  1. Auto-correlation (Cross-correlation)

    • A straightforward spatial-domain approach that searches for the shift (translation) maximizing the correlation between two images.

    • Often implemented by brute-force correlation (rolling or shifting the image) or by using 2D convolution methods.

  2. Phase Correlation

    • A frequency-domain method leveraging the Fourier shift theorem.

    • Typically faster for large images and can provide sub-pixel accuracy.

    • Commonly used in remote sensing for global registration of two images that differ only by translation.

  3. ECC (Enhanced Correlation Coefficient)

    • An iterative, gradient-based approach (available in OpenCV’s findTransformECC) that maximizes a correlation measure between two images.

    • Can be used for translation, affine, or homography alignment, but here we focus on translation mode.

    • Often converges quickly and handles small or moderate shifts well.

  4. SEA-RAFT (Optical Flow–Based Estimation) We thank Dr François Chadebecq for the implementation of this method in our jupyter book.

    • A state-of-the-art deep learning approach that predicts dense pixel-wise motion (optical flow) between frames.

    • While computationally intensive and complex to train, it offers superior accuracy by leveraging pretrained weights from large-scale computer vision datasets.

    • For this work, we utilize the authors’ pretrained models, though performance could be further optimised through fine-tuning specifically on satellite imagery.

In this Notebook, we will:

  1. Introduce the basic concepts behind these three methods.

  2. Demonstrate them on simple, synthetic examples (e.g., 2D arrays or toy images).

  3. Lay groundwork for applying these methods to real satellite data in Notebook 2.

from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import correlate2d
import cv2

%matplotlib inline
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
/tmp/ipython-input-2924983776.py in <cell line: 0>()
      2 import matplotlib.pyplot as plt
      3 from scipy.signal import correlate2d
----> 4 import ot
      5 import cv2
      6 

ModuleNotFoundError: No module named 'ot'

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------

Create synthetic images for example#

This cell what we do is to create a synthetic pair of original image and shifted image so we can use them to illustrate the performance of the algorithm.

def create_synthetic_image(size=(100, 100), centers=[(50, 50)], radius=10, intensity=1.0):
    """
    Create a synthetic image with circular bright spots.
    size: (height, width)
    centers: list of (row, col) circle centers
    radius: radius of circles
    intensity: pixel value inside the circle
    """
    img = np.zeros(size, dtype=np.float32)
    for c in centers:
        rr, cc = np.ogrid[:size[0], :size[1]]
        mask = (rr - c[0])**2 + (cc - c[1])**2 <= radius**2
        img[mask] = intensity
    return img

imageA = create_synthetic_image(
    size=(100, 100),
    centers=[(30, 30), (70, 70)],
    radius=8,
    intensity=1.0
)

#We can define a known shift, e.g., (dy, dx) = (5, 10)
known_shift = (15, 25)

def shift_image(img, shift):
    """
    Shift the image by (dy, dx) using np.roll for simplicity.
    This is a naive wrap-around shift, but it's enough to demonstrate the concept.
    """
    dy, dx = shift
    shifted = np.roll(img, dy, axis=0)
    shifted = np.roll(shifted, dx, axis=1)
    return shifted

imageB = shift_image(imageA, known_shift)

# We visualise them here
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(imageA, cmap='gray')
axes[0].set_title('Image A')
axes[1].imshow(imageB, cmap='gray')
axes[1].set_title('Image B (shifted)')
plt.show()
_images/ea7de39b5e684e0d5d234d89fbc636348b81dbf1024ac159af2111c6eff268c0.png

Auto-correlation#

Auto-correlation (or spatial cross-correlation) is a direct method to find the shift between two images by measuring how strongly one image correlates with various shifted versions of the other. Concretely, you can think of “sliding” one image over another in small increments and computing a correlation score at each step. The shift that maximizes this score is considered the best alignment.

  • Pros: Easy to grasp conceptually, works well for moderate image sizes, and doesn’t require special libraries beyond basic 2D operations.

  • Cons: Can be slow if you do a large brute-force search, and it only handles pure translation (no rotation or scaling).

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrow
from scipy.signal import correlate2d

def find_shift_via_cross_correlation(imgA, imgB):
    """
    Compute the shift between imgA and imgB via cross-correlation.
    BUT then negate the result to match the (dy, dx) we used in np.roll(imageA, (dy, dx)).

    Returns (dy, dx) in the same sign convention as the 'shift_image' function.
    """
    corr = correlate2d(imgA, imgB, boundary='fill', mode='full')
    max_idx = np.unravel_index(np.argmax(corr), corr.shape)

    # By default, this returns how to move B-->A. We'll call those (shift_y, shift_x).
    shift_y = max_idx[0] - imgA.shape[0] + 1
    shift_x = max_idx[1] - imgA.shape[1] + 1

    # Negate them to get A-->B
    return (-shift_y, -shift_x)

def visualize_shift_arrows(imgA, imgB, shift, step=10, method_name="ECC"):
    (dy, dx) = shift

    fig, (axA, axB) = plt.subplots(1, 2, figsize=(12, 5))

    # Plot Image A with arrows
    axA.imshow(imgA, cmap='gray')
    axA.set_title("Image A")

    h, w = imgA.shape
    rows = np.arange(0, h, step)
    cols = np.arange(0, w, step)

    for r in rows:
        for c in cols:
            axA.add_patch(FancyArrow(c, r, dx, dy, color="yellow",
                                     width=0.5, head_width=2, head_length=2, alpha=0.8))

    # Plot Image B without arrows
    axB.imshow(imgB, cmap='gray')
    axB.set_title("Image B (Shifted)")

    plt.suptitle(f"{method_name} shift = (dy={dy:.2f}, dx={dx:.2f})")
    plt.show()

def create_synthetic_image(size=(100, 100), centers=[(30, 30), (70, 70)], radius=8, intensity=1.0):
    img = np.zeros(size, dtype=np.float32)
    rr, cc = np.ogrid[:size[0], :size[1]]
    for c in centers:
        mask = (rr - c[0])**2 + (cc - c[1])**2 <= radius**2
        img[mask] = intensity
    return img

def shift_image(img, shift):
    (dy, dx) = shift
    shifted = np.roll(img, dy, axis=0)
    shifted = np.roll(shifted, dx, axis=1)
    return shifted

# Create two synthetic images
imageA = create_synthetic_image()
known_shift = (26, 5)
imageB = shift_image(imageA, known_shift)

# Estimate shift (now it should match known_shift)
estimated_shift_cc = find_shift_via_cross_correlation(imageA, imageB)
print("Known shift:       ", known_shift)
print("Estimated shift (CC):", estimated_shift_cc)

# Visualize with arrows on a sampled grid
visualize_shift_arrows(imageA, imageB, estimated_shift_cc, step=10, method_name="Auto Correlation Translation")
Known shift:        (26, 5)
Estimated shift (CC): (np.int64(26), np.int64(5))
_images/cee080c40a5c55cd2c65729b1156ae4c3b264d3bd87ca8252ebd540db4aaebba.png

Phase Correlation#

Phase correlation moves the cross-correlation problem into the frequency domain. By taking the Fourier transform of each image, multiplying one by the complex conjugate of the other, and looking at the peak in the inverse transform, you recover the shift. This leverages the Fourier shift theorem, which says that a translation in the spatial domain shows up as a linear phase difference in the frequency domain.

  • Pros: Faster than brute-force auto-correlation (especially for large images), naturally yields subpixel accuracy, and is robust to certain intensity variations.

  • Cons: Requires computing the Fourier transform; can become tricky if the images have major differences other than translation (e.g., rotation, scale changes).

!pip install scikit-image

import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import fourier_shift
from skimage.registration import phase_cross_correlation
from matplotlib.patches import FancyArrow


def phase_corr_estimate_shift(imgA, imgB):
    """
    phase_cross_correlation(imgA, imgB) returns the shift that aligns B to A.
    If we want the shift that moves A to B (same sign as shift_image), we negate it.
    """
    shift, error, diffphase = phase_cross_correlation(imgA, imgB)
    # shift is (row_shift, col_shift) => let's rename (dyBtoA, dxBtoA)
    (dyBtoA, dxBtoA) = shift

    # We want the shift from A->B, so just negate:
    dyAtoB = -dyBtoA
    dxAtoB = -dxBtoA
    return (dyAtoB, dxAtoB)


def visualize_shift_arrows(imgA, imgB, shift, step=10, method_name="ECC"):
    (dy, dx) = shift

    fig, (axA, axB) = plt.subplots(1, 2, figsize=(12, 5))

    # Plot Image A with arrows
    axA.imshow(imgA, cmap='gray')
    axA.set_title("Image A")

    h, w = imgA.shape
    rows = np.arange(0, h, step)
    cols = np.arange(0, w, step)

    for r in rows:
        for c in cols:
            axA.add_patch(FancyArrow(c, r, dx, dy, color="yellow",
                                     width=0.5, head_width=2, head_length=2, alpha=0.8))

    # Plot Image B without arrows
    axB.imshow(imgB, cmap='gray')
    axB.set_title("Image B (Shifted)")

    plt.suptitle(f"{method_name} shift = (dy={dy:.2f}, dx={dx:.2f})")
    plt.show()


imageA = create_synthetic_image()

known_shift = (15, 10)

imageB = shift_image(imageA, known_shift)

estimated_shift_pc = phase_corr_estimate_shift(imageA, imageB)
print("Known shift:", known_shift)
print("Estimated shift (Phase Corr):", estimated_shift_pc)
visualize_shift_arrows(imageA, imageB, estimated_shift_pc, step=10, method_name="Phase Correlation Translation")

Requirement already satisfied: scikit-image in /usr/local/lib/python3.12/dist-packages (0.25.2)
Requirement already satisfied: numpy>=1.24 in /usr/local/lib/python3.12/dist-packages (from scikit-image) (2.0.2)
Requirement already satisfied: scipy>=1.11.4 in /usr/local/lib/python3.12/dist-packages (from scikit-image) (1.16.3)
Requirement already satisfied: networkx>=3.0 in /usr/local/lib/python3.12/dist-packages (from scikit-image) (3.6.1)
Requirement already satisfied: pillow>=10.1 in /usr/local/lib/python3.12/dist-packages (from scikit-image) (11.3.0)
Requirement already satisfied: imageio!=2.35.0,>=2.33 in /usr/local/lib/python3.12/dist-packages (from scikit-image) (2.37.2)
Requirement already satisfied: tifffile>=2022.8.12 in /usr/local/lib/python3.12/dist-packages (from scikit-image) (2026.2.16)
Requirement already satisfied: packaging>=21 in /usr/local/lib/python3.12/dist-packages (from scikit-image) (26.0)
Requirement already satisfied: lazy-loader>=0.4 in /usr/local/lib/python3.12/dist-packages (from scikit-image) (0.4)
Known shift: (15, 10)
Estimated shift (Phase Corr): (np.float32(15.0), np.float32(10.0))
_images/827f83b1aaba81c12a6c8c96ee0c5bf4359c30c36d09ef9c2088f1bd225f7fd8.png

ECC (Enhanced Correlation Coefficient) Alignment#

ECC (Enhanced Correlation Coefficient)#

ECC is an iterative approach that tries to maximize a particular correlation measure between two images. It’s available in OpenCV’s findTransformECC function and can handle not only translation but also affine or homography transformations (if configured). For translation-only alignment, ECC updates the image transformation parameters step by step, converging to the shift that yields the highest correlation.

  • Pros: Often more robust and faster than naive brute force; can yield subpixel accuracy; can be extended to more complex transformations if needed.

  • Cons: Requires a decent initial guess for bigger shifts, and it’s more of a “black box” iterative method than a simple correlation map.

import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrow

def align_ecc_translation(imgA, imgB, num_iterations=300, termination_eps=1e-6, gaussFiltSize=1):
    """
    Align imgB to imgA using ECC with a translation model.
    Handles large shifts by providing an initial guess.
    """
    imgA_f = imgA.astype(np.float32) / 255.0  # Normalize
    imgB_f = imgB.astype(np.float32) / 255.0  # Normalize

    warp_matrix = np.array([[1, 0, 10], [0, 1, 15]], dtype=np.float32)  # Initial guess for large shift

    criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, num_iterations, termination_eps)

    try:
        cc, warp_matrix = cv2.findTransformECC(
            templateImage=imgA_f,
            inputImage=imgB_f,
            warpMatrix=warp_matrix,
            motionType=cv2.MOTION_TRANSLATION,
            criteria=criteria
        )
        dx = warp_matrix[0, 2]
        dy = warp_matrix[1, 2]
        return (dy, dx), cc

    except cv2.error as e:
        print("ECC failed:", e)
        return (None, None), None  # Return None if ECC fails


def shift_image(img, shift):
    """
    Shift the image by (dy, dx) using np.roll.
    """
    dy, dx = shift
    shifted = np.roll(img, dy, axis=0)
    shifted = np.roll(shifted, dx, axis=1)
    return shifted


imageA = create_synthetic_image()
known_shift = (5, 16)

imageB = shift_image(imageA, known_shift)

ecc_shift, cc_value = align_ecc_translation(imageA, imageB)

print("Known shift:", known_shift)
print("Estimated shift (ECC):", ecc_shift, "CC:", cc_value)


def visualize_ecc_shift_arrows(imgA, imgB, shift, step=10, method_name="ECC"):
    (dy, dx) = shift

    fig, (axA, axB) = plt.subplots(1, 2, figsize=(12, 5))

    # Plot Image A with arrows
    axA.imshow(imgA, cmap='gray')
    axA.set_title("Image A")

    h, w = imgA.shape
    rows = np.arange(0, h, step)
    cols = np.arange(0, w, step)

    for r in rows:
        for c in cols:
            axA.add_patch(FancyArrow(c, r, dx, dy, color="yellow",
                                     width=0.5, head_width=2, head_length=2, alpha=0.8))

    # Plot Image B without arrows
    axB.imshow(imgB, cmap='gray')
    axB.set_title("Image B (Shifted)")

    plt.suptitle(f"{method_name} shift = (dy={dy:.2f}, dx={dx:.2f}), CC={cc_value:.4f}")
    plt.show()


# Example Usage
visualize_ecc_shift_arrows(imageA, imageB, ecc_shift, step=10, method_name="ECC Translation")

Known shift: (5, 16)
Estimated shift (ECC): (np.float32(4.9955196), np.float32(15.99729)) CC: 0.999999672404866
_images/60f14fc88313d26b00335a9af41e1bea82d5ae0d4799ce50de2b5fff2007d572.png

For ECC, we have also included an example on alignment on rotated images.

import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrow

def create_synthetic_image(size=(100, 100), shape="rectangle"):
    """
    Creates a synthetic image with a simple geometric shape.
    """
    img = np.zeros(size, dtype=np.uint8)
    h, w = img.shape

    if shape == "rectangle":
        cv2.rectangle(img, (w//4, h//4), (3*w//4, 3*h//4), 255, -1)
    elif shape == "circle":
        cv2.circle(img, (w//2, h//2), w//4, 255, -1)

    return img


def rotate_image(img, angle, center=None):
    """
    Rotate the image around its center.
    """
    h, w = img.shape
    if center is None:
        center = (w//2, h//2)

    rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
    rotated = cv2.warpAffine(img, rotation_matrix, (w, h))
    return rotated


def align_ecc_rotation(imgA, imgB, num_iterations=300, termination_eps=1e-6):
    """
    Align imgB to imgA using ECC with an Euclidean transformation model (rotation + translation).
    """
    imgA_f = imgA.astype(np.float32) / 255.0
    imgB_f = imgB.astype(np.float32) / 255.0

    warp_matrix = np.eye(2, 3, dtype=np.float32)  # Identity matrix for Euclidean transform

    criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, num_iterations, termination_eps)

    try:
        cc, warp_matrix = cv2.findTransformECC(
            templateImage=imgA_f,
            inputImage=imgB_f,
            warpMatrix=warp_matrix,
            motionType=cv2.MOTION_EUCLIDEAN,
            criteria=criteria
        )

        dx = warp_matrix[0, 2]
        dy = warp_matrix[1, 2]
        theta = np.arctan2(warp_matrix[1, 0], warp_matrix[0, 0]) * (180.0 / np.pi)

        return (dy, dx, theta), cc

    except cv2.error as e:
        print("ECC failed:", e)
        return (None, None, None), None


# Create Synthetic Image and Rotate
imageA = create_synthetic_image()
known_rotation = 30  # Degrees
imageB = rotate_image(imageA, known_rotation)

# Apply ECC Rotation Alignment
ecc_shift, cc_value = align_ecc_rotation(imageA, imageB)

print("Known rotation:", known_rotation)
print("Estimated shift and rotation (ECC):", ecc_shift, "CC:", cc_value)


def visualize_ecc_rotation_arrows(imgA, imgB, shift, step=10, method_name="ECC"):
    (dy, dx, theta) = shift

    fig, (axA, axB) = plt.subplots(1, 2, figsize=(12, 5))

    axA.imshow(imgA, cmap='gray')
    axA.set_title("Image A")

    h, w = imgA.shape
    rows = np.arange(0, h, step)
    cols = np.arange(0, w, step)

    for r in rows:
        for c in cols:
            axA.add_patch(FancyArrow(c, r, dx, dy, color="yellow",
                                     width=0.5, head_width=2, head_length=2, alpha=0.8))

    axB.imshow(imgB, cmap='gray')
    axB.set_title(f"Image B (Rotated {theta:.2f}°)")

    plt.suptitle(f"{method_name} shift = (dy={dy:.2f}, dx={dx:.2f}), Rotation={theta:.2f}°, CC={cc_value:.4f}")
    plt.show()


# Visualize the Rotation Alignment
visualize_ecc_rotation_arrows(imageA, imageB, ecc_shift, step=10, method_name="ECC Rotation")
Known rotation: 30
Estimated shift and rotation (ECC): (np.float32(31.69622), np.float32(-18.302961), np.float32(-29.998287)) CC: 0.9998337715512696
_images/b30d3ca0c74a026d664621f81681f635ece9a1972868b7f56c6b4dfb7f876056.png

SeaRAFT#

Simple, Efficient, Accurate RAFT (SEARaft) is a modern AI-based optical-flow and image registration method. It is a refinement of the AI-based approach Recurrent All-Pairs Field Transforms (RAFT) tailored for estimating motion between frames.

https://arxiv.org/abs/2405.14793 https://arxiv.org/abs/2003.12039

SEARaft iteratively updates a dense motion field (optical flow) between image pairs using recurrent modules aim at improving alignment step by step. We thank Dr François Chadebecq for the implementation of this method in our jupyter book.

# Install required deep learning, image processing, and visualization libraries
!pip install torch torchvision opencv-python matplotlib

# OS utilities (file system checks, paths)
import os
# System utilities (used to modify Python path)
import sys
# Deep learning framework
import torch
# Numerical computations
import numpy as np
# Computer vision library
import cv2
# Visualization
import matplotlib.pyplot as plt
# Image loading / manipulation
from PIL import Image

# Download SEA-RAFT if it is not already available
if not os.path.exists("SEA-RAFT"):
    !git clone https://github.com/princeton-vl/SEA-RAFT.git
else:
    print("SEA-RAFT folder exists — skipping clone")

# Simple container for configuration objects
from types import SimpleNamespace
# Allows Python to find RAFT modules inside the repository
sys.path.append("SEA-RAFT/core")
# Main RAFT model class
from raft import RAFT
# Utility for padding images to dimensions compatible with the network
from utils.utils import InputPadder
Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (2.10.0+cu128)
Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (0.25.0+cu128)
Requirement already satisfied: opencv-python in /usr/local/lib/python3.12/dist-packages (4.13.0.92)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (3.10.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch) (3.24.2)
Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch) (4.15.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch) (75.2.0)
Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch) (1.14.0)
Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch) (3.6.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch) (3.1.6)
Requirement already satisfied: fsspec>=0.8.5 in /usr/local/lib/python3.12/dist-packages (from torch) (2025.3.0)
Requirement already satisfied: cuda-bindings==12.9.4 in /usr/local/lib/python3.12/dist-packages (from torch) (12.9.4)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch) (12.8.93)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch) (12.8.90)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch) (12.8.90)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch) (9.10.2.21)
Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /usr/local/lib/python3.12/dist-packages (from torch) (12.8.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /usr/local/lib/python3.12/dist-packages (from torch) (11.3.3.83)
Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /usr/local/lib/python3.12/dist-packages (from torch) (10.3.9.90)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /usr/local/lib/python3.12/dist-packages (from torch) (11.7.3.90)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /usr/local/lib/python3.12/dist-packages (from torch) (12.5.8.93)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch) (0.7.1)
Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch) (2.27.5)
Requirement already satisfied: nvidia-nvshmem-cu12==3.4.5 in /usr/local/lib/python3.12/dist-packages (from torch) (3.4.5)
Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch) (12.8.90)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch) (12.8.93)
Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /usr/local/lib/python3.12/dist-packages (from torch) (1.13.1.3)
Requirement already satisfied: triton==3.6.0 in /usr/local/lib/python3.12/dist-packages (from torch) (3.6.0)
Requirement already satisfied: cuda-pathfinder~=1.1 in /usr/local/lib/python3.12/dist-packages (from cuda-bindings==12.9.4->torch) (1.3.4)
Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from torchvision) (2.0.2)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.12/dist-packages (from torchvision) (11.3.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (4.61.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (1.4.9)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (26.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (3.3.2)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.12/dist-packages (from matplotlib) (2.9.0.post0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.7->matplotlib) (1.17.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch) (3.0.3)
Cloning into 'SEA-RAFT'...
remote: Enumerating objects: 115, done.
remote: Counting objects: 100% (40/40), done.
remote: Compressing objects: 100% (18/18), done.
remote: Total 115 (delta 28), reused 22 (delta 22), pack-reused 75 (from 1)
Receiving objects: 100% (115/115), 8.16 MiB | 16.20 MiB/s, done.
Resolving deltas: 100% (51/51), done.
def create_synthetic_image(size=(100, 100), centers=[(50, 50)], radius=10, intensity=1.0):
    """
    Create a synthetic image with circular bright spots.
    size: (height, width)
    centers: list of (row, col) circle centers
    radius: radius of circles
    intensity: pixel value inside the circle
    """
    img = np.zeros(size, dtype=np.float32)
    for c in centers:
        rr, cc = np.ogrid[:size[0], :size[1]]
        mask = (rr - c[0])**2 + (cc - c[1])**2 <= radius**2
        img[mask] = intensity
    return img

imageA = create_synthetic_image(
    size=(1000, 1000),
    centers=[(300, 300), (700, 700)],
    radius=80,
    intensity=1.0
)


def shift_image(img, shift):
    """
    Shift the image by (dy, dx) using np.roll for simplicity.
    This is a naive wrap-around shift, but it's enough to demonstrate the concept.
    """
    dy, dx = shift
    shifted = np.roll(img, dy, axis=0)
    shifted = np.roll(shifted, dx, axis=1)
    return shifted

#We can define a known shift, e.g., (dy, dx) = (150, 250)
known_shift = (150, 250)
imageB = shift_image(imageA, known_shift)
known_shift = (100, 200)
imageC = shift_image(imageA, known_shift)

# We visualise them here
fig, axes = plt.subplots(1, 3, figsize=(10, 4))
axes[0].imshow(imageA, cmap='gray')
axes[0].set_title('Image A')
axes[1].imshow(imageB, cmap='gray')
axes[1].set_title('Image B (shifted)')
axes[2].imshow(imageC, cmap='gray')
axes[2].set_title('Image C (shifted)')
plt.show()
_images/aa4a087d5bc64cc662c1a198dd37949a7a92c85ef547066db0bcf55c00247092.png

Optical Flow–Based Displacement Estimation with SEA-RAFT#

SEA-RAFT achieves state-of-the-art optical flow accuracy on standard benchmarks. However, it is computationally expensive, has relatively slow inference, and is complex to train.

In this work, we therefore rely on pretrained model weights provided by the authors of SEA-RAFT, which were trained on standard computer-vision image registration datasets (e.g., in-the-wild datasets). We note that a dedicated training or fine-tuning on satellite imagery would likely further improve performance for our application.

def grayscale_to_rgb(img, normalize=True):
    """
    Convert a single-channel grayscale image to a 3-channel RGB image.

    SEA‑RAFT expects RGB images as input.

    This function converts a grayscale image
    to RGB by replicating the grayscale channel across the three RGB channels.

    Parameters
    ----------
    img : np.ndarray
        Input grayscale image of shape (H, W).
    normalize : bool, optional (default=True)
        If True, the image will be normalized to the [0, 255] range.
        For real data, additional preprocessing (e.g., contrast enhancement)
        may be required.

    Returns
    -------
    rgb_img : np.ndarray
        RGB image of shape (H, W, 3), dtype=np.uint8.
    """
    if img.ndim != 2:
        raise ValueError("Input image must be single-channel (H, W)")

    img_out = img.copy()

    if normalize:
        img_min, img_max = img_out.min(), img_out.max()
        if img_max > img_min:
            img_out = (img_out - img_min) / (img_max - img_min)
        img_out = (img_out * 255).astype(np.uint8)
    else:
        img_out = img_out.astype(np.uint8)

    # Stack into 3 identical channels
    rgb_img = np.stack([img_out, img_out, img_out], axis=-1)

    return rgb_img

def sea_raft_flow(img1, img2, model_path=None, device='cpu'):
    """
    Estimate dense optical flow (displacement field) between two RGB images using SEA‑RAFT.

    Parameters
    ----------
    img1 : np.ndarray
        Source RGB image of shape (H, W, 3), dtype=np.uint8 or float.
    img2 : np.ndarray
        Target RGB image of shape (H, W, 3), dtype=np.uint8 or float.
    model_path : str, optional
        Path to pre-trained SEA‑RAFT weights. If None, defaults to the standard model path.
    device : str, optional
        Device to run the model on ('cpu' or 'cuda').

    Returns
    -------
    u_disp : np.ndarray
        Horizontal displacement (dx) for each pixel, shape (H, W).
    v_disp : np.ndarray
        Vertical displacement (dy) for each pixel, shape (H, W).
    """

    # Default dictionary storing standard SEARaft parameters/arguments
    args_dict = {
        "name": "spring-M",
        "dataset": "spring",
        "gpus": [0,1,2,3,4,5,6,7],
        "use_var": True,
        "var_min": 0,
        "var_max": 10,
        "pretrain": "resnet34",
        "initial_dim": 64,
        "block_dims": [64,128,256],
        "radius": 4,
        "dim": 128,
        "num_blocks": 2,
        "iters": 4,
        "image_size": [540,960],
        "scale": -1,
        "batch_size": 32,
        "epsilon": 1e-08,
        "lr": 0.0004,
        "wdecay": 1e-05,
        "dropout": 0,
        "clip": 1.0,
        "gamma": 0.85,
        "num_steps": 120000,
        "restore_ckpt": None,
        "coarse_config": None,
        "cfg": "SEA-RAFT/config/eval/spring-M.json",
        "path": "/content/drive/MyDrive/PhD Year 3/GEOL0069_test_2026/Week 6/Tartan-C-T-TSKH-spring540x960-M.pth",
        "url": None,
        "device": device
    }

    # Convert the dictionary to an arg namespace (compatibility)
    args = SimpleNamespace(**args_dict)

    # Instantiate the RAFT model
    model = RAFT(args)
    model_path = args.path
    if model_path:
        model.load_state_dict(torch.load(model_path,map_location=torch.device(device)))
    model.to(device)
    model.eval()

    # Convert input images to tensors
    img1_t = torch.from_numpy(img1).permute(2,0,1).float()[None].to(device)
    img2_t = torch.from_numpy(img2).permute(2,0,1).float()[None].to(device)

    # Pad to multiple of 8 (constraint related to the architecture design)
    padder = InputPadder(img1_t.shape)
    img1_t, img2_t = padder.pad(img1_t, img2_t)

    # Evaluation
    with torch.no_grad():
        outputs = model(img1_t, img2_t, iters=20, test_mode=True)
        flow = outputs['flow'][-1]  # last iteration

    # Extract the horizontal and vertical flow (displacement) dx, dy
    flow = flow.squeeze(0)         # remove batch -> shape (2, H, W)
    flow = flow.permute(1,2,0)     # (H, W, 2)

    u_disp = flow[...,0].cpu().numpy()  # horizontal
    v_disp = flow[...,1].cpu().numpy()  # vertical

    return u_disp, v_disp

# If needed: Helper function to plot the optical flow (displacement vectors)
def plot_quiver(img, dx, dy, skip=10, title="Displacement Field"):
    """
    Overlay a quiver plot (arrows) showing pixel-wise displacement vectors on top of an image.

    Parameters
    ----------
    img : np.ndarray
        Background image to display, typically RGB.
    dx : np.ndarray
        Horizontal displacement (flow) for each pixel.
    dy : np.ndarray
        Vertical displacement (flow) for each pixel.
    skip : int, optional (default=10)
        Only every `skip`-th pixel in both x and y directions will be plotted.
    title : str, optional
        Title of the plot.

    Notes
    -----
    - The displacement vectors are drawn in red by default.
    - The quiver arrows are scaled to correspond to pixel displacements (scale=1).
    """
    h, w = dx.shape
    Y, X = np.mgrid[0:h, 0:w]
    plt.figure(figsize=(8,8))
    plt.imshow(img)
    plt.quiver(
        X[::skip, ::skip], Y[::skip, ::skip],
        dx[::skip, ::skip], dy[::skip, ::skip],
        angles='xy', scale_units='xy', scale=1 , color='red'
    )
    plt.title(title)
    plt.axis('off')
    plt.show()

# "Convert" images to RGB
imageA = grayscale_to_rgb(imageA)
imageB = grayscale_to_rgb(imageB)
imageC = grayscale_to_rgb(imageC)

# Estimate optical flow using SEARaft
dx_raft_AB, dy_raft_AB = sea_raft_flow(imageA, imageB, model_path=None)
dx_raft_AC, dy_raft_AC = sea_raft_flow(imageA, imageC, model_path=None)

# Display
skip = 10
fig, axes = plt.subplots(1, 2, figsize=(16,8))
h1, w1 = dx_raft_AB.shape
Y1, X1 = np.mgrid[0:h1, 0:w1]
axes[0].imshow(imageA)
axes[0].quiver(X1[::skip, ::skip], Y1[::skip, ::skip],
        dx_raft_AB[::skip, ::skip], dy_raft_AB[::skip, ::skip],
        angles='xy', scale_units='xy', scale=1, color='red'
    )
axes[0].set_title("Mass splitting")
axes[0].axis('off')

h2, w2 = dx_raft_AC.shape
Y2, X2 = np.mgrid[0:h2, 0:w2]
axes[1].imshow(imageA)
axes[1].quiver(
        X2[::skip, ::skip], Y2[::skip, ::skip],
        dx_raft_AC[::skip, ::skip], dy_raft_AC[::skip, ::skip],
        angles='xy', scale_units='xy', scale=1, color='red'
    )
axes[1].set_title("No mass splitting")
axes[1].axis('off')
plt.tight_layout()
plt.show()
/usr/local/lib/python3.12/dist-packages/torch/functional.py:505: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /pytorch/aten/src/ATen/native/TensorShape.cpp:4381.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
_images/5c685623cb38ebe9d00d83877d44f063da1dc6ac741e5cb11a44496225229794.png

SEA-RAFT Performance on Rotational Motion#

def create_another_synthetic_image(size=(1000, 1000), shape="rectangle"):
    """
    Creates a synthetic image with a simple geometric shape.
    """
    img = np.zeros(size, dtype=np.uint8)
    h, w = img.shape

    if shape == "rectangle":
        cv2.rectangle(img, (w//4, h//4), (3*w//4, 3*h//4), 255, -1)
    elif shape == "circle":
        cv2.circle(img, (w//2, h//2), w//4, 255, -1)

    return img


def rotate_image(img, angle, center=None):
    """
    Rotate the image around its center.
    """
    h, w = img.shape
    if center is None:
        center = (w//2, h//2)

    rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
    rotated = cv2.warpAffine(img, rotation_matrix, (w, h))
    return rotated

# Create synthetic image and rotate it
imageA = create_another_synthetic_image()

known_rotation = 30  # Degrees
imageB = rotate_image(imageA, known_rotation)
known_rotation = 5  # Degrees
imageC = rotate_image(imageA, known_rotation)


#  "Convert" images to RGB
imageA = grayscale_to_rgb(imageA)
imageB = grayscale_to_rgb(imageB)
imageC = grayscale_to_rgb(imageC)

# Estimate optical flow using SEARaft
dx_raft_AB, dy_raft_AB = sea_raft_flow(imageA, imageB, model_path=None)
dx_raft_AC, dy_raft_AC = sea_raft_flow(imageA, imageC, model_path=None)

# Display
skip=30
fig, axes = plt.subplots(2, 2, figsize=(16,8))
h1, w1 = dx_raft_AB.shape
Y1, X1 = np.mgrid[0:h1, 0:w1]
axes[0,0].imshow(imageA)
axes[0,0].quiver(X1[::skip, ::skip], Y1[::skip, ::skip],
        dx_raft_AB[::skip, ::skip], dy_raft_AB[::skip, ::skip],
        angles='xy', scale_units='xy', scale=1, color='red'
    )
axes[0,0].set_title("Displacement")
axes[0,0].axis('off')
#######
h2, w2 = dx_raft_AB.shape
Y2, X2 = np.mgrid[0:h2, 0:w2]
axes[0,1].imshow(imageB)
axes[0,1].set_title("Target image")
axes[0,1].axis('off')
#######
h2, w2 = dx_raft_AC.shape
Y2, X2 = np.mgrid[0:h2, 0:w2]
axes[1,0].imshow(imageA)
axes[1,0].quiver(
       X2[::skip, ::skip], Y2[::skip, ::skip],
       dx_raft_AC[::skip, ::skip], dy_raft_AC[::skip, ::skip],
       angles='xy', scale_units='xy', scale=1, color='red'
      )
axes[1,0].set_title("Displacement")
axes[1,0].axis('off')
#######
h2, w2 = dx_raft_AC.shape
Y2, X2 = np.mgrid[0:h2, 0:w2]
axes[1,1].imshow(imageC)
axes[1,1].set_title("Target image")
axes[1,1].axis('off')

plt.tight_layout()
plt.show()
_images/6d4177c5aad45cddb80c48bc28080473a369743b17e1e521f87635efeab069db.png