@@ -114,15 +114,20 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values(
114114 uint8x8_t packed1 = vld1_u8 (packed + 8 );
115115 uint8x8_t packed2 = vld1_u8 (packed + 16 );
116116
117- // unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
118- // ((packed[1] & 0b1100'0000u) >> 4) |
119- // ((packed[2] & 0b1100'0000u) >> 2);
120- const uint8x8_t high = vdup_n_u8 (0b1100'0000u );
121117 uint8x8_t unpacked3;
122- unpacked3 = vorr_u8 (
123- vshr_n_u8 (vand_u8 (packed0, high), 6 ),
124- vshr_n_u8 (vand_u8 (packed1, high), 4 ));
125- unpacked3 = vorr_u8 (unpacked3, vshr_n_u8 (vand_u8 (packed2, high), 2 ));
118+ // We want to extract bits 123456 and place them in unpacked3.
119+ // Packed structure is:
120+ //
121+ // packed0: 56 | abcdef
122+ // packed1: 34 | ghijkl
123+ // packed2: 12 | mnopqr
124+ //
125+ // unpacked3 = 1234 ghij
126+ unpacked3 = vsri_n_u8 (packed2, packed1, 2 );
127+ // unpacked3 = 1234 56ab
128+ unpacked3 = vsri_n_u8 (unpacked3, packed0, 4 );
129+ // unpacked3 = 0012 3456
130+ unpacked3 = vshr_n_u8 (unpacked3, 2 );
126131
127132 // unpacked[i] = packed[i] & 0b11'1111u;
128133 const uint8x8_t mask = vdup_n_u8 (0b11'1111u );
@@ -183,14 +188,19 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values(
183188 unpacked1 = vld1q_u8 (packed + 16 );
184189 unpacked2 = vld1q_u8 (packed + 32 );
185190
186- // unpacked[3] = ((packed[0] & 0b1100'0000u) >> 6) |
187- // ((packed[1] & 0b1100'0000u) >> 4) |
188- // ((packed[2] & 0b1100'0000u) >> 2);
189- const uint8x16_t high = vdupq_n_u8 (0b1100'0000u );
190- unpacked3 = vorrq_u8 (
191- vshrq_n_u8 (vandq_u8 (unpacked0, high), 6 ),
192- vshrq_n_u8 (vandq_u8 (unpacked1, high), 4 ));
193- unpacked3 = vorrq_u8 (unpacked3, vshrq_n_u8 (vandq_u8 (unpacked2, high), 2 ));
191+ // We want to extract bits 123456 and place them in unpacked3.
192+ // Packed structure is:
193+ //
194+ // packed0: 56 | abcdef
195+ // packed1: 34 | ghijkl
196+ // packed2: 12 | mnopqr
197+ //
198+ // unpacked3 = 1234 ghij
199+ unpacked3 = vsriq_n_u8 (unpacked2, unpacked1, 2 );
200+ // unpacked3 = 1234 56ab
201+ unpacked3 = vsriq_n_u8 (unpacked3, unpacked0, 4 );
202+ // unpacked3 = 0012 3456
203+ unpacked3 = vshrq_n_u8 (unpacked3, 2 );
194204
195205 // unpacked[i] = packed[i] & 0b11'1111u;
196206 const uint8x16_t mask = vdupq_n_u8 (0b11'1111u );
0 commit comments