Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ abstract class InitialSearchPhase<FirstResult extends SearchPhaseResult> extends
private final Logger logger;
private final int expectedTotalOps;
private final AtomicInteger totalOps = new AtomicInteger();
private final AtomicInteger shardExecutionIndex = new AtomicInteger(0);
private final int maxConcurrentShardRequests;

InitialSearchPhase(String name, SearchRequest request, GroupShardsIterator<SearchShardIterator> shardsIts, Logger logger) {
super(name);
Expand All @@ -61,6 +63,7 @@ abstract class InitialSearchPhase<FirstResult extends SearchPhaseResult> extends
// on a per shards level we use shardIt.remaining() to increment the totalOps pointer but add 1 for the current shard result
// we process hence we add one for the non active partition here.
this.expectedTotalOps = shardsIts.totalSizeWith1ForEmpty();
maxConcurrentShardRequests = Math.min(request.getMaxConcurrentShardRequests(), shardsIts.size());
}

private void onShardFailure(final int shardIndex, @Nullable ShardRouting shard, @Nullable String nodeId,
Expand Down Expand Up @@ -105,6 +108,7 @@ private void onShardFailure(final int shardIndex, @Nullable ShardRouting shard,
onShardFailure(shardIndex, shard, shard.currentNodeId(), shardIt, inner);
}
} else {
maybeExecuteNext(); // move to the next execution if needed
// no more shards active, add a failure
if (logger.isDebugEnabled() && !logger.isTraceEnabled()) { // do not double log this exception
if (e != null && !TransportActions.isShardNotAvailableException(e)) {
Expand All @@ -124,23 +128,25 @@ private void onShardFailure(final int shardIndex, @Nullable ShardRouting shard,

@Override
public final void run() throws IOException {
int shardIndex = -1;
for (final SearchShardIterator shardIt : shardsIts) {
shardIndex++;
final ShardRouting shard = shardIt.nextOrNull();
if (shard != null) {
performPhaseOnShard(shardIndex, shardIt, shard);
} else {
// really, no shards active in this group
onShardFailure(shardIndex, null, null, shardIt, new NoShardAvailableActionException(shardIt.shardId()));
}
boolean success = shardExecutionIndex.compareAndSet(0, maxConcurrentShardRequests);
assert success;
for (int i = 0; i < maxConcurrentShardRequests; i++) {
SearchShardIterator shardRoutings = shardsIts.get(i);
performPhaseOnShard(i, shardRoutings, shardRoutings.nextOrNull());
}
}

private void maybeExecuteNext() {
final int index = shardExecutionIndex.getAndIncrement();
if (index < shardsIts.size()) {
SearchShardIterator shardRoutings = shardsIts.get(index);
performPhaseOnShard(index, shardRoutings, shardRoutings.nextOrNull());
}
}


private void performPhaseOnShard(final int shardIndex, final SearchShardIterator shardIt, final ShardRouting shard) {
if (shard == null) {
// TODO upgrade this to an assert...
// no more active shards... (we should not really get here, but just for safety)
onShardFailure(shardIndex, null, null, shardIt, new NoShardAvailableActionException(shardIt.shardId()));
} else {
try {
Expand All @@ -166,6 +172,7 @@ public void onFailure(Exception t) {
}

private void onShardResult(FirstResult result, ShardIterator shardIt) {
maybeExecuteNext();
assert result.getShardIndex() != -1 : "shard index is not set";
assert result.getSearchShardTarget() != null : "search shard target must not be null";
onShardSuccess(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.elasticsearch.action.search;

import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.IndicesRequest;
Expand Down Expand Up @@ -74,6 +75,8 @@ public final class SearchRequest extends ActionRequest implements IndicesRequest

private int batchedReduceSize = 512;

private int maxConcurrentShardRequests = 0;

private String[] types = Strings.EMPTY_ARRAY;

public static final IndicesOptions DEFAULT_INDICES_OPTIONS = IndicesOptions.strictExpandOpenAndForbidClosed();
Expand Down Expand Up @@ -302,6 +305,34 @@ public int getBatchedReduceSize() {
return batchedReduceSize;
}

/**
* Returns the number of shard requests that should be executed concurrently. This value should be used as a protection mechanism to
* reduce the number of shard reqeusts fired per high level search request. Searches that hit the entire cluster can be throttled
* with this number to reduce the cluster load. The default grows with the number of nodes in the cluster but is at most <tt>256</tt>.
*/
public int getMaxConcurrentShardRequests() {
return maxConcurrentShardRequests == 0 ? 256 : maxConcurrentShardRequests;
}

/**
* Sets the number of shard requests that should be executed concurrently. This value should be used as a protection mechanism to
* reduce the number of shard requests fired per high level search request. Searches that hit the entire cluster can be throttled
* with this number to reduce the cluster load. The default grows with the number of nodes in the cluster but is at most <tt>256</tt>.
*/
public void setMaxConcurrentShardRequests(int maxConcurrentShardRequests) {
if (maxConcurrentShardRequests < 1) {
throw new IllegalArgumentException("maxConcurrentShardRequests must be >= 1");
}
this.maxConcurrentShardRequests = maxConcurrentShardRequests;
}

/**
* Returns <code>true</code> iff the maxConcurrentShardRequest is set.
*/
boolean isMaxConcurrentShardRequestsSet() {
return maxConcurrentShardRequests != 0;
}

/**
* @return true if the request only has suggest
*/
Expand Down Expand Up @@ -349,6 +380,9 @@ public void readFrom(StreamInput in) throws IOException {
indicesOptions = IndicesOptions.readIndicesOptions(in);
requestCache = in.readOptionalBoolean();
batchedReduceSize = in.readVInt();
if (in.getVersion().onOrAfter(Version.V_6_0_0_beta1)) {
maxConcurrentShardRequests = in.readVInt();
}
}

@Override
Expand All @@ -367,6 +401,9 @@ public void writeTo(StreamOutput out) throws IOException {
indicesOptions.writeIndicesOptions(out);
out.writeOptionalBoolean(requestCache);
out.writeVInt(batchedReduceSize);
if (out.getVersion().onOrAfter(Version.V_6_0_0_beta1)) {
out.writeVInt(maxConcurrentShardRequests);
}
}

@Override
Expand All @@ -386,13 +423,15 @@ public boolean equals(Object o) {
Objects.equals(requestCache, that.requestCache) &&
Objects.equals(scroll, that.scroll) &&
Arrays.equals(types, that.types) &&
Objects.equals(batchedReduceSize, that.batchedReduceSize) &&
Objects.equals(maxConcurrentShardRequests, that.maxConcurrentShardRequests) &&
Objects.equals(indicesOptions, that.indicesOptions);
}

@Override
public int hashCode() {
return Objects.hash(searchType, Arrays.hashCode(indices), routing, preference, source, requestCache,
scroll, Arrays.hashCode(types), indicesOptions);
scroll, Arrays.hashCode(types), indicesOptions, maxConcurrentShardRequests);
}

@Override
Expand All @@ -406,6 +445,8 @@ public String toString() {
", preference='" + preference + '\'' +
", requestCache=" + requestCache +
", scroll=" + scroll +
", maxConcurrentShardRequests=" + maxConcurrentShardRequests +
", batchedReduceSize=" + batchedReduceSize +
", source=" + source + '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -525,4 +525,14 @@ public SearchRequestBuilder setBatchedReduceSize(int batchedReduceSize) {
this.request.setBatchedReduceSize(batchedReduceSize);
return this;
}

/**
* Sets the number of shard requests that should be executed concurrently. This value should be used as a protection mechanism to
* reduce the number of shard requests fired per high level search request. Searches that hit the entire cluster can be throttled
* with this number to reduce the cluster load. The default grows with the number of nodes in the cluster but is at most <tt>256</tt>.
*/
public SearchRequestBuilder setMaxConcurrentShardRequests(int maxConcurrentShardRequests) {
this.request.setMaxConcurrentShardRequests(maxConcurrentShardRequests);
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.IndexMetaData;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
Expand Down Expand Up @@ -184,16 +185,19 @@ protected void doExecute(Task task, SearchRequest searchRequest, ActionListener<
OriginalIndices localIndices = remoteClusterIndices.remove(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
if (remoteClusterIndices.isEmpty()) {
executeSearch((SearchTask)task, timeProvider, searchRequest, localIndices, remoteClusterIndices, Collections.emptyList(),
(clusterName, nodeId) -> null, clusterState, Collections.emptyMap(), listener);
(clusterName, nodeId) -> null, clusterState, Collections.emptyMap(), listener, clusterState.getNodes()
.getDataNodes().size());
} else {
remoteClusterService.collectSearchShards(searchRequest.indicesOptions(), searchRequest.preference(), searchRequest.routing(),
remoteClusterIndices, ActionListener.wrap((searchShardsResponses) -> {
List<SearchShardIterator> remoteShardIterators = new ArrayList<>();
Map<String, AliasFilter> remoteAliasFilters = new HashMap<>();
BiFunction<String, String, DiscoveryNode> clusterNodeLookup = processRemoteShards(searchShardsResponses,
remoteClusterIndices, remoteShardIterators, remoteAliasFilters);
int numNodesInvovled = searchShardsResponses.values().stream().mapToInt(r -> r.getNodes().length).sum()
+ clusterState.getNodes().getDataNodes().size();
executeSearch((SearchTask) task, timeProvider, searchRequest, localIndices, remoteClusterIndices, remoteShardIterators,
clusterNodeLookup, clusterState, remoteAliasFilters, listener);
clusterNodeLookup, clusterState, remoteAliasFilters, listener, numNodesInvovled);
}, listener::onFailure));
}
}
Expand Down Expand Up @@ -247,7 +251,7 @@ static BiFunction<String, String, DiscoveryNode> processRemoteShards(Map<String,
private void executeSearch(SearchTask task, SearchTimeProvider timeProvider, SearchRequest searchRequest, OriginalIndices localIndices,
Map<String, OriginalIndices> remoteClusterIndices, List<SearchShardIterator> remoteShardIterators,
BiFunction<String, String, DiscoveryNode> remoteConnections, ClusterState clusterState,
Map<String, AliasFilter> remoteAliasMap, ActionListener<SearchResponse> listener) {
Map<String, AliasFilter> remoteAliasMap, ActionListener<SearchResponse> listener, int nodeCount) {

clusterState.blocks().globalBlockedRaiseException(ClusterBlockLevel.READ);
// TODO: I think startTime() should become part of ActionRequest and that should be used both for index name
Expand Down Expand Up @@ -300,7 +304,15 @@ private void executeSearch(SearchTask task, SearchTimeProvider timeProvider, Sea
}
return searchTransportService.getConnection(clusterName, discoveryNode);
};

if (searchRequest.isMaxConcurrentShardRequestsSet() == false) {
// we try to set a default of max concurrent shard requests based on
// the node count but upper-bound it by 256 by default to keep it sane. A single
// search request that fans out lots of shards should hit a cluster too hard while 256 is already a lot
// we multiply is by the default number of shards such that a single request in a cluster of 1 would hit all shards of a
// default index.
searchRequest.setMaxConcurrentShardRequests(Math.min(256, nodeCount
* IndexMetaData.INDEX_NUMBER_OF_SHARDS_SETTING.getDefault(Settings.EMPTY)));
}
searchAsyncAction(task, searchRequest, shardIterators, timeProvider, connectionLookup, clusterState.version(),
Collections.unmodifiableMap(aliasFilter), concreteIndexBoosts, listener).start();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,8 @@ public int size() {
public Iterator<ShardIt> iterator() {
return iterators.iterator();
}

public ShardIt get(int index) {
return iterators.get(index);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ public static void parseSearchRequest(SearchRequest searchRequest, RestRequest r
final int batchedReduceSize = request.paramAsInt("batched_reduce_size", searchRequest.getBatchedReduceSize());
searchRequest.setBatchedReduceSize(batchedReduceSize);

if (request.hasParam("max_concurrent_shard_requests")) {
// only set if we have the parameter since we auto adjust the max concurrency on the coordinator
// based on the number of nodes in the cluster
final int maxConcurrentShardRequests = request.paramAsInt("max_concurrent_shard_requests",
searchRequest.getMaxConcurrentShardRequests());
searchRequest.setMaxConcurrentShardRequests(maxConcurrentShardRequests);
}

// do not allow 'query_and_fetch' or 'dfs_query_and_fetch' search types
// from the REST layer. these modes are an internal optimization and should
// not be specified explicitly by the user.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,113 @@
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

public class SearchAsyncActionTests extends ESTestCase {

public void testLimitConcurrentShardRequests() throws InterruptedException {
SearchRequest request = new SearchRequest();
int numConcurrent = randomIntBetween(1, 5);
request.setMaxConcurrentShardRequests(numConcurrent);
CountDownLatch latch = new CountDownLatch(1);
AtomicReference<TestSearchResponse> response = new AtomicReference<>();
ActionListener<SearchResponse> responseListener = new ActionListener<SearchResponse>() {
@Override
public void onResponse(SearchResponse searchResponse) {
response.set((TestSearchResponse) searchResponse);
}

@Override
public void onFailure(Exception e) {
logger.warn("test failed", e);
fail(e.getMessage());
}
};
DiscoveryNode primaryNode = new DiscoveryNode("node_1", buildNewFakeTransportAddress(), Version.CURRENT);
DiscoveryNode replicaNode = new DiscoveryNode("node_2", buildNewFakeTransportAddress(), Version.CURRENT);

AtomicInteger contextIdGenerator = new AtomicInteger(0);
GroupShardsIterator<SearchShardIterator> shardsIter = getShardsIter("idx",
new OriginalIndices(new String[]{"idx"}, IndicesOptions.strictExpandOpenAndForbidClosed()),
10, randomBoolean(), primaryNode, replicaNode);
SearchTransportService transportService = new SearchTransportService(Settings.EMPTY, null);
Map<String, Transport.Connection> lookup = new HashMap<>();
Map<ShardId, Boolean> seenShard = new ConcurrentHashMap<>();
lookup.put(primaryNode.getId(), new MockConnection(primaryNode));
lookup.put(replicaNode.getId(), new MockConnection(replicaNode));
Map<String, AliasFilter> aliasFilters = Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY));
CountDownLatch awaitInitialRequests = new CountDownLatch(1);
AtomicInteger numRequests = new AtomicInteger(0);
AtomicInteger numResponses = new AtomicInteger(0);
AbstractSearchAsyncAction asyncAction =
new AbstractSearchAsyncAction<TestSearchPhaseResult>(
"test",
logger,
transportService,
(cluster, node) -> {
assert cluster == null : "cluster was not null: " + cluster;
return lookup.get(node); },
aliasFilters,
Collections.emptyMap(),
null,
request,
responseListener,
shardsIter,
new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0),
0,
null,
new InitialSearchPhase.SearchPhaseResults<>(shardsIter.size())) {

@Override
protected void executePhaseOnShard(SearchShardIterator shardIt, ShardRouting shard,
SearchActionListener<TestSearchPhaseResult> listener) {
seenShard.computeIfAbsent(shard.shardId(), (i) -> {
numRequests.incrementAndGet(); // only count this once per replica
return Boolean.TRUE;
});

new Thread(() -> {
try {
awaitInitialRequests.await();
} catch (InterruptedException e) {
throw new AssertionError(e);
}
Transport.Connection connection = getConnection(null, shard.currentNodeId());
TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(contextIdGenerator.incrementAndGet(),
connection.getNode());
if (numResponses.getAndIncrement() > 0 && randomBoolean()) { // at least one response otherwise the entire
// request fails
listener.onFailure(new RuntimeException());
} else {
listener.onResponse(testSearchPhaseResult);
}

}).start();
}

@Override
protected SearchPhase getNextPhase(SearchPhaseResults<TestSearchPhaseResult> results, SearchPhaseContext context) {
return new SearchPhase("test") {
@Override
public void run() throws IOException {
latch.countDown();
}
};
}
};
asyncAction.start();
assertEquals(numConcurrent, numRequests.get());
awaitInitialRequests.countDown();
latch.await();
assertEquals(10, numRequests.get());
}

public void testFanOutAndCollect() throws InterruptedException {
SearchRequest request = new SearchRequest();
request.setMaxConcurrentShardRequests(randomIntBetween(1, 100));
CountDownLatch latch = new CountDownLatch(1);
AtomicReference<TestSearchResponse> response = new AtomicReference<>();
ActionListener<SearchResponse> responseListener = new ActionListener<SearchResponse>() {
Expand Down
Loading