Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update swav to override optimizer_step with optimizer.step(closure=op… #323

Merged
merged 7 commits into from
Nov 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions pl_bolts/models/self_supervised/swav/swav_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
import torch
import torch.distributed as dist
from torch import nn
from torch.optim.optimizer import Optimizer

from pl_bolts.models.self_supervised.swav.swav_resnet import resnet50, resnet18
from typing import Callable, Optional
from pytorch_lightning.utilities import AMPType

from pl_bolts.models.self_supervised.swav.swav_resnet import resnet50, resnet18
from pl_bolts.transforms.dataset_normalizations import stl10_normalization, cifar10_normalization
from pl_bolts.optimizers.lars_scheduling import LARSWrapper

Expand Down Expand Up @@ -321,15 +324,15 @@ def configure_optimizers(self):

def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
second_order_closure=None,
on_tpu=False,
using_native_amp=False,
using_lbfgs=False
):
epoch: int,
batch_idx: int,
optimizer: Optimizer,
optimizer_idx: int,
optimizer_closure: Optional[Callable] = None,
on_tpu: bool = False,
using_native_amp: bool = False,
using_lbfgs: bool = False,
) -> None:
# warm-up + decay schedule placed here since LARSWrapper is not optimizer class
# adjust LR of optim contained within LARSWrapper
if self.lars_wrapper:
Expand All @@ -340,14 +343,18 @@ def optimizer_step(
param_group["lr"] = self.lr_schedule[self.trainer.global_step]

# log LR (LearningRateLogger callback doesn't work with LARSWrapper)
learning_rate = {'learning_rate': self.lr_schedule[self.trainer.global_step]}
self.logger.log_metrics(learning_rate, step=self.trainer.global_step)

# from lightning implementation
if using_native_amp:
self.trainer.scaler.step(optimizer)
else:
optimizer.step()
self.log('learning_rate', self.lr_schedule[self.trainer.global_step], on_step=True, on_epoch=False)

super().optimizer_step(
epoch=epoch,
batch_idx=batch_idx,
optimizer=optimizer,
optimizer_idx=optimizer_idx,
optimizer_closure=optimizer_closure,
on_tpu=on_tpu,
using_native_amp=using_native_amp,
using_lbfgs=using_lbfgs,
)

def sinkhorn(self, Q, nmb_iters):
with torch.no_grad():
Expand Down
6 changes: 4 additions & 2 deletions tests/models/self_supervised/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,7 @@ def test_swav(tmpdir):
gpus=0, fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, max_steps=3
)

results = trainer.fit(model, datamodule)
assert results == 1
trainer.fit(model, datamodule)
loss = trainer.progress_bar_dict['loss']

assert float(loss) > 0