From ad22dca42480a7842dcf46151a59f0d0436981d5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 25 Oct 2023 16:55:40 +0100 Subject: [PATCH] Fix v2 transforms in spawn mp context --- test/datasets_utils.py | 9 ++++++--- test/test_datasets.py | 4 +++- torchvision/tv_tensors/_dataset_wrapper.py | 12 +++++++++++- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index bd9f7ea3a0f..c5b03e56ba5 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -5,7 +5,6 @@ import itertools import os import pathlib -import platform import random import shutil import string @@ -713,8 +712,8 @@ def check_transforms_v2_wrapper_spawn(dataset): # On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new # subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what # we are enforcing here. - if platform.system() != "Darwin": - pytest.skip("Multiprocessing spawning is only checked on macOS.") + # if platform.system() != "Darwin": + # pytest.skip("Multiprocessing spawning is only checked on macOS.") from torch.utils.data import DataLoader from torchvision import tv_tensors @@ -728,6 +727,10 @@ def check_transforms_v2_wrapper_spawn(dataset): assert tree_any( lambda item: isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)), wrapped_sample ) + from torchvision.datasets import VOCDetection + + if isinstance(dataset, VOCDetection): + assert wrapped_sample[0][0].size == (321, 123) def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor: diff --git a/test/test_datasets.py b/test/test_datasets.py index 1270201d53e..8a510e6932f 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -741,7 +741,9 @@ def test_annotations(self): assert object == info["annotation"] def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset() as (dataset, _): + from torchvision.transforms import v2 + + with self.create_dataset(transform=v2.Resize(size=(123, 321))) as (dataset, _): datasets_utils.check_transforms_v2_wrapper_spawn(dataset) diff --git a/torchvision/tv_tensors/_dataset_wrapper.py b/torchvision/tv_tensors/_dataset_wrapper.py index ef9260ebde9..cdf8d1c01da 100644 --- a/torchvision/tv_tensors/_dataset_wrapper.py +++ b/torchvision/tv_tensors/_dataset_wrapper.py @@ -6,6 +6,7 @@ import contextlib from collections import defaultdict +from copy import copy import torch @@ -199,7 +200,16 @@ def __len__(self): return len(self._dataset) def __reduce__(self): - return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys) + # __reduce__ gets called when we try to pickle the dataset. + # In a DataLoader with spawn context, this gets called `num_workers` times from the main process. + # We have to reset the [target_]transform[s] attribute of the dataset + # before we pass it back to wrap_dataset_for_transforms_v2, because we + # set them to None in __init__(). + dataset = copy(self._dataset) + dataset.transform = self.transform + dataset.transforms = self.transforms + dataset.target_transform = self.target_transform + return wrap_dataset_for_transforms_v2, (dataset, self._target_keys) def raise_not_supported(description):