Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compile error in get_causal_mask_from_padding_mask #1627

Merged

Conversation

SalmanMohammadi
Copy link
Collaborator

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Changelog

What are the changes made in this PR?

Fixing a compile bug in get_causal_mask_from_padding_mask.

Repro on main

import torch
from torchtune.generation import get_causal_mask_from_padding_mask
get_causal_mask_from_padding_mask = torch.compile(get_causal_mask_from_padding_mask, fullgraph=True, backend="aot_eager")
get_causal_mask_from_padding_mask(torch.ones((10)).unsqueeze(0).bool())
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
RuntimeError: aten::copy() Expected a value of type 'Tensor' for argument 'src' but instead found type 'bool'.
Position: 1
Value: True
Declaration: aten::copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor
Cast error details: Unable to cast True to Tensor

While executing %copy_ : [num_users=0] = call_method[target=copy_](args = (%diagonal, True), kwargs = {})
Original traceback:
  File "/Users/salmanmohammadi/projects/torchtune/torchtune/generation/_generation.py", line 150, in get_causal_mask_from_padding_mask
    mask.diagonal(dim1=1, dim2=2).copy_(True)


Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

This error does not occur on this branch.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Sep 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1627

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8ea0473 with merge base c5db813 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 19, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 0% with 1 line in your changes missing coverage. Please review.

Project coverage is 26.23%. Comparing base (dd348ce) to head (8ea0473).
Report is 7 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/generation/_generation.py 0.00% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (dd348ce) and HEAD (8ea0473). Click for more details.

HEAD has 1 upload less than BASE
Flag BASE (dd348ce) HEAD (8ea0473)
3 2
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1627       +/-   ##
===========================================
- Coverage   72.26%   26.23%   -46.03%     
===========================================
  Files         290      295        +5     
  Lines       14554    15079      +525     
===========================================
- Hits        10517     3956     -6561     
- Misses       4037    11123     +7086     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@SalmanMohammadi SalmanMohammadi merged commit e3718e8 into pytorch:main Sep 20, 2024
17 checks passed
@SalmanMohammadi SalmanMohammadi deleted the fix_compile_get_causal_mask branch September 27, 2024 13:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants