Source code for kornia.contrib.models.base

from __future__ import annotations

import os
from abc import ABC, abstractmethod
from typing import Any, Generic, Optional, TypeVar, cast

import torch

from kornia.core import Module

ModelConfig = TypeVar("ModelConfig")


[docs]class ModelBase(ABC, Module, Generic[ModelConfig]): """Abstract model class with some utilities function."""
[docs] def load_checkpoint(self, checkpoint: str, device: Optional[torch.device] = None) -> None: """Load checkpoint from a given url or file. Args: checkpoint: The url or filepath for the respective checkpoint device: The desired device to load the weights and move the model """ if os.path.isfile(checkpoint): with open(checkpoint, "rb") as f: state_dict = torch.load(f, map_location=device) else: state_dict = torch.hub.load_state_dict_from_url(checkpoint, map_location=device) self.load_state_dict(state_dict)
[docs] @staticmethod @abstractmethod def from_config(config: ModelConfig) -> ModelBase[ModelConfig]: """This function should build/load the model. Args: config: The specifications for the model be build/loaded """ raise NotImplementedError
[docs] def compile( self, *, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: Optional[str] = None, options: dict[Any, Any] = {}, disable: bool = False, ) -> ModelBase[ModelConfig]: compiled = torch.compile( self, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode, options=options, disable=disable ) compiled = cast(ModelBase[ModelConfig], compiled) return compiled