5454import org .junit .BeforeClass ;
5555
5656import java .util .ArrayList ;
57+ import java .util .Arrays ;
5758import java .util .Collections ;
5859import java .util .HashMap ;
5960import java .util .HashSet ;
@@ -177,6 +178,8 @@ public void testThreadContext() throws InterruptedException {
177178
178179 try (ThreadContext .StoredContext ignored = threadPool .getThreadContext ().stashContext ()) {
179180 final Map <String , String > expectedHeaders = Collections .singletonMap ("test" , "test" );
181+ final Map <String , List <String >> expectedResponseHeaders = Collections .singletonMap ("testResponse" ,
182+ Arrays .asList ("testResponse" ));
180183 threadPool .getThreadContext ().putHeader (expectedHeaders );
181184
182185 final TimeValue ackTimeout = randomBoolean () ? TimeValue .ZERO : TimeValue .timeValueMillis (randomInt (10000 ));
@@ -187,6 +190,8 @@ public void testThreadContext() throws InterruptedException {
187190 public ClusterState execute (ClusterState currentState ) {
188191 assertTrue (threadPool .getThreadContext ().isSystemContext ());
189192 assertEquals (Collections .emptyMap (), threadPool .getThreadContext ().getHeaders ());
193+ threadPool .getThreadContext ().addResponseHeader ("testResponse" , "testResponse" );
194+ assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
190195
191196 if (randomBoolean ()) {
192197 return ClusterState .builder (currentState ).build ();
@@ -201,13 +206,15 @@ public ClusterState execute(ClusterState currentState) {
201206 public void onFailure (String source , Exception e ) {
202207 assertFalse (threadPool .getThreadContext ().isSystemContext ());
203208 assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
209+ assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
204210 latch .countDown ();
205211 }
206212
207213 @ Override
208214 public void clusterStateProcessed (String source , ClusterState oldState , ClusterState newState ) {
209215 assertFalse (threadPool .getThreadContext ().isSystemContext ());
210216 assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
217+ assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
211218 latch .countDown ();
212219 }
213220
@@ -229,20 +236,23 @@ public TimeValue timeout() {
229236 public void onAllNodesAcked (@ Nullable Exception e ) {
230237 assertFalse (threadPool .getThreadContext ().isSystemContext ());
231238 assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
239+ assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
232240 latch .countDown ();
233241 }
234242
235243 @ Override
236244 public void onAckTimeout () {
237245 assertFalse (threadPool .getThreadContext ().isSystemContext ());
238246 assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
247+ assertEquals (expectedResponseHeaders , threadPool .getThreadContext ().getResponseHeaders ());
239248 latch .countDown ();
240249 }
241250
242251 });
243252
244253 assertFalse (threadPool .getThreadContext ().isSystemContext ());
245254 assertEquals (expectedHeaders , threadPool .getThreadContext ().getHeaders ());
255+ assertEquals (Collections .emptyMap (), threadPool .getThreadContext ().getResponseHeaders ());
246256 }
247257
248258 latch .await ();
0 commit comments