Source code for kornia.enhance.core

import torch
import torch.nn as nn

__all__ = ["add_weighted", "AddWeighted"]


[docs]def add_weighted(src1: torch.Tensor, alpha: float, src2: torch.Tensor, beta: float, gamma: float) -> torch.Tensor: r"""Calculate the weighted sum of two Tensors. .. image:: _static/img/add_weighted.png The function calculates the weighted sum of two Tensors as follows: .. math:: out = src1 * alpha + src2 * beta + gamma Args: src1: Tensor of shape :math:`(*, H, W)`. alpha: weight of the src1 elements. src2: Tensor of same size and channel number as src1 :math:`(*, H, W)`. beta: weight of the src2 elements. gamma: scalar added to each sum. Returns: Weighted Tensor of shape :math:`(B, C, H, W)`. Example: >>> input1 = torch.rand(1, 1, 5, 5) >>> input2 = torch.rand(1, 1, 5, 5) >>> output = add_weighted(input1, 0.5, input2, 0.5, 1.0) >>> output.shape torch.Size([1, 1, 5, 5]) """ if not isinstance(src1, torch.Tensor): raise TypeError(f"src1 should be a tensor. Got {type(src1)}") if not isinstance(src2, torch.Tensor): raise TypeError(f"src2 should be a tensor. Got {type(src2)}") if not isinstance(alpha, float): raise TypeError(f"alpha should be a float. Got {type(alpha)}") if not isinstance(beta, float): raise TypeError(f"beta should be a float. Got {type(beta)}") if not isinstance(gamma, float): raise TypeError(f"gamma should be a float. Got {type(gamma)}") return src1 * alpha + src2 * beta + gamma
[docs]class AddWeighted(nn.Module): r"""Calculate the weighted sum of two Tensors. The function calculates the weighted sum of two Tensors as follows: .. math:: out = src1 * alpha + src2 * beta + gamma Args: alpha: weight of the src1 elements. beta: weight of the src2 elements. gamma: scalar added to each sum. Shape: - Input1: Tensor of shape :math:`(B, C, H, W)`. - Input2: Tensor of shape :math:`(B, C, H, W)`. - Output: Weighted tensor of shape :math:`(B, C, H, W)`. Example: >>> input1 = torch.rand(1, 1, 5, 5) >>> input2 = torch.rand(1, 1, 5, 5) >>> output = AddWeighted(0.5, 0.5, 1.0)(input1, input2) >>> output.shape torch.Size([1, 1, 5, 5]) """ def __init__(self, alpha: float, beta: float, gamma: float) -> None: super().__init__() self.alpha = alpha self.beta = beta self.gamma = gamma def forward(self, src1: torch.Tensor, src2: torch.Tensor) -> torch.Tensor: # type: ignore return add_weighted(src1, self.alpha, src2, self.beta, self.gamma)