Skip to content

Commit

Permalink
Refactor Vision DataModules (#400)
Browse files Browse the repository at this point in the history
* Add BaseDataModule

* Add pre-commit hooks

* Refactor cifar10_datamodule

* Move torchvision warning

* Refactor binary_mnist_datamodule

* Refactor fashion_mnist_datamodule

* Fix errors

* Remove VisionDataset type hint so CI base testing does not fail (torchvision is not installed there)

* Implement Nate's suggestions

* Remove train and eval batch size because it brakes a lot of tests

* Properly add transforms to train and val dataset

* Add num_samples property to cifar10 dm

* Add tesats and docs

* Fix flake8 and codafactor issue

* Update changelog

* Fix isort

* Add typing

* Rename to VisionDataModule

* Remove transform_lib type annotation

* suggestions

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: Akihiro Nitta <[email protected]>

* Add flags from #388 to API

* Make tests work

* Move _TORCHVISION_AVAILABLE check

* Update changelog

* Fix CI base testing

* Fix CI base testing

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Akihiro Nitta <[email protected]>
  • Loading branch information
4 people authored Dec 17, 2020
1 parent 347a63d commit 02684ef
Show file tree
Hide file tree
Showing 9 changed files with 363 additions and 458 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ repos:
- id: isort

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.790
hooks:
- id: mypy
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added data monitor callbacks `ModuleDataMonitor` and `TrainingDataMonitor` ([#285](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/285))

- Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`,
and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/400))

### Changed

- Decoupled datamodules from models ([#332](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/332), [#270](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/270))
Expand Down
155 changes: 44 additions & 111 deletions pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split
from typing import Any, Optional, Union

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets.mnist_dataset import BinaryMNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand All @@ -12,7 +11,7 @@
warn_missing_pkg('torchvision')


class BinaryMNISTDataModule(LightningDataModule):
class BinaryMNISTDataModule(VisionDataModule):
"""
.. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png
:width: 400
Expand Down Expand Up @@ -41,136 +40,70 @@ class BinaryMNISTDataModule(LightningDataModule):
"""

name = "binary_mnist"
dataset_cls = BinaryMNIST
dims = (1, 28, 28)

def __init__(
self,
data_dir: str,
val_split: int = 5000,
num_workers: int = 16,
normalize: bool = False,
batch_size: int = 32,
seed: int = 42,
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
*args,
**kwargs,
):
self,
data_dir: Optional[str] = None,
val_split: Union[int, float] = 0.2,
num_workers: int = 16,
normalize: bool = False,
batch_size: int = 32,
seed: int = 42,
shuffle: bool = False,
pin_memory: bool = False,
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: where to save/load the data
val_split: how many of the training images to use for the validation split
num_workers: how many workers to use for loading data
data_dir: Where to save/load the data
val_split: Percent (float) or number (int) of samples to use for the validation split
num_workers: How many workers to use for loading data
normalize: If true applies image normalize
batch_size: size of batch
seed: random seed to be used for train/val/test splits
shuffle: If true shuffles the data every epoch
batch_size: How many samples per batch to load
seed: Random seed to be used for train/val/test splits
shuffle: If true shuffles the train data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use MNIST dataset loaded from `torchvision` which is not installed yet.'
"You want to use transforms loaded from `torchvision` which is not installed yet."
)

self.dims = (1, 28, 28)
self.data_dir = data_dir
self.val_split = val_split
self.num_workers = num_workers
self.normalize = normalize
self.batch_size = batch_size
self.seed = seed
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
super().__init__(
data_dir=data_dir,
val_split=val_split,
num_workers=num_workers,
normalize=normalize,
batch_size=batch_size,
seed=seed,
shuffle=shuffle,
pin_memory=pin_memory,
drop_last=drop_last,
*args,
**kwargs,
)

@property
def num_classes(self):
def num_classes(self) -> int:
"""
Return:
10
"""
return 10

def prepare_data(self):
"""
Saves MNIST files to data_dir
"""
BinaryMNIST(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor())
BinaryMNIST(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor())

def train_dataloader(self):
"""
MNIST train set removes a subset to use for validation
"""
transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms

dataset = BinaryMNIST(self.data_dir, train=True, download=False, transform=transforms)
train_length = len(dataset)
dataset_train, _ = random_split(
dataset,
[train_length - self.val_split, self.val_split],
generator=torch.Generator().manual_seed(self.seed)
)
loader = DataLoader(
dataset_train,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory
)
return loader

def val_dataloader(self):
"""
MNIST val set uses a subset of the training set for validation
"""
transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
dataset = BinaryMNIST(self.data_dir, train=True, download=False, transform=transforms)
train_length = len(dataset)
_, dataset_val = random_split(
dataset,
[train_length - self.val_split, self.val_split],
generator=torch.Generator().manual_seed(self.seed)
)
loader = DataLoader(
dataset_val,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory
)
return loader

def test_dataloader(self):
"""
MNIST test set uses the test split
"""
transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms

dataset = BinaryMNIST(self.data_dir, train=False, download=False, transform=transforms)
loader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory
)
return loader

def _default_transforms(self):
def default_transforms(self):
if self.normalize:
mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
mnist_transforms = transform_lib.Compose(
[transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
)
else:
mnist_transforms = transform_lib.ToTensor()
mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()])

return mnist_transforms
Loading

0 comments on commit 02684ef

Please sign in to comment.