# Source code for kornia.losses.depth_smooth

from __future__ import annotations

import torch
import torch.nn as nn

# Based on
# https://github.com/tensorflow/models/blob/master/research/struct2depth/model.py#L625-L641

if len(img.shape) != 4:
raise AssertionError(img.shape)
return img[:, :, :, :-1] - img[:, :, :, 1:]

if len(img.shape) != 4:
raise AssertionError(img.shape)
return img[:, :, :-1, :] - img[:, :, 1:, :]

[docs]def inverse_depth_smoothness_loss(idepth: torch.Tensor, image: torch.Tensor) -> torch.Tensor:
r"""Criterion that computes image-aware inverse depth smoothness loss.

.. math::

\text{loss} = \left | \partial_x d_{ij} \right | e^{-\left \|
\partial_x I_{ij} \right \|} + \left |
\partial_y d_{ij} \right | e^{-\left \| \partial_y I_{ij} \right \|}

Args:
idepth: tensor with the inverse depth with shape :math:(N, 1, H, W).
image: tensor with the input image with shape :math:(N, 3, H, W).

Return:
a scalar with the computed loss.

Examples:
>>> idepth = torch.rand(1, 1, 4, 5)
>>> image = torch.rand(1, 3, 4, 5)
>>> loss = inverse_depth_smoothness_loss(idepth, image)
"""
if not isinstance(idepth, torch.Tensor):
raise TypeError(f"Input idepth type is not a torch.Tensor. Got {type(idepth)}")

if not isinstance(image, torch.Tensor):
raise TypeError(f"Input image type is not a torch.Tensor. Got {type(image)}")

if not len(idepth.shape) == 4:
raise ValueError(f"Invalid idepth shape, we expect BxCxHxW. Got: {idepth.shape}")

if not len(image.shape) == 4:
raise ValueError(f"Invalid image shape, we expect BxCxHxW. Got: {image.shape}")

if not idepth.shape[-2:] == image.shape[-2:]:
raise ValueError(f"idepth and image shapes must be the same. Got: {idepth.shape} and {image.shape}")

if not idepth.device == image.device:
raise ValueError(f"idepth and image must be in the same device. Got: {idepth.device} and {image.device}")

if not idepth.dtype == image.dtype:
raise ValueError(f"idepth and image must be in the same dtype. Got: {idepth.dtype} and {image.dtype}")

# compute image weights
weights_x: torch.Tensor = torch.exp(-torch.mean(torch.abs(image_dx), dim=1, keepdim=True))
weights_y: torch.Tensor = torch.exp(-torch.mean(torch.abs(image_dy), dim=1, keepdim=True))

# apply image weights to depth
smoothness_x: torch.Tensor = torch.abs(idepth_dx * weights_x)
smoothness_y: torch.Tensor = torch.abs(idepth_dy * weights_y)

[docs]class InverseDepthSmoothnessLoss(nn.Module):
r"""Criterion that computes image-aware inverse depth smoothness loss.

.. math::

\text{loss} = \left | \partial_x d_{ij} \right | e^{-\left \|
\partial_x I_{ij} \right \|} + \left |
\partial_y d_{ij} \right | e^{-\left \| \partial_y I_{ij} \right \|}

Shape:
- Inverse Depth: :math:(N, 1, H, W)
- Image: :math:(N, 3, H, W)
- Output: scalar

Examples:
>>> idepth = torch.rand(1, 1, 4, 5)
>>> image = torch.rand(1, 3, 4, 5)
>>> smooth = InverseDepthSmoothnessLoss()
>>> loss = smooth(idepth, image)
"""

def forward(self, idepth: torch.Tensor, image: torch.Tensor) -> torch.Tensor:
return inverse_depth_smoothness_loss(idepth, image)