Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renaming fpx to floatx #877

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ The best example we have combining the composability of lower bit dtype with com

We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()` so if you love writing kernels but hate packaging them so they work all operating systems and cuda versions, we'd love to accept contributions for your custom ops. We have a few examples you can follow

1. [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fp6_llm_weight_only())`
1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference

Expand Down
6 changes: 3 additions & 3 deletions benchmarks/benchmark_fp6.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import torch
import pandas as pd
import torch.nn.functional as F
from torchao.dtypes import to_affine_quantized_fpx
from torchao.dtypes.fpx import FpxTensorCoreAQTLayout, FpxTensorCoreLayoutType
from torchao.dtypes import to_affine_quantized_floatx
from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType
from torchao.utils import benchmark_torch_function_in_microseconds
from tqdm import tqdm


def benchmark(m: int, k: int, n: int):
float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
fp6_weight = to_affine_quantized_fpx(float_data, FpxTensorCoreLayoutType(3, 2))
fp6_weight = to_affine_quantized_floatx(float_data, FloatxTensorCoreLayoutType(3, 2))
fp16_weight = fp6_weight.dequantize(torch.half)

fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
Expand Down
72 changes: 36 additions & 36 deletions test/dtypes/test_fpx.py → test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
parametrize,
run_tests,
)
from torchao.dtypes.fpx import (
FpxTensorCoreAQTLayout,
FpxTensorCoreLayoutType,
to_scaled_tc_fpx,
from_scaled_tc_fpx,
from torchao.dtypes.floatx import (
FloatxTensorCoreAQTLayout,
FloatxTensorCoreLayoutType,
to_scaled_tc_floatx,
from_scaled_tc_floatx,
)
from torchao.dtypes.fpx.fpx import _pack_tc_fpx, _pack_tc_fp6
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32
from torchao.dtypes.floatx.floatx import _pack_tc_floatx, _pack_tc_fp6
from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32
from torchao.quantization import (
quantize_,
fpx_weight_only,
Expand All @@ -25,71 +25,71 @@


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
_FPx_DTYPES = [(3, 2), (2, 2)]
_Floatx_DTYPES = [(3, 2), (2, 2)]


class TestFpxTensorCoreAQTLayout(TestCase):
class TestFloatxTensorCoreAQTLayout(TestCase):
@parametrize("device", _DEVICES)
def test_pack_tc_fp6_correctness(self, device):
x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device)

expected = _pack_tc_fpx(x, 6)
expected = _pack_tc_floatx(x, 6)
actual = _pack_tc_fp6(x)
torch.testing.assert_close(actual, expected)

@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("device", _DEVICES)
def test_to_scaled_tc_fpx_compile(self, ebits, mbits, device):
def test_to_scaled_tc_floatx_compile(self, ebits, mbits, device):
x = torch.randn(256, 64, device=device)

expected = to_scaled_tc_fpx(x, ebits, mbits)
actual = torch.compile(to_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits)
expected = to_scaled_tc_floatx(x, ebits, mbits)
actual = torch.compile(to_scaled_tc_floatx, fullgraph=True)(x, ebits, mbits)
torch.testing.assert_close(actual, expected)

@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("device", _DEVICES)
def test_from_tc_fpx_correctness(self, ebits, mbits, device):
def test_from_tc_floatx_correctness(self, ebits, mbits, device):
x = torch.randn(256, 64, device=device) * 100

# quantize and dequantize so that the values are exactly representable in FPx
x = _fpx_unpacked_to_f32(_f32_to_fpx_unpacked(x, ebits, mbits), ebits, mbits)
# quantize and dequantize so that the values are exactly representable in Floatx
x = _floatx_unpacked_to_f32(_f32_to_floatx_unpacked(x, ebits, mbits), ebits, mbits)

tc_fpx, scale = to_scaled_tc_fpx(x, ebits, mbits)
actual = from_scaled_tc_fpx(tc_fpx, ebits, mbits, scale=scale)
tc_floatx, scale = to_scaled_tc_floatx(x, ebits, mbits)
actual = from_scaled_tc_floatx(tc_floatx, ebits, mbits, scale=scale)
torch.testing.assert_close(actual, x)

@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("device", _DEVICES)
def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device):
def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device):
M, N = 256, 64
nbits = 1 + ebits + mbits
x = torch.randint(256, size=(M, N // 8 * nbits), dtype=torch.uint8, device=device)
scale = torch.randn(M, device=device)

expected = from_scaled_tc_fpx(x, ebits, mbits, scale)
actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale)
expected = from_scaled_tc_floatx(x, ebits, mbits, scale)
actual = torch.compile(from_scaled_tc_floatx, fullgraph=True)(x, ebits, mbits, scale)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("ebits,mbits", _Floatx_DTYPES)
def test_to_copy_device(self, ebits, mbits):
from torchao.quantization.quant_primitives import (
choose_qparams_affine_fpx,
quantize_affine_fpx,
choose_qparams_affine_floatx,
quantize_affine_floatx,
)

x = torch.randn(256, 64)
scale = choose_qparams_affine_fpx(x, ebits, mbits)
x = quantize_affine_fpx(x, scale, ebits, mbits)
layout_type = FpxTensorCoreLayoutType(ebits, mbits)
fpx_layout_tensor = FpxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda()
assert fpx_layout_tensor.device.type == "cuda"
fpx_layout_tensor = fpx_layout_tensor.cpu()
assert fpx_layout_tensor.device.type == "cpu"
scale = choose_qparams_affine_floatx(x, ebits, mbits)
x = quantize_affine_floatx(x, scale, ebits, mbits)
layout_type = FloatxTensorCoreLayoutType(ebits, mbits)
floatx_layout_tensor = FloatxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda()
assert floatx_layout_tensor.device.type == "cuda"
floatx_layout_tensor = floatx_layout_tensor.cpu()
assert floatx_layout_tensor.device.type == "cpu"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("bias", [False, True])
def test_fpx_weight_only(self, ebits, mbits, bias):
N, OC, IC = 4, 256, 64
Expand All @@ -106,7 +106,7 @@ def test_fpx_weight_only(self, ebits, mbits, bias):
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestFpxTensorCoreAQTLayout)
instantiate_parametrized_tests(TestFloatxTensorCoreAQTLayout)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from torchao.dtypes.uintx.Uintx import to_uintx
from torchao.dtypes.uintx.uintx import to_uintx
from torchao.quantization.quant_api import quantize_, uintx_weight_only
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
Expand Down
26 changes: 13 additions & 13 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from torch.testing._internal.optests import opcheck
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
from torchao.dtypes.fpx import from_scaled_tc_fpx
from torchao.dtypes.floatx import from_scaled_tc_floatx
from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24
import pytest

