from pathlib import Path
from typing import Callable, Dict, Optional, Union
import torch
from kornia.core import Module
from kornia.metrics import AverageMeter
from .utils import TrainerState
# default function to generate the filename in the model checkpoint
def default_filename_fcn(x: Union[str, int]) -> str:
return f"model_{x}.pt"
[docs]class EarlyStopping:
"""Callback that evaluates whether there is improvement in the loss function.
The module track the losses and in case of finish patience sends a termination signal to the trainer.
Args:
monitor: the name of the value to track.
min_delta: the minimum difference between losses to increase the patience counter.
patience: the number of times to wait until the trainer does not terminate.
**Usage example:**
.. code:: python
early_stop = EarlyStopping(
monitor="top5", filepath="early_stop_model.pt"
)
trainer = ImageClassifierTrainer(...,
callbacks={"on_checkpoint", early_stop}
)
"""
def __init__(self, monitor: str, min_delta: float = 0.0, patience: int = 8) -> None:
self.monitor = monitor
self.min_delta = min_delta
self.patience = patience
self.counter: int = 0
self.best_score: Optional[float] = None
self.early_stop: bool = False
def __call__(self, model: Module, epoch: int, valid_metric: Dict[str, AverageMeter]) -> TrainerState:
score: float = valid_metric[self.monitor].avg
if self.best_score is None:
self.best_score = score
elif score < self.best_score + self.min_delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.counter = 0
if self.early_stop:
print(f"[INFO] Early-Stopping the training process. Epoch: {epoch}.")
return TrainerState.TERMINATE
return TrainerState.TRAINING
[docs]class ModelCheckpoint:
"""Callback that save the model at the end of every epoch.
Args:
filepath: the where to save the mode.
monitor: the name of the value to track.
**Usage example:**
.. code:: python
model_checkpoint = ModelCheckpoint(
filepath="./outputs", monitor="top5",
)
trainer = ImageClassifierTrainer(...,
callbacks={"on_checkpoint", model_checkpoint}
)
"""
def __init__(self, filepath: str, monitor: str, filename_fcn: Optional[Callable[..., str]] = None) -> None:
self.filepath = filepath
self.monitor = monitor
self._filename_fcn = filename_fcn or default_filename_fcn
# track best model
self.best_metric: float = 0.0
# create directory
Path(self.filepath).mkdir(parents=True, exist_ok=True)
def __call__(self, model: Module, epoch: int, valid_metric: Dict[str, AverageMeter]) -> None:
valid_metric_value: float = valid_metric[self.monitor].avg
if valid_metric_value > self.best_metric:
self.best_metric = valid_metric_value
# store old metric and save new model
filename = Path(self.filepath) / self._filename_fcn(epoch)
torch.save(model, filename)