diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportMultiSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportMultiSearchAction.java index f03e1fd4dd1c9..66aa15c569279 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportMultiSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportMultiSearchAction.java @@ -86,6 +86,7 @@ protected void doExecute(Task task, MultiSearchRequest request, ActionListener searchRequestSlots = new ConcurrentLinkedQueue<>(); for (int i = 0; i < request.requests().size(); i++) { SearchRequest searchRequest = request.requests().get(i); + searchRequest.setParentTask(client.getLocalNodeId(), task.getId()); searchRequestSlots.add(new SearchRequestSlot(searchRequest, i)); } diff --git a/server/src/test/java/org/elasticsearch/action/search/MultiSearchActionTookTests.java b/server/src/test/java/org/elasticsearch/action/search/MultiSearchActionTookTests.java index 01f1109ef3bed..19b53e2f8d380 100644 --- a/server/src/test/java/org/elasticsearch/action/search/MultiSearchActionTookTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/MultiSearchActionTookTests.java @@ -159,6 +159,11 @@ public void search(final SearchRequest request, final ActionListener DiscoveryNode.createLocal(settings, boundAddress.publishAddress(), UUIDs.randomBase64UUID()), null, + Collections.emptySet()) { + @Override + public TaskManager getTaskManager() { + return taskManager; + } + }; + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("test")).build()); + + String localNodeId = randomAlphaOfLengthBetween(3, 10); + int numSearchRequests = randomIntBetween(1, 100); + MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); + for (int i = 0; i < numSearchRequests; i++) { + multiSearchRequest.add(new SearchRequest()); + } + AtomicInteger counter = new AtomicInteger(0); + Task task = multiSearchRequest.createTask(randomLong(), "type", "action", null, Collections.emptyMap()); + NodeClient client = new NodeClient(settings, threadPool) { + @Override + public void search(final SearchRequest request, final ActionListener listener) { + assertEquals(task.getId(), request.getParentTask().getId()); + assertEquals(localNodeId, request.getParentTask().getNodeId()); + counter.incrementAndGet(); + listener.onResponse(SearchResponse.empty(() -> 1L, SearchResponse.Clusters.EMPTY)); + } + + @Override + public String getLocalNodeId() { + return localNodeId; + } + }; + TransportMultiSearchAction action = + new TransportMultiSearchAction(threadPool, actionFilters, transportService, clusterService, 10, System::nanoTime, client); + + PlainActionFuture future = newFuture(); + action.execute(task, multiSearchRequest, future); + future.get(); + assertEquals(numSearchRequests, counter.get()); + } finally { + assertTrue(ESTestCase.terminate(threadPool)); + } + } + + public void testBatchExecute() { // Initialize dependencies of TransportMultiSearchAction Settings settings = Settings.builder() .put("node.name", TransportMultiSearchActionTests.class.getSimpleName()) @@ -123,6 +181,11 @@ public void search(final SearchRequest request, final ActionListener