Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable float8 attention support (q/k/v) #1382

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Dec 5, 2024

Summary:
This PR integrates flashattention 3 kernel: https://github.com/Dao-AILab/flash-attention/blob/1feb711f46563960fc10a8e659c93c300619504b/flash_attn/flash_attn_interface.py#L1102 to float8 affine quantized tensor.

To use the kernel, right now we need to manually add quantize call for q/k/v before sdpa op, but we can explore other APIs in the future

@sijiac is working on adding new variations of attention implementation in the future (per row, per column, per block scaling etc.).

Test Plan:
python test/dtypes/test_affine_quantized_float.py -k test_float8_attention

SAM2

tested on sam2 and seems to be a bit slower than before, this is reasonable because sam2 is using 16 and 32 head dimension, but fa3 requires 64 being the minimum size, we need to do some padding to make this work (pad 32 to 64) which is expected to increase runtime significantly.

llama2

llama2 without fallback: doesn't work because attn_mask is not supported.

llama2 numerics only

(just for testing, code is not checked in) tested on llama2 (with fallback to test numerics):

since attn_mask is not supported in flashattention 3 kernel, it's using the fallback path: https://github.com/pytorch/ao/pull/1382/files#diff-3019e8f38b0919dbaba5aa1329a697e89fc98749e35a7bdc274c71a0d3738ec2R285

no quantize attn:

wikitext: {'alias': 'wikitext', 'word_perplexity,none': 12.2451228989592, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.5975503222785694, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.675861376279183, 'bits_per_byte_stderr,none': 'N/A'}

quantize attn:

wikitext: {'alias': 'wikitext', 'word_perplexity,none': 12.294668124182962, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.5987571166545238, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.6769507810861697, 'bits_per_byte_stderr,none': 'N/A'}

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:
att, right now we need to manually add quantize call for q/k/v before
sdpa op, but we can explore other APIs in the future

Test Plan:
TBD

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link

pytorch-bot bot commented Dec 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1382

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 2895626 with merge base 04d611a (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 5, 2024
@jerryzh168 jerryzh168 added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Dec 5, 2024
@jerryzh168 jerryzh168 changed the title [WIP] Enable float8 attention support (q/k/v) Enable float8 attention support (q/k/v) Dec 5, 2024
q_float8_data = q_tensor_impl.float8_data
# change from scalar to tensor of size [1]
q_scale = q_tensor_impl.scale
q_scale = torch.tensor([q_scale], device=q_scale.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the scales on host?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean q_scale before we call torch.tensor? they should be using the same device as original weight I think, so should be on cuda before

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, left a few comments. I think we can add the int8 kernel when it gets added as well for CPU

@jerryzh168 jerryzh168 requested review from drisspg and vkuzo December 5, 2024 23:51
from torchao.quantization.quant_api import _float8_symmetric_per_tensor_quant
original_dtype = v.dtype
if q.shape[-1] in [64, 128, 256]:
q = _float8_symmetric_per_tensor_quant(q)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also likely need/want to apply the hadamard transform. I don't remember off hand if this is include in fav3 api

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't see this, maybe we can add it after spinquant is integrated

Copy link

@bhack bhack Dec 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also likely need/want to apply the hadamard transform. I don't remember off hand if this is include in fav3 api

https://pytorch.org/blog/hadacore/

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the current SAM2 readme with all the ao optimizations we have introduced?

yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* Update multimodal.md

Complete markup for testing

* Update run-docs

Add ability to run on docs/multimodal.md

* Update run-readme-pr.yml
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants