kornia.x

Package with the utilities to train kornia models.

class kornia.x.Trainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks={})[source]

Base class to train the different models in kornia.

Warning

The API is experimental and subject to be modified based on the needs of kornia models.

Parameters
  • model (Module) – the nn.Module to be optimized.

  • train_dataloader (DataLoader) – the data loader used in the training loop.

  • valid_dataloader (DataLoader) – the data loader used in the validation loop.

  • criterion (Module) – the nn.Module with the function that computes the loss.

  • optimizer (Optimizer) – the torch optimizer object to be used during the optimization.

  • scheduler (CosineAnnealingLR) – the torch scheduler object with defiing the scheduling strategy.

  • accelerator – the Accelerator object to distribute the training.

  • config (Configuration) – a TrainerConfiguration structure containing the experiment hyper parameters.

  • callbacks (Dict[str, Callable], optional) – a dictionary containing the pointers to the functions to overrides. The main supported hooks are evaluate, preprocess, augmentations and fit. Default: {}

Important

The API heavily relies on accelerate. In order to use it, you must: pip install kornia[x]

See also

Learn how to use the API in our documentation here.

Domain trainers

class kornia.x.ImageClassifierTrainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks={})[source]

Module to be used for Image Classification purposes.

The module subclasses Trainer and overrides the evaluate() function implementing a standard accuracy() topk@[1, 5].

See also

Learn how to use this class in the following example.

class kornia.x.SemanticSegmentationTrainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks={})[source]

Module to be used for Semantic segmentation purposes.

The module subclasses Trainer and overrides the evaluate() function implementing IoU mean_iou().

See also

Learn how to use this class in the following example.

Callbacks

class kornia.x.ModelCheckpoint(filepath, monitor)[source]

Callback that save the model at the end of everyepoch.

Parameters
  • filepath (str) – the where to save the mode.

  • monitor (str) – the name of the value to track.

Usage example:

model_checkpoint = ModelCheckpoint(
    filepath="./outputs", monitor="top5",
)

trainer = ImageClassifierTrainer(...,
    callbacks={"checkpoint", model_checkpoint}
)
class kornia.x.EarlyStopping(monitor, min_delta=0.0, patience=8)[source]

Callback that evaluates whether there is improvement in the loss function.

The module track the losses and in case of finish patience sends a termination signal to the trainer. In case of termination, the module will save the last model.

Parameters
  • monitor (str) – the name of the value to track.

  • min_delta (float, optional) – the minimum difference between losses to increase the patience counter. Default: 0.0

  • patience (int, optional) – the number of times to wait until the trainer does not terminate. Default: 8

  • filepath – a backup filename to save the file in case of termination.

Usage example:

early_stop = EarlyStopping(
    monitor="top5", filepath="early_stop_model.pt"
)

trainer = ImageClassifierTrainer(...,
    callbacks={"terminate", early_stop}
)