# @Copyright: CEA-LIST/DIASI/SIALV/ (2023- )
# @Author: CEA-LIST/DIASI/SIALV/ <julien.denize@cea.fr>
# @License: CECILL-C
#
# Code partially based on Torchvision (BSD 3-Clause License), available at:
# https://github.com/pytorch/vision
from __future__ import annotations
import collections.abc
import contextlib
from collections import defaultdict
from copy import copy
import PIL.Image
import torch
from torchvision import datasets
from torchvision.transforms.v2.functional._type_conversion import pil_to_tensor
from torchaug import ta_tensors
from torchaug.transforms import functional as F
class WrapperFactories(dict):
def register(self, dataset_cls):
def decorator(wrapper_factory):
self[dataset_cls] = wrapper_factory
return wrapper_factory
return decorator
# We need this two-stage design, i.e. a wrapper factory producing the actual wrapper, since some wrappers depend on the
# dataset instance rather than just the class, since they require the user defined instance attributes. Thus, we can
# provide a wrapping from the dataset class to the factory here, but can only instantiate the wrapper at runtime when
# we have access to the dataset instance.
WRAPPER_FACTORIES = WrapperFactories()
class VisionDatasetTATensorWrapper:
def __init__(self, dataset, target_keys):
dataset_cls = type(dataset)
if not isinstance(dataset, datasets.VisionDataset):
raise TypeError(
f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
f"but got a '{dataset_cls.__name__}' instead.\n"
f"For an example of how to perform the wrapping for custom datasets, see\n\n"
"https://pytorch.org/vision/main/auto_examples/plot_tv_tensors.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
)
for cls in dataset_cls.mro():
if cls in WRAPPER_FACTORIES:
wrapper_factory = WRAPPER_FACTORIES[cls]
if target_keys is not None and cls not in {
datasets.CocoDetection,
datasets.VOCDetection,
datasets.Kitti,
datasets.WIDERFace,
}:
raise ValueError(
f"`target_keys` is currently only supported for `CocoDetection`, `VOCDetection`, `Kitti`, "
f"and `WIDERFace`, but got {cls.__name__}."
)
break
elif cls is datasets.VisionDataset:
# TODO: If we have documentation on how to do that, put a link in the error message.
msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
if dataset_cls in datasets.__dict__.values():
msg = (
f"{msg} If an automated wrapper for this dataset would be useful for you, "
f"please open an issue at https://github.com/pytorch/vision/issues."
)
raise TypeError(msg)
self._dataset = dataset
self._target_keys = target_keys
self._wrapper = wrapper_factory(dataset, target_keys)
# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
# Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
# `transforms`
# https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54
# some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to
# disable all three here to be able to extract the untransformed sample to wrap.
self.transform, dataset.transform = dataset.transform, None
self.target_transform, dataset.target_transform = dataset.target_transform, None
self.transforms, dataset.transforms = dataset.transforms, None
def __getattr__(self, item):
with contextlib.suppress(AttributeError):
return object.__getattribute__(self, item)
return getattr(self._dataset, item)
def __getitem__(self, idx):
# This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor
# of this class
sample = self._dataset[idx]
sample = self._wrapper(idx, sample)
# Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
# or joint (`transforms`), we can access the full functionality through `transforms`
if self.transforms is not None:
sample = self.transforms(*sample)
return sample
def __len__(self):
return len(self._dataset)
# TODO: maybe we should use __getstate__ and __setstate__ instead of __reduce__, as recommended in the docs.
def __reduce__(self):
# __reduce__ gets called when we try to pickle the dataset.
# In a DataLoader with spawn context, this gets called `num_workers` times from the main process.
# We have to reset the [target_]transform[s] attributes of the dataset
# to their original values, because we previously set them to None in __init__().
dataset = copy(self._dataset)
dataset.transform = self.transform
dataset.transforms = self.transforms
dataset.target_transform = self.target_transform
return wrap_dataset_for_transforms_v2, (dataset, self._target_keys)
def raise_not_supported(description):
raise RuntimeError(
f"{description} is currently not supported by this wrapper. "
f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues."
)
def identity(item):
return item
def identity_wrapper_factory(dataset, target_keys):
def wrapper(idx, sample):
return sample
return wrapper
def pil_image_to_mask(pil_image):
return ta_tensors.Mask(pil_image)
def parse_target_keys(target_keys, *, available, default):
if target_keys is None:
target_keys = default
if target_keys == "all":
target_keys = available
else:
target_keys = set(target_keys)
extra = target_keys - available
if extra:
raise ValueError(f"Target keys {sorted(extra)} are not available")
return target_keys
def list_of_dicts_to_dict_of_lists(list_of_dicts):
dict_of_lists = defaultdict(list)
for dct in list_of_dicts:
for key, value in dct.items():
dict_of_lists[key].append(value)
return dict(dict_of_lists)
def wrap_target_by_type(target, *, target_types, type_wrappers):
if not isinstance(target, (tuple, list)):
target = [target]
wrapped_target = tuple(
type_wrappers.get(target_type, identity)(item) for target_type, item in zip(target_types, target)
)
if len(wrapped_target) == 1:
wrapped_target = wrapped_target[0]
return wrapped_target
def classification_wrapper_factory(dataset, target_keys):
return identity_wrapper_factory(dataset, target_keys)
for dataset_cls in [
datasets.Caltech256,
datasets.CIFAR10,
datasets.CIFAR100,
datasets.ImageNet,
datasets.MNIST,
datasets.FashionMNIST,
datasets.GTSRB,
datasets.DatasetFolder,
datasets.ImageFolder,
datasets.Imagenette,
]:
WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory)
def segmentation_wrapper_factory(dataset, target_keys):
def wrapper(idx, sample):
image, mask = sample
return image, pil_image_to_mask(mask)
return wrapper
for dataset_cls in [
datasets.VOCSegmentation,
]:
WRAPPER_FACTORIES.register(dataset_cls)(segmentation_wrapper_factory)
def video_classification_wrapper_factory(dataset, target_keys):
if dataset.video_clips.output_format == "THWC":
raise RuntimeError(
f"{type(dataset).__name__} with `output_format='THWC'` is not supported by this wrapper, "
f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
)
def wrapper(idx, sample):
video, audio, label = sample
video = ta_tensors.Video(video)
return video, audio, label
return wrapper
for dataset_cls in [
datasets.HMDB51,
datasets.Kinetics,
datasets.UCF101,
]:
WRAPPER_FACTORIES.register(dataset_cls)(video_classification_wrapper_factory)
@WRAPPER_FACTORIES.register(datasets.Caltech101)
def caltech101_wrapper_factory(dataset, target_keys):
if "annotation" in dataset.target_type:
raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`")
return classification_wrapper_factory(dataset, target_keys)
@WRAPPER_FACTORIES.register(datasets.CocoDetection)
def coco_dectection_wrapper_factory(dataset, target_keys):
target_keys = parse_target_keys(
target_keys,
available={
# native
"segmentation",
"area",
"iscrowd",
"image_id",
"bbox",
"category_id",
# added by the wrapper
"boxes",
"masks",
"labels",
},
default={"image_id", "boxes", "labels"},
)
def segmentation_to_mask(segmentation, *, canvas_size):
from pycocotools import mask # noqa: I001
segmentation = (
mask.frPyObjects(segmentation, *canvas_size)
if isinstance(segmentation, dict)
else mask.merge(mask.frPyObjects(segmentation, *canvas_size))
)
return torch.from_numpy(mask.decode(segmentation))
def wrapper(idx, sample):
image_id = dataset.ids[idx]
image, target = sample
if isinstance(image, PIL.Image.Image):
image = pil_to_tensor(image)
image = F.to_image(image)
if not target:
return image, {"image_id": image_id}
canvas_size = tuple(F.get_size(image))
batched_target = list_of_dicts_to_dict_of_lists(target)
target = {}
if "image_id" in target_keys:
target["image_id"] = image_id
if "boxes" in target_keys:
target["boxes"] = F.convert_bounding_box_format(
ta_tensors.BoundingBoxes(
batched_target["bbox"],
format=ta_tensors.BoundingBoxFormat.XYWH,
canvas_size=canvas_size,
),
new_format=ta_tensors.BoundingBoxFormat.XYXY,
)
if "masks" in target_keys:
target["masks"] = ta_tensors.Mask(
torch.stack(
[
segmentation_to_mask(segmentation, canvas_size=canvas_size)
for segmentation in batched_target["segmentation"]
]
),
)
if "labels" in target_keys:
target["labels"] = ta_tensors.Labels(torch.tensor(batched_target["category_id"]))
for target_key in target_keys - {"image_id", "boxes", "masks", "labels"}:
target[target_key] = batched_target[target_key]
return image, target
return wrapper
WRAPPER_FACTORIES.register(datasets.CocoCaptions)(identity_wrapper_factory)
VOC_DETECTION_CATEGORIES = [
"__background__",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC_DETECTION_CATEGORIES))))
@WRAPPER_FACTORIES.register(datasets.VOCDetection)
def voc_detection_wrapper_factory(dataset, target_keys):
target_keys = parse_target_keys(
target_keys,
available={
# native
"annotation",
# added by the wrapper
"boxes",
"labels",
},
default={"boxes", "labels"},
)
def wrapper(idx, sample):
image, target = sample
if isinstance(image, PIL.Image.Image):
image = pil_to_tensor(image)
image = F.to_image(image)
batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"])
if "annotation" not in target_keys:
target = {}
if "boxes" in target_keys:
target["boxes"] = ta_tensors.BoundingBoxes(
[
[int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")]
for bndbox in batched_instances["bndbox"]
],
format=ta_tensors.BoundingBoxFormat.XYXY,
canvas_size=(image.height, image.width),
)
if "labels" in target_keys:
target["labels"] = torch.tensor(
[VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]]
)
return image, target
return wrapper
@WRAPPER_FACTORIES.register(datasets.SBDataset)
def sbd_wrapper(dataset, target_keys):
if dataset.mode == "boundaries":
raise_not_supported("SBDataset with mode='boundaries'")
return segmentation_wrapper_factory(dataset, target_keys)
@WRAPPER_FACTORIES.register(datasets.CelebA)
def celeba_wrapper_factory(dataset, target_keys):
if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]):
raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")
def wrapper(idx, sample):
image, target = sample
if isinstance(image, PIL.Image.Image):
image = pil_to_tensor(image)
image = F.to_image(image)
target = wrap_target_by_type(
target,
target_types=dataset.target_type,
type_wrappers={
"bbox": lambda item: F.convert_bounding_box_format(
ta_tensors.BoundingBoxes(
item,
format=ta_tensors.BoundingBoxFormat.XYWH,
canvas_size=(image.height, image.width),
),
new_format=ta_tensors.BoundingBoxFormat.XYXY,
),
},
)
return image, target
return wrapper
KITTI_CATEGORIES = [
"Car",
"Van",
"Truck",
"Pedestrian",
"Person_sitting",
"Cyclist",
"Tram",
"Misc",
"DontCare",
]
KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES))))
@WRAPPER_FACTORIES.register(datasets.Kitti)
def kitti_wrapper_factory(dataset, target_keys):
target_keys = parse_target_keys(
target_keys,
available={
# native
"type",
"truncated",
"occluded",
"alpha",
"bbox",
"dimensions",
"location",
"rotation_y",
# added by the wrapper
"boxes",
"labels",
},
default={"boxes", "labels"},
)
def wrapper(idx, sample):
image, target = sample
if isinstance(image, PIL.Image.Image):
image = pil_to_tensor(image)
image = F.to_image(image)
if target is None:
return image, target
batched_target = list_of_dicts_to_dict_of_lists(target)
target = {}
if "boxes" in target_keys:
target["boxes"] = ta_tensors.BoundingBoxes(
batched_target["bbox"],
format=ta_tensors.BoundingBoxFormat.XYXY,
canvas_size=(image.height, image.width),
)
if "labels" in target_keys:
target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in batched_target["type"]])
for target_key in target_keys - {"boxes", "labels"}:
target[target_key] = batched_target[target_key]
return image, target
return wrapper
@WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)
def oxford_iiit_pet_wrapper_factory(dataset, target_keys):
def wrapper(idx, sample):
image, target = sample
if isinstance(image, PIL.Image.Image):
image = pil_to_tensor(image)
image = F.to_image(image)
if target is not None:
target = wrap_target_by_type(
target,
target_types=dataset._target_types,
type_wrappers={
"segmentation": pil_image_to_mask,
},
)
return image, target
return wrapper
@WRAPPER_FACTORIES.register(datasets.Cityscapes)
def cityscapes_wrapper_factory(dataset, target_keys):
if any(target_type in dataset.target_type for target_type in ["polygon", "color"]):
raise_not_supported("`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`")
def instance_segmentation_wrapper(mask):
# See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21
data = pil_image_to_mask(mask)
masks = []
labels = []
for id in data.unique():
masks.append(data == id)
label = id
if label >= 1_000:
label //= 1_000
labels.append(label)
return {"masks": ta_tensors.Mask(torch.stack(masks)), "labels": torch.stack(labels)}
def wrapper(idx, sample):
image, target = sample
if isinstance(image, PIL.Image.Image):
image = pil_to_tensor(image)
image = F.to_image(image)
target = wrap_target_by_type(
target,
target_types=dataset.target_type,
type_wrappers={
"instance": instance_segmentation_wrapper,
"semantic": pil_image_to_mask,
},
)
return image, target
return wrapper
@WRAPPER_FACTORIES.register(datasets.WIDERFace)
def widerface_wrapper(dataset, target_keys):
target_keys = parse_target_keys(
target_keys,
available={
"bbox",
"blur",
"expression",
"illumination",
"occlusion",
"pose",
"invalid",
},
default="all",
)
def wrapper(idx, sample):
image, target = sample
if isinstance(image, PIL.Image.Image):
image = pil_to_tensor(image)
image = F.to_image(image)
if target is None:
return image, target
target = {key: target[key] for key in target_keys}
if "bbox" in target_keys:
target["bbox"] = F.convert_bounding_box_format(
ta_tensors.BoundingBoxes(
target["bbox"],
format=ta_tensors.BoundingBoxFormat.XYWH,
canvas_size=(image.height, image.width),
),
new_format=ta_tensors.BoundingBoxFormat.XYXY,
)
return image, target
return wrapper