-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* min req * imports * imports * split * imports * imports
- Loading branch information
Showing
10 changed files
with
80 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import math | ||
|
||
from pytorch_lightning import Callback | ||
|
||
|
||
class BYOLMAWeightUpdate(Callback): | ||
""" | ||
Weight update rule from BYOL. | ||
Your model should have a: | ||
- self.online_network. | ||
- self.target_network. | ||
Updates the target_network params using an exponential moving average update rule weighted by tau. | ||
BYOL claims this keeps the online_network from collapsing. | ||
.. note:: Automatically increases tau from `initial_tau` to 1.0 with every training step | ||
Example:: | ||
# model must have 2 attributes | ||
model = Model() | ||
model.online_network = ... | ||
model.target_network = ... | ||
trainer = Trainer(callbacks=[BYOLMAWeightUpdate()]) | ||
""" | ||
|
||
def __init__(self, initial_tau=0.996): | ||
""" | ||
Args: | ||
initial_tau: starting tau. Auto-updates with every training step | ||
""" | ||
super().__init__() | ||
self.initial_tau = initial_tau | ||
self.current_tau = initial_tau | ||
|
||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): | ||
# get networks | ||
online_net = pl_module.online_network | ||
target_net = pl_module.target_network | ||
|
||
# update weights | ||
self.update_weights(online_net, target_net) | ||
|
||
# update tau after | ||
self.current_tau = self.update_tau(pl_module, trainer) | ||
|
||
def update_tau(self, pl_module, trainer): | ||
max_steps = len(trainer.train_dataloader) * trainer.max_epochs | ||
tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / max_steps) + 1) / 2 | ||
return tau | ||
|
||
def update_weights(self, online_net, target_net): | ||
# apply MA weight update | ||
for (name, online_p), (_, target_p) in zip(online_net.named_parameters(), target_net.named_parameters()): | ||
if 'weight' in name: | ||
target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
torch>=1.6 | ||
pytorch-lightning>=1.0.2 | ||
pytorch-lightning>=1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters