Skip to content

Commit b076a03

Browse files
author
Saurabh Singh
committed
Support for traversing BKD tree with prefetching
1 parent 8e8e37d commit b076a03

File tree

6 files changed

+562
-6
lines changed

6 files changed

+562
-6
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,8 @@ API Changes
984984
* GITHUB#13820, GITHUB#13825, GITHUB#13830: Corrects DataInput.readGroupVInts to be public and not-final, removes the protected
985985
DataInput.readGroupVInt method. (Zhang Chao, Robert Muir, Uwe Schindler, Dawid Weiss)
986986

987+
* GITHUB#15376, GITHUB#15197: Added prefetching in bkd tree traversal, couple of new api in PointValues visitDocIDs from a position and prepareOrVisitDocIDs to prefetch the IO before visiting docIds (Saurabh Singh)
988+
987989
New Features
988990
---------------------
989991

lucene/core/src/java/org/apache/lucene/index/CheckIndex.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3194,7 +3194,7 @@ private static void checkByteVectorValues(
31943194
*
31953195
* @lucene.internal
31963196
*/
3197-
public static class VerifyPointsVisitor implements PointValues.IntersectVisitor {
3197+
public static class VerifyPointsVisitor implements IntersectVisitor {
31983198
private long pointCountSeen;
31993199
private int lastDocID = -1;
32003200
private final FixedBitSet docsSeen;

lucene/core/src/java/org/apache/lucene/index/PointValues.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ default void grow(int count) {}
345345
* Finds all documents and points matching the provided visitor. This method does not enforce live
346346
* documents, so it's up to the caller to test whether each document is deleted, if necessary.
347347
*/
348-
public final void intersect(IntersectVisitor visitor) throws IOException {
348+
public void intersect(IntersectVisitor visitor) throws IOException {
349349
final PointTree pointTree = getPointTree();
350350
intersect(visitor, pointTree);
351351
assert pointTree.moveToParent() == false;

lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
package org.apache.lucene.util.bkd;
1818

1919
import java.io.IOException;
20+
import java.util.ArrayList;
2021
import java.util.Arrays;
22+
import java.util.List;
2123
import org.apache.lucene.codecs.CodecUtil;
2224
import org.apache.lucene.index.CorruptIndexException;
2325
import org.apache.lucene.index.PointValues;
@@ -589,6 +591,65 @@ public void visitDocIDs(PointValues.IntersectVisitor visitor) throws IOException
589591
addAll(visitor, false);
590592
}
591593

594+
/** prefetch DocIds below current node */
595+
public void prefetchDocIDs(TwoPhaseIntersectVisitor visitor) throws IOException {
596+
resetNodeDataPosition();
597+
prefetchAll(visitor, false);
598+
}
599+
600+
/** visit Doc Ids for a leafNode at provided input position */
601+
public void visitDocIDs(long position, IntersectVisitor visitor) throws IOException {
602+
visitDocIDs(position, visitor, false);
603+
}
604+
605+
private void visitDocIDs(long position, IntersectVisitor visitor, boolean grown)
606+
throws IOException {
607+
leafNodes.seek(position);
608+
int count = leafNodes.readVInt();
609+
if (!grown) {
610+
visitor.grow(count);
611+
}
612+
docIdsWriter.readInts(leafNodes, count, visitor, scratchIterator.docIDs);
613+
}
614+
615+
private int getLeafNodeOrdinal() {
616+
assert isLeafNode() : "nodeID=" + nodeID + " is not a leaf";
617+
return nodeID - leafNodeOffset;
618+
}
619+
620+
public void prefetchAll(TwoPhaseIntersectVisitor visitor, boolean grown) throws IOException {
621+
if (grown == false) {
622+
final long size = size();
623+
if (size <= Integer.MAX_VALUE) {
624+
visitor.grow((int) size);
625+
grown = true;
626+
}
627+
}
628+
if (isLeafNode()) {
629+
// int count = isLastLeaf() ? config.maxPointsInLeafNode() : lastLeafNodePointCount;
630+
long leafFp = getLeafBlockFP();
631+
int leafNodeOrdinal = getLeafNodeOrdinal();
632+
// Only call prefetch is this is the first leaf node ordinal or the first match in
633+
// contigiuous sequence of matches for leaf nodes
634+
// boolean prefetched = false;
635+
if (visitor.lastDeferredBlockOrdinal() == -1
636+
|| visitor.lastDeferredBlockOrdinal() + 1 < leafNodeOrdinal) {
637+
// System.out.println("Prefetched called on " + leafNodeOrdinal);
638+
leafNodes.prefetch(leafFp, 1);
639+
// prefetched = true;
640+
}
641+
visitor.setLastDeferredBlockOrdinal(leafNodeOrdinal);
642+
visitor.deferBlock(leafFp);
643+
} else {
644+
pushLeft();
645+
prefetchAll(visitor, grown);
646+
pop();
647+
pushRight();
648+
prefetchAll(visitor, grown);
649+
pop();
650+
}
651+
}
652+
592653
public void addAll(PointValues.IntersectVisitor visitor, boolean grown) throws IOException {
593654
if (grown == false) {
594655
final long size = size();
@@ -1076,4 +1137,123 @@ public long cost() {
10761137
return length;
10771138
}
10781139
}
1140+
1141+
/**
1142+
* We can recurse the {@link BKDPointTree} using {@link TwoPhaseIntersectVisitor}. This visitor
1143+
* travere {@link BKDPointTree} in two phases. In the first phase, it recurses over the {@link
1144+
* BKDPointTree} optionally triggering IO for some of the blocks and caching them. In the second
1145+
* phase, once the recursion is over it visits the cached blocks one by one.
1146+
*
1147+
* @lucene.experimental
1148+
*/
1149+
public interface TwoPhaseIntersectVisitor extends IntersectVisitor {
1150+
/** return the last deferred block ordinal during recursion. */
1151+
public int lastDeferredBlockOrdinal();
1152+
1153+
/** set last deferred block ordinal */
1154+
public void setLastDeferredBlockOrdinal(int leafNodeOrdinal);
1155+
1156+
/** Defer this block for processing in the second phase. */
1157+
public void deferBlock(long leafFp);
1158+
1159+
/** Returns a snapshot of the currently deferred blocks. */
1160+
public List<Long> deferredBlocks();
1161+
1162+
/** Mark the given block as processed and remove it from the deferred set. */
1163+
public void onProcessingDeferredBlock(long leafFp);
1164+
}
1165+
1166+
/**
1167+
* Base implementation of {@link TwoPhaseIntersectVisitor} that maintains a list of deferred
1168+
* blocks from first phase of traversal and visits them in the second phase.
1169+
*
1170+
* @lucene.experimental
1171+
*/
1172+
public abstract static class BaseTwoPhaseIntersectVisitor implements TwoPhaseIntersectVisitor {
1173+
1174+
int lastDeferredBlockOrdinal = -1;
1175+
List<Long> deferredBlocks = new ArrayList<>();
1176+
1177+
/**
1178+
* return the last deferred block ordinal - this is used to avoid prefetching call for
1179+
* contiguous ordinals assuming contiguous ordinals prefetching can be taken care by readaheads.
1180+
*/
1181+
@Override
1182+
public int lastDeferredBlockOrdinal() {
1183+
return lastDeferredBlockOrdinal;
1184+
}
1185+
1186+
/** set last deferred block ordinal * */
1187+
@Override
1188+
public void setLastDeferredBlockOrdinal(int leafNodeOrdinal) {
1189+
lastDeferredBlockOrdinal = leafNodeOrdinal;
1190+
}
1191+
1192+
/** Defer this block for processing in the second phase. */
1193+
@Override
1194+
public void deferBlock(long leafFp) {
1195+
deferredBlocks.add(leafFp);
1196+
}
1197+
1198+
/** Returns a snapshot of the currently deferred blocks. */
1199+
@Override
1200+
public List<Long> deferredBlocks() {
1201+
return new ArrayList<>(deferredBlocks);
1202+
}
1203+
1204+
/** Mark the given block as processed and remove it from the deferred set. */
1205+
@Override
1206+
public void onProcessingDeferredBlock(long leafFp) {
1207+
deferredBlocks.remove(leafFp);
1208+
}
1209+
}
1210+
1211+
/**
1212+
* Finds all documents and points matching the provided visitor. This method does not enforce live
1213+
* documents, so it's up to the caller to test whether each document is deleted, if necessary.
1214+
*/
1215+
@Override
1216+
public final void intersect(IntersectVisitor visitor) throws IOException {
1217+
final BKDPointTree pointTree = (BKDPointTree) getPointTree();
1218+
if (visitor instanceof TwoPhaseIntersectVisitor twoPhaseIntersectVisitor) {
1219+
intersect(twoPhaseIntersectVisitor, pointTree);
1220+
List<Long> fps = twoPhaseIntersectVisitor.deferredBlocks();
1221+
for (int i = 0; i < fps.size(); ++i) {
1222+
long fp = fps.get(i);
1223+
pointTree.visitDocIDs(fp, visitor);
1224+
twoPhaseIntersectVisitor.onProcessingDeferredBlock(fp);
1225+
}
1226+
} else {
1227+
intersect(visitor, pointTree);
1228+
}
1229+
assert pointTree.moveToParent() == false;
1230+
}
1231+
1232+
private static void intersect(TwoPhaseIntersectVisitor visitor, BKDPointTree pointTree)
1233+
throws IOException {
1234+
while (true) {
1235+
Relation compare =
1236+
visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
1237+
if (compare == Relation.CELL_INSIDE_QUERY) {
1238+
// This cell is fully inside the query shape: recursively prefetch all points in this cell
1239+
// without filtering
1240+
pointTree.prefetchDocIDs(visitor);
1241+
} else if (compare == Relation.CELL_CROSSES_QUERY) {
1242+
// The cell crosses the shape boundary, or the cell fully contains the query, so we fall
1243+
// through and do full filtering:
1244+
if (pointTree.moveToChild()) {
1245+
continue;
1246+
}
1247+
// TODO: we can assert that the first value here in fact matches what the pointTree
1248+
// claimed?
1249+
// Leaf node; scan and filter all points in this block:
1250+
pointTree.visitDocValues(visitor);
1251+
}
1252+
while (pointTree.moveToSibling() == false) {
1253+
if (pointTree.moveToParent() == false) {
1254+
return;
1255+
}
1256+
}
1257+
}
1258+
}
10791259
}

lucene/core/src/test/org/apache/lucene/codecs/lucene90/TestLucene90PointsFormat.java

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import java.io.IOException;
2020
import java.util.Arrays;
21+
import java.util.BitSet;
22+
import java.util.List;
2123
import org.apache.lucene.codecs.Codec;
2224
import org.apache.lucene.codecs.FilterCodec;
2325
import org.apache.lucene.codecs.PointsFormat;
@@ -39,7 +41,10 @@
3941
import org.apache.lucene.tests.index.BasePointsFormatTestCase;
4042
import org.apache.lucene.tests.index.MockRandomMergePolicy;
4143
import org.apache.lucene.tests.util.TestUtil;
44+
import org.apache.lucene.util.IOUtils;
45+
import org.apache.lucene.util.NumericUtils;
4246
import org.apache.lucene.util.bkd.BKDConfig;
47+
import org.apache.lucene.util.bkd.BKDReader;
4348

4449
public class TestLucene90PointsFormat extends BasePointsFormatTestCase {
4550

@@ -355,4 +360,134 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
355360
r.close();
356361
dir.close();
357362
}
363+
364+
public void testBasicWithPrefetchVisitor() throws Exception {
365+
Directory dir = newDirectory();
366+
IndexWriterConfig iwc = newIndexWriterConfig();
367+
// Avoid mockRandomMP since it may cause non-optimal merges that make the
368+
// number of points per leaf hard to predict
369+
while (iwc.getMergePolicy() instanceof MockRandomMergePolicy) {
370+
iwc.setMergePolicy(newMergePolicy());
371+
}
372+
IndexWriter w = new IndexWriter(dir, iwc);
373+
byte[] pointValue = new byte[3];
374+
byte[] uniquePointValue = new byte[3];
375+
random().nextBytes(uniquePointValue);
376+
final int numDocs =
377+
TEST_NIGHTLY ? atLeast(10000) : atLeast(500); // at night, make sure we have several leaves
378+
final boolean multiValues = random().nextBoolean();
379+
int totalValues = 0;
380+
for (int i = 0; i < numDocs; ++i) {
381+
Document doc = new Document();
382+
if (i == numDocs / 2) {
383+
totalValues++;
384+
doc.add(new BinaryPoint("f", uniquePointValue));
385+
} else {
386+
final int numValues = (multiValues) ? TestUtil.nextInt(random(), 2, 100) : 1;
387+
for (int j = 0; j < numValues; j++) {
388+
do {
389+
random().nextBytes(pointValue);
390+
} while (Arrays.equals(pointValue, uniquePointValue));
391+
doc.add(new BinaryPoint("f", pointValue));
392+
totalValues++;
393+
}
394+
}
395+
w.addDocument(doc);
396+
}
397+
w.forceMerge(1);
398+
final IndexReader r = DirectoryReader.open(w);
399+
w.close();
400+
401+
final LeafReader lr = getOnlyLeafReader(r);
402+
PointValues points = lr.getPointValues("f");
403+
404+
BKDReader.BaseTwoPhaseIntersectVisitor allPointsVisitor =
405+
new BKDReader.BaseTwoPhaseIntersectVisitor() {
406+
@Override
407+
public void visit(int docID, byte[] packedValue) throws IOException {}
408+
409+
@Override
410+
public void visit(int docID) throws IOException {}
411+
412+
@Override
413+
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
414+
return Relation.CELL_INSIDE_QUERY;
415+
}
416+
};
417+
418+
List<Long> savedBlocks = allPointsVisitor.deferredBlocks();
419+
assertEquals(0, savedBlocks.size()); // Test that all deferred blocks were processed
420+
assertEquals(totalValues, points.estimatePointCount(allPointsVisitor));
421+
assertEquals(numDocs, points.estimateDocCount(allPointsVisitor));
422+
423+
r.close();
424+
dir.close();
425+
}
426+
427+
public void testBasicWithPrefetchCapableVisitor() throws Exception {
428+
Directory dir = newDirectory();
429+
IndexWriterConfig iwc = newIndexWriterConfig();
430+
iwc.setMergePolicy(newLogMergePolicy());
431+
IndexWriter w = new IndexWriter(dir, iwc);
432+
byte[] point = new byte[4];
433+
for (int i = 0; i < 20; i++) {
434+
Document doc = new Document();
435+
NumericUtils.intToSortableBytes(i, point, 0);
436+
doc.add(new BinaryPoint("dim", point));
437+
w.addDocument(doc);
438+
}
439+
w.forceMerge(1);
440+
w.close();
441+
442+
DirectoryReader r = DirectoryReader.open(dir);
443+
LeafReader sub = getOnlyLeafReader(r);
444+
PointValues values = sub.getPointValues("dim");
445+
446+
// Simple test: make sure prefetch capable visitor can visit every doc when cell crosses query:
447+
BitSet seen = new BitSet();
448+
values.intersect(
449+
new BKDReader.BaseTwoPhaseIntersectVisitor() {
450+
@Override
451+
public Relation compare(byte[] minPacked, byte[] maxPacked) {
452+
return Relation.CELL_CROSSES_QUERY;
453+
}
454+
455+
@Override
456+
public void visit(int docID) {
457+
throw new IllegalStateException();
458+
}
459+
460+
@Override
461+
public void visit(int docID, byte[] packedValue) {
462+
seen.set(docID);
463+
assertEquals(docID, NumericUtils.sortableBytesToInt(packedValue, 0));
464+
}
465+
});
466+
assertEquals(20, seen.cardinality());
467+
// Make sure prefetch capable visitor can visit all docs when all docs are inside query
468+
// Also test we are not visiting documents twice based on whether PointTree has a prefetch
469+
// implementation of
470+
// prepareOrVisit or uses the default implementation
471+
seen.clear();
472+
final int[] docCount = {0};
473+
values.intersect(
474+
new BKDReader.BaseTwoPhaseIntersectVisitor() {
475+
@Override
476+
public void visit(int docID) throws IOException {
477+
seen.set(docID);
478+
docCount[0]++;
479+
}
480+
481+
@Override
482+
public void visit(int docID, byte[] packedValue) throws IOException {}
483+
484+
@Override
485+
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
486+
return Relation.CELL_INSIDE_QUERY;
487+
}
488+
});
489+
assertEquals(20, seen.cardinality());
490+
assertEquals(20, docCount[0]);
491+
IOUtils.close(r, dir);
492+
}
358493
}

0 commit comments

Comments
 (0)