|
1 | 1 | /* |
2 | | - * Copyright 2019 the original author or authors. |
| 2 | + * Copyright 2019-2021 the original author or authors. |
3 | 3 | * |
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 5 | * you may not use this file except in compliance with the License. |
|
19 | 19 | import java.util.Arrays; |
20 | 20 | import java.util.Collections; |
21 | 21 | import java.util.List; |
| 22 | +import java.util.concurrent.ExecutorService; |
| 23 | +import java.util.concurrent.Executors; |
22 | 24 |
|
23 | 25 | import io.rsocket.Payload; |
24 | 26 | import io.rsocket.RSocket; |
25 | 27 | import io.rsocket.metadata.WellKnownMimeType; |
| 28 | +import io.rsocket.util.ByteBufPayload; |
| 29 | +import io.rsocket.util.DefaultPayload; |
26 | 30 | import io.rsocket.util.RSocketProxy; |
27 | 31 | import org.junit.Test; |
28 | 32 | import org.junit.runner.RunWith; |
|
32 | 36 | import org.mockito.runners.MockitoJUnitRunner; |
33 | 37 | import org.mockito.stubbing.Answer; |
34 | 38 | import org.reactivestreams.Publisher; |
| 39 | +import org.reactivestreams.Subscription; |
| 40 | +import reactor.core.CoreSubscriber; |
35 | 41 | import reactor.core.publisher.Flux; |
36 | 42 | import reactor.core.publisher.Mono; |
37 | 43 | import reactor.test.StepVerifier; |
38 | 44 | import reactor.test.publisher.PublisherProbe; |
39 | 45 | import reactor.test.publisher.TestPublisher; |
| 46 | +import reactor.util.context.Context; |
40 | 47 |
|
41 | 48 | import org.springframework.http.MediaType; |
| 49 | +import org.springframework.security.access.AccessDeniedException; |
42 | 50 | import org.springframework.security.authentication.TestingAuthenticationToken; |
43 | 51 | import org.springframework.security.core.Authentication; |
44 | 52 | import org.springframework.security.core.context.ReactiveSecurityContextHolder; |
|
56 | 64 | import static org.mockito.ArgumentMatchers.any; |
57 | 65 | import static org.mockito.ArgumentMatchers.eq; |
58 | 66 | import static org.mockito.BDDMockito.given; |
| 67 | +import static org.mockito.Mockito.times; |
59 | 68 | import static org.mockito.Mockito.verify; |
60 | 69 | import static org.mockito.Mockito.verifyZeroInteractions; |
61 | 70 |
|
@@ -265,6 +274,57 @@ public void requestChannelWhenInterceptorCompletesThenDelegateSubscribed() { |
265 | 274 | verify(this.delegate).requestChannel(any()); |
266 | 275 | } |
267 | 276 |
|
| 277 | + // gh-9345 |
| 278 | + @Test |
| 279 | + public void requestChannelWhenInterceptorCompletesThenAllPayloadsRetained() { |
| 280 | + ExecutorService executors = Executors.newSingleThreadExecutor(); |
| 281 | + Payload payload = ByteBufPayload.create("data"); |
| 282 | + Payload payloadTwo = ByteBufPayload.create("moredata"); |
| 283 | + Payload payloadThree = ByteBufPayload.create("stillmoredata"); |
| 284 | + Context ctx = Context.empty(); |
| 285 | + Flux<Payload> payloads = this.payloadResult.flux(); |
| 286 | + given(this.interceptor.intercept(any(), any())).willReturn(Mono.empty()) |
| 287 | + .willReturn(Mono.error(() -> new AccessDeniedException("Access Denied"))); |
| 288 | + given(this.delegate.requestChannel(any())).willAnswer((invocation) -> { |
| 289 | + Flux<Payload> input = invocation.getArgument(0); |
| 290 | + return Flux.from(input).switchOnFirst((signal, innerFlux) -> innerFlux.map(Payload::getDataUtf8) |
| 291 | + .transform((data) -> Flux.<String>create((emitter) -> { |
| 292 | + Runnable run = () -> data.subscribe(new CoreSubscriber<String>() { |
| 293 | + @Override |
| 294 | + public void onSubscribe(Subscription s) { |
| 295 | + s.request(3); |
| 296 | + } |
| 297 | + |
| 298 | + @Override |
| 299 | + public void onNext(String s) { |
| 300 | + emitter.next(s); |
| 301 | + } |
| 302 | + |
| 303 | + @Override |
| 304 | + public void onError(Throwable t) { |
| 305 | + emitter.error(t); |
| 306 | + } |
| 307 | + |
| 308 | + @Override |
| 309 | + public void onComplete() { |
| 310 | + emitter.complete(); |
| 311 | + } |
| 312 | + }); |
| 313 | + executors.execute(run); |
| 314 | + })).map(DefaultPayload::create)); |
| 315 | + }); |
| 316 | + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, |
| 317 | + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType, ctx); |
| 318 | + StepVerifier.create(interceptor.requestChannel(payloads).doOnDiscard(Payload.class, Payload::release)) |
| 319 | + .then(() -> this.payloadResult.assertSubscribers()) |
| 320 | + .then(() -> this.payloadResult.emit(payload, payloadTwo, payloadThree)) |
| 321 | + .assertNext((next) -> assertThat(next.getDataUtf8()).isEqualTo(payload.getDataUtf8())) |
| 322 | + .verifyError(AccessDeniedException.class); |
| 323 | + verify(this.interceptor, times(2)).intercept(this.exchange.capture(), any()); |
| 324 | + assertThat(this.exchange.getValue().getPayload()).isEqualTo(payloadTwo); |
| 325 | + verify(this.delegate).requestChannel(any()); |
| 326 | + } |
| 327 | + |
268 | 328 | @Test |
269 | 329 | public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() { |
270 | 330 | RuntimeException expected = new RuntimeException("Oops"); |
|
0 commit comments