Skip to content

Commit f0f7684

Browse files
author
Saurabh Singh
committed
Support for traversing BKD tree with prefetching
1 parent 8e8e37d commit f0f7684

File tree

7 files changed

+299
-4
lines changed

7 files changed

+299
-4
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,8 @@ API Changes
984984
* GITHUB#13820, GITHUB#13825, GITHUB#13830: Corrects DataInput.readGroupVInts to be public and not-final, removes the protected
985985
DataInput.readGroupVInt method. (Zhang Chao, Robert Muir, Uwe Schindler, Dawid Weiss)
986986

987+
* GITHUB#15376, GITHUB#15197: Added prefetching in bkd tree traversal, couple of new api in PointValues visitDocIDs from a position and prepareOrVisitDocIDs to prefetch the IO before visiting docIds (Saurabh Singh)
988+
987989
New Features
988990
---------------------
989991

lucene/core/src/java/org/apache/lucene/index/PointValues.java

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import java.io.UncheckedIOException;
2121
import java.math.BigInteger;
2222
import java.net.InetAddress;
23+
import java.util.ArrayList;
24+
import java.util.List;
2325
import org.apache.lucene.document.BinaryPoint;
2426
import org.apache.lucene.document.DoublePoint;
2527
import org.apache.lucene.document.Field;
@@ -274,6 +276,19 @@ public interface PointTree extends Cloneable {
274276

275277
/** Visit all the docs and values below the current node. */
276278
void visitDocValues(IntersectVisitor visitor) throws IOException;
279+
280+
/** Visit all the docs below the node at position pos */
281+
default void visitDocIDs(long pos, IntersectVisitor visitor) throws IOException {}
282+
;
283+
284+
/**
285+
* call prefetch for docs below the current node if vistor supports prefetching otherwise visit
286+
* docIds
287+
*/
288+
default void prepareOrVisitDocIDs(IntersectVisitor visitor) throws IOException {
289+
visitDocIDs(visitor);
290+
}
291+
;
277292
}
278293

279294
/**
@@ -341,13 +356,63 @@ default void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOExcep
341356
default void grow(int count) {}
342357
}
343358

359+
/**
360+
* We can recurse the {@link PointTree} using {@TwoPhaseIntersectVisitor}. This visitor caches the
361+
* blocks during recursion, calling prefetch on required blocks. This should potentially trigger
362+
* IO for these blocks asynchronously in the first phase. In the second phase, the cached blocks
363+
* are visited one by one.
364+
*
365+
* @lucene.experimental
366+
*/
367+
public abstract static class TwoPhaseIntersectVisitor implements IntersectVisitor {
368+
369+
int lastDeferredBlockOrdinal = -1;
370+
List<Long> deferredBlocks = new ArrayList<>();
371+
372+
/**
373+
* return the last deferred block ordinal - this is used to avoid prefetching call for
374+
* contiguous ordinals assuming contiguous ordinals prefetching can be taken care by readaheads.
375+
*/
376+
public int lastDeferredBlockOrdinal() {
377+
return lastDeferredBlockOrdinal;
378+
}
379+
380+
/** set last deferred block ordinal * */
381+
public void setLastDeferredBlockOrdinal(int leafNodeOrdinal) {
382+
lastDeferredBlockOrdinal = leafNodeOrdinal;
383+
}
384+
385+
/** Defer this block for processing in the second phase. */
386+
public void deferBlock(long leafFp) {
387+
deferredBlocks.add(leafFp);
388+
}
389+
390+
/** Returns a snapshot of the currently deferred blocks. */
391+
public List<Long> deferredBlocks() {
392+
return new ArrayList<>(deferredBlocks);
393+
}
394+
395+
/** Mark the given block as processed and remove it from the deferred set. */
396+
public void onProcessingDeferredBlock(long leafFp) {
397+
deferredBlocks.remove(leafFp);
398+
}
399+
}
400+
344401
/**
345402
* Finds all documents and points matching the provided visitor. This method does not enforce live
346403
* documents, so it's up to the caller to test whether each document is deleted, if necessary.
347404
*/
348405
public final void intersect(IntersectVisitor visitor) throws IOException {
349406
final PointTree pointTree = getPointTree();
350407
intersect(visitor, pointTree);
408+
if (visitor instanceof TwoPhaseIntersectVisitor twoPhaseVisitor) {
409+
List<Long> fps = twoPhaseVisitor.deferredBlocks();
410+
for (int i = 0; i < fps.size(); ++i) {
411+
long fp = fps.get(i);
412+
pointTree.visitDocIDs(fp, visitor);
413+
twoPhaseVisitor.onProcessingDeferredBlock(fp);
414+
}
415+
}
351416
assert pointTree.moveToParent() == false;
352417
}
353418

