Skip to content

Commit

Permalink
Fix v2 transforms in spawn mp context
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Oct 25, 2023
1 parent 3fb88b3 commit ad22dca
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
9 changes: 6 additions & 3 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import itertools
import os
import pathlib
import platform
import random
import shutil
import string
Expand Down Expand Up @@ -713,8 +712,8 @@ def check_transforms_v2_wrapper_spawn(dataset):
# On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new
# subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what
# we are enforcing here.
if platform.system() != "Darwin":
pytest.skip("Multiprocessing spawning is only checked on macOS.")
# if platform.system() != "Darwin":
# pytest.skip("Multiprocessing spawning is only checked on macOS.")

from torch.utils.data import DataLoader
from torchvision import tv_tensors
Expand All @@ -728,6 +727,10 @@ def check_transforms_v2_wrapper_spawn(dataset):
assert tree_any(
lambda item: isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)), wrapped_sample
)
from torchvision.datasets import VOCDetection

if isinstance(dataset, VOCDetection):
assert wrapped_sample[0][0].size == (321, 123)


def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
Expand Down
4 changes: 3 additions & 1 deletion test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,9 @@ def test_annotations(self):
assert object == info["annotation"]

def test_transforms_v2_wrapper_spawn(self):
with self.create_dataset() as (dataset, _):
from torchvision.transforms import v2

with self.create_dataset(transform=v2.Resize(size=(123, 321))) as (dataset, _):
datasets_utils.check_transforms_v2_wrapper_spawn(dataset)


Expand Down
12 changes: 11 additions & 1 deletion torchvision/tv_tensors/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import contextlib
from collections import defaultdict
from copy import copy

import torch

Expand Down Expand Up @@ -199,7 +200,16 @@ def __len__(self):
return len(self._dataset)

def __reduce__(self):
return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys)
# __reduce__ gets called when we try to pickle the dataset.
# In a DataLoader with spawn context, this gets called `num_workers` times from the main process.
# We have to reset the [target_]transform[s] attribute of the dataset
# before we pass it back to wrap_dataset_for_transforms_v2, because we
# set them to None in __init__().
dataset = copy(self._dataset)
dataset.transform = self.transform
dataset.transforms = self.transforms
dataset.target_transform = self.target_transform
return wrap_dataset_for_transforms_v2, (dataset, self._target_keys)


def raise_not_supported(description):
Expand Down

0 comments on commit ad22dca

Please sign in to comment.