Skip to content

Commit abb7810

Browse files
metascroyfacebook-github-bot
authored andcommitted
Header bug fix (#1079)
Summary: A last minute change created a compile error on the header. This fixes the issue. I also make the header 64 bytes and add a magic number at the start to make it safer in future. Differential Revision: D64370707
1 parent b53694a commit abb7810

File tree

2 files changed

+24
-22
lines changed

2 files changed

+24
-22
lines changed

torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,18 @@ torchao::ops::PackedWeightsHeader get_packed_weights_header_universal(
1717
int nr,
1818
int kr,
1919
int version = 1) {
20-
TORCHAO_CHECK(
21-
version >= 0 && version < 256, "version must be between 0 and 255");
22-
TORCHAO_CHECK(
23-
weight_nbit >= 1 && weight_nbit < 256,
24-
"weight_nbit must be between 1 and 255");
2520
return torchao::ops::PackedWeightsHeader(
2621
torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal,
27-
{((static_cast<unsigned short>(version) << 8) |
28-
static_cast<unsigned short>(weight_nbit)),
29-
((static_cast<unsigned short>(has_weight_zeros) << 8) |
30-
static_cast<unsigned short>(has_bias)),
31-
static_cast<unsigned short>(nr),
32-
static_cast<unsigned short>(kr),
22+
{version,
23+
weight_nbit,
24+
has_weight_zeros,
25+
has_bias,
26+
nr,
27+
kr,
28+
0,
29+
0,
30+
0,
31+
0,
3332
0,
3433
0,
3534
0,

torchao/experimental/ops/packed_weights_header.h

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,44 +10,47 @@
1010
#include <cassert>
1111
namespace torchao::ops {
1212

13-
enum PackedWeightsFormat : unsigned short {
13+
enum PackedWeightsFormat : int {
1414
unknown = 0,
1515
linear_8bit_act_xbit_weight_universal = 1
1616
};
1717

1818
class PackedWeightsHeader {
1919
public:
20-
using params_type = std::array<unsigned short, 7>;
20+
using params_type = std::array<int, 14>;
21+
const static int magic = 6712;
2122
PackedWeightsFormat format;
2223

2324
// 14 bytes of format specific params
2425
params_type params;
2526

2627
PackedWeightsHeader(
2728
PackedWeightsFormat format = PackedWeightsFormat::unknown,
28-
params_type params = {0, 0, 0, 0, 0, 0, 0})
29+
params_type params = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
2930
: format{format}, params{params} {}
3031

3132
inline static constexpr int size() {
32-
static_assert(sizeof(format) + sizeof(params) == 16);
33-
return 16;
33+
static_assert(sizeof(magic) + sizeof(format) + sizeof(params) == 64);
34+
return 64;
3435
}
3536

3637
inline void write(void* packed_weights) const {
37-
auto header = (unsigned short*)(packed_weights);
38-
header[0] = (unsigned short)format;
38+
auto header = (int*)(packed_weights);
39+
header[0] = magic;
40+
header[1] = (int)format;
3941
for (int i = 0; i < params.size(); i++) {
40-
header[i + 1] = params[i];
42+
header[i + 2] = params[i];
4143
}
4244
}
4345

4446
static PackedWeightsHeader read(const void* packed_weights) {
45-
auto header = (unsigned short*)(packed_weights);
47+
auto header = (int*)(packed_weights);
48+
assert(header[0] == PackedWeightsHeader::magic);
4649
params_type params;
4750
for (int i = 0; i < params.size(); i++) {
48-
params[i] = header[i + 1];
51+
params[i] = header[i + 2];
4952
}
50-
return PackedWeightsHeader((PackedWeightsFormat)header[0], params);
53+
return PackedWeightsHeader((PackedWeightsFormat)header[1], params);
5154
}
5255

5356
bool operator==(const PackedWeightsHeader& other) const {

0 commit comments

Comments
 (0)