kornia.x#

Package with the utilities to train kornia models.

class kornia.x.Trainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks={})#

Base class to train the different models in kornia.

Warning

The API is experimental and subject to be modified based on the needs of kornia models.

Parameters:
  • model (Module) – the nn.Module to be optimized.

  • train_dataloader (DataLoader[Any]) – the data loader used in the training loop.

  • valid_dataloader (DataLoader[Any]) – the data loader used in the validation loop.

  • criterion (Optional[Module]) – the nn.Module with the function that computes the loss.

  • optimizer (Optimizer) – the torch optimizer object to be used during the optimization.

  • scheduler (_LRScheduler) – the torch scheduler object with defiing the scheduling strategy.

  • accelerator – the Accelerator object to distribute the training.

  • config (Configuration) – a TrainerConfiguration structure containing the experiment hyper parameters.

  • callbacks (Dict[str, Callable[..., None]], optional) – a dictionary containing the pointers to the functions to overrides. The main supported hooks are evaluate, preprocess, augmentations and fit. Default: {}

Important

The API heavily relies on accelerate. In order to use it, you must: pip install kornia[x]

See also

Learn how to use the API in our documentation here.

Domain trainers#

class kornia.x.ImageClassifierTrainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks={})#

Module to be used for image classification purposes.

The module subclasses Trainer and overrides the evaluate() function implementing a standard accuracy() topk@[1, 5].

See also

Learn how to use this class in the following example.

class kornia.x.SemanticSegmentationTrainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks={})#

Module to be used for semantic segmentation purposes.

The module subclasses Trainer and overrides the evaluate() function implementing IoU mean_iou().

See also

Learn how to use this class in the following example.

class kornia.x.ObjectDetectionTrainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, num_classes, callbacks=None, loss_computed_by_model=None)#

Module to be used for object detection purposes.

The module subclasses Trainer and overrides the evaluate() function implementing IoU mean_iou().

See also

Learn how to use this class in the following example.

Callbacks#

class kornia.x.ModelCheckpoint(filepath, monitor, filename_fcn=None, max_mode=False)#

Callback that save the model at the end of every epoch.

Parameters:
  • filepath (str) – the where to save the mode.

  • monitor (str) – the name of the value to track.

  • max_mode (bool, optional) – if true metric will be multiply by -1 turn this flag when increasing metric value is expected for example Accuracy Default: False

Usage example:

model_checkpoint = ModelCheckpoint(
    filepath="./outputs", monitor="loss",
)

trainer = ImageClassifierTrainer(...,
    callbacks={"on_checkpoint", model_checkpoint}
)
class kornia.x.EarlyStopping(monitor, min_delta=0.0, patience=8, max_mode=False)#

Callback that evaluates whether there is improvement in the loss function.

The module track the losses and in case of finish patience sends a termination signal to the trainer.

Parameters:
  • monitor (str) – the name of the value to track.

  • min_delta (float, optional) – the minimum difference between losses to increase the patience counter. Default: 0.0

  • patience (int, optional) – the number of times to wait until the trainer does not terminate. Default: 8

  • max_mode (bool, optional) – if true metric will be multiply by -1, turn this flag when increasing metric value is expected for example Accuracy Default: False

Usage example:

early_stop = EarlyStopping(
    monitor="loss", patience=10
)

trainer = ImageClassifierTrainer(
    callbacks={"on_epoch_end", early_stop}
)