Skip to content

Commit 091cbe5

Browse files
committed
ServletOAuth2AuthorizedClientExchangeFilterFunction supports chaining
Fixes gh-6483
1 parent 04f8a79 commit 091cbe5

File tree

2 files changed

+231
-11
lines changed

2 files changed

+231
-11
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 113 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
package org.springframework.security.oauth2.client.web.reactive.function.client;
1818

19+
import org.reactivestreams.Subscription;
20+
import org.springframework.beans.factory.DisposableBean;
21+
import org.springframework.beans.factory.InitializingBean;
1922
import org.springframework.http.HttpHeaders;
2023
import org.springframework.http.HttpMethod;
2124
import org.springframework.http.MediaType;
@@ -44,8 +47,12 @@
4447
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
4548
import org.springframework.web.reactive.function.client.ExchangeFunction;
4649
import org.springframework.web.reactive.function.client.WebClient;
50+
import reactor.core.CoreSubscriber;
51+
import reactor.core.publisher.Hooks;
4752
import reactor.core.publisher.Mono;
53+
import reactor.core.publisher.Operators;
4854
import reactor.core.scheduler.Schedulers;
55+
import reactor.util.context.Context;
4956

5057
import javax.servlet.http.HttpServletRequest;
5158
import javax.servlet.http.HttpServletResponse;
@@ -98,7 +105,9 @@
98105
* @author Rob Winch
99106
* @since 5.1
100107
*/
101-
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
108+
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
109+
implements ExchangeFilterFunction, InitializingBean, DisposableBean {
110+
102111
/**
103112
* The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
104113
*/
@@ -108,6 +117,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
108117
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
109118
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
110119

120+
private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName();
121+
111122
private Clock clock = Clock.systemUTC();
112123

113124
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
@@ -123,7 +134,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
123134

124135
private String defaultClientRegistrationId;
125136

126-
public ServletOAuth2AuthorizedClientExchangeFilterFunction() {}
137+
public ServletOAuth2AuthorizedClientExchangeFilterFunction() {
138+
}
127139

128140
public ServletOAuth2AuthorizedClientExchangeFilterFunction(
129141
ClientRegistrationRepository clientRegistrationRepository,
@@ -132,6 +144,16 @@ public ServletOAuth2AuthorizedClientExchangeFilterFunction(
132144
this.authorizedClientRepository = authorizedClientRepository;
133145
}
134146

147+
@Override
148+
public void afterPropertiesSet() throws Exception {
149+
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub)));
150+
}
151+
152+
@Override
153+
public void destroy() throws Exception {
154+
Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY);
155+
}
156+
135157
/**
136158
* Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
137159
* client_credentials grant.
@@ -266,15 +288,36 @@ public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
266288

267289
@Override
268290
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
269-
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
270-
.map(OAuth2AuthorizedClient.class::cast);
271-
return Mono.justOrEmpty(attribute)
272-
.flatMap(authorizedClient -> authorizedClient(request, next, authorizedClient))
291+
return Mono.just(request)
292+
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
293+
.switchIfEmpty(mergeRequestAttributesFromContext(request))
294+
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
295+
.flatMap(req -> authorizedClient(req, next, getOAuth2AuthorizedClient(req.attributes())))
273296
.map(authorizedClient -> bearer(request, authorizedClient))
274297
.flatMap(next::exchange)
275298
.switchIfEmpty(next.exchange(request));
276299
}
277300

301+
private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest request) {
302+
return Mono.just(ClientRequest.from(request))
303+
.flatMap(builder -> Mono.subscriberContext()
304+
.map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))))
305+
.map(ClientRequest.Builder::build);
306+
}
307+
308+
private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
309+
if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) {
310+
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME));
311+
}
312+
if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
313+
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
314+
}
315+
if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) {
316+
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME));
317+
}
318+
populateDefaultOAuth2AuthorizedClient(attrs);
319+
}
320+
278321
private void populateDefaultRequestResponse(Map<String, Object> attrs) {
279322
if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey(
280323
HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
@@ -425,6 +468,19 @@ private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient autho
425468
.build();
426469
}
427470

471+
private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) {
472+
HttpServletRequest request = null;
473+
HttpServletResponse response = null;
474+
ServletRequestAttributes requestAttributes =
475+
(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
476+
if (requestAttributes != null) {
477+
request = requestAttributes.getRequest();
478+
response = requestAttributes.getResponse();
479+
}
480+
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
481+
return new RequestContextSubscriber<>(delegate, request, response, authentication);
482+
}
483+
428484
private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) {
429485
return BodyInserters
430486
.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
@@ -498,4 +554,55 @@ private UnsupportedOperationException unsupported() {
498554
return new UnsupportedOperationException("Not Supported");
499555
}
500556
}
557+
558+
private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
559+
private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
560+
private final CoreSubscriber<T> delegate;
561+
private final HttpServletRequest request;
562+
private final HttpServletResponse response;
563+
private final Authentication authentication;
564+
565+
private RequestContextSubscriber(CoreSubscriber<T> delegate,
566+
HttpServletRequest request,
567+
HttpServletResponse response,
568+
Authentication authentication) {
569+
this.delegate = delegate;
570+
this.request = request;
571+
this.response = response;
572+
this.authentication = authentication;
573+
}
574+
575+
@Override
576+
public Context currentContext() {
577+
Context context = this.delegate.currentContext();
578+
if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) {
579+
return context;
580+
}
581+
return Context.of(
582+
CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE,
583+
HTTP_SERVLET_REQUEST_ATTR_NAME, this.request,
584+
HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response,
585+
AUTHENTICATION_ATTR_NAME, this.authentication);
586+
}
587+
588+
@Override
589+
public void onSubscribe(Subscription s) {
590+
this.delegate.onSubscribe(s);
591+
}
592+
593+
@Override
594+
public void onNext(T t) {
595+
this.delegate.onNext(t);
596+
}
597+
598+
@Override
599+
public void onError(Throwable t) {
600+
this.delegate.onError(t);
601+
}
602+
603+
@Override
604+
public void onComplete() {
605+
this.delegate.onComplete();
606+
}
607+
}
501608
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 118 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,11 @@
7474
import java.util.Optional;
7575
import java.util.function.Consumer;
7676

77-
import static org.assertj.core.api.Assertions.*;
77+
import static org.assertj.core.api.Assertions.assertThat;
78+
import static org.assertj.core.api.Assertions.assertThatCode;
7879
import static org.mockito.ArgumentMatchers.any;
7980
import static org.mockito.ArgumentMatchers.eq;
80-
import static org.mockito.Mockito.mock;
81-
import static org.mockito.Mockito.verify;
82-
import static org.mockito.Mockito.verifyZeroInteractions;
83-
import static org.mockito.Mockito.when;
81+
import static org.mockito.Mockito.*;
8482
import static org.springframework.http.HttpMethod.GET;
8583
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*;
8684

@@ -572,6 +570,121 @@ public void filterWhenNotExpiredThenShouldRefreshFalse() {
572570
assertThat(getBody(request0)).isEmpty();
573571
}
574572

573+
// gh-6483
574+
@Test
575+
public void filterWhenChainedThenDefaultsStillAvailable() throws Exception {
576+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
577+
this.clientRegistrationRepository, this.authorizedClientRepository);
578+
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
579+
this.function.setDefaultOAuth2AuthorizedClient(true);
580+
581+
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
582+
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
583+
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
584+
585+
OAuth2User user = mock(OAuth2User.class);
586+
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
587+
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
588+
user, authorities, this.registration.getRegistrationId());
589+
SecurityContextHolder.getContext().setAuthentication(authentication);
590+
591+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
592+
this.registration, "principalName", this.accessToken);
593+
when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()),
594+
eq(authentication), eq(servletRequest))).thenReturn(authorizedClient);
595+
596+
// Default request attributes set
597+
final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com"))
598+
.attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build();
599+
600+
// Default request attributes NOT set
601+
final ClientRequest request2 = ClientRequest.create(GET, URI.create("https://example2.com")).build();
602+
603+
this.function.filter(request1, this.exchange)
604+
.flatMap(response -> this.function.filter(request2, this.exchange))
605+
.block();
606+
607+
this.function.destroy(); // Hooks.onLastOperator() released
608+
609+
List<ClientRequest> requests = this.exchange.getRequests();
610+
assertThat(requests).hasSize(2);
611+
612+
ClientRequest request = requests.get(0);
613+
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
614+
assertThat(request.url().toASCIIString()).isEqualTo("https://example1.com");
615+
assertThat(request.method()).isEqualTo(HttpMethod.GET);
616+
assertThat(getBody(request)).isEmpty();
617+
618+
request = requests.get(1);
619+
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
620+
assertThat(request.url().toASCIIString()).isEqualTo("https://example2.com");
621+
assertThat(request.method()).isEqualTo(HttpMethod.GET);
622+
assertThat(getBody(request)).isEmpty();
623+
}
624+
625+
@Test
626+
public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvailable() throws Exception {
627+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
628+
this.clientRegistrationRepository, this.authorizedClientRepository);
629+
// this.function.afterPropertiesSet(); // Hooks.onLastOperator() NOT initialized
630+
this.function.setDefaultOAuth2AuthorizedClient(true);
631+
632+
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
633+
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
634+
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
635+
636+
OAuth2User user = mock(OAuth2User.class);
637+
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
638+
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
639+
user, authorities, this.registration.getRegistrationId());
640+
SecurityContextHolder.getContext().setAuthentication(authentication);
641+
642+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
643+
644+
this.function.filter(request, this.exchange).block();
645+
646+
List<ClientRequest> requests = this.exchange.getRequests();
647+
assertThat(requests).hasSize(1);
648+
649+
request = requests.get(0);
650+
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
651+
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
652+
assertThat(request.method()).isEqualTo(HttpMethod.GET);
653+
assertThat(getBody(request)).isEmpty();
654+
}
655+
656+
@Test
657+
public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsNotAvailable() throws Exception {
658+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
659+
this.clientRegistrationRepository, this.authorizedClientRepository);
660+
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
661+
this.function.destroy(); // Hooks.onLastOperator() released
662+
this.function.setDefaultOAuth2AuthorizedClient(true);
663+
664+
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
665+
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
666+
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
667+
668+
OAuth2User user = mock(OAuth2User.class);
669+
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
670+
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
671+
user, authorities, this.registration.getRegistrationId());
672+
SecurityContextHolder.getContext().setAuthentication(authentication);
673+
674+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
675+
676+
this.function.filter(request, this.exchange).block();
677+
678+
List<ClientRequest> requests = this.exchange.getRequests();
679+
assertThat(requests).hasSize(1);
680+
681+
request = requests.get(0);
682+
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
683+
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
684+
assertThat(request.method()).isEqualTo(HttpMethod.GET);
685+
assertThat(getBody(request)).isEmpty();
686+
}
687+
575688
private static String getBody(ClientRequest request) {
576689
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
577690
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));

0 commit comments

Comments
 (0)