Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,9 @@
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.transport.Transport;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.function.BiFunction;
Expand Down Expand Up @@ -128,7 +126,18 @@ private static List<SearchShardIterator> sortShards(GroupShardsIterator<SearchSh
}

private static boolean shouldSortShards(MinAndMax<?>[] minAndMaxes) {
return Arrays.stream(minAndMaxes).anyMatch(Objects::nonNull);
Class<?> clazz = null;
for (MinAndMax<?> minAndMax : minAndMaxes) {
if (clazz == null) {
clazz = minAndMax == null ? null : minAndMax.getMin().getClass();
} else if (minAndMax != null && clazz != minAndMax.getMin().getClass()) {
// we don't support sort values that mix different types (e.g.: long/double, numeric/keyword).
// TODO: we could fail the request because there is a high probability
// that the merging of topdocs will fail later for the same reason ?
return false;
}
}
return clazz != null;
}

private static Comparator<Integer> shardComparator(GroupShardsIterator<SearchShardIterator> shardsIts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ public void writeTo(StreamOutput out) throws IOException {
/**
* Return the minimum value.
*/
T getMin() {
public T getMin() {
return minValue;
}

/**
* Return the maximum value.
*/
T getMax() {
public T getMax() {
return maxValue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.elasticsearch.action.search;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.OriginalIndices;
Expand Down Expand Up @@ -54,6 +55,8 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.IntStream;

import static org.hamcrest.Matchers.equalTo;

public class CanMatchPreFilterSearchPhaseTests extends ESTestCase {

public void testFilterShards() throws InterruptedException {
Expand Down Expand Up @@ -350,4 +353,76 @@ public void run() {
}
}
}

public void testInvalidSortShards() throws InterruptedException {
final TransportSearchAction.SearchTimeProvider timeProvider =
new TransportSearchAction.SearchTimeProvider(0, System.nanoTime(), System::nanoTime);

Map<String, Transport.Connection> lookup = new ConcurrentHashMap<>();
DiscoveryNode primaryNode = new DiscoveryNode("node_1", buildNewFakeTransportAddress(), Version.CURRENT);
DiscoveryNode replicaNode = new DiscoveryNode("node_2", buildNewFakeTransportAddress(), Version.CURRENT);
lookup.put("node1", new SearchAsyncActionTests.MockConnection(primaryNode));
lookup.put("node2", new SearchAsyncActionTests.MockConnection(replicaNode));

for (SortOrder order : SortOrder.values()) {
int numShards = randomIntBetween(2, 20);
List<ShardId> shardIds = new ArrayList<>();
Set<ShardId> shardToSkip = new HashSet<>();

SearchTransportService searchTransportService = new SearchTransportService(null, null) {
@Override
public void sendCanMatch(Transport.Connection connection, ShardSearchRequest request, SearchTask task,
ActionListener<SearchService.CanMatchResponse> listener) {
final MinAndMax<?> minMax;
if (request.shardId().id() == numShards-1) {
minMax = new MinAndMax<>(new BytesRef("bar"), new BytesRef("baz"));
} else {
Long min = randomLong();
Long max = randomLongBetween(min, Long.MAX_VALUE);
minMax = new MinAndMax<>(min, max);
}
boolean canMatch = frequently();
synchronized (shardIds) {
shardIds.add(request.shardId());
if (canMatch == false) {
shardToSkip.add(request.shardId());
}
}
new Thread(() -> listener.onResponse(new SearchService.CanMatchResponse(canMatch, minMax))).start();
}
};

AtomicReference<GroupShardsIterator<SearchShardIterator>> result = new AtomicReference<>();
CountDownLatch latch = new CountDownLatch(1);
GroupShardsIterator<SearchShardIterator> shardsIter = SearchAsyncActionTests.getShardsIter("logs",
new OriginalIndices(new String[]{"logs"}, SearchRequest.DEFAULT_INDICES_OPTIONS),
numShards, randomBoolean(), primaryNode, replicaNode);
final SearchRequest searchRequest = new SearchRequest();
searchRequest.source(new SearchSourceBuilder().sort(SortBuilders.fieldSort("timestamp").order(order)));
searchRequest.allowPartialSearchResults(true);

CanMatchPreFilterSearchPhase canMatchPhase = new CanMatchPreFilterSearchPhase(logger,
searchTransportService,
(clusterAlias, node) -> lookup.get(node),
Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)),
Collections.emptyMap(), Collections.emptyMap(), EsExecutors.newDirectExecutorService(),
searchRequest, null, shardsIter, timeProvider, ClusterState.EMPTY_STATE, null,
(iter) -> new SearchPhase("test") {
@Override
public void run() {
result.set(iter);
latch.countDown();
}
}, SearchResponse.Clusters.EMPTY);

canMatchPhase.start();
latch.await();
int shardId = 0;
for (SearchShardIterator i : result.get()) {
assertThat(i.shardId().id(), equalTo(shardId++));
assertEquals(shardToSkip.contains(i.shardId()), i.skip());
}
assertThat(result.get().size(), equalTo(numShards));
}
}
}