@@ -234,6 +234,183 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values(
234234 unpacked3 = vorrq_u8 (unpacked3, vshrq_n_u8 (b3210, 4 ));
235235}
236236
237+ TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values_v2 (
238+ uint8_t * packed,
239+ const uint8_t * unpacked) {
240+ // Given 4 unpacked uint6 values: abcdef, ghijkl, mnopqr, 123456
241+ // this function packs them as:
242+ // packed[0]: 56 | abcdef
243+ // packed[1]: 34 | ghijkl
244+ // packed[2]: 12 | mnopqr
245+ //
246+ // Input is 4 bytes
247+ // Output is 6 * 4 bits/8 = 3 bytes
248+ packed[0 ] = unpacked[0 ];
249+ packed[1 ] = unpacked[1 ];
250+ packed[2 ] = unpacked[2 ];
251+ // Last value is packed in the upper 2 bits of the three bytes
252+ packed[0 ] |= ((unpacked[3 ] & 0b00'0011u ) << 6 );
253+ packed[1 ] |= ((unpacked[3 ] & 0b00'1100u ) << 4 );
254+ packed[2 ] |= ((unpacked[3 ] & 0b11'0000u ) << 2 );
255+ }
256+
257+ TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values_v2 (
258+ uint8_t * unpacked,
259+ const uint8_t * packed) {
260+ // Unpacks data packed by pack_4_uint6_values_v2
261+ //
262+ // Input is 24 bits = 3 bytes
263+ // Output is 4 bytes
264+ unpacked[0 ] = packed[0 ] & 0b111111u ;
265+ unpacked[1 ] = packed[1 ] & 0b111111u ;
266+ unpacked[2 ] = packed[2 ] & 0b111111u ;
267+ // Last value is packed in the upper 2 bits of the three bytes
268+ unpacked[3 ] = ((packed[0 ] & 0b1100'0000u ) >> 6 ) |
269+ ((packed[1 ] & 0b1100'0000u ) >> 4 ) |
270+ ((packed[2 ] & 0b1100'0000u ) >> 2 );
271+ }
272+
273+ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values_v2 (
274+ uint8_t * packed,
275+ const uint8x16_t & unpacked0,
276+ const uint8x16_t & unpacked1) {
277+ // This function is a vectorized version of pack_4_uint6_values_v2.
278+ // To understand the following code, please see pack_4_uint6_values_v2 first and
279+ // consider the following mapping for the unpacked parameter of that function:
280+ //
281+ // unpacked[0] -> vget_low_u8(unpacked0)
282+ // unpacked[1] -> vget_high_u8(unpacked0)
283+ // unpacked[2] -> vget_low_u8(unpacked1)
284+ // unpacked[3] -> vget_high_u8(unpacked1)
285+ //
286+ // Before each code section, there is a comment indicating the
287+ // code in pack_4_uint6_values_v2 that is being vectorized.
288+ //
289+ // Input is 32 bytes.
290+ // Output is 6*32= 192 bits = 24 bytes.
291+ uint8x8_t r;
292+
293+ // packed[0] = unpacked[0]
294+ // packed[0] |= ((unpacked[3] & 0b00'0011u) << 6)
295+ r = vget_low_u8 (unpacked0);
296+ r = vorr_u8 (r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b00'0011u )), 6 ));
297+ vst1_u8 (packed, r);
298+
299+ // packed[1] = unpacked[1]
300+ // packed[1] |= ((unpacked[3] & 0b00'1100u) << 4)
301+ r = vget_high_u8 (unpacked0);
302+ r = vorr_u8 (r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b00'1100u )), 4 ));
303+ vst1_u8 (packed + 8 , r);
304+
305+ // packed[2] = unpacked[2]
306+ // packed[2] |= ((unpacked[3] & 0b11'0000u) << 2)
307+ r = vget_low_u8 (unpacked1);
308+ r = vorr_u8 (r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b11'0000u )), 2 ));
309+ vst1_u8 (packed + 16 , r);
310+ }
311+
312+ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values_v2 (
313+ uint8x16_t & unpacked0,
314+ uint8x16_t & unpacked1,
315+ const uint8_t * packed) {
316+ // Unpacks data packed by vec_pack_32_uint6_values_v2.
317+ //
318+ // This function vectorizes unpack_4_uint6_values_v2.
319+ // To understand it, please see unpack_4_uint6_values_v2 first.
320+ // Before each code section, there is a comment indicating the
321+ // code in unpack_4_uint6_values_v2 that is being vectorized.
322+ //
323+ // Input is 24 bytes.
324+ // Output is 32 bytes.
325+ uint8x8_t packed0 = vld1_u8 (packed);
326+ uint8x8_t packed1 = vld1_u8 (packed + 8 );
327+ uint8x8_t packed2 = vld1_u8 (packed + 16 );
328+
329+ // unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
330+ // ((packed[1] & 0b1100'0000u) >> 4) |
331+ // ((packed[2] & 0b1100'0000u) >> 2);
332+ const uint8x8_t high = vdup_n_u8 (0b1100'0000u );
333+ uint8x8_t unpacked3;
334+ unpacked3 = vorr_u8 (vshr_n_u8 (vand_u8 (packed0, high), 6 ),
335+ vshr_n_u8 (vand_u8 (packed1, high), 4 ));
336+ unpacked3 = vorr_u8 (unpacked3,
337+ vshr_n_u8 (vand_u8 (packed2, high), 2 ));
338+
339+ // unpacked[i] = packed[i] & 0b11'1111u;
340+ const uint8x8_t mask = vdup_n_u8 (0b11'1111u );
341+ unpacked0 = vcombine_u8 (vand_u8 (packed0, mask), vand_u8 (packed1, mask));
342+ unpacked1 = vcombine_u8 (vand_u8 (packed2, mask), unpacked3);
343+ }
344+
345+ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values_v2 (
346+ uint8_t * packed,
347+ const uint8x16_t & unpacked0,
348+ const uint8x16_t & unpacked1,
349+ const uint8x16_t & unpacked2,
350+ const uint8x16_t & unpacked3) {
351+ // This function is a vectorized version of pack_4_uint6_values_v2.
352+ // To understand the following code, please see pack_4_uint6_values_v2 first.
353+ // Before each code section, there is a comment indicating the
354+ // code in pack_4_uint6_values_v2 that is being vectorized.
355+ //
356+ // Input is 48 bytes.
357+ // Output is 64 bytes.
358+ uint8x16_t r;
359+
360+ // packed[0] = unpacked[0]
361+ // packed[0] |= ((unpacked[3] & 0b00'0011u) << 6)
362+ r = unpacked0;
363+ r = vorrq_u8 (r, vshlq_n_u8 (vandq_u8 (unpacked3, vdupq_n_u8 (0b00'0011u )), 6 ));
364+ vst1q_u8 (packed, r);
365+
366+ // packed[1] = unpacked[1]
367+ // packed[1] |= ((unpacked[3] & 0b00'1100u) << 4)
368+ r = unpacked1;
369+ r = vorrq_u8 (r, vshlq_n_u8 (vandq_u8 (unpacked3, vdupq_n_u8 (0b00'1100u )), 4 ));
370+ vst1q_u8 (packed + 16 , r);
371+
372+ // packed[2] = unpacked[2]
373+ // packed[2] |= ((unpacked[3] & 0b11'0000u) << 2)
374+ r = unpacked2;
375+ r = vorrq_u8 (r, vshlq_n_u8 (vandq_u8 (unpacked3, vdupq_n_u8 (0b11'0000u )), 2 ));
376+ vst1q_u8 (packed + 32 , r);
377+ }
378+
379+ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values_v2 (
380+ uint8x16_t & unpacked0,
381+ uint8x16_t & unpacked1,
382+ uint8x16_t & unpacked2,
383+ uint8x16_t & unpacked3,
384+ const uint8_t * packed) {
385+ // Unpacks data packed by vec_pack_64_uint6_values_v2.
386+ //
387+ // This function vectorizes unpack_4_uint6_values_v2.
388+ // To understand it, please see unpack_4_uint6_values_v2 first.
389+ // Before each code section, there is a comment indicating the
390+ // code in unpack_4_uint6_values that is being vectorized
391+
392+ // Input is 48 bytes.
393+ // Output is 64 bytes.
394+ unpacked0 = vld1q_u8 (packed);
395+ unpacked1 = vld1q_u8 (packed + 16 );
396+ unpacked2 = vld1q_u8 (packed + 32 );
397+
398+ // unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
399+ // ((packed[1] & 0b1100'0000u) >> 4) |
400+ // ((packed[2] & 0b1100'0000u) >> 2);
401+ const uint8x16_t high = vdupq_n_u8 (0b1100'0000u );
402+ unpacked3 = vorrq_u8 (vshrq_n_u8 (vandq_u8 (unpacked0, high), 6 ),
403+ vshrq_n_u8 (vandq_u8 (unpacked1, high), 4 ));
404+ unpacked3 = vorrq_u8 (unpacked3,
405+ vshrq_n_u8 (vandq_u8 (unpacked2, high), 2 ));
406+
407+ // unpacked[i] = packed[i] & 0b11'1111u;
408+ const uint8x16_t mask = vdupq_n_u8 (0b11'1111u );
409+ unpacked0 = vandq_u8 (unpacked0, mask);
410+ unpacked1 = vandq_u8 (unpacked1, mask);
411+ unpacked2 = vandq_u8 (unpacked2, mask);
412+ }
413+
237414} // namespace internal
238415} // namespace bitpacking
239416} // namespace torchao
0 commit comments