Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BlockMask to batch_to_device #1651

Merged
merged 2 commits into from
Sep 23, 2024

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Sep 22, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

batch_to_device only checks dict and tensors. When input is BlockMask (packed=True), it triggers the RaiseError condition.

Changelog

  • Improve error message
  • Add BlockMask as one of the isinstance checks
  • Add file _import_guard.py to check for flex attention, since we have two different files checking it, and we should have only one source of truth

Test plan

After this PR, the code below runs well
tune run full_finetune_single_device --config llama3_1/8B_full_single_device dataset.packed=True tokenizer.max_seq_len=4096

If I remove the check for BlockMask, the error message is improved:

ValueError: To use batch_to_device, all elements in the batch must be a dict or Tensor.
Got key "mask" with value of type <class 'torch.nn.attention.flex_attention.BlockMask'>

Copy link

pytorch-bot bot commented Sep 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1651

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0dd8ec5 with merge base bf93806 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 22, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 80.00000% with 2 lines in your changes missing coverage. Please review.

Project coverage is 69.09%. Comparing base (bf93806) to head (0dd8ec5).

Files with missing lines Patch % Lines
torchtune/utils/_device.py 66.66% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1651      +/-   ##
==========================================
- Coverage   71.04%   69.09%   -1.95%     
==========================================
  Files         296      297       +1     
  Lines       15112    15120       +8     
==========================================
- Hits        10736    10447     -289     
- Misses       4376     4673     +297     
Flag Coverage Δ
69.09% <80.00%> (-1.95%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -142,7 +149,10 @@ def batch_to_device(batch: dict, device: torch.device) -> None:
batch_to_device(v, device)
elif isinstance(v, torch.Tensor):
batch[k] = v.to(device)
elif _SUPPORTS_FLEX_ATTENTION and isinstance(v, BlockMask):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you just combine this case and the above with or?

@felipemello1 felipemello1 merged commit 50b24e5 into pytorch:main Sep 23, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants