From 2a32c6c6eb351f8a3b419cd1a664c4ca6528c5ea Mon Sep 17 00:00:00 2001 From: Alexandre Dutra Date: Mon, 28 Apr 2025 11:44:45 -0300 Subject: [PATCH] Token decoding improvements --- .../InternalAuthenticationMechanism.java | 42 +++------ .../internal/InternalIdentityProvider.java | 66 +++++++++++--- .../quarkus/auth/JWTRSAKeyPairTest.java | 3 +- .../auth/JWTSymmetricKeyGeneratorTest.java | 3 +- .../InternalAuthenticationMechanismTest.java | 78 +++++++++++----- .../InternalIdentityProviderTest.java | 53 +++++++++-- .../polaris/service/auth/DecodedToken.java | 9 ++ .../polaris/service/auth/JWTBroker.java | 89 +++++++++++++------ .../polaris/service/auth/JWTRSAKeyPair.java | 4 +- .../service/auth/JWTRSAKeyPairFactory.java | 1 + .../service/auth/JWTSymmetricKeyBroker.java | 4 +- .../service/auth/JWTSymmetricKeyFactory.java | 1 + .../service/auth/NoneTokenBrokerFactory.java | 7 +- .../polaris/service/auth/TokenBroker.java | 36 +++++++- .../service/auth/TokenDecodeResult.java | 47 ++++++++++ 15 files changed, 342 insertions(+), 101 deletions(-) create mode 100644 service/common/src/main/java/org/apache/polaris/service/auth/TokenDecodeResult.java diff --git a/quarkus/service/src/main/java/org/apache/polaris/service/quarkus/auth/internal/InternalAuthenticationMechanism.java b/quarkus/service/src/main/java/org/apache/polaris/service/quarkus/auth/internal/InternalAuthenticationMechanism.java index 9a7bb4a025..2c6635e3c4 100644 --- a/quarkus/service/src/main/java/org/apache/polaris/service/quarkus/auth/internal/InternalAuthenticationMechanism.java +++ b/quarkus/service/src/main/java/org/apache/polaris/service/quarkus/auth/internal/InternalAuthenticationMechanism.java @@ -32,7 +32,6 @@ import io.quarkus.vertx.http.runtime.security.HttpSecurityUtils; import io.smallrye.mutiny.Uni; import io.vertx.ext.web.RoutingContext; -import jakarta.annotation.Nullable; import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; import java.util.Collections; @@ -41,7 +40,7 @@ import org.apache.polaris.service.auth.AuthenticationType; import org.apache.polaris.service.auth.DecodedToken; import org.apache.polaris.service.auth.TokenBroker; -import org.apache.polaris.service.quarkus.auth.QuarkusPrincipalAuthInfo; +import org.apache.polaris.service.auth.TokenDecodeResult; /** * A custom {@link HttpAuthenticationMechanism} that handles internal token authentication, that is, @@ -86,22 +85,21 @@ public Uni authenticate( String credential = authHeader.substring(spaceIdx + 1); - DecodedToken token; - try { - token = tokenBroker.verify(credential); - } catch (Exception e) { + TokenDecodeResult result = tokenBroker.decode(credential); + + if (result.token().isEmpty()) { return configuration.type() == AuthenticationType.MIXED + && result.status() != TokenDecodeResult.Status.MALFORMED_TOKEN ? Uni.createFrom().nullItem() // let other auth mechanisms handle it - : Uni.createFrom().failure(new AuthenticationFailedException(e)); // stop here + : Uni.createFrom() + .failure(new AuthenticationFailedException(result.status().message())); // stop here } - if (token == null) { - return Uni.createFrom().nullItem(); - } + DecodedToken token = result.token().get(); return identityProviderManager.authenticate( HttpSecurityUtils.setRoutingContextAttribute( - new TokenAuthenticationRequest(new InternalPrincipalAuthInfo(credential, token)), + new TokenAuthenticationRequest(new InternalTokenCredential(credential, token)), context)); } @@ -124,31 +122,17 @@ public Uni getCredentialTransport(RoutingContext contex .item(new HttpCredentialTransport(HttpCredentialTransport.Type.AUTHORIZATION, BEARER)); } - static class InternalPrincipalAuthInfo extends TokenCredential - implements QuarkusPrincipalAuthInfo { + static class InternalTokenCredential extends TokenCredential { private final DecodedToken token; - InternalPrincipalAuthInfo(String credential, DecodedToken token) { + InternalTokenCredential(String credential, DecodedToken token) { super(credential, "bearer"); this.token = token; } - @Nullable - @Override - public Long getPrincipalId() { - return token.getPrincipalId(); - } - - @Nullable - @Override - public String getPrincipalName() { - return token.getPrincipalName(); - } - - @Override - public Set getPrincipalRoles() { - return token.getPrincipalRoles(); + public DecodedToken getDecodedToken() { + return token; } } } diff --git a/quarkus/service/src/main/java/org/apache/polaris/service/quarkus/auth/internal/InternalIdentityProvider.java b/quarkus/service/src/main/java/org/apache/polaris/service/quarkus/auth/internal/InternalIdentityProvider.java index f7f29ecb81..928d390003 100644 --- a/quarkus/service/src/main/java/org/apache/polaris/service/quarkus/auth/internal/InternalIdentityProvider.java +++ b/quarkus/service/src/main/java/org/apache/polaris/service/quarkus/auth/internal/InternalIdentityProvider.java @@ -18,6 +18,7 @@ */ package org.apache.polaris.service.quarkus.auth.internal; +import io.quarkus.security.AuthenticationFailedException; import io.quarkus.security.identity.AuthenticationRequestContext; import io.quarkus.security.identity.IdentityProvider; import io.quarkus.security.identity.SecurityIdentity; @@ -26,13 +27,22 @@ import io.quarkus.vertx.http.runtime.security.HttpSecurityUtils; import io.smallrye.mutiny.Uni; import io.vertx.ext.web.RoutingContext; +import jakarta.annotation.Nullable; import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; import java.security.Principal; +import java.util.Set; +import org.apache.polaris.service.auth.DecodedToken; +import org.apache.polaris.service.auth.TokenBroker; +import org.apache.polaris.service.quarkus.auth.QuarkusPrincipalAuthInfo; +import org.apache.polaris.service.quarkus.auth.internal.InternalAuthenticationMechanism.InternalTokenCredential; /** A custom {@link IdentityProvider} that handles internal token authentication requests. */ @ApplicationScoped public class InternalIdentityProvider implements IdentityProvider { + @Inject TokenBroker tokenBroker; + @Override public Class getRequestType() { return TokenAuthenticationRequest.class; @@ -41,21 +51,53 @@ public Class getRequestType() { @Override public Uni authenticate( TokenAuthenticationRequest request, AuthenticationRequestContext context) { - if (!(request.getToken() - instanceof InternalAuthenticationMechanism.InternalPrincipalAuthInfo credential)) { + if (!(request.getToken() instanceof InternalTokenCredential credential)) { return Uni.createFrom().nullItem(); } - InternalTokenPrincipal principal = new InternalTokenPrincipal(credential.getPrincipalName()); - return Uni.createFrom() - .item( - QuarkusSecurityIdentity.builder() - .setPrincipal(principal) - .addCredential(credential) - .addAttribute( - RoutingContext.class.getName(), - HttpSecurityUtils.getRoutingContextAttribute(request)) - .build()); + return Uni.createFrom().item(() -> verifyToken(credential, request)); + } + + private SecurityIdentity verifyToken( + InternalTokenCredential credential, TokenAuthenticationRequest request) { + DecodedToken verified; + try { + verified = tokenBroker.verify(credential.getDecodedToken()); + } catch (Exception e) { + throw new AuthenticationFailedException(e); + } + return QuarkusSecurityIdentity.builder() + .setPrincipal(new InternalTokenPrincipal(verified.getSub())) + .addCredential(new InternalPrincipalAuthInfo(verified)) + .addAttribute( + RoutingContext.class.getName(), HttpSecurityUtils.getRoutingContextAttribute(request)) + .build(); } private record InternalTokenPrincipal(String getName) implements Principal {} + + static class InternalPrincipalAuthInfo implements QuarkusPrincipalAuthInfo { + + private final DecodedToken token; + + InternalPrincipalAuthInfo(DecodedToken token) { + this.token = token; + } + + @Nullable + @Override + public Long getPrincipalId() { + return token.getPrincipalId(); + } + + @Nullable + @Override + public String getPrincipalName() { + return token.getPrincipalName(); + } + + @Override + public Set getPrincipalRoles() { + return token.getPrincipalRoles(); + } + } } diff --git a/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTRSAKeyPairTest.java b/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTRSAKeyPairTest.java index 84ca59926b..7c1a6f3e27 100644 --- a/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTRSAKeyPairTest.java +++ b/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTRSAKeyPairTest.java @@ -79,7 +79,8 @@ public void testSuccessfulTokenGeneration() throws Exception { metastoreManager.loadEntity(polarisCallContext, 0L, 1L, PolarisEntityType.PRINCIPAL)) .thenReturn(new EntityResult(principal)); TokenBroker tokenBroker = - new JWTRSAKeyPair(metastoreManager, 420, publicFileLocation, privateFileLocation); + new JWTRSAKeyPair( + () -> "test", metastoreManager, 420, publicFileLocation, privateFileLocation); TokenResponse token = tokenBroker.generateFromClientSecrets( clientId, diff --git a/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTSymmetricKeyGeneratorTest.java b/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTSymmetricKeyGeneratorTest.java index fed5d20db0..fd1f6fd3cf 100644 --- a/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTSymmetricKeyGeneratorTest.java +++ b/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTSymmetricKeyGeneratorTest.java @@ -84,7 +84,8 @@ public Map contextVariables() { Mockito.when( metastoreManager.loadEntity(polarisCallContext, 0L, 1L, PolarisEntityType.PRINCIPAL)) .thenReturn(new EntityResult(principal)); - TokenBroker generator = new JWTSymmetricKeyBroker(metastoreManager, 666, () -> "polaris"); + TokenBroker generator = + new JWTSymmetricKeyBroker(() -> "test", metastoreManager, 666, () -> "polaris"); TokenResponse token = generator.generateFromClientSecrets( clientId, diff --git a/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/internal/InternalAuthenticationMechanismTest.java b/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/internal/InternalAuthenticationMechanismTest.java index 403b94cdaa..5b70f20f54 100644 --- a/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/internal/InternalAuthenticationMechanismTest.java +++ b/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/internal/InternalAuthenticationMechanismTest.java @@ -31,16 +31,19 @@ import io.quarkus.security.identity.SecurityIdentity; import io.quarkus.security.identity.request.TokenAuthenticationRequest; import io.smallrye.mutiny.Uni; +import io.vertx.core.http.HttpServerRequest; import io.vertx.ext.web.RoutingContext; -import org.apache.iceberg.exceptions.NotAuthorizedException; +import java.util.Optional; import org.apache.polaris.service.auth.AuthenticationRealmConfiguration; import org.apache.polaris.service.auth.AuthenticationType; import org.apache.polaris.service.auth.DecodedToken; import org.apache.polaris.service.auth.TokenBroker; +import org.apache.polaris.service.auth.TokenDecodeResult; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.EnumSource; public class InternalAuthenticationMechanismTest { @@ -79,47 +82,73 @@ public void testShouldProcess(AuthenticationType type, boolean expectedResult) { @Test public void testAuthenticateWithNoAuthHeader() { when(configuration.type()).thenReturn(AuthenticationType.INTERNAL); - when(routingContext.request()).thenReturn(mock(io.vertx.core.http.HttpServerRequest.class)); + when(routingContext.request()).thenReturn(mock(HttpServerRequest.class)); when(routingContext.request().getHeader("Authorization")).thenReturn(null); Uni result = mechanism.authenticate(routingContext, identityProviderManager); assertThat(result.await().indefinitely()).isNull(); - verify(tokenBroker, never()).verify(any()); + verify(tokenBroker, never()).decode(any()); } @Test public void testAuthenticateWithInvalidAuthHeaderFormat() { when(configuration.type()).thenReturn(AuthenticationType.INTERNAL); - when(routingContext.request()).thenReturn(mock(io.vertx.core.http.HttpServerRequest.class)); + when(routingContext.request()).thenReturn(mock(HttpServerRequest.class)); when(routingContext.request().getHeader("Authorization")).thenReturn("InvalidFormat"); Uni result = mechanism.authenticate(routingContext, identityProviderManager); assertThat(result.await().indefinitely()).isNull(); - verify(tokenBroker, never()).verify(any()); + verify(tokenBroker, never()).decode(any()); } @Test public void testAuthenticateWithNonBearerAuthHeader() { when(configuration.type()).thenReturn(AuthenticationType.INTERNAL); - when(routingContext.request()).thenReturn(mock(io.vertx.core.http.HttpServerRequest.class)); + when(routingContext.request()).thenReturn(mock(HttpServerRequest.class)); when(routingContext.request().getHeader("Authorization")).thenReturn("Basic dXNlcjpwYXNz"); Uni result = mechanism.authenticate(routingContext, identityProviderManager); assertThat(result.await().indefinitely()).isNull(); - verify(tokenBroker, never()).verify(any()); + verify(tokenBroker, never()).decode(any()); } @Test - public void testAuthenticateWithInvalidTokenInternalAuth() { + public void testAuthenticateWithMalformedToken() { + when(configuration.type()).thenReturn(AuthenticationType.MIXED); + when(routingContext.request()).thenReturn(mock(HttpServerRequest.class)); + when(routingContext.request().getHeader("Authorization")).thenReturn("Bearer malformedToken"); + + when(tokenBroker.decode("malformedToken")) + .thenReturn( + new TokenDecodeResult(TokenDecodeResult.Status.MALFORMED_TOKEN, Optional.empty())); + + SecurityIdentity securityIdentity = mock(SecurityIdentity.class); + when(identityProviderManager.authenticate(any())) + .thenReturn(Uni.createFrom().item(securityIdentity)); + + Uni result = mechanism.authenticate(routingContext, identityProviderManager); + + assertThatThrownBy(() -> result.await().indefinitely()) + .isInstanceOf(AuthenticationFailedException.class) + .hasMessage("Malformed token"); + verify(tokenBroker).decode("malformedToken"); + verify(identityProviderManager, never()).authenticate(any(TokenAuthenticationRequest.class)); + } + + @ParameterizedTest + @EnumSource( + value = TokenDecodeResult.Status.class, + names = {"INVALID_ISSUER", "INVALID_REALM"}) + public void testAuthenticateWithInvalidTokenInternalAuth(TokenDecodeResult.Status status) { when(configuration.type()).thenReturn(AuthenticationType.INTERNAL); - when(routingContext.request()).thenReturn(mock(io.vertx.core.http.HttpServerRequest.class)); + when(routingContext.request()).thenReturn(mock(HttpServerRequest.class)); when(routingContext.request().getHeader("Authorization")).thenReturn("Bearer invalidToken"); - NotAuthorizedException cause = new NotAuthorizedException("Invalid token"); - when(tokenBroker.verify("invalidToken")).thenThrow(cause); + when(tokenBroker.decode("invalidToken")) + .thenReturn(new TokenDecodeResult(status, Optional.empty())); SecurityIdentity securityIdentity = mock(SecurityIdentity.class); when(identityProviderManager.authenticate(any())) @@ -129,19 +158,22 @@ public void testAuthenticateWithInvalidTokenInternalAuth() { assertThatThrownBy(() -> result.await().indefinitely()) .isInstanceOf(AuthenticationFailedException.class) - .hasCause(cause); - verify(tokenBroker).verify("invalidToken"); + .hasMessage(status.message()); + verify(tokenBroker).decode("invalidToken"); verify(identityProviderManager, never()).authenticate(any(TokenAuthenticationRequest.class)); } - @Test - public void testAuthenticateWithInvalidTokenMixedAuth() { + @ParameterizedTest + @EnumSource( + value = TokenDecodeResult.Status.class, + names = {"INVALID_ISSUER", "INVALID_REALM"}) + public void testAuthenticateWithInvalidTokenMixedAuth(TokenDecodeResult.Status status) { when(configuration.type()).thenReturn(AuthenticationType.MIXED); - when(routingContext.request()).thenReturn(mock(io.vertx.core.http.HttpServerRequest.class)); + when(routingContext.request()).thenReturn(mock(HttpServerRequest.class)); when(routingContext.request().getHeader("Authorization")).thenReturn("Bearer invalidToken"); - NotAuthorizedException cause = new NotAuthorizedException("Invalid token"); - when(tokenBroker.verify("invalidToken")).thenThrow(cause); + when(tokenBroker.decode("invalidToken")) + .thenReturn(new TokenDecodeResult(status, Optional.empty())); SecurityIdentity securityIdentity = mock(SecurityIdentity.class); when(identityProviderManager.authenticate(any())) @@ -150,18 +182,20 @@ public void testAuthenticateWithInvalidTokenMixedAuth() { Uni result = mechanism.authenticate(routingContext, identityProviderManager); assertThat(result.await().indefinitely()).isNull(); - verify(tokenBroker).verify("invalidToken"); + verify(tokenBroker).decode("invalidToken"); verify(identityProviderManager, never()).authenticate(any(TokenAuthenticationRequest.class)); } @Test public void testAuthenticateWithValidToken() { when(configuration.type()).thenReturn(AuthenticationType.INTERNAL); - when(routingContext.request()).thenReturn(mock(io.vertx.core.http.HttpServerRequest.class)); + when(routingContext.request()).thenReturn(mock(HttpServerRequest.class)); when(routingContext.request().getHeader("Authorization")).thenReturn("Bearer validToken"); DecodedToken decodedToken = mock(DecodedToken.class); - when(tokenBroker.verify("validToken")).thenReturn(decodedToken); + when(tokenBroker.decode("validToken")) + .thenReturn( + new TokenDecodeResult(TokenDecodeResult.Status.SUCCESS, Optional.of(decodedToken))); SecurityIdentity securityIdentity = mock(SecurityIdentity.class); when(identityProviderManager.authenticate(any())) @@ -170,7 +204,7 @@ public void testAuthenticateWithValidToken() { Uni result = mechanism.authenticate(routingContext, identityProviderManager); assertThat(result.await().indefinitely()).isSameAs(securityIdentity); - verify(tokenBroker).verify("validToken"); + verify(tokenBroker).decode("validToken"); verify(identityProviderManager).authenticate(any(TokenAuthenticationRequest.class)); } } diff --git a/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/internal/InternalIdentityProviderTest.java b/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/internal/InternalIdentityProviderTest.java index 18b623713c..f684508902 100644 --- a/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/internal/InternalIdentityProviderTest.java +++ b/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/internal/InternalIdentityProviderTest.java @@ -19,9 +19,11 @@ package org.apache.polaris.service.quarkus.auth.internal; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import io.quarkus.security.AuthenticationFailedException; import io.quarkus.security.credential.TokenCredential; import io.quarkus.security.identity.AuthenticationRequestContext; import io.quarkus.security.identity.SecurityIdentity; @@ -30,7 +32,12 @@ import io.smallrye.mutiny.Uni; import io.vertx.ext.web.RoutingContext; import java.security.Principal; -import org.apache.polaris.service.quarkus.auth.internal.InternalAuthenticationMechanism.InternalPrincipalAuthInfo; +import java.util.Set; +import org.apache.iceberg.exceptions.NotAuthorizedException; +import org.apache.polaris.service.auth.DecodedToken; +import org.apache.polaris.service.auth.TokenBroker; +import org.apache.polaris.service.quarkus.auth.QuarkusPrincipalAuthInfo; +import org.apache.polaris.service.quarkus.auth.internal.InternalAuthenticationMechanism.InternalTokenCredential; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -38,11 +45,14 @@ public class InternalIdentityProviderTest { private InternalIdentityProvider provider; private AuthenticationRequestContext context; + private TokenBroker tokenBroker; @BeforeEach public void setup() { provider = new InternalIdentityProvider(); context = mock(AuthenticationRequestContext.class); + tokenBroker = mock(TokenBroker.class); + provider.tokenBroker = tokenBroker; } @Test @@ -55,11 +65,36 @@ public void testAuthenticateWithWrongCredential() { assertThat(result.await().indefinitely()).isNull(); } + @Test + public void testAuthenticateWithInvalidCredential() { + // Create a mock InternalTokenCredential + InternalTokenCredential credential = mock(InternalTokenCredential.class); + DecodedToken token = mock(DecodedToken.class); + when(credential.getDecodedToken()).thenReturn(token); + when(tokenBroker.verify(token)).thenThrow(new NotAuthorizedException("Invalid token")); + + // Create a request with the credential + TokenAuthenticationRequest request = new TokenAuthenticationRequest(credential); + + // Authenticate the request + Uni result = provider.authenticate(request, context); + + // Verify the result + assertThatThrownBy(() -> result.await().indefinitely()) + .isInstanceOf(AuthenticationFailedException.class) + .hasCauseInstanceOf(NotAuthorizedException.class); + } + @Test public void testAuthenticateWithValidCredential() { - // Create a mock InternalPrincipalAuthInfo - InternalPrincipalAuthInfo credential = mock(InternalPrincipalAuthInfo.class); - when(credential.getPrincipalName()).thenReturn("testUser"); + // Create a mock InternalTokenCredential + InternalTokenCredential credential = mock(InternalTokenCredential.class); + DecodedToken token = mock(DecodedToken.class); + when(token.getPrincipalId()).thenReturn(123L); + when(token.getSub()).thenReturn("123"); + when(token.getPrincipalRoles()).thenReturn(Set.of("role1", "role2")); + when(credential.getDecodedToken()).thenReturn(token); + when(tokenBroker.verify(token)).thenReturn(token); // Create a request with the credential and a routing context attribute RoutingContext routingContext = mock(RoutingContext.class); @@ -76,10 +111,14 @@ public void testAuthenticateWithValidCredential() { // Verify the principal Principal principal = identity.getPrincipal(); assertThat(principal).isNotNull(); - assertThat(principal.getName()).isEqualTo("testUser"); + assertThat(principal.getName()).isEqualTo("123"); - // Verify the credential is set - assertThat(identity.getCredential(InternalPrincipalAuthInfo.class)).isSameAs(credential); + // Verify the credential + QuarkusPrincipalAuthInfo authInfo = identity.getCredential(QuarkusPrincipalAuthInfo.class); + assertThat(authInfo).isNotNull(); + assertThat(authInfo.getPrincipalName()).isNull(); // not set by Polaris tokens + assertThat(authInfo.getPrincipalId()).isEqualTo(123L); + assertThat(authInfo.getPrincipalRoles()).containsExactlyInAnyOrder("role1", "role2"); // Verify the routing context attribute is set assertThat((RoutingContext) identity.getAttribute(RoutingContext.class.getName())) diff --git a/service/common/src/main/java/org/apache/polaris/service/auth/DecodedToken.java b/service/common/src/main/java/org/apache/polaris/service/auth/DecodedToken.java index 73fada6752..33bbb21929 100644 --- a/service/common/src/main/java/org/apache/polaris/service/auth/DecodedToken.java +++ b/service/common/src/main/java/org/apache/polaris/service/auth/DecodedToken.java @@ -28,6 +28,15 @@ */ public interface DecodedToken extends PrincipalAuthInfo { + /** + * Returns the underlying implementation of the decoded JWT token. + * + * @param clazz the class of the token to unwrap + * @return the underlying implementation of the decoded JWT token, never null + * @throws ClassCastException if the token cannot be unwrapped to the specified type + */ + T unwrap(Class clazz); + String getClientId(); String getSub(); diff --git a/service/common/src/main/java/org/apache/polaris/service/auth/JWTBroker.java b/service/common/src/main/java/org/apache/polaris/service/auth/JWTBroker.java index c0ce0b471c..a75c961b42 100644 --- a/service/common/src/main/java/org/apache/polaris/service/auth/JWTBroker.java +++ b/service/common/src/main/java/org/apache/polaris/service/auth/JWTBroker.java @@ -20,6 +20,7 @@ import com.auth0.jwt.JWT; import com.auth0.jwt.algorithms.Algorithm; +import com.auth0.jwt.exceptions.JWTDecodeException; import com.auth0.jwt.exceptions.JWTVerificationException; import com.auth0.jwt.interfaces.DecodedJWT; import com.auth0.jwt.interfaces.JWTVerifier; @@ -32,6 +33,7 @@ import org.apache.iceberg.exceptions.NotAuthorizedException; import org.apache.polaris.core.PolarisCallContext; import org.apache.polaris.core.context.CallContext; +import org.apache.polaris.core.context.RealmContext; import org.apache.polaris.core.entity.PolarisEntityType; import org.apache.polaris.core.entity.PrincipalEntity; import org.apache.polaris.core.persistence.PolarisMetaStoreManager; @@ -47,14 +49,20 @@ public abstract class JWTBroker implements TokenBroker { private static final String ISSUER_KEY = "polaris"; private static final String CLAIM_KEY_ACTIVE = "active"; + private static final String CLAIM_KEY_REALM = "realm"; private static final String CLAIM_KEY_CLIENT_ID = "client_id"; private static final String CLAIM_KEY_PRINCIPAL_ID = "principalId"; private static final String CLAIM_KEY_SCOPE = "scope"; + private final RealmContext realmContext; private final PolarisMetaStoreManager metaStoreManager; private final int maxTokenGenerationInSeconds; - JWTBroker(PolarisMetaStoreManager metaStoreManager, int maxTokenGenerationInSeconds) { + JWTBroker( + RealmContext realmContext, + PolarisMetaStoreManager metaStoreManager, + int maxTokenGenerationInSeconds) { + this.realmContext = realmContext; this.metaStoreManager = metaStoreManager; this.maxTokenGenerationInSeconds = maxTokenGenerationInSeconds; } @@ -62,32 +70,34 @@ public abstract class JWTBroker implements TokenBroker { public abstract Algorithm getAlgorithm(); @Override - public DecodedToken verify(String token) { - JWTVerifier verifier = JWT.require(getAlgorithm()).withClaim(CLAIM_KEY_ACTIVE, true).build(); + public TokenDecodeResult decode(String token) { + DecodedJWT jwt; + try { + jwt = JWT.decode(token); + } catch (JWTDecodeException e) { + return new TokenDecodeResult(TokenDecodeResult.Status.MALFORMED_TOKEN, Optional.empty()); + } + if (!ISSUER_KEY.equals(jwt.getIssuer())) { + return new TokenDecodeResult(TokenDecodeResult.Status.INVALID_ISSUER, Optional.empty()); + } + if (!realmContext.getRealmIdentifier().equals(jwt.getClaim(CLAIM_KEY_REALM).asString())) { + return new TokenDecodeResult(TokenDecodeResult.Status.INVALID_REALM, Optional.empty()); + } + return new TokenDecodeResult( + TokenDecodeResult.Status.SUCCESS, Optional.of(new DecodedTokenImpl(jwt))); + } + + @Override + public DecodedToken verify(DecodedToken token) { + JWTVerifier verifier = + JWT.require(getAlgorithm()) + .withClaim(CLAIM_KEY_ACTIVE, true) + .withClaim(CLAIM_KEY_REALM, realmContext.getRealmIdentifier()) + .build(); try { - DecodedJWT decodedJWT = verifier.verify(token); - return new DecodedToken() { - @Override - public Long getPrincipalId() { - return decodedJWT.getClaim("principalId").asLong(); - } - - @Override - public String getClientId() { - return decodedJWT.getClaim("client_id").asString(); - } - - @Override - public String getSub() { - return decodedJWT.getSubject(); - } - - @Override - public String getScope() { - return decodedJWT.getClaim("scope").asString(); - } - }; + DecodedJWT decodedJWT = verifier.verify(token.unwrap(DecodedJWT.class)); + return new DecodedTokenImpl(decodedJWT); } catch (JWTVerificationException e) { LOGGER.error("Failed to verify the token with error", e); @@ -170,6 +180,7 @@ private String generateTokenString(String clientId, String scope, Long principal .withExpiresAt(now.plus(maxTokenGenerationInSeconds, ChronoUnit.SECONDS)) .withJWTId(UUID.randomUUID().toString()) .withClaim(CLAIM_KEY_ACTIVE, true) + .withClaim(CLAIM_KEY_REALM, realmContext.getRealmIdentifier()) .withClaim(CLAIM_KEY_CLIENT_ID, clientId) .withClaim(CLAIM_KEY_PRINCIPAL_ID, principalId) .withClaim(CLAIM_KEY_SCOPE, scopes(scope)) @@ -189,4 +200,32 @@ public boolean supportsRequestedTokenType(TokenType tokenType) { private String scopes(String scope) { return StringUtils.isNotBlank(scope) ? scope : DefaultAuthenticator.PRINCIPAL_ROLE_ALL; } + + private record DecodedTokenImpl(DecodedJWT decodedJWT) implements DecodedToken { + + @Override + public T unwrap(Class clazz) { + return clazz.cast(decodedJWT); + } + + @Override + public Long getPrincipalId() { + return decodedJWT.getClaim("principalId").asLong(); + } + + @Override + public String getClientId() { + return decodedJWT.getClaim("client_id").asString(); + } + + @Override + public String getSub() { + return decodedJWT.getSubject(); + } + + @Override + public String getScope() { + return decodedJWT.getClaim("scope").asString(); + } + } } diff --git a/service/common/src/main/java/org/apache/polaris/service/auth/JWTRSAKeyPair.java b/service/common/src/main/java/org/apache/polaris/service/auth/JWTRSAKeyPair.java index 18f2702380..c051f7f751 100644 --- a/service/common/src/main/java/org/apache/polaris/service/auth/JWTRSAKeyPair.java +++ b/service/common/src/main/java/org/apache/polaris/service/auth/JWTRSAKeyPair.java @@ -22,6 +22,7 @@ import java.nio.file.Path; import java.security.interfaces.RSAPrivateKey; import java.security.interfaces.RSAPublicKey; +import org.apache.polaris.core.context.RealmContext; import org.apache.polaris.core.persistence.PolarisMetaStoreManager; /** Generates a JWT using a Public/Private RSA Key */ @@ -30,11 +31,12 @@ public class JWTRSAKeyPair extends JWTBroker { private final KeyProvider keyProvider; public JWTRSAKeyPair( + RealmContext realmContext, PolarisMetaStoreManager metaStoreManager, int maxTokenGenerationInSeconds, Path publicKeyFile, Path privateKeyFile) { - super(metaStoreManager, maxTokenGenerationInSeconds); + super(realmContext, metaStoreManager, maxTokenGenerationInSeconds); keyProvider = new LocalRSAKeyProvider(publicKeyFile, privateKeyFile); } diff --git a/service/common/src/main/java/org/apache/polaris/service/auth/JWTRSAKeyPairFactory.java b/service/common/src/main/java/org/apache/polaris/service/auth/JWTRSAKeyPairFactory.java index ee74caf466..9adeb34c68 100644 --- a/service/common/src/main/java/org/apache/polaris/service/auth/JWTRSAKeyPairFactory.java +++ b/service/common/src/main/java/org/apache/polaris/service/auth/JWTRSAKeyPairFactory.java @@ -64,6 +64,7 @@ private JWTRSAKeyPair createTokenBroker(RealmContext realmContext) { PolarisMetaStoreManager metaStoreManager = metaStoreManagerFactory.getOrCreateMetaStoreManager(realmContext); return new JWTRSAKeyPair( + realmContext, metaStoreManager, (int) maxTokenGeneration.toSeconds(), keyPairConfiguration.publicKeyFile(), diff --git a/service/common/src/main/java/org/apache/polaris/service/auth/JWTSymmetricKeyBroker.java b/service/common/src/main/java/org/apache/polaris/service/auth/JWTSymmetricKeyBroker.java index 16c9e15511..da72a20043 100644 --- a/service/common/src/main/java/org/apache/polaris/service/auth/JWTSymmetricKeyBroker.java +++ b/service/common/src/main/java/org/apache/polaris/service/auth/JWTSymmetricKeyBroker.java @@ -20,6 +20,7 @@ import com.auth0.jwt.algorithms.Algorithm; import java.util.function.Supplier; +import org.apache.polaris.core.context.RealmContext; import org.apache.polaris.core.persistence.PolarisMetaStoreManager; /** Generates a JWT using a Symmetric Key. */ @@ -27,10 +28,11 @@ public class JWTSymmetricKeyBroker extends JWTBroker { private final Supplier secretSupplier; public JWTSymmetricKeyBroker( + RealmContext realmContext, PolarisMetaStoreManager metaStoreManager, int maxTokenGenerationInSeconds, Supplier secretSupplier) { - super(metaStoreManager, maxTokenGenerationInSeconds); + super(realmContext, metaStoreManager, maxTokenGenerationInSeconds); this.secretSupplier = secretSupplier; } diff --git a/service/common/src/main/java/org/apache/polaris/service/auth/JWTSymmetricKeyFactory.java b/service/common/src/main/java/org/apache/polaris/service/auth/JWTSymmetricKeyFactory.java index ed33800646..1f381e7cbf 100644 --- a/service/common/src/main/java/org/apache/polaris/service/auth/JWTSymmetricKeyFactory.java +++ b/service/common/src/main/java/org/apache/polaris/service/auth/JWTSymmetricKeyFactory.java @@ -71,6 +71,7 @@ private JWTSymmetricKeyBroker createTokenBroker(RealmContext realmContext) { checkState(secret != null || file != null, "Either file or secret must be set"); Supplier secretSupplier = secret != null ? () -> secret : readSecretFromDisk(file); return new JWTSymmetricKeyBroker( + realmContext, metaStoreManagerFactory.getOrCreateMetaStoreManager(realmContext), (int) maxTokenGeneration.toSeconds(), secretSupplier); diff --git a/service/common/src/main/java/org/apache/polaris/service/auth/NoneTokenBrokerFactory.java b/service/common/src/main/java/org/apache/polaris/service/auth/NoneTokenBrokerFactory.java index a352159372..64ed0ad2c5 100644 --- a/service/common/src/main/java/org/apache/polaris/service/auth/NoneTokenBrokerFactory.java +++ b/service/common/src/main/java/org/apache/polaris/service/auth/NoneTokenBrokerFactory.java @@ -63,7 +63,12 @@ public TokenResponse generateFromToken( } @Override - public DecodedToken verify(String token) { + public TokenDecodeResult decode(String token) { + return null; + } + + @Override + public DecodedToken verify(DecodedToken token) { return null; } }; diff --git a/service/common/src/main/java/org/apache/polaris/service/auth/TokenBroker.java b/service/common/src/main/java/org/apache/polaris/service/auth/TokenBroker.java index b5d242070a..5005370765 100644 --- a/service/common/src/main/java/org/apache/polaris/service/auth/TokenBroker.java +++ b/service/common/src/main/java/org/apache/polaris/service/auth/TokenBroker.java @@ -20,6 +20,7 @@ import jakarta.annotation.Nonnull; import java.util.Optional; +import org.apache.iceberg.exceptions.NotAuthorizedException; import org.apache.polaris.core.PolarisCallContext; import org.apache.polaris.core.entity.PolarisEntityType; import org.apache.polaris.core.entity.PrincipalEntity; @@ -109,7 +110,40 @@ TokenResponse generateFromToken( final String scope, TokenType requestedTokenType); - DecodedToken verify(String token); + /** + * Decode a token. This method does not verify the token signature nor its expiration; it only + * checks the issuer and realm. Do NOT rely on this method only for authentication: {@link + * #verify} must also be called to ensure the token is valid. + * + *

This method never throws an exception. If the token is invalid, it will return a {@link + * TokenDecodeResult} with the error code set. + * + * @param token The token to decode + * @return The decode result + */ + TokenDecodeResult decode(String token); + + /** + * Verify a token and return the decoded token. This method verifies the token signature and + * required claims. + * + * @param token The decoded token to verify + * @return The verified token + * @throws NotAuthorizedException if the token is invalid + */ + DecodedToken verify(DecodedToken token); + + /** + * Verify a token and return the decoded token. This method verifies the token signature and + * required claims. This is a convenience method that calls {@link #decode(String)} and then + * {@link #verify(DecodedToken)}. + * + * @param token The token to verify + * @return The verified token + */ + default DecodedToken verify(String token) { + return verify(decode(token).token().orElseThrow()); + } static @Nonnull Optional findPrincipalEntity( PolarisMetaStoreManager metaStoreManager, diff --git a/service/common/src/main/java/org/apache/polaris/service/auth/TokenDecodeResult.java b/service/common/src/main/java/org/apache/polaris/service/auth/TokenDecodeResult.java new file mode 100644 index 0000000000..7a84a7c22f --- /dev/null +++ b/service/common/src/main/java/org/apache/polaris/service/auth/TokenDecodeResult.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.polaris.service.auth; + +import com.google.common.base.Preconditions; +import java.util.Optional; + +public record TokenDecodeResult(Status status, Optional token) { + + public TokenDecodeResult { + Preconditions.checkState(!(status == Status.SUCCESS && token.isEmpty())); + Preconditions.checkState(!(status != Status.SUCCESS && token.isPresent())); + } + + public enum Status { + SUCCESS(""), + MALFORMED_TOKEN("Malformed token"), + INVALID_ISSUER("Invalid issuer"), + INVALID_REALM("Invalid realm"), + ; + private final String message; + + Status(String message) { + this.message = message; + } + + public String message() { + return message; + } + } +}