|
7 | 7 | #pragma once |
8 | 8 | #include <array> |
9 | 9 |
|
| 10 | +#include <stdint.h> |
10 | 11 | #include <cassert> |
| 12 | + |
11 | 13 | namespace torchao::ops { |
12 | 14 |
|
13 | | -enum PackedWeightsFormat : unsigned short { |
| 15 | +enum class PackedWeightsFormat : uint32_t { |
14 | 16 | unknown = 0, |
15 | 17 | linear_8bit_act_xbit_weight_universal = 1 |
16 | 18 | }; |
17 | 19 |
|
18 | 20 | class PackedWeightsHeader { |
19 | 21 | public: |
20 | | - using params_type = std::array<unsigned short, 7>; |
| 22 | + using params_type = std::array<int, 14>; |
| 23 | + const static int magic = 6712; |
21 | 24 | PackedWeightsFormat format; |
22 | 25 |
|
23 | 26 | // 14 bytes of format specific params |
24 | 27 | params_type params; |
25 | 28 |
|
26 | 29 | PackedWeightsHeader( |
27 | 30 | PackedWeightsFormat format = PackedWeightsFormat::unknown, |
28 | | - params_type params = {0, 0, 0, 0, 0, 0, 0}) |
| 31 | + params_type params = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) |
29 | 32 | : format{format}, params{params} {} |
30 | 33 |
|
31 | 34 | inline static constexpr int size() { |
32 | | - static_assert(sizeof(format) + sizeof(params) == 16); |
33 | | - return 16; |
| 35 | + static_assert(sizeof(magic) + sizeof(format) + sizeof(params) == 64); |
| 36 | + return 64; |
34 | 37 | } |
35 | 38 |
|
36 | 39 | inline void write(void* packed_weights) const { |
37 | | - auto header = (unsigned short*)(packed_weights); |
38 | | - header[0] = (unsigned short)format; |
| 40 | + auto header = reinterpret_cast<int*>(packed_weights); |
| 41 | + header[0] = magic; |
| 42 | + header[1] = static_cast<int>(format); |
39 | 43 | for (int i = 0; i < params.size(); i++) { |
40 | | - header[i + 1] = params[i]; |
| 44 | + header[i + 2] = params[i]; |
41 | 45 | } |
42 | 46 | } |
43 | 47 |
|
44 | 48 | static PackedWeightsHeader read(const void* packed_weights) { |
45 | | - auto header = (unsigned short*)(packed_weights); |
| 49 | + auto header = reinterpret_cast<const int*>(packed_weights); |
| 50 | + assert(header[0] == PackedWeightsHeader::magic); |
46 | 51 | params_type params; |
47 | 52 | for (int i = 0; i < params.size(); i++) { |
48 | | - params[i] = header[i + 1]; |
| 53 | + params[i] = header[i + 2]; |
49 | 54 | } |
50 | | - return PackedWeightsHeader((PackedWeightsFormat)header[0], params); |
| 55 | + return PackedWeightsHeader( |
| 56 | + static_cast<PackedWeightsFormat>(header[1]), params); |
51 | 57 | } |
52 | 58 |
|
53 | 59 | bool operator==(const PackedWeightsHeader& other) const { |
|
0 commit comments