@@ -358,7 +423,7 @@ private static void intersect(IntersectVisitor visitor, PointTree pointTree) thr
358423
if (compare == Relation.CELL_INSIDE_QUERY) {
359424
// This cell is fully inside the query shape: recursively add all points in this cell
360425
// without filtering
361-
pointTree.visitDocIDs(visitor);
426+
pointTree.prepareOrVisitDocIDs(visitor);
362427
} else if (compare == Relation.CELL_CROSSES_QUERY) {
363428
// The cell crosses the shape boundary, or the cell fully contains the query, so we fall
364429
// through and do full filtering:

lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ private boolean matches(byte[] packedValue) {
147147
}
148148

149149
private IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) {
150-
return new IntersectVisitor() {
150+
return new PointValues.TwoPhaseIntersectVisitor() {
151151

152152
DocIdSetBuilder.BulkAdder adder;
153153

@@ -194,7 +194,7 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
194194

195195
/** Create a visitor that sets documents that do NOT match the range. */
196196
private IntersectVisitor getInverseIntersectVisitor(FixedBitSet result, long[] cost) {
197-
return new IntersectVisitor() {
197+
return new PointValues.TwoPhaseIntersectVisitor() {
198198

199199
@Override
200200
public void visit(int docID) {

lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,69 @@ public void visitDocIDs(PointValues.IntersectVisitor visitor) throws IOException
589589
addAll(visitor, false);
590590
}
591591

592+
@Override
593+
public void prepareOrVisitDocIDs(IntersectVisitor visitor) throws IOException {
594+
resetNodeDataPosition();
595+
prefetchAll(visitor, false);
596+
}
597+
598+
@Override
599+
public void visitDocIDs(long position, IntersectVisitor visitor) throws IOException {
600+
visitDocIDs(position, visitor, false);
601+
}
602+
603+
private void visitDocIDs(long position, IntersectVisitor visitor, boolean grown)
604+
throws IOException {
605+
leafNodes.seek(position);
606+
int count = leafNodes.readVInt();
607+
if (!grown) {
608+
visitor.grow(count);
609+
}
610+
docIdsWriter.readInts(leafNodes, count, visitor, scratchIterator.docIDs);
611+
}
612+
613+
private int getLeafNodeOrdinal() {
614+
assert isLeafNode() : "nodeID=" + nodeID + " is not a leaf";
615+
return nodeID - leafNodeOffset;
616+
}
617+
618+
public void prefetchAll(IntersectVisitor visitor, boolean grown) throws IOException {
619+
if (grown == false) {
620+
final long size = size();
621+
if (size <= Integer.MAX_VALUE) {
622+
visitor.grow((int) size);
623+
grown = true;
624+
}
625+
}
626+
if (isLeafNode()) {
627+
// int count = isLastLeaf() ? config.maxPointsInLeafNode() : lastLeafNodePointCount;
628+
long leafFp = getLeafBlockFP();
629+
int leafNodeOrdinal = getLeafNodeOrdinal();
630+
if (visitor instanceof TwoPhaseIntersectVisitor twoPhaseIntersectVisitor) {
631+
// Only call prefetch is this is the first leaf node ordinal or the first match in
632+
// contigiuous sequence of matches for leaf nodes
633+
// boolean prefetched = false;
634+
if (twoPhaseIntersectVisitor.lastDeferredBlockOrdinal() == -1
635+
|| twoPhaseIntersectVisitor.lastDeferredBlockOrdinal() + 1 < leafNodeOrdinal) {
636+
// System.out.println("Prefetched called on " + leafNodeOrdinal);
637+
leafNodes.prefetch(leafFp, 1);
638+
// prefetched = true;
639+
}
640+
twoPhaseIntersectVisitor.setLastDeferredBlockOrdinal(leafNodeOrdinal);
641+
twoPhaseIntersectVisitor.deferBlock(leafFp);
642+
} else {
643+
visitDocIDs(getLeafBlockFP(), visitor, true);
644+
}
645+
} else {
646+
pushLeft();
647+
prefetchAll(visitor, grown);
648+
pop();
649+
pushRight();
650+
prefetchAll(visitor, grown);
651+
pop();
652+
}
653+
}
654+
592655
public void addAll(PointValues.IntersectVisitor visitor, boolean grown) throws IOException {
593656
if (grown == false) {
594657
final long size = size();

lucene/core/src/test/org/apache/lucene/codecs/lucene90/TestLucene90PointsFormat.java

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.io.IOException;
2020
import java.util.Arrays;
21+
import java.util.List;
2122
import org.apache.lucene.codecs.Codec;
2223
import org.apache.lucene.codecs.FilterCodec;
2324
import org.apache.lucene.codecs.PointsFormat;
@@ -355,4 +356,67 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
355356
r.close();
356357
dir.close();
357358
}
359+
360+
public void testBasicWithPrefetchVisitor() throws Exception {
361+
Directory dir = newDirectory();
362+
IndexWriterConfig iwc = newIndexWriterConfig();
363+
// Avoid mockRandomMP since it may cause non-optimal merges that make the
364+
// number of points per leaf hard to predict
365+
while (iwc.getMergePolicy() instanceof MockRandomMergePolicy) {
366+
iwc.setMergePolicy(newMergePolicy());
367+
}
368+
IndexWriter w = new IndexWriter(dir, iwc);
369+
byte[] pointValue = new byte[3];
370+
byte[] uniquePointValue = new byte[3];
371+
random().nextBytes(uniquePointValue);
372+
final int numDocs =
373+
TEST_NIGHTLY ? atLeast(10000) : atLeast(500); // at night, make sure we have several leaves
374+
final boolean multiValues = random().nextBoolean();
375+
int totalValues = 0;
376+
for (int i = 0; i < numDocs; ++i) {
377+
Document doc = new Document();
378+
if (i == numDocs / 2) {
379+
totalValues++;
380+
doc.add(new BinaryPoint("f", uniquePointValue));
381+
} else {
382+
final int numValues = (multiValues) ? TestUtil.nextInt(random(), 2, 100) : 1;
383+
for (int j = 0; j < numValues; j++) {
384+
do {
385+
random().nextBytes(pointValue);
386+
} while (Arrays.equals(pointValue, uniquePointValue));
387+
doc.add(new BinaryPoint("f", pointValue));
388+
totalValues++;
389+
}
390+
}
391+
w.addDocument(doc);
392+
}
393+
w.forceMerge(1);
394+
final IndexReader r = DirectoryReader.open(w);
395+
w.close();
396+
397+
final LeafReader lr = getOnlyLeafReader(r);
398+
PointValues points = lr.getPointValues("f");
399+
400+
PointValues.TwoPhaseIntersectVisitor allPointsVisitor =
401+
new PointValues.TwoPhaseIntersectVisitor() {
402+
@Override
403+
public void visit(int docID, byte[] packedValue) throws IOException {}
404+
405+
@Override
406+
public void visit(int docID) throws IOException {}
407+
408+
@Override
409+
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
410+
return Relation.CELL_INSIDE_QUERY;
411+
}
412+
};
413+
414+
List<Long> savedBlocks = allPointsVisitor.deferredBlocks();
415+
assertEquals(0, savedBlocks.size()); // Test that all deferred blocks were processed
416+
assertEquals(totalValues, points.estimatePointCount(allPointsVisitor));
417+
assertEquals(numDocs, points.estimateDocCount(allPointsVisitor));
418+
419+
r.close();
420+
dir.close();
421+
}
358422
}

lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
package org.apache.lucene.tests.index;
1818

1919
import java.io.IOException;
20+
import java.util.ArrayList;
2021
import java.util.Arrays;
2122
import java.util.Iterator;
23+
import java.util.List;
2224
import java.util.Objects;
2325
import org.apache.lucene.index.BinaryDocValues;
2426
import org.apache.lucene.index.DocValues;
@@ -1597,13 +1599,23 @@ public void visitDocValues(IntersectVisitor visitor) throws IOException {
15971599
pointValues.getBytesPerDimension(),
15981600
visitor));
15991601
}
1602+
1603+
@Override
1604+
public void visitDocIDs(long pos, IntersectVisitor visitor) throws IOException {
1605+
in.visitDocIDs(pos, visitor);
1606+
}
1607+
1608+
@Override
1609+
public void prepareOrVisitDocIDs(IntersectVisitor visitor) throws IOException {
1610+
in.prepareOrVisitDocIDs(visitor);
1611+
}
16001612
}
16011613

16021614
/**
16031615
* Validates in the 1D case that all points are visited in order, and point values are in bounds
16041616
* of the last cell checked
16051617
*/
1606-
static class AssertingIntersectVisitor implements IntersectVisitor {
1618+
static class AssertingIntersectVisitor extends PointValues.TwoPhaseIntersectVisitor {
16071619
final IntersectVisitor in;
16081620
final int numDataDims;
16091621
final int numIndexDims;
@@ -1614,6 +1626,8 @@ static class AssertingIntersectVisitor implements IntersectVisitor {
16141626
private Relation lastCompareResult;
16151627
private int lastDocID = -1;
16161628
private int docBudget;
1629+
int lastMatchedBlock;
1630+
private List<Long> prefetchedBlocks;
16171631

16181632
AssertingIntersectVisitor(
16191633
int numDataDims, int numIndexDims, int bytesPerDim, IntersectVisitor in) {
@@ -1716,6 +1730,26 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
17161730
lastCompareResult = in.compare(minPackedValue, maxPackedValue);
17171731
return lastCompareResult;
17181732
}
1733+
1734+
@Override
1735+
public int lastDeferredBlockOrdinal() {
1736+
return lastMatchedBlock;
1737+
}
1738+
1739+
@Override
1740+
public void setLastDeferredBlockOrdinal(int leafNodeOrdinal) {
1741+
lastMatchedBlock = leafNodeOrdinal;
1742+
}
1743+
1744+
@Override
1745+
public void deferBlock(long leafFp) {
1746+
prefetchedBlocks.add(leafFp);
1747+
}
1748+
1749+
@Override
1750+
public List<Long> deferredBlocks() {
1751+
return new ArrayList<>(prefetchedBlocks);
1752+
}
17191753
}
17201754

17211755
@Override

0 commit comments

Comments
 (0)