2424import org .elasticsearch .action .support .ActionFilter ;
2525import org .elasticsearch .action .support .ActionFilters ;
2626import org .elasticsearch .action .support .ActionTestUtils ;
27+ import org .elasticsearch .action .support .PlainActionFuture ;
2728import org .elasticsearch .client .node .NodeClient ;
2829import org .elasticsearch .cluster .ClusterName ;
2930import org .elasticsearch .cluster .ClusterState ;
3536import org .elasticsearch .common .UUIDs ;
3637import org .elasticsearch .common .settings .Settings ;
3738import org .elasticsearch .search .internal .InternalSearchResponse ;
39+ import org .elasticsearch .tasks .Task ;
3840import org .elasticsearch .tasks .TaskManager ;
3941import org .elasticsearch .test .ESTestCase ;
4042import org .elasticsearch .threadpool .TestThreadPool ;
5355import java .util .concurrent .atomic .AtomicInteger ;
5456import java .util .concurrent .atomic .AtomicReference ;
5557
58+ import static org .elasticsearch .action .support .PlainActionFuture .newFuture ;
5659import static org .hamcrest .Matchers .equalTo ;
5760import static org .hamcrest .Matchers .nullValue ;
5861import static org .mockito .Mockito .mock ;
@@ -76,7 +79,62 @@ public void tearDown() throws Exception {
7679 super .tearDown ();
7780 }
7881
79- public void testBatchExecute () throws Exception {
82+ public void testParentTaskId () throws Exception {
83+ // Initialize dependencies of TransportMultiSearchAction
84+ Settings settings = Settings .builder ()
85+ .put ("node.name" , TransportMultiSearchActionTests .class .getSimpleName ())
86+ .build ();
87+ ActionFilters actionFilters = mock (ActionFilters .class );
88+ when (actionFilters .filters ()).thenReturn (new ActionFilter [0 ]);
89+ ThreadPool threadPool = new ThreadPool (settings );
90+ try {
91+ TransportService transportService = new TransportService (Settings .EMPTY , mock (Transport .class ), threadPool ,
92+ TransportService .NOOP_TRANSPORT_INTERCEPTOR ,
93+ boundAddress -> DiscoveryNode .createLocal (settings , boundAddress .publishAddress (), UUIDs .randomBase64UUID ()), null ,
94+ Collections .emptySet ()) {
95+ @ Override
96+ public TaskManager getTaskManager () {
97+ return taskManager ;
98+ }
99+ };
100+ ClusterService clusterService = mock (ClusterService .class );
101+ when (clusterService .state ()).thenReturn (ClusterState .builder (new ClusterName ("test" )).build ());
102+
103+ String localNodeId = randomAlphaOfLengthBetween (3 , 10 );
104+ int numSearchRequests = randomIntBetween (1 , 100 );
105+ MultiSearchRequest multiSearchRequest = new MultiSearchRequest ();
106+ for (int i = 0 ; i < numSearchRequests ; i ++) {
107+ multiSearchRequest .add (new SearchRequest ());
108+ }
109+ AtomicInteger counter = new AtomicInteger (0 );
110+ Task task = multiSearchRequest .createTask (randomLong (), "type" , "action" , null , Collections .emptyMap ());
111+ NodeClient client = new NodeClient (settings , threadPool ) {
112+ @ Override
113+ public void search (final SearchRequest request , final ActionListener <SearchResponse > listener ) {
114+ assertEquals (task .getId (), request .getParentTask ().getId ());
115+ assertEquals (localNodeId , request .getParentTask ().getNodeId ());
116+ counter .incrementAndGet ();
117+ listener .onResponse (SearchResponse .empty (() -> 1L , SearchResponse .Clusters .EMPTY ));
118+ }
119+
120+ @ Override
121+ public String getLocalNodeId () {
122+ return localNodeId ;
123+ }
124+ };
125+ TransportMultiSearchAction action =
126+ new TransportMultiSearchAction (threadPool , actionFilters , transportService , clusterService , 10 , System ::nanoTime , client );
127+
128+ PlainActionFuture <MultiSearchResponse > future = newFuture ();
129+ action .execute (task , multiSearchRequest , future );
130+ future .get ();
131+ assertEquals (numSearchRequests , counter .get ());
132+ } finally {
133+ assertTrue (ESTestCase .terminate (threadPool ));
134+ }
135+ }
136+
137+ public void testBatchExecute () {
80138 // Initialize dependencies of TransportMultiSearchAction
81139 Settings settings = Settings .builder ()
82140 .put ("node.name" , TransportMultiSearchActionTests .class .getSimpleName ())
@@ -123,6 +181,11 @@ public void search(final SearchRequest request, final ActionListener<SearchRespo
123181 ShardSearchFailure .EMPTY_ARRAY , SearchResponse .Clusters .EMPTY ));
124182 });
125183 }
184+
185+ @ Override
186+ public String getLocalNodeId () {
187+ return "local_node_id" ;
188+ }
126189 };
127190
128191 TransportMultiSearchAction action =
0 commit comments