-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
1/n - remove TiedEmbeddingTransformerDecoder from qwen #1547
1/n - remove TiedEmbeddingTransformerDecoder from qwen #1547
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1547
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b26b4fc with merge base d7fae96 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great to see this. I remember the lambda has some weird interactions with FSDP in the past but that may not be the case with FSDP2. As long as you're able to test on a distributed recipe and get identical loss, I have no concerns. Stamping to unblock
@@ -10,8 +10,8 @@ | |||
import torch.nn.functional as F | |||
from torch import nn | |||
from torchtune.modules import MultiHeadAttention | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PS: I forgot to update the typehint of the TransformerDecoder, saying that output can now be a callable. To avoid rerunning tests, this will come in a followin gemma PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great. One other request for a sanity check prior to landing: please make sure that you're able to save the checkpoint then resuming training again properly (especially for a distributed run). Other than that no concerns from me!
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1547 +/- ##
==========================================
- Coverage 73.36% 73.35% -0.02%
==========================================
Files 287 288 +1
Lines 14142 14151 +9
==========================================
+ Hits 10375 10380 +5
- Misses 3767 3771 +4 ☔ View full report in Codecov by Sentry. |
torchtune/modules/__init__.py
Outdated
@@ -32,7 +33,7 @@ | |||
"KVCache", | |||
"RotaryPositionalEmbeddings", | |||
"RMSNorm", | |||
"Fp32LayerNorm", | |||
"TiedLinear" "Fp32LayerNorm", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
formatting weird?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
two nits
@@ -516,6 +516,10 @@ def forward( | |||
return output | |||
|
|||
|
|||
@deprecated( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tell them to use TransformerDecoder WITH TiedLinear.
Co-authored-by: Felipe Mello <[email protected]>
Context
What is the purpose of this PR? Is it to
we dont need TiedEmbeddingTransformerDecoder if we pass output_proj as a lambda
Changelog
Test plan
tune run --nnodes 1 --nproc_per_node 8 full_finetune_distributed --config qwen2/1.5B_full batch_size=8 max_steps_per_epoch=20 metric_logger=torchtune.training.metric_logging.WandBLogger gradient_accumulation_steps=1 epochs=1
resume from checkpoint:
ran it twice, the second time using the previous checkpoint
tune run --nnodes 1 --nproc_per_node 8 lora_finetune_distributed --config qwen2/0.5B_lora batch_size=8 max_steps_per_epoch=20 metric_logger=torchtune.training.metric_logging.WandBLogger gradient_accumulation_steps=1 epochs=2 compile=True
also added to the transformer this code to check if the weights were still tied: