COCO dataset example using batch tensors¶
Introduction¶
In this example, we will show you how to perform transformations for object detection using Torchaug through batch tensors that stacks images, boxes and masks. This can only be used when images are of the same shape which requires former resizing.
We also make qualitative comparison with Torchvision that cannot handle batch transforms or nested tensors natively. However, Torchaug follows Torchvision’s implementations which lets you chose the transform library of your choice.
Set Up¶
Imports¶
Here we import all the required modules for the notebook.
import pathlib
import random
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms.v2.functional as TVF
from torch.utils.data import DataLoader
from torchvision import datasets, tv_tensors
from torchvision.transforms import v2 as tv_transforms
from torchvision.tv_tensors._dataset_wrapper import wrap_dataset_for_transforms_v2 as tv_wrap_dataset_for_transforms_v2
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchaug import ta_tensors
from torchaug import transforms as ta_transforms
from torchaug.data.dataloader import default_collate
from torchaug.data.dataset import wrap_dataset_for_transforms_v2
Utils functions¶
Here we define utilities functions for plotting and handling various data structure.
#### visualization function tools
def plot(imgs, row_title=None, **imshow_kwargs):
if not isinstance(imgs[0], list):
# Make a 2d grid even if there's just 1 row
imgs = [imgs]
num_rows = len(imgs)
num_cols = len(imgs[0])
_, axs = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(10, 5), squeeze=False)
for row_idx, row in enumerate(imgs):
for col_idx, img in enumerate(row):
boxes = None
masks = None
if isinstance(img, tuple):
img, target = img
if isinstance(target, dict):
boxes = target.get("boxes")
masks = target.get("masks")
elif isinstance(target, (tv_tensors.BoundingBoxes, ta_tensors.BoundingBoxes)):
boxes = target
else:
raise ValueError(f"Unexpected target type: {type(target)}")
img = TVF.to_image(img)
if img.dtype.is_floating_point and img.min() < 0:
# Poor man's re-normalization for the colors to be OK-ish. This
# is useful for images coming out of Normalize()
img -= img.min()
img /= img.max()
img = TVF.to_dtype(img, torch.uint8, scale=True)
if boxes is not None:
img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
if masks is not None:
img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=0.65)
ax = axs[row_idx, col_idx]
ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout()
def uncollate_batch(batch):
imgs = batch[0]
targets = batch[1:][0]
decollate_imgs = imgs.to_samples()
for key, elems in targets.items():
if key == "boxes":
decollate_bbox = elems.to_samples()
elif key == "masks":
decollate_mask = elems.to_samples()
elif key == "labels":
decollate_labels = elems.to_samples()
decollate = [
(
decollate_imgs[i],
{
"boxes": decollate_bbox[i],
"masks": decollate_mask[i],
"labels": decollate_labels[i],
},
)
for i in range(len(decollate_imgs))
]
return decollate
Dataset creations¶
Configurate paths for the COCO dataset.
ROOT = pathlib.Path("your_path") / "coco" # replace by your path
IMAGES_PATH = str(ROOT / "val2017")
ANNOTATIONS_PATH = str(ROOT / "annotations" / "instances_val2017.json")
Torchaug¶
We define the dataloaders for batched tensors that requires resizing in the dataset in order to stack the tensors:
torchaug_batch_dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH, transforms=ta_transforms.Resize([224, 224]))
torchaug_batch_dataset = wrap_dataset_for_transforms_v2(torchaug_batch_dataset, target_keys=("boxes", "labels", "masks"))
torchaug_batch_dataloader = DataLoader(torchaug_batch_dataset, batch_size=5, collate_fn=default_collate)
Torchvision¶
torchvision_dataset = datasets.CocoDetection(
IMAGES_PATH, ANNOTATIONS_PATH, transforms=tv_transforms.Resize([224, 224])
)
torchvision_dataset = tv_wrap_dataset_for_transforms_v2(torchvision_dataset, target_keys=("boxes", "labels", "masks"))
Transform batch¶
seed = 203 # set seed for reproducibility
Torchaug¶
We define the following transformations:
Horizontal flip
Color jittering
RandomResizedCrop
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torchaug_transform = ta_transforms.SequentialTransform(
[
ta_transforms.RandomHorizontalFlip(p=0.5),
ta_transforms.ColorJitter(0.4, 0.4, 0.4, 0.2),
ta_transforms.RandomResizedCrop(224),
],
inplace=True,
batch_inplace=True,
batch_transform=True,
num_chunks=2,
permute_chunks=False,
)
batch = next(iter(torchaug_batch_dataloader))
torchaug_transformed_batch = torchaug_transform(batch)
torchaug_transformed_batch_uncollated = uncollate_batch(torchaug_transformed_batch)
sanitized_torchaug_batch = ta_transforms.SanitizeBoundingBoxes()(torchaug_transformed_batch)
sanitized_torchaug_batch_uncollated = uncollate_batch(sanitized_torchaug_batch)
Torchvision¶
To perform the same random transforms as Torchaug we override the sampling of parameters for the data augmentations.
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
batch = next(iter(torchaug_batch_dataloader))
indices_flip = torchaug_transform.transforms[0]._get_indices_transform(5, "cpu")
params_color = torchaug_transform.transforms[1]._get_params(
batch[0], 2, (torch.tensor([0, 1, 2]), torch.tensor([3, 4]))
)
params_crop = torchaug_transform.transforms[2]._get_params(batch[0], 2, None)
torchvision_batch = [torchvision_dataset[i] for i in range(5)]
for idx in indices_flip:
torchvision_batch[idx] = tv_transforms.RandomHorizontalFlip(1)(torchvision_batch[idx])
for idxs, param_color in zip([[0, 1, 2], [3, 4]], params_color):
tv_transform_color = tv_transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)
tv_transform_color._get_params = lambda x: param_color
for i, idx in enumerate(idxs):
tv_transform_color._get_params = lambda x: {
"fn_idx": param_color["fn_idx"],
"brightness_factor": float(param_color["brightness_factor"][i]),
"contrast_factor": float(param_color["contrast_factor"][i]),
"saturation_factor": float(param_color["saturation_factor"][i]),
"hue_factor": float(param_color["hue_factor"][i]),
}
torchvision_batch[idx] = tv_transform_color(torchvision_batch[idx])
for idxs, param_crop in zip([[0, 1, 2], [3, 4]], params_crop):
tv_transform_resized_crop = tv_transforms.RandomResizedCrop([224, 224])
tv_transform_resized_crop._get_params = lambda x: param_crop
for idx in idxs:
torchvision_batch[idx] = tv_transform_resized_crop(torchvision_batch[idx])
sanitized_torchvision_batch = [tv_transforms.SanitizeBoundingBoxes()(img) for img in torchvision_batch]
Visualization¶
Torchaug¶
Here we qualitatively visualize the results of Torchaug after transformations. Sanitized batch removed boxes and masks that are not valid.
plot(
[
[torchaug_batch_dataset[i] for i in range(5)],
torchaug_transformed_batch_uncollated,
sanitized_torchaug_batch_uncollated,
],
row_title=["torchaug dataset", "torchaug transformed", "torchaug sanitized"],
)

Torchaug and Torchvision¶
plot(
[torchaug_transformed_batch_uncollated, torchvision_batch],
row_title=["torchaug transformed", "torchvision transformed"],
)
plot(
[sanitized_torchaug_batch_uncollated, sanitized_torchvision_batch],
row_title=["torchaug sanitized", "torchvision sanitized"],
)

As we can see Torchaug provides the same augmentations as Torchvision which allows you to switch between frameworks !