Expand All @@ -33,13 +33,13 @@


class TestOps(TestCase):
def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
# Randomly initialize each byte
nbits = 1 + ebits + mbits
fpx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
scale = torch.rand(OC).half() + 0.5
fp16_act = torch.rand(BS, IC).half() + 0.5
return fpx_weight.to(device), scale.to(device), fp16_act.to(device)
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
Expand All @@ -48,28 +48,28 @@ def test_quant_llm_linear(self, ebits, mbits):
OC = 256
IC = 256
splitK = 1
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda")
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda")

# smoke test
torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK)
torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, fpx_weight, scale, splitK), test_utils=test_utils)
opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, floatx_weight, scale, splitK), test_utils=test_utils)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py
fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda")
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda")

results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK)
results_floatx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK)

fp16_weight = from_scaled_tc_fpx(fpx_weight, ebits, mbits, scale).half()
fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).half()
results_fp16 = fp16_act @ fp16_weight.T

error = (results_fpx - results_fp16).abs().mean()
error = (results_floatx - results_fp16).abs().mean()
gt = results_fp16.abs().mean()
relative_error = error / gt
assert relative_error < 1e-3
Expand Down Expand Up @@ -319,7 +319,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]

MARLIN_TEST_PARAMS = list(itertools.product(
MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS,
MARLIN_24_BATCH_SIZE, MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS,
MARLIN_24_SUPPORTED_NUM_BITS, MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
))

Expand Down Expand Up @@ -405,7 +405,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
workspace_24 = marlin_24_workspace(size_n)

fn_inputs = (
input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
input_2d, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
num_bits, a_input_in, marlin_24_scale.shape[1], a_input_out,
)
output = torchao.ops.marlin_24_gemm(*fn_inputs)
Expand Down
Loading
Loading