Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.base.Suppliers;
import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
Expand Down Expand Up @@ -60,38 +61,40 @@ public interface StorageConfiguration {
Optional<Duration> gcpAccessTokenLifespan();

default Supplier<StsClient> stsClientSupplier() {
return () -> {
StsClientBuilder stsClientBuilder = StsClient.builder();
if (awsAccessKey().isPresent() && awsSecretKey().isPresent()) {
LoggerFactory.getLogger(StorageConfiguration.class)
.warn("Using hard-coded AWS credentials - this is not recommended for production");
StaticCredentialsProvider awsCredentialsProvider =
StaticCredentialsProvider.create(
AwsBasicCredentials.create(awsAccessKey().get(), awsSecretKey().get()));
stsClientBuilder.credentialsProvider(awsCredentialsProvider);
}
return stsClientBuilder.build();
};
return Suppliers.memoize(
() -> {
StsClientBuilder stsClientBuilder = StsClient.builder();
if (awsAccessKey().isPresent() && awsSecretKey().isPresent()) {
LoggerFactory.getLogger(StorageConfiguration.class)
.warn("Using hard-coded AWS credentials - this is not recommended for production");
StaticCredentialsProvider awsCredentialsProvider =
StaticCredentialsProvider.create(
AwsBasicCredentials.create(awsAccessKey().get(), awsSecretKey().get()));
stsClientBuilder.credentialsProvider(awsCredentialsProvider);
}
return stsClientBuilder.build();
});
}

default Supplier<GoogleCredentials> gcpCredentialsSupplier() {
return () -> {
if (gcpAccessToken().isEmpty()) {
try {
return GoogleCredentials.getApplicationDefault();
} catch (IOException e) {
throw new RuntimeException("Failed to get GCP credentials", e);
}
} else {
AccessToken accessToken =
new AccessToken(
gcpAccessToken().get(),
new Date(
Instant.now()
.plus(gcpAccessTokenLifespan().orElse(DEFAULT_TOKEN_LIFESPAN))
.toEpochMilli()));
return GoogleCredentials.create(accessToken);
}
};
return Suppliers.memoize(
() -> {
if (gcpAccessToken().isEmpty()) {
try {
return GoogleCredentials.getApplicationDefault();
} catch (IOException e) {
throw new RuntimeException("Failed to get GCP credentials", e);
}
} else {
AccessToken accessToken =
new AccessToken(
gcpAccessToken().get(),
new Date(
Instant.now()
.plus(gcpAccessTokenLifespan().orElse(DEFAULT_TOKEN_LIFESPAN))
.toEpochMilli()));
return GoogleCredentials.create(accessToken);
}
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.polaris.service.storage;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.*;

import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.GoogleCredentials;
import java.time.Duration;
import java.time.Instant;
import java.util.Optional;
import java.util.function.Supplier;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.StsClientBuilder;

public class StorageConfigurationTest {

private static final String TEST_ACCESS_KEY = "test-access-key";
private static final String TEST_GCP_TOKEN = "ya29.test-token";
private static final String TEST_SECRET_KEY = "test-secret-key";
private static final Duration TEST_TOKEN_LIFESPAN = Duration.ofMinutes(20);

private StorageConfiguration configWithAwsCredentialsAndGcpToken() {
return new StorageConfiguration() {
@Override
public Optional<String> awsAccessKey() {
return Optional.of(TEST_ACCESS_KEY);
}

@Override
public Optional<String> awsSecretKey() {
return Optional.of(TEST_SECRET_KEY);
}

@Override
public Optional<String> gcpAccessToken() {
return Optional.of(TEST_GCP_TOKEN);
}

@Override
public Optional<Duration> gcpAccessTokenLifespan() {
return Optional.of(TEST_TOKEN_LIFESPAN);
}
};
}

private StorageConfiguration configWithoutGcpToken() {
return new StorageConfiguration() {
@Override
public Optional<String> awsAccessKey() {
return Optional.empty();
}

@Override
public Optional<String> awsSecretKey() {
return Optional.empty();
}

@Override
public Optional<String> gcpAccessToken() {
return Optional.empty();
}

@Override
public Optional<Duration> gcpAccessTokenLifespan() {
return Optional.empty();
}
};
}

@Test
public void testSingletonStsClientWithStaticCredentials() {
StsClientBuilder mockBuilder = mock(StsClientBuilder.class);
StsClient mockStsClient = mock(StsClient.class);
ArgumentCaptor<StaticCredentialsProvider> credsCaptor =
ArgumentCaptor.forClass(StaticCredentialsProvider.class);

when(mockBuilder.credentialsProvider(credsCaptor.capture())).thenReturn(mockBuilder);
when(mockBuilder.region(any())).thenReturn(mockBuilder);
when(mockBuilder.build()).thenReturn(mockStsClient);

try (MockedStatic<StsClient> staticMock = Mockito.mockStatic(StsClient.class)) {
staticMock.when(StsClient::builder).thenReturn(mockBuilder);

StorageConfiguration config = configWithAwsCredentialsAndGcpToken();
Supplier<StsClient> supplier = config.stsClientSupplier();
StsClient client1 = supplier.get();
StsClient client2 = supplier.get();

assertThat(client1).isSameAs(client2);
assertThat(client1).isNotNull();

StaticCredentialsProvider credentialsProvider = credsCaptor.getValue();
assertThat(credentialsProvider.resolveCredentials().accessKeyId()).isEqualTo(TEST_ACCESS_KEY);
assertThat(credentialsProvider.resolveCredentials().secretAccessKey())
.isEqualTo(TEST_SECRET_KEY);
}
}

@Test
public void testCreateGcpCredentialsFromStaticToken() {
Supplier<GoogleCredentials> supplier =
configWithAwsCredentialsAndGcpToken().gcpCredentialsSupplier();

GoogleCredentials credentials = supplier.get();
assertThat(credentials).isNotNull();

AccessToken accessToken = credentials.getAccessToken();
assertThat(accessToken).isNotNull();
assertThat(accessToken.getTokenValue()).isEqualTo(TEST_GCP_TOKEN);
long expectedExpiry = Instant.now().plus(Duration.ofMinutes(20)).toEpochMilli();
long actualExpiry = accessToken.getExpirationTime().getTime();
assertThat(actualExpiry).isBetween(expectedExpiry - 500, expectedExpiry + 500);
}

@Test
public void testGcpCredentialsFromDefault() {
GoogleCredentials mockDefaultCreds = mock(GoogleCredentials.class);

try (MockedStatic<GoogleCredentials> mockedStatic =
Mockito.mockStatic(GoogleCredentials.class)) {

mockedStatic.when(GoogleCredentials::getApplicationDefault).thenReturn(mockDefaultCreds);

Supplier<GoogleCredentials> supplier = configWithoutGcpToken().gcpCredentialsSupplier();
GoogleCredentials result = supplier.get();

assertThat(result).isSameAs(mockDefaultCreds);
mockedStatic.verify(GoogleCredentials::getApplicationDefault, times(1));
}
}
}