Skip to content

Commit

Permalink
Create header for packed weight ops (pytorch#1072)
Browse files Browse the repository at this point in the history
Summary:

This diff defines a packed_weight hader in torchao/experimental/ops/packed_weights_header.h.

The header is 16 bytes and has 4 fields:
* format: PackedWeightsFormat (enum)
* extra0: int
* extra1: int
* extra2: int

Whenever we have a new format type, we can add a field to the enum.  Currently I have a field for the format the universal kernels use, but MPS can have a different format, and KleidiAI also has its own format.

I modified the pack ops put this header and the start of the packed weights.  When the linear op runs, it reads the header to understand how the weights were packed.

Reviewed By: digantdesai

Differential Revision: D63498956
  • Loading branch information
metascroy authored and facebook-github-bot committed Oct 15, 2024
1 parent afc0a02 commit c121875
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#pragma once
#include <stdint.h>
#include <stddef.h>
#include <torchao/experimental/ops/packed_weights_header.h>

namespace torchao::ops::linear_8bit_act_xbit_weight {

Expand Down Expand Up @@ -59,6 +60,8 @@ struct UKernelConfig {
kernel_fn_type kernel_fn{nullptr};
int mr{0};
int nr{0};

torchao::ops::PackedWeightsHeader packed_weights_header;
};

// Pack weight functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#endif // defined(__aarch64__) || defined(__ARM_NEON)

#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h>
#include <torchao/experimental/ops/packed_weights_header.h>
#include <optional>
#include <vector>

Expand All @@ -35,31 +37,63 @@ using RuntimeContext = torch::executor::KernelRuntimeContext;

namespace {

// This selects a UkernelConfig based on the packed weight header
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
inline torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig
get_ukernel_config() {
get_ukernel_config(torchao::ops::PackedWeightsHeader header) {
torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config;

switch (header.format) {
#if defined(__aarch64__) || defined(__ARM_NEON)
namespace ukernel = torchao::kernels::cpu::aarch64::linear::
channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot;
config.mr = 1;
config.nr = 8;
config.activation_data_size_fn =
&ukernel::activation_data_size<has_weight_zeros>;
config.preferred_activation_data_alignment = 16; // size of neon register
config.prepare_activation_data_fn =
&ukernel::prepare_activation_data<has_weight_zeros>;
config.weight_data_size_fn =
&ukernel::weight_data_size<weight_nbit, has_weight_zeros>;
config.preferred_weight_data_alignment = 16; // size of neon register
config.prepare_weight_data_fn =
&ukernel::prepare_weight_data<weight_nbit, has_weight_zeros>;
config.kernel_fn =
&ukernel::kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>;
case torchao::ops::PackedWeightsFormat::
linear_8bit_act_xbit_weight_universal:
namespace ukernel
= torchao::kernels::cpu::aarch64::linear::
channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot;

// Check packing params match the kernel
CHECK_MSG(
header ==
torchao::ops::linear_8bit_act_xbit_weight::
get_packed_weights_header_universal(
weight_nbit,
has_weight_zeros,
has_bias,
/*nr=*/8,
/*kr=*/16),
"Packing params do not match what kernel supports");

config.packed_weights_header = header;
config.mr = 1;
config.nr = 8;
config.activation_data_size_fn =
&ukernel::activation_data_size<has_weight_zeros>;
config.preferred_activation_data_alignment = 16; // size of neon register
config.prepare_activation_data_fn =
&ukernel::prepare_activation_data<has_weight_zeros>;
config.weight_data_size_fn =
&ukernel::weight_data_size<weight_nbit, has_weight_zeros>;
config.preferred_weight_data_alignment = 16; // size of neon register
config.prepare_weight_data_fn =
&ukernel::prepare_weight_data<weight_nbit, has_weight_zeros>;
config.kernel_fn =
&ukernel::kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>;
return config;
break;
default:
CHECK_MSG(false, "Unsupported packed weights format");
#endif // defined(__aarch64__) || defined(__ARM_NEON)
}
}

return config;
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
inline torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig
get_ukernel_config() {
auto header = torchao::ops::linear_8bit_act_xbit_weight::
get_packed_weights_header_universal(
weight_nbit, has_weight_zeros, has_bias, /*nr=*/8, /*kr=*/16);
return get_ukernel_config<weight_nbit, has_weight_zeros, has_bias, has_clamp>(
header);
}

#ifdef USE_ATEN
Expand Down Expand Up @@ -114,13 +148,17 @@ Tensor pack_weights_cpu(
auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params(
ukernel_config, n, /*target_panels_per_thread=*/1);

auto packed_weight_data_size =
auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() +
get_packed_weight_data_size(ukernel_config, n, k, group_size);
Tensor packed_weights = torch::empty({static_cast<int64_t>(packed_weight_data_size)}, torch::kInt8);
Tensor packed_weights = torch::empty(
{static_cast<int64_t>(packed_weight_data_size)}, torch::kInt8);
ukernel_config.packed_weights_header.write(
packed_weights.mutable_data_ptr<int8_t>());
pack_weight_data_operator(
ukernel_config,
pack_weight_tiling_params,
packed_weights.mutable_data_ptr<int8_t>(),
packed_weights.mutable_data_ptr<int8_t>() +
torchao::ops::PackedWeightsHeader::size(),
n,
k,
group_size,
Expand Down Expand Up @@ -180,9 +218,10 @@ Tensor pack_weights_meta(
false /*has_bias*/,
false /*has_clamp*/>();

auto packed_weight_data_size =
auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() +
get_packed_weight_data_size(ukernel_config, n, k, group_size);
return torch::empty({static_cast<int64_t>(packed_weight_data_size)}).to("meta");
return torch::empty({static_cast<int64_t>(packed_weight_data_size)})
.to("meta");
}
#endif // USE_ATEN

Expand Down Expand Up @@ -260,11 +299,23 @@ Tensor linear_out_cpu(

using namespace torchao::ops::linear_8bit_act_xbit_weight;

CHECK_MSG(packed_weights.dim() == 1, "packed_weights must be 1D");
#ifdef USE_ATEN
CHECK_MSG(
packed_weights.dtype() == torch::kInt8, "packed_weights must be int8");
#endif // USE_ATEN
CHECK_MSG(
packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(),
"packed_weights is not big enough to read the header.");
auto header = torchao::ops::PackedWeightsHeader::read(
packed_weights.const_data_ptr());

auto ukernel_config = get_ukernel_config<
weight_nbit,
has_weight_zeros /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>();
false /*has_clamp*/>(header);

auto linear_tiling_params = get_default_linear_tiling_params(
ukernel_config,
m,
Expand Down Expand Up @@ -292,7 +343,8 @@ Tensor linear_out_cpu(
n,
k,
group_size,
packed_weights.const_data_ptr<int8_t>(),
packed_weights.const_data_ptr<int8_t>() +
torchao::ops::PackedWeightsHeader::size(),
activations.const_data_ptr<float>(),
/*bias=*/nullptr,
// Clamp parameters are ignored because config is created from
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <torchao/experimental/ops/macro.h>
#include <torchao/experimental/ops/packed_weights_header.h>

namespace torchao::ops::linear_8bit_act_xbit_weight {

torchao::ops::PackedWeightsHeader get_packed_weights_header_universal(
int weight_nbit,
bool has_weight_zeros,
bool has_bias,
int nr,
int kr,
int version = 1) {
TORCHAO_CHECK(
version >= 0 && version < 256, "version must be between 0 and 255");
TORCHAO_CHECK(
weight_nbit >= 1 && weight_nbit < 256,
"weight_nbit must be between 1 and 255");
return torchao::ops::PackedWeightsHeader(
torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal,
{((static_cast<unsigned short>(version) << 8) |
static_cast<unsigned short>(weight_nbit)),
((static_cast<unsigned short>(has_weight_zeros) << 8) |
static_cast<unsigned short>(has_bias)),
static_cast<unsigned short>(nr),
static_cast<unsigned short>(kr),
0,
0,
0,
0});
}

} // namespace torchao::ops::linear_8bit_act_xbit_weight
66 changes: 66 additions & 0 deletions torchao/experimental/ops/packed_weights_header.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <array>

#include <cassert>
namespace torchao::ops {

enum PackedWeightsFormat : unsigned short {
unknown = 0,
linear_8bit_act_xbit_weight_universal = 1
};

class PackedWeightsHeader {
public:
using params_type = std::array<unsigned short, 7>;
PackedWeightsFormat format;

// 14 bytes of format specific params
params_type params;

PackedWeightsHeader(
PackedWeightsFormat format = PackedWeightsFormat::unknown,
params_type params = {0, 0, 0, 0, 0, 0, 0})
: format{format}, params{params} {}

inline static constexpr int size() {
static_assert(sizeof(format) + sizeof(params) == 16);
return 16;
}

inline void write(void* packed_weights) const {
auto header = (unsigned short*)(packed_weights);
header[0] = (unsigned short)format;
for (int i = 0; i < params.size(); i++) {
header[i + 1] = params[i];
}
}

static PackedWeightsHeader read(const void* packed_weights) {
auto header = (unsigned short*)(packed_weights);
params_type params;
for (int i = 0; i < params.size(); i++) {
params[i] = header[i + 1];
}
return PackedWeightsHeader((PackedWeightsFormat)header[0], params);
}

bool operator==(const PackedWeightsHeader& other) const {
if (format != other.format) {
return false;
}
for (int i = 0; i < params.size(); i++) {
if (params[i] != other.params[i]) {
return false;
}
}
return true;
}
};

} // namespace torchao::ops

0 comments on commit c121875

Please sign in to comment.