Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import io.quarkus.vertx.http.runtime.security.HttpSecurityUtils;
import io.smallrye.mutiny.Uni;
import io.vertx.ext.web.RoutingContext;
import jakarta.annotation.Nullable;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import java.util.Collections;
Expand All @@ -41,7 +40,7 @@
import org.apache.polaris.service.auth.AuthenticationType;
import org.apache.polaris.service.auth.DecodedToken;
import org.apache.polaris.service.auth.TokenBroker;
import org.apache.polaris.service.quarkus.auth.QuarkusPrincipalAuthInfo;
import org.apache.polaris.service.auth.TokenDecodeResult;

/**
* A custom {@link HttpAuthenticationMechanism} that handles internal token authentication, that is,
Expand Down Expand Up @@ -86,22 +85,21 @@ public Uni<SecurityIdentity> authenticate(

String credential = authHeader.substring(spaceIdx + 1);

DecodedToken token;
try {
token = tokenBroker.verify(credential);
} catch (Exception e) {
TokenDecodeResult result = tokenBroker.decode(credential);

if (result.token().isEmpty()) {
return configuration.type() == AuthenticationType.MIXED
&& result.status() != TokenDecodeResult.Status.MALFORMED_TOKEN
? Uni.createFrom().nullItem() // let other auth mechanisms handle it
: Uni.createFrom().failure(new AuthenticationFailedException(e)); // stop here
: Uni.createFrom()
.failure(new AuthenticationFailedException(result.status().message())); // stop here
}

if (token == null) {
return Uni.createFrom().nullItem();
}
DecodedToken token = result.token().get();

return identityProviderManager.authenticate(
HttpSecurityUtils.setRoutingContextAttribute(
new TokenAuthenticationRequest(new InternalPrincipalAuthInfo(credential, token)),
new TokenAuthenticationRequest(new InternalTokenCredential(credential, token)),
context));
}

Expand All @@ -124,31 +122,17 @@ public Uni<HttpCredentialTransport> getCredentialTransport(RoutingContext contex
.item(new HttpCredentialTransport(HttpCredentialTransport.Type.AUTHORIZATION, BEARER));
}

static class InternalPrincipalAuthInfo extends TokenCredential
implements QuarkusPrincipalAuthInfo {
static class InternalTokenCredential extends TokenCredential {

private final DecodedToken token;

InternalPrincipalAuthInfo(String credential, DecodedToken token) {
InternalTokenCredential(String credential, DecodedToken token) {
super(credential, "bearer");
this.token = token;
}

@Nullable
@Override
public Long getPrincipalId() {
return token.getPrincipalId();
}

@Nullable
@Override
public String getPrincipalName() {
return token.getPrincipalName();
}

@Override
public Set<String> getPrincipalRoles() {
return token.getPrincipalRoles();
public DecodedToken getDecodedToken() {
return token;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.polaris.service.quarkus.auth.internal;

import io.quarkus.security.AuthenticationFailedException;
import io.quarkus.security.identity.AuthenticationRequestContext;
import io.quarkus.security.identity.IdentityProvider;
import io.quarkus.security.identity.SecurityIdentity;
Expand All @@ -26,13 +27,22 @@
import io.quarkus.vertx.http.runtime.security.HttpSecurityUtils;
import io.smallrye.mutiny.Uni;
import io.vertx.ext.web.RoutingContext;
import jakarta.annotation.Nullable;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import java.security.Principal;
import java.util.Set;
import org.apache.polaris.service.auth.DecodedToken;
import org.apache.polaris.service.auth.TokenBroker;
import org.apache.polaris.service.quarkus.auth.QuarkusPrincipalAuthInfo;
import org.apache.polaris.service.quarkus.auth.internal.InternalAuthenticationMechanism.InternalTokenCredential;

/** A custom {@link IdentityProvider} that handles internal token authentication requests. */
@ApplicationScoped
public class InternalIdentityProvider implements IdentityProvider<TokenAuthenticationRequest> {

@Inject TokenBroker tokenBroker;

@Override
public Class<TokenAuthenticationRequest> getRequestType() {
return TokenAuthenticationRequest.class;
Expand All @@ -41,21 +51,53 @@ public Class<TokenAuthenticationRequest> getRequestType() {
@Override
public Uni<SecurityIdentity> authenticate(
TokenAuthenticationRequest request, AuthenticationRequestContext context) {
if (!(request.getToken()
instanceof InternalAuthenticationMechanism.InternalPrincipalAuthInfo credential)) {
if (!(request.getToken() instanceof InternalTokenCredential credential)) {
return Uni.createFrom().nullItem();
}
InternalTokenPrincipal principal = new InternalTokenPrincipal(credential.getPrincipalName());
return Uni.createFrom()
.item(
QuarkusSecurityIdentity.builder()
.setPrincipal(principal)
.addCredential(credential)
.addAttribute(
RoutingContext.class.getName(),
HttpSecurityUtils.getRoutingContextAttribute(request))
.build());
return Uni.createFrom().item(() -> verifyToken(credential, request));
}

private SecurityIdentity verifyToken(
InternalTokenCredential credential, TokenAuthenticationRequest request) {
DecodedToken verified;
try {
verified = tokenBroker.verify(credential.getDecodedToken());
} catch (Exception e) {
throw new AuthenticationFailedException(e);
}
return QuarkusSecurityIdentity.builder()
.setPrincipal(new InternalTokenPrincipal(verified.getSub()))
.addCredential(new InternalPrincipalAuthInfo(verified))
.addAttribute(
RoutingContext.class.getName(), HttpSecurityUtils.getRoutingContextAttribute(request))
.build();
}

private record InternalTokenPrincipal(String getName) implements Principal {}

static class InternalPrincipalAuthInfo implements QuarkusPrincipalAuthInfo {

private final DecodedToken token;

InternalPrincipalAuthInfo(DecodedToken token) {
this.token = token;
}

@Nullable
@Override
public Long getPrincipalId() {
return token.getPrincipalId();
}

@Nullable
@Override
public String getPrincipalName() {
return token.getPrincipalName();
}

@Override
public Set<String> getPrincipalRoles() {
return token.getPrincipalRoles();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ public void testSuccessfulTokenGeneration() throws Exception {
metastoreManager.loadEntity(polarisCallContext, 0L, 1L, PolarisEntityType.PRINCIPAL))
.thenReturn(new EntityResult(principal));
TokenBroker tokenBroker =
new JWTRSAKeyPair(metastoreManager, 420, publicFileLocation, privateFileLocation);
new JWTRSAKeyPair(
() -> "test", metastoreManager, 420, publicFileLocation, privateFileLocation);
TokenResponse token =
tokenBroker.generateFromClientSecrets(
clientId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ public Map<String, Object> contextVariables() {
Mockito.when(
metastoreManager.loadEntity(polarisCallContext, 0L, 1L, PolarisEntityType.PRINCIPAL))
.thenReturn(new EntityResult(principal));
TokenBroker generator = new JWTSymmetricKeyBroker(metastoreManager, 666, () -> "polaris");
TokenBroker generator =
new JWTSymmetricKeyBroker(() -> "test", metastoreManager, 666, () -> "polaris");
TokenResponse token =
generator.generateFromClientSecrets(
clientId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,19 @@
import io.quarkus.security.identity.SecurityIdentity;
import io.quarkus.security.identity.request.TokenAuthenticationRequest;
import io.smallrye.mutiny.Uni;
import io.vertx.core.http.HttpServerRequest;
import io.vertx.ext.web.RoutingContext;
import org.apache.iceberg.exceptions.NotAuthorizedException;
import java.util.Optional;
import org.apache.polaris.service.auth.AuthenticationRealmConfiguration;
import org.apache.polaris.service.auth.AuthenticationType;
import org.apache.polaris.service.auth.DecodedToken;
import org.apache.polaris.service.auth.TokenBroker;
import org.apache.polaris.service.auth.TokenDecodeResult;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.EnumSource;

public class InternalAuthenticationMechanismTest {

Expand Down Expand Up @@ -79,47 +82,73 @@ public void testShouldProcess(AuthenticationType type, boolean expectedResult) {
@Test
public void testAuthenticateWithNoAuthHeader() {
when(configuration.type()).thenReturn(AuthenticationType.INTERNAL);
when(routingContext.request()).thenReturn(mock(io.vertx.core.http.HttpServerRequest.class));
when(routingContext.request()).thenReturn(mock(HttpServerRequest.class));
when(routingContext.request().getHeader("Authorization")).thenReturn(null);

Uni<SecurityIdentity> result = mechanism.authenticate(routingContext, identityProviderManager);

assertThat(result.await().indefinitely()).isNull();
verify(tokenBroker, never()).verify(any());
verify(tokenBroker, never()).decode(any());
}

@Test
public void testAuthenticateWithInvalidAuthHeaderFormat() {
when(configuration.type()).thenReturn(AuthenticationType.INTERNAL);
when(routingContext.request()).thenReturn(mock(io.vertx.core.http.HttpServerRequest.class));
when(routingContext.request()).thenReturn(mock(HttpServerRequest.class));
when(routingContext.request().getHeader("Authorization")).thenReturn("InvalidFormat");

Uni<SecurityIdentity> result = mechanism.authenticate(routingContext, identityProviderManager);

assertThat(result.await().indefinitely()).isNull();
verify(tokenBroker, never()).verify(any());
verify(tokenBroker, never()).decode(any());
}

@Test
public void testAuthenticateWithNonBearerAuthHeader() {
when(configuration.type()).thenReturn(AuthenticationType.INTERNAL);
when(routingContext.request()).thenReturn(mock(io.vertx.core.http.HttpServerRequest.class));
when(routingContext.request()).thenReturn(mock(HttpServerRequest.class));
when(routingContext.request().getHeader("Authorization")).thenReturn("Basic dXNlcjpwYXNz");

Uni<SecurityIdentity> result = mechanism.authenticate(routingContext, identityProviderManager);

assertThat(result.await().indefinitely()).isNull();
verify(tokenBroker, never()).verify(any());
verify(tokenBroker, never()).decode(any());
}

@Test
public void testAuthenticateWithInvalidTokenInternalAuth() {
public void testAuthenticateWithMalformedToken() {
when(configuration.type()).thenReturn(AuthenticationType.MIXED);
when(routingContext.request()).thenReturn(mock(HttpServerRequest.class));
when(routingContext.request().getHeader("Authorization")).thenReturn("Bearer malformedToken");

when(tokenBroker.decode("malformedToken"))
.thenReturn(
new TokenDecodeResult(TokenDecodeResult.Status.MALFORMED_TOKEN, Optional.empty()));

SecurityIdentity securityIdentity = mock(SecurityIdentity.class);
when(identityProviderManager.authenticate(any()))
.thenReturn(Uni.createFrom().item(securityIdentity));

Uni<SecurityIdentity> result = mechanism.authenticate(routingContext, identityProviderManager);

assertThatThrownBy(() -> result.await().indefinitely())
.isInstanceOf(AuthenticationFailedException.class)
.hasMessage("Malformed token");
verify(tokenBroker).decode("malformedToken");
verify(identityProviderManager, never()).authenticate(any(TokenAuthenticationRequest.class));
}

@ParameterizedTest
@EnumSource(
value = TokenDecodeResult.Status.class,
names = {"INVALID_ISSUER", "INVALID_REALM"})
public void testAuthenticateWithInvalidTokenInternalAuth(TokenDecodeResult.Status status) {
when(configuration.type()).thenReturn(AuthenticationType.INTERNAL);
when(routingContext.request()).thenReturn(mock(io.vertx.core.http.HttpServerRequest.class));
when(routingContext.request()).thenReturn(mock(HttpServerRequest.class));
when(routingContext.request().getHeader("Authorization")).thenReturn("Bearer invalidToken");

NotAuthorizedException cause = new NotAuthorizedException("Invalid token");
when(tokenBroker.verify("invalidToken")).thenThrow(cause);
when(tokenBroker.decode("invalidToken"))
.thenReturn(new TokenDecodeResult(status, Optional.empty()));

SecurityIdentity securityIdentity = mock(SecurityIdentity.class);
when(identityProviderManager.authenticate(any()))
Expand All @@ -129,19 +158,22 @@ public void testAuthenticateWithInvalidTokenInternalAuth() {

assertThatThrownBy(() -> result.await().indefinitely())
.isInstanceOf(AuthenticationFailedException.class)
.hasCause(cause);
verify(tokenBroker).verify("invalidToken");
.hasMessage(status.message());
verify(tokenBroker).decode("invalidToken");
verify(identityProviderManager, never()).authenticate(any(TokenAuthenticationRequest.class));
}

@Test
public void testAuthenticateWithInvalidTokenMixedAuth() {
@ParameterizedTest
@EnumSource(
value = TokenDecodeResult.Status.class,
names = {"INVALID_ISSUER", "INVALID_REALM"})
public void testAuthenticateWithInvalidTokenMixedAuth(TokenDecodeResult.Status status) {
when(configuration.type()).thenReturn(AuthenticationType.MIXED);
when(routingContext.request()).thenReturn(mock(io.vertx.core.http.HttpServerRequest.class));
when(routingContext.request()).thenReturn(mock(HttpServerRequest.class));
when(routingContext.request().getHeader("Authorization")).thenReturn("Bearer invalidToken");

NotAuthorizedException cause = new NotAuthorizedException("Invalid token");
when(tokenBroker.verify("invalidToken")).thenThrow(cause);
when(tokenBroker.decode("invalidToken"))
.thenReturn(new TokenDecodeResult(status, Optional.empty()));

SecurityIdentity securityIdentity = mock(SecurityIdentity.class);
when(identityProviderManager.authenticate(any()))
Expand All @@ -150,18 +182,20 @@ public void testAuthenticateWithInvalidTokenMixedAuth() {
Uni<SecurityIdentity> result = mechanism.authenticate(routingContext, identityProviderManager);

assertThat(result.await().indefinitely()).isNull();
verify(tokenBroker).verify("invalidToken");
verify(tokenBroker).decode("invalidToken");
verify(identityProviderManager, never()).authenticate(any(TokenAuthenticationRequest.class));
}

@Test
public void testAuthenticateWithValidToken() {
when(configuration.type()).thenReturn(AuthenticationType.INTERNAL);
when(routingContext.request()).thenReturn(mock(io.vertx.core.http.HttpServerRequest.class));
when(routingContext.request()).thenReturn(mock(HttpServerRequest.class));
when(routingContext.request().getHeader("Authorization")).thenReturn("Bearer validToken");

DecodedToken decodedToken = mock(DecodedToken.class);
when(tokenBroker.verify("validToken")).thenReturn(decodedToken);
when(tokenBroker.decode("validToken"))
.thenReturn(
new TokenDecodeResult(TokenDecodeResult.Status.SUCCESS, Optional.of(decodedToken)));

SecurityIdentity securityIdentity = mock(SecurityIdentity.class);
when(identityProviderManager.authenticate(any()))
Expand All @@ -170,7 +204,7 @@ public void testAuthenticateWithValidToken() {
Uni<SecurityIdentity> result = mechanism.authenticate(routingContext, identityProviderManager);

assertThat(result.await().indefinitely()).isSameAs(securityIdentity);
verify(tokenBroker).verify("validToken");
verify(tokenBroker).decode("validToken");
verify(identityProviderManager).authenticate(any(TokenAuthenticationRequest.class));
}
}
Loading
Loading