.. _training_api: Training API (experimental) =========================== Kornia provides a Training API with the specific purpose to train and fine-tune the supported deep learning algorithms within the library. .. sidebar:: **Deep Alchemy** .. image:: https://github.com/kornia/data/raw/main/pixie_alchemist.png :width: 100% :align: center A seemingly magical process of transformation, creation, or combination of data to usable deep learning models. .. important:: In order to use our Training API you must: ``pip install kornia[x]`` Why a Training API ? -------------------- Kornia includes deep learning models that eventually need to be updated through fine-tuning. Our aim is to have an API flexible enough to be used across our vision models and enable us to override methods or dynamically pass callbacks to ease the process of debugging and experimentations. .. admonition:: **Disclaimer** :class: seealso We do not pretend to be a general purpose training library but instead we allow Kornia users to experiment with the training of our models. Design Principles ----------------- - `kornia` golden rule is to not have heavy dependencies. - Our models are simple enough so that a light training API can fulfill our needs. - Flexible and full control to the training/validation loops and customize the pipeline. - Decouple the model definition from the training pipeline. - Use plane PyTorch abstractions and recipes to write your own routines. - Implement `accelerate `_ library to scale the problem. Trainer Usage ------------- The entry point to start training with Kornia is through the :py:class:`~kornia.x.Trainer` class. The main API is a self contained module that heavily relies on `accelerate `_ to easily scale the training over multi-GPUs/TPU/fp16 `(see more) `_ by following standard PyTorch recipes. Our API expects to consume standard PyTorch components and you decide if `kornia` makes the magic for you. 1. Define your model .. code:: python model = nn.Sequential( kornia.contrib.VisionTransformer(image_size=32, patch_size=16), kornia.contrib.ClassificationHead(num_classes=10), ) 2. Create the datasets and dataloaders for training and validation .. code:: python # datasets train_dataset = torchvision.datasets.CIFAR10( root=config.data_path, train=True, download=True, transform=T.ToTensor()) valid_dataset = torchvision.datasets.CIFAR10( root=config.data_path, train=False, download=True, transform=T.ToTensor()) # dataloaders train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=config.batch_size, shuffle=True) valid_daloader = torch.utils.data.DataLoader( valid_dataset, batch_size=config.batch_size, shuffle=True) 3. Create your loss function, optimizer and scheduler .. code:: python # loss function criterion = nn.CrossEntropyLoss() # optimizer and scheduler optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, config.num_epochs * len(train_dataloader) ) 4. Create the Trainer and execute the training pipeline .. code:: python trainer = kornia.train.Trainer( model, train_dataloader, valid_daloader, criterion, optimizer, scheduler, config, ) trainer.fit() # execute your training ! Customize [callbacks] --------------------- At this point you might think - *Is this API generic enough ?* Of course not ! What is next ? Let's have fun and **customize**. The :py:class:`~kornia.x.Trainer` internals are clearly defined such in a way so that e.g you can subclass and just override the :py:func:`~kornia.x.Trainer.evaluate` method and adjust according to your needs. We provide predefined classes for generic problems such as :py:class:`~kornia.x.ImageClassifierTrainer`, :py:class:`~kornia.x.SemanticSegmentationTrainer`. .. note:: More trainers will come as soon as we include more models. You can easily customize by creating your own class, or even through ``callbacks`` as follows: .. code:: python @torch.no_grad() def my_evaluate(self) -> dict: self.model.eval() for sample_id, sample in enumerate(self.valid_dataloader): source, target = sample # this might change with new pytorch ataset structure # perform the preprocess and augmentations in batch img = self.preprocess(source) # Forward out = self.model(img) # Loss computation val_loss = self.criterion(out, target) # measure accuracy and record loss acc1, acc5 = accuracy(out.detach(), target, topk=(1, 5)) # create the trainer and pass the evaluate method as follows trainer = K.train.Trainer(..., callbacks={"evaluate", my_evaluate}) **Still not convinced ?** You can even override the whole :py:func:`~kornia.x.ImageClassifierTrainer.fit()` method and implement your custom for loops and the trainer will setup for you using the Accelerator all the data to the device and the rest of the story is just PyTorch :) .. code:: python def my_fit(self, ): # this is a custom pytorch training loop self.model.train() for epoch in range(self.num_epochs): for source, targets in self.train_dataloader: self.optimizer.zero_grad() output = self.model(source) loss = self.criterion(output, targets) self.backward(loss) self.optimizer.step() stats = self.evaluate() # do whatever you want with validation # create the trainer and pass the evaluate method as follows trainer = K.train.Trainer(..., callbacks={"fit", my_fit}) .. note:: The following hooks are available to override: ``preprocess``, ``augmentations``, ``evaluate``, ``fit``, ``on_checkpoint``, ``on_epoch_end``, ``on_before_model`` Preprocess and augmentations ---------------------------- Taking a pre-trained model from an external source and assume that fine-tuning with your data by just changing few things in your model is usually a bad assumption in practice. Fine-tuning a model need a lot tricks which usually means designing a good augmentation or preprocess strategy before you execute the training pipeline. For this reason, we enable through callbacks to pass pointers to the ``proprocess`` and ``augmentation`` functions to make easy the debugging and experimentation experience. .. code:: python def preprocess(x): return x.float() / 255. augmentations = nn.Sequential( K.augmentation.RandomHorizontalFlip(p=0.75), K.augmentation.RandomVerticalFlip(p=0.75), K.augmentation.RandomAffine(degrees=10.), K.augmentation.PatchSequential( K.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=0.8), grid_size=(2, 2), # cifar-10 is 32x32 and vit is patch 16 patchwise_apply=False, ), ) # create the trainer and pass the augmentation or preprocess trainer = K.train.ImageClassifierTrainer(..., callbacks={"preprocess", preprocess, "augmentations": augmentations}) Callbacks utilities ------------------- We also provide utilities to save checkpoints of the model or early stop the training. You can use as follows passing as ``callbacks`` the classes :py:class:`~kornia.x.ModelCheckpoint` and :py:class:`~kornia.x.EarlyStopping`. .. code:: python def my_evaluate(self): # stats = StatsTracker() # loss = nn.CrossEntropyLoss() # rest of your evaluation loop prediction = self.on_model(self.model, sample) val_loss = self.compute_loss(out, sample["mask"]) stats.update("loss", val_loss.item(), batch_size) return stats.as_dict() model_checkpoint = ModelCheckpoint( filepath="./outputs", monitor="loss", ) early_stop = EarlyStopping( monitor="loss", patience=2 ) trainer = SemanticSegmentationTrainer(..., callbacks={"on_epoch_end": early_stop, "on_checkpoint", model_checkpoint} ) .. note:: It'll look something like this considering code above | best_loss = inf | epoch 1: | avg_loss = 1.0 | checkpoint = ./output/model_epoch=1_metricValue=1.0.pt | epoch 2: | avg_loss = 0.9 | best_loss = 0.9 | checkpoint = ./output/model_epoch=1_metricValue=0.9.pt | epoch 3 | avg_loss = 1.1 | early_stop.counter += 1 | epoch 4 | avg_loss = 1.2 | early_stop.counter += 1 | early_stop.patience == early_stop.counter | training ends here Hyperparameter sweeps --------------------- Use `hydra `_ to implement an easy search strategy for your hyper-parameters as follows: .. note:: Checkout the toy example in `here `__ .. code:: python python ./train/image_classifier/main.py num_epochs=50 batch_size=32 .. code:: python python ./train/image_classifier/main.py --multirun lr=1e-3,1e-4 Distributed Training -------------------- Kornia :py:class:`~kornia.x.Trainer` heavily relies on `accelerate `_ to decouple the process of running your training scripts in a distributed environment. .. note:: We haven't tested yet all the possibilities for distributed training. Expect some adventures or `join us `_ and help to iterate :) The below recipes are taken from the `accelerate` library in `here `__: - single CPU: * from a server without GPU .. code:: bash python ./train/image_classifier/main.py * from any server by passing `cpu=True` to the `Accelerator`. .. code:: bash python ./train/image_classifier/main.py --data_path path_to_data --cpu * from any server with Accelerate launcher .. code:: bash accelerate launch --cpu ./train/image_classifier/main.py --data_path path_to_data - single GPU: .. code:: bash python ./train/image_classifier/main.py # from a server with a GPU - with fp16 (mixed-precision) * from any server by passing `fp16=True` to the `Accelerator`. .. code:: bash python ./train/image_classifier/main.py --data_path path_to_data --fp16 * from any server with Accelerate launcher .. code:: bash accelerate launch --fp16 ./train/image_classifier/main.py --data_path path_to_data - multi GPUs (using PyTorch distributed mode) * With Accelerate config and launcher .. code:: bash accelerate config # This will create a config file on your server accelerate launch ./train/image_classifier/main.py --data_path path_to_data # This will run the script on your server * With traditional PyTorch launcher .. code:: bash python -m torch.distributed.launch --nproc_per_node 2 --use_env ./train/image_classifier/main.py --data_path path_to_data - multi GPUs, multi node (several machines, using PyTorch distributed mode) * With Accelerate config and launcher, on each machine: .. code:: bash accelerate config # This will create a config file on each server accelerate launch ./train/image_classifier/main.py --data_path path_to_data # This will run the script on each server * With PyTorch launcher only .. code:: bash python -m torch.distributed.launch --nproc_per_node 2 \ --use_env \ --node_rank 0 \ --master_addr master_node_ip_address \ ./train/image_classifier/main.py --data_path path_to_data # On the first server python -m torch.distributed.launch --nproc_per_node 2 \ --use_env \ --node_rank 1 \ --master_addr master_node_ip_address \ ./train/image_classifier/main.py --data_path path_to_data # On the second server - (multi) TPUs * With Accelerate config and launcher .. code:: bash accelerate config # This will create a config file on your TPU server accelerate launch ./train/image_classifier/main.py --data_path path_to_data # This will run the script on each server * In PyTorch: Add an `xmp.spawn` line in your script as you usually do.