Skip to content

Commit

Permalink
Merge branch 'main' into 8da4w
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryzh168 authored May 6, 2024
2 parents 7f1662a + ce78e79 commit ff1212d
Show file tree
Hide file tree
Showing 11 changed files with 705 additions and 179 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ jobs:
torch-spec: 'torch==2.3.0'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CUDA 2.4.0.dev20240428
- name: CUDA Nightly
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: '--pre torch==2.4.0.dev20240428+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CPU 2.2.2
Expand All @@ -46,7 +46,7 @@ jobs:
torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu'
gpu-arch-type: "cpu"
gpu-arch-version: ""
- name: Nightly CPU
- name: CPU Nightly
runs-on: linux.4xlarge
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu'
gpu-arch-type: "cpu"
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ torchao has been integrated with other libraries including
## Success stories
Our kernels have been used to achieve SOTA inference performance on

* Image segmentation models with [sam-fast](pytorch.org/blog/accelerating-generative-ai)
* Language models with [gpt-fast](pytorch.org/blog/accelerating-generative-ai-2)
* Diffusion models with [sd-fast](pytorch.org/blog/accelerating-generative-ai-3)
* Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai)
* Language models with [gpt-fast](https://pytorch.org/blog/accelerating-generative-ai-2)
* Diffusion models with [sd-fast](https://pytorch.org/blog/accelerating-generative-ai-3)

## License

Expand Down
85 changes: 56 additions & 29 deletions benchmarks/benchmark_hqq.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@

try:
import triton
import hqq
import triton

if int(triton.__version__.split(".")[0]) < 3:
raise "triton >= 3.0.0 is required to run this test"
except ImportError:
raise "triton and hqq required to run this benchmark"

import torch
from io import StringIO

import pandas as pd
from hqq.core.quantize import HQQLinear, BaseQuantizeConfig
from torchao.prototype.hqq.hqq_tinygemm_linear import HQQLinearTorchWeightOnlyInt4
from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4

import torch
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear
from triton.testing import do_bench

from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm
from torchao.prototype.hqq.hqq_tinygemm_linear import HQQLinearTorchWeightOnlyInt4

BASE_QUANT_CONFIG = {
"optimize": True,
Expand All @@ -27,7 +26,16 @@
}


def bench_custom_kernel(x, W_q, scales, zeros, group_size, kernel_type="max_autotune", fp8_fast_accum=False):
def bench_custom_kernel(
x,
W_q,
scales,
zeros,
group_size,
transposed=False,
kernel_type="max_autotune",
fp8_fast_accum=False,
):
packed_w = pack_2xint4(W_q.T)

def fn():
Expand All @@ -36,6 +44,7 @@ def fn():
packed_w,
scales.T,
zeros.T,
transposed=transposed,
group_size=group_size,
fp8_fast_accum=fp8_fast_accum,
kernel_type=kernel_type,
Expand All @@ -45,22 +54,30 @@ def fn():
return t


def bench_hqq(x, hqq_linear: HQQLinear):
def fn():
_ = hqq_linear.forward(x)
def bench_hqq(x, hqq_linear: HQQLinear | HQQLinearTorchWeightOnlyInt4, transposed=False, tinygemm=False):
def reference_fn():
W_dq = hqq_linear.dequantize()
_ = x @ W_dq.T if not transposed else x @ W_dq
fn = reference_fn if not tinygemm else lambda: hqq_linear(x)

t = do_bench(fn)
return t


def run_benchmark(shape, group_size, dtype, axis=1, quant_dtype=torch.uint8):
def run_benchmark(
shape, group_size, dtype, axis=1, transposed=False, quant_dtype=torch.uint8
):
qcfg = {
**BASE_QUANT_CONFIG,
**dict(group_size=group_size, axis=axis),
}
M, N, K = shape

x = torch.randn(M, K, dtype=dtype, device="cuda")
x = (
torch.randn(M, K, dtype=dtype, device="cuda")
if not transposed
else torch.randn(M, N, dtype=dtype, device="cuda")
)
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")

quant_config = BaseQuantizeConfig(
Expand All @@ -71,7 +88,7 @@ def run_benchmark(shape, group_size, dtype, axis=1, quant_dtype=torch.uint8):
hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False)

# Reference
ref_time = bench_hqq(x, hqq_linear)
ref_time = bench_hqq(x, hqq_linear, transposed=transposed)

# Custom kernel
W_q, meta = hqq_linear.W_q, hqq_linear.meta
Expand All @@ -85,26 +102,31 @@ def run_benchmark(shape, group_size, dtype, axis=1, quant_dtype=torch.uint8):
W_q = W_q.to(dtype=quant_dtype)
scales = scales.reshape(N, -1)
zeros = zeros.reshape(N, -1)
tt_time = bench_custom_kernel(x, W_q, scales, zeros, group_size)
tt_time = bench_custom_kernel(
x, W_q, scales, zeros, group_size, transposed=transposed
)

if dtype == torch.bfloat16:
should_run_tinygemm = dtype == torch.bfloat16 and not transposed
if should_run_tinygemm:
_ = quant_config["weight_quant_params"].pop("bitpack")
hqq_int4mm = HQQLinearTorchWeightOnlyInt4(
linear, quant_config, compute_dtype=dtype, del_orig=False
)
int4_time = bench_hqq(x, hqq_int4mm)
int4_time = bench_hqq(x, hqq_int4mm, transposed=transposed, tinygemm=True)

print(f"{shape=} {group_size=} {dtype=}:")
print(f"{shape=}, {group_size=}, {dtype=}, {transposed=}:")

print(
f"Ref: {ref_time:.4f}",
f"Triton: {tt_time:.4f}",
f"Torch int4mm: {int4_time:.4f}"
if dtype == torch.bfloat16
else "",
f"Ref: {ref_time:.4f}ms",
f"Triton: {tt_time:.4f}ms",
f"Torch int4mm: {int4_time:.4f}ms" if should_run_tinygemm else "",
)
print()
return ref_time, tt_time, int4_time if dtype == torch.bfloat16 else None
return (
ref_time,
tt_time,
int4_time if should_run_tinygemm else -1,
)


SHAPES = [
Expand All @@ -116,16 +138,17 @@ def run_benchmark(shape, group_size, dtype, axis=1, quant_dtype=torch.uint8):
[1024, 4096, 4096],
]

DTYPES = [torch.bfloat16] # , torch.float16]
DTYPES = [torch.bfloat16] #[torch.float16, torch.bfloat16]
GROUP_SIZES = [128]

TRANSPOSED = [True] #[False, True]

HEADERS = [
"M",
"N",
"K",
"group_size",
"dtype",
"transposed",
"ref",
"triton",
"tinygemm",
Expand All @@ -138,10 +161,14 @@ def run_benchmark(shape, group_size, dtype, axis=1, quant_dtype=torch.uint8):
for shape in SHAPES:
for group_size in GROUP_SIZES:
for dtype in DTYPES:
timings = run_benchmark(shape, group_size, dtype)
data.append((*shape, group_size, dtype, *timings))
for transposed in TRANSPOSED:
timings = run_benchmark(
shape, group_size, dtype, transposed=transposed
)
data.append((*shape, group_size, dtype, transposed, *timings))

output = StringIO()
df = pd.DataFrame(data, columns=HEADERS)
df.to_csv(output, index=False)
print(output.getvalue())
print(output.getvalue())
# df.to_csv("benchmark_hqq_tinygemm.csv", index=False)
Loading

0 comments on commit ff1212d

Please sign in to comment.