From 1f6305478145b5aa6b8ce8f087d5e79b8663756b Mon Sep 17 00:00:00 2001 From: Eugene Gusarov Date: Wed, 13 Mar 2019 09:45:13 +0000 Subject: [PATCH 1/2] CpuMath Enhancement: Make bound checking of loops in hardware intrinsics more efficient --- src/Microsoft.ML.CpuMath/AvxIntrinsics.cs | 40 +++++++++++------------ src/Microsoft.ML.CpuMath/SseIntrinsics.cs | 34 +++++++++---------- 2 files changed, 37 insertions(+), 37 deletions(-) diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index 195b2fc1ca..fc998e3c15 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -425,7 +425,7 @@ public static unsafe void AddScalarU(float scalar, Span dst) Vector256 scalarVector256 = Vector256.Create(scalar); - while (pDstCurrent + 8 <= pDstEnd) + while (pDstCurrent <= pDstEnd - 8) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); dstVector = Avx.Add(dstVector, scalarVector256); @@ -577,7 +577,7 @@ public static unsafe void ScaleSrcU(float scale, ReadOnlySpan src, Span scaleVector256 = Vector256.Create(scale); - while (pDstCurrent + 8 <= pDstEnd) + while (pDstCurrent <= pDstEnd - 8) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Multiply(srcVector, scaleVector256); @@ -623,7 +623,7 @@ public static unsafe void ScaleAddU(float a, float b, Span dst) Vector256 a256 = Vector256.Create(a); Vector256 b256 = Vector256.Create(b); - while (pDstCurrent + 8 <= pDstEnd) + while (pDstCurrent <= pDstEnd - 8) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); dstVector = Avx.Add(dstVector, b256); @@ -671,7 +671,7 @@ public static unsafe void AddScaleU(float scale, ReadOnlySpan src, Span scaleVector256 = Vector256.Create(scale); - while (pDstCurrent + 8 <= pEnd) + while (pDstCurrent <= pEnd - 8) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -728,7 +728,7 @@ public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan src, Re Vector256 scaleVector256 = Vector256.Create(scale); - while (pResCurrent + 8 <= pResEnd) + while (pResCurrent <= pResEnd - 8) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector); @@ -785,7 +785,7 @@ public static unsafe void AddScaleSU(float scale, ReadOnlySpan src, ReadO Vector256 scaleVector256 = Vector256.Create(scale); - while (pIdxCurrent + 8 <= pEnd) + while (pIdxCurrent <= pEnd - 8) { Vector256 dstVector = Load8(pDstCurrent, pIdxCurrent); dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector); @@ -831,7 +831,7 @@ public static unsafe void AddU(ReadOnlySpan src, Span dst, int cou float* pDstCurrent = pdst; float* pEnd = psrc + count; - while (pSrcCurrent + 8 <= pEnd) + while (pSrcCurrent <= pEnd - 8) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -883,7 +883,7 @@ public static unsafe void AddSU(ReadOnlySpan src, ReadOnlySpan idx, float* pDstCurrent = pdst; int* pEnd = pidx + count; - while (pIdxCurrent + 8 <= pEnd) + while (pIdxCurrent <= pEnd - 8) { Vector256 dstVector = Load8(pDstCurrent, pIdxCurrent); Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); @@ -931,7 +931,7 @@ public static unsafe void MulElementWiseU(ReadOnlySpan src1, ReadOnlySpan float* pDstCurrent = pdst; float* pEnd = pdst + count; - while (pDstCurrent + 8 <= pEnd) + while (pDstCurrent <= pEnd - 8) { Vector256 src1Vector = Avx.LoadVector256(pSrc1Current); Vector256 src2Vector = Avx.LoadVector256(pSrc2Current); @@ -1066,7 +1066,7 @@ public static unsafe float SumSqU(ReadOnlySpan src) Vector256 result256 = Vector256.Zero; - while (pSrcCurrent + 8 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 8) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); result256 = MultiplyAdd(srcVector, srcVector, result256); @@ -1111,7 +1111,7 @@ public static unsafe float SumSqDiffU(float mean, ReadOnlySpan src) Vector256 result256 = Vector256.Zero; Vector256 meanVector256 = Vector256.Create(mean); - while (pSrcCurrent + 8 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 8) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Subtract(srcVector, meanVector256); @@ -1158,7 +1158,7 @@ public static unsafe float SumAbsU(ReadOnlySpan src) Vector256 result256 = Vector256.Zero; - while (pSrcCurrent + 8 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 8) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); result256 = Avx.Add(result256, Avx.And(srcVector, _absMask256)); @@ -1203,7 +1203,7 @@ public static unsafe float SumAbsDiffU(float mean, ReadOnlySpan src) Vector256 result256 = Vector256.Zero; Vector256 meanVector256 = Vector256.Create(mean); - while (pSrcCurrent + 8 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 8) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Subtract(srcVector, meanVector256); @@ -1251,7 +1251,7 @@ public static unsafe float MaxAbsU(ReadOnlySpan src) Vector256 result256 = Vector256.Zero; - while (pSrcCurrent + 8 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 8) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); result256 = Avx.Max(result256, Avx.And(srcVector, _absMask256)); @@ -1296,7 +1296,7 @@ public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan src) Vector256 result256 = Vector256.Zero; Vector256 meanVector256 = Vector256.Create(mean); - while (pSrcCurrent + 8 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 8) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Subtract(srcVector, meanVector256); @@ -1348,7 +1348,7 @@ public static unsafe float DotU(ReadOnlySpan src, ReadOnlySpan dst Vector256 result256 = Vector256.Zero; - while (pSrcCurrent + 8 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 8) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); result256 = MultiplyAdd(pSrcCurrent, dstVector, result256); @@ -1405,7 +1405,7 @@ public static unsafe float DotSU(ReadOnlySpan src, ReadOnlySpan ds Vector256 result256 = Vector256.Zero; - while (pIdxCurrent + 8 <= pIdxEnd) + while (pIdxCurrent <= pIdxEnd - 8) { Vector256 srcVector = Load8(pSrcCurrent, pIdxCurrent); result256 = MultiplyAdd(pDstCurrent, srcVector, result256); @@ -1459,7 +1459,7 @@ public static unsafe float Dist2(ReadOnlySpan src, ReadOnlySpan ds Vector256 sqDistanceVector256 = Vector256.Zero; - while (pSrcCurrent + 8 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 8) { Vector256 distanceVector = Avx.Subtract(Avx.LoadVector256(pSrcCurrent), Avx.LoadVector256(pDstCurrent)); @@ -1514,7 +1514,7 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlyS Vector256 xPrimal256 = Vector256.Create(primalUpdate); Vector256 xThreshold256 = Vector256.Create(threshold); - while (pSrcCurrent + 8 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 8) { Vector256 xDst1 = Avx.LoadVector256(pDst1Current); xDst1 = MultiplyAdd(pSrcCurrent, xPrimal256, xDst1); @@ -1574,7 +1574,7 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnly Vector256 xPrimal256 = Vector256.Create(primalUpdate); Vector256 xThreshold = Vector256.Create(threshold); - while (pIdxCurrent + 8 <= pIdxEnd) + while (pIdxCurrent <= pIdxEnd - 8) { Vector256 xDst1 = Load8(pdst1, pIdxCurrent); xDst1 = MultiplyAdd(pSrcCurrent, xPrimal256, xDst1); diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index b569fba353..69ee7d6e88 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -565,7 +565,7 @@ public static unsafe void AddScaleU(float scale, ReadOnlySpan src, Span scaleVector = Vector128.Create(scale); - while (pDstCurrent + 4 <= pEnd) + while (pDstCurrent <= pEnd - 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -609,7 +609,7 @@ public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan src, Re Vector128 scaleVector = Vector128.Create(scale); - while (pResCurrent + 4 <= pResEnd) + while (pResCurrent <= pResEnd - 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -653,7 +653,7 @@ public static unsafe void AddScaleSU(float scale, ReadOnlySpan src, ReadO Vector128 scaleVector = Vector128.Create(scale); - while (pIdxCurrent + 4 <= pEnd) + while (pIdxCurrent <= pEnd - 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Load4(pDstCurrent, pIdxCurrent); @@ -687,7 +687,7 @@ public static unsafe void AddU(ReadOnlySpan src, Span dst, int cou float* pDstCurrent = pdst; float* pEnd = psrc + count; - while (pSrcCurrent + 4 <= pEnd) + while (pSrcCurrent <= pEnd - 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -727,7 +727,7 @@ public static unsafe void AddSU(ReadOnlySpan src, ReadOnlySpan idx, float* pDstCurrent = pdst; int* pEnd = pidx + count; - while (pIdxCurrent + 4 <= pEnd) + while (pIdxCurrent <= pEnd - 4) { Vector128 dstVector = Load4(pDstCurrent, pIdxCurrent); Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); @@ -763,7 +763,7 @@ public static unsafe void MulElementWiseU(ReadOnlySpan src1, ReadOnlySpan float* pDstCurrent = pdst; float* pEnd = pdst + count; - while (pDstCurrent + 4 <= pEnd) + while (pDstCurrent <= pEnd - 4) { Vector128 src1Vector = Sse.LoadVector128(pSrc1Current); Vector128 src2Vector = Sse.LoadVector128(pSrc2Current); @@ -883,7 +883,7 @@ public static unsafe float SumSqU(ReadOnlySpan src) Vector128 result = Vector128.Zero; - while (pSrcCurrent + 4 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); result = Sse.Add(result, Sse.Multiply(srcVector, srcVector)); @@ -915,7 +915,7 @@ public static unsafe float SumSqDiffU(float mean, ReadOnlySpan src) Vector128 result = Vector128.Zero; Vector128 meanVector = Vector128.Create(mean); - while (pSrcCurrent + 4 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); srcVector = Sse.Subtract(srcVector, meanVector); @@ -948,7 +948,7 @@ public static unsafe float SumAbsU(ReadOnlySpan src) Vector128 result = Vector128.Zero; - while (pSrcCurrent + 4 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); result = Sse.Add(result, Sse.And(srcVector, AbsMask128)); @@ -980,7 +980,7 @@ public static unsafe float SumAbsDiffU(float mean, ReadOnlySpan src) Vector128 result = Vector128.Zero; Vector128 meanVector = Vector128.Create(mean); - while (pSrcCurrent + 4 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); srcVector = Sse.Subtract(srcVector, meanVector); @@ -1013,7 +1013,7 @@ public static unsafe float MaxAbsU(ReadOnlySpan src) Vector128 result = Vector128.Zero; - while (pSrcCurrent + 4 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); result = Sse.Max(result, Sse.And(srcVector, AbsMask128)); @@ -1045,7 +1045,7 @@ public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan src) Vector128 result = Vector128.Zero; Vector128 meanVector = Vector128.Create(mean); - while (pSrcCurrent + 4 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); srcVector = Sse.Subtract(srcVector, meanVector); @@ -1082,7 +1082,7 @@ public static unsafe float DotU(ReadOnlySpan src, ReadOnlySpan dst Vector128 result = Vector128.Zero; - while (pSrcCurrent + 4 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -1126,7 +1126,7 @@ public static unsafe float DotSU(ReadOnlySpan src, ReadOnlySpan ds Vector128 result = Vector128.Zero; - while (pIdxCurrent + 4 <= pIdxEnd) + while (pIdxCurrent <= pIdxEnd - 4) { Vector128 srcVector = Load4(pSrcCurrent, pIdxCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -1167,7 +1167,7 @@ public static unsafe float Dist2(ReadOnlySpan src, ReadOnlySpan ds Vector128 sqDistanceVector = Vector128.Zero; - while (pSrcCurrent + 4 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 4) { Vector128 distanceVector = Sse.Subtract(Sse.LoadVector128(pSrcCurrent), Sse.LoadVector128(pDstCurrent)); @@ -1210,7 +1210,7 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlyS Vector128 signMask = Vector128.Create(-0.0f); // 0x8000 0000 Vector128 xThreshold = Vector128.Create(threshold); - while (pSrcCurrent + 4 <= pSrcEnd) + while (pSrcCurrent <= pSrcEnd - 4) { Vector128 xSrc = Sse.LoadVector128(pSrcCurrent); @@ -1255,7 +1255,7 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnly Vector128 signMask = Vector128.Create(-0.0f); // 0x8000 0000 Vector128 xThreshold = Vector128.Create(threshold); - while (pIdxCurrent + 4 <= pIdxEnd) + while (pIdxCurrent <= pIdxEnd - 4) { Vector128 xSrc = Sse.LoadVector128(pSrcCurrent); From c07e0b2974f3e54115e9830c4e6c799b5404acc6 Mon Sep 17 00:00:00 2001 From: Eugene Gusarov Date: Sat, 30 Mar 2019 13:46:20 +0300 Subject: [PATCH 2/2] CpuMath Enhancement: Make bound checking of loops in hardware intrinsics more efficient - store vectorization end --- src/Microsoft.ML.CpuMath/AvxIntrinsics.cs | 60 +++++++++++++++-------- src/Microsoft.ML.CpuMath/SseIntrinsics.cs | 51 ++++++++++++------- 2 files changed, 74 insertions(+), 37 deletions(-) diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index fc998e3c15..46c80704a1 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -422,10 +422,11 @@ public static unsafe void AddScalarU(float scalar, Span dst) float* pDstEnd = pdst + dst.Length; float* pDstCurrent = pdst; float* pVectorizationEnd = pDstEnd - 4; + float* pAvxVectorizationEnd = pDstEnd - 8; Vector256 scalarVector256 = Vector256.Create(scalar); - while (pDstCurrent <= pDstEnd - 8) + while (pDstCurrent <= pAvxVectorizationEnd) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); dstVector = Avx.Add(dstVector, scalarVector256); @@ -574,10 +575,11 @@ public static unsafe void ScaleSrcU(float scale, ReadOnlySpan src, Span scaleVector256 = Vector256.Create(scale); - while (pDstCurrent <= pDstEnd - 8) + while (pDstCurrent <= pAvxVectorizationEnd) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Multiply(srcVector, scaleVector256); @@ -619,11 +621,12 @@ public static unsafe void ScaleAddU(float a, float b, Span dst) float* pDstEnd = pdst + dst.Length; float* pDstCurrent = pdst; float* pVectorizationEnd = pDstEnd - 4; + float* pAvxVectorizationEnd = pDstEnd - 8; Vector256 a256 = Vector256.Create(a); Vector256 b256 = Vector256.Create(b); - while (pDstCurrent <= pDstEnd - 8) + while (pDstCurrent <= pAvxVectorizationEnd) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); dstVector = Avx.Add(dstVector, b256); @@ -668,10 +671,11 @@ public static unsafe void AddScaleU(float scale, ReadOnlySpan src, Span scaleVector256 = Vector256.Create(scale); - while (pDstCurrent <= pEnd - 8) + while (pDstCurrent <= pVectorizationEnd) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -722,13 +726,14 @@ public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan src, Re fixed (float* pres = &MemoryMarshal.GetReference(result)) { float* pResEnd = pres + count; + float* pVectorizationEnd = pResEnd - 8; float* pSrcCurrent = psrc; float* pDstCurrent = pdst; float* pResCurrent = pres; Vector256 scaleVector256 = Vector256.Create(scale); - while (pResCurrent <= pResEnd - 8) + while (pResCurrent <= pVectorizationEnd) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector); @@ -782,10 +787,11 @@ public static unsafe void AddScaleSU(float scale, ReadOnlySpan src, ReadO int* pIdxCurrent = pidx; float* pDstCurrent = pdst; int* pEnd = pidx + count; + int* pVectorizationEnd = pEnd - 8; Vector256 scaleVector256 = Vector256.Create(scale); - while (pIdxCurrent <= pEnd - 8) + while (pIdxCurrent <= pVectorizationEnd) { Vector256 dstVector = Load8(pDstCurrent, pIdxCurrent); dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector); @@ -830,8 +836,9 @@ public static unsafe void AddU(ReadOnlySpan src, Span dst, int cou float* pSrcCurrent = psrc; float* pDstCurrent = pdst; float* pEnd = psrc + count; + float* pVectorizationEnd = pEnd - 8; - while (pSrcCurrent <= pEnd - 8) + while (pSrcCurrent <= pVectorizationEnd) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -882,8 +889,9 @@ public static unsafe void AddSU(ReadOnlySpan src, ReadOnlySpan idx, int* pIdxCurrent = pidx; float* pDstCurrent = pdst; int* pEnd = pidx + count; + int* pVectorizationEnd = pEnd - 8; - while (pIdxCurrent <= pEnd - 8) + while (pIdxCurrent <= pVectorizationEnd) { Vector256 dstVector = Load8(pDstCurrent, pIdxCurrent); Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); @@ -930,8 +938,9 @@ public static unsafe void MulElementWiseU(ReadOnlySpan src1, ReadOnlySpan float* pSrc2Current = psrc2; float* pDstCurrent = pdst; float* pEnd = pdst + count; + float* pVectorizationEnd = pEnd - 8; - while (pDstCurrent <= pEnd - 8) + while (pDstCurrent <= pVectorizationEnd) { Vector256 src1Vector = Avx.LoadVector256(pSrc1Current); Vector256 src2Vector = Avx.LoadVector256(pSrc2Current); @@ -1062,11 +1071,12 @@ public static unsafe float SumSqU(ReadOnlySpan src) fixed (float* psrc = &MemoryMarshal.GetReference(src)) { float* pSrcEnd = psrc + src.Length; + float* pVectorizationEnd = pSrcEnd - 8; float* pSrcCurrent = psrc; Vector256 result256 = Vector256.Zero; - while (pSrcCurrent <= pSrcEnd - 8) + while (pSrcCurrent <= pVectorizationEnd) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); result256 = MultiplyAdd(srcVector, srcVector, result256); @@ -1106,12 +1116,13 @@ public static unsafe float SumSqDiffU(float mean, ReadOnlySpan src) fixed (float* psrc = &MemoryMarshal.GetReference(src)) { float* pSrcEnd = psrc + src.Length; + float* pVectorizationEnd = pSrcEnd - 8; float* pSrcCurrent = psrc; Vector256 result256 = Vector256.Zero; Vector256 meanVector256 = Vector256.Create(mean); - while (pSrcCurrent <= pSrcEnd - 8) + while (pSrcCurrent <= pVectorizationEnd) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Subtract(srcVector, meanVector256); @@ -1154,11 +1165,12 @@ public static unsafe float SumAbsU(ReadOnlySpan src) fixed (float* psrc = &MemoryMarshal.GetReference(src)) { float* pSrcEnd = psrc + src.Length; + float* pVectorizationEnd = pSrcEnd - 8; float* pSrcCurrent = psrc; Vector256 result256 = Vector256.Zero; - while (pSrcCurrent <= pSrcEnd - 8) + while (pSrcCurrent <= pVectorizationEnd) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); result256 = Avx.Add(result256, Avx.And(srcVector, _absMask256)); @@ -1198,12 +1210,13 @@ public static unsafe float SumAbsDiffU(float mean, ReadOnlySpan src) fixed (float* psrc = &MemoryMarshal.GetReference(src)) { float* pSrcEnd = psrc + src.Length; + float* pVectorizationEnd = pSrcEnd - 8; float* pSrcCurrent = psrc; Vector256 result256 = Vector256.Zero; Vector256 meanVector256 = Vector256.Create(mean); - while (pSrcCurrent <= pSrcEnd - 8) + while (pSrcCurrent <= pVectorizationEnd) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Subtract(srcVector, meanVector256); @@ -1247,11 +1260,12 @@ public static unsafe float MaxAbsU(ReadOnlySpan src) fixed (float* psrc = &MemoryMarshal.GetReference(src)) { float* pSrcEnd = psrc + src.Length; + float* pVectorizationEnd = pSrcEnd - 8; float* pSrcCurrent = psrc; Vector256 result256 = Vector256.Zero; - while (pSrcCurrent <= pSrcEnd - 8) + while (pSrcCurrent <= pVectorizationEnd) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); result256 = Avx.Max(result256, Avx.And(srcVector, _absMask256)); @@ -1291,12 +1305,13 @@ public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan src) fixed (float* psrc = &MemoryMarshal.GetReference(src)) { float* pSrcEnd = psrc + src.Length; + float* pVectorizationEnd = pSrcEnd - 8; float* pSrcCurrent = psrc; Vector256 result256 = Vector256.Zero; Vector256 meanVector256 = Vector256.Create(mean); - while (pSrcCurrent <= pSrcEnd - 8) + while (pSrcCurrent <= pVectorizationEnd) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Subtract(srcVector, meanVector256); @@ -1345,10 +1360,11 @@ public static unsafe float DotU(ReadOnlySpan src, ReadOnlySpan dst float* pSrcCurrent = psrc; float* pDstCurrent = pdst; float* pSrcEnd = psrc + count; + float* pVectorizationEnd = pSrcEnd - 8; Vector256 result256 = Vector256.Zero; - while (pSrcCurrent <= pSrcEnd - 8) + while (pSrcCurrent <= pVectorizationEnd) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); result256 = MultiplyAdd(pSrcCurrent, dstVector, result256); @@ -1402,10 +1418,11 @@ public static unsafe float DotSU(ReadOnlySpan src, ReadOnlySpan ds float* pDstCurrent = pdst; int* pIdxCurrent = pidx; int* pIdxEnd = pidx + count; + int* pVectorizationEnd = pIdxEnd - 8; Vector256 result256 = Vector256.Zero; - while (pIdxCurrent <= pIdxEnd - 8) + while (pIdxCurrent <= pVectorizationEnd) { Vector256 srcVector = Load8(pSrcCurrent, pIdxCurrent); result256 = MultiplyAdd(pDstCurrent, srcVector, result256); @@ -1456,10 +1473,11 @@ public static unsafe float Dist2(ReadOnlySpan src, ReadOnlySpan ds float* pSrcCurrent = psrc; float* pDstCurrent = pdst; float* pSrcEnd = psrc + count; + float* pVectorizationEnd = pSrcEnd - 8; Vector256 sqDistanceVector256 = Vector256.Zero; - while (pSrcCurrent <= pSrcEnd - 8) + while (pSrcCurrent <= pVectorizationEnd) { Vector256 distanceVector = Avx.Subtract(Avx.LoadVector256(pSrcCurrent), Avx.LoadVector256(pDstCurrent)); @@ -1507,6 +1525,7 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlyS fixed (float* pdst2 = &MemoryMarshal.GetReference(w)) { float* pSrcEnd = psrc + count; + float* pVectorizationEnd = pSrcEnd - 8; float* pSrcCurrent = psrc; float* pDst1Current = pdst1; float* pDst2Current = pdst2; @@ -1514,7 +1533,7 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlyS Vector256 xPrimal256 = Vector256.Create(primalUpdate); Vector256 xThreshold256 = Vector256.Create(threshold); - while (pSrcCurrent <= pSrcEnd - 8) + while (pSrcCurrent <= pVectorizationEnd) { Vector256 xDst1 = Avx.LoadVector256(pDst1Current); xDst1 = MultiplyAdd(pSrcCurrent, xPrimal256, xDst1); @@ -1568,13 +1587,14 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnly fixed (float* pdst2 = &MemoryMarshal.GetReference(w)) { int* pIdxEnd = pidx + count; + int* pVectorizationEnd = pIdxEnd - 8; float* pSrcCurrent = psrc; int* pIdxCurrent = pidx; Vector256 xPrimal256 = Vector256.Create(primalUpdate); Vector256 xThreshold = Vector256.Create(threshold); - while (pIdxCurrent <= pIdxEnd - 8) + while (pIdxCurrent <= pVectorizationEnd) { Vector256 xDst1 = Load8(pdst1, pIdxCurrent); xDst1 = MultiplyAdd(pSrcCurrent, xPrimal256, xDst1); diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index 69ee7d6e88..6060bd50d0 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -562,10 +562,11 @@ public static unsafe void AddScaleU(float scale, ReadOnlySpan src, Span scaleVector = Vector128.Create(scale); - while (pDstCurrent <= pEnd - 4) + while (pDstCurrent <= pVectorizationEnd) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -606,10 +607,11 @@ public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan src, Re float* pSrcCurrent = psrc; float* pDstCurrent = pdst; float* pResCurrent = pres; + float* pVectorizationEnd = pResEnd - 4; Vector128 scaleVector = Vector128.Create(scale); - while (pResCurrent <= pResEnd - 4) + while (pResCurrent <= pVectorizationEnd) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -650,10 +652,11 @@ public static unsafe void AddScaleSU(float scale, ReadOnlySpan src, ReadO int* pIdxCurrent = pidx; float* pDstCurrent = pdst; int* pEnd = pidx + count; + int* pVectorizationEnd = pEnd - 4; Vector128 scaleVector = Vector128.Create(scale); - while (pIdxCurrent <= pEnd - 4) + while (pIdxCurrent <= pVectorizationEnd) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Load4(pDstCurrent, pIdxCurrent); @@ -686,8 +689,9 @@ public static unsafe void AddU(ReadOnlySpan src, Span dst, int cou float* pSrcCurrent = psrc; float* pDstCurrent = pdst; float* pEnd = psrc + count; + float* pVectorizationEnd = pEnd - 4; - while (pSrcCurrent <= pEnd - 4) + while (pSrcCurrent <= pVectorizationEnd) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -726,8 +730,9 @@ public static unsafe void AddSU(ReadOnlySpan src, ReadOnlySpan idx, int* pIdxCurrent = pidx; float* pDstCurrent = pdst; int* pEnd = pidx + count; + int* pVectorizationEnd = pEnd - 4; - while (pIdxCurrent <= pEnd - 4) + while (pIdxCurrent <= pVectorizationEnd) { Vector128 dstVector = Load4(pDstCurrent, pIdxCurrent); Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); @@ -762,8 +767,9 @@ public static unsafe void MulElementWiseU(ReadOnlySpan src1, ReadOnlySpan float* pSrc2Current = psrc2; float* pDstCurrent = pdst; float* pEnd = pdst + count; + float* pVectorizationEnd = pEnd - 4; - while (pDstCurrent <= pEnd - 4) + while (pDstCurrent <= pVectorizationEnd) { Vector128 src1Vector = Sse.LoadVector128(pSrc1Current); Vector128 src2Vector = Sse.LoadVector128(pSrc2Current); @@ -879,11 +885,12 @@ public static unsafe float SumSqU(ReadOnlySpan src) fixed (float* psrc = &MemoryMarshal.GetReference(src)) { float* pSrcEnd = psrc + src.Length; + float* pVectorizationEnd = pSrcEnd - 4; float* pSrcCurrent = psrc; Vector128 result = Vector128.Zero; - while (pSrcCurrent <= pSrcEnd - 4) + while (pSrcCurrent <= pVectorizationEnd) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); result = Sse.Add(result, Sse.Multiply(srcVector, srcVector)); @@ -910,12 +917,13 @@ public static unsafe float SumSqDiffU(float mean, ReadOnlySpan src) fixed (float* psrc = &MemoryMarshal.GetReference(src)) { float* pSrcEnd = psrc + src.Length; + float* pVectorizationEnd = pSrcEnd - 4; float* pSrcCurrent = psrc; Vector128 result = Vector128.Zero; Vector128 meanVector = Vector128.Create(mean); - while (pSrcCurrent <= pSrcEnd - 4) + while (pSrcCurrent <= pVectorizationEnd) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); srcVector = Sse.Subtract(srcVector, meanVector); @@ -945,10 +953,11 @@ public static unsafe float SumAbsU(ReadOnlySpan src) { float* pSrcEnd = psrc + src.Length; float* pSrcCurrent = psrc; + float* pVectorizationEnd = pSrcEnd - 4; Vector128 result = Vector128.Zero; - while (pSrcCurrent <= pSrcEnd - 4) + while (pSrcCurrent <= pVectorizationEnd) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); result = Sse.Add(result, Sse.And(srcVector, AbsMask128)); @@ -975,12 +984,13 @@ public static unsafe float SumAbsDiffU(float mean, ReadOnlySpan src) fixed (float* psrc = &MemoryMarshal.GetReference(src)) { float* pSrcEnd = psrc + src.Length; + float* pVectorizationEnd = pSrcEnd - 4; float* pSrcCurrent = psrc; Vector128 result = Vector128.Zero; Vector128 meanVector = Vector128.Create(mean); - while (pSrcCurrent <= pSrcEnd - 4) + while (pSrcCurrent <= pVectorizationEnd) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); srcVector = Sse.Subtract(srcVector, meanVector); @@ -1009,11 +1019,12 @@ public static unsafe float MaxAbsU(ReadOnlySpan src) fixed (float* psrc = &MemoryMarshal.GetReference(src)) { float* pSrcEnd = psrc + src.Length; + float* pVectorizationEnd = pSrcEnd - 4; float* pSrcCurrent = psrc; Vector128 result = Vector128.Zero; - while (pSrcCurrent <= pSrcEnd - 4) + while (pSrcCurrent <= pVectorizationEnd) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); result = Sse.Max(result, Sse.And(srcVector, AbsMask128)); @@ -1040,12 +1051,13 @@ public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan src) fixed (float* psrc = &MemoryMarshal.GetReference(src)) { float* pSrcEnd = psrc + src.Length; + float* pVectorizationEnd = pSrcEnd - 4; float* pSrcCurrent = psrc; Vector128 result = Vector128.Zero; Vector128 meanVector = Vector128.Create(mean); - while (pSrcCurrent <= pSrcEnd - 4) + while (pSrcCurrent <= pVectorizationEnd) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); srcVector = Sse.Subtract(srcVector, meanVector); @@ -1079,10 +1091,11 @@ public static unsafe float DotU(ReadOnlySpan src, ReadOnlySpan dst float* pSrcCurrent = psrc; float* pDstCurrent = pdst; float* pSrcEnd = psrc + count; + float* pVectorizationEnd = pSrcEnd - 4; Vector128 result = Vector128.Zero; - while (pSrcCurrent <= pSrcEnd - 4) + while (pSrcCurrent <= pVectorizationEnd) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -1123,10 +1136,11 @@ public static unsafe float DotSU(ReadOnlySpan src, ReadOnlySpan ds float* pDstCurrent = pdst; int* pIdxCurrent = pidx; int* pIdxEnd = pidx + count; + int* pVectorizationEnd = pIdxEnd - 4; Vector128 result = Vector128.Zero; - while (pIdxCurrent <= pIdxEnd - 4) + while (pIdxCurrent <= pVectorizationEnd) { Vector128 srcVector = Load4(pSrcCurrent, pIdxCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -1164,10 +1178,11 @@ public static unsafe float Dist2(ReadOnlySpan src, ReadOnlySpan ds float* pSrcCurrent = psrc; float* pDstCurrent = pdst; float* pSrcEnd = psrc + count; + float* pVectorizationEnd = pSrcEnd - 4; Vector128 sqDistanceVector = Vector128.Zero; - while (pSrcCurrent <= pSrcEnd - 4) + while (pSrcCurrent <= pVectorizationEnd) { Vector128 distanceVector = Sse.Subtract(Sse.LoadVector128(pSrcCurrent), Sse.LoadVector128(pDstCurrent)); @@ -1201,6 +1216,7 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlyS fixed (float* pdst2 = &MemoryMarshal.GetReference(w)) { float* pSrcEnd = psrc + count; + float* pVectorizationEnd = pSrcEnd - 4; float* pSrcCurrent = psrc; float* pDst1Current = pdst1; float* pDst2Current = pdst2; @@ -1210,7 +1226,7 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlyS Vector128 signMask = Vector128.Create(-0.0f); // 0x8000 0000 Vector128 xThreshold = Vector128.Create(threshold); - while (pSrcCurrent <= pSrcEnd - 4) + while (pSrcCurrent <= pVectorizationEnd) { Vector128 xSrc = Sse.LoadVector128(pSrcCurrent); @@ -1249,13 +1265,14 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnly int* pIdxEnd = pidx + count; float* pSrcCurrent = psrc; int* pIdxCurrent = pidx; + int* pVectorizationEnd = pIdxEnd - 4; Vector128 xPrimal = Vector128.Create(primalUpdate); Vector128 signMask = Vector128.Create(-0.0f); // 0x8000 0000 Vector128 xThreshold = Vector128.Create(threshold); - while (pIdxCurrent <= pIdxEnd - 4) + while (pIdxCurrent <= pVectorizationEnd) { Vector128 xSrc = Sse.LoadVector128(pSrcCurrent);