Skip to content

Commit d448444

Browse files
author
Saurabh Singh
committed
Support for traversing BKD tree with prefetching
1 parent 8e8e37d commit d448444

File tree

6 files changed

+328
-5
lines changed

6 files changed

+328
-5
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,8 @@ API Changes
984984
* GITHUB#13820, GITHUB#13825, GITHUB#13830: Corrects DataInput.readGroupVInts to be public and not-final, removes the protected
985985
DataInput.readGroupVInt method. (Zhang Chao, Robert Muir, Uwe Schindler, Dawid Weiss)
986986

987+
* GITHUB#15376, GITHUB#15197: Added prefetching in bkd tree traversal, couple of new api in PointValues visitDocIDs from a position and prepareOrVisitDocIDs to prefetch the IO before visiting docIds (Saurabh Singh)
988+
987989
New Features
988990
---------------------
989991

lucene/core/src/java/org/apache/lucene/index/PointValues.java

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import java.io.UncheckedIOException;
2121
import java.math.BigInteger;
2222
import java.net.InetAddress;
23+
import java.util.ArrayList;
24+
import java.util.List;
2325
import org.apache.lucene.document.BinaryPoint;
2426
import org.apache.lucene.document.DoublePoint;
2527
import org.apache.lucene.document.Field;
@@ -274,6 +276,19 @@ public interface PointTree extends Cloneable {
274276

275277
/** Visit all the docs and values below the current node. */
276278
void visitDocValues(IntersectVisitor visitor) throws IOException;
279+
280+
/** Visit all the docs below the node at position pos */
281+
default void visitDocIDs(long pos, IntersectVisitor visitor) throws IOException {}
282+
;
283+
284+
/**
285+
* call prefetch for docs below the current node if vistor supports prefetching otherwise visit
286+
* docIds
287+
*/
288+
default void prepareOrVisitDocIDs(IntersectVisitor visitor) throws IOException {
289+
visitDocIDs(visitor);
290+
}
291+
;
277292
}
278293

279294
/**
@@ -341,13 +356,91 @@ default void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOExcep
341356
default void grow(int count) {}
342357
}
343358

359+
/*
360+
* We can recurse the {@link PointTree} using {@link TwoPhaseIntersectVisitor}. This visitor
361+
* travere {@link PointTree} in two phases. In the first phase, it recurses over the {@link PointTree} optionally
362+
* triggering IO for some of the blocks and caching them. In the second phase, once the recursion is over it visits
363+
* the cached blocks one by one.
364+
*
365+
* @lucene.experimental
366+
*/
367+
public interface TwoPhaseIntersectVisitor extends IntersectVisitor {
368+
/** return the last deferred block ordinal during recursion. */
369+
public int lastDeferredBlockOrdinal();
370+
371+
/** set last deferred block ordinal */
372+
public void setLastDeferredBlockOrdinal(int leafNodeOrdinal);
373+
374+
/** Defer this block for processing in the second phase. */
375+
public void deferBlock(long leafFp);
376+
377+
/** Returns a snapshot of the currently deferred blocks. */
378+
public List<Long> deferredBlocks();
379+
380+
/** Mark the given block as processed and remove it from the deferred set. */
381+
public void onProcessingDeferredBlock(long leafFp);
382+
}
383+
384+
/**
385+
* Base implementation of {@link TwoPhaseIntersectVisitor} that maintains a list of deferred
386+
* blocks from first phase of traversal and visits them in the second phase.
387+
*
388+
* @lucene.experimental
389+
*/
390+
public abstract static class BaseTwoPhaseIntersectVisitor implements TwoPhaseIntersectVisitor {
391+
392+
int lastDeferredBlockOrdinal = -1;
393+
List<Long> deferredBlocks = new ArrayList<>();
394+
395+
/**
396+
* return the last deferred block ordinal - this is used to avoid prefetching call for
397+
* contiguous ordinals assuming contiguous ordinals prefetching can be taken care by readaheads.
398+
*/
399+
@Override
400+
public int lastDeferredBlockOrdinal() {
401+
return lastDeferredBlockOrdinal;
402+
}
403+
404+
/** set last deferred block ordinal * */
405+
@Override
406+
public void setLastDeferredBlockOrdinal(int leafNodeOrdinal) {
407+
lastDeferredBlockOrdinal = leafNodeOrdinal;
408+
}
409+
410+
/** Defer this block for processing in the second phase. */
411+
@Override
412+
public void deferBlock(long leafFp) {
413+
deferredBlocks.add(leafFp);
414+
}
415+
416+
/** Returns a snapshot of the currently deferred blocks. */
417+
@Override
418+
public List<Long> deferredBlocks() {
419+
return new ArrayList<>(deferredBlocks);
420+
}
421+
422+
/** Mark the given block as processed and remove it from the deferred set. */
423+
@Override
424+
public void onProcessingDeferredBlock(long leafFp) {
425+
deferredBlocks.remove(leafFp);
426+
}
427+
}
428+
344429
/**
345430
* Finds all documents and points matching the provided visitor. This method does not enforce live
346431
* documents, so it's up to the caller to test whether each document is deleted, if necessary.
347432
*/
348433
public final void intersect(IntersectVisitor visitor) throws IOException {
349434
final PointTree pointTree = getPointTree();
350435
intersect(visitor, pointTree);
436+
if (visitor instanceof TwoPhaseIntersectVisitor twoPhaseVisitor) {
437+
List<Long> fps = twoPhaseVisitor.deferredBlocks();
438+
for (int i = 0; i < fps.size(); ++i) {
439+
long fp = fps.get(i);
440+
pointTree.visitDocIDs(fp, visitor);
441+
twoPhaseVisitor.onProcessingDeferredBlock(fp);
442+
}
443+
}
351444
assert pointTree.moveToParent() == false;
352445
}
353446

@@ -358,7 +451,7 @@ private static void intersect(IntersectVisitor visitor, PointTree pointTree) thr
358451
if (compare == Relation.CELL_INSIDE_QUERY) {
359452
// This cell is fully inside the query shape: recursively add all points in this cell
360453
// without filtering
361-
pointTree.visitDocIDs(visitor);
454+
pointTree.prepareOrVisitDocIDs(visitor);
362455
} else if (compare == Relation.CELL_CROSSES_QUERY) {
363456
// The cell crosses the shape boundary, or the cell fully contains the query, so we fall
364457
// through and do full filtering:

lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,69 @@ public void visitDocIDs(PointValues.IntersectVisitor visitor) throws IOException
589589
addAll(visitor, false);
590590
}
591591

592+
@Override
593+
public void prepareOrVisitDocIDs(IntersectVisitor visitor) throws IOException {
594+
resetNodeDataPosition();
595+
prefetchAll(visitor, false);
596+
}
597+
598+
@Override
599+
public void visitDocIDs(long position, IntersectVisitor visitor) throws IOException {
600+
visitDocIDs(position, visitor, false);
601+
}
602+
603+
private void visitDocIDs(long position, IntersectVisitor visitor, boolean grown)
604+
throws IOException {
605+
leafNodes.seek(position);
606+
int count = leafNodes.readVInt();
607+
if (!grown) {
608+
visitor.grow(count);
609+
}
610+
docIdsWriter.readInts(leafNodes, count, visitor, scratchIterator.docIDs);
611+
}
612+
613+
private int getLeafNodeOrdinal() {
614+
assert isLeafNode() : "nodeID=" + nodeID + " is not a leaf";
615+
return nodeID - leafNodeOffset;
616+
}
617+
618+
public void prefetchAll(IntersectVisitor visitor, boolean grown) throws IOException {
619+
if (grown == false) {
620+
final long size = size();
621+
if (size <= Integer.MAX_VALUE) {
622+
visitor.grow((int) size);
623+
grown = true;
624+
}
625+
}
626+
if (isLeafNode()) {
627+
// int count = isLastLeaf() ? config.maxPointsInLeafNode() : lastLeafNodePointCount;
628+
long leafFp = getLeafBlockFP();
629+
int leafNodeOrdinal = getLeafNodeOrdinal();
630+
if (visitor instanceof BaseTwoPhaseIntersectVisitor twoPhaseIntersectVisitor) {
631+
// Only call prefetch is this is the first leaf node ordinal or the first match in
632+
// contigiuous sequence of matches for leaf nodes
633+
// boolean prefetched = false;
634+
if (twoPhaseIntersectVisitor.lastDeferredBlockOrdinal() == -1
635+
|| twoPhaseIntersectVisitor.lastDeferredBlockOrdinal() + 1 < leafNodeOrdinal) {
636+
// System.out.println("Prefetched called on " + leafNodeOrdinal);
637+
leafNodes.prefetch(leafFp, 1);
638+
// prefetched = true;
639+
}
640+
twoPhaseIntersectVisitor.setLastDeferredBlockOrdinal(leafNodeOrdinal);
641+
twoPhaseIntersectVisitor.deferBlock(leafFp);
642+
} else {
643+
visitDocIDs(getLeafBlockFP(), visitor, true);
644+
}
645+
} else {
646+
pushLeft();
647+
prefetchAll(visitor, grown);
648+
pop();
649+
pushRight();
650+
prefetchAll(visitor, grown);
651+
pop();
652+
}
653+
}
654+
592655
public void addAll(PointValues.IntersectVisitor visitor, boolean grown) throws IOException {
593656
if (grown == false) {
594657
final long size = size();

lucene/core/src/test/org/apache/lucene/codecs/lucene90/TestLucene90PointsFormat.java

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.io.IOException;
2020
import java.util.Arrays;
21+
import java.util.List;
2122
import org.apache.lucene.codecs.Codec;
2223
import org.apache.lucene.codecs.FilterCodec;
2324
import org.apache.lucene.codecs.PointsFormat;
@@ -355,4 +356,67 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
355356
r.close();
356357
dir.close();
357358
}
359+
360+
public void testBasicWithPrefetchVisitor() throws Exception {
361+
Directory dir = newDirectory();
362+
IndexWriterConfig iwc = newIndexWriterConfig();
363+
// Avoid mockRandomMP since it may cause non-optimal merges that make the
364+
// number of points per leaf hard to predict
365+
while (iwc.getMergePolicy() instanceof MockRandomMergePolicy) {
366+
iwc.setMergePolicy(newMergePolicy());
367+
}
368+
IndexWriter w = new IndexWriter(dir, iwc);
369+
byte[] pointValue = new byte[3];
370+
byte[] uniquePointValue = new byte[3];
371+
random().nextBytes(uniquePointValue);
372+
final int numDocs =
373+
TEST_NIGHTLY ? atLeast(10000) : atLeast(500); // at night, make sure we have several leaves
374+
final boolean multiValues = random().nextBoolean();
375+
int totalValues = 0;
376+
for (int i = 0; i < numDocs; ++i) {
377+
Document doc = new Document();
378+
if (i == numDocs / 2) {
379+
totalValues++;
380+
doc.add(new BinaryPoint("f", uniquePointValue));
381+
} else {
382+
final int numValues = (multiValues) ? TestUtil.nextInt(random(), 2, 100) : 1;
383+
for (int j = 0; j < numValues; j++) {
384+
do {
385+
random().nextBytes(pointValue);
386+
} while (Arrays.equals(pointValue, uniquePointValue));
387+
doc.add(new BinaryPoint("f", pointValue));
388+
totalValues++;
389+
}
390+
}
391+
w.addDocument(doc);
392+
}
393+
w.forceMerge(1);
394+
final IndexReader r = DirectoryReader.open(w);
395+
w.close();
396+
397+
final LeafReader lr = getOnlyLeafReader(r);
398+
PointValues points = lr.getPointValues("f");
399+
400+
PointValues.BaseTwoPhaseIntersectVisitor allPointsVisitor =
401+
new PointValues.BaseTwoPhaseIntersectVisitor() {
402+
@Override
403+
public void visit(int docID, byte[] packedValue) throws IOException {}
404+
405+
@Override
406+
public void visit(int docID) throws IOException {}
407+
408+
@Override
409+
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
410+
return Relation.CELL_INSIDE_QUERY;
411+
}
412+
};
413+
414+
List<Long> savedBlocks = allPointsVisitor.deferredBlocks();
415+
assertEquals(0, savedBlocks.size()); // Test that all deferred blocks were processed
416+
assertEquals(totalValues, points.estimatePointCount(allPointsVisitor));
417+
assertEquals(numDocs, points.estimateDocCount(allPointsVisitor));
418+
419+
r.close();
420+
dir.close();
421+
}
358422
}

lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
package org.apache.lucene.tests.index;
1818

1919
import java.io.IOException;
20+
import java.util.ArrayList;
2021
import java.util.Arrays;
2122
import java.util.Iterator;
23+
import java.util.List;
2224
import java.util.Objects;
2325
import org.apache.lucene.index.BinaryDocValues;
2426
import org.apache.lucene.index.DocValues;
@@ -1581,7 +1583,7 @@ public long size() {
15811583
@Override
15821584
public void visitDocIDs(IntersectVisitor visitor) throws IOException {
15831585
in.visitDocIDs(
1584-
new AssertingIntersectVisitor(
1586+
new AssertingIntersectVisitorBase(
15851587
pointValues.getNumDimensions(),
15861588
pointValues.getNumIndexDimensions(),
15871589
pointValues.getBytesPerDimension(),
@@ -1591,19 +1593,29 @@ public void visitDocIDs(IntersectVisitor visitor) throws IOException {
15911593
@Override
15921594
public void visitDocValues(IntersectVisitor visitor) throws IOException {
15931595
in.visitDocValues(
1594-
new AssertingIntersectVisitor(
1596+
new AssertingIntersectVisitorBase(
15951597
pointValues.getNumDimensions(),
15961598
pointValues.getNumIndexDimensions(),
15971599
pointValues.getBytesPerDimension(),
15981600
visitor));
15991601
}
1602+
1603+
@Override
1604+
public void visitDocIDs(long pos, IntersectVisitor visitor) throws IOException {
1605+
in.visitDocIDs(pos, visitor);
1606+
}
1607+
1608+
@Override
1609+
public void prepareOrVisitDocIDs(IntersectVisitor visitor) throws IOException {
1610+
in.prepareOrVisitDocIDs(visitor);
1611+
}
16001612
}
16011613

16021614
/**
16031615
* Validates in the 1D case that all points are visited in order, and point values are in bounds
16041616
* of the last cell checked
16051617
*/
1606-
static class AssertingIntersectVisitor implements IntersectVisitor {
1618+
static class AssertingIntersectVisitorBase extends PointValues.BaseTwoPhaseIntersectVisitor {
16071619
final IntersectVisitor in;
16081620
final int numDataDims;
16091621
final int numIndexDims;
@@ -1614,8 +1626,10 @@ static class AssertingIntersectVisitor implements IntersectVisitor {
16141626
private Relation lastCompareResult;
16151627
private int lastDocID = -1;
16161628
private int docBudget;
1629+
int lastMatchedBlock;
1630+
private List<Long> prefetchedBlocks;
16171631

1618-
AssertingIntersectVisitor(
1632+
AssertingIntersectVisitorBase(
16191633
int numDataDims, int numIndexDims, int bytesPerDim, IntersectVisitor in) {
16201634
this.in = in;
16211635
this.numDataDims = numDataDims;
@@ -1716,6 +1730,26 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
17161730
lastCompareResult = in.compare(minPackedValue, maxPackedValue);
17171731
return lastCompareResult;
17181732
}
1733+
1734+
@Override
1735+
public int lastDeferredBlockOrdinal() {
1736+
return lastMatchedBlock;
1737+
}
1738+
1739+
@Override
1740+
public void setLastDeferredBlockOrdinal(int leafNodeOrdinal) {
1741+
lastMatchedBlock = leafNodeOrdinal;
1742+
}
1743+
1744+
@Override
1745+
public void deferBlock(long leafFp) {
1746+
prefetchedBlocks.add(leafFp);
1747+
}
1748+
1749+
@Override
1750+
public List<Long> deferredBlocks() {
1751+
return new ArrayList<>(prefetchedBlocks);
1752+
}
17191753
}
17201754

17211755
@Override

0 commit comments

Comments
 (0)