Source code for kornia.x.trainers

from typing import Callable, Dict, Optional

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from kornia.metrics import accuracy, mean_average_precision, mean_iou

from .trainer import Trainer
from .utils import Configuration


[docs]class ImageClassifierTrainer(Trainer): """Module to be used for image classification purposes. The module subclasses :py:class:`~kornia.x.Trainer` and overrides the :py:func:`~kornia.x.Trainer.evaluate` function implementing a standard :py:func:`~kornia.metrics.accuracy` topk@[1, 5]. .. seealso:: Learn how to use this class in the following `example <https://github.com/kornia/kornia/blob/master/examples/train/image_classifier/>`__. """ def compute_metrics(self, *args: torch.Tensor) -> Dict[str, float]: if len(args) != 2: raise AssertionError out, target = args acc1, acc5 = accuracy(out, target, topk=(1, 5)) return dict(top1=acc1.item(), top5=acc5.item())
[docs]class SemanticSegmentationTrainer(Trainer): """Module to be used for semantic segmentation purposes. The module subclasses :py:class:`~kornia.x.Trainer` and overrides the :py:func:`~kornia.x.Trainer.evaluate` function implementing IoU :py:func:`~kornia.metrics.mean_iou`. .. seealso:: Learn how to use this class in the following `example <https://github.com/kornia/kornia/blob/master/examples/train/semantic_segmentation/>`__. """ def compute_metrics(self, *args: torch.Tensor) -> Dict[str, float]: if len(args) != 2: raise AssertionError out, target = args iou = mean_iou(out.argmax(1), target, out.shape[1]).mean() return dict(iou=iou.item())
[docs]class ObjectDetectionTrainer(Trainer): """Module to be used for object detection purposes. The module subclasses :py:class:`~kornia.x.Trainer` and overrides the :py:func:`~kornia.x.Trainer.evaluate` function implementing IoU :py:func:`~kornia.metrics.mean_iou`. .. seealso:: Learn how to use this class in the following `example <https://github.com/kornia/kornia/blob/master/examples/train/object_detection/>`__. """ def __init__( self, model: nn.Module, train_dataloader: DataLoader, valid_dataloader: DataLoader, criterion: Optional[nn.Module], optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.CosineAnnealingLR, config: Configuration, num_classes: int, callbacks: Dict[str, Callable] = None, loss_computed_by_model: Optional[bool] = None, ) -> None: if callbacks is None: callbacks = {} super().__init__(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks) # TODO: auto-detect if the model is from TorchVision self.loss_computed_by_model = loss_computed_by_model self.num_classes = num_classes def on_model(self, model: nn.Module, sample: dict): if self.loss_computed_by_model and model.training: return model(sample["input"], sample["target"]) return model(sample["input"]) def compute_loss(self, *args: torch.Tensor) -> torch.Tensor: if self.loss_computed_by_model: # Note: in case of dict losses obtained if isinstance(args[0], dict): return torch.stack([v for _, v in args[0].items()]).mean() return torch.stack(list(args[0])).sum() if self.criterion is None: raise RuntimeError("`criterion` should not be None if `loss_computed_by_model` is False.") return self.criterion(*args) def compute_metrics(self, *args: torch.Tensor) -> Dict[str, float]: if ( isinstance(args[0], dict) and "boxes" in args[0] and "labels" in args[0] and "scores" in args[0] and isinstance(args[1], dict) and "boxes" in args[1] and "labels" in args[1] ): mAP, _ = mean_average_precision( [a['boxes'] for a in args[0]], [a['labels'] for a in args[0]], [a['scores'] for a in args[0]], [a['boxes'] for a in args[1]], [a['labels'] for a in args[1]], n_classes=self.num_classes, threshold=0.000001, ) return {'mAP': mAP.item()} return super().compute_metrics(*args)