Transforms2D¶
Set of operators to perform data augmentation on 2D image tensors.
-
class
CenterCrop
(size: Union[int, Tuple[int, int]], align_corners: bool = True, resample: Union[str, int, <unknown>.Resample] = 'BILINEAR', return_transform: bool = False, p: float = 1.0, keepdim: bool = False)[source]¶ Crops a given image tensor at the center.
- Parameters
p (float) – probability of applying the transformation for the whole batch. Default value is 1.
size (Tuple[int, int] or int) – Desired output size (out_h, out_w) of the crop. If integer, out_h = out_w = size. If Tuple[int, int], out_h = size[0], out_w = size[1].
return_transform (bool) – if
True
return the matrix describing the transformation applied to each. Default: False.keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, out_h, out_w)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Examples
>>> rng = torch.manual_seed(0) >>> inputs = torch.randn(1, 1, 4, 4) >>> inputs tensor([[[[-1.1258, -1.1524, -0.2506, -0.4339], [ 0.8487, 0.6920, -0.3160, -2.1152], [ 0.3223, -1.2633, 0.3500, 0.3081], [ 0.1198, 1.2377, 1.1168, -0.2473]]]]) >>> aug = CenterCrop(2, p=1.) >>> aug(inputs) tensor([[[[ 0.6920, -0.3160], [-1.2633, 0.3500]]]])
-
class
ColorJitter
(brightness: Union[torch.Tensor, float, Tuple[float, float], List[float]] = 0.0, contrast: Union[torch.Tensor, float, Tuple[float, float], List[float]] = 0.0, saturation: Union[torch.Tensor, float, Tuple[float, float], List[float]] = 0.0, hue: Union[torch.Tensor, float, Tuple[float, float], List[float]] = 0.0, return_transform: bool = False, same_on_batch: bool = False, p: float = 1.0, keepdim: bool = False)[source]¶ Applies a random transformation to the brightness, contrast, saturation and hue of a tensor image.
- Parameters
p (float) – probability of applying the transformation. Default value is 1.
return_transform (bool) – if
True
return the matrix describing the transformation applied to each input tensor. IfFalse
and the input is a tuple the applied transformation wont be concatenated.same_on_batch (bool) – apply the same transformation across the batch. Default: False.
keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, H, W)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Examples
>>> rng = torch.manual_seed(0) >>> inputs = torch.ones(1, 3, 3, 3) >>> aug = ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.) >>> aug(inputs) tensor([[[[0.9993, 0.9993, 0.9993], [0.9993, 0.9993, 0.9993], [0.9993, 0.9993, 0.9993]], <BLANKLINE> [[0.9993, 0.9993, 0.9993], [0.9993, 0.9993, 0.9993], [0.9993, 0.9993, 0.9993]], <BLANKLINE> [[0.9993, 0.9993, 0.9993], [0.9993, 0.9993, 0.9993], [0.9993, 0.9993, 0.9993]]]])
-
class
GaussianBlur
(kernel_size: Tuple[int, int], sigma: Tuple[float, float], border_type: str = 'reflect', return_transform: bool = False, same_on_batch: bool = False, p: float = 0.5)[source]¶ Apply gaussian blur given tensor image or a batch of tensor images randomly.
- Parameters
sigma (Tuple[float, float]) – the standard deviation of the kernel.
border_type (str) – the padding mode to be applied before convolving. The expected modes are:
'constant'
,'reflect'
,'replicate'
or'circular'
. Default:'reflect'
.return_transform (bool) – if
True
return the matrix describing the transformation applied to each input tensor. IfFalse
and the input is a tuple the applied transformation wont be concatenated.same_on_batch (bool) – apply the same transformation across the batch. Default: False.
p (float) – probability of applying the transformation. Default value is 0.5.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, H, W)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Examples
>>> rng = torch.manual_seed(0) >>> input = torch.rand(1, 1, 5, 5) >>> blur = GaussianBlur((3, 3), (0.1, 2.0), p=1.) >>> blur(input) tensor([[[[0.6699, 0.4645, 0.3193, 0.1741, 0.1955], [0.5422, 0.6657, 0.6261, 0.6527, 0.5195], [0.3826, 0.2638, 0.1902, 0.1620, 0.2141], [0.6329, 0.6732, 0.5634, 0.4037, 0.2049], [0.8307, 0.6753, 0.7147, 0.5768, 0.7097]]]])
-
class
RandomAffine
(degrees: Union[torch.Tensor, float, Tuple[float, float]], translate: Union[torch.Tensor, Tuple[float, float], None] = None, scale: Union[torch.Tensor, Tuple[float, float], Tuple[float, float, float, float], None] = None, shear: Union[torch.Tensor, float, Tuple[float, float], None] = None, resample: Union[str, int, <unknown>.Resample] = 'BILINEAR', return_transform: bool = False, same_on_batch: bool = False, align_corners: bool = False, padding_mode: Union[str, int, <unknown>.SamplePadding] = 'ZEROS', p: float = 0.5, keepdim: bool = False)[source]¶ Applies a random 2D affine transformation to a tensor image.
The transformation is computed so that the image center is kept invariant.
- Parameters
p (float) – probability of applying the transformation. Default value is 0.5.
degrees (float or tuple) – Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees). Set to 0 to deactivate rotations.
translate (tuple, optional) – tuple of maximum absolute fraction for horizontal and vertical translations. For example translate=(a, b), then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
scale (tuple, optional) – scaling factor interval. If (a, b) represents isotropic scaling, the scale is randomly sampled from the range a <= scale <= b. If (a, b, c, d), the scale is randomly sampled from the range a <= scale_x <= b, c <= scale_y <= d. Will keep original scale by default.
shear (sequence or float, optional) – Range of degrees to select from. If float, a shear parallel to the x axis in the range (-shear, +shear) will be apllied. If (a, b), a shear parallel to the x axis in the range (-shear, +shear) will be apllied. If (a, b, c, d), then x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. Will not apply shear by default.
resample (int, str or kornia.Resample) – resample mode from “nearest” (0) or “bilinear” (1). Default: Resample.BILINEAR.
padding_mode (int, str or kornia.SamplePadding) – padding mode from “zeros” (0), “border” (1) or “refection” (2). Default: SamplePadding.ZEROS.
return_transform (bool) – if
True
return the matrix describing the transformation applied to each. Default: False.same_on_batch (bool) – apply the same transformation across the batch. Default: False.
align_corners (bool) – interpolation flag. Default: False.
keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, H, W)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Examples
>>> rng = torch.manual_seed(0) >>> input = torch.rand(1, 1, 3, 3) >>> aug = RandomAffine((-15., 20.), return_transform=True, p=1.) >>> aug(input) (tensor([[[[0.3961, 0.7310, 0.1574], [0.1781, 0.3074, 0.5648], [0.4804, 0.8379, 0.4234]]]]), tensor([[[ 0.9923, -0.1241, 0.1319], [ 0.1241, 0.9923, -0.1164], [ 0.0000, 0.0000, 1.0000]]]))
-
class
RandomCrop
(size: Tuple[int, int], padding: Union[int, Tuple[int, int], Tuple[int, int, int, int], None] = None, pad_if_needed: Optional[bool] = False, fill: int = 0, padding_mode: str = 'constant', resample: Union[str, int, <unknown>.Resample] = 'BILINEAR', return_transform: bool = False, same_on_batch: bool = False, align_corners: bool = False, p: float = 1.0, keepdim: bool = False)[source]¶ Crops random patches of a tensor image on a given size.
- Parameters
p (float) – probability of applying the transformation for the whole batch. Default value is 1.0.
size (Tuple[int, int]) – Desired output size (out_h, out_w) of the crop. Must be Tuple[int, int], then out_h = size[0], out_w = size[1].
padding (int or sequence, optional) – Optional padding on each border of the image. Default is None, i.e no padding. If a sequence of length 4 is provided, it is used to pad left, top, right, bottom borders respectively. If a sequence of length 2 is provided, it is used to pad left/right, top/bottom borders, respectively.
pad_if_needed (boolean) – It will pad the image if smaller than the desired size to avoid raising an exception. Since cropping is done after padding, the padding seems to be done at a random offset.
fill – Pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. This value is only used when the padding_mode is constant
padding_mode – Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
resample (int, str or kornia.Resample) – Default: Resample.BILINEAR
return_transform (bool) – if
True
return the matrix describing the transformation applied to each input tensor. IfFalse
and the input is a tuple the applied transformation wont be concatenatedsame_on_batch (bool) – apply the same transformation across the batch. Default: False
align_corners (bool) – interpolation flag. Default: False.
keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, out_h, out_w)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Examples
>>> rng = torch.manual_seed(0) >>> inputs = torch.randn(1, 1, 3, 3) >>> aug = RandomCrop((2, 2), p=1.) >>> aug(inputs) tensor([[[[-0.6562, -1.0009], [ 0.2223, -0.5507]]]])
-
class
RandomErasing
(scale: Union[torch.Tensor, Tuple[float, float]] = (0.02, 0.33), ratio: Union[torch.Tensor, Tuple[float, float]] = (0.3, 3.3), value: float = 0.0, return_transform: bool = False, same_on_batch: bool = False, p: float = 0.5, keepdim: bool = False)[source]¶ Erases a random rectangle of a tensor image according to a probability p value.
The operator removes image parts and fills them with zero values at a selected rectangle for each of the images in the batch.
The rectangle will have an area equal to the original image area multiplied by a value uniformly sampled between the range [scale[0], scale[1]) and an aspect ratio sampled between [ratio[0], ratio[1])
- Parameters
p (float) – probability that the random erasing operation will be performed. Default value is 0.5.
scale (Tuple[float, float]) – range of proportion of erased area against input image.
ratio (Tuple[float, float]) – range of aspect ratio of erased area.
same_on_batch (bool) – apply the same transformation across the batch. Default: False
keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, H, W)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Examples
>>> rng = torch.manual_seed(0) >>> inputs = torch.ones(1, 1, 3, 3) >>> rec_er = RandomErasing((.4, .8), (.3, 1/.3), p=0.5) >>> rec_er(inputs) tensor([[[[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]]]])
-
class
RandomGrayscale
(return_transform: bool = False, same_on_batch: bool = False, p: float = 0.1, keepdim: bool = False)[source]¶ Applies random transformation to Grayscale according to a probability p value.
- Parameters
p (float) – probability of the image to be transformed to grayscale. Default value is 0.1.
return_transform (bool) – if
True
return the matrix describing the transformation applied to each input tensor. IfFalse
and the input is a tuple the applied transformation wont be concatenated.same_on_batch (bool) – apply the same transformation across the batch. Default: False.
keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, H, W)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Examples
>>> rng = torch.manual_seed(0) >>> inputs = torch.randn((1, 3, 3, 3)) >>> rec_er = RandomGrayscale(p=1.0) >>> rec_er(inputs) tensor([[[[-1.1344, -0.1330, 0.1517], [-0.0791, 0.6711, -0.1413], [-0.1717, -0.9023, 0.0819]], <BLANKLINE> [[-1.1344, -0.1330, 0.1517], [-0.0791, 0.6711, -0.1413], [-0.1717, -0.9023, 0.0819]], <BLANKLINE> [[-1.1344, -0.1330, 0.1517], [-0.0791, 0.6711, -0.1413], [-0.1717, -0.9023, 0.0819]]]])
-
class
RandomHorizontalFlip
(return_transform: bool = False, same_on_batch: bool = False, p: float = 0.5, p_batch: float = 1.0, keepdim: bool = False)[source]¶ Applies a random horizontal flip to a tensor image or a batch of tensor images with a given probability.
Input should be a tensor of shape (C, H, W) or a batch of tensors \((B, C, H, W)\). If Input is a tuple it is assumed that the first element contains the aforementioned tensors and the second, the corresponding transformation matrix that has been applied to them. In this case the module will Horizontally flip the tensors and concatenate the corresponding transformation matrix to the previous one. This is especially useful when using this functionality as part of an
nn.Sequential
module.- Parameters
p (float) – probability of the image being flipped. Default value is 0.5
return_transform (bool) – if
True
return the matrix describing the transformation applied to each input tensor. IfFalse
and the input is a tuple the applied transformation wont be concatenated.same_on_batch (bool) – apply the same transformation across the batch. Default: False.
keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, H, W)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Examples
>>> input = torch.tensor([[[[0., 0., 0.], ... [0., 0., 0.], ... [0., 1., 1.]]]]) >>> seq = nn.Sequential(RandomHorizontalFlip(p=1.0, return_transform=True), ... RandomHorizontalFlip(p=1.0, return_transform=True)) >>> seq(input) (tensor([[[[0., 0., 0.], [0., 0., 0.], [0., 1., 1.]]]]), tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]))
-
class
RandomVerticalFlip
(return_transform: bool = False, same_on_batch: bool = False, p: float = 0.5, p_batch: float = 1.0, keepdim: bool = False)[source]¶ Applies a random vertical flip to a tensor image or a batch of tensor images with a given probability.
- Parameters
p (float) – probability of the image being flipped. Default value is 0.5
return_transform (bool) – if
True
return the matrix describing the transformation applied to each input tensor. IfFalse
and the input is a tuple the applied transformation wont be concatenated.same_on_batch (bool) – apply the same transformation across the batch. Default: False.
keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, H, W)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Examples
>>> input = torch.tensor([[[[0., 0., 0.], ... [0., 0., 0.], ... [0., 1., 1.]]]]) >>> seq = RandomVerticalFlip(p=1.0, return_transform=True) >>> seq(input) (tensor([[[[0., 1., 1.], [0., 0., 0.], [0., 0., 0.]]]]), tensor([[[ 1., 0., 0.], [ 0., -1., 2.], [ 0., 0., 1.]]]))
-
class
RandomMotionBlur
(kernel_size: Union[int, Tuple[int, int]], angle: Union[torch.Tensor, float, Tuple[float, float]], direction: Union[torch.Tensor, float, Tuple[float, float]], border_type: Union[int, str, <unknown>.BorderType] = 'CONSTANT', resample: Union[str, int, <unknown>.Resample] = 'NEAREST', return_transform: bool = False, same_on_batch: bool = False, p: float = 0.5, keepdim: bool = False)[source]¶ Perform motion blur on 2D images (4D tensor).
- Parameters
p (float) – probability of applying the transformation. Default value is 0.5.
kernel_size (int or Tuple[int, int]) – motion kernel size (odd and positive). If int, the kernel will have a fixed size. If Tuple[int, int], it will randomly generate the value from the range batch-wisely.
angle (float or Tuple[float, float]) – angle of the motion blur in degrees (anti-clockwise rotation). If float, it will generate the value from (-angle, angle).
direction (float or Tuple[float, float]) – forward/backward direction of the motion blur. Lower values towards -1.0 will point the motion blur towards the back (with angle provided via angle), while higher values towards 1.0 will point the motion blur forward. A value of 0.0 leads to a uniformly (but still angled) motion blur. If float, it will generate the value from (-direction, direction). If Tuple[int, int], it will randomly generate the value from the range.
border_type (int, str or kornia.BorderType) – the padding mode to be applied before convolving. CONSTANT = 0, REFLECT = 1, REPLICATE = 2, CIRCULAR = 3. Default: BorderType.CONSTANT.
resample (int, str or kornia.Resample) – Default: Resample.NEAREST.
keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, H, W)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Please set
resample
to'bilinear'
if more meaningful gradients wanted.Examples
>>> rng = torch.manual_seed(0) >>> input = torch.ones(1, 1, 5, 5) >>> motion_blur = RandomMotionBlur(3, 35., 0.5, p=1.) >>> motion_blur(input) tensor([[[[0.5773, 1.0000, 1.0000, 1.0000, 0.7561], [0.5773, 1.0000, 1.0000, 1.0000, 0.7561], [0.5773, 1.0000, 1.0000, 1.0000, 0.7561], [0.5773, 1.0000, 1.0000, 1.0000, 0.7561], [0.5773, 1.0000, 1.0000, 1.0000, 0.7561]]]])
-
class
RandomPerspective
(distortion_scale: Union[torch.Tensor, float] = 0.5, resample: Union[str, int, <unknown>.Resample] = 'BILINEAR', return_transform: bool = False, same_on_batch: bool = False, align_corners: bool = False, p: float = 0.5, keepdim: bool = False)[source]¶ Applies a random perspective transformation to an image tensor with a given probability.
- Parameters
p (float) – probability of the image being perspectively transformed. Default value is 0.5.
distortion_scale (float) – it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5.
resample (int, str or kornia.Resample) – Default: Resample.BILINEAR.
return_transform (bool) – if
True
return the matrix describing the transformation applied to each. Default: False.same_on_batch (bool) – apply the same transformation across the batch. Default: False.
align_corners (bool) – interpolation flag. Default: False.
keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, H, W)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Examples
>>> rng = torch.manual_seed(0) >>> inputs= torch.tensor([[[[1., 0., 0.], ... [0., 1., 0.], ... [0., 0., 1.]]]]) >>> aug = RandomPerspective(0.5, p=0.5) >>> aug(inputs) tensor([[[[0.0000, 0.2289, 0.0000], [0.0000, 0.4800, 0.0000], [0.0000, 0.0000, 0.0000]]]])
-
class
RandomResizedCrop
(size: Tuple[int, int], scale: Union[torch.Tensor, Tuple[float, float]] = (0.08, 1.0), ratio: Union[torch.Tensor, Tuple[float, float]] = (0.75, 1.3333333333333333), resample: Union[str, int, <unknown>.Resample] = 'BILINEAR', return_transform: bool = False, same_on_batch: bool = False, align_corners: bool = False, p: float = 1.0, keepdim: bool = False)[source]¶ Crops random patches in an image tensor and resizes to a given size.
- Parameters
size (Tuple[int, int]) – Desired output size (out_h, out_w) of each edge. Must be Tuple[int, int], then out_h = size[0], out_w = size[1].
scale – range of size of the origin size cropped.
ratio – range of aspect ratio of the origin aspect ratio cropped.
resample (int, str or kornia.Resample) – Default: Resample.BILINEAR.
return_transform (bool) – if
True
return the matrix describing the transformation applied to each input tensor. IfFalse
and the input is a tuple the applied transformation wont be concatenated.same_on_batch (bool) – apply the same transformation across the batch. Default: False.
align_corners (bool) – interpolation flag. Default: False.
keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, out_h, out_w)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Example
>>> rng = torch.manual_seed(0) >>> inputs = torch.tensor([[[0., 1., 2.], ... [3., 4., 5.], ... [6., 7., 8.]]]) >>> aug = RandomResizedCrop(size=(3, 3), scale=(3., 3.), ratio=(2., 2.), p=1.) >>> aug(inputs) tensor([[[[1.2500, 1.7500, 1.5000], [4.2500, 4.7500, 3.7500], [7.2500, 7.7500, 6.0000]]]])
-
class
RandomRotation
(degrees: Union[torch.Tensor, float, Tuple[float, float], List[float]], resample: Union[str, int, <unknown>.Resample] = 'BILINEAR', return_transform: bool = False, same_on_batch: bool = False, align_corners: bool = True, p: float = 0.5, keepdim: bool = False)[source]¶ Applies a random rotation to a tensor image or a batch of tensor images given an amount of degrees.
- Parameters
p (float) – probability of applying the transformation. Default value is 0.5.
degrees (sequence or float or tensor) – range of degrees to select from. If degrees is a number the range of degrees to select from will be (-degrees, +degrees).
resample (int, str or kornia.Resample) – Default: Resample.BILINEAR.
return_transform (bool) – if
True
return the matrix describing the transformation applied to each input tensor. IfFalse
and the input is a tuple the applied transformation wont be concatenated.same_on_batch (bool) – apply the same transformation across the batch. Default: False.
align_corners (bool) – interpolation flag. Default: False.
keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, H, W)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Examples
>>> rng = torch.manual_seed(0) >>> input = torch.tensor([[1., 0., 0., 2.], ... [0., 0., 0., 0.], ... [0., 1., 2., 0.], ... [0., 0., 1., 2.]]) >>> seq = RandomRotation(degrees=45.0, return_transform=True, p=1.) >>> seq(input) (tensor([[[[0.9824, 0.0088, 0.0000, 1.9649], [0.0000, 0.0029, 0.0000, 0.0176], [0.0029, 1.0000, 1.9883, 0.0000], [0.0000, 0.0088, 1.0117, 1.9649]]]]), tensor([[[ 1.0000, -0.0059, 0.0088], [ 0.0059, 1.0000, -0.0088], [ 0.0000, 0.0000, 1.0000]]]))
-
class
RandomSolarize
(thresholds: Union[torch.Tensor, float, Tuple[float, float], List[float]] = 0.1, additions: Union[torch.Tensor, float, Tuple[float, float], List[float]] = 0.1, same_on_batch: bool = False, return_transform: bool = False, p: float = 0.5, keepdim: bool = False)[source]¶ Solarize given tensor image or a batch of tensor images randomly.
- Parameters
p (float) – probability of applying the transformation. Default value is 0.5.
thresholds (float or tuple) – Default value is 0.1. If float x, threshold will be generated from (0.5 - x, 0.5 + x). If tuple (x, y), threshold will be generated from (x, y).
additions (float or tuple) – Default value is 0.1. If float x, addition will be generated from (-x, x). If tuple (x, y), addition will be generated from (x, y).
same_on_batch (bool) – apply the same transformation across the batch. Default: False.
return_transform (bool) – if
True
return the matrix describing the transformation applied to each input tensor. IfFalse
and the input is a tuple the applied transformation wont be concatenated.keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, H, W)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Examples
>>> rng = torch.manual_seed(0) >>> input = torch.rand(1, 1, 5, 5) >>> solarize = RandomSolarize(0.1, 0.1, p=1.) >>> solarize(input) tensor([[[[0.4132, 0.1412, 0.1790, 0.2226, 0.3980], [0.2754, 0.4194, 0.0130, 0.4538, 0.2771], [0.4394, 0.4923, 0.1129, 0.2594, 0.3844], [0.3909, 0.2118, 0.1094, 0.2516, 0.3728], [0.2278, 0.0000, 0.4876, 0.0353, 0.5100]]]])
-
class
RandomPosterize
(bits: Union[int, Tuple[int, int], torch.Tensor] = 3, same_on_batch: bool = False, return_transform: bool = False, p: float = 0.5, keepdim: bool = False)[source]¶ Posterize given tensor image or a batch of tensor images randomly.
- Parameters
p (float) – probability of applying the transformation. Default value is 0.5.
bits (int or tuple) – Integer that ranged from (0, 8], in which 0 gives black image and 8 gives the original. If int x, bits will be generated from (x, 8). If tuple (x, y), bits will be generated from (x, y). Default value is 3.
same_on_batch (bool) – apply the same transformation across the batch. Default: False.
return_transform (bool) – if
True
return the matrix describing the transformation applied to each input tensor. IfFalse
and the input is a tuple the applied transformation wont be concatenated.keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, H, W)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Examples
>>> rng = torch.manual_seed(0) >>> input = torch.rand(1, 1, 5, 5) >>> posterize = RandomPosterize(3, p=1.) >>> posterize(input) tensor([[[[0.4706, 0.7529, 0.0627, 0.1255, 0.2824], [0.6275, 0.4706, 0.8784, 0.4392, 0.6275], [0.3451, 0.3765, 0.0000, 0.1569, 0.2824], [0.5020, 0.6902, 0.7843, 0.1569, 0.2510], [0.6588, 0.9098, 0.3765, 0.8471, 0.4078]]]])
-
class
RandomSharpness
(sharpness: Union[torch.Tensor, float, Tuple[float, float]] = 0.5, same_on_batch: bool = False, return_transform: bool = False, p: float = 0.5, keepdim: bool = False)[source]¶ Sharpen given tensor image or a batch of tensor images randomly.
- Parameters
p (float) – probability of applying the transformation. Default value is 0.5.
sharpness (float or tuple) – factor of sharpness strength. Must be above 0. Default value is 0.5.
same_on_batch (bool) – apply the same transformation across the batch. Default: False.
return_transform (bool) – if
True
return the matrix describing the transformation applied to each input tensor. IfFalse
and the input is a tuple the applied transformation wont be concatenated.keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, H, W)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Examples
>>> rng = torch.manual_seed(0) >>> input = torch.rand(1, 1, 5, 5) >>> sharpness = RandomSharpness(1., p=1.) >>> sharpness(input) tensor([[[[0.4963, 0.7682, 0.0885, 0.1320, 0.3074], [0.6341, 0.4810, 0.7367, 0.4177, 0.6323], [0.3489, 0.4428, 0.1562, 0.2443, 0.2939], [0.5185, 0.6462, 0.7050, 0.2288, 0.2823], [0.6816, 0.9152, 0.3971, 0.8742, 0.4194]]]])
-
class
RandomEqualize
(same_on_batch: bool = False, return_transform: bool = False, p: float = 0.5, keepdim: bool = False)[source]¶ Equalize given tensor image or a batch of tensor images randomly.
- Parameters
p (float) – Probability to equalize an image. Default value is 0.5.
same_on_batch (bool) – apply the same transformation across the batch. Default: False.
return_transform (bool) – if
True
return the matrix describing the transformation applied to each input tensor. IfFalse
and the input is a tuple the applied transformation wont be concatenated.keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
- Shape:
Input: \((C, H, W)\) or \((B, C, H, W)\), Optional: \((B, 3, 3)\)
Output: \((B, C, H, W)\)
Note
Input tensor must be float and normalized into [0, 1] for the best differentiability support. Additionally, this function accepts another transformation tensor (\((B, 3, 3)\)), then the applied transformation will be merged int to the input transformation tensor and returned.
Examples
>>> rng = torch.manual_seed(0) >>> input = torch.rand(1, 1, 5, 5) >>> equalize = RandomEqualize(p=1.) >>> equalize(input) tensor([[[[0.4963, 0.7682, 0.0885, 0.1320, 0.3074], [0.6341, 0.4901, 0.8964, 0.4556, 0.6323], [0.3489, 0.4017, 0.0223, 0.1689, 0.2939], [0.5185, 0.6977, 0.8000, 0.1610, 0.2823], [0.6816, 0.9152, 0.3971, 0.8742, 0.4194]]]])
-
class
RandomMixUp
(lambda_val: Union[torch.Tensor, Tuple[float, float], None] = None, same_on_batch: bool = False, p: float = 1.0, keepdim: bool = False)[source]¶ Apply MixUp augmentation to a batch of tensor images.
Implemention for mixup: BEYOND EMPIRICAL RISK MINIMIZATION [ZnYNDLP18].
The function returns (inputs, labels), in which the inputs is the tensor that contains the mixup images while the labels is a \((B, 3)\) tensor that contains (label_batch, label_permuted_batch, lambda) for each image.
The implementation is on top of the following repository: https://github.com/hongyi-zhang/mixup/blob/master/cifar/utils.py.
The loss and accuracy are computed as:
def loss_mixup(y, logits): criterion = F.cross_entropy loss_a = criterion(logits, y[:, 0].long(), reduction='none') loss_b = criterion(logits, y[:, 1].long(), reduction='none') return ((1 - y[:, 2]) * loss_a + y[:, 2] * loss_b).mean()
def acc_mixup(y, logits): pred = torch.argmax(logits, dim=1).to(y.device) return (1 - y[:, 2]) * pred.eq(y[:, 0]).float() + y[:, 2] * pred.eq(y[:, 1]).float()
- Parameters
p (float) – probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wisely.
lambda_val (float or torch.Tensor, optional) – min-max value of mixup strength. Default is 0-1.
same_on_batch (bool) – apply the same transformation across the batch. This flag will not maintain permutation order. Default: False.
keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False
- Inputs:
Input image tensors, shape of \((B, C, H, W)\).
Label: raw labels, shape of \((B)\).
- Returns
Adjusted image, shape of \((B, C, H, W)\).
Raw labels, permuted labels and lambdas for each mix, shape of \((B, 3)\).
- Return type
Tuple[torch.Tensor, torch.Tensor]
Note
This implementation would randomly mixup images in a batch. Ideally, the larger batch size would be preferred.
Examples
>>> rng = torch.manual_seed(1) >>> input = torch.rand(2, 1, 3, 3) >>> label = torch.tensor([0, 1]) >>> mixup = RandomMixUp() >>> mixup(input, label) (tensor([[[[0.7576, 0.2793, 0.4031], [0.7347, 0.0293, 0.7999], [0.3971, 0.7544, 0.5695]]], <BLANKLINE> <BLANKLINE> [[[0.4388, 0.6387, 0.5247], [0.6826, 0.3051, 0.4635], [0.4550, 0.5725, 0.4980]]]]), tensor([[0.0000, 0.0000, 0.1980], [1.0000, 1.0000, 0.4162]]))
-
class
RandomCutMix
(height: int, width: int, num_mix: int = 1, cut_size: Union[torch.Tensor, Tuple[float, float], None] = None, beta: Union[float, torch.Tensor, None] = None, same_on_batch: bool = False, p: float = 1.0, keepdim: bool = False)[source]¶ Apply CutMix augmentation to a batch of tensor images.
Implemention for CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features [YHO+19].
The function returns (inputs, labels), in which the inputs is the tensor that contains the mixup images while the labels is a \((\text{num_mixes}, B, 3)\) tensor that contains (label_permuted_batch, lambda) for each cutmix.
The implementation referred to the following repository: https://github.com/clovaai/CutMix-PyTorch.
The onehot label may be computed as:
def onehot(size, target): vec = torch.zeros(size, dtype=torch.float32) vec[target] = 1. return vec
def cutmix_label(labels, out_labels, size): lb_onehot = onehot(size, labels) for out_label in out_labels: label_permuted_batch, lam = out_label[:, 0], out_label[:, 1] label_permuted_onehot = onehot(size, label_permuted_batch) lb_onehot = lb_onehot * lam + label_permuted_onehot * (1. - lam) return lb_onehot
- Parameters
height (int) – the width of the input image.
width (int) – the width of the input image.
p (float) – probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wisely.
num_mix (int) – cut mix times. Default is 1.
beta (float or torch.Tensor, optional) – hyperparameter for generating cut size from beta distribution. Beta cannot be set to 0 after torch 1.8.0. If None, it will be set to 1.
cut_size ((float, float) or torch.Tensor, optional) – controlling the minimum and maximum cut ratio from [0, 1]. If None, it will be set to [0, 1], which means no restriction.
same_on_batch (bool) – apply the same transformation across the batch. This flag will not maintain permutation order. Default: False.
keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False
- Inputs:
Input image tensors, shape of \((B, C, H, W)\).
Raw labels, shape of \((B)\).
- Returns
Adjusted image, shape of \((B, C, H, W)\).
Raw labels, permuted labels and lambdas for each mix, shape of \((B, num_mix, 3)\).
- Return type
Tuple[torch.Tensor, torch.Tensor]
Note
This implementation would randomly cutmix images in a batch. Ideally, the larger batch size would be preferred.
Examples
>>> rng = torch.manual_seed(3) >>> input = torch.rand(2, 1, 3, 3) >>> input[0] = torch.ones((1, 3, 3)) >>> label = torch.tensor([0, 1]) >>> cutmix = RandomCutMix(3, 3) >>> cutmix(input, label) (tensor([[[[0.8879, 0.4510, 1.0000], [0.1498, 0.4015, 1.0000], [1.0000, 1.0000, 1.0000]]], <BLANKLINE> <BLANKLINE> [[[1.0000, 1.0000, 0.7995], [1.0000, 1.0000, 0.0542], [0.4594, 0.1756, 0.9492]]]]), tensor([[[0.0000, 1.0000, 0.4444], [1.0000, 0.0000, 0.4444]]]))
-
apply_adjust_brightness
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Apply brightness adjustment.
Wrapper for adjust_brightness for Torchvision-like param settings.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘brightness_factor’]: Brightness adjust factor per element in the batch. 0 gives a black image, 1 does not modify the input image and 2 gives a white image, while any other number modify the brightness.
- Returns
Adjusted image.
- Return type
-
apply_adjust_contrast
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Apply contrast adjustment.
Wrapper for adjust_contrast for Torchvision-like param settings.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘contrast_factor’]: Contrast adjust factor per element in the batch. 0 generates a compleatly black image, 1 does not modify the input image while any other non-negative number modify the brightness by this factor.
- Returns
Adjusted image.
- Return type
-
apply_adjust_gamma
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Perform gamma correction on an image.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘gamma_factor’]: Non negative real number, same as γgammaγ in the equation. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter.
- Returns
Adjusted image.
- Return type
-
apply_adjust_hue
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Apply hue adjustment.
Wrapper for adjust_hue for Torchvision-like param settings.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘hue_factor’]: How much to shift the hue channel. Should be in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in HSV space in positive and negative direction respectively. 0 means no shift. Therefore, both -0.5 and 0.5 will give an image with complementary colors while 0 gives the original image.
- Returns
Adjusted image.
- Return type
-
apply_adjust_saturation
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Apply saturation adjustment.
Wrapper for adjust_saturation for Torchvision-like param settings.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘saturation_factor’]: How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2.
- Returns
Adjusted image.
- Return type
-
apply_affine
(input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Random affine transformation of the image keeping center invariant.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘angle’]: Degrees of rotation.
params[‘translations’]: Horizontal and vertical translations.
params[‘center’]: Rotation center.
params[‘scale’]: Scaling params.
params[‘sx’]: Shear param toward x-axis.
params[‘sy’]: Shear param toward y-axis.
flags (Dict[str, torch.Tensor]) –
params[‘resample’]: Integer tensor. NEAREST = 0, BILINEAR = 1.
params[‘padding_mode’]: Integer tensor, see SamplePadding enum.
params[‘align_corners’]: Boolean tensor.
- Returns
The transfromed input
- Return type
-
apply_color_jitter
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Apply Color Jitter on a tensor image or a batch of tensor images with given random parameters.
Input should be a tensor of shape (H, W), (C, H, W) or a batch of tensors \((B, C, H, W)\).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘brightness_factor’]: The brightness factor.
params[‘contrast_factor’]: The contrast factor.
params[‘hue_factor’]: The hue factor.
params[‘saturation_factor’]: The saturation factor.
params[‘order’]: The order of applying color transforms. 0 is brightness, 1 is contrast, 2 is saturation, 4 is hue.
- Returns
The color jitterred input
- Return type
-
apply_crop
(input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Apply cropping by src bounding box and dst bounding box.
Order: top-left, top-right, bottom-right and bottom-left. The coordinates must be in the x, y order.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘src’]: The applied cropping src matrix :math: (*, 4, 2).
params[‘dst’]: The applied cropping dst matrix :math: (*, 4, 2).
flags (Dict[str, torch.Tensor]) –
params[‘interpolation’]: Integer tensor. NEAREST = 0, BILINEAR = 1.
params[‘align_corners’]: Boolean tensor.
- Returns
The cropped input.
- Return type
-
apply_cutmix
(input: torch.Tensor, labels: torch.Tensor, params: Dict[str, torch.Tensor]) → Tuple[torch.Tensor, torch.Tensor][source]¶ Apply cutmix to images in a batch.
CutMix augmentation strategy: patches are cut and pasted among training images where the ground truth labels are also mixed proportionally to the area of the patches.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
labels (torch.Tensor) – Label tensor with shape (B,).
params (Dict[str, torch.Tensor]) –
params[‘mix_pairs’]: Mixup indexes with shape (num_mixes, B).
params[‘crop_src’]: Lambda for the mixup strength (num_mixes, B, 4, 2).
- Returns
Adjusted image, shape of \((B, C, H, W)\).
Corresponding labels and lambdas for each mix, shape of \((num_mixes, B, 2)\).
- Return type
Tuple[torch.Tensor, torch.Tensor]
Examples
>>> input = torch.stack([torch.zeros(1, 5, 5), torch.ones(1, 5, 5)], dim=0) >>> labels = torch.tensor([0, 1]) >>> params = {'mix_pairs': torch.tensor([[1, 0]]), 'crop_src': torch.tensor([[[ ... [1., 1.], ... [2., 1.], ... [2., 2.], ... [1., 2.]], ... [[1., 1.], ... [3., 1.], ... [3., 2.], ... [1., 2.]]]])} >>> apply_cutmix(input, labels, params) (tensor([[[[0., 0., 0., 0., 0.], [0., 1., 1., 0., 0.], [0., 1., 1., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]], <BLANKLINE> <BLANKLINE> [[[1., 1., 1., 1., 1.], [1., 0., 0., 0., 1.], [1., 0., 0., 0., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]]]]), tensor([[[0.0000, 1.0000, 0.1600], [1.0000, 0.0000, 0.2400]]]))
-
apply_equalize
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Equalize an image.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
- Returns
Adjusted image.
- Return type
-
apply_erase_rectangles
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Apply rectangle erase by params.
Generate a {0, 1} mask with drawed rectangle having parameters defined by params and size by input.size()
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘widths’]: widths tensor
params[‘heights’]: heights tensor
params[‘xs’]: x positions tensor
params[‘ys’]: y positions tensor
params[‘values’]: the value to fill in
- Returns
Erased image.
- Return type
-
apply_grayscale
(input: torch.Tensor) → torch.Tensor[source]¶ Apply Gray Scale on a tensor image or a batch of tensor images with given random parameters.
Input should be a tensor of shape (3, H, W) or a batch of tensors \((*, 3, H, W)\).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
- Returns
The grayscaled input
- Return type
-
apply_hflip
(input: torch.Tensor) → torch.Tensor[source]¶ Apply Horizontally flip on a tensor image or a batch of tensor images with given random parameters.
Input should be a tensor of shape (H, W), (C, H, W) or a batch of tensors \((B, C, H, W)\).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
- Returns
The horizontally flipped input
- Return type
-
apply_mixup
(input: torch.Tensor, labels: torch.Tensor, params: Dict[str, torch.Tensor]) → Tuple[torch.Tensor, torch.Tensor][source]¶ Apply mixup to images in a batch.
MixUp augmentation strategy: overlap images with different alpha values.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
labels (torch.Tensor) – Label tensor with shape (B,).
params (Dict[str, torch.Tensor]) –
params[‘mixup_pairs’]: Mixup indexes.
params[‘mixup_lambdas’]: Lambda for the mixup strength.
- Returns
Adjusted image, shape of \((B, C, H, W)\).
Raw labels, corresponding labels and lambdas for each mix, shape of \((B, 3)\).
- Return type
Tuple[torch.Tensor, torch.Tensor]
Examples
>>> input = torch.stack([torch.eye(5).unsqueeze(dim=0), torch.ones(5, 5).unsqueeze(dim=0)]) >>> labels = torch.tensor([0, 1]) >>> params = dict(mixup_pairs=torch.tensor([1, 0]), mixup_lambdas=torch.tensor([0.5, 0.9])) >>> out_img, out_label = apply_mixup(input, labels, params) >>> out_img tensor([[[[1.0000, 0.5000, 0.5000, 0.5000, 0.5000], [0.5000, 1.0000, 0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 1.0000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000, 1.0000, 0.5000], [0.5000, 0.5000, 0.5000, 0.5000, 1.0000]]], <BLANKLINE> <BLANKLINE> [[[1.0000, 0.1000, 0.1000, 0.1000, 0.1000], [0.1000, 1.0000, 0.1000, 0.1000, 0.1000], [0.1000, 0.1000, 1.0000, 0.1000, 0.1000], [0.1000, 0.1000, 0.1000, 1.0000, 0.1000], [0.1000, 0.1000, 0.1000, 0.1000, 1.0000]]]]) >>> out_label tensor([[0.0000, 1.0000, 0.5000], [1.0000, 0.0000, 0.9000]])
-
apply_motion_blur
(input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Perform motion blur on an image.
The input image is expected to be in the range of [0, 1].
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘ksize_factor’]: motion kernel width and height (odd and positive).
params[‘angle_factor’]: angle of the motion blur in degrees (anti-clockwise rotation).
params[‘direction_factor’]: forward/backward direction of the motion blur. Lower values towards -1.0 will point the motion blur towards the back (with angle provided via angle), while higher values towards 1.0 will point the motion blur forward. A value of 0.0 leads to a uniformly (but still angled) motion blur.
flags (Dict[str, torch.Tensor]) –
flags[‘border_type’]: the padding mode to be applied before convolving. CONSTANT = 0, REFLECT = 1, REPLICATE = 2, CIRCULAR = 3. Default: BorderType.CONSTANT.
- Returns
Adjusted image with the shape as the input (*, C, H, W).
- Return type
-
apply_perspective
(input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Perform perspective transform of the given torch.Tensor or batch of tensors.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘start_points’]: Tensor containing [top-left, top-right, bottom-right, bottom-left] of the original image with shape Bx4x2.
params[‘end_points’]: Tensor containing [top-left, top-right, bottom-right, bottom-left] of the transformed image with shape Bx4x2.
flags (Dict[str, torch.Tensor]) –
params[‘interpolation’]: Integer tensor. NEAREST = 0, BILINEAR = 1.
params[‘align_corners’]: Boolean tensor.
- Returns
Perspectively transformed tensor.
- Return type
-
apply_posterize
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Posterize an image.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘bits_factor’]: uint8 bits number ranged from 0 to 8.
- Returns
Adjusted image.
- Return type
-
apply_rotation
(input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Rotate a tensor image or a batch of tensor images a random amount of degrees.
Input should be a tensor of shape (C, H, W) or a batch of tensors \((B, C, H, W)\).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘degrees’]: degree to be applied.
flags (Dict[str, torch.Tensor]) –
params[‘interpolation’]: Integer tensor. NEAREST = 0, BILINEAR = 1.
params[‘align_corners’]: Boolean tensor.
- Returns
The cropped input
- Return type
-
apply_sharpness
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Sharpen an image.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘sharpness_factor’]: Sharpness strength. Must be above 0.
- Returns
Adjusted image.
- Return type
-
apply_solarize
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Solarize an image.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘thresholds_factor’]: thresholds ranged from 0 ~ 1.
params[‘additions_factor’]: additions to add on before solarizing.
- Returns
Adjusted image.
- Return type
-
apply_vflip
(input: torch.Tensor) → torch.Tensor[source]¶ Apply vertically flip on a tensor image or a batch of tensor images with given random parameters.
Input should be a tensor of shape (H, W), (C, H, W) or a batch of tensors \((B, C, H, W)\).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
- Returns
The vertically flipped input
- Return type
-
color_jitter
(input: torch.Tensor, brightness: Union[torch.Tensor, float, Tuple[float, float], List[float]] = 0.0, contrast: Union[torch.Tensor, float, Tuple[float, float], List[float]] = 0.0, saturation: Union[torch.Tensor, float, Tuple[float, float], List[float]] = 0.0, hue: Union[torch.Tensor, float, Tuple[float, float], List[float]] = 0.0, return_transform: bool = False) → Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]][source]¶ Generate params and apply operation on input tensor. See
random_color_jitter_generator()
for details. Seeapply_color_jitter()
for details.
-
compute_affine_transformation
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Compute the applied transformation matrix :math: (*, 3, 3).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘angle’]: Degrees of rotation.
params[‘translations’]: Horizontal and vertical translations.
params[‘center’]: Rotation center.
params[‘scale’]: Scaling params.
params[‘sx’]: Shear param toward x-axis.
params[‘sy’]: Shear param toward y-axis.
- Returns
The applied transformation matrix :math: (*, 3, 3)
- Return type
-
compute_crop_transformation
(input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, torch.Tensor])[source]¶ Compute the applied transformation matrix :math: (*, 3, 3).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘src’]: The applied cropping src matrix :math: (*, 4, 2).
params[‘dst’]: The applied cropping dst matrix :math: (*, 4, 2).
- Returns
The applied transformation matrix :math: (*, 3, 3)
- Return type
-
compute_hflip_transformation
(input: torch.Tensor) → torch.Tensor[source]¶ Compute the applied transformation matrix :math: (*, 3, 3).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
- Returns
The applied transformation matrix :math: (*, 3, 3)
- Return type
-
compute_intensity_transformation
(input: torch.Tensor)[source]¶ Compute the applied transformation matrix :math: (*, 3, 3).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
- Returns
The applied transformation matrix :math: (*, 3, 3). Returns identity transformations.
- Return type
-
compute_perspective_transformation
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Compute the applied transformation matrix :math: (*, 3, 3).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘start_points’]: Tensor containing [top-left, top-right, bottom-right, bottom-left] of the orignal image with shape Bx4x2.
params[‘end_points’]: Tensor containing [top-left, top-right, bottom-right, bottom-left] of the transformed image with shape Bx4x2.
- Returns
The applied transformation matrix :math: (*, 3, 3)
- Return type
-
compute_rotate_tranformation
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Compute the applied transformation matrix :math: (*, 3, 3).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘degrees’]: degree to be applied.
- Returns
The applied transformation matrix :math: (*, 3, 3)
- Return type
-
compute_vflip_transformation
(input: torch.Tensor) → torch.Tensor[source]¶ Compute the applied transformation matrix :math: (*, 3, 3).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
- Returns
The applied transformation matrix :math: (*, 3, 3)
- Return type
-
random_affine
(input: torch.Tensor, degrees: Union[float, Tuple[float, float]], translate: Optional[Tuple[float, float]] = None, scale: Optional[Tuple[float, float]] = None, shear: Optional[Union[T, Tuple[T, T]]] = None, resample: Union[str, int, <unknown>.Resample] = 'BILINEAR', return_transform: bool = False) → Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]][source]¶ Generate params and apply operation on input tensor. See
random_affine_generator()
for details. Seeapply_affine()
for details.
-
random_grayscale
(input: torch.Tensor, p: float = 0.5, return_transform: bool = False)[source]¶ Generate params and apply operation on input tensor. See
random_prob_generator()
for details. Seeapply_grayscale()
for details.
-
random_hflip
(input: torch.Tensor, p: float = 0.5, return_transform: bool = False) → Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]][source]¶ Generate params and apply operation on input tensor. See
random_prob_generator()
for details. Seeapply_hflip()
for details.
-
random_perspective
(input: torch.Tensor, distortion_scale: Union[torch.Tensor, float] = 0.5, p: float = 0.5, return_transform: bool = False) → Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]][source]¶ Generate params and apply operation on input tensor. See
random_perspective_generator()
for details. Seeapply_perspective()
for details.
-
random_rectangle_erase
(input: torch.Tensor, p: float = 0.5, scale: Tuple[float, float] = (0.02, 0.33), ratio: Tuple[float, float] = (0.3, 3.3), return_transform: bool = False) → Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]][source]¶ Function that erases a random selected rectangle for each image in the batch, putting the value to zero. The rectangle will have an area equal to the original image area multiplied by a value uniformly sampled between the range [scale[0], scale[1]) and an aspect ratio sampled between [aspect_ratio_range[0], aspect_ratio_range[1]) :param input: input images. :type input: torch.Tensor :param scale: range of proportion of erased area against input image. :type scale: Tuple[float, float] :param ratio: range of aspect ratio of erased area. :type ratio: Tuple[float, float]
See
random_rectangles_params_generator()
for details. Seeapply_erase_rectangles()
for details.
-
random_rotation
(input: torch.Tensor, degrees: Union[torch.Tensor, float, Tuple[float, float], List[float]], return_transform: bool = False) → Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]][source]¶ Generate params and apply operation on input tensor. See
random_rotation_generator()
for details. Seeapply_rotation()
for details.
Transforms3D¶
Set of operators to perform data augmentation on 3D volumetric tensors.
-
apply_affine3d
(input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Random affine transformation of the image keeping center invariant.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (D, H, W), (C, D, H, W), (B, C, D, H, W).
params (Dict[str, torch.Tensor]) –
params[‘angles’]: Degrees of rotation with the shape of :math: (*, 3) for yaw, pitch, roll.
params[‘translations’]: Horizontal, vertical and depthical translations (dx,dy,dz).
params[‘center’]: Rotation center (x,y,z).
params[‘scale’]: Isotropic scaling params.
params[‘sxy’]: Shear param toward x-y-axis.
params[‘sxz’]: Shear param toward x-z-axis.
params[‘syx’]: Shear param toward y-x-axis.
params[‘syz’]: Shear param toward y-z-axis.
params[‘szx’]: Shear param toward z-x-axis.
params[‘szy’]: Shear param toward z-y-axis.
flags (Dict[str, torch.Tensor]) –
params[‘resample’]: Integer tensor. NEAREST = 0, BILINEAR = 1.
params[‘align_corners’]: Boolean tensor.
- Returns
The transfromed input
- Return type
-
apply_crop3d
(input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Apply cropping by src bounding box and dst bounding box.
- Order: front-top-left, front-top-right, front-bottom-right, front-bottom-left, back-top-left,
- back-top-right, back-bottom-right, back-bottom-left. The coordinates must be in x, y, z order.
params[‘src’]: The applied cropping src matrix :math: (*, 8, 3).
params[‘dst’]: The applied cropping dst matrix :math: (*, 8, 3).
- flags (Dict[str, torch.Tensor]):
params[‘interpolation’]: Integer tensor. NEAREST = 0, BILINEAR = 1.
params[‘align_corners’]: Boolean tensor.
- Returns
The cropped input.
- Return type
-
apply_dflip3d
(input: torch.Tensor) → torch.Tensor[source]¶ Apply depthical flip on a 3D tensor volume or a batch of tensors volumes with given random parameters.
Input should be a tensor of shape \((D, H, W)\), \((C, D, H, W)\) or \((*, C, D, H, W)\).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape \((D, H, W)\), \((C, D, H, W)\), \((*, C, D, H, W)\).
- Returns
The depthical flipped input.
- Return type
-
apply_equalize3d
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Equalize a tensor volume or a batch of tensors volumes with given random parameters.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape \((D, H, W)\), \((C, D, H, W)\), \((*, C, D, H, W)\).
params (Dict[str, torch.Tensor]) – shall be empty.
- Returns
The equalized input. \((D, H, W)\), \((C, D, H, W)\), \((*, C, D, H, W)\).
- Return type
-
apply_hflip3d
(input: torch.Tensor) → torch.Tensor[source]¶ Apply horizontal flip on a 3D tensor volume or a batch of tensors volumes with given random parameters.
Input should be a tensor of shape \((D, H, W)\), \((C, D, H, W)\) or \((*, C, D, H, W)\).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape \((D, H, W)\), \((C, D, H, W)\), \((*, C, D, H, W)\).
- Returns
The horizontal flipped input
- Return type
-
apply_motion_blur3d
(input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Perform motion blur on an image.
The input image is expected to be in the range of [0, 1].
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘ksize_factor’]: motion kernel width and height (odd and positive).
params[‘angle_factor’]: yaw, pitch and roll range of the motion blur in degrees \((B, 3)\).
params[‘direction_factor’]: forward/backward direction of the motion blur. Lower values towards -1.0 will point the motion blur towards the back (with angle provided via angle), while higher values towards 1.0 will point the motion blur forward. A value of 0.0 leads to a uniformly (but still angled) motion blur.
flags (Dict[str, torch.Tensor]) –
flags[‘border_type’]: the padding mode to be applied before convolving. CONSTANT = 0, REFLECT = 1, REPLICATE = 2, CIRCULAR = 3. Default: BorderType.CONSTANT.
- Returns
Adjusted image with the shape as the inpute (*, C, H, W).
- Return type
-
apply_perspective3d
(input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Perform perspective transform of the given torch.Tensor or batch of tensors.
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (D, H, W), (C, D, H, W), (B, C, D, H, W).
params (Dict[str, torch.Tensor]) –
params[‘start_points’]: Tensor containing [top-left, top-right, bottom-right, bottom-left] of the original image with shape Bx8x3.
params[‘end_points’]: Tensor containing [top-left, top-right, bottom-right, bottom-left] of the transformed image with shape Bx8x3.
flags (Dict[str, torch.Tensor]) –
params[‘interpolation’]: Integer tensor. NEAREST = 0, BILINEAR = 1.
params[‘align_corners’]: Boolean tensor.
- Returns
Perspectively transformed tensor.
- Return type
-
apply_rotation3d
(input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Rotate a tensor image or a batch of tensor images a random amount of degrees.
Input should be a tensor of shape (C, H, W) or a batch of tensors \((B, C, H, W)\).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘degrees’]: degree to be applied.
flags (Dict[str, torch.Tensor]) –
params[‘resample’]: Integer tensor. NEAREST = 0, BILINEAR = 1.
params[‘align_corners’]: Boolean tensor.
- Returns
The cropped input
- Return type
-
apply_vflip3d
(input: torch.Tensor) → torch.Tensor[source]¶ Apply vertical flip on a 3D tensor volume or a batch of tensors volumes with given random parameters.
Input should be a tensor of shape \((D, H, W)\), \((C, D, H, W)\) or \((*, C, D, H, W)\).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape \((D, H, W)\), \((C, D, H, W)\), \((*, C, D, H, W)\).
- Returns
The vertical flipped input
- Return type
-
compute_affine_transformation3d
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Compute the applied transformation matrix :math: (*, 4, 4).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (D, H, W), (C, D, H, W), (B, C, D, H, W).
params (Dict[str, torch.Tensor]) –
params[‘angles’]: Degrees of rotation with the shape of :math: (*, 3) for yaw, pitch, roll.
params[‘translations’]: Horizontal, vertical and depthical translations (dx,dy,dz).
params[‘center’]: Rotation center (x,y,z).
params[‘scale’]: Isotropic scaling params.
params[‘sxy’]: Shear param toward x-y-axis.
params[‘sxz’]: Shear param toward x-z-axis.
params[‘syx’]: Shear param toward y-x-axis.
params[‘syz’]: Shear param toward y-z-axis.
params[‘szx’]: Shear param toward z-x-axis.
params[‘szy’]: Shear param toward z-y-axis.
- Returns
The applied transformation matrix :math: (*, 4, 4)
- Return type
-
compute_crop_transformation3d
(input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, torch.Tensor])[source]¶ Compute the applied transformation matrix :math: (*, 4, 4).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
params (Dict[str, torch.Tensor]) –
params[‘src’]: The applied cropping src matrix :math: (*, 8, 3).
params[‘dst’]: The applied cropping dst matrix :math: (*, 8, 3).
- Returns
The applied transformation matrix :math: (*, 4, 4)
- Return type
-
compute_dflip_transformation3d
(input: torch.Tensor) → torch.Tensor[source]¶ Compute the applied transformation matrix :math: (*, 4, 4).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape \((D, H, W)\), \((C, D, H, W)\), \((*, C, D, H, W)\).
- Returns
The applied transformation matrix :math: (*, 4, 4)
- Return type
-
compute_hflip_transformation3d
(input: torch.Tensor) → torch.Tensor[source]¶ Compute the applied transformation matrix :math: (*, 4, 4).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape \((D, H, W)\), \((C, D, H, W)\), \((*, C, D, H, W)\).
- Returns
The applied transformation matrix :math: (*, 4, 4)
- Return type
-
compute_intensity_transformation3d
(input: torch.Tensor)[source]¶ Compute the applied transformation matrix :math: (*, 4, 4).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
- Returns
The applied transformation matrix :math: (*, 4, 4). Returns identity transformations.
- Return type
-
compute_perspective_transformation3d
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor[source]¶ Compute the applied transformation matrix :math: (*, 4, 4).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (D, H, W), (C, D, H, W), (B, C, D, H, W).
params (Dict[str, torch.Tensor]) –
params[‘start_points’]: Tensor containing [top-left, top-right, bottom-right, bottom-left] of the orignal image with shape Bx8x3.
params[‘end_points’]: Tensor containing [top-left, top-right, bottom-right, bottom-left] of the transformed image with shape Bx8x3.
- Returns
The applied transformation matrix :math: (*, 4, 4)
- Return type
-
compute_rotate_tranformation3d
(input: torch.Tensor, params: Dict[str, torch.Tensor])[source]¶ Compute the applied transformation matrix :math: (*, 4, 4).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape (D, H, W), (C, D, H, W), (B, C, D, H, W).
params (Dict[str, torch.Tensor]) –
params[‘yaw’]: degree to be applied.
params[‘pitch’]: degree to be applied.
params[‘roll’]: degree to be applied.
- Returns
The applied transformation matrix :math: (*, 4, 4)
- Return type
-
compute_vflip_transformation3d
(input: torch.Tensor) → torch.Tensor[source]¶ Compute the applied transformation matrix :math: (*, 4, 4).
- Parameters
input (torch.Tensor) – Tensor to be transformed with shape \((D, H, W)\), \((C, D, H, W)\), \((*, C, D, H, W)\).
- Returns
The applied transformation matrix :math: (*, 4, 4)
- Return type
-
random_dflip3d
(input: torch.Tensor, p: float = 0.5, return_transform: bool = False) → Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]][source]¶ Generate params and apply operation on input tensor. See
random_prob_generator()
for details. Seeapply_dflip3d()
for details.
Normalizations¶
Normalization operations are shape-agnostic for both 2D and 3D tensors.