Source code for kornia.core.tensor_wrapper

# type: ignore
# pytorch tensor wrapper class
# insipired by:
# https://github.com/pytorch/pytorch/blob/591dfffa38848de54b7f5f4e49260847024c9281/test/test_overrides.py#L748
import collections

import torch
from torch import Tensor


# TODO: promote to KORNIA_WRAP
def wrap(v, cls):
    # wrap inputs if necessary
    if type(v) in {tuple, list}:
        return type(v)(wrap(vi, cls) for vi in v)

    return cls(v) if isinstance(v, Tensor) else v


# TODO: promote to KORNIA_UNWRAP
def unwrap(v):
    if type(v) in {tuple, list}:
        return type(v)(unwrap(vi) for vi in v)

    return v._data if isinstance(v, TensorWrapper) else v


[docs]class TensorWrapper: def __init__(self, data: Tensor) -> None: self.__dict__["_data"] = data self.__dict__["used_attrs"] = set() self.__dict__["used_calls"] = set()
[docs] def unwrap(self): return unwrap(self)
def __getstate__(self): return self.__dict__ def __setstate__(self, state): self.__dict__.update(state) def __repr__(self) -> str: return f"{self._data}" def __getattr__(self, name): if name == "data": return self._data elif name in self.__dict__: return self.__dict__[name] self.used_attrs.add(name) val = getattr(self._data, name) # NOTE: not clear is really needed # If it's a method # if callable(val): # c = getattr(type(self._data), name) # # Don't append self to args if classmethod/staticmethod # if c is val: # return lambda *a, **kw: wrap(self.__torch_function__(c, (type(self),), args=a, kwargs=kw), type(self)) # # Otherwise append self to args # return lambda *a, **kw: wrap( # #self.__torch_function__(c, (type(self),), args=(self,) + a, kwargs=kw), type(self) # ) return wrap(val, type(self)) def __setattr__(self, name, value) -> None: if name in self.__dict__: self.__dict__[name] = value self.used_attrs.add(name) setattr(self._data, name, value) def __setitem__(self, key, value) -> None: self._data[key] = value def __getitem__(self, key): return wrap(self._data[key], type(self)) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} # Find an instance of this class in the arguments args_of_this_cls = [] for a in args: if isinstance(a, cls): args_of_this_cls.append(a) elif isinstance(a, collections.abc.Sequence): args_of_this_cls.extend(el for el in a if isinstance(el, cls)) # assert len(args_of_this_cls) > 0 for a in args_of_this_cls: a.used_calls.add(func) args = unwrap(tuple(args)) kwargs = {k: unwrap(v) for k, v in kwargs.items()} return wrap(func(*args, **kwargs), cls) # TODO: `def __add__(self, other) -> Self:` when mypy release >0.991 def __add__(self, other): return self.__unary_op__(torch.add, other) def __radd__(self, other): return self.__unary_op__(torch.add, other) def __mul__(self, other): return self.__unary_op__(torch.mul, other) def __rmul__(self, other): return self.__unary_op__(torch.mul, other) def __sub__(self, other): return self.__unary_op__(torch.sub, other) def __rsub__(self, other): return self.__unary_op__(torch.sub, other) def __truediv__(self, other): return self.__unary_op__(torch.true_divide, other) def __floordiv__(self, other): return self.__unary_op__(torch.floor_divide, other) def __ge__(self, other): return self.__unary_op__(torch.ge, other) def __gt__(self, other): return self.__unary_op__(torch.gt, other) def __lt__(self, other): return self.__unary_op__(torch.lt, other) def __le__(self, other): return self.__unary_op__(torch.le, other) def __eq__(self, other): return self.__unary_op__(torch.eq, other) def __ne__(self, other): return self.__unary_op__(torch.ne, other) def __bool__(self) -> bool: return self.__unary_op__(Tensor.__bool__) def __int__(self) -> int: return self.__unary_op__(Tensor.__int__) def __neg__(self): return self.__unary_op__(Tensor.negative) def __unary_op__(self, func, other=None): args = (self, other) if other is not None else (self,) return self.__torch_function__(func, (type(self),), args) def __len__(self) -> int: return len(self._data)