Skip to content

Commit

Permalink
fix swav to run on imagenet (#348)
Browse files Browse the repository at this point in the history
* imagenet train update

* revert

* fix

* swav for imagenet

* fixes'

* gaussian kernel

* gaussian kernel

* gaussian kernel

* batch_size fix

* batch_size fix

* test fix

* rename logging var for online eval

* finetuner

* finetuner

* test fix
  • Loading branch information
ananyahjha93 authored Nov 10, 2020
1 parent b697407 commit 7dabfae
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 15 deletions.
8 changes: 4 additions & 4 deletions pl_bolts/callbacks/ssl_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data

# log metrics
train_acc = accuracy(mlp_preds, y)
pl_module.log('train_acc', train_acc, on_step=True, on_epoch=False)
pl_module.log('train_mlp_loss', mlp_loss, on_step=True, on_epoch=False)
pl_module.log('online_train_acc', train_acc, on_step=True, on_epoch=False)
pl_module.log('online_train_loss', mlp_loss, on_step=True, on_epoch=False)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
x, y = self.to_device(batch, pl_module.device)
Expand All @@ -119,5 +119,5 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,

# log metrics
val_acc = accuracy(mlp_preds, y)
pl_module.log('val_acc', val_acc, on_step=False, on_epoch=True, sync_dist=True)
pl_module.log('val_mlp_loss', mlp_loss, on_step=False, on_epoch=True, sync_dist=True)
pl_module.log('online_val_acc', val_acc, on_step=False, on_epoch=True, sync_dist=True)
pl_module.log('online_val_loss', mlp_loss, on_step=False, on_epoch=True, sync_dist=True)
2 changes: 2 additions & 0 deletions pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
self.meta_dir = meta_dir
self.num_imgs_per_val_class = num_imgs_per_val_class
self.batch_size = batch_size
self.num_samples = 1281167 - self.num_imgs_per_val_class * self.num_classes

@property
def num_classes(self):
Expand Down Expand Up @@ -144,6 +145,7 @@ def train_dataloader(self):

dataset = UnlabeledImagenet(self.data_dir,
num_imgs_per_class=-1,
num_imgs_per_class_val_split=self.num_imgs_per_val_class,
meta_dir=self.meta_dir,
split='train',
transform=transforms)
Expand Down
5 changes: 3 additions & 2 deletions pl_bolts/models/self_supervised/swav/swav_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ def cli_main(): # pragma: no-cover

backbone = SwAV(
gpus=args.gpus,
nodes=1,
num_samples=args.num_samples,
batch_size=args.batch_size,
datamodule=dm,
maxpool1=args.maxpool1,
first_conv=args.first_conv,
dataset='imagenet',
dataset=args.dataset,
).load_from_checkpoint(args.ckpt_path, strict=False)

tuner = SSLFineTuner(
Expand All @@ -117,6 +117,7 @@ def cli_main(): # pragma: no-cover

trainer = pl.Trainer(
gpus=args.gpus,
num_nodes=1,
precision=16,
max_epochs=args.num_epochs,
distributed_backend='ddp',
Expand Down
58 changes: 50 additions & 8 deletions pl_bolts/models/self_supervised/swav/swav_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@

from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization, stl10_normalization
from pl_bolts.transforms.dataset_normalizations import (
stl10_normalization,
cifar10_normalization,
imagenet_normalization
)


class SwAV(pl.LightningModule):
def __init__(
self,
gpus: int,
nodes: int,
num_samples: int,
batch_size: int,
dataset: str,
Expand Down Expand Up @@ -54,8 +59,9 @@ def __init__(
):
"""
Args:
gpus: number of gpus used in training, passed to SwAV module
gpus: number of gpus per node used in training, passed to SwAV module
to manage the queue and select distributed sinkhorn
nodes: number of nodes to train on
num_samples: number of image samples used for training
batch_size: batch size per GPU in ddp
dataset: dataset being used for train/val
Expand Down Expand Up @@ -94,6 +100,7 @@ def __init__(
self.save_hyperparameters()

self.gpus = gpus
self.nodes = nodes
self.arch = arch
self.dataset = dataset
self.num_samples = num_samples
Expand Down Expand Up @@ -127,15 +134,15 @@ def __init__(
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs

if self.gpus > 1:
if self.gpus * self.nodes > 1:
self.get_assignments = self.distributed_sinkhorn
else:
self.get_assignments = self.sinkhorn

self.model = self.init_model()

# compute iters per epoch
global_batch_size = self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
global_batch_size = self.nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
self.train_iters_per_epoch = self.num_samples // global_batch_size

# define LR schedule
Expand Down Expand Up @@ -435,6 +442,7 @@ def add_model_specific_args(parent_parser):

# training params
parser.add_argument("--fast_dev_run", action='store_true')
parser.add_argument("--nodes", default=1, type=int, help="number of nodes for training")
parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on")
parser.add_argument("--num_workers", default=16, type=int, help="num of workers per GPU")
parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/sgd")
Expand Down Expand Up @@ -471,8 +479,8 @@ def add_model_specific_args(parent_parser):

def cli_main():
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule
from pl_bolts.models.self_supervised.swav.transforms import SwAVEvalDataTransform, SwAVTrainDataTransform
from pl_bolts.models.self_supervised.swav.transforms import SwAVTrainDataTransform, SwAVEvalDataTransform
from pl_bolts.datamodules import STL10DataModule, CIFAR10DataModule, ImagenetDataModule

parser = ArgumentParser()

Expand Down Expand Up @@ -515,11 +523,44 @@ def cli_main():
args.size_crops = [32, 16]
args.nmb_crops = [2, 1]
args.gaussian_blur = False
elif args.dataset == 'imagenet':
args.maxpool1 = True
args.first_conv = True
normalization = imagenet_normalization()

args.size_crops = [224, 96]
args.min_scale_crops = [0.14, 0.05]
args.max_scale_crops = [1., 0.14]
args.gaussian_blur = True
args.jitter_strength = 1.

args.batch_size = 64
args.nodes = 8
args.gpus = 8 # per-node
args.max_epochs = 800

args.optimizer = 'sgd'
args.lars_wrapper = True
args.learning_rate = 4.8
args.final_lr = 0.0048
args.start_lr = 0.3

args.nmb_prototypes = 3000
args.online_ft = True

dm = ImagenetDataModule(
data_dir=args.data_path,
batch_size=args.batch_size,
num_workers=args.num_workers
)

args.num_samples = dm.num_samples
args.input_height = dm.size()[-1]
else:
raise NotImplementedError("other datasets have not been implemented till now")

dm.train_transforms = SwAVTrainDataTransform(
normalize=stl10_normalization(),
normalize=normalization,
size_crops=args.size_crops,
nmb_crops=args.nmb_crops,
min_scale_crops=args.min_scale_crops,
Expand All @@ -529,7 +570,7 @@ def cli_main():
)

dm.val_transforms = SwAVEvalDataTransform(
normalize=stl10_normalization(),
normalize=normalization,
size_crops=args.size_crops,
nmb_crops=args.nmb_crops,
min_scale_crops=args.min_scale_crops,
Expand All @@ -556,6 +597,7 @@ def cli_main():
max_epochs=args.max_epochs,
max_steps=None if args.max_steps == -1 else args.max_steps,
gpus=args.gpus,
num_nodes=args.nodes,
distributed_backend='ddp' if args.gpus > 1 else None,
sync_batchnorm=True if args.gpus > 1 else False,
precision=32 if args.fp32 else 16,
Expand Down
6 changes: 5 additions & 1 deletion pl_bolts/models/self_supervised/swav/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,12 @@ def __init__(
]

if self.gaussian_blur:
kernel_size = int(0.1 * self.size_crops[0])
if kernel_size % 2 == 0:
kernel_size += 1

color_transform.append(
GaussianBlur(kernel_size=int(0.1 * self.size_crops[0]), p=0.5)
GaussianBlur(kernel_size=kernel_size, p=0.5)
)

self.color_transform = transforms.Compose(color_transform)
Expand Down
1 change: 1 addition & 0 deletions tests/models/self_supervised/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def test_swav(tmpdir, datadir):
arch='resnet18',
hidden_mlp=512,
gpus=0,
nodes=1,
num_samples=datamodule.num_samples,
batch_size=batch_size,
nmb_crops=[2, 1],
Expand Down

0 comments on commit 7dabfae

Please sign in to comment.