Skip to content

Commit

Permalink
fix imports (#288)
Browse files Browse the repository at this point in the history
* min req

* imports

* imports

* split

* imports

* imports
  • Loading branch information
Borda authored Oct 22, 2020
1 parent dc16e7b commit c21bfe2
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 90 deletions.
4 changes: 2 additions & 2 deletions docs/source/self_supervised_callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ BYOLMAWeightUpdate
------------------
The exponential moving average weight-update rule from Bring Your Own Latent (BYOL).

.. autoclass:: pl_bolts.callbacks.self_supervised.BYOLMAWeightUpdate
.. autoclass:: pl_bolts.callbacks.byol_updates.BYOLMAWeightUpdate
:noindex:

----------------
Expand All @@ -20,5 +20,5 @@ SSLOnlineEvaluator
------------------
Appends a MLP for fine-tuning to the given model. Callback has its own mini-inner loop.

.. autoclass:: pl_bolts.callbacks.self_supervised.SSLOnlineEvaluator
.. autoclass:: pl_bolts.callbacks.ssl_online.SSLOnlineEvaluator
:noindex:
59 changes: 59 additions & 0 deletions pl_bolts/callbacks/byol_updates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import math

from pytorch_lightning import Callback


class BYOLMAWeightUpdate(Callback):
"""
Weight update rule from BYOL.
Your model should have a:
- self.online_network.
- self.target_network.
Updates the target_network params using an exponential moving average update rule weighted by tau.
BYOL claims this keeps the online_network from collapsing.
.. note:: Automatically increases tau from `initial_tau` to 1.0 with every training step
Example::
# model must have 2 attributes
model = Model()
model.online_network = ...
model.target_network = ...
trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
"""

def __init__(self, initial_tau=0.996):
"""
Args:
initial_tau: starting tau. Auto-updates with every training step
"""
super().__init__()
self.initial_tau = initial_tau
self.current_tau = initial_tau

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# get networks
online_net = pl_module.online_network
target_net = pl_module.target_network

# update weights
self.update_weights(online_net, target_net)

# update tau after
self.current_tau = self.update_tau(pl_module, trainer)

def update_tau(self, pl_module, trainer):
max_steps = len(trainer.train_dataloader) * trainer.max_epochs
tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / max_steps) + 1) / 2
return tau

def update_weights(self, online_net, target_net):
# apply MA weight update
for (name, online_p), (_, target_p) in zip(online_net.named_parameters(), target_net.named_parameters()):
if 'weight' in name:
target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import math
from typing import Optional

import pytorch_lightning as pl
import torch
from pytorch_lightning import Callback
from pytorch_lightning.metrics.functional import accuracy
from torch.nn import functional as F

from pl_bolts.models.self_supervised.evaluator import SSLEvaluator

class SSLOnlineEvaluator(pl.Callback): # pragma: no-cover

