1414import org .elasticsearch .common .io .stream .StreamInput ;
1515import org .elasticsearch .common .io .stream .StreamOutput ;
1616import org .elasticsearch .common .settings .Settings ;
17+ import org .elasticsearch .core .AbstractRefCounted ;
1718import org .elasticsearch .core .IOUtils ;
19+ import org .elasticsearch .core .RefCounted ;
1820import org .elasticsearch .tasks .CancellableTask ;
1921import org .elasticsearch .tasks .Task ;
2022import org .elasticsearch .tasks .TaskId ;
2729import java .io .IOException ;
2830import java .util .Map ;
2931import java .util .concurrent .CountDownLatch ;
32+ import java .util .concurrent .atomic .AtomicReference ;
3033
3134import static org .hamcrest .Matchers .equalTo ;
35+ import static org .hamcrest .Matchers .notNullValue ;
3236
3337public class TransportActionProxyTests extends ESTestCase {
3438 protected ThreadPool threadPool ;
@@ -76,8 +80,9 @@ private MockTransportService buildService(final Version version) {
7680 public void testSendMessage () throws InterruptedException {
7781 serviceA .registerRequestHandler ("internal:test" , ThreadPool .Names .SAME , SimpleTestRequest ::new , (request , channel , task ) -> {
7882 assertEquals (request .sourceNode , "TS_A" );
79- SimpleTestResponse response = new SimpleTestResponse ("TS_A" );
83+ final SimpleTestResponse response = new SimpleTestResponse ("TS_A" );
8084 channel .sendResponse (response );
85+ assertThat (response .hasReferences (), equalTo (false ));
8186 });
8287 final boolean cancellable = randomBoolean ();
8388 TransportActionProxy .registerProxyAction (serviceA , "internal:test" , cancellable , SimpleTestResponse ::new );
@@ -86,21 +91,24 @@ public void testSendMessage() throws InterruptedException {
8691 serviceB .registerRequestHandler ("internal:test" , ThreadPool .Names .SAME , SimpleTestRequest ::new , (request , channel , task ) -> {
8792 assertThat (task instanceof CancellableTask , equalTo (cancellable ));
8893 assertEquals (request .sourceNode , "TS_A" );
89- SimpleTestResponse response = new SimpleTestResponse ("TS_B" );
94+ final SimpleTestResponse response = new SimpleTestResponse ("TS_B" );
9095 channel .sendResponse (response );
96+ assertThat (response .hasReferences (), equalTo (false ));
9197 });
9298 TransportActionProxy .registerProxyAction (serviceB , "internal:test" , cancellable , SimpleTestResponse ::new );
9399 AbstractSimpleTransportTestCase .connectToNode (serviceB , nodeC );
94100 serviceC .registerRequestHandler ("internal:test" , ThreadPool .Names .SAME , SimpleTestRequest ::new , (request , channel , task ) -> {
95101 assertThat (task instanceof CancellableTask , equalTo (cancellable ));
96102 assertEquals (request .sourceNode , "TS_A" );
97- SimpleTestResponse response = new SimpleTestResponse ("TS_C" );
103+ final SimpleTestResponse response = new SimpleTestResponse ("TS_C" );
98104 channel .sendResponse (response );
105+ assertThat (response .hasReferences (), equalTo (false ));
99106 });
100107
101108 TransportActionProxy .registerProxyAction (serviceC , "internal:test" , cancellable , SimpleTestResponse ::new );
102109
103- CountDownLatch latch = new CountDownLatch (1 );
110+ final CountDownLatch latch = new CountDownLatch (1 );
111+ // Node A -> Node B -> Node C
104112 serviceA .sendRequest (
105113 nodeB ,
106114 TransportActionProxy .getProxyAction ("internal:test" ),
@@ -133,6 +141,61 @@ public void handleException(TransportException exp) {
133141 latch .await ();
134142 }
135143
144+ public void testSendLocalRequest () throws InterruptedException {
145+ final AtomicReference <SimpleTestResponse > response = new AtomicReference <>();
146+ final boolean cancellable = randomBoolean ();
147+ serviceB .registerRequestHandler (
148+ "internal:test" ,
149+ randomFrom (ThreadPool .Names .SAME , ThreadPool .Names .GENERIC ),
150+ SimpleTestRequest ::new ,
151+ (request , channel , task ) -> {
152+ assertThat (task instanceof CancellableTask , equalTo (cancellable ));
153+ assertEquals (request .sourceNode , "TS_A" );
154+ final SimpleTestResponse responseB = new SimpleTestResponse ("TS_B" );
155+ channel .sendResponse (responseB );
156+ response .set (responseB );
157+ }
158+ );
159+ TransportActionProxy .registerProxyAction (serviceB , "internal:test" , cancellable , SimpleTestResponse ::new );
160+ AbstractSimpleTransportTestCase .connectToNode (serviceA , nodeB );
161+
162+ final CountDownLatch latch = new CountDownLatch (1 );
163+ // Node A -> Proxy Node B (Local execution)
164+ serviceA .sendRequest (
165+ nodeB ,
166+ TransportActionProxy .getProxyAction ("internal:test" ),
167+ TransportActionProxy .wrapRequest (nodeB , new SimpleTestRequest ("TS_A" , cancellable )), // Request
168+ new TransportResponseHandler <SimpleTestResponse >() {
169+ @ Override
170+ public SimpleTestResponse read (StreamInput in ) throws IOException {
171+ return new SimpleTestResponse (in );
172+ }
173+
174+ @ Override
175+ public void handleResponse (SimpleTestResponse response ) {
176+ try {
177+ assertEquals ("TS_B" , response .targetNode );
178+ } finally {
179+ latch .countDown ();
180+ }
181+ }
182+
183+ @ Override
184+ public void handleException (TransportException exp ) {
185+ try {
186+ throw new AssertionError (exp );
187+ } finally {
188+ latch .countDown ();
189+ }
190+ }
191+ }
192+ );
193+ latch .await ();
194+
195+ assertThat (response .get (), notNullValue ());
196+ assertThat (response .get ().hasReferences (), equalTo (false ));
197+ }
198+
136199 public void testException () throws InterruptedException {
137200 boolean cancellable = randomBoolean ();
138201 serviceA .registerRequestHandler ("internal:test" , ThreadPool .Names .SAME , SimpleTestRequest ::new , (request , channel , task ) -> {
@@ -230,7 +293,12 @@ public boolean shouldCancelChildrenOnCancellation() {
230293 }
231294
232295 public static class SimpleTestResponse extends TransportResponse {
296+
233297 final String targetNode ;
298+ final RefCounted refCounted = new AbstractRefCounted () {
299+ @ Override
300+ protected void closeInternal () {}
301+ };
234302
235303 SimpleTestResponse (String targetNode ) {
236304 this .targetNode = targetNode ;
@@ -245,6 +313,26 @@ public static class SimpleTestResponse extends TransportResponse {
245313 public void writeTo (StreamOutput out ) throws IOException {
246314 out .writeString (targetNode );
247315 }
316+
317+ @ Override
318+ public void incRef () {
319+ refCounted .incRef ();
320+ }
321+
322+ @ Override
323+ public boolean tryIncRef () {
324+ return refCounted .tryIncRef ();
325+ }
326+
327+ @ Override
328+ public boolean decRef () {
329+ return refCounted .decRef ();
330+ }
331+
332+ @ Override
333+ public boolean hasReferences () {
334+ return refCounted .hasReferences ();
335+ }
248336 }
249337
250338 public void testGetAction () {
0 commit comments