kornia.x¶
Package with the utilities to train kornia models.
- class kornia.x.Trainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks=None)¶
Base class to train the different models in kornia.
Warning
The API is experimental and subject to be modified based on the needs of kornia models.
- Parameters:
model (
Module
) – the nn.Module to be optimized.train_dataloader (
DataLoader
[Any
]) – the data loader used in the training loop.valid_dataloader (
DataLoader
[Any
]) – the data loader used in the validation loop.criterion (
Optional
[Module
]) – the nn.Module with the function that computes the loss.optimizer (
Optimizer
) – the torch optimizer object to be used during the optimization.scheduler (
_LRScheduler
) – the torch scheduler object with defiing the scheduling strategy.accelerator – the Accelerator object to distribute the training.
config (
Configuration
) – a TrainerConfiguration structure containing the experiment hyper parameters.callbacks (
Optional
[Dict
[str
,Callable
[...
,None
]]], optional) – a dictionary containing the pointers to the functions to overrides. The main supported hooks areevaluate
,preprocess
,augmentations
andfit
. Default:None
Important
The API heavily relies on accelerate. In order to use it, you must:
pip install kornia[x]
See also
Learn how to use the API in our documentation here.
Domain trainers¶
- class kornia.x.ImageClassifierTrainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks=None)¶
Module to be used for image classification purposes.
The module subclasses
Trainer
and overrides theevaluate()
function implementing a standardaccuracy()
topk@[1, 5].See also
Learn how to use this class in the following example.
- class kornia.x.SemanticSegmentationTrainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks=None)¶
Module to be used for semantic segmentation purposes.
The module subclasses
Trainer
and overrides theevaluate()
function implementing IoUmean_iou()
.See also
Learn how to use this class in the following example.
- class kornia.x.ObjectDetectionTrainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, num_classes, callbacks=None, loss_computed_by_model=None)¶
Module to be used for object detection purposes.
The module subclasses
Trainer
and overrides theevaluate()
function implementing IoUmean_iou()
.See also
Learn how to use this class in the following example.
Callbacks¶
- class kornia.x.ModelCheckpoint(filepath, monitor, filename_fcn=None, max_mode=False)¶
Callback that save the model at the end of every epoch.
- Parameters:
Usage example:
model_checkpoint = ModelCheckpoint( filepath="./outputs", monitor="loss", ) trainer = ImageClassifierTrainer(..., callbacks={"on_checkpoint", model_checkpoint} )
- class kornia.x.EarlyStopping(monitor, min_delta=0.0, patience=8, max_mode=False)¶
Callback that evaluates whether there is improvement in the loss function.
The module track the losses and in case of finish patience sends a termination signal to the trainer.
- Parameters:
monitor (
str
) – the name of the value to track.min_delta (
float
, optional) – the minimum difference between losses to increase the patience counter. Default:0.0
patience (
int
, optional) – the number of times to wait until the trainer does not terminate. Default:8
max_mode (
bool
, optional) – if true metric will be multiply by -1, turn this flag when increasing metric value is expected for example Accuracy Default:False
Usage example:
early_stop = EarlyStopping( monitor="loss", patience=10 ) trainer = ImageClassifierTrainer( callbacks={"on_epoch_end", early_stop} )