kornia.x¶
Package with the utilities to train kornia models.
- class kornia.x.Trainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks={})[source]¶
Base class to train the different models in kornia.
Warning
The API is experimental and subject to be modified based on the needs of kornia models.
- Parameters
model (
Module) – the nn.Module to be optimized.train_dataloader (
DataLoader) – the data loader used in the training loop.valid_dataloader (
DataLoader) – the data loader used in the validation loop.criterion (
Optional[Module]) – the nn.Module with the function that computes the loss.optimizer (
Optimizer) – the torch optimizer object to be used during the optimization.scheduler (
CosineAnnealingLR) – the torch scheduler object with defiing the scheduling strategy.accelerator – the Accelerator object to distribute the training.
config (
Configuration) – a TrainerConfiguration structure containing the experiment hyper parameters.callbacks (
Dict[str,Callable], optional) – a dictionary containing the pointers to the functions to overrides. The main supported hooks areevaluate,preprocess,augmentationsandfit. Default:{}
Important
The API heavily relies on accelerate. In order to use it, you must:
pip install kornia[x]See also
Learn how to use the API in our documentation here.
Domain trainers¶
- class kornia.x.ImageClassifierTrainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks={})[source]¶
Module to be used for image classification purposes.
The module subclasses
Trainerand overrides theevaluate()function implementing a standardaccuracy()topk@[1, 5].See also
Learn how to use this class in the following example.
- class kornia.x.SemanticSegmentationTrainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks={})[source]¶
Module to be used for semantic segmentation purposes.
The module subclasses
Trainerand overrides theevaluate()function implementing IoUmean_iou().See also
Learn how to use this class in the following example.
- class kornia.x.ObjectDetectionTrainer(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, num_classes, callbacks=None, loss_computed_by_model=None)[source]¶
Module to be used for object detection purposes.
The module subclasses
Trainerand overrides theevaluate()function implementing IoUmean_iou().See also
Learn how to use this class in the following example.
Callbacks¶
- class kornia.x.ModelCheckpoint(filepath, monitor)[source]¶
Callback that save the model at the end of everyepoch.
- Parameters
Usage example:
model_checkpoint = ModelCheckpoint( filepath="./outputs", monitor="top5", ) trainer = ImageClassifierTrainer(..., callbacks={"checkpoint", model_checkpoint} )
- class kornia.x.EarlyStopping(monitor, min_delta=0.0, patience=8)[source]¶
Callback that evaluates whether there is improvement in the loss function.
The module track the losses and in case of finish patience sends a termination signal to the trainer. In case of termination, the module will save the last model.
- Parameters
monitor (
str) – the name of the value to track.min_delta (
float, optional) – the minimum difference between losses to increase the patience counter. Default:0.0patience (
int, optional) – the number of times to wait until the trainer does not terminate. Default:8filepath – a backup filename to save the file in case of termination.
Usage example:
early_stop = EarlyStopping( monitor="top5", filepath="early_stop_model.pt" ) trainer = ImageClassifierTrainer(..., callbacks={"terminate", early_stop} )