-
Notifications
You must be signed in to change notification settings - Fork 323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SimSiam #407
SimSiam #407
Conversation
Hello @zlapp! Thanks for updating this PR.
Comment last updated at 2021-01-17 20:56:04 UTC |
Codecov Report
@@ Coverage Diff @@
## master #407 +/- ##
==========================================
- Coverage 79.49% 78.95% -0.55%
==========================================
Files 102 105 +3
Lines 5912 6121 +209
==========================================
+ Hits 4700 4833 +133
- Misses 1212 1288 +76
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
# Image 1 to image 2 loss | ||
_, z1, h1 = self.online_network(img_1) | ||
_, z2, h2 = self.target_network(img_2) | ||
loss_a = -1.0 * self.cosine_similarity(h1, z2) | ||
|
||
# Image 2 to image 1 loss | ||
_, z1, h1 = self.online_network(img_2) | ||
_, z2, h2 = self.target_network(img_1) | ||
loss_b = -1.0 * self.cosine_similarity(h1, z2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Believe it is equivalent. For example if you refer to BYOL implementation in pl bolts there is a deep copy here there isn't https://github.com/lucidrains/byol-pytorch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nonetheless, great job on the implementation. It might be worth investigating in the future if there is any difference in performance between the two methods. As with this approach, maybe there is more memory usage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With a deep copy instead of using the same network twice you use ~3GB gpu extra memory with resnet18. The two versions are about equally fast on cifar10, but I think that's because the gpu is memorybound on the task.
Given that the final performane ends up the same, I guess removing the copied network would be a good idea as it's less wasteful and scales better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @MikkelAntonsen.
Could you please share the version you ran with plot of results on CIFAR10?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for sharing @MikkelAntonsen.
I just pushed a commit with the changes you suggested in the gist.
Great job making the improvements in efficiency while maintaining accuracy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that the SSLOnlineEvaluator does an additional forward pass per train_batch to get embeddings. It would be possible to return the encoder output from training_step()
and which would be accessible in the outputs
argument in SSLOnlineEvaluator, AFAIK. This is unfortunate because it couples the implementation of SSLOnlineEvaluator with the networks that uses it. But if we are to reproduce results on imagenet, is reducing the number of forward passes by 1/3 negligible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @MikkelAntonsen I believe this is related more to SSL capabilities in general of pl_bolts so might be better to open a separate issue since this isn't only effecting SimSiam (tagging @ananyahjha93).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MikkelAntonsen Maybe the byol implementation can also benefit from using only one network (without deepcopy) and using detach() to control the gradient flow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just read the BYOL paper abstract and it seems like the online network and target network use a different set of weights, so I'm not sure how you could share a network. If you look at figure 2 in the paper, it does seem like they use the stop gradient trick, but only for the target network. Do you see any ways to incorporate simsiam ideas into BYOL without the implemention just ending up as simsiam?
@zlapp how is it going here, is it still WIP? |
Hi @Borda, based on the results here #407 (comment) I believe the PR is ready to be merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zlapp Would you mind having a look at https://github.com/zlapp/pytorch-lightning-bolts/pull/1?
@zlapp as a potential user thanks so much for your (+ the PR reviewers!) work, this looks great. I do have one question: the paper notes in appendix B that one difference between SimSiam and BYOL is the bottleneck structure in the predictor. Namely, the hidden dimension of the predictor MLP should be 1/4 of the output dimension. So, for example, with the default prediction space dimension of 256, I think the hidden dim should be 64. This apparently helps with training stability. Do you agree with my reading of the paper? If so, I'm wondering if this should be the default in the bolt as well? |
Could you also have a look at https://github.com/zlapp/pytorch-lightning-bolts/pull/2? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
The same question as @wjn0. |
What does this PR do?
Implement https://arxiv.org/pdf/2011.10566v1.pdf
Largely based on https://github.com/lucidrains/byol-pytorch extension of BYOL to support SimSiam.
I used pl-bolts BYOL implementation as a reference.
Colab gist for testing on cifar-10 https://gist.github.com/zlapp/c35b8c97d4f6537f21aa07bbc37959c9
Discussed on slack channel https://pytorch-lightning.slack.com/archives/C010PRC9M2R/p1606329394008100
Also adds KNN online evaluation callback
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