Semantic segmentation

https://production-media.paperswithcode.com/thumbnails/task/task-0000000885-bec5f079_K84qLCL.jpg

Semantic segmentation, or image segmentation, is the task of clustering parts of an image together which belong to the same object class. It is a form of pixel-level prediction because each pixel in an image is classified according to a category. Some example benchmarks for this task are Cityscapes, PASCAL VOC and ADE20K. Models are usually evaluated with the Mean Intersection-Over-Union (Mean IoU) and Pixel Accuracy metrics.

Learn more: https://paperswithcode.com/task/semantic-segmentation

Finetuning

In order to customize your model with your own data you can use our Training API (experimental) to perform the fine-tuning of your model.

We provide SemanticSegmentationTrainer with a default training structure to train semantic segmentation problems. However, one can leverage this is API using the models provided by Kornia or use existing libraries from the PyTorch ecosystem such as torchvision.

Create the dataloaders and transforms:

    class Transform(nn.Module):
        def __init__(self, image_size):
            super().__init__()
            self.resize = K.geometry.Resize(image_size, interpolation='nearest')

        @torch.no_grad()
        def forward(self, x, y):
            x = K.utils.image_to_tensor(np.array(x))
            x, y = x.float() / 255., torch.from_numpy(y)
            return self.resize(x), self.resize(y)

    # make image size homogeneous
    transform = Transform(tuple(config.image_size))

    # create the dataset
    train_dataset = torchvision.datasets.SBDataset(
        root=to_absolute_path(config.data_path), image_set='train', download=False, transforms=transform)

    valid_dataset = torchvision.datasets.SBDataset(
        root=to_absolute_path(config.data_path), image_set='val', download=False, transforms=transform)

    # create the dataloaders
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True)

    valid_daloader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True)

Define your model, losses, optimizers and schedulers:

    # create the loss function
    criterion = nn.CrossEntropyLoss()

    # create the model
    model = nn.Sequential(
        torchvision.models.segmentation.fcn_resnet50(pretrained=False),
        Lambda(lambda x: x['out']),
    )

    # instantiate the optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, config.num_epochs * len(train_dataloader))

Create your preprocessing and augmentations pipeline:

    # define some augmentations
    _augmentations = K.augmentation.AugmentationSequential(
        K.augmentation.RandomHorizontalFlip(p=0.75),
        K.augmentation.RandomVerticalFlip(p=0.75),
        K.augmentation.RandomAffine(degrees=10.),
        data_keys=['input', 'mask']
    )

    def preprocess(self, sample: dict) -> dict:
        target = sample["target"].argmax(1).unsqueeze(1).float()
        return {"input": sample["input"], "target": target}

    def augmentations(self, sample: dict) -> dict:
        x, y = _augmentations(sample["input"], sample["target"])
        # NOTE: use matplotlib to visualise before/after samples
        return {"input": x, "target": y}

    def on_before_model(self, sample: dict) -> dict:
        target = sample["target"].squeeze(1).long()
        return {"input": sample["input"], "target": target}

Finally, instantiate the SemanticSegmentationTrainer and execute your training pipeline.

    trainer = SemanticSegmentationTrainer(
        model, train_dataloader, valid_daloader, criterion, optimizer, scheduler, config,
        callbacks={
            "preprocess": preprocess,
            "augmentations": augmentations,
            "on_before_model": on_before_model,
            # "on_after_model": on_after_model,
            "on_checkpoint": model_checkpoint,
        }
    )
    trainer.fit()

See also

Play with the full example here