diff --git a/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/it/QuarkusApplicationIntegrationTest.java b/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/it/QuarkusApplicationIntegrationTest.java index 8aaac080e5..ba1395ac8e 100644 --- a/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/it/QuarkusApplicationIntegrationTest.java +++ b/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/it/QuarkusApplicationIntegrationTest.java @@ -19,19 +19,26 @@ package org.apache.polaris.service.quarkus.it; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import com.auth0.jwt.JWT; import com.auth0.jwt.algorithms.Algorithm; import io.quarkus.test.junit.QuarkusTest; import io.quarkus.test.junit.QuarkusTestProfile; import io.quarkus.test.junit.TestProfile; +import io.smallrye.common.annotation.Identifier; +import jakarta.inject.Inject; import java.io.IOException; import java.time.Instant; import java.util.Map; +import org.apache.iceberg.exceptions.NotAuthorizedException; +import org.apache.iceberg.rest.ErrorHandlers; import org.apache.iceberg.rest.HTTPClient; import org.apache.iceberg.rest.RESTClient; import org.apache.iceberg.rest.auth.AuthConfig; import org.apache.iceberg.rest.auth.OAuth2Util; +import org.apache.iceberg.rest.responses.OAuthTokenResponse; +import org.apache.polaris.service.auth.TokenBrokerFactory; import org.apache.polaris.service.it.env.ClientCredentials; import org.apache.polaris.service.it.env.PolarisApiEndpoints; import org.apache.polaris.service.it.test.PolarisApplicationIntegrationTest; @@ -53,8 +60,12 @@ public Map getConfigOverrides() { } } + @Inject + @Identifier("rsa-key-pair") + TokenBrokerFactory tokenBrokerFactory; + @Test - public void testIcebergRestApiRefreshToken( + public void testIcebergRestApiRefreshExpiredToken( PolarisApiEndpoints endpoints, ClientCredentials clientCredentials) throws IOException { String path = endpoints.catalogApiEndpoint() + "/v1/oauth/tokens"; try (RESTClient client = @@ -82,4 +93,91 @@ public void testIcebergRestApiRefreshToken( assertThat(JWT.decode(session.token()).getExpiresAtAsInstant()).isAfter(Instant.EPOCH); } } + + @Test + public void testIcebergRestApiRefreshValidToken( + PolarisApiEndpoints endpoints, ClientCredentials clientCredentials) throws IOException { + String path = endpoints.catalogApiEndpoint() + "/v1/oauth/tokens"; + try (RESTClient client = + HTTPClient.builder(Map.of()) + .withHeader(endpoints.realmHeaderName(), endpoints.realmId()) + .uri(path) + .build()) { + var response = + client.postForm( + path, + Map.of( + "grant_type", + "client_credentials", + "scope", + "PRINCIPAL_ROLE:ALL", + "client_id", + clientCredentials.clientId(), + "client_secret", + clientCredentials.clientSecret()), + OAuthTokenResponse.class, + Map.of(), + ErrorHandlers.oauthErrorHandler()); + String token = response.token(); + var authConfig = + AuthConfig.builder() + .credential(clientCredentials.clientId() + ":" + clientCredentials.clientSecret()) + .scope("PRINCIPAL_ROLE:ALL") + .oauth2ServerUri(path) + .token(token) + .build(); + var parentSession = new OAuth2Util.AuthSession(Map.of(), authConfig); + var session = OAuth2Util.AuthSession.fromAccessToken(client, null, token, 0L, parentSession); + session.refresh(client); + assertThat(session.token()).isNotEqualTo(token); + assertThat(JWT.decode(session.token()).getExpiresAtAsInstant()).isAfter(Instant.now()); + } + } + + @Test + public void testIcebergRestApiInvalidToken( + PolarisApiEndpoints endpoints, ClientCredentials clientCredentials) throws IOException { + String path = endpoints.catalogApiEndpoint() + "/v1/oauth/tokens"; + try (RESTClient client = + HTTPClient.builder(Map.of()) + .withHeader(endpoints.realmHeaderName(), endpoints.realmId()) + .uri(path) + .build()) { + var response = + client.postForm( + path, + Map.of( + "grant_type", + "client_credentials", + "scope", + "PRINCIPAL_ROLE:ALL", + "client_id", + clientCredentials.clientId(), + "client_secret", + clientCredentials.clientSecret()), + OAuthTokenResponse.class, + Map.of(), + ErrorHandlers.oauthErrorHandler()); + String token = response.token(); + // mimics OAUth2Util.AuthSession refreshing the token + assertThatThrownBy( + () -> + client.postForm( + path, + Map.of( + "grant_type", + "urn:ietf:params:oauth:grant-type:token-exchange", + "scope", + "PRINCIPAL_ROLE:ALL", + "subject_token", + "invalid", + "subject_token_type", + "urn:ietf:params:oauth:token-type:access_token"), + OAuthTokenResponse.class, + Map.of("Authorization", "Bearer " + token), + ErrorHandlers.oauthErrorHandler())) + .isInstanceOf(NotAuthorizedException.class) + .hasMessageContaining("invalid_client"); + } + } } diff --git a/service/common/src/main/java/org/apache/polaris/service/auth/JWTBroker.java b/service/common/src/main/java/org/apache/polaris/service/auth/JWTBroker.java index 43e4804aa4..dafe0732dc 100644 --- a/service/common/src/main/java/org/apache/polaris/service/auth/JWTBroker.java +++ b/service/common/src/main/java/org/apache/polaris/service/auth/JWTBroker.java @@ -35,6 +35,7 @@ import org.apache.polaris.core.entity.PrincipalEntity; import org.apache.polaris.core.persistence.PolarisMetaStoreManager; import org.apache.polaris.core.persistence.dao.entity.EntityResult; +import org.apache.polaris.service.auth.OAuthTokenErrorResponse.Error; import org.apache.polaris.service.types.TokenType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -100,7 +101,7 @@ public TokenResponse generateFromToken( String grantType, String scope, TokenType requestedTokenType) { - if (!TokenType.ACCESS_TOKEN.equals(requestedTokenType)) { + if (requestedTokenType != null && !TokenType.ACCESS_TOKEN.equals(requestedTokenType)) { return new TokenResponse(OAuthTokenErrorResponse.Error.invalid_request); } if (!TokenType.ACCESS_TOKEN.equals(subjectTokenType)) { @@ -109,7 +110,12 @@ public TokenResponse generateFromToken( if (StringUtils.isBlank(subjectToken)) { return new TokenResponse(OAuthTokenErrorResponse.Error.invalid_request); } - DecodedToken decodedToken = verify(subjectToken); + DecodedToken decodedToken; + try { + decodedToken = verify(subjectToken); + } catch (NotAuthorizedException e) { + return new TokenResponse(Error.invalid_client); + } EntityResult principalLookup = metaStoreManager.loadEntity( CallContext.getCurrentContext().getPolarisCallContext(),