Semantic segmentation¶

Semantic segmentation, or image segmentation, is the task of clustering parts of an image together which belong to the same object class. It is a form of pixel-level prediction because each pixel in an image is classified according to a category. Some example benchmarks for this task are Cityscapes, PASCAL VOC and ADE20K. Models are usually evaluated with the Mean Intersection-Over-Union (Mean IoU) and Pixel Accuracy metrics.
Learn more: https://paperswithcode.com/task/semantic-segmentation
Finetuning¶
In order to customize your model with your own data you can use our Training API (experimental) to perform the fine-tuning of your model.
We provide SemanticSegmentationTrainer
with a default training structure to train semantic
segmentation problems. However, one can leverage this is API using the models provided by Kornia or
use existing libraries from the PyTorch ecosystem such as torchvision.
Create the dataloaders and transforms:
class Transform(nn.Module):
def __init__(self, image_size):
super().__init__()
self.resize = K.geometry.Resize(image_size, interpolation='nearest')
@torch.no_grad()
def forward(self, x, y):
x = K.utils.image_to_tensor(np.array(x))
x, y = x.float() / 255.0, torch.from_numpy(y)
return self.resize(x), self.resize(y)
@hydra.main(config_path=".", config_name="config.yaml")
def my_app(config: Configuration) -> None:
# make image size homogeneous
transform = Transform(tuple(config.image_size))
# create the dataset
train_dataset = torchvision.datasets.SBDataset(
root=to_absolute_path(config.data_path), image_set='train', download=False, transforms=transform
)
valid_dataset = torchvision.datasets.SBDataset(
root=to_absolute_path(config.data_path), image_set='val', download=False, transforms=transform
)
# create the dataloaders
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True
)
valid_daloader = torch.utils.data.DataLoader(
valid_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True
)
Define your model, losses, optimizers and schedulers:
# create the loss function
criterion = nn.CrossEntropyLoss()
# create the model
model = nn.Sequential(torchvision.models.segmentation.fcn_resnet50(pretrained=False), Lambda(lambda x: x['out']))
# instantiate the optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.num_epochs * len(train_dataloader))
Create your preprocessing and augmentations pipeline:
# define some augmentations
_augmentations = K.augmentation.AugmentationSequential(
K.augmentation.RandomHorizontalFlip(p=0.75),
K.augmentation.RandomVerticalFlip(p=0.75),
K.augmentation.RandomAffine(degrees=10.0),
data_keys=['input', 'mask'],
)
def preprocess(self, sample: dict) -> dict:
target = sample["target"].argmax(1).unsqueeze(1).float()
return {"input": sample["input"], "target": target}
def augmentations(self, sample: dict) -> dict:
x, y = _augmentations(sample["input"], sample["target"])
# NOTE: use matplotlib to visualise before/after samples
return {"input": x, "target": y}
def on_before_model(self, sample: dict) -> dict:
target = sample["target"].squeeze(1).long()
return {"input": sample["input"], "target": target}
Finally, instantiate the SemanticSegmentationTrainer
and execute your training pipeline.
trainer = SemanticSegmentationTrainer(
model,
train_dataloader,
valid_daloader,
criterion,
optimizer,
scheduler,
config,
callbacks={"preprocess": preprocess, "augmentations": augmentations, "on_before_model": on_before_model},
See also
Play with the full example here