diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java index d42fa9909ed93..d431998c33714 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java @@ -657,9 +657,15 @@ public void invalidateRefreshToken(String refreshToken, ActionListener 1) { listener.onFailure(new IllegalStateException("multiple tokens share the same refresh token")); } else { - final Tuple parsedTokens = - parseTokensFromDocument(searchHits.getAt(0).getSourceAsMap(), null); - indexInvalidation(Collections.singletonList(parsedTokens.v1()), backoff, "refresh_token", null, listener); + final Tuple parsedTokens = + parseTokenAndRefreshStatus(searchHits.getAt(0).getSourceAsMap()); + final UserToken userToken = parsedTokens.v1(); + final RefreshTokenStatus refresh = parsedTokens.v2(); + if (refresh.isInvalidated()) { + listener.onResponse(new TokensInvalidationResult(List.of(), List.of(userToken.getId()), null, RestStatus.OK)); + } else { + indexInvalidation(Collections.singletonList(userToken), backoff, "refresh_token", null, listener); + } } }, e -> { if (e instanceof IndexNotFoundException || e instanceof IndexClosedException) { @@ -1258,9 +1264,7 @@ private static Optional checkTokenDocumentExpire */ private static Tuple> checkTokenDocumentForRefresh( Instant refreshRequested, Authentication clientAuth, Map source) throws IllegalStateException, DateTimeException { - final RefreshTokenStatus refreshTokenStatus = RefreshTokenStatus.fromSourceMap(getRefreshTokenSourceMap(source)); - final UserToken userToken = UserToken.fromSourceMap(getUserTokenSourceMap(source)); - refreshTokenStatus.setVersion(userToken.getVersion()); + final RefreshTokenStatus refreshTokenStatus = parseTokenAndRefreshStatus(source).v2(); final ElasticsearchSecurityException validationException = checkTokenDocumentExpired(refreshRequested, source).orElseGet(() -> { if (refreshTokenStatus.isInvalidated()) { return invalidGrantException("token has been invalidated"); @@ -1272,6 +1276,13 @@ private static Tuple(refreshTokenStatus, Optional.ofNullable(validationException)); } + private static Tuple parseTokenAndRefreshStatus(Map source) { + final RefreshTokenStatus refreshTokenStatus = RefreshTokenStatus.fromSourceMap(getRefreshTokenSourceMap(source)); + final UserToken userToken = UserToken.fromSourceMap(getUserTokenSourceMap(source)); + refreshTokenStatus.setVersion(userToken.getVersion()); + return new Tuple<>(userToken, refreshTokenStatus); + } + /** * Refresh tokens are bound to be used only by the client that originally created them. This check validates this condition, given the * {@code Authentication} of the client that attempted the refresh operation. @@ -1484,12 +1495,18 @@ private void sourceIndicesWithTokensAndRun(ActionListener> listener private BytesReference createTokenDocument(UserToken userToken, @Nullable String refreshToken, @Nullable Authentication originatingClientAuth) { + final Instant creationTime = getCreationTime(userToken.getExpirationTime()); + return createTokenDocument(userToken, refreshToken, originatingClientAuth, creationTime); + } + + static BytesReference createTokenDocument(UserToken userToken, String refreshToken, Authentication originatingClientAuth, + Instant creationTime) { assert refreshToken == null || originatingClientAuth != null : "non-null refresh token " + refreshToken + " requires non-null client authn " + originatingClientAuth; try (XContentBuilder builder = XContentFactory.jsonBuilder()) { builder.startObject(); builder.field("doc_type", TOKEN_DOC_TYPE); - builder.field("creation_time", getCreationTime(userToken.getExpirationTime()).toEpochMilli()); + builder.field("creation_time", creationTime.toEpochMilli()); if (refreshToken != null) { builder.startObject("refresh_token") .field("token", refreshToken) diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java index eb3dcf208e0a9..d04c68c570e88 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java @@ -9,8 +9,15 @@ import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.DocWriteResponse; import org.elasticsearch.action.NoShardAvailableActionException; import org.elasticsearch.action.UnavailableShardsException; +import org.elasticsearch.action.bulk.BulkAction; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.get.GetAction; import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetRequestBuilder; @@ -19,17 +26,24 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.update.UpdateAction; import org.elasticsearch.action.update.UpdateRequestBuilder; +import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; @@ -39,13 +53,19 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.Index; +import org.elasticsearch.index.get.GetResult; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.license.XPackLicenseState.Feature; import org.elasticsearch.node.Node; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.EqualsHashCodeTestUtils; +import org.elasticsearch.test.XContentTestUtils; import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.XPackSettings; @@ -57,6 +77,7 @@ import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult; import org.elasticsearch.xpack.core.security.user.User; import org.elasticsearch.xpack.core.watcher.watch.ClockMock; +import org.elasticsearch.xpack.security.authc.TokenService.RefreshTokenStatus; import org.elasticsearch.xpack.security.support.FeatureNotEnabledException; import org.elasticsearch.xpack.security.support.SecurityIndexManager; import org.elasticsearch.xpack.security.test.SecurityMocks; @@ -83,9 +104,12 @@ import static org.elasticsearch.repositories.blobstore.ESBlobStoreRepositoryIntegTestCase.randomBytes; import static org.elasticsearch.test.ClusterServiceUtils.setState; import static org.elasticsearch.test.TestMatchers.throwableWithMessage; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; @@ -100,7 +124,7 @@ public class TokenServiceTests extends ESTestCase { private static ThreadPool threadPool; private static final Settings settings = Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "TokenServiceTests") - .put(XPackSettings.TOKEN_SERVICE_ENABLED_SETTING.getKey(), true).build(); + .put(XPackSettings.TOKEN_SERVICE_ENABLED_SETTING.getKey(), true).build(); private Client client; private SecurityIndexManager securityMainIndex; @@ -120,19 +144,44 @@ public void setupClient() { doAnswer(invocationOnMock -> { GetRequestBuilder builder = new GetRequestBuilder(client, GetAction.INSTANCE); builder.setIndex((String) invocationOnMock.getArguments()[0]) - .setId((String) invocationOnMock.getArguments()[1]); + .setId((String) invocationOnMock.getArguments()[1]); return builder; }).when(client).prepareGet(anyString(), anyString()); when(client.prepareIndex(any(String.class))) - .thenReturn(new IndexRequestBuilder(client, IndexAction.INSTANCE)); + .thenReturn(new IndexRequestBuilder(client, IndexAction.INSTANCE)); + when(client.prepareBulk()) + .thenReturn(new BulkRequestBuilder(client, BulkAction.INSTANCE)); when(client.prepareUpdate(any(String.class), any(String.class))) - .thenReturn(new UpdateRequestBuilder(client, UpdateAction.INSTANCE)); + .thenAnswer(inv -> { + final String index = (String) inv.getArguments()[0]; + final String id = (String) inv.getArguments()[1]; + return new UpdateRequestBuilder(client, UpdateAction.INSTANCE).setIndex(index).setId(id); + }); + when(client.prepareSearch(any(String.class))) + .thenReturn(new SearchRequestBuilder(client, SearchAction.INSTANCE)); doAnswer(invocationOnMock -> { ActionListener responseActionListener = (ActionListener) invocationOnMock.getArguments()[2]; responseActionListener.onResponse(new IndexResponse(new ShardId(".security", UUIDs.randomBase64UUID(), randomInt()), randomAlphaOfLength(4), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), true)); return null; }).when(client).execute(eq(IndexAction.INSTANCE), any(IndexRequest.class), any(ActionListener.class)); + doAnswer(invocationOnMock -> { + BulkRequest request = (BulkRequest) invocationOnMock.getArguments()[0]; + ActionListener responseActionListener = (ActionListener) invocationOnMock.getArguments()[1]; + BulkItemResponse[] responses = new BulkItemResponse[request.requests().size()]; + final String indexUUID = randomAlphaOfLength(22); + for (int i = 0; i < responses.length; i++) { + var shardId = new ShardId(securityTokensIndex.aliasName(), indexUUID, 1); + var docId = request.requests().get(i).id(); + var result = new GetResult(shardId.getIndexName(), docId, 1, 1, 1, true, null, null, null); + final UpdateResponse response = new UpdateResponse(shardId, result.getId(), result.getSeqNo(), result.getPrimaryTerm(), + result.getVersion() + 1, DocWriteResponse.Result.UPDATED); + response.setGetResult(result); + responses[i] = new BulkItemResponse(i, DocWriteRequest.OpType.UPDATE, response); + } + responseActionListener.onResponse(new BulkResponse(responses, randomLongBetween(1, 500))); + return null; + }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); this.securityContext = new SecurityContext(settings, threadPool.getThreadContext()); // setup lifecycle service @@ -162,8 +211,8 @@ public void tearDown() throws Exception { @BeforeClass public static void startThreadPool() throws IOException { threadPool = new ThreadPool(settings, - new FixedExecutorBuilder(settings, TokenService.THREAD_POOL_NAME, 1, 1000, "xpack.security.authc.token.thread_pool", - false)); + new FixedExecutorBuilder(settings, TokenService.THREAD_POOL_NAME, 1, 1000, "xpack.security.authc.token.thread_pool", + false)); new Authentication(new User("foo"), new RealmRef("realm", "type", "node"), null).writeToContext(threadPool.getThreadContext()); } @@ -474,6 +523,60 @@ public void testInvalidatedToken() throws Exception { } } + public void testInvalidateRefreshToken() throws Exception { + when(securityMainIndex.indexExists()).thenReturn(true); + TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); + Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + final String userTokenId = UUIDs.randomBase64UUID(); + final String rawRefreshToken = UUIDs.randomBase64UUID(); + tokenService.createOAuth2Tokens(userTokenId, rawRefreshToken, authentication, authentication, Collections.emptyMap(), tokenFuture); + final String accessToken = tokenFuture.get().v1(); + final String clientRefreshToken = tokenFuture.get().v2(); + assertNotNull(accessToken); + mockFindTokenFromRefreshToken(rawRefreshToken, buildUserToken(tokenService, userTokenId, authentication), null); + + ThreadContext requestContext = new ThreadContext(Settings.EMPTY); + storeTokenHeader(requestContext, accessToken); + + try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { + PlainActionFuture future = new PlainActionFuture<>(); + tokenService.invalidateRefreshToken(clientRefreshToken, future); + final TokensInvalidationResult result = future.get(); + assertThat(result.getInvalidatedTokens(), hasSize(1)); + assertThat(result.getPreviouslyInvalidatedTokens(), empty()); + assertThat(result.getErrors(), empty()); + } + } + + public void testInvalidateRefreshTokenThatIsAlreadyInvalidated() throws Exception { + when(securityMainIndex.indexExists()).thenReturn(true); + TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); + Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + final String userTokenId = UUIDs.randomBase64UUID(); + final String rawRefreshToken = UUIDs.randomBase64UUID(); + tokenService.createOAuth2Tokens(userTokenId, rawRefreshToken, authentication, authentication, Collections.emptyMap(), tokenFuture); + final String accessToken = tokenFuture.get().v1(); + final String clientRefreshToken = tokenFuture.get().v2(); + assertNotNull(accessToken); + mockFindTokenFromRefreshToken(rawRefreshToken, buildUserToken(tokenService, userTokenId, authentication), + new RefreshTokenStatus(true, randomAlphaOfLength(12), randomAlphaOfLength(6), false, null, null, null, null) + ); + + ThreadContext requestContext = new ThreadContext(Settings.EMPTY); + storeTokenHeader(requestContext, accessToken); + + try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { + PlainActionFuture future = new PlainActionFuture<>(); + tokenService.invalidateRefreshToken(clientRefreshToken, future); + final TokensInvalidationResult result = future.get(); + assertThat(result.getPreviouslyInvalidatedTokens(), hasSize(1)); + assertThat(result.getInvalidatedTokens(), empty()); + assertThat(result.getErrors(), empty()); + } + } + private void storeTokenHeader(ThreadContext requestContext, String tokenString) throws IOException, GeneralSecurityException { requestContext.putHeader("Authorization", "Bearer " + tokenString); } @@ -489,26 +592,26 @@ public void testComputeSecretKeyIsConsistent() throws Exception { } public void testTokenExpiryConfig() { - TimeValue expiration = TokenService.TOKEN_EXPIRATION.get(tokenServiceEnabledSettings); + TimeValue expiration = TokenService.TOKEN_EXPIRATION.get(tokenServiceEnabledSettings); assertThat(expiration, equalTo(TimeValue.timeValueMinutes(20L))); // Configure Minimum expiration tokenServiceEnabledSettings = Settings.builder().put(TokenService.TOKEN_EXPIRATION.getKey(), "1s").build(); - expiration = TokenService.TOKEN_EXPIRATION.get(tokenServiceEnabledSettings); + expiration = TokenService.TOKEN_EXPIRATION.get(tokenServiceEnabledSettings); assertThat(expiration, equalTo(TimeValue.timeValueSeconds(1L))); // Configure Maximum expiration tokenServiceEnabledSettings = Settings.builder().put(TokenService.TOKEN_EXPIRATION.getKey(), "60m").build(); - expiration = TokenService.TOKEN_EXPIRATION.get(tokenServiceEnabledSettings); + expiration = TokenService.TOKEN_EXPIRATION.get(tokenServiceEnabledSettings); assertThat(expiration, equalTo(TimeValue.timeValueHours(1L))); // Outside range should fail tokenServiceEnabledSettings = Settings.builder().put(TokenService.TOKEN_EXPIRATION.getKey(), "1ms").build(); IllegalArgumentException ile = expectThrows(IllegalArgumentException.class, - () -> TokenService.TOKEN_EXPIRATION.get(tokenServiceEnabledSettings)); + () -> TokenService.TOKEN_EXPIRATION.get(tokenServiceEnabledSettings)); assertThat(ile.getMessage(), - containsString("failed to parse value [1ms] for setting [xpack.security.authc.token.timeout], must be >= [1s]")); + containsString("failed to parse value [1ms] for setting [xpack.security.authc.token.timeout], must be >= [1s]")); tokenServiceEnabledSettings = Settings.builder().put(TokenService.TOKEN_EXPIRATION.getKey(), "120m").build(); ile = expectThrows(IllegalArgumentException.class, () -> TokenService.TOKEN_EXPIRATION.get(tokenServiceEnabledSettings)); assertThat(ile.getMessage(), - containsString("failed to parse value [120m] for setting [xpack.security.authc.token.timeout], must be <= [1h]")); + containsString("failed to parse value [120m] for setting [xpack.security.authc.token.timeout], must be <= [1h]")); } public void testTokenExpiry() throws Exception { @@ -564,8 +667,8 @@ public void testTokenExpiry() throws Exception { public void testTokenServiceDisabled() throws Exception { TokenService tokenService = new TokenService(Settings.builder() - .put(XPackSettings.TOKEN_SERVICE_ENABLED_SETTING.getKey(), false) - .build(), + .put(XPackSettings.TOKEN_SERVICE_ENABLED_SETTING.getKey(), false) + .build(), Clock.systemUTC(), client, licenseState, securityContext, securityMainIndex, securityTokensIndex, clusterService); ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> tokenService.createOAuth2Tokens(null, null, null, true, null)); @@ -740,7 +843,7 @@ public void testSupercedingTokenEncryption() throws Exception { final Version version = tokenService.getTokenVersionCompatibility(); String encryptedTokens = tokenService.encryptSupersedingTokens(newAccessToken, newRefreshToken, refrehToken, iv, salt); - TokenService.RefreshTokenStatus refreshTokenStatus = new TokenService.RefreshTokenStatus(false, + RefreshTokenStatus refreshTokenStatus = new RefreshTokenStatus(false, authentication.getUser().principal(), authentication.getAuthenticatedBy().getName(), true, Instant.now().minusSeconds(5L), encryptedTokens, Base64.getEncoder().encodeToString(iv), Base64.getEncoder().encodeToString(salt)); @@ -806,10 +909,7 @@ public static void mockGetTokenFromId(TokenService tokenService, String userToke if (possiblyHashedUserTokenId.equals(request.id().replace("token_", ""))) { when(response.isExists()).thenReturn(true); Map sourceMap = new HashMap<>(); - final Authentication tokenAuth = new Authentication(authentication.getUser(), authentication.getAuthenticatedBy(), - authentication.getLookedUpBy(), tokenVersion, AuthenticationType.TOKEN, authentication.getMetadata()); - final UserToken userToken = new UserToken(possiblyHashedUserTokenId, tokenVersion, tokenAuth, - tokenService.getExpirationTime(), authentication.getMetadata()); + final UserToken userToken = buildUserToken(tokenService, userTokenId, authentication); try (XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent())) { userToken.toXContent(builder, ToXContent.EMPTY_PARAMS); Map accessTokenMap = new HashMap<>(); @@ -825,6 +925,22 @@ public static void mockGetTokenFromId(TokenService tokenService, String userToke }).when(client).get(any(GetRequest.class), any(ActionListener.class)); } + protected static UserToken buildUserToken(TokenService tokenService, String userTokenId, Authentication authentication) { + final Version tokenVersion = tokenService.getTokenVersionCompatibility(); + final String possiblyHashedUserTokenId; + if (tokenVersion.onOrAfter(TokenService.VERSION_ACCESS_TOKENS_AS_UUIDS)) { + possiblyHashedUserTokenId = TokenService.hashTokenString(userTokenId); + } else { + possiblyHashedUserTokenId = userTokenId; + } + + final Authentication tokenAuth = new Authentication(authentication.getUser(), authentication.getAuthenticatedBy(), + authentication.getLookedUpBy(), tokenVersion, AuthenticationType.TOKEN, authentication.getMetadata()); + final UserToken userToken = new UserToken(possiblyHashedUserTokenId, tokenVersion, tokenAuth, + tokenService.getExpirationTime(), authentication.getMetadata()); + return userToken; + } + private void mockGetTokenFromId(UserToken userToken, boolean isExpired) { doAnswer(invocationOnMock -> { GetRequest request = (GetRequest) invocationOnMock.getArguments()[0]; @@ -856,6 +972,58 @@ private void mockGetTokenFromId(UserToken userToken, boolean isExpired) { }).when(client).get(any(GetRequest.class), any(ActionListener.class)); } + private void mockFindTokenFromRefreshToken(String refreshToken, UserToken userToken, @Nullable RefreshTokenStatus refreshTokenStatus) { + String storedRefreshToken; + if (userToken.getVersion().onOrAfter(TokenService.VERSION_HASHED_TOKENS)) { + storedRefreshToken = TokenService.hashTokenString(refreshToken); + } else { + storedRefreshToken = refreshToken; + } + doAnswer(invocationOnMock -> { + final SearchRequest request = (SearchRequest) invocationOnMock.getArguments()[0]; + final ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + final SearchResponse response = mock(SearchResponse.class); + + assertThat(request.source().query(), instanceOf(BoolQueryBuilder.class)); + BoolQueryBuilder bool = (BoolQueryBuilder) request.source().query(); + assertThat(bool.filter(), hasSize(2)); + + assertThat(bool.filter().get(0), instanceOf(TermQueryBuilder.class)); + TermQueryBuilder docType = (TermQueryBuilder) bool.filter().get(0); + assertThat(docType.fieldName(), is("doc_type")); + assertThat(docType.value(), is("token")); + + assertThat(bool.filter().get(1), instanceOf(TermQueryBuilder.class)); + TermQueryBuilder refreshFilter = (TermQueryBuilder) bool.filter().get(1); + assertThat(refreshFilter.fieldName(), is("refresh_token.token")); + assertThat(refreshFilter.value(), is(storedRefreshToken)); + + final RealmRef realmRef = new RealmRef( + refreshTokenStatus == null ? randomAlphaOfLength(6) : refreshTokenStatus.getAssociatedRealm(), + "test", + randomAlphaOfLength(12)); + final Authentication clientAuthentication = new Authentication( + new User(refreshTokenStatus == null ? randomAlphaOfLength(8) : refreshTokenStatus.getAssociatedUser()), + realmRef, realmRef); + + final SearchHit hit = new SearchHit(randomInt(), "token_" + TokenService.hashTokenString(userToken.getId()), null, null); + BytesReference source = TokenService.createTokenDocument(userToken, storedRefreshToken, clientAuthentication, Instant.now()); + if (refreshTokenStatus != null) { + var sourceAsMap = XContentHelper.convertToMap(source, false, XContentType.JSON).v2(); + var refreshTokenSource = (Map) sourceAsMap.get("refresh_token"); + refreshTokenSource.put("invalidated", refreshTokenStatus.isInvalidated()); + refreshTokenSource.put("refreshed", refreshTokenStatus.isRefreshed()); + source = XContentTestUtils.convertToXContent(sourceAsMap, XContentType.JSON); + } + hit.sourceRef(source); + + final SearchHits hits = new SearchHits(new SearchHit[]{hit}, null, 1); + when(response.getHits()).thenReturn(hits); + listener.onResponse(response); + return Void.TYPE; + }).when(client).search(any(SearchRequest.class), any(ActionListener.class)); + } + private void mockGetTokenAsyncForDecryptedToken(String accessToken) { doAnswer(invocationOnMock -> { GetRequest request = (GetRequest) invocationOnMock.getArguments()[0]; @@ -880,7 +1048,7 @@ private DiscoveryNode addAnotherDataNodeWithVersion(ClusterService clusterServic final ClusterState currentState = clusterService.state(); final DiscoveryNodes.Builder discoBuilder = DiscoveryNodes.builder(currentState.getNodes()); final DiscoveryNode anotherDataNode = new DiscoveryNode("another_data_node#" + version, buildNewFakeTransportAddress(), - Collections.emptyMap(), Collections.singleton(DiscoveryNodeRole.DATA_ROLE), version); + Collections.emptyMap(), Collections.singleton(DiscoveryNodeRole.DATA_ROLE), version); discoBuilder.add(anotherDataNode); final ClusterState.Builder newStateBuilder = ClusterState.builder(currentState); newStateBuilder.nodes(discoBuilder);