Source code for giant.image_processing.feature_matchers.roma_matcher

from dataclasses import dataclass, field

import numpy as np
from numpy.typing import NDArray, DTypeLike

from giant.image_processing.feature_matchers.feature_matcher import FeatureMatcher
from giant.image_processing.utilities.image_validation_mixin import ImageValidationMixin
from giant._typing import DOUBLE_ARRAY
from giant.utilities.options import UserOptions
from giant.utilities.mixin_classes.attribute_equality_comparison import AttributeEqualityComparison
from giant.utilities.mixin_classes.attribute_printing import AttributePrinting
from giant.utilities.mixin_classes.user_option_configured import UserOptionConfigured


try:
    from romatch.models import roma_outdoor
    import torch
    from PIL import Image
    
    def _determine_default_device() -> torch.device:
        if torch.backends.mps.is_available():
            return torch.device('mps')
        elif torch.cuda.is_available():
            return torch.device('cuda')
        else:
            return torch.device('cpu')
    
[docs] @dataclass class RoMaFeatureMatcherOptions(UserOptions): device: torch.device = field(default_factory=_determine_default_device) """ What device to work on """ coarse_res: int | tuple[int, int] = 560 """ The initial coarse resolution of the image (must be a multiple of 14) """ upsample_res: tuple[int, int] = (864, 864) """ The resolution to upsample the image to """ sample_thresh: float = 0.05 """ Controls the thresholding used when sampling matches for estimation. In certain cases a lower or higher threshold may improve results. """
[docs] class RoMaFeatureMatcher(UserOptionConfigured[RoMaFeatureMatcherOptions], RoMaFeatureMatcherOptions, FeatureMatcher, AttributeEqualityComparison, AttributePrinting, ImageValidationMixin): """ Implementation of a matcher using RoMa. """ allowed_dtypes: list[DTypeLike] = [np.uint8] def __init__(self, options: RoMaFeatureMatcherOptions | None = None): """ Initialize the RomaKeypointMatcher. :param options: The options to configure with :param romatch_checkpoint: Path to the RoMa model checkpoint. """ super().__init__(RoMaFeatureMatcherOptions, options=options) # ratio_threshold is unused self.roma_model = roma_outdoor(device=self.device, coarse_res=self.coarse_res, upsample_res=self.upsample_res) self.roma_model.sample_thresh = self.sample_thresh print(f"RoMa model loaded on {self.device}")
[docs] def match_images(self, image1: NDArray, image2: NDArray) -> DOUBLE_ARRAY: """ Matches keypoints by overriding the base class method to use RoMa's end-to-end matching process. :param image1: The first image to match (as a NumPy array). :param image2: The second image to match (as a NumPy array). :returns: An array of the matched keypoint locations of shape (N, 2, 2). """ # Convert images to PIL RGB # RoMa's default processing expects RGB. image1_pil = Image.fromarray(self._validate_and_convert_image(image1)).convert('RGB') image2_pil = Image.fromarray(self._validate_and_convert_image(image2)).convert('RGB') # Use RoMa to get correspondences warp, certainty = self.roma_model.match(image1_pil, image2_pil, device=self.device) # get the correspondences matches, certainty = self.roma_model.sample(warp, certainty) kpts1, kpts2 = self.roma_model.to_pixel_coordinates(matches, *image1.shape, *image2.shape) # Reshape the (N, 4) array to (N, 2, 2) matched_keypoints_array = np.concat([kpts1.cpu().numpy().reshape(-1, 1, 2), kpts2.cpu().numpy().reshape(-1, 1, 2)], axis=1) return matched_keypoints_array.astype(np.float64)
except ImportError: raise ImportError('RoMa is not installed. Please clone and follow the install instructions from https://github.com/Parskatt/RoMa/tree/main')