Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SRGAN and datamodules for super resolution #466

Merged
merged 140 commits into from
Nov 26, 2021
Merged
Show file tree
Hide file tree
Changes from 132 commits
Commits
Show all changes
140 commits
Select commit Hold shift + click to select a range
d8abd6d
Add components
chris-clem Nov 28, 2020
36344b4
Add stl10_sr dataset and datamodule
chris-clem Nov 28, 2020
7e0f552
Add SRGANModule
chris-clem Nov 29, 2020
3c9dbce
Add pretraining script
chris-clem Nov 30, 2020
cc79753
Add more type annotations and todos
chris-clem Dec 20, 2020
8c21b69
Prepare srresnet training
chris-clem Dec 20, 2020
01981c2
Merge branch 'master' into feature/412_srgan
chris-clem Dec 20, 2020
777f0c6
Update torchvision imports
chris-clem Dec 20, 2020
1817c40
fix flake8
chris-clem Dec 20, 2020
ea7f602
Fix PIL import
chris-clem Dec 20, 2020
cf75ac6
Fix make docs
chris-clem Dec 20, 2020
e75f6af
Add docs and changelog
chris-clem Dec 20, 2020
315e5df
Add tests
chris-clem Dec 20, 2020
02aaaf6
Add lr scheduler
Dec 26, 2020
7c5d407
Add train_split arg to dm
Dec 26, 2020
c0edcad
Update tests
chris-clem Dec 27, 2020
b40c07b
Merge remote-tracking branch 'origin/feature/412_srgan' into feature/…
chris-clem Dec 27, 2020
d969d09
Merge branch 'master' into feature/412_srgan
chris-clem Dec 27, 2020
34af7c7
Add CIFAR10_SR for testing purposes and update tests
chris-clem Dec 27, 2020
5bbf254
fix import when torchvision not available
chris-clem Dec 27, 2020
f401ffa
Add checkpoint
Dec 27, 2020
ee89159
Merge branch 'feature/412_srgan' of https://github.com/chris-clem/pyt…
Dec 27, 2020
3d025ec
Update tests
Dec 27, 2020
cff05d2
Rename methods
Dec 27, 2020
e0c7fd7
Update args
Dec 28, 2020
cf32bb3
Change atol for stl10_sr datamodule
Dec 28, 2020
85a5f90
Run isort
Dec 29, 2020
fbad4cf
Change order of images
Dec 29, 2020
b95ab22
Set requires_grad to False for feature extractor
Dec 29, 2020
1ec9915
Update docs and checkpoint name
Dec 29, 2020
9d79660
Remove get in function name, add content and perceptual loss
Dec 29, 2020
cbc7269
Add SRDataModule
Dec 31, 2020
e435ba4
Remove CIFAR10_SR
Dec 31, 2020
68f071c
Add SRDatasetMixin and SR CelebA, STL10, MNIST
Dec 31, 2020
cbb0101
Add train, val, test, step and datasets
Dec 31, 2020
0a2c08a
Add parse_args function
Dec 31, 2020
351fa56
Update srgan module
Dec 31, 2020
a8a069e
Update tests
Dec 31, 2020
9fd88ea
Fix flake8
Dec 31, 2020
e9a5b59
Make tests work
Dec 31, 2020
55839ae
Add new checkpoints
Dec 31, 2020
09c89b9
Update init alls
Jan 1, 2021
194ff9d
Update type hints
Jan 1, 2021
894e834
Update docs
Jan 1, 2021
8e6b721
Update type hints
Jan 1, 2021
28c8b8e
Update tests
Jan 1, 2021
56a94d7
Run tests locally
Jan 1, 2021
5cd437c
Remove torchvision type hint
Jan 1, 2021
3c07582
Update checkpoints
Jan 1, 2021
c1f0f58
Warn if no generator checkpoint found
Jan 1, 2021
c6b0d44
add prepare_dataset function
Jan 1, 2021
47ef5b1
Update docs
Jan 1, 2021
daea7e0
Revert black changes
Jan 3, 2021
4ab26c6
Delete model checkpoints
Jan 3, 2021
3104ad4
Merge branch 'master' into feature/412_srgan
chris-clem Jan 4, 2021
efd8cc1
Merge branch 'master' into feature/412_srgan
chris-clem Jan 4, 2021
fea021e
Merge remote-tracking branch 'origin/feature/412_srgan' into feature/…
chris-clem Jan 4, 2021
1ff22c4
Merge branch 'master' into feature/412_srgan
chris-clem Jan 7, 2021
38c7cd7
Apply suggestions from review
chris-clem Jan 8, 2021
8719a4b
Apply suggestions from review
chris-clem Jan 8, 2021
7c284c1
Apply suggestions from review
chris-clem Jan 8, 2021
3a9e447
Fix test_gans
Jan 8, 2021
bc20e31
Fixing tests second try
Jan 8, 2021
e5ed1bd
Fix tests
Jan 8, 2021
c1d236a
Update test_scripts
Jan 8, 2021
9ad817c
Fix fast_dev_run arg
Jan 10, 2021
f11b4ae
Remove fast_dev_run from test
Jan 10, 2021
ea8f572
Merge branch 'master' into feature/412_srgan
chris-clem Jan 17, 2021
342bd41
Apply yapf
Jan 17, 2021
94e4e8b
Update CHANGELOG.md
Borda Jan 18, 2021
5b0adba
Merge branch 'master' into feature/412_srgan
akihironitta Jan 19, 2021
f4b0a0d
Merge branch 'master' into chris-clem-feature/412_srgan
akihironitta Jan 19, 2021
c49afcd
Apply suggestions from code review
Borda Jan 19, 2021
b6b0681
Make tests work again
Jan 24, 2021
18132e7
Merge branch 'master' into feature/412_srgan
Jan 24, 2021
5b0bf93
Add docs for SRImageLoggerCallback
Jan 24, 2021
2ab355b
Apply suggestions from code review
akihironitta Feb 13, 2021
88ef342
Move "pragma: no cover" to clause beginnings
akihironitta Feb 13, 2021
6aa4fa9
Merge branch 'master' into feature/412_srgan
chris-clem Mar 11, 2021
a8430ec
Update components.py
akihironitta Mar 13, 2021
d56f6c2
Add types
akihironitta Mar 13, 2021
7cbcf30
Add SRResNet to the docs
akihironitta Mar 13, 2021
07aa771
Add the original implementation link
akihironitta Mar 13, 2021
690e7b9
Simplify the dataset
akihironitta Mar 13, 2021
895803f
Link abs of the paper
akihironitta Mar 13, 2021
ece8940
Revert CHANGELOG
akihironitta Mar 13, 2021
c302544
Update CHANGELOG.md
akihironitta Mar 13, 2021
a4e990c
Simplify and yapf
akihironitta Mar 13, 2021
15e6973
Rermove parse_args from utils
Mar 21, 2021
18f663a
Fix flake8
Mar 21, 2021
60f7b36
Merge branch 'master' into feature/412_srgan
akihironitta Jul 31, 2021
f0c7c13
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2021
c9e6aeb
Fix docs warning
akihironitta Aug 10, 2021
40d3770
Temporarily disable GPU testing
akihironitta Aug 10, 2021
9b05e72
Merge branch 'master' into feature/412_srgan
akihironitta Aug 10, 2021
6fba9b5
Remove irrelevant tests
akihironitta Aug 10, 2021
8936273
Refactor test args for consistency
akihironitta Aug 10, 2021
6e3c2c6
Remove unused local assignments
akihironitta Aug 10, 2021
8beba4b
Instantiate transforms only once
akihironitta Aug 10, 2021
b21d1f4
Decouple SRMNISTDataset from mnist_dataset.py
akihironitta Aug 10, 2021
0941f37
Revert "Instantiate transforms only once"
akihironitta Aug 10, 2021
a9ee84a
Rename datasets for consistency
akihironitta Aug 10, 2021
d19a1de
Remove duplicate properties of SRMNIST
akihironitta Aug 10, 2021
7cdb75c
Always try to download datasets
akihironitta Aug 12, 2021
065b505
Always try to download datasets
akihironitta Aug 12, 2021
7874761
Always try to download datasets
akihironitta Aug 12, 2021
cf54cc9
Always try to download datasets
akihironitta Aug 12, 2021
f2c1f98
Move utils
Aug 20, 2021
814fa82
Rename SRDataModule to TVTDataModule
Aug 20, 2021
c1574fd
Initial hr and lr transforms in init
Aug 20, 2021
8b2b9fd
fix transforms initialization
Aug 20, 2021
031adc2
Merge branch 'master' into feature/412_srgan
chris-clem Aug 20, 2021
3880da7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2021
ab632a4
Fix resizing
Aug 20, 2021
fccb234
Merge branch 'master' into feature/412_srgan
chris-clem Sep 3, 2021
fcf7e45
Merge branch 'master' into feature/412_srgan
akihironitta Sep 9, 2021
6bf356b
Update CHANGELOG.md
akihironitta Sep 9, 2021
376828d
Merge branch 'master' into feature/412_srgan
chris-clem Sep 10, 2021
e28e654
Fix CHANGELOG
akihironitta Sep 10, 2021
4c1d805
Merge branch 'master' into feature/412_srgan
akihironitta Sep 10, 2021
4ee0262
Merge branch 'master' into feature/412_srgan
akihironitta Sep 13, 2021
7977edc
Set pl>=1.4.8
akihironitta Oct 2, 2021
77d1025
Revert "Set pl>=1.4.8"
akihironitta Oct 2, 2021
1044d8c
Revert "Temporarily disable GPU testing"
akihironitta Oct 2, 2021
42546b5
Merge branch 'master' into feature/412_srgan
akihironitta Oct 2, 2021
3a83ecb
Fix azurepiplines.yml
akihironitta Oct 2, 2021
bfd2054
.
akihironitta Oct 2, 2021
a99f92b
Exclude gym==0.20.0 to avoid attributeerror
akihironitta Oct 2, 2021
e93ab15
Fix image links
akihironitta Oct 3, 2021
a30c129
Merge branch 'master' into feature/412_srgan
mergify[bot] Oct 14, 2021
bc38c84
Merge branch 'master' into feature/412_srgan
mergify[bot] Oct 14, 2021
500c113
Merge branch 'master' into feature/412_srgan
mergify[bot] Oct 14, 2021
4515b36
Merge branch 'master' into feature/412_srgan
mergify[bot] Oct 15, 2021
a4edc40
Merge branch 'master' into feature/412_srgan
mergify[bot] Oct 20, 2021
5c330e9
Merge branch 'master' into feature/412_srgan
Borda Nov 8, 2021
e2b1249
Merge branch 'master' into feature/412_srgan
mergify[bot] Nov 8, 2021
eaa7379
Merge branch 'master' into feature/412_srgan
mergify[bot] Nov 8, 2021
4d247e0
Merge branch 'master' into feature/412_srgan
mergify[bot] Nov 8, 2021
168e845
Merge branch 'master' into feature/412_srgan
mergify[bot] Nov 15, 2021
f030023
Merge branch 'master' into feature/412_srgan
Borda Nov 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added YOLO model ([#552](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/552))


- Added `SRGAN`, `SRImageLoggerCallback`, `TVTDataModule`, `SRCelebA`, `SRMNIST`, `SRSTL10` ([#466](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/466))


### Changed


Expand Down Expand Up @@ -115,8 +119,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#323](https://github.com/PyTorchLightning/lightning-bolts/pull/323))
- Added data monitor callbacks `ModuleDataMonitor` and `TrainingDataMonitor` ([#285](https://github.com/PyTorchLightning/lightning-bolts/pull/285))
- Added DCGAN module ([#403](https://github.com/PyTorchLightning/lightning-bolts/pull/403))
- Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`,
and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/lightning-bolts/pull/400))
- Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`, and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/lightning-bolts/pull/400))
- Added GIoU loss ([#347](https://github.com/PyTorchLightning/lightning-bolts/pull/347))
- Added IoU loss ([#469](https://github.com/PyTorchLightning/lightning-bolts/pull/469))
- Added semantic segmentation model `SemSegment` with `UNet` backend ([#259](https://github.com/PyTorchLightning/lightning-bolts/pull/259))
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
83 changes: 83 additions & 0 deletions docs/source/deprecated/models/gans.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,86 @@ LSUN Loss curves:

.. autoclass:: pl_bolts.models.gans.DCGAN
:noindex:


SRGAN
---------
SRGAN implementation from the paper `Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial
Network <https://arxiv.org/pdf/1609.04802.pdf>`_. The implementation is based on the version from
`deeplearning.ai <https://github.com/https-deeplearning-ai/GANs-Public/blob/master/C3W2_SRGAN_(Optional).ipynb>`_.

Implemented by:

- `Christoph Clement <https://github.com/chris-clem>`_

MNIST results:

SRGAN MNIST with scale factor of 2 (left: low res, middle: generated high res, right: ground truth high res):

.. image:: ../../_images/gans/srgan-mnist-scale_factor=2.png
:width: 200
:alt: SRGAN MNIST with scale factor of 2

SRGAN MNIST with scale factor of 4:

.. image:: ../../_images/gans/srgan-mnist-scale_factor=4.png
:width: 200
:alt: SRGAN MNIST with scale factor of 4

SRResNet pretraining command used::
>>> python srresnet_module.py --dataset=mnist --data_dir=~/Data --scale_factor=4 --save_model_checkpoint \
--batch_size=16 --num_workers=2 --gpus=4 --accelerator=ddp --precision=16 --max_steps=25000

SRGAN training command used::
>>> python srgan_module.py --dataset=mnist --data_dir=~/Data --scale_factor=4 --batch_size=16 \
--num_workers=2 --scheduler_step=29 --gpus=4 --accelerator=ddp --precision=16 --max_steps=50000

STL10 results:

SRGAN STL10 with scale factor of 2:

.. image:: ../../_images/gans/srgan-stl10-scale_factor=2.png
:width: 200
:alt: SRGAN STL10 with scale factor of 2

SRGAN STL10 with scale factor of 4:

.. image:: ../../_images/gans/srgan-stl10-scale_factor=4.png
:width: 200
:alt: SRGAN STL10 with scale factor of 4

SRResNet pretraining command used::
>>> python srresnet_module.py --dataset=stl10 --data_dir=~/Data --scale_factor=4 --save_model_checkpoint \
--batch_size=16 --num_workers=2 --gpus=4 --accelerator=ddp --precision=16 --max_steps=25000

SRGAN training command used::
>>> python srgan_module.py --dataset=stl10 --data_dir=~/Data --scale_factor=4 --batch_size=16 \
--num_workers=2 --scheduler_step=29 --gpus=4 --accelerator=ddp --precision=16 --max_steps=50000

CelebA results:

SRGAN CelebA with scale factor of 2:

.. image:: ../../_images/gans/srgan-celeba-scale_factor=2.png
:width: 200
:alt: SRGAN CelebA with scale factor of 2

SRGAN CelebA with scale factor of 4:

.. image:: ../../_images/gans/srgan-celeba-scale_factor=4.png
:width: 200
:alt: SRGAN CelebA with scale factor of 4

SRResNet pretraining command used::
>>> python srresnet_module.py --dataset=celeba --data_dir=~/Data --scale_factor=4 --save_model_checkpoint \
--batch_size=16 --num_workers=2 --gpus=4 --accelerator=ddp --precision=16 --max_steps=25000

SRGAN training command used::
>>> python srgan_module.py --dataset=celeba --data_dir=~/Data --scale_factor=4 --batch_size=16 \
--num_workers=2 --scheduler_step=29 --gpus=4 --accelerator=ddp --precision=16 --max_steps=50000

.. autoclass:: pl_bolts.models.gans.SRGAN
:noindex:

.. autoclass:: pl_bolts.models.gans.SRResNet
:noindex:
2 changes: 2 additions & 0 deletions pl_bolts/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerificationCallback # type: ignore
from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback
from pl_bolts.callbacks.vision.image_generation import TensorboardGenerativeModelImageSampler
from pl_bolts.callbacks.vision.sr_image_logger import SRImageLoggerCallback

__all__ = [
"BatchGradientVerificationCallback",
Expand All @@ -20,6 +21,7 @@
"LatentDimInterpolator",
"ConfusedLogitCallback",
"TensorboardGenerativeModelImageSampler",
"SRImageLoggerCallback",
"ORTCallback",
"SparseMLCallback",
]
67 changes: 67 additions & 0 deletions pl_bolts/callbacks/vision/sr_image_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Tuple

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from pytorch_lightning import Callback

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision.utils import make_grid
else: # pragma: no cover
warn_missing_pkg("torchvision")


class SRImageLoggerCallback(Callback):
chris-clem marked this conversation as resolved.
Show resolved Hide resolved
"""Logs low-res, generated high-res, and ground truth high-res images to TensorBoard Your model must implement
the ``forward`` function for generation.

Requirements::

# model forward must work generating high-res from low-res image
hr_fake = pl_module(lr_image)

Example::

from pl_bolts.callbacks import SRImageLoggerCallback

trainer = Trainer(callbacks=[SRImageLoggerCallback()])
"""

def __init__(self, log_interval: int = 1000, scale_factor: int = 4, num_samples: int = 5) -> None:
"""
Args:
log_interval: Number of steps between logging. Default: ``1000``.
scale_factor: Scale factor used for downsampling the high-res images. Default: ``4``.
num_samples: Number of images of displayed in the grid. Default: ``5``.
"""
super().__init__()
self.log_interval = log_interval
self.scale_factor = scale_factor
self.num_samples = num_samples

def on_train_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: torch.Tensor,
batch: Tuple[torch.Tensor, torch.Tensor],
batch_idx: int,
dataloader_idx: int,
) -> None:
global_step = trainer.global_step
if global_step % self.log_interval == 0:
hr_image, lr_image = batch
hr_image, lr_image = hr_image.to(pl_module.device), lr_image.to(pl_module.device)
hr_fake = pl_module(lr_image)
lr_image = F.interpolate(lr_image, scale_factor=self.scale_factor)

lr_image_grid = make_grid(lr_image[: self.num_samples], nrow=1, normalize=True)
hr_fake_grid = make_grid(hr_fake[: self.num_samples], nrow=1, normalize=True)
hr_image_grid = make_grid(hr_image[: self.num_samples], nrow=1, normalize=True)

grid = torch.cat((lr_image_grid, hr_fake_grid, hr_image_grid), -1)
title = "sr_images"
trainer.logger.experiment.add_image(title, grid, global_step=global_step)
2 changes: 2 additions & 0 deletions pl_bolts/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pl_bolts.datamodules.kitti_datamodule import KittiDataModule
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule, SklearnDataset, TensorDataset
from pl_bolts.datamodules.sr_datamodule import TVTDataModule
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule
Expand All @@ -31,6 +32,7 @@
"SklearnDataModule",
"SklearnDataset",
"TensorDataset",
"TVTDataModule",
"SSLImagenetDataModule",
"STL10DataModule",
"VOCDetectionDataModule",
Expand Down
73 changes: 73 additions & 0 deletions pl_bolts/datamodules/sr_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Any

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset


class TVTDataModule(LightningDataModule):
"""Simple DataModule creating train, val, and test dataloaders from given train, val, and test dataset.

Example::
from pl_bolts.datamodules import TVTDataModule
from pl_bolts.datasets.sr_mnist_dataset import SRMNIST

dataset_dev = SRMNIST(scale_factor=4, root=".", train=True)
dataset_train, dataset_val = random_split(dataset_dev, lengths=[55_000, 5_000])
dataset_test = SRMNIST(scale_factor=4, root=".", train=True)
dm = TVTDataModule(dataset_train, dataset_val, dataset_test)
"""

def __init__(
self,
dataset_train: Dataset,
dataset_val: Dataset,
dataset_test: Dataset,
batch_size: int = 16,
shuffle: bool = True,
num_workers: int = 8,
pin_memory: bool = True,
drop_last: bool = True,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
dataset_train: Train dataset
dataset_val: Val dataset
dataset_test: Test dataset
batch_size: How many samples per batch to load
num_workers: How many workers to use for loading data
shuffle: If true shuffles the train data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""
super().__init__()

self.dataset_train = dataset_train
self.dataset_val = dataset_val
self.dataset_test = dataset_test
self.num_workers = num_workers
self.batch_size = batch_size
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last

def train_dataloader(self) -> DataLoader:
return self._dataloader(self.dataset_train, shuffle=self.shuffle)

def val_dataloader(self) -> DataLoader:
return self._dataloader(self.dataset_val, shuffle=False)

def test_dataloader(self) -> DataLoader:
return self._dataloader(self.dataset_test, shuffle=False)

def _dataloader(self, dataset: Dataset, shuffle: bool = True) -> DataLoader:
return DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=shuffle,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
)
33 changes: 33 additions & 0 deletions pl_bolts/datasets/sr_celeba_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
from typing import Any

from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin
from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _PIL_AVAILABLE:
from PIL import Image
else: # pragma: no cover
warn_missing_pkg("PIL", pypi_name="Pillow")

if _TORCHVISION_AVAILABLE:
from torchvision.datasets import CelebA
else: # pragma: no cover
warn_missing_pkg("torchvision")
CelebA = object


class SRCelebA(SRDatasetMixin, CelebA):
"""CelebA dataset that can be used to train Super Resolution models.

Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image.
"""

def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None:
hr_image_size = 128
lr_image_size = hr_image_size // scale_factor
self.image_channels = 3
super().__init__(hr_image_size, lr_image_size, self.image_channels, *args, **kwargs)

def _get_image(self, index: int):
return Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
52 changes: 52 additions & 0 deletions pl_bolts/datasets/sr_dataset_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Adapted from: https://github.com/https-deeplearning-ai/GANs-Public."""
from typing import Any, Tuple

import torch

from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _PIL_AVAILABLE:
from PIL import Image
else: # pragma: no cover
warn_missing_pkg("PIL", pypi_name="Pillow")

if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
else: # pragma: no cover
warn_missing_pkg("torchvision")


class SRDatasetMixin:
"""Mixin for Super Resolution datasets.

Scales range of high resolution images to [-1, 1] and range or low resolution images to [0, 1].
"""

def __init__(self, hr_image_size: int, lr_image_size: int, image_channels: int, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

self.hr_transforms = transform_lib.Compose(
[
transform_lib.RandomCrop(hr_image_size),
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,) * image_channels, std=(0.5,) * image_channels),
]
)

self.lr_transforms = transform_lib.Compose(
[
transform_lib.Normalize(mean=(-1.0,) * image_channels, std=(2.0,) * image_channels),
transform_lib.ToPILImage(),
transform_lib.Resize(lr_image_size, Image.BICUBIC),
transform_lib.ToTensor(),
]
)

def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
image = self._get_image(index)

hr_image = self.hr_transforms(image)
lr_image = self.lr_transforms(hr_image)

return hr_image, lr_image
27 changes: 27 additions & 0 deletions pl_bolts/datasets/sr_mnist_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Any

from pl_bolts.datasets.mnist_dataset import MNIST
from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin
from pl_bolts.utils import _PIL_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _PIL_AVAILABLE:
from PIL import Image
else: # pragma: no cover
warn_missing_pkg("PIL", pypi_name="Pillow")


class SRMNIST(SRDatasetMixin, MNIST):
"""MNIST dataset that can be used to train Super Resolution models.

Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image.
"""

def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None:
hr_image_size = 28
lr_image_size = hr_image_size // scale_factor
self.image_channels = 1
super().__init__(hr_image_size, lr_image_size, self.image_channels, *args, **kwargs)

def _get_image(self, index: int):
return Image.fromarray(self.data[index].numpy(), mode="L")
Loading