kornia.x#
Package with the utilities to train kornia models.
- class kornia.x.Trainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks={})[source]#
Base class to train the different models in kornia.
Warning
The API is experimental and subject to be modified based on the needs of kornia models.
- Parameters:
model (
Module
) – the nn.Module to be optimized.train_dataloader (
DataLoader
[Any
]) – the data loader used in the training loop.valid_dataloader (
DataLoader
[Any
]) – the data loader used in the validation loop.criterion (
Optional
[Module
]) – the nn.Module with the function that computes the loss.optimizer (
Optimizer
) – the torch optimizer object to be used during the optimization.scheduler (
_LRScheduler
) – the torch scheduler object with defiing the scheduling strategy.accelerator – the Accelerator object to distribute the training.
config (
Configuration
) – a TrainerConfiguration structure containing the experiment hyper parameters.callbacks (
Dict
[str
,Callable
[...
,None
]], optional) – a dictionary containing the pointers to the functions to overrides. The main supported hooks areevaluate
,preprocess
,augmentations
andfit
. Default:{}
Important
The API heavily relies on accelerate. In order to use it, you must:
pip install kornia[x]
See also
Learn how to use the API in our documentation here.
Domain trainers#
- class kornia.x.ImageClassifierTrainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks={})[source]#
Module to be used for image classification purposes.
The module subclasses
Trainer
and overrides theevaluate()
function implementing a standardaccuracy()
topk@[1, 5].See also
Learn how to use this class in the following example.
- class kornia.x.SemanticSegmentationTrainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks={})[source]#
Module to be used for semantic segmentation purposes.
The module subclasses
Trainer
and overrides theevaluate()
function implementing IoUmean_iou()
.See also
Learn how to use this class in the following example.
- class kornia.x.ObjectDetectionTrainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, num_classes, callbacks=None, loss_computed_by_model=None)[source]#
Module to be used for object detection purposes.
The module subclasses
Trainer
and overrides theevaluate()
function implementing IoUmean_iou()
.See also
Learn how to use this class in the following example.
Callbacks#
- class kornia.x.ModelCheckpoint(filepath, monitor, filename_fcn=None)[source]#
Callback that save the model at the end of every epoch.
- Parameters:
Usage example:
model_checkpoint = ModelCheckpoint( filepath="./outputs", monitor="top5", ) trainer = ImageClassifierTrainer(..., callbacks={"on_checkpoint", model_checkpoint} )
- class kornia.x.EarlyStopping(monitor, min_delta=0.0, patience=8)[source]#
Callback that evaluates whether there is improvement in the loss function.
The module track the losses and in case of finish patience sends a termination signal to the trainer.
- Parameters:
Usage example:
early_stop = EarlyStopping( monitor="top5", filepath="early_stop_model.pt" ) trainer = ImageClassifierTrainer(..., callbacks={"on_checkpoint", early_stop} )