Source code for kornia.augmentation._2d.intensity.random_snow

from typing import Any, Dict, Optional, Tuple

import torch

from kornia.augmentation import random_generator as rg
from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
from kornia.color import hls_to_rgb, rgb_to_hls
from kornia.core import Tensor
from kornia.core.check import KORNIA_CHECK


[docs]class RandomSnow(IntensityAugmentationBase2D): r"""Generates snow effect on given tensor image or a batch tensor images. Args: snow_coefficient: A tuple of floats (lower and upper bound) between 0 and 1 that control the amount of snow to add to the image, the larger value corresponds to the more snow. brightness: A tuple of floats (lower and upper bound) greater than 1 that controls the brightness of the snow. same_on_batch: If True, apply the same transformation to each image in a batch. Default: False. p: Probability of applying the transformation. Default: 0.5. keepdim: Keep the output tensor with the same shape as input. Default: False. Shape: - Input: :math:`(C, H, W)` or :math:`(B, C, H, W)` - Output: :math:`(B, C, H, W)` Examples: >>> inputs = torch.rand(2, 3, 4, 4) >>> snow = kornia.augmentation.RandomSnow(p=1.0, snow_coefficient=(0.1, 0.6), brightness=(1.0, 5.0)) >>> output = snow(inputs) >>> output.shape torch.Size([2, 3, 4, 4]) """ def __init__( self, snow_coefficient: Tuple[float, float] = (0.5, 0.5), brightness: Tuple[float, float] = (2, 2), same_on_batch: bool = False, p: float = 1.0, keepdim: bool = False, ) -> None: super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim) KORNIA_CHECK(all(0 <= el <= 1 for el in snow_coefficient), "Snow coefficient values must be between 0 and 1.") KORNIA_CHECK(all(1 <= el for el in brightness), "Brightness values must be greater than 1.") self._param_generator = rg.PlainUniformGenerator( (snow_coefficient, "snow_coefficient", 0.5, (0.0, 1.0)), (brightness, "brightness", None, None) ) def apply_transform( self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None ) -> Tensor: KORNIA_CHECK(input.shape[1] == 3, "Number of color channels should be 3.") KORNIA_CHECK(len(input.shape) in (3, 4), "Wrong input dimension.") if len(input.shape) == 3: input = input[None, :, :, :] input_HLS = rgb_to_hls(input) mask = torch.zeros_like(input_HLS) # Retrieve generated parameters snow_coefficient = params["snow_coefficient"].to(input) brightness = params["brightness"].to(input) snow_coefficient = snow_coefficient[:, None, None, None] brightness = brightness[:, None, None, None] mask[:, 1, :, :] = torch.where(input_HLS[:, 1, :, :] < snow_coefficient[:, 0, :, :], 1, 0) # Increase Light channel of the image by given brightness for areas based on snow coefficient. new_light = (input_HLS * mask * brightness).clamp(min=0.0, max=1.0) input_HLS = input_HLS * (1 - mask) + new_light output = hls_to_rgb(input_HLS) return output