-
Notifications
You must be signed in to change notification settings - Fork 683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torchao quant (int4/int8/fp8) to llama models #1341
Conversation
Summary: We want to hack before we work on a proper solution proper solution will be rewrite llama model with tensor parallelism: https://pytorch.org/docs/stable/distributed.tensor.parallel.html (using DTensor underneath), trying to do it here: pytorch/ao#785 Test Plan: change `ENABLE_TORCHAO` to True/False in `python/sglang/srt/models/llama.py` to test the baseline v.s. torchao int4 weight only quant performance python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 ``` max_total_num_tokens=432196 Warmup ... Prefill. latency: 0.03214 s, throughput: 3983.19 token/s Decode. latency: 0.01383 s, throughput: 72.31 token/s Decode. latency: 0.01354 s, throughput: 73.88 token/s Decode. latency: 0.01338 s, throughput: 74.75 token/s Decode. latency: 0.01330 s, throughput: 75.17 token/s Decode. median latency: 0.01346 s, median throughput: 74.31 token/s Total. latency: 0.086 s, throughput: 1531.66 token/s Benchmark ... Prefill. latency: 0.02514 s, throughput: 5092.40 token/s Decode. latency: 0.01337 s, throughput: 74.80 token/s Decode. latency: 0.01338 s, throughput: 74.74 token/s Decode. latency: 0.01339 s, throughput: 74.68 token/s Decode. latency: 0.01321 s, throughput: 75.68 token/s Decode. latency: 0.01295 s, throughput: 77.23 token/s Decode. median latency: 0.01337 s, median throughput: 74.77 token/s Total. latency: 0.132 s, throughput: 1032.13 token/s max_total_num_tokens=505188 Warmup ... Prefill. latency: 0.10929 s, throughput: 1171.18 token/s Decode. latency: 0.00790 s, throughput: 126.57 token/s Decode. latency: 0.00738 s, throughput: 135.54 token/s Decode. latency: 0.00724 s, throughput: 138.16 token/s Decode. latency: 0.00726 s, throughput: 137.71 token/s Decode. median latency: 0.00732 s, median throughput: 136.62 token/s Total. latency: 0.139 s, throughput: 949.17 token/s Benchmark ... Prefill. latency: 0.10405 s, throughput: 1230.13 token/s Decode. latency: 0.00769 s, throughput: 129.96 token/s Decode. latency: 0.00725 s, throughput: 137.85 token/s Decode. latency: 0.00724 s, throughput: 138.11 token/s Decode. latency: 0.00731 s, throughput: 136.72 token/s Decode. latency: 0.00744 s, throughput: 134.47 token/s Decode. median latency: 0.00730 s, median throughput: 136.97 token/s Total. latency: 0.163 s, throughput: 834.99 token/s Warmup ... Prefill. latency: 0.05868 s, throughput: 2181.51 token/s Decode. latency: 0.04475 s, throughput: 22.35 token/s Decode. latency: 0.04463 s, throughput: 22.41 token/s Decode. latency: 0.04467 s, throughput: 22.39 token/s Decode. latency: 0.04478 s, throughput: 22.33 token/s Decode. median latency: 0.04471 s, median throughput: 22.37 token/s Total. latency: 0.238 s, throughput: 555.78 token/s Benchmark ... Prefill. latency: 0.05274 s, throughput: 2427.22 token/s Decode. latency: 0.04463 s, throughput: 22.41 token/s Decode. latency: 0.04456 s, throughput: 22.44 token/s Decode. latency: 0.04453 s, throughput: 22.45 token/s Decode. latency: 0.04469 s, throughput: 22.38 token/s Decode. latency: 0.04457 s, throughput: 22.44 token/s Decode. median latency: 0.04457 s, median throughput: 22.44 token/s Total. latency: 0.409 s, throughput: 332.13 token/s ``` Reviewers: Subscribers: Tasks: Tags:
Hi @msaroufim @jerryzh168 Nice work! It looks like the CI failures are due to a missing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the changes are very minimal. I feel that we can even merge this and start iteration. Here are a few thoughts on making it more clean.
- Add a new argument
--enable-torchao
similar to--enable-torch-compile
heresglang/python/sglang/srt/server_args.py
Line 432 in ab4a83b
"--enable-torch-compile", - Alternatively, instead of adding
--enable-torchao
, we can add a new quantization formattorchao-int4
heresglang/python/sglang/srt/server_args.py
Line 224 in ab4a83b
"--quantization", - You can pass the above arguments to this global variable
sglang/python/sglang/srt/model_executor/model_runner.py
Lines 93 to 100 in ab4a83b
global_server_args_dict.update( { "disable_flashinfer": server_args.disable_flashinfer, "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "enable_mla": server_args.enable_mla, } ) llama.py
@merrymercy thanks for the suggestions, I can start with https://github.com/pytorch/ao/blob/1ce7da941b311ef6ab416d9c7a22be20e65b7495/torchao/_models/llama/generate.py#L462 I think ideally we'd want to be able to specify different types of quantization like I could be using sglang/python/sglang/srt/models/llama.py Line 248 in ab4a83b
|
Feel free to copy over the code as we will gradually remove the dependency of vLLM (likely in one month or so) |
@jerryzh168 Thanks for the contribution. It is merged. I pushed some updates:
|
I verified the performance on H100. Here are the results with a few comments:
I also got some errors with the current torchao release on Pypi when trying the following settings. Which version should I use?
|
Thanks @merrymercy for the fixes, tests and merging the PR!
not yet I think, but we can take a look to understand why
yeah int8 requires torch.compile to speedup, using fp8 sounds good to us as well.
yeah currently we have torchao 0.4 in pypi, we are doing a new 0.5 release this week so fp8 should be supported this week
I've seen this before, this is because of some issue in pytorch dynamo/inductor code I think, and it should be fixed in pytorch nightly, maybe you can try installing pytorch nightly first to verify that the issue can be fix first. also do you have a version requirement for pytorch in sglang? |
@jerryzh168 We use torch==2.4.0 |
I see, for 2.4.0 and below, we'd need to call https://github.com/pytorch/ao/blob/e1039abac7f429a8d7f489d047d9b34d6ac6afe2/torchao/utils.py#L269 for the model to be compilable I think, I'll try to verify as well when I got a chance today |
for 2.4.0 there seems to be some issue even with work around, trying to debug now |
for the second issue "torch.compile + torchao is not supported" I can only repro in 2.4.0 and not sure how to fix, also the fix could just be needed in pytorch itself I think and we can't really back port the fix to 2.4.0. I'm wondering if it's OK for you to use torch nightly for now for testing? pytorch 2.5.0 is going to be released in one month. |
Summary: Similar to sgl-project#1341 we add torchao quantization to mixtral model Test Plan: Note: compile is not working yet, and I can't install torchnightly locally and make it work either. I'll wait for pytorch 2.5 release which happens in mid Oct, or check that again later python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 Warmup ... Prefill. latency: 0.05532 s, throughput: 2313.73 token/s Decode. latency: 0.00896 s, throughput: 111.65 token/s Decode. latency: 0.00833 s, throughput: 120.04 token/s Decode. latency: 0.00869 s, throughput: 115.06 token/s Decode. latency: 0.00842 s, throughput: 118.79 token/s Decode. median latency: 0.00855 s, median throughput: 116.89 token/s Total. latency: 0.090 s, throughput: 1471.26 token/s Benchmark ... Prefill. latency: 0.04294 s, throughput: 2980.61 token/s Decode. latency: 0.00839 s, throughput: 119.12 token/s Decode. latency: 0.00828 s, throughput: 120.78 token/s Decode. latency: 0.00857 s, throughput: 116.64 token/s Decode. latency: 0.00853 s, throughput: 117.19 token/s Decode. latency: 0.00859 s, throughput: 116.39 token/s Decode. median latency: 0.00853 s, median throughput: 117.17 token/s Total. latency: 0.111 s, throughput: 1226.84 token/s python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --torchao-config int4wo-128 Warmup ... Prefill. latency: 0.06413 s, throughput: 1996.05 token/s Decode. latency: 0.00764 s, throughput: 130.84 token/s Decode. latency: 0.00748 s, throughput: 133.73 token/s Decode. latency: 0.00725 s, throughput: 137.84 token/s Decode. latency: 0.00721 s, throughput: 138.74 token/s Decode. median latency: 0.00737 s, median throughput: 135.76 token/s Total. latency: 0.094 s, throughput: 1408.61 token/s Benchmark ... Prefill. latency: 0.05239 s, throughput: 2443.43 token/s Decode. latency: 0.00739 s, throughput: 135.25 token/s Decode. latency: 0.00720 s, throughput: 138.90 token/s Decode. latency: 0.00718 s, throughput: 139.21 token/s Decode. latency: 0.00722 s, throughput: 138.42 token/s Decode. latency: 0.00745 s, throughput: 134.30 token/s Decode. median latency: 0.00731 s, median throughput: 136.82 token/s Total. latency: 0.111 s, throughput: 1223.51 token/s A100, no compile python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --torchao-config fp8wo max_total_num_tokens=199454 Warmup ... Prefill. latency: 0.06958 s, throughput: 1839.60 token/s Decode. latency: 0.02343 s, throughput: 42.68 token/s Decode. latency: 0.02342 s, throughput: 42.70 token/s Decode. latency: 0.02368 s, throughput: 42.23 token/s Decode. latency: 0.02337 s, throughput: 42.80 token/s Decode. median latency: 0.02342 s, median throughput: 42.69 token/s Total. latency: 0.163 s, throughput: 807.48 token/s Benchmark ... Prefill. latency: 0.05767 s, throughput: 2219.36 token/s Decode. latency: 0.02293 s, throughput: 43.61 token/s Decode. latency: 0.02026 s, throughput: 49.36 token/s Decode. latency: 0.02029 s, throughput: 49.29 token/s Decode. latency: 0.02024 s, throughput: 49.41 token/s Decode. latency: 0.02026 s, throughput: 49.36 token/s Decode. median latency: 0.02025 s, median throughput: 49.39 token/s Total. latency: 0.222 s, throughput: 611.87 token/s Reviewers: Subscribers: Tasks: Tags:
@jerryzh168 I think we can give it a try, I will verify asap. |
interesting |
@jerryzh168 The main branch currently supports torch 2.5.1, and the installation method is as follows:
If you have any questions, feel free to contact me anytime. Thanks! |
Summary:
We want to hack before we work on a proper solution
proper solution will be rewrite llama model with tensor parallelism: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
(using DTensor underneath), trying to do it here: pytorch/ao#785
Test Plan:
with compile
python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3-8B --batch-size 1 --input 128 --output 8 --enable-torch-compile
Reviewers:
Subscribers:
Tasks:
Tags: