@@ -46,25 +46,6 @@ internal static class AvxIntrinsics
4646
4747 private static readonly Vector256 < float > _absMask256 = Avx . StaticCast < int , float > ( Avx . SetAllVector256 ( 0x7FFFFFFF ) ) ;
4848
49- private const int Vector256Alignment = 32 ;
50-
51- [ MethodImplAttribute ( MethodImplOptions . AggressiveInlining ) ]
52- private static bool HasCompatibleAlignment ( AlignedArray alignedArray )
53- {
54- Contracts . AssertValue ( alignedArray ) ;
55- Contracts . Assert ( alignedArray . Size > 0 ) ;
56- return ( alignedArray . CbAlign % Vector256Alignment ) == 0 ;
57- }
58-
59- [ MethodImplAttribute ( MethodImplOptions . AggressiveInlining ) ]
60- private static unsafe float * GetAlignedBase ( AlignedArray alignedArray , float * unalignedBase )
61- {
62- Contracts . AssertValue ( alignedArray ) ;
63- float * alignedBase = unalignedBase + alignedArray . GetBase ( ( long ) unalignedBase ) ;
64- Contracts . Assert ( ( ( long ) alignedBase % Vector256Alignment ) == 0 ) ;
65- return alignedBase ;
66- }
67-
6849 [ MethodImplAttribute ( MethodImplOptions . AggressiveInlining ) ]
6950 private static Vector128 < float > GetHigh ( in Vector256 < float > x )
7051 => Avx . ExtractVector128 ( x , 1 ) ;
@@ -170,19 +151,19 @@ private static Vector256<float> MultiplyAdd(Vector256<float> src1, Vector256<flo
170151 }
171152
172153 // Multiply matrix times vector into vector.
173- public static unsafe void MatMulX ( AlignedArray mat , AlignedArray src , AlignedArray dst , int crow , int ccol )
154+ public static unsafe void MatMul ( AlignedArray mat , AlignedArray src , AlignedArray dst , int crow , int ccol )
174155 {
175- Contracts . Assert ( crow % 4 == 0 ) ;
176- Contracts . Assert ( ccol % 4 == 0 ) ;
177-
178- MatMulX ( mat . Items , src . Items , dst . Items , crow , ccol ) ;
156+ MatMul ( mat . Items , src . Items , dst . Items , crow , ccol ) ;
179157 }
180158
181- public static unsafe void MatMulX ( float [ ] mat , float [ ] src , float [ ] dst , int crow , int ccol )
159+ public static unsafe void MatMul ( ReadOnlySpan < float > mat , ReadOnlySpan < float > src , Span < float > dst , int crow , int ccol )
182160 {
183- fixed ( float * psrc = & src [ 0 ] )
184- fixed ( float * pdst = & dst [ 0 ] )
185- fixed ( float * pmat = & mat [ 0 ] )
161+ Contracts . Assert ( crow % 4 == 0 ) ;
162+ Contracts . Assert ( ccol % 4 == 0 ) ;
163+
164+ fixed ( float * psrc = & MemoryMarshal . GetReference ( src ) )
165+ fixed ( float * pdst = & MemoryMarshal . GetReference ( dst ) )
166+ fixed ( float * pmat = & MemoryMarshal . GetReference ( mat ) )
186167 fixed ( uint * pLeadingAlignmentMask = & LeadingAlignmentMask [ 0 ] )
187168 fixed ( uint * pTrailingAlignmentMask = & TrailingAlignmentMask [ 0 ] )
188169 {
@@ -312,32 +293,134 @@ public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int cro
312293 }
313294
314295 // Partial sparse source vector.
315- public static unsafe void MatMulPX ( AlignedArray mat , int [ ] rgposSrc , AlignedArray src ,
316- int posMin , int iposMin , int iposEnd , AlignedArray dst , int crow , int ccol )
296+ public static unsafe void MatMulP ( AlignedArray mat , ReadOnlySpan < int > rgposSrc , AlignedArray src ,
297+ int posMin , int iposMin , int iposEnd , AlignedArray dst , int crow , int ccol )
317298 {
318- Contracts . Assert ( HasCompatibleAlignment ( mat ) ) ;
319- Contracts . Assert ( HasCompatibleAlignment ( src ) ) ;
320- Contracts . Assert ( HasCompatibleAlignment ( dst ) ) ;
299+ MatMulP ( mat . Items , rgposSrc , src . Items , posMin , iposMin , iposEnd , dst . Items , crow , ccol ) ;
300+ }
301+
302+ public static unsafe void MatMulP ( ReadOnlySpan < float > mat , ReadOnlySpan < int > rgposSrc , ReadOnlySpan < float > src ,
303+ int posMin , int iposMin , int iposEnd , Span < float > dst , int crow , int ccol )
304+ {
305+ Contracts . Assert ( crow % 8 == 0 ) ;
306+ Contracts . Assert ( ccol % 8 == 0 ) ;
321307
322308 // REVIEW: For extremely sparse inputs, interchanging the loops would
323309 // likely be more efficient.
324- fixed ( float * pSrcStart = & src . Items [ 0 ] )
325- fixed ( float * pDstStart = & dst . Items [ 0 ] )
326- fixed ( float * pMatStart = & mat . Items [ 0 ] )
327- fixed ( int * pposSrc = & rgposSrc [ 0 ] )
310+ fixed ( float * psrc = & MemoryMarshal . GetReference ( src ) )
311+ fixed ( float * pdst = & MemoryMarshal . GetReference ( dst ) )
312+ fixed ( float * pmat = & MemoryMarshal . GetReference ( mat ) )
313+ fixed ( int * pposSrc = & MemoryMarshal . GetReference ( rgposSrc ) )
314+ fixed ( uint * pLeadingAlignmentMask = & LeadingAlignmentMask [ 0 ] )
315+ fixed ( uint * pTrailingAlignmentMask = & TrailingAlignmentMask [ 0 ] )
328316 {
329- float * psrc = GetAlignedBase ( src , pSrcStart ) ;
330- float * pdst = GetAlignedBase ( dst , pDstStart ) ;
331- float * pmat = GetAlignedBase ( mat , pMatStart ) ;
332-
333317 int * pposMin = pposSrc + iposMin ;
334318 int * pposEnd = pposSrc + iposEnd ;
335319 float * pDstEnd = pdst + crow ;
336320 float * pm0 = pmat - posMin ;
337321 float * pSrcCurrent = psrc - posMin ;
338322 float * pDstCurrent = pdst ;
339323
340- while ( pDstCurrent < pDstEnd )
324+ nuint address = ( nuint ) ( pDstCurrent ) ;
325+ int misalignment = ( int ) ( address % 32 ) ;
326+ int length = crow ;
327+ int remainder = 0 ;
328+
329+ if ( ( misalignment & 3 ) != 0 )
330+ {
331+ while ( pDstCurrent < pDstEnd )
332+ {
333+ Avx . Store ( pDstCurrent , SparseMultiplicationAcrossRow ( ) ) ;
334+ pDstCurrent += 8 ;
335+ pm0 += 8 * ccol ;
336+ }
337+ }
338+ else
339+ {
340+ if ( misalignment != 0 )
341+ {
342+ misalignment >>= 2 ;
343+ misalignment = 8 - misalignment ;
344+
345+ Vector256 < float > mask = Avx . LoadVector256 ( ( ( float * ) ( pLeadingAlignmentMask ) ) + ( misalignment * 8 ) ) ;
346+
347+ float * pm1 = pm0 + ccol ;
348+ float * pm2 = pm1 + ccol ;
349+ float * pm3 = pm2 + ccol ;
350+ Vector256 < float > result = Avx . SetZeroVector256 < float > ( ) ;
351+
352+ int * ppos = pposMin ;
353+
354+ while ( ppos < pposEnd )
355+ {
356+ int col1 = * ppos ;
357+ int col2 = col1 + 4 * ccol ;
358+ Vector256 < float > x1 = Avx . SetVector256 ( pm3 [ col2 ] , pm2 [ col2 ] , pm1 [ col2 ] , pm0 [ col2 ] ,
359+ pm3 [ col1 ] , pm2 [ col1 ] , pm1 [ col1 ] , pm0 [ col1 ] ) ;
360+
361+ x1 = Avx . And ( mask , x1 ) ;
362+ Vector256 < float > x2 = Avx . SetAllVector256 ( pSrcCurrent [ col1 ] ) ;
363+ result = MultiplyAdd ( x2 , x1 , result ) ;
364+ ppos ++ ;
365+ }
366+
367+ Avx . Store ( pDstCurrent , result ) ;
368+ pDstCurrent += misalignment ;
369+ pm0 += misalignment * ccol ;
370+ length -= misalignment ;
371+ }
372+
373+ if ( length > 7 )
374+ {
375+ remainder = length % 8 ;
376+ while ( pDstCurrent < pDstEnd )
377+ {
378+ Avx . Store ( pDstCurrent , SparseMultiplicationAcrossRow ( ) ) ;
379+ pDstCurrent += 8 ;
380+ pm0 += 8 * ccol ;
381+ }
382+ }
383+ else
384+ {
385+ remainder = length ;
386+ }
387+
388+ if ( remainder != 0 )
389+ {
390+ pDstCurrent -= ( 8 - remainder ) ;
391+ pm0 -= ( 8 - remainder ) * ccol ;
392+ Vector256 < float > trailingMask = Avx . LoadVector256 ( ( ( float * ) ( pTrailingAlignmentMask ) ) + ( remainder * 8 ) ) ;
393+ Vector256 < float > leadingMask = Avx . LoadVector256 ( ( ( float * ) ( pLeadingAlignmentMask ) ) + ( ( 8 - remainder ) * 8 ) ) ;
394+
395+ float * pm1 = pm0 + ccol ;
396+ float * pm2 = pm1 + ccol ;
397+ float * pm3 = pm2 + ccol ;
398+ Vector256 < float > result = Avx . SetZeroVector256 < float > ( ) ;
399+
400+ int * ppos = pposMin ;
401+
402+ while ( ppos < pposEnd )
403+ {
404+ int col1 = * ppos ;
405+ int col2 = col1 + 4 * ccol ;
406+ Vector256 < float > x1 = Avx . SetVector256 ( pm3 [ col2 ] , pm2 [ col2 ] , pm1 [ col2 ] , pm0 [ col2 ] ,
407+ pm3 [ col1 ] , pm2 [ col1 ] , pm1 [ col1 ] , pm0 [ col1 ] ) ;
408+ x1 = Avx . And ( x1 , trailingMask ) ;
409+
410+ Vector256 < float > x2 = Avx . SetAllVector256 ( pSrcCurrent [ col1 ] ) ;
411+ result = MultiplyAdd ( x2 , x1 , result ) ;
412+ ppos ++ ;
413+ }
414+
415+ result = Avx . Add ( result , Avx . And ( leadingMask , Avx . LoadVector256 ( pDstCurrent ) ) ) ;
416+
417+ Avx . Store ( pDstCurrent , result ) ;
418+ pDstCurrent += 8 ;
419+ pm0 += 8 * ccol ;
420+ }
421+ }
422+
423+ Vector256 < float > SparseMultiplicationAcrossRow ( )
341424 {
342425 float * pm1 = pm0 + ccol ;
343426 float * pm2 = pm1 + ccol ;
@@ -351,33 +434,30 @@ public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArra
351434 int col1 = * ppos ;
352435 int col2 = col1 + 4 * ccol ;
353436 Vector256 < float > x1 = Avx . SetVector256 ( pm3 [ col2 ] , pm2 [ col2 ] , pm1 [ col2 ] , pm0 [ col2 ] ,
354- pm3 [ col1 ] , pm2 [ col1 ] , pm1 [ col1 ] , pm0 [ col1 ] ) ;
437+ pm3 [ col1 ] , pm2 [ col1 ] , pm1 [ col1 ] , pm0 [ col1 ] ) ;
355438 Vector256 < float > x2 = Avx . SetAllVector256 ( pSrcCurrent [ col1 ] ) ;
356439 result = MultiplyAdd ( x2 , x1 , result ) ;
357-
358440 ppos ++ ;
359441 }
360442
361- Avx . StoreAligned ( pDstCurrent , result ) ;
362- pDstCurrent += 8 ;
363- pm0 += 8 * ccol ;
443+ return result ;
364444 }
365445 }
366446 }
367447
368- public static unsafe void MatMulTranX ( AlignedArray mat , AlignedArray src , AlignedArray dst , int crow , int ccol )
448+ public static unsafe void MatMulTran ( AlignedArray mat , AlignedArray src , AlignedArray dst , int crow , int ccol )
369449 {
370- Contracts . Assert ( crow % 4 == 0 ) ;
371- Contracts . Assert ( ccol % 4 == 0 ) ;
372-
373- MatMulTranX ( mat . Items , src . Items , dst . Items , crow , ccol ) ;
450+ MatMulTran ( mat . Items , src . Items , dst . Items , crow , ccol ) ;
374451 }
375452
376- public static unsafe void MatMulTranX ( float [ ] mat , float [ ] src , float [ ] dst , int crow , int ccol )
453+ public static unsafe void MatMulTran ( ReadOnlySpan < float > mat , ReadOnlySpan < float > src , Span < float > dst , int crow , int ccol )
377454 {
378- fixed ( float * psrc = & src [ 0 ] )
379- fixed ( float * pdst = & dst [ 0 ] )
380- fixed ( float * pmat = & mat [ 0 ] )
455+ Contracts . Assert ( crow % 4 == 0 ) ;
456+ Contracts . Assert ( ccol % 4 == 0 ) ;
457+
458+ fixed ( float * psrc = & MemoryMarshal . GetReference ( src ) )
459+ fixed ( float * pdst = & MemoryMarshal . GetReference ( dst ) )
460+ fixed ( float * pmat = & MemoryMarshal . GetReference ( mat ) )
381461 fixed ( uint * pLeadingAlignmentMask = & LeadingAlignmentMask [ 0 ] )
382462 fixed ( uint * pTrailingAlignmentMask = & TrailingAlignmentMask [ 0 ] )
383463 {
0 commit comments