|
10 | 10 | #include <cassert> |
11 | 11 | namespace torchao::ops { |
12 | 12 |
|
13 | | -enum PackedWeightsFormat : unsigned short { |
| 13 | +enum PackedWeightsFormat : int { |
14 | 14 | unknown = 0, |
15 | 15 | linear_8bit_act_xbit_weight_universal = 1 |
16 | 16 | }; |
17 | 17 |
|
18 | 18 | class PackedWeightsHeader { |
19 | 19 | public: |
20 | | - using params_type = std::array<unsigned short, 7>; |
| 20 | + using params_type = std::array<int, 14>; |
| 21 | + const static int magic = 6712; |
21 | 22 | PackedWeightsFormat format; |
22 | 23 |
|
23 | 24 | // 14 bytes of format specific params |
24 | 25 | params_type params; |
25 | 26 |
|
26 | 27 | PackedWeightsHeader( |
27 | 28 | PackedWeightsFormat format = PackedWeightsFormat::unknown, |
28 | | - params_type params = {0, 0, 0, 0, 0, 0, 0}) |
| 29 | + params_type params = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) |
29 | 30 | : format{format}, params{params} {} |
30 | 31 |
|
31 | 32 | inline static constexpr int size() { |
32 | | - static_assert(sizeof(format) + sizeof(params) == 16); |
33 | | - return 16; |
| 33 | + static_assert(sizeof(magic) + sizeof(format) + sizeof(params) == 64); |
| 34 | + return 64; |
34 | 35 | } |
35 | 36 |
|
36 | 37 | inline void write(void* packed_weights) const { |
37 | | - auto header = (unsigned short*)(packed_weights); |
38 | | - header[0] = (unsigned short)format; |
| 38 | + auto header = (int*)(packed_weights); |
| 39 | + header[0] = magic; |
| 40 | + header[1] = (int)format; |
39 | 41 | for (int i = 0; i < params.size(); i++) { |
40 | | - header[i + 1] = params[i]; |
| 42 | + header[i + 2] = params[i]; |
41 | 43 | } |
42 | 44 | } |
43 | 45 |
|
44 | 46 | static PackedWeightsHeader read(const void* packed_weights) { |
45 | | - auto header = (unsigned short*)(packed_weights); |
| 47 | + auto header = (int*)(packed_weights); |
| 48 | + assert(header[0] == PackedWeightsHeader::magic); |
46 | 49 | params_type params; |
47 | 50 | for (int i = 0; i < params.size(); i++) { |
48 | | - params[i] = header[i + 1]; |
| 51 | + params[i] = header[i + 2]; |
49 | 52 | } |
50 | | - return PackedWeightsHeader((PackedWeightsFormat)header[0], params); |
| 53 | + return PackedWeightsHeader((PackedWeightsFormat)header[1], params); |
51 | 54 | } |
52 | 55 |
|
53 | 56 | bool operator==(const PackedWeightsHeader& other) const { |
|
0 commit comments