Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow BF16 dtype support on CPU #1218

Merged
merged 1 commit into from
Jul 26, 2024
Merged

Conversation

sanchitintel
Copy link
Contributor

@sanchitintel sanchitintel commented Jul 24, 2024

Description

PyTorch supports BF16 dtype for CPUs. If CPUs don't support some BF16-related ISAs such as AVX512_BF16 & AMX_BF16, BF16 <-> FP32 conversions are done (compute happens in FP32, in these cases).

Changelog

BF16 dtype support check would now return True for CPUs.
It was returning False earlier in the quantize.py recipe if PyTorch was installed without CUDA support and device type was set to cpu in a config yaml file because no value for the device argument was being provided in the invocation to utils.get_dtype(), so it was defaulting to None, and utils.get_dtype() was throwing an error about the device not supporting BFloat16 .
Added device argument in calls to utils.get_dtype().

Test plan

Verified manually that the quantization recipe on CPU is not failing due to utils.get_dtype() returning False.
I could add a UT for CPU device, if required. Thanks!

Copy link

pytorch-bot bot commented Jul 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1218

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c832dd6 with merge base 6e4809a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @sanchitintel!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 24, 2024
@joecummings
Copy link
Contributor

@ebsmothers Any reason not to do this?

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchitintel Great PR! Could you just run lint for these files?

@sanchitintel
Copy link
Contributor Author

sanchitintel commented Jul 24, 2024

Hi @joecummings, sorry, I haven't verified this change (and it doesn't look correct). I'll request review after verifying it, and will remove Draft mode for this PR. Thanks!

@sanchitintel sanchitintel force-pushed the patch-1 branch 3 times, most recently from 55d9518 to 12c2dad Compare July 24, 2024 21:18
@sanchitintel
Copy link
Contributor Author

sanchitintel commented Jul 24, 2024

Hi @joecummings, the PR is now ready for review. Thanks!

At my end, though, running pre-commit install to set up the linter results in a run-time error importlib.metadata.PackageNotFoundError: No package metadata was found for pre-commit, although I installed pre-commit:

@sanchitintel sanchitintel marked this pull request as ready for review July 24, 2024 22:22
@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Jul 24, 2024

Thanks for adding this @sanchitintel : ) pretty neat change

Perhaps one suggestion while we're here. RE this comment:

    # TODO (rohan-varma): prefer to use get_default_device() here to figure out whether user is training on
    # CPU or GPU, but it is not supported in versions of torch we test.

it looks like get_default_device has landed in stable - not sure whether it's worth just switching to use this and omit the device param?

@ebsmothers
Copy link
Contributor

Thanks for the PR! I partially agree with @SalmanMohammadi -- get_default_device is probably the cleanest way to handle this. But I also think we should still pass the device explicitly once we infer it in the recipe (otherwise we have to infer it in the dtype utilities which creates needless separation on where defaults are defined).

So I think we can delete _get_device_type_from_env and just replace its usage here with a call to get_default_device. Then when we call get_device in the recipe (e.g. here for the generate recipe) we can pass that explicitly to get_dtype as you've done here.

@ebsmothers
Copy link
Contributor

Ah sorry @sanchitintel I did not look at the code for get_default_device closely enough. I actually don't think we should use this after all (this is why our GPU unit tests are failing). This just gives the default device that tensors will be allocated to (which is generally CPU, even if there are GPUs available) rather than telling us if there are CUDA devices available. For example, on my machine with GPUs:

>>> import torch
>>> torch.get_default_device()
device(type='cpu')
>>> torch.cuda.is_available()
True

So this is actually doing something different than our existing _get_device_type_from_env. Sorry for the thrash on this, but I think you should revert to the previous version of the PR and we can just land that. Please also remove the TODO referenced by @SalmanMohammadi as this turned out to be misleading.

@sanchitintel
Copy link
Contributor Author

sanchitintel commented Jul 25, 2024

Thanks for your prompt feedback, @SalmanMohammadi & @ebsmothers!

I reverted the change pertaining to torch.get_default_device, and also removed the related comment.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for fixing this!

@ebsmothers ebsmothers merged commit e101420 into pytorch:main Jul 26, 2024
29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants