Skip to content

Commit 4841b22

Browse files
committed
Merge branch 'main' into users/alexpeck/neon
2 parents 94b8968 + d8ac6f4 commit 4841b22

File tree

3 files changed

+21
-47
lines changed

3 files changed

+21
-47
lines changed

BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
namespace BitFaster.Caching.Benchmarks.Lfu
99
{
10+
#if Windows
11+
[DisassemblyDiagnoser(printSource: true, maxDepth: 4)]
12+
#endif
1013
[SimpleJob(RuntimeMoniker.Net60)]
1114
[SimpleJob(RuntimeMoniker.Net80)]
1215
[SimpleJob(RuntimeMoniker.Net90)]

BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
namespace BitFaster.Caching.Benchmarks.Lfu
99
{
10+
#if Windows
11+
[DisassemblyDiagnoser(printSource: true, maxDepth: 4)]
12+
#endif
1013
[SimpleJob(RuntimeMoniker.Net60)]
1114
[SimpleJob(RuntimeMoniker.Net80)]
1215
[SimpleJob(RuntimeMoniker.Net90)]

BitFaster.Caching/Lfu/CmSketchCore.cs

Lines changed: 15 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -272,39 +272,26 @@ private void Reset()
272272
}
273273

274274
#if !NETSTANDARD2_0
275+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
275276
private unsafe int EstimateFrequencyAvx(T value)
276277
{
277278
int blockHash = Spread(comparer.GetHashCode(value));
278279
int counterHash = Rehash(blockHash);
279280
int block = (blockHash & blockMask) << 3;
280281

281-
Vector128<int> h = Vector128.Create(counterHash);
282-
h = Avx2.ShiftRightLogicalVariable(h.AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
282+
Vector128<int> h = Avx2.ShiftRightLogicalVariable(Vector128.Create(counterHash).AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
283+
Vector128<int> index = Avx2.ShiftLeftLogical(Avx2.And(Avx2.ShiftRightLogical(h, 1), Vector128.Create(15)), 2);
284+
Vector128<int> blockOffset = Avx2.Add(Avx2.Add(Vector128.Create(block), Avx2.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6));
283285

284-
var index = Avx2.ShiftRightLogical(h, 1);
285-
index = Avx2.And(index, Vector128.Create(15)); // j - counter index
286-
Vector128<int> offset = Avx2.And(h, Vector128.Create(1));
287-
Vector128<int> blockOffset = Avx2.Add(Vector128.Create(block), offset); // i - table index
288-
blockOffset = Avx2.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1)
286+
Vector256<ulong> indexLong = Avx2.PermuteVar8x32(Vector256.Create(index, Vector128<int>.Zero), Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7)).AsUInt64();
289287

290288
#if NET6_0_OR_GREATER
291289
long* tablePtr = tableAddr;
292290
#else
293291
fixed (long* tablePtr = table)
294292
#endif
295293
{
296-
Vector256<long> tableVector = Avx2.GatherVector256(tablePtr, blockOffset, 8);
297-
index = Avx2.ShiftLeftLogical(index, 2);
298-
299-
// convert index from int to long via permute
300-
Vector256<long> indexLong = Vector256.Create(index, Vector128<int>.Zero).AsInt64();
301-
Vector256<int> permuteMask2 = Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7);
302-
indexLong = Avx2.PermuteVar8x32(indexLong.AsInt32(), permuteMask2).AsInt64();
303-
tableVector = Avx2.ShiftRightLogicalVariable(tableVector, indexLong.AsUInt64());
304-
tableVector = Avx2.And(tableVector, Vector256.Create(0xfL));
305-
306-
Vector256<int> permuteMask = Vector256.Create(0, 2, 4, 6, 1, 3, 5, 7);
307-
Vector128<ushort> count = Avx2.PermuteVar8x32(tableVector.AsInt32(), permuteMask)
294+
Vector128<ushort> count = Avx2.PermuteVar8x32(Avx2.And(Avx2.ShiftRightLogicalVariable(Avx2.GatherVector256(tablePtr, blockOffset, 8), indexLong), Vector256.Create(0xfL)).AsInt32(), Vector256.Create(0, 2, 4, 6, 1, 3, 5, 7))
308295
.GetLower()
309296
.AsUInt16();
310297

@@ -319,52 +306,33 @@ private unsafe int EstimateFrequencyAvx(T value)
319306
}
320307
}
321308

309+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
322310
private unsafe void IncrementAvx(T value)
323311
{
324312
int blockHash = Spread(comparer.GetHashCode(value));
325313
int counterHash = Rehash(blockHash);
326314
int block = (blockHash & blockMask) << 3;
327315

328-
Vector128<int> h = Vector128.Create(counterHash);
329-
h = Avx2.ShiftRightLogicalVariable(h.AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
316+
Vector128<int> h = Avx2.ShiftRightLogicalVariable(Vector128.Create(counterHash).AsUInt32(), Vector128.Create(0U, 8U, 16U, 24U)).AsInt32();
317+
Vector128<int> index = Avx2.ShiftLeftLogical(Avx2.And(Avx2.ShiftRightLogical(h, 1), Vector128.Create(15)), 2);
318+
Vector128<int> blockOffset = Avx2.Add(Avx2.Add(Vector128.Create(block), Avx2.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6));
330319

331-
Vector128<int> index = Avx2.ShiftRightLogical(h, 1);
332-
index = Avx2.And(index, Vector128.Create(15)); // j - counter index
333-
Vector128<int> offset = Avx2.And(h, Vector128.Create(1));
334-
Vector128<int> blockOffset = Avx2.Add(Vector128.Create(block), offset); // i - table index
335-
blockOffset = Avx2.Add(blockOffset, Vector128.Create(0, 2, 4, 6)); // + (i << 1)
320+
Vector256<ulong> offsetLong = Avx2.PermuteVar8x32(Vector256.Create(index, Vector128<int>.Zero), Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7)).AsUInt64();
321+
Vector256<long> mask = Avx2.ShiftLeftLogicalVariable(Vector256.Create(0xfL), offsetLong);
336322

337323
#if NET6_0_OR_GREATER
338324
long* tablePtr = tableAddr;
339325
#else
340326
fixed (long* tablePtr = table)
341327
#endif
342328
{
343-
Vector256<long> tableVector = Avx2.GatherVector256(tablePtr, blockOffset, 8);
344-
345-
// j == index
346-
index = Avx2.ShiftLeftLogical(index, 2);
347-
Vector256<long> offsetLong = Vector256.Create(index, Vector128<int>.Zero).AsInt64();
348-
349-
Vector256<int> permuteMask = Vector256.Create(0, 4, 1, 5, 2, 5, 3, 7);
350-
offsetLong = Avx2.PermuteVar8x32(offsetLong.AsInt32(), permuteMask).AsInt64();
351-
352-
// mask = (0xfL << offset)
353-
Vector256<long> fifteen = Vector256.Create(0xfL);
354-
Vector256<long> mask = Avx2.ShiftLeftLogicalVariable(fifteen, offsetLong.AsUInt64());
355-
356-
// (table[i] & mask) != mask)
357329
// Note masked is 'equal' - therefore use AndNot below
358-
Vector256<long> masked = Avx2.CompareEqual(Avx2.And(tableVector, mask), mask);
359-
360-
// 1L << offset
361-
Vector256<long> inc = Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), offsetLong.AsUInt64());
330+
Vector256<long> masked = Avx2.CompareEqual(Avx2.And(Avx2.GatherVector256(tablePtr, blockOffset, 8), mask), mask);
362331

363332
// Mask to zero out non matches (add zero below) - first operand is NOT then AND result (order matters)
364-
inc = Avx2.AndNot(masked, inc);
333+
Vector256<long> inc = Avx2.AndNot(masked, Avx2.ShiftLeftLogicalVariable(Vector256.Create(1L), offsetLong));
365334

366-
Vector256<byte> result = Avx2.CompareEqual(masked.AsByte(), Vector256<byte>.Zero);
367-
bool wasInc = Avx2.MoveMask(result.AsByte()) == unchecked((int)(0b1111_1111_1111_1111_1111_1111_1111_1111));
335+
bool wasInc = Avx2.MoveMask(Avx2.CompareEqual(masked.AsByte(), Vector256<byte>.Zero).AsByte()) == unchecked((int)(0b1111_1111_1111_1111_1111_1111_1111_1111));
368336

369337
tablePtr[blockOffset.GetElement(0)] += inc.GetElement(0);
370338
tablePtr[blockOffset.GetElement(1)] += inc.GetElement(1);

0 commit comments

Comments
 (0)