Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions extension/httpfs/create_secret_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ void CreateS3SecretFunctions::Register(DatabaseInstance &instance) {
RegisterCreateSecretFunction(instance, "gcs");
}

static Value MapToStruct(const Value &map){
static Value MapToStruct(const Value &map) {
auto children = MapValue::GetChildren(map);

child_list_t<Value> struct_fields;
Expand Down Expand Up @@ -109,15 +109,17 @@ unique_ptr<BaseSecret> CreateS3SecretFunctions::CreateSecretFunctionInternal(Cli
refresh = true;
secret->secret_map["refresh_info"] = MapToStruct(named_param.second);
} else {
throw InvalidInputException("Unknown named parameter passed to CreateSecretFunctionInternal: " + lower_name);
throw InvalidInputException("Unknown named parameter passed to CreateSecretFunctionInternal: " +
lower_name);
}
}

return std::move(secret);
}

CreateSecretInput CreateS3SecretFunctions::GenerateRefreshSecretInfo(const SecretEntry &secret_entry, Value &refresh_info) {
const auto &kv_secret = dynamic_cast<const KeyValueSecret&>(*secret_entry.secret);
CreateSecretInput CreateS3SecretFunctions::GenerateRefreshSecretInfo(const SecretEntry &secret_entry,
Value &refresh_info) {
const auto &kv_secret = dynamic_cast<const KeyValueSecret &>(*secret_entry.secret);

CreateSecretInput result;
result.on_conflict = OnCreateConflict::REPLACE_ON_CONFLICT;
Expand All @@ -143,7 +145,7 @@ CreateSecretInput CreateS3SecretFunctions::GenerateRefreshSecretInfo(const Secre

//! Function that will automatically try to refresh a secret
bool CreateS3SecretFunctions::TryRefreshS3Secret(ClientContext &context, const SecretEntry &secret_to_refresh) {
const auto &kv_secret = dynamic_cast<const KeyValueSecret&>(*secret_to_refresh.secret);
const auto &kv_secret = dynamic_cast<const KeyValueSecret &>(*secret_to_refresh.secret);

Value refresh_info;
if (!kv_secret.TryGetValue("refresh_info", refresh_info)) {
Expand All @@ -155,12 +157,15 @@ bool CreateS3SecretFunctions::TryRefreshS3Secret(ClientContext &context, const S
// TODO: change SecretManager API to avoid requiring catching this exception
try {
auto res = secret_manager.CreateSecret(context, refresh_input);
auto &new_secret = dynamic_cast<const KeyValueSecret&>(*res->secret);
DUCKDB_LOG_INFO(context, "httpfs.SecretRefresh", "Successfully refreshed secret: %s, new key_id: %s", secret_to_refresh.secret->GetName(), new_secret.TryGetValue("key_id").ToString());
auto &new_secret = dynamic_cast<const KeyValueSecret &>(*res->secret);
DUCKDB_LOG_INFO(context, "httpfs.SecretRefresh", "Successfully refreshed secret: %s, new key_id: %s",
secret_to_refresh.secret->GetName(), new_secret.TryGetValue("key_id").ToString());
return true;
} catch (std::exception &ex) {
ErrorData error(ex);
string new_message = StringUtil::Format("Exception thrown while trying to refresh secret %s. To fix this, please recreate or remove the secret and try again. Error: '%s'", secret_to_refresh.secret->GetName(), error.Message());
string new_message = StringUtil::Format("Exception thrown while trying to refresh secret %s. To fix this, "
"please recreate or remove the secret and try again. Error: '%s'",
secret_to_refresh.secret->GetName(), error.Message());
throw Exception(error.Type(), new_message);
}
}
Expand Down
189 changes: 92 additions & 97 deletions extension/httpfs/crypto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,34 +44,33 @@ AESStateSSL::~AESStateSSL() {

const EVP_CIPHER *AESStateSSL::GetCipher(const string &key) {

switch (cipher) {
case GCM:
switch (key.size()) {
case 16:
return EVP_aes_128_gcm();
case 24:
return EVP_aes_192_gcm();
case 32:
return EVP_aes_256_gcm();
default:
throw InternalException("Invalid AES key length");
}
case CTR:
switch (key.size()) {
case 16:
return EVP_aes_128_ctr();
case 24:
return EVP_aes_192_ctr();
case 32:
return EVP_aes_256_ctr();
default:
throw InternalException("Invalid AES key length");
}

default:
throw duckdb::InternalException("Invalid Encryption/Decryption Cipher: %d",
static_cast<int>(cipher));
}
switch (cipher) {
case GCM:
switch (key.size()) {
case 16:
return EVP_aes_128_gcm();
case 24:
return EVP_aes_192_gcm();
case 32:
return EVP_aes_256_gcm();
default:
throw InternalException("Invalid AES key length");
}
case CTR:
switch (key.size()) {
case 16:
return EVP_aes_128_ctr();
case 24:
return EVP_aes_192_ctr();
case 32:
return EVP_aes_256_ctr();
default:
throw InternalException("Invalid AES key length");
}

default:
throw duckdb::InternalException("Invalid Encryption/Decryption Cipher: %d", static_cast<int>(cipher));
}
}

void AESStateSSL::GenerateRandomData(data_ptr_t data, idx_t len) {
Expand Down Expand Up @@ -121,79 +120,75 @@ size_t AESStateSSL::Process(const_data_ptr_t in, idx_t in_len, data_ptr_t out, i
return out_len;
}

size_t AESStateSSL::FinalizeGCM(data_ptr_t out, idx_t out_len, data_ptr_t tag, idx_t tag_len){
auto text_len = out_len;

switch (mode) {
case ENCRYPT:
{
if (1 != EVP_EncryptFinal_ex(context, data_ptr_cast(out) + out_len, reinterpret_cast<int *>(&out_len))) {
throw InternalException("EncryptFinal failed");
}
text_len += out_len;

// The computed tag is written at the end of a chunk
if (1 != EVP_CIPHER_CTX_ctrl(context, EVP_CTRL_GCM_GET_TAG, tag_len, tag)) {
throw InternalException("Calculating the tag failed");
}
return text_len;
}
case DECRYPT:
{
// Set expected tag value
if (!EVP_CIPHER_CTX_ctrl(context, EVP_CTRL_GCM_SET_TAG, tag_len, tag)) {
throw InternalException("Finalizing tag failed");
}

// EVP_DecryptFinal() will return an error code if final block is not correctly formatted.
int ret = EVP_DecryptFinal_ex(context, data_ptr_cast(out) + out_len, reinterpret_cast<int *>(&out_len));
text_len += out_len;

if (ret > 0) {
// success
return text_len;
}
throw InvalidInputException("Computed AES tag differs from read AES tag, are you using the right key?");
}
default:
throw InternalException("Unhandled encryption mode %d", static_cast<int>(mode));
}
size_t AESStateSSL::FinalizeGCM(data_ptr_t out, idx_t out_len, data_ptr_t tag, idx_t tag_len) {
auto text_len = out_len;

switch (mode) {
case ENCRYPT: {
if (1 != EVP_EncryptFinal_ex(context, data_ptr_cast(out) + out_len, reinterpret_cast<int *>(&out_len))) {
throw InternalException("EncryptFinal failed");
}
text_len += out_len;

// The computed tag is written at the end of a chunk
if (1 != EVP_CIPHER_CTX_ctrl(context, EVP_CTRL_GCM_GET_TAG, tag_len, tag)) {
throw InternalException("Calculating the tag failed");
}
return text_len;
}
case DECRYPT: {
// Set expected tag value
if (!EVP_CIPHER_CTX_ctrl(context, EVP_CTRL_GCM_SET_TAG, tag_len, tag)) {
throw InternalException("Finalizing tag failed");
}

// EVP_DecryptFinal() will return an error code if final block is not correctly formatted.
int ret = EVP_DecryptFinal_ex(context, data_ptr_cast(out) + out_len, reinterpret_cast<int *>(&out_len));
text_len += out_len;

if (ret > 0) {
// success
return text_len;
}
throw InvalidInputException("Computed AES tag differs from read AES tag, are you using the right key?");
}
default:
throw InternalException("Unhandled encryption mode %d", static_cast<int>(mode));
}
}

size_t AESStateSSL::Finalize(data_ptr_t out, idx_t out_len, data_ptr_t tag, idx_t tag_len) {

if (cipher == GCM){
return FinalizeGCM(out, out_len, tag, tag_len);
}

auto text_len = out_len;
switch (mode) {

case ENCRYPT:
{
if (1 != EVP_EncryptFinal_ex(context, data_ptr_cast(out) + out_len, reinterpret_cast<int *>(&out_len))) {
throw InternalException("EncryptFinal failed");
}

return text_len += out_len;
}

case DECRYPT:
{
// EVP_DecryptFinal() will return an error code if final block is not correctly formatted.
int ret = EVP_DecryptFinal_ex(context, data_ptr_cast(out) + out_len, reinterpret_cast<int *>(&out_len));
text_len += out_len;

if (ret > 0) {
// success
return text_len;
}

throw InvalidInputException("Computed AES tag differs from read AES tag, are you using the right key?");
}
default:
throw InternalException("Unhandled encryption mode %d", static_cast<int>(mode));
}
if (cipher == GCM) {
return FinalizeGCM(out, out_len, tag, tag_len);
}

auto text_len = out_len;
switch (mode) {

case ENCRYPT: {
if (1 != EVP_EncryptFinal_ex(context, data_ptr_cast(out) + out_len, reinterpret_cast<int *>(&out_len))) {
throw InternalException("EncryptFinal failed");
}

return text_len += out_len;
}

case DECRYPT: {
// EVP_DecryptFinal() will return an error code if final block is not correctly formatted.
int ret = EVP_DecryptFinal_ex(context, data_ptr_cast(out) + out_len, reinterpret_cast<int *>(&out_len));
text_len += out_len;

if (ret > 0) {
// success
return text_len;
}

throw InvalidInputException("Computed AES tag differs from read AES tag, are you using the right key?");
}
default:
throw InternalException("Unhandled encryption mode %d", static_cast<int>(mode));
}
}

} // namespace duckdb
Expand Down
36 changes: 18 additions & 18 deletions extension/httpfs/hffs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,21 @@ string HuggingFaceFileSystem::ListHFRequest(ParsedHFUrl &url, HTTPFSParams &http
string link_header_result;

std::stringstream response;
GetRequestInfo get_request(url.endpoint, next_page_url, header_map, http_params,
[&](const HTTPResponse &response) {
if (static_cast<int>(response.status) >= 400) {
throw HTTPException(response, "HTTP GET error on '%s' (HTTP %d)", next_page_url, response.status);
}
if (response.HasHeader("Link")) {
link_header_result = response.GetHeaderValue("Link");
}
return true;
},
[&](const_data_ptr_t data, idx_t data_length) {
response << string(const_char_ptr_cast(data), data_length);
return true;
});
GetRequestInfo get_request(
url.endpoint, next_page_url, header_map, http_params,
[&](const HTTPResponse &response) {
if (static_cast<int>(response.status) >= 400) {
throw HTTPException(response, "HTTP GET error on '%s' (HTTP %d)", next_page_url, response.status);
}
if (response.HasHeader("Link")) {
link_header_result = response.GetHeaderValue("Link");
}
return true;
},
[&](const_data_ptr_t data, idx_t data_length) {
response << string(const_char_ptr_cast(data), data_length);
return true;
});
auto res = http_params.http_util->Request(get_request);
if (res->status != HTTPStatusCode::OK_200) {
throw IOException(res->GetError() + " error for HTTP GET to '" + next_page_url + "'");
Expand Down Expand Up @@ -248,8 +249,7 @@ vector<OpenFileInfo> HuggingFaceFileSystem::Glob(const string &path, FileOpener
return result;
}

unique_ptr<HTTPResponse> HuggingFaceFileSystem::HeadRequest(FileHandle &handle, string hf_url,
HTTPHeaders header_map) {
unique_ptr<HTTPResponse> HuggingFaceFileSystem::HeadRequest(FileHandle &handle, string hf_url, HTTPHeaders header_map) {
auto &hf_handle = handle.Cast<HFFileHandle>();
auto http_url = HuggingFaceFileSystem::GetFileUrl(hf_handle.parsed_url);
return HTTPFileSystem::HeadRequest(handle, http_url, header_map);
Expand All @@ -262,8 +262,8 @@ unique_ptr<HTTPResponse> HuggingFaceFileSystem::GetRequest(FileHandle &handle, s
}

unique_ptr<HTTPResponse> HuggingFaceFileSystem::GetRangeRequest(FileHandle &handle, string s3_url,
HTTPHeaders header_map, idx_t file_offset,
char *buffer_out, idx_t buffer_out_len) {
HTTPHeaders header_map, idx_t file_offset,
char *buffer_out, idx_t buffer_out_len) {
auto &hf_handle = handle.Cast<HFFileHandle>();
auto http_url = HuggingFaceFileSystem::GetFileUrl(hf_handle.parsed_url);
return HTTPFileSystem::GetRangeRequest(handle, http_url, header_map, file_offset, buffer_out, buffer_out_len);
Expand Down
Loading
Loading