class SSLOnlineEvaluator(Callback): # pragma: no-cover
"""
Attaches a MLP for finetuning using the standard self-supervised protocol.
Attaches a MLP for fine-tuning using the standard self-supervised protocol.
Example::
from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator
# your model must have 2 attributes
model = Model()
model.z_dim = ... # the representation dim
Expand All @@ -30,8 +29,8 @@ def __init__(
):
"""
Args:
drop_p: (0.2) dropout probability
hidden_dim: (1024) the hidden dimension for the finetune MLP
drop_p: dropout probability
hidden_dim: the hidden dimension for the fine-tune MLP
"""
super().__init__()
self.hidden_dim = hidden_dim
Expand All @@ -41,8 +40,6 @@ def __init__(
self.num_classes = num_classes

def on_pretrain_routine_start(self, trainer, pl_module):
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator

# attach the evaluator to the module

if hasattr(pl_module, 'z_dim'):
Expand Down Expand Up @@ -101,61 +98,3 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data

metrics = {'ft_callback_mlp_loss': mlp_loss, 'ft_callback_mlp_acc': acc}
pl_module.logger.log_metrics(metrics, step=trainer.global_step)


class BYOLMAWeightUpdate(pl.Callback):
"""
Weight update rule from BYOL.
Your model should have a:
- self.online_network.
- self.target_network.
Updates the target_network params using an exponential moving average update rule weighted by tau.
BYOL claims this keeps the online_network from collapsing.
.. note:: Automatically increases tau from `initial_tau` to 1.0 with every training step
Example::
from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate
# model must have 2 attributes
model = Model()
model.online_network = ...
model.target_network = ...
trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
"""

def __init__(self, initial_tau=0.996):
"""
Args:
initial_tau: starting tau. Auto-updates with every training step
"""
super().__init__()
self.initial_tau = initial_tau
self.current_tau = initial_tau

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# get networks
online_net = pl_module.online_network
target_net = pl_module.target_network

# update weights
self.update_weights(online_net, target_net)

# update tau after
self.current_tau = self.update_tau(pl_module, trainer)

def update_tau(self, pl_module, trainer):
max_steps = len(trainer.train_dataloader) * trainer.max_epochs
tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / max_steps) + 1) / 2
return tau

def update_weights(self, online_net, target_net):
# apply MA weight update
for (name, online_p), (_, target_p) in zip(online_net.named_parameters(), target_net.named_parameters()):
if 'weight' in name:
target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data
12 changes: 2 additions & 10 deletions pl_bolts/models/self_supervised/byol/byol_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pytorch_lightning import seed_everything
from torch.optim import Adam

from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate
from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate
from pl_bolts.models.self_supervised.byol.models import SiameseArm
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
Expand All @@ -35,16 +35,8 @@ class BYOL(pl.LightningModule):
Example::
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import BYOL
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.transforms import (
SimCLREvalDataTransform, SimCLRTrainDataTransform)
# model
model = BYOL(num_classes=10)
# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)
Expand Down Expand Up @@ -184,7 +176,7 @@ def add_model_specific_args(parent_parser):


def cli_main():
from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule
from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform, SimCLREvalDataTransform

Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/self_supervised/cpc/cpc_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch.optim as optim
from pytorch_lightning.utilities import rank_zero_warn

from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator
from pl_bolts.losses.self_supervised_learning import CPCTask
from pl_bolts.models.self_supervised.cpc.networks import cpc_resnet101
from pl_bolts.models.self_supervised.cpc.transforms import (
Expand Down Expand Up @@ -230,6 +229,7 @@ def add_model_specific_args(parent_parser):


def cli_main():
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule

Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/self_supervised/simclr/simclr_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover
' install it with `pip install torchvision`.')

from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator
from pl_bolts.losses.self_supervised_learning import nt_xent_loss
from pl_bolts.models.self_supervised.evaluator import Flatten
from pl_bolts.models.self_supervised.resnets import resnet50_bn
Expand Down Expand Up @@ -218,6 +217,7 @@ def add_model_specific_args(parent_parser):


def cli_main():
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule

parser = ArgumentParser()
Expand Down
8 changes: 4 additions & 4 deletions pl_bolts/models/self_supervised/swav/swav_online_eval.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Optional

import pytorch_lightning as pl
import torch
from pytorch_lightning import Callback
from pytorch_lightning.metrics.functional import accuracy
from torch.nn import functional as F

from pl_bolts.models.self_supervised.evaluator import SSLEvaluator

class SwavOnlineEvaluator(pl.Callback):

class SwavOnlineEvaluator(Callback):
def __init__(
self,
drop_p: float = 0.2,
Expand All @@ -29,8 +31,6 @@ def __init__(
self.acc = []

def on_pretrain_routine_start(self, trainer, pl_module):
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator

pl_module.non_linear_evaluator = SSLEvaluator(
n_input=self.z_dim,
n_classes=self.num_classes,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch>=1.6
pytorch-lightning>=1.0.2
pytorch-lightning>=1.0
2 changes: 1 addition & 1 deletion tests/callbacks/test_param_update_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import nn

from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate
from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate


def test_byol_ma_weight_update_callback(tmpdir):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/self_supervised/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import torch
from pytorch_lightning import seed_everything

from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised import CPCV2, AMDIM, MocoV2, SimCLR, BYOL, SwAV
from pl_bolts.models.self_supervised.cpc import CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10
from pl_bolts.models.self_supervised.moco.callbacks import MocoLRScheduler
from pl_bolts.models.self_supervised.moco.transforms import (Moco2TrainCIFAR10Transforms, Moco2EvalCIFAR10Transforms)
from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform
from pl_bolts.models.self_supervised.swav.transforms import SwAVTrainDataTransform, SwAVEvalDataTransform
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization


# TODO: this test is hanging (runs for more then 10min) so we need to use GPU or optimize it...
Expand Down

0 comments on commit c21bfe2

Please sign in to comment.