COCO dataset example

Set Up

Imports

import pathlib
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms.v2.functional as TVF
from torch.utils.data import DataLoader
from torchvision import datasets, tv_tensors
from torchvision.transforms import v2 as tv_transforms
from torchvision.tv_tensors._dataset_wrapper import wrap_dataset_for_transforms_v2 as tv_wrap_dataset_for_transforms_v2
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks

from torchaug import ta_tensors
from torchaug import transforms as ta_transforms
from torchaug.data.dataloader import default_collate
from torchaug.data.dataset import wrap_dataset_for_transforms_v2

Utils functions

#### visualization function tools
def plot(imgs, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0])
    _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(10, 5), squeeze=False)
    for row_idx, row in enumerate(imgs):
        for col_idx, img in enumerate(row):
            boxes = None
            masks = None
            if isinstance(img, tuple):
                img, target = img
                if isinstance(target, dict):
                    boxes = target.get("boxes")
                    masks = target.get("masks")
                elif isinstance(target, (tv_tensors.BoundingBoxes, ta_tensors.BoundingBoxes)):
                    boxes = target
                else:
                    raise ValueError(f"Unexpected target type: {type(target)}")
            img = TVF.to_image(img)
            if img.dtype.is_floating_point and img.min() < 0:
                # Poor man's re-normalization for the colors to be OK-ish. This
                # is useful for images coming out of Normalize()
                img -= img.min()
                img /= img.max()

            img = TVF.to_dtype(img, torch.uint8, scale=True)
            if boxes is not None:
                img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
            if masks is not None:
                img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=0.65)

            ax = axs[row_idx, col_idx]
            ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()


def uncollate_batch(batch):
    imgs = batch[0]
    targets = batch[1:][0]
    decollate_imgs = imgs.to_samples()
    for key, elems in targets.items():
        if key == "boxes":
            decollate_bbox = elems.to_samples()
        elif key == "masks":
            decollate_mask = elems.to_samples()
        elif key == "labels":
            decollate_labels = elems.to_samples()

    decollate = [
        (
            decollate_imgs[i],
            {
                "boxes": decollate_bbox[i],
                "masks": decollate_mask[i],
                "labels": decollate_labels[i],
            },
        )
        for i in range(len(decollate_imgs))
    ]

    return decollate

Dataset creations

ROOT = pathlib.Path("your_path") / "coco"  # replace by your path
IMAGES_PATH = str(ROOT / "val2017")
ANNOTATIONS_PATH = str(ROOT / "annotations" / "instances_val2017.json")

Torchaug

torchaug_dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH, transforms=ta_transforms.Resize([224, 224]))
torchaug_dataset = wrap_dataset_for_transforms_v2(torchaug_dataset, target_keys=("boxes", "labels", "masks"))
torchaug_dataloader = DataLoader(torchaug_dataset, batch_size=5, collate_fn=default_collate)

Torchvision

torchvision_dataset = datasets.CocoDetection(
    IMAGES_PATH, ANNOTATIONS_PATH, transforms=tv_transforms.Resize([224, 224])
)
torchvision_dataset = tv_wrap_dataset_for_transforms_v2(torchvision_dataset, target_keys=("boxes", "labels", "masks"))

Transform batch

seed = 203  # set seed for reproducibility

Torchaug

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

torchaug_transform = ta_transforms.SequentialTransform(
    [
        ta_transforms.RandomHorizontalFlip(p=0.5),
        ta_transforms.ColorJitter(0.4, 0.4, 0.4, 0.2),
        ta_transforms.RandomResizedCrop(224),
    ],
    transforms_attributes_override={
        "inplace": True,
        "batch_inplace": True,
        "batch_transform": True,
        "num_chunks": 2,
        "permute_chunks": False,
    },
)

batch = next(iter(torchaug_dataloader))
torchaug_transformed_batch = torchaug_transform(batch)
torchaug_transformed_batch_uncollated = uncollate_batch(torchaug_transformed_batch)
sanitized_torchaug_batch = ta_transforms.SanitizeBoundingBoxes()(torchaug_transformed_batch)
sanitized_torchaug_batch_uncollated = uncollate_batch(sanitized_torchaug_batch)

Torchvision

torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

batch = next(iter(torchaug_dataloader))
indices_flip = torchaug_transform.transforms[0]._get_indices_transform(5, "cpu")
params_color = torchaug_transform.transforms[1]._get_params(
    batch[0], 2, (torch.tensor([0, 1, 2]), torch.tensor([3, 4]))
)
params_crop = torchaug_transform.transforms[2]._get_params(batch[0], 2, None)
torchvision_batch = [torchvision_dataset[i] for i in range(5)]
for idx in indices_flip:
    torchvision_batch[idx] = tv_transforms.RandomHorizontalFlip(1)(torchvision_batch[idx])
for idxs, param_color in zip([[0, 1, 2], [3, 4]], params_color):
    tv_transform_color = tv_transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)
    tv_transform_color._get_params = lambda x: param_color
    for i, idx in enumerate(idxs):
        tv_transform_color._get_params = lambda x: {
            "fn_idx": param_color["fn_idx"],
            "brightness_factor": float(param_color["brightness_factor"][i]),
            "contrast_factor": float(param_color["contrast_factor"][i]),
            "saturation_factor": float(param_color["saturation_factor"][i]),
            "hue_factor": float(param_color["hue_factor"][i]),
        }
        torchvision_batch[idx] = tv_transform_color(torchvision_batch[idx])
for idxs, param_crop in zip([[0, 1, 2], [3, 4]], params_crop):
    tv_transform_resized_crop = tv_transforms.RandomResizedCrop([224, 224])
    tv_transform_resized_crop._get_params = lambda x: param_crop
    for idx in idxs:
        torchvision_batch[idx] = tv_transform_resized_crop(torchvision_batch[idx])
sanitized_torchvision_batch = [tv_transforms.SanitizeBoundingBoxes()(img) for img in torchvision_batch]

Visualization

Torchaug

plot(
    [
        [torchaug_dataset[i] for i in range(5)],
        torchaug_transformed_batch_uncollated,
        sanitized_torchaug_batch_uncollated,
    ],
    row_title=["torchaug dataset", "torchaug transformed", "torchaug sanitized"],
)

Torchvision

plot(
    [[torchvision_dataset[i] for i in range(5)], torchvision_batch, sanitized_torchvision_batch],
    row_title=["torchvision dataset", "torchvision transformed", "torchvision sanitized"],
)

Torchaug and Torchvision

plot(
    [torchaug_transformed_batch_uncollated, torchvision_batch],
    row_title=["torchaug transformed", "torchvision transformed"],
)
plot(
    [sanitized_torchaug_batch_uncollated, sanitized_torchvision_batch],
    row_title=["torchaug sanitized", "torchvision sanitized"],
)