Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentials.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,42 @@
#pragma once

#include <aws/core/Core_EXPORTS.h>
#include <aws/core/client/UserAgent.h>
#include <aws/core/utils/memory/stl/AWSString.h>
#include <aws/core/utils/DateTime.h>
namespace Aws
{
namespace Auth
{
/**
* Context class for credential resolution that tracks features used during credential retrieval.
*/
class AWS_CORE_API CredentialsResolutionContext
{
public:
// Default constructor - no features tracked
CredentialsResolutionContext() = default;

/**
* Add a user agent feature to track credential usage.
*/
void AddUserAgentFeature(Aws::Client::UserAgentFeature feature)
{
m_features.insert(feature);
}

/**
* Get all tracked credential features.
*/
const Aws::Set<Aws::Client::UserAgentFeature> GetUserAgentFeatures() const
{
return m_features;
}

private:
Aws::Set<Aws::Client::UserAgentFeature> m_features;
};

/**
* Simple data object around aws credentials
*/
Expand Down Expand Up @@ -214,12 +244,24 @@ namespace Aws
m_expiration = expiration;
}

/**
* Gets credential resolution context. this is information about the call
* such as what credentials provider was used to to resolve the credentials
*/
inline CredentialsResolutionContext GetContext() { return m_context; }

/**
* Adds a user agent feature used during credentials resolution to the credentials
* context. This is useful to track which credentials provider was used.
*/
inline void AddUserAgentFeature(Aws::Client::UserAgentFeature feature) { m_context.AddUserAgentFeature(feature); }
private:
Aws::String m_accessKeyId;
Aws::String m_secretKey;
Aws::String m_sessionToken;
Aws::Utils::DateTime m_expiration;
Aws::String m_accountId;
CredentialsResolutionContext m_context;
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace Aws
{
class AWSCredentials;
class AWSCredentialsProvider;
class CredentialsResolutionContext;

enum class AWSSigningAlgorithm
{
Expand Down Expand Up @@ -191,6 +192,7 @@ namespace Aws

protected:
virtual bool ServiceRequireUnsignedPayload(const Aws::String& serviceName) const;
void UpdateUserAgentWithCredentialFeatures(Aws::Http::HttpRequest& request, const Aws::Auth::CredentialsResolutionContext& context) const;
bool m_includeSha256HashHeader;

private:
Expand Down
2 changes: 2 additions & 0 deletions src/aws-cpp-sdk-core/include/aws/core/client/UserAgent.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ enum class UserAgentFeature {
ACCOUNT_ID_MODE_REQUIRED,
RESOLVED_ACCOUNT_ID,
GZIP_REQUEST_COMPRESSION,
CREDENTIALS_ENV_VARS,
};

class AWS_CORE_API UserAgent {
public:
static Aws::String BusinessMetricForFeature(UserAgentFeature feature);
explicit UserAgent(const ClientConfiguration& clientConfiguration, const Aws::String& retryStrategyName, const Aws::String& apiName);
Aws::String SerializeWithFeatures(const Aws::Set<UserAgentFeature>& features) const;
void SetApiName(const Aws::String& apiName) { m_api = apiName; }
Expand Down
5 changes: 5 additions & 0 deletions src/aws-cpp-sdk-core/source/auth/AWSCredentialsProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <aws/core/client/AWSError.h>
#include <aws/core/utils/StringUtils.h>
#include <aws/core/utils/xml/XmlSerializer.h>
#include <aws/core/client/UserAgent.h>
#include <cstdlib>
#include <fstream>
#include <string.h>
Expand Down Expand Up @@ -103,6 +104,10 @@ AWSCredentials EnvironmentAWSCredentialsProvider::GetAWSCredentials()
}
}

if (!credentials.IsEmpty()) {
credentials.AddUserAgentFeature(UserAgentFeature::CREDENTIALS_ENV_VARS);
}

return credentials;
}

Expand Down
58 changes: 58 additions & 0 deletions src/aws-cpp-sdk-core/source/auth/signer/AWSAuthV4Signer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <aws/core/auth/signer/AWSAuthSignerHelper.h>

#include <aws/core/auth/AWSCredentialsProvider.h>
#include <aws/core/client/UserAgent.h>
#include <aws/core/http/HttpRequest.h>
#include <aws/core/http/URI.h>
#include <aws/core/utils/DateTime.h>
Expand All @@ -25,6 +26,7 @@

#include <iomanip>
#include <cstring>
#include <numeric>

using namespace Aws;
using namespace Aws::Client;
Expand Down Expand Up @@ -81,6 +83,8 @@ bool AWSAuthV4Signer::SignRequestWithSigV4a(Aws::Http::HttpRequest& request, con
bool signBody, long long expirationTimeInSeconds, Aws::Crt::Auth::SignatureType signatureType) const
{
AWSCredentials credentials = GetCredentials(request.GetServiceSpecificParameters());

UpdateUserAgentWithCredentialFeatures(request, credentials.GetContext());
auto crtCredentials = Aws::MakeShared<Aws::Crt::Auth::Credentials>(v4AsymmetricLogTag,
Aws::Crt::ByteCursorFromCString(credentials.GetAWSAccessKeyId().c_str()),
Aws::Crt::ByteCursorFromCString(credentials.GetAWSSecretKey().c_str()),
Expand Down Expand Up @@ -336,6 +340,9 @@ bool AWSAuthV4Signer::SignRequestWithCreds(Aws::Http::HttpRequest& request, cons
bool AWSAuthV4Signer::SignRequest(Aws::Http::HttpRequest& request, const char* region, const char* serviceName, bool signBody) const
{
AWSCredentials credentials = GetCredentials(request.GetServiceSpecificParameters());

UpdateUserAgentWithCredentialFeatures(request, credentials.GetContext());

return SignRequestWithCreds(request, credentials, region, serviceName, signBody);
}

Expand Down Expand Up @@ -464,6 +471,9 @@ bool AWSAuthV4Signer::PresignRequest(Aws::Http::HttpRequest& request, const Aws:
bool AWSAuthV4Signer::PresignRequest(Aws::Http::HttpRequest& request, const char* region, const char* serviceName, long long expirationTimeInSeconds) const
{
AWSCredentials credentials = GetCredentials(request.GetServiceSpecificParameters());

UpdateUserAgentWithCredentialFeatures(request, credentials.GetContext());

return PresignRequest(request, credentials, region,serviceName, expirationTimeInSeconds );
}

Expand Down Expand Up @@ -595,3 +605,51 @@ Aws::Auth::AWSCredentials AWSAuthV4Signer::GetCredentials(const std::shared_ptr<
AWS_UNREFERENCED_PARAM(serviceSpecificParameters);
return m_credentialsProvider->GetAWSCredentials();
}

void AWSAuthV4Signer::UpdateUserAgentWithCredentialFeatures(Aws::Http::HttpRequest& request, const Aws::Auth::CredentialsResolutionContext& context) const {
if (!request.HasHeader(USER_AGENT)) {
AWS_LOGSTREAM_DEBUG(v4LogTag, "Request does not have User-Agent header, skipping credential feature update");
return;
}

const auto features = context.GetUserAgentFeatures();
if (features.empty()) {
AWS_LOGSTREAM_DEBUG(v4LogTag, "No credential features to add to User-Agent");
return;
}

std::vector<Aws::String> businessMetrics(features.size());
std::transform(features.begin(),
features.end(),
businessMetrics.begin(),
[](UserAgentFeature feature) -> Aws::String { return UserAgent::BusinessMetricForFeature(feature); });

const auto credentialFeatures = std::accumulate(std::next(businessMetrics.begin()),
businessMetrics.end(),
businessMetrics.front(),
[](const Aws::String& a, const Aws::String& b) {
return a + "," + b;
});

const auto userAgent = request.GetHeaderValue(USER_AGENT);
auto userAgentParsed = Aws::Utils::StringUtils::Split(userAgent, ' ');
auto metricsSegment = std::find_if(userAgentParsed.begin(), userAgentParsed.end(),
[](const Aws::String& value) { return value.find("m/") != Aws::String::npos; });

if (metricsSegment != userAgentParsed.end()) {
// Add new metrics to existing metrics section
*metricsSegment = Aws::String{*metricsSegment + "," + credentialFeatures};
} else {
// No metrics section exists, add new one
userAgentParsed.push_back("m/" + credentialFeatures);
}

// Reassemble all parts with spaces
const auto newUserAgent = std::accumulate(std::next(userAgentParsed.begin()),
userAgentParsed.end(),
userAgentParsed.front(),
[](const Aws::String& a, const Aws::String& b) {
return a + " " + b;
});
request.SetUserAgent(newUserAgent);
}
23 changes: 12 additions & 11 deletions src/aws-cpp-sdk-core/source/client/UserAgent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,9 @@ const std::pair<UserAgentFeature, const char*> BUSINESS_METRIC_MAPPING[] = {
{UserAgentFeature::ACCOUNT_ID_MODE_REQUIRED, "R"},
{UserAgentFeature::RESOLVED_ACCOUNT_ID, "T"},
{UserAgentFeature::GZIP_REQUEST_COMPRESSION, "L"},
{UserAgentFeature::CREDENTIALS_ENV_VARS, "g"},
};

Aws::String BusinessMetricForFeature(UserAgentFeature feature) {
const auto* const metric =
std::find_if(std::begin(BUSINESS_METRIC_MAPPING), std::end(BUSINESS_METRIC_MAPPING),
[feature](const std::pair<UserAgentFeature, const char*>& pair) -> bool { return pair.first == feature; });
if (metric == std::end(BUSINESS_METRIC_MAPPING)) {
AWS_LOGSTREAM_ERROR(LOG_TAG, "business metric mapping not found for feature");
return {};
}
return metric->second;
}

const std::pair<const char*, UserAgentFeature> RETRY_FEATURE_MAPPING[] = {
{"default", UserAgentFeature::RETRY_MODE_LEGACY},
{"standard", UserAgentFeature::RETRY_MODE_STANDARD},
Expand Down Expand Up @@ -96,6 +86,17 @@ const char* APP_ID = "app";
const char* BUSINESS_METRICS = "m";
} // namespace

Aws::String UserAgent::BusinessMetricForFeature(UserAgentFeature feature) {
const auto* const metric =
std::find_if(std::begin(BUSINESS_METRIC_MAPPING), std::end(BUSINESS_METRIC_MAPPING),
[feature](const std::pair<UserAgentFeature, const char*>& pair) -> bool { return pair.first == feature; });
if (metric == std::end(BUSINESS_METRIC_MAPPING)) {
AWS_LOGSTREAM_ERROR(LOG_TAG, "business metric mapping not found for feature");
return {};
}
return metric->second;
}

UserAgent::UserAgent(const ClientConfiguration& clientConfiguration,
const Aws::String& retryStrategyName,
const Aws::String& apiName)
Expand Down
129 changes: 129 additions & 0 deletions tests/aws-cpp-sdk-core-tests/aws/auth/CredentialTrackingTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

#include <aws/testing/AwsCppSdkGTestSuite.h>
#include <aws/testing/AwsTestHelpers.h>
#include <aws/testing/mocks/aws/client/MockAWSClient.h>
#include <aws/testing/mocks/http/MockHttpClient.h>
#include <aws/testing/platform/PlatformTesting.h>
#include <aws/core/auth/AWSCredentialsProvider.h>
#include <aws/core/auth/AWSCredentialsProviderChain.h>
#include <aws/core/client/AWSClient.h>
#include <aws/core/utils/StringUtils.h>

using namespace Aws::Client;
using namespace Aws::Auth;
using namespace Aws::Http;

static const char ALLOCATION_TAG[] = "CredentialTrackingTest";

// Custom client that uses default credential provider for testing
class CredentialTestingClient : public Aws::Client::AWSClient
{
public:
explicit CredentialTestingClient(const Aws::Client::ClientConfiguration& configuration)
: AWSClient(configuration,
Aws::MakeShared<Aws::Client::AWSAuthV4Signer>(ALLOCATION_TAG,
Aws::MakeShared<DefaultAWSCredentialsProviderChain>(ALLOCATION_TAG),
"service", configuration.region),
Aws::MakeShared<MockAWSErrorMarshaller>(ALLOCATION_TAG))
{
}

Aws::Client::HttpResponseOutcome MakeRequest(const Aws::AmazonWebServiceRequest& request)
{
auto uri = Aws::Http::URI("https://test.com");
return AWSClient::AttemptExhaustively(uri, request, Aws::Http::HttpMethod::HTTP_POST, Aws::Auth::SIGV4_SIGNER);
}

const char* GetServiceClientName() const override { return "CredentialTestingClient"; }

protected:
Aws::Client::AWSError<Aws::Client::CoreErrors> BuildAWSError(const std::shared_ptr<Aws::Http::HttpResponse>& response) const override
{
AWS_UNREFERENCED_PARAM(response);
return Aws::Client::AWSError<Aws::Client::CoreErrors>(Aws::Client::CoreErrors::UNKNOWN, false);
}
};

class CredentialTrackingTest : public Aws::Testing::AwsCppSdkGTestSuite
{
protected:
std::shared_ptr<MockHttpClient> mockHttpClient;
std::shared_ptr<MockHttpClientFactory> mockHttpClientFactory;

void SetUp() override
{
mockHttpClient = Aws::MakeShared<MockHttpClient>(ALLOCATION_TAG);
mockHttpClientFactory = Aws::MakeShared<MockHttpClientFactory>(ALLOCATION_TAG);
mockHttpClientFactory->SetClient(mockHttpClient);
SetHttpClientFactory(mockHttpClientFactory);
}

void TearDown() override
{
mockHttpClient->Reset();
mockHttpClient = nullptr;
mockHttpClientFactory = nullptr;
Aws::Http::CleanupHttp();
Aws::Http::InitHttp();
}
};

TEST_F(CredentialTrackingTest, TestEnvironmentCredentialsTracking)
{
Aws::Environment::EnvironmentRAII testEnvironment{{
{"AWS_ACCESS_KEY_ID", "test-access-key"},
{"AWS_SECRET_ACCESS_KEY", "test-secret-key"},
}};

// Setup mock response
std::shared_ptr<HttpRequest> requestTmp =
CreateHttpRequest(Aws::Http::URI("dummy"), Aws::Http::HttpMethod::HTTP_POST,
Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
auto successResponse = Aws::MakeShared<Standard::StandardHttpResponse>(ALLOCATION_TAG, requestTmp);
successResponse->SetResponseCode(HttpResponseCode::OK);
successResponse->GetResponseBody() << "{}";
mockHttpClient->AddResponseToReturn(successResponse);

// Create client configuration
Aws::Client::ClientConfigurationInitValues cfgInit;
cfgInit.shouldDisableIMDS = true;
Aws::Client::ClientConfiguration clientConfig(cfgInit);
clientConfig.region = Aws::Region::US_EAST_1;

// Create credential testing client that uses default provider chain
CredentialTestingClient client(clientConfig);

// Create mock request
AmazonWebServiceRequestMock mockRequest;

// Make request
auto outcome = client.MakeRequest(mockRequest);
ASSERT_TRUE(outcome.IsSuccess());

// Verify User-Agent contains environment credentials tracking
auto lastRequest = mockHttpClient->GetMostRecentHttpRequest();
EXPECT_TRUE(lastRequest.HasHeader(Aws::Http::USER_AGENT_HEADER));
const auto& userAgent = lastRequest.GetHeaderValue(Aws::Http::USER_AGENT_HEADER);
EXPECT_FALSE(userAgent.empty());

const auto userAgentParsed = Aws::Utils::StringUtils::Split(userAgent, ' ');

// Verify there's only one m/ section (no duplicate m/ sections)
int mSectionCount = 0;
for (const auto& part : userAgentParsed) {
if (part.find("m/") != Aws::String::npos) {
mSectionCount++;
}
}
EXPECT_EQ(1, mSectionCount);

// Check for environment credentials business metric (g) in user agent
auto businessMetrics = std::find_if(userAgentParsed.begin(), userAgentParsed.end(),
[](const Aws::String& value) { return value.find("m/") != Aws::String::npos && value.find("g") != Aws::String::npos; });

EXPECT_TRUE(businessMetrics != userAgentParsed.end());
}
Loading