diff --git a/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java b/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java index d2d03156271cd..34434f50b456f 100644 --- a/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java +++ b/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java @@ -33,9 +33,9 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.StepListener; -import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.ShardRouting; +import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.StopWatch; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.collect.Tuple; @@ -71,7 +71,7 @@ import java.util.Locale; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Supplier; @@ -514,13 +514,6 @@ TimeValue prepareTargetForTranslog(final boolean fileBasedRecovery, final int to */ void phase2(long startingSeqNo, long requiredSeqNoRangeStart, long endingSeqNo, Translog.Snapshot snapshot, long maxSeenAutoIdTimestamp, long maxSeqNoOfUpdatesOrDeletes, ActionListener listener) throws IOException { - ActionListener.completeWith(listener, () -> sendSnapshotBlockingly( - startingSeqNo, requiredSeqNoRangeStart, endingSeqNo, snapshot, maxSeenAutoIdTimestamp, maxSeqNoOfUpdatesOrDeletes)); - } - - private SendSnapshotResult sendSnapshotBlockingly(long startingSeqNo, long requiredSeqNoRangeStart, long endingSeqNo, - Translog.Snapshot snapshot, long maxSeenAutoIdTimestamp, - long maxSeqNoOfUpdatesOrDeletes) throws IOException { assert requiredSeqNoRangeStart <= endingSeqNo + 1: "requiredSeqNoRangeStart " + requiredSeqNoRangeStart + " is larger than endingSeqNo " + endingSeqNo; assert startingSeqNo <= requiredSeqNoRangeStart : @@ -528,83 +521,87 @@ private SendSnapshotResult sendSnapshotBlockingly(long startingSeqNo, long requi if (shard.state() == IndexShardState.CLOSED) { throw new IndexShardClosedException(request.shardId()); } - - final StopWatch stopWatch = new StopWatch().start(); - logger.trace("recovery [phase2]: sending transaction log operations (seq# from [" + startingSeqNo + "], " + "required [" + requiredSeqNoRangeStart + ":" + endingSeqNo + "]"); - int ops = 0; - long size = 0; - int skippedOps = 0; - int totalSentOps = 0; - final AtomicLong targetLocalCheckpoint = new AtomicLong(SequenceNumbers.UNASSIGNED_SEQ_NO); - final List operations = new ArrayList<>(); + final AtomicInteger skippedOps = new AtomicInteger(); + final AtomicInteger totalSentOps = new AtomicInteger(); final LocalCheckpointTracker requiredOpsTracker = new LocalCheckpointTracker(endingSeqNo, requiredSeqNoRangeStart - 1); + final AtomicInteger lastBatchCount = new AtomicInteger(); // used to estimate the count of the subsequent batch. + final CheckedSupplier, IOException> readNextBatch = () -> { + // We need to synchronized Snapshot#next() because it's called by different threads through sendBatch. + // Even though those calls are not concurrent, Snapshot#next() uses non-synchronized state and is not multi-thread-compatible. + synchronized (snapshot) { + final List ops = lastBatchCount.get() > 0 ? new ArrayList<>(lastBatchCount.get()) : new ArrayList<>(); + long batchSizeInBytes = 0L; + Translog.Operation operation; + while ((operation = snapshot.next()) != null) { + if (shard.state() == IndexShardState.CLOSED) { + throw new IndexShardClosedException(request.shardId()); + } + cancellableThreads.checkForCancel(); + final long seqNo = operation.seqNo(); + if (seqNo < startingSeqNo || seqNo > endingSeqNo) { + skippedOps.incrementAndGet(); + continue; + } + ops.add(operation); + batchSizeInBytes += operation.estimateSize(); + totalSentOps.incrementAndGet(); + requiredOpsTracker.markSeqNoAsCompleted(seqNo); - final int expectedTotalOps = snapshot.totalOperations(); - if (expectedTotalOps == 0) { - logger.trace("no translog operations to send"); - } - - final CancellableThreads.IOInterruptible sendBatch = () -> { - // TODO: Make this non-blocking - final PlainActionFuture future = new PlainActionFuture<>(); - recoveryTarget.indexTranslogOperations( - operations, expectedTotalOps, maxSeenAutoIdTimestamp, maxSeqNoOfUpdatesOrDeletes, future); - targetLocalCheckpoint.set(future.actionGet()); - }; - - // send operations in batches - Translog.Operation operation; - while ((operation = snapshot.next()) != null) { - if (shard.state() == IndexShardState.CLOSED) { - throw new IndexShardClosedException(request.shardId()); - } - cancellableThreads.checkForCancel(); - - final long seqNo = operation.seqNo(); - if (seqNo < startingSeqNo || seqNo > endingSeqNo) { - skippedOps++; - continue; - } - operations.add(operation); - ops++; - size += operation.estimateSize(); - totalSentOps++; - requiredOpsTracker.markSeqNoAsCompleted(seqNo); - - // check if this request is past bytes threshold, and if so, send it off - if (size >= chunkSizeInBytes) { - cancellableThreads.executeIO(sendBatch); - logger.trace("sent batch of [{}][{}] (total: [{}]) translog operations", ops, new ByteSizeValue(size), expectedTotalOps); - ops = 0; - size = 0; - operations.clear(); + // check if this request is past bytes threshold, and if so, send it off + if (batchSizeInBytes >= chunkSizeInBytes) { + break; + } + } + lastBatchCount.set(ops.size()); + return ops; } - } - - if (!operations.isEmpty() || totalSentOps == 0) { - // send the leftover operations or if no operations were sent, request the target to respond with its local checkpoint - cancellableThreads.executeIO(sendBatch); - } + }; - assert expectedTotalOps == snapshot.skippedOperations() + skippedOps + totalSentOps - : String.format(Locale.ROOT, "expected total [%d], overridden [%d], skipped [%d], total sent [%d]", - expectedTotalOps, snapshot.skippedOperations(), skippedOps, totalSentOps); + final StopWatch stopWatch = new StopWatch().start(); + final ActionListener batchedListener = ActionListener.wrap( + targetLocalCheckpoint -> { + assert snapshot.totalOperations() == snapshot.skippedOperations() + skippedOps.get() + totalSentOps.get() + : String.format(Locale.ROOT, "expected total [%d], overridden [%d], skipped [%d], total sent [%d]", + snapshot.totalOperations(), snapshot.skippedOperations(), skippedOps.get(), totalSentOps.get()); + if (requiredOpsTracker.getCheckpoint() < endingSeqNo) { + throw new IllegalStateException("translog replay failed to cover required sequence numbers" + + " (required range [" + requiredSeqNoRangeStart + ":" + endingSeqNo + "). first missing op is [" + + (requiredOpsTracker.getCheckpoint() + 1) + "]"); + } + stopWatch.stop(); + final TimeValue tookTime = stopWatch.totalTime(); + logger.trace("recovery [phase2]: took [{}]", tookTime); + listener.onResponse(new SendSnapshotResult(targetLocalCheckpoint, totalSentOps.get(), tookTime)); + }, + listener::onFailure + ); + + sendBatch(readNextBatch, true, SequenceNumbers.UNASSIGNED_SEQ_NO, snapshot.totalOperations(), + maxSeenAutoIdTimestamp, maxSeqNoOfUpdatesOrDeletes, batchedListener); + } - if (requiredOpsTracker.getCheckpoint() < endingSeqNo) { - throw new IllegalStateException("translog replay failed to cover required sequence numbers" + - " (required range [" + requiredSeqNoRangeStart + ":" + endingSeqNo + "). first missing op is [" - + (requiredOpsTracker.getCheckpoint() + 1) + "]"); + private void sendBatch(CheckedSupplier, IOException> nextBatch, boolean firstBatch, + long targetLocalCheckpoint, int totalTranslogOps, long maxSeenAutoIdTimestamp, + long maxSeqNoOfUpdatesOrDeletes, ActionListener listener) throws IOException { + final List operations = nextBatch.get(); + // send the leftover operations or if no operations were sent, request the target to respond with its local checkpoint + if (operations.isEmpty() == false || firstBatch) { + cancellableThreads.execute(() -> { + recoveryTarget.indexTranslogOperations(operations, totalTranslogOps, maxSeenAutoIdTimestamp, maxSeqNoOfUpdatesOrDeletes, + ActionListener.wrap( + newCheckpoint -> { + sendBatch(nextBatch, false, SequenceNumbers.max(targetLocalCheckpoint, newCheckpoint), + totalTranslogOps, maxSeenAutoIdTimestamp, maxSeqNoOfUpdatesOrDeletes, listener); + }, + listener::onFailure + )); + }); + } else { + listener.onResponse(targetLocalCheckpoint); } - - logger.trace("sent final batch of [{}][{}] (total: [{}]) translog operations", ops, new ByteSizeValue(size), expectedTotalOps); - - stopWatch.stop(); - final TimeValue tookTime = stopWatch.totalTime(); - logger.trace("recovery [phase2]: took [{}]", tookTime); - return new SendSnapshotResult(targetLocalCheckpoint.get(), totalSentOps, tookTime); } void finalizeRecovery(final long targetLocalCheckpoint, final ActionListener listener) throws IOException { diff --git a/server/src/test/java/org/elasticsearch/indices/recovery/RecoverySourceHandlerTests.java b/server/src/test/java/org/elasticsearch/indices/recovery/RecoverySourceHandlerTests.java index 97f2cadfa3a5d..0cecc925b2488 100644 --- a/server/src/test/java/org/elasticsearch/indices/recovery/RecoverySourceHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/indices/recovery/RecoverySourceHandlerTests.java @@ -76,6 +76,10 @@ import org.elasticsearch.test.DummyShardLock; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.IndexSettingsModule; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; import java.io.IOException; import java.io.OutputStream; @@ -115,6 +119,18 @@ public class RecoverySourceHandlerTests extends ESTestCase { private final ShardId shardId = new ShardId(INDEX_SETTINGS.getIndex(), 1); private final ClusterSettings service = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + private ThreadPool threadPool; + + @Before + public void setUpThreadPool() { + threadPool = new TestThreadPool(getTestName()); + } + + @After + public void tearDownThreadPool() { + terminate(threadPool); + } + public void testSendFiles() throws Throwable { Settings settings = Settings.builder().put("indices.recovery.concurrent_streams", 1). put("indices.recovery.concurrent_small_file_streams", 1).build(); @@ -198,18 +214,17 @@ public StartRecoveryRequest getStartRecoveryRequest() throws IOException { } public void testSendSnapshotSendsOps() throws IOException { - final RecoverySettings recoverySettings = new RecoverySettings(Settings.EMPTY, service); - final int fileChunkSizeInBytes = recoverySettings.getChunkSize().bytesAsInt(); + final int fileChunkSizeInBytes = between(1, 4096); final StartRecoveryRequest request = getStartRecoveryRequest(); final IndexShard shard = mock(IndexShard.class); when(shard.state()).thenReturn(IndexShardState.STARTED); final List operations = new ArrayList<>(); - final int initialNumberOfDocs = randomIntBetween(16, 64); + final int initialNumberOfDocs = randomIntBetween(10, 1000); for (int i = 0; i < initialNumberOfDocs; i++) { final Engine.Index index = getIndex(Integer.toString(i)); operations.add(new Translog.Index(index, new Engine.IndexResult(1, 1, SequenceNumbers.UNASSIGNED_SEQ_NO, true))); } - final int numberOfDocsWithValidSequenceNumbers = randomIntBetween(16, 64); + final int numberOfDocsWithValidSequenceNumbers = randomIntBetween(10, 1000); for (int i = initialNumberOfDocs; i < initialNumberOfDocs + numberOfDocsWithValidSequenceNumbers; i++) { final Engine.Index index = getIndex(Integer.toString(i)); operations.add(new Translog.Index(index, new Engine.IndexResult(1, 1, i - initialNumberOfDocs, true))); @@ -219,12 +234,14 @@ public void testSendSnapshotSendsOps() throws IOException { final long endingSeqNo = randomIntBetween((int) requiredStartingSeqNo - 1, numberOfDocsWithValidSequenceNumbers - 1); final List shippedOps = new ArrayList<>(); + final AtomicLong checkpointOnTarget = new AtomicLong(SequenceNumbers.NO_OPS_PERFORMED); RecoveryTargetHandler recoveryTarget = new TestRecoveryTargetHandler() { @Override public void indexTranslogOperations(List operations, int totalTranslogOps, long timestamp, long msu, ActionListener listener) { shippedOps.addAll(operations); - listener.onResponse(SequenceNumbers.NO_OPS_PERFORMED); + checkpointOnTarget.set(randomLongBetween(checkpointOnTarget.get(), Long.MAX_VALUE)); + maybeExecuteAsync(() -> listener.onResponse(checkpointOnTarget.get())); } }; RecoverySourceHandler handler = new RecoverySourceHandler(shard, recoveryTarget, request, fileChunkSizeInBytes, between(1, 10)); @@ -239,6 +256,7 @@ public void indexTranslogOperations(List operations, int tot for (int i = 0; i < shippedOps.size(); i++) { assertThat(shippedOps.get(i), equalTo(operations.get(i + (int) startingSeqNo + initialNumberOfDocs))); } + assertThat(result.targetLocalCheckpoint, equalTo(checkpointOnTarget.get())); if (endingSeqNo >= requiredStartingSeqNo + 1) { // check that missing ops blows up List requiredOps = operations.subList(0, operations.size() - 1).stream() // remove last null marker @@ -253,6 +271,40 @@ public void indexTranslogOperations(List operations, int tot } } + public void testSendSnapshotStopOnError() throws Exception { + final int fileChunkSizeInBytes = between(1, 10 * 1024); + final StartRecoveryRequest request = getStartRecoveryRequest(); + final IndexShard shard = mock(IndexShard.class); + when(shard.state()).thenReturn(IndexShardState.STARTED); + final List ops = new ArrayList<>(); + for (int numOps = between(1, 256), i = 0; i < numOps; i++) { + final Engine.Index index = getIndex(Integer.toString(i)); + ops.add(new Translog.Index(index, new Engine.IndexResult(1, 1, i, true))); + } + final AtomicBoolean wasFailed = new AtomicBoolean(); + RecoveryTargetHandler recoveryTarget = new TestRecoveryTargetHandler() { + @Override + public void indexTranslogOperations(List operations, int totalTranslogOps, long timestamp, + long msu, ActionListener listener) { + if (randomBoolean()) { + maybeExecuteAsync(() -> listener.onResponse(SequenceNumbers.NO_OPS_PERFORMED)); + } else { + maybeExecuteAsync(() -> listener.onFailure(new RuntimeException("test - failed to index"))); + wasFailed.set(true); + } + } + }; + RecoverySourceHandler handler = new RecoverySourceHandler(shard, recoveryTarget, request, fileChunkSizeInBytes, between(1, 10)); + PlainActionFuture future = new PlainActionFuture<>(); + final long startingSeqNo = randomLongBetween(0, ops.size() - 1L); + final long endingSeqNo = randomLongBetween(startingSeqNo, ops.size() - 1L); + handler.phase2(startingSeqNo, startingSeqNo, endingSeqNo, newTranslogSnapshot(ops, Collections.emptyList()), + randomNonNegativeLong(), randomNonNegativeLong(), future); + if (wasFailed.get()) { + assertThat(expectThrows(RuntimeException.class, () -> future.actionGet()).getMessage(), equalTo("test - failed to index")); + } + } + private Engine.Index getIndex(final String id) { final String type = "test"; final ParseContext.Document document = new ParseContext.Document(); @@ -717,4 +769,12 @@ public void close() { } }; } + + private void maybeExecuteAsync(Runnable runnable) { + if (randomBoolean()) { + threadPool.generic().execute(runnable); + } else { + runnable.run(); + } + } }