Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1/n - remove TiedEmbeddingTransformerDecoder from qwen #1547

Merged
merged 14 commits into from
Sep 12, 2024

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Sep 11, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

we dont need TiedEmbeddingTransformerDecoder if we pass output_proj as a lambda

Changelog

  • pass output_proj as lambda to the model TransformerDecoder
  • deprecate TiedEmbeddingTransformerDecoder, as qwen is the only model using it. Gemma has its own transformer, and will be deprecated in a follow up PR

Test plan

tune run full_finetune_single_device --config qwen2/0.5B_full_single_device batch_size=64 max_steps_per_epoch=30 metric_logger=torchtune.training.metric_logging.WandBLogger
image

tune run --nnodes 1 --nproc_per_node 8 full_finetune_distributed --config qwen2/1.5B_full batch_size=8 max_steps_per_epoch=20 metric_logger=torchtune.training.metric_logging.WandBLogger gradient_accumulation_steps=1 epochs=1
image

resume from checkpoint:

ran it twice, the second time using the previous checkpoint

tune run --nnodes 1 --nproc_per_node 8 lora_finetune_distributed --config qwen2/0.5B_lora batch_size=8 max_steps_per_epoch=20 metric_logger=torchtune.training.metric_logging.WandBLogger gradient_accumulation_steps=1 epochs=2 compile=True

also added to the transformer this code to check if the weights were still tied:

all_close = torch.allclose(self.output(h), torch.nn.functional.linear(h, self.tok_embeddings.weight), atol=1e-5)
if not all_close:
    print("not all close")
else:
    print("all close")
image image

Copy link

pytorch-bot bot commented Sep 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1547

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b26b4fc with merge base d7fae96 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 11, 2024
Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to see this. I remember the lambda has some weird interactions with FSDP in the past but that may not be the case with FSDP2. As long as you're able to test on a distributed recipe and get identical loss, I have no concerns. Stamping to unblock

@@ -10,8 +10,8 @@
import torch.nn.functional as F
from torch import nn
from torchtune.modules import MultiHeadAttention

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PS: I forgot to update the typehint of the TransformerDecoder, saying that output can now be a callable. To avoid rerunning tests, this will come in a followin gemma PR

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great. One other request for a sanity check prior to landing: please make sure that you're able to save the checkpoint then resuming training again properly (especially for a distributed run). Other than that no concerns from me!

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 76.00000% with 6 lines in your changes missing coverage. Please review.

Project coverage is 73.35%. Comparing base (221031a) to head (a0bd26b).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/modules/tied_linear.py 60.00% 4 Missing ⚠️
torchtune/models/qwen2/_component_builders.py 75.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1547      +/-   ##
==========================================
- Coverage   73.36%   73.35%   -0.02%     
==========================================
  Files         287      288       +1     
  Lines       14142    14151       +9     
==========================================
+ Hits        10375    10380       +5     
- Misses       3767     3771       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@felipemello1 felipemello1 changed the title remove TiedEmbeddingTransformerDecoder from qwen 1/n - remove TiedEmbeddingTransformerDecoder from qwen Sep 12, 2024
@@ -32,7 +33,7 @@
"KVCache",
"RotaryPositionalEmbeddings",
"RMSNorm",
"Fp32LayerNorm",
"TiedLinear" "Fp32LayerNorm",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatting weird?

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two nits

@@ -516,6 +516,10 @@ def forward(
return output


@deprecated(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tell them to use TransformerDecoder WITH TiedLinear.

@felipemello1 felipemello1 merged commit 7c51100 into pytorch:main Sep 12, 2024
17 checks passed
@felipemello1 felipemello1 deleted the remove_tied_embeddings branch September 12, 2024 18:20
ebsmothers pushed a commit that referenced this pull request Sep 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants