Source code for kornia.contrib.extract_patches

from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair


[docs]class ExtractTensorPatches(nn.Module): r"""Module that extract patches from tensors and stack them. In the simplest case, the output value of the operator with input size :math:`(B, C, H, W)` is :math:`(B, N, C, H_{out}, W_{out})`. where - :math:`B` is the batch size. - :math:`N` denotes the total number of extracted patches stacked in - :math:`C` denotes the number of input channels. - :math:`H`, :math:`W` the input height and width of the input in pixels. - :math:`H_{out}`, :math:`W_{out}` denote to denote to the patch size defined in the function signature. left-right and top-bottom order. * :attr:`window_size` is the size of the sliding window and controls the shape of the output tensor and defines the shape of the output patch. * :attr:`stride` controls the stride to apply to the sliding window and regulates the overlapping between the extracted patches. * :attr:`padding` controls the amount of implicit zeros-paddings on both sizes at each dimension. The parameters :attr:`window_size`, :attr:`stride` and :attr:`padding` can be either: - a single ``int`` -- in which case the same value is used for the height and width dimension. - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, and the second `int` for the width dimension. Args: window_size: the size of the sliding window and the output patch size. stride: stride of the sliding window. padding: Zero-padding added to both side of the input. Shape: - Input: :math:`(B, C, H, W)` - Output: :math:`(B, N, C, H_{out}, W_{out})` Returns: the tensor with the extracted patches. Examples: >>> input = torch.arange(9.).view(1, 1, 3, 3) >>> patches = extract_tensor_patches(input, (2, 3)) >>> input tensor([[[[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]]]]) >>> patches[:, -1] tensor([[[[3., 4., 5.], [6., 7., 8.]]]]) """ def __init__( self, window_size: Union[int, Tuple[int, int]], stride: Optional[Union[int, Tuple[int, int]]] = 1, padding: Optional[Union[int, Tuple[int, int]]] = 0, ) -> None: super().__init__() self.window_size: Tuple[int, int] = _pair(window_size) self.stride: Tuple[int, int] = _pair(stride) self.padding: Tuple[int, int] = _pair(padding) def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore return extract_tensor_patches(input, self.window_size, stride=self.stride, padding=self.padding)
###################### # functional interface ###################### def _extract_tensor_patchesnd( input: torch.Tensor, window_sizes: Tuple[int, ...], strides: Tuple[int, ...] ) -> torch.Tensor: batch_size, num_channels = input.size()[:2] dims = range(2, input.dim()) for dim, patch_size, stride in zip(dims, window_sizes, strides): input = input.unfold(dim, patch_size, stride) input = input.permute(0, *dims, 1, *(dim + len(dims) for dim in dims)).contiguous() return input.view(batch_size, -1, num_channels, *window_sizes)
[docs]def extract_tensor_patches( input: torch.Tensor, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, ) -> torch.Tensor: r"""Function that extract patches from tensors and stack them. See :class:`~kornia.contrib.ExtractTensorPatches` for details. """ if not torch.is_tensor(input): raise TypeError(f"Input input type is not a torch.Tensor. Got {type(input)}") if not len(input.shape) == 4: raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") if padding: pad_vert, pad_horz = _pair(padding) input = F.pad(input, [pad_horz, pad_horz, pad_vert, pad_vert]) return _extract_tensor_patchesnd(input, _pair(window_size), _pair(stride))