diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportShardBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportShardBulkAction.java index e709c24bc64c7..c371061561614 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportShardBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportShardBulkAction.java @@ -53,6 +53,8 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.common.util.concurrent.RunOnce; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentType; @@ -75,6 +77,7 @@ import java.io.IOException; import java.util.Map; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.function.LongSupplier; @@ -88,6 +91,7 @@ public class TransportShardBulkAction extends TransportWriteAction max) { + release.run(); + throw new EsRejectedExecutionException("rejected executing primary bulk operation on " + ThreadPool.Names.WRITE + + " has " + (active - 1) + " active operations"); + } + return release; + } + + @Override + protected Runnable beforeReplica() { + RunOnce release = new RunOnce(activeOperations::decrementAndGet); + activeOperations.incrementAndGet(); + return release; + } + public static void performOnPrimary( BulkShardRequest request, IndexShard primary, diff --git a/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java b/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java index 81402ffead304..ba2a4bd2b45d1 100644 --- a/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java @@ -250,9 +250,31 @@ protected void handleOperationRequest(final Request request, final TransportChan execute(task, request, new ChannelActionListener<>(channel, actionName, request)); } - protected void handlePrimaryRequest(final ConcreteShardRequest request, final TransportChannel channel, final Task task) { - new AsyncPrimaryAction( - request, new ChannelActionListener<>(channel, transportPrimaryAction, request), (ReplicationTask) task).run(); + final void handlePrimaryRequest(final ConcreteShardRequest request, final TransportChannel channel, final Task task) { + ActionListener onCompletionListener = new ChannelActionListener<>(channel, transportPrimaryAction, request); + try { + Runnable afterPrimary = beforePrimary(); + if (afterPrimary != null) { + onCompletionListener = ActionListener.runAfter(onCompletionListener, afterPrimary); + } + } catch (RuntimeException e) { + onCompletionListener.onFailure(e); + return; + } + try { + new AsyncPrimaryAction( + request, onCompletionListener, (ReplicationTask) task).run(); + } catch (RuntimeException e) { + onCompletionListener.onFailure(e); + } + } + + /** + * Override this to execute code before (and after) primary action. + * @return runnable to invoke after completion of primary action. Returning null means no action to perform after. + */ + protected Runnable beforePrimary() { + return null; } class AsyncPrimaryAction extends AbstractRunnable { @@ -467,8 +489,30 @@ public void runPostReplicaActions(ActionListener listener) { protected void handleReplicaRequest(final ConcreteReplicaRequest replicaRequest, final TransportChannel channel, final Task task) { - new AsyncReplicaAction( - replicaRequest, new ChannelActionListener<>(channel, transportReplicaAction, replicaRequest), (ReplicationTask) task).run(); + ActionListener onCompletionListener = new ChannelActionListener<>(channel, transportReplicaAction, replicaRequest); + try { + Runnable afterPrimary = beforeReplica(); + if (afterPrimary != null) { + onCompletionListener = ActionListener.runAfter(onCompletionListener, afterPrimary); + } + } catch (RuntimeException e) { + onCompletionListener.onFailure(e); + return; + } + try { + new AsyncReplicaAction( + replicaRequest, onCompletionListener, (ReplicationTask) task).run(); + } catch (RuntimeException e) { + onCompletionListener.onFailure(e); + } + } + + /** + * Override this to execute code before (and after) replica action. + * @return runnable to invoke after completion of replica action. Returning null means no action to perform after. + */ + protected Runnable beforeReplica() { + return null; } public static class RetryOnReplicaException extends ElasticsearchException { diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkRejectionSingleNodeIT.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkRejectionSingleNodeIT.java new file mode 100644 index 0000000000000..8b04c2de84cf9 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkRejectionSingleNodeIT.java @@ -0,0 +1,150 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.action.bulk; + +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateUpdateTask; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.threadpool.ThreadPoolStats; +import org.hamcrest.Matchers; + +import java.util.Arrays; +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.StreamSupport; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; + +public class BulkRejectionSingleNodeIT extends ESSingleNodeTestCase { + @Override + protected Settings nodeSettings() { + return Settings.builder() + .put("thread_pool.write.queue_size", randomIntBetween(2, 20)) + .build(); + } + + private class NoOverflowCountDownLatch { + private CountDownLatch latch = new CountDownLatch(1); + private AtomicInteger counter; + + private NoOverflowCountDownLatch(int count) { + this.counter = new AtomicInteger(count); + } + + public void countDown() { + int value = counter.decrementAndGet(); + assertThat(value, greaterThanOrEqualTo(0)); + if (value == 0) { + latch.countDown(); + } + } + + public void await(TimeValue waitTime) throws InterruptedException { + assertTrue(latch.await(waitTime.millis(), TimeUnit.MILLISECONDS)); + } + } + + public void testBulkRejectionOnWaitingForClusterStateUpdate() throws Exception { + final String index = "test"; + assertAcked(client().admin().indices().prepareCreate(index)); + ThreadPool threadPool = getInstanceFromNode(ThreadPool.class); + ThreadPool.Info info = threadPool.info(ThreadPool.Names.WRITE); + int maxActive = Math.toIntExact(info.getMax() + info.getQueueSize().getSingles()); + int requests = maxActive + randomIntBetween(1, 100); + logger.info("maxActive {}, requests {}", maxActive, requests); + NoOverflowCountDownLatch completed = new NoOverflowCountDownLatch(requests); + NoOverflowCountDownLatch rejected = new NoOverflowCountDownLatch(requests - maxActive); + CountDownLatch masterWaiting = new CountDownLatch(1); + CountDownLatch releaseMaster = new CountDownLatch(1); + ClusterService clusterService = getInstanceFromNode(ClusterService.class); + clusterService.submitStateUpdateTask("test", new ClusterStateUpdateTask() { + @Override + public ClusterState execute(ClusterState currentState) throws Exception { + masterWaiting.countDown(); + assertTrue(releaseMaster.await(10, TimeUnit.SECONDS)); + return currentState; + } + + @Override + public void onFailure(String source, Exception e) { + fail(); + } + }); + try { + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(BulkResponse bulkItemResponses) { + completed.countDown(); + Arrays.stream(bulkItemResponses.getItems()).filter(BulkItemResponse::isFailed) + .forEach(r -> { + assertThat(ExceptionsHelper.unwrapCause(r.getFailure().getCause()), + Matchers.instanceOf(EsRejectedExecutionException.class)); + rejected.countDown(); + }); + } + + @Override + public void onFailure(Exception e) { + completed.countDown(); + assertThat(ExceptionsHelper.unwrapCause(e), Matchers.instanceOf(EsRejectedExecutionException.class)); + rejected.countDown(); + } + }; + for (int i = 0; i < requests; ++i) { + final BulkRequest request = new BulkRequest(); + request.add(new IndexRequest(index).source(Collections.singletonMap("key", "valuea" + i))); + waitEmpty(threadPool, ThreadPool.Names.WRITE); + client().bulk(request, listener); + } + + rejected.await(TimeValue.timeValueSeconds(10)); + } finally { + releaseMaster.countDown(); + } + completed.await(TimeValue.timeValueSeconds(10)); + } + + private void waitEmpty(ThreadPool threadPool, String name) throws InterruptedException { + long begin = System.currentTimeMillis(); + long sleep = 0; + while (true) { + ThreadPoolStats stats = threadPool.stats(); + if (StreamSupport.stream(stats.spliterator(), false).filter(s -> s.getName().equals(name)) + .anyMatch(s -> s.getQueue() == 0)) { + return; + } + if (System.currentTimeMillis() > (begin + 10000)) { + fail("Waiting for empty queue timed out: " + name); + } + Thread.sleep(sleep); + sleep += 10; + } + } +}