From bbbdb0bd3ea3fb00d507026e2806fba32a0455f2 Mon Sep 17 00:00:00 2001 From: Simon Unge Date: Tue, 1 Jul 2025 17:19:24 +0000 Subject: [PATCH 01/11] wip --- deps/rabbitmq_aws/include/rabbitmq_aws.hrl | 4 +- deps/rabbitmq_aws/src/rabbitmq_aws.erl | 259 ++++++++++++++++-- deps/rabbitmq_aws/src/rabbitmq_aws_config.erl | 149 ++++++++-- deps/rabbitmq_aws/src/rabbitmq_aws_xml.erl | 2 + .../test/rabbitmq_aws_config_tests.erl | 182 +++++++----- deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl | 214 +++++++++------ 6 files changed, 606 insertions(+), 204 deletions(-) diff --git a/deps/rabbitmq_aws/include/rabbitmq_aws.hrl b/deps/rabbitmq_aws/include/rabbitmq_aws.hrl index 6a0cacd81131..2106b3a372df 100644 --- a/deps/rabbitmq_aws/include/rabbitmq_aws.hrl +++ b/deps/rabbitmq_aws/include/rabbitmq_aws.hrl @@ -71,7 +71,9 @@ security_token :: security_token() | undefined, region :: region() | undefined, imdsv2_token :: imdsv2token() | undefined, - error :: atom() | string() | undefined + error :: atom() | string() | undefined, + % host -> gun_pid mapping + gun_connections = #{} :: #{string() => pid()} }). -type state() :: #state{}. diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index e0c85ec55372..d54f890e0af0 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -10,15 +10,19 @@ %% API exports -export([ - get/2, get/3, + get/2, get/3, get/4, + put/4, put/5, post/4, refresh_credentials/0, request/5, request/6, request/7, set_credentials/2, has_credentials/0, + parse_uri/1, set_region/1, ensure_imdsv2_token_valid/0, - api_get_request/2 + api_get_request/2, + close_connection/3, + status_text/1 ]). %% gen-server exports @@ -65,7 +69,10 @@ get(Service, Path) -> %% format. %% @end get(Service, Path, Headers) -> - request(Service, get, Path, "", Headers). + request(Service, get, Path, "", Headers, []). + +get(Service, Path, Headers, Options) -> + request(Service, get, Path, "", Headers, Options). -spec post( Service :: string(), @@ -80,12 +87,31 @@ get(Service, Path, Headers) -> post(Service, Path, Body, Headers) -> request(Service, post, Path, Body, Headers). +-spec put( + Service :: string(), + Path :: path(), + Body :: body(), + Headers :: headers() +) -> result(). +%% @doc Perform a HTTP Post request to the AWS API for the specified service. The +%% response will automatically be decoded if it is either in JSON or XML +%% format. +%% @end +put(Service, Path, Body, Headers) -> + put(Service, Path, Body, Headers, []). + +put(Service, Path, Body, Headers, Options) -> + request(Service, put, Path, Body, Headers, Options). + -spec refresh_credentials() -> ok | error. %% @doc Manually refresh the credentials from the environment, filesystem or EC2 Instance Metadata Service. %% @end refresh_credentials() -> gen_server:call(rabbitmq_aws, refresh_credentials). +close_connection(Service, Path, Options) -> + gen_server:cast(?MODULE, {close_connection, Service, Path, Options}). + -spec refresh_credentials(state()) -> ok | error. %% @doc Manually refresh the credentials from the environment, filesystem or EC2 Instance Metadata Service. %% @end @@ -186,9 +212,18 @@ start_link() -> -spec init(list()) -> {ok, state()}. init([]) -> + {ok, _} = application:ensure_all_started(gun), {ok, #state{}}. -terminate(_, _) -> +terminate(_, State) -> + %% Close all Gun connections + maps:fold( + fun(_Host, ConnPid, _Acc) -> + gun:close(ConnPid) + end, + ok, + State#state.gun_connections + ), ok. code_change(_, _, State) -> @@ -197,6 +232,8 @@ code_change(_, _, State) -> handle_call(Msg, _From, State) -> handle_msg(Msg, State). +handle_cast({close_connection, Service, Path, Options}, State) -> + {noreply, close_connection(Service, Path, Options, State)}; handle_cast(_Request, State) -> {noreply, State}. @@ -225,12 +262,23 @@ handle_msg({set_credentials, AccessKey, SecretAccessKey}, State) -> error = undefined }}; handle_msg({set_credentials, NewState}, State) -> + spawn(fun() -> + maps:fold( + fun(_Host, ConnPid, _Acc) -> + gun:close(ConnPid) + end, + ok, + State#state.gun_connections + ) + end), {reply, ok, State#state{ access_key = NewState#state.access_key, secret_access_key = NewState#state.secret_access_key, security_token = NewState#state.security_token, expiration = NewState#state.expiration, - error = NewState#state.error + error = NewState#state.error, + % Potentially new credentials, so clear the connection pool? + gun_connections = #{} }}; handle_msg({set_region, Region}, State) -> {reply, ok, State#state{region = Region}}; @@ -282,6 +330,8 @@ endpoint_tld(_Other) -> %% @end format_response({ok, {{_Version, 200, _Message}, Headers, Body}}) -> {ok, {Headers, maybe_decode_body(get_content_type(Headers), Body)}}; +format_response({ok, {{_Version, 206, _Message}, Headers, Body}}) -> + {ok, {Headers, maybe_decode_body(get_content_type(Headers), Body)}}; format_response({ok, {{_Version, StatusCode, Message}, Headers, Body}}) when StatusCode >= 400 -> {error, Message, {Headers, maybe_decode_body(get_content_type(Headers), Body)}}; format_response({error, Reason}) -> @@ -293,9 +343,9 @@ format_response({error, Reason}) -> %% @end get_content_type(Headers) -> Value = - case proplists:get_value("content-type", Headers, undefined) of + case proplists:get_value(<<"content-type">>, Headers, undefined) of undefined -> - proplists:get_value("Content-Type", Headers, "text/xml"); + proplists:get_value(<<"Content-Type">>, Headers, "text/xml"); Other -> Other end, @@ -329,7 +379,7 @@ expired_credentials(Expiration) -> %% - Credentials file %% - EC2 Instance Metadata Service %% @end -load_credentials(#state{region = Region}) -> +load_credentials(#state{region = Region, gun_connections = GunConnections}) -> case rabbitmq_aws_config:credentials() of {ok, AccessKey, SecretAccessKey, Expiration, SecurityToken} -> {ok, #state{ @@ -339,7 +389,8 @@ load_credentials(#state{region = Region}) -> secret_access_key = SecretAccessKey, expiration = Expiration, security_token = SecurityToken, - imdsv2_token = undefined + imdsv2_token = undefined, + gun_connections = GunConnections }}; {error, Reason} -> ?LOG_ERROR( @@ -353,7 +404,8 @@ load_credentials(#state{region = Region}) -> secret_access_key = undefined, expiration = undefined, security_token = undefined, - imdsv2_token = undefined + imdsv2_token = undefined, + gun_connections = GunConnections }} end. @@ -368,6 +420,8 @@ local_time() -> list() | body(). %% @doc Attempt to decode the response body by its MIME %% @end +maybe_decode_body(_, <<>>) -> + <<>>; maybe_decode_body({"application", "x-amz-json-1.0"}, Body) -> rabbitmq_aws_json:decode(Body); maybe_decode_body({"application", "json"}, Body) -> @@ -380,6 +434,8 @@ maybe_decode_body(_ContentType, Body) -> -spec parse_content_type(ContentType :: string()) -> {Type :: string(), Subtype :: string()}. %% @doc parse a content type string returning a tuple of type/subtype %% @end +parse_content_type(ContentType) when is_binary(ContentType) -> + parse_content_type(binary_to_list(ContentType)); parse_content_type(ContentType) -> Parts = string:tokens(ContentType, ";"), [Type, Subtype] = string:tokens(lists:nth(1, Parts), "/"), @@ -480,15 +536,13 @@ perform_request_creds_expired(true, State, _, _, _, _, _, _, _) -> perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, Host) -> URI = endpoint(State, Host, Service, Path), SignedHeaders = sign_headers(State, Service, Method, URI, Headers, Body), - ContentType = proplists:get_value("content-type", SignedHeaders, undefined), - perform_request_with_creds(State, Method, URI, SignedHeaders, ContentType, Body, Options). + perform_request_with_creds(State, Method, URI, SignedHeaders, Body, Options). -spec perform_request_with_creds( State :: state(), Method :: method(), URI :: string(), Headers :: headers(), - ContentType :: string() | undefined, Body :: body(), Options :: http_options() ) -> @@ -496,14 +550,12 @@ perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, %% @doc Once it is validated that there are credentials to try and that they have not %% expired, perform the request and return the response. %% @end -perform_request_with_creds(State, Method, URI, Headers, undefined, "", Options0) -> - Options1 = ensure_timeout(Options0), - Response = httpc:request(Method, {URI, Headers}, Options1, []), - {format_response(Response), State}; -perform_request_with_creds(State, Method, URI, Headers, ContentType, Body, Options0) -> - Options1 = ensure_timeout(Options0), - Response = httpc:request(Method, {URI, Headers, ContentType, Body}, Options1, []), - {format_response(Response), State}. +perform_request_with_creds(State, Method, URI, Headers, "", Options0) -> + {Response, NewState} = gun_request(State, Method, URI, Headers, <<>>, Options0), + {format_response(Response), NewState}; +perform_request_with_creds(State, Method, URI, Headers, Body, Options0) -> + {Response, NewState} = gun_request(State, Method, URI, Headers, Body, Options0), + {format_response(Response), NewState}. -spec perform_request_creds_error(State :: state()) -> {result_error(), NewState :: state()}. @@ -648,3 +700,168 @@ api_get_request_with_retries(Service, Path, Retries, WaitTimeBetweenRetries) -> timer:sleep(WaitTimeBetweenRetries), api_get_request_with_retries(Service, Path, Retries - 1, WaitTimeBetweenRetries) end. + +%% Gun HTTP client functions +gun_request(State, Method, URI, Headers, Body, Options) -> + HeadersBin = lists:map( + fun({Key, Value}) -> + {list_to_binary(Key), list_to_binary(Value)} + end, + Headers + ), + {Host, Port, Path} = parse_uri(URI), + {ConnPid, NewState} = get_or_create_gun_connection(State, Host, Port, Path, Options), + Timeout = proplists:get_value(timeout, Options, ?DEFAULT_HTTP_TIMEOUT), + try + StreamRef = do_gun_request(ConnPid, Method, Path, HeadersBin, Body), + case gun:await(ConnPid, StreamRef, Timeout) of + {response, fin, Status, RespHeaders} -> + Response = {ok, {{http_version, Status, status_text(Status)}, RespHeaders, <<>>}}, + {Response, NewState}; + {response, nofin, Status, RespHeaders} -> + {ok, RespBody} = gun:await_body(ConnPid, StreamRef, Timeout), + Response = + {ok, {{http_version, Status, status_text(Status)}, RespHeaders, RespBody}}, + {Response, NewState}; + {error, Reason} -> + {{error, Reason}, NewState} + end + catch + _:Error -> + % Connection failed, remove from pool and return error + HostKey = get_connection_key(Host, Port, Path, Options), + NewConnections = maps:remove(HostKey, NewState#state.gun_connections), + gun:close(ConnPid), + {{error, Error}, NewState#state{gun_connections = NewConnections}} + end. + +do_gun_request(ConnPid, get, Path, Headers, _Body) -> + gun:get(ConnPid, Path, Headers); +do_gun_request(ConnPid, post, Path, Headers, Body) -> + gun:post(ConnPid, Path, Headers, Body, #{}); +do_gun_request(ConnPid, put, Path, Headers, Body) -> + gun:put(ConnPid, Path, Headers, Body, #{}); +do_gun_request(ConnPid, head, Path, Headers, _Body) -> + gun:head(ConnPid, Path, Headers, #{}); +do_gun_request(ConnPid, delete, Path, Headers, _Body) -> + gun:delete(ConnPid, Path, Headers, #{}); +do_gun_request(ConnPid, patch, Path, Headers, Body) -> + gun:patch(ConnPid, Path, Headers, Body, #{}); +do_gun_request(ConnPid, options, Path, Headers, _Body) -> + gun:options(ConnPid, Path, Headers, #{}). + +get_or_create_gun_connection(State, Host, Port, Path, Options) -> + HostKey = get_connection_key(Host, Port, Path, Options), + case maps:get(HostKey, State#state.gun_connections, undefined) of + undefined -> + create_gun_connection(State, Host, Port, HostKey, Options); + ConnPid -> + case is_process_alive(ConnPid) andalso gun:info(ConnPid) =/= undefined of + true -> + {ConnPid, State}; + false -> + % Connection is dead, create new one + gun:close(ConnPid), + create_gun_connection(State, Host, Port, HostKey, Options) + end + end. + +get_connection_key(Host, Port, Path, Options) -> + case proplists:get_value(connection_key_type, Options, host) of + host -> + Host ++ ":" ++ integer_to_list(Port); + path -> + Host ++ ":" ++ integer_to_list(Port) ++ Path; + {path_custom, Extra} -> + Host ++ ":" ++ integer_to_list(Port) ++ Path ++ ":" ++ Extra; + _ -> + Host ++ ":" ++ integer_to_list(Port) + end. + +create_gun_connection(State, Host, Port, HostKey, Options) -> + % Map HTTP version to Gun protocols, always include http as fallback + HttpVersion = proplists:get_value(version, Options, "HTTP/1.1"), + Protocols = + case HttpVersion of + "HTTP/2" -> [http2, http]; + "HTTP/2.0" -> [http2, http]; + "HTTP/1.1" -> [http]; + "HTTP/1.0" -> [http]; + % Default: try HTTP/2, fallback to HTTP/1.1 + _ -> [http2, http] + end, + ConnectTimeout = proplists:get_value(connect_timeout, Options, 5000), + Opts = #{ + transport => + if + Port == 443 -> tls; + true -> tcp + end, + protocols => Protocols, + connect_timeout => ConnectTimeout + }, + case gun:open(Host, Port, Opts) of + {ok, ConnPid} -> + case gun:await_up(ConnPid, ConnectTimeout) of + {ok, _Protocol} -> + NewConnections = maps:put(HostKey, ConnPid, State#state.gun_connections), + NewState = State#state{gun_connections = NewConnections}, + {ConnPid, NewState}; + {error, Reason} -> + gun:close(ConnPid), + error({gun_connection_failed, Reason}) + end; + {error, Reason} -> + error({gun_open_failed, Reason}) + end. + +close_connection(Service, Path, Options, State) -> + URI = endpoint(State, undefined, Service, Path), + {Host, Port, Path} = parse_uri(URI), + HostKey = get_connection_key(Host, Port, Path, Options), + case maps:get(HostKey, State#state.gun_connections, undefined) of + undefined -> + State; + ConnPid -> + gun:close(ConnPid), + NewConnections = maps:remove(HostKey, State#state.gun_connections), + State#state{gun_connections = NewConnections} + end. + +parse_uri(URI) -> + case string:split(URI, "://", leading) of + [Scheme, Rest] -> + case string:split(Rest, "/", leading) of + [HostPort] -> + {Host, Port} = parse_host_port(HostPort, Scheme), + {Host, Port, "/"}; + [HostPort, Path] -> + {Host, Port} = parse_host_port(HostPort, Scheme), + {Host, Port, "/" ++ Path} + end + end. + +parse_host_port(HostPort, Scheme) -> + DefaultPort = + case Scheme of + "https" -> 443; + "http" -> 80; + % Fallback to HTTPS + _ -> 443 + end, + case string:split(HostPort, ":", trailing) of + [Host] -> + {Host, DefaultPort}; + [Host, PortStr] -> + {Host, list_to_integer(PortStr)} + end. + +status_text(200) -> "OK"; +status_text(206) -> "Partial Content"; +status_text(400) -> "Bad Request"; +status_text(401) -> "Unauthorized"; +status_text(403) -> "Forbidden"; +status_text(404) -> "Not Found"; +status_text(416) -> "Range Not Satisfiable"; +status_text(500) -> "Internal Server Error"; +status_text(Code) -> integer_to_list(Code). diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl index 3d2ae89fe918..4ba821249a99 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_config.erl @@ -629,9 +629,14 @@ maybe_get_role_from_instance_metadata() -> %% @doc Parse the response from the Availability Zone query to the %% Instance Metadata service, returning the Region if successful. %% end. -parse_az_response({error, _}) -> {error, undefined}; -parse_az_response({ok, {{_, 200, _}, _, Body}}) -> {ok, region_from_availability_zone(Body)}; -parse_az_response({ok, {{_, _, _}, _, _}}) -> {error, undefined}. +parse_az_response({error, _}) -> + {error, undefined}; +parse_az_response({ok, {{_, 200, _}, _, Body}}) when is_binary(Body) -> + {ok, region_from_availability_zone(binary_to_list(Body))}; +parse_az_response({ok, {{_, 200, _}, _, Body}}) -> + {ok, region_from_availability_zone(Body)}; +parse_az_response({ok, {{_, _, _}, _, _}}) -> + {error, undefined}. -spec parse_body_response(httpc_result()) -> {ok, Value :: string()} | {error, Reason :: atom()}. @@ -640,8 +645,9 @@ parse_az_response({ok, {{_, _, _}, _, _}}) -> {error, undefined}. %% end. parse_body_response({error, _}) -> {error, undefined}; -parse_body_response({ok, {{_, 200, _}, _, Body}}) -> - {ok, Body}; +parse_body_response({ok, {{_, 200, _}, _, Body}}) when is_binary(Body) -> + {ok, binary_to_list(Body)}; +parse_body_response({ok, {{_, 200, _}, _, Body}}) when is_list(Body) -> {ok, Body}; parse_body_response({ok, {{_, 401, _}, _, _}}) -> ?LOG_ERROR( get_instruction_on_instance_metadata_error( @@ -678,12 +684,47 @@ parse_credentials_response({ok, {{_, 200, _}, _, Body}}) -> %% @end perform_http_get_instance_metadata(URL) -> ?LOG_DEBUG("Querying instance metadata service: ~tp", [URL]), - httpc:request( - get, - {URL, instance_metadata_request_headers()}, - [{timeout, ?DEFAULT_HTTP_TIMEOUT}], - [] - ). + % Parse metadata service URL + {Host, Port, Path} = rabbitmq_aws:parse_uri(URL), + % Simple Gun connection for metadata service + + % HTTP only, no TLS + Opts = #{transport => tcp, protocols => [http]}, + case gun:open(Host, Port, Opts) of + {ok, ConnPid} -> + case gun:await_up(ConnPid, 5000) of + {ok, _Protocol} -> + Headers = instance_metadata_request_headers(), + StreamRef = gun:get(ConnPid, Path, Headers), + Result = + case gun:await(ConnPid, StreamRef, ?DEFAULT_HTTP_TIMEOUT) of + {response, fin, Status, RespHeaders} -> + {ok, { + {http_version, Status, rabbitmq_aws:status_text(Status)}, + RespHeaders, + <<>> + }}; + {response, nofin, Status, RespHeaders} -> + {ok, Body} = gun:await_body( + ConnPid, StreamRef, ?DEFAULT_HTTP_TIMEOUT + ), + {ok, { + {http_version, Status, rabbitmq_aws:status_text(Status)}, + RespHeaders, + Body + }}; + {error, Reason} -> + {error, Reason} + end, + gun:close(ConnPid), + Result; + {error, Reason} -> + gun:close(ConnPid), + {error, Reason} + end; + {error, Reason} -> + {error, Reason} + end. -spec get_instruction_on_instance_metadata_error(string()) -> string(). %% @doc Return error message on failures related to EC2 Instance Metadata Service with a reference to AWS document. @@ -742,29 +783,77 @@ region_from_availability_zone(Value) -> load_imdsv2_token() -> TokenUrl = imdsv2_token_url(), ?LOG_INFO("Attempting to obtain EC2 IMDSv2 token from ~tp ...", [TokenUrl]), - case - httpc:request( - put, - {TokenUrl, [{?METADATA_TOKEN_TTL_HEADER, integer_to_list(?METADATA_TOKEN_TTL_SECONDS)}]}, - [{timeout, ?DEFAULT_HTTP_TIMEOUT}], - [] - ) - of - {ok, {{_, 200, _}, _, Value}} -> - ?LOG_DEBUG("Successfully obtained EC2 IMDSv2 token."), - Value; - {error, {{_, 400, _}, _, _}} -> - ?LOG_WARNING( - "Failed to obtain EC2 IMDSv2 token: Missing or Invalid Parameters – The PUT request is not valid." - ), - undefined; - Other -> + % Parse metadata service URL + {Host, Port, Path} = rabbitmq_aws:parse_uri(TokenUrl), + % Simple Gun connection for metadata service + + % HTTP only, no TLS + Opts = #{transport => tcp, protocols => [http]}, + case gun:open(Host, Port, Opts) of + {ok, ConnPid} -> + case gun:await_up(ConnPid, 5000) of + {ok, _Protocol} -> + % PUT request with IMDSv2 token TTL header + Headers = [ + {?METADATA_TOKEN_TTL_HEADER, integer_to_list(?METADATA_TOKEN_TTL_SECONDS)} + ], + StreamRef = gun:put(ConnPid, Path, Headers, <<>>), + Result = + case gun:await(ConnPid, StreamRef, ?DEFAULT_HTTP_TIMEOUT) of + {response, fin, 200, _RespHeaders} -> + ?LOG_DEBUG("Successfully obtained EC2 IMDSv2 token."), + % Empty body for fin response + <<>>; + {response, nofin, 200, _RespHeaders} -> + {ok, Body} = gun:await_body( + ConnPid, StreamRef, ?DEFAULT_HTTP_TIMEOUT + ), + ?LOG_DEBUG("Successfully obtained EC2 IMDSv2 token."), + binary_to_list(Body); + {response, _, 400, _RespHeaders} -> + ?LOG_WARNING( + "Failed to obtain EC2 IMDSv2 token: Missing or Invalid Parameters – The PUT request is not valid." + ), + undefined; + {error, Reason} -> + ?LOG_WARNING( + get_instruction_on_instance_metadata_error( + "Failed to obtain EC2 IMDSv2 token: ~tp. " + "Falling back to EC2 IMDSv1 for now. It is recommended to use EC2 IMDSv2." + ), + [Reason] + ), + undefined; + Other -> + ?LOG_WARNING( + get_instruction_on_instance_metadata_error( + "Failed to obtain EC2 IMDSv2 token: ~tp. " + "Falling back to EC2 IMDSv1 for now. It is recommended to use EC2 IMDSv2." + ), + [Other] + ), + undefined + end, + gun:close(ConnPid), + Result; + {error, Reason} -> + gun:close(ConnPid), + ?LOG_WARNING( + get_instruction_on_instance_metadata_error( + "Failed to connect for EC2 IMDSv2 token: ~tp. " + "Falling back to EC2 IMDSv1 for now. It is recommended to use EC2 IMDSv2." + ), + [Reason] + ), + undefined + end; + {error, Reason} -> ?LOG_WARNING( get_instruction_on_instance_metadata_error( - "Failed to obtain EC2 IMDSv2 token: ~tp. " + "Failed to open connection for EC2 IMDSv2 token: ~tp. " "Falling back to EC2 IMDSv1 for now. It is recommended to use EC2 IMDSv2." ), - [Other] + [Reason] ), undefined end. diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_xml.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_xml.erl index 250fc1fc882e..98094aea87eb 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_xml.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_xml.erl @@ -11,6 +11,8 @@ -include_lib("xmerl/include/xmerl.hrl"). -spec parse(Value :: string() | binary()) -> list(). +parse(Value) when is_binary(Value) -> + parse(binary_to_list(Value)); parse(Value) -> {Element, _} = xmerl_scan:string(Value), parse_node(Element). diff --git a/deps/rabbitmq_aws/test/rabbitmq_aws_config_tests.erl b/deps/rabbitmq_aws/test/rabbitmq_aws_config_tests.erl index cca1b4af8231..fd6c30376c37 100644 --- a/deps/rabbitmq_aws/test/rabbitmq_aws_config_tests.erl +++ b/deps/rabbitmq_aws/test/rabbitmq_aws_config_tests.erl @@ -120,10 +120,10 @@ credentials_test_() -> { foreach, fun() -> - meck:new(httpc), - meck:new(rabbitmq_aws), + meck:new(gun, []), + meck:new(rabbitmq_aws, [passthrough]), reset_environment(), - [httpc, rabbitmq_aws] + [gun, rabbitmq_aws] end, fun meck:unload/1, [ @@ -222,13 +222,26 @@ credentials_test_() -> {"from instance metadata service", fun() -> CredsBody = "{\n \"Code\" : \"Success\",\n \"LastUpdated\" : \"2016-03-31T21:51:49Z\",\n \"Type\" : \"AWS-HMAC\",\n \"AccessKeyId\" : \"ASIAIMAFAKEACCESSKEY\",\n \"SecretAccessKey\" : \"2+t64tZZVaz0yp0x1G23ZRYn+FAKEyVALUEs/4qh\",\n \"Token\" : \"FAKE//////////wEAK/TOKEN/VALUE=\",\n \"Expiration\" : \"2016-04-01T04:13:28Z\"\n}", + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:sequence(gun, get, 3, [stream_ref1, stream_ref2]), meck:sequence( - httpc, - request, - 4, + gun, + await, + 3, [ - {ok, {{protocol, 200, message}, headers, "Bob"}}, - {ok, {{protocol, 200, message}, headers, CredsBody}} + {response, nofin, 200, headers}, + {response, nofin, 200, headers} + ] + ), + meck:sequence( + gun, + await_body, + 3, + [ + {ok, <<"Bob">>}, + {ok, list_to_binary(CredsBody)} ] ), meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), @@ -239,41 +252,59 @@ credentials_test_() -> end}, {"with instance metadata service role error", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), - meck:expect(httpc, request, 4, {error, timeout}), + meck:expect(gun, open, fun(_, _, _) -> {error, timeout} end), ?assertEqual({error, undefined}, rabbitmq_aws_config:credentials()) end}, {"with instance metadata service role http error", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), - meck:expect( - httpc, - request, - 4, - {ok, {{protocol, 500, message}, headers, "Internal Server Error"}} - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, get, fun(_, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 500, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"Internal Server Error">>} end), ?assertEqual({error, undefined}, rabbitmq_aws_config:credentials()) end}, {"with instance metadata service credentials error", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:sequence(gun, get, 3, [stream_ref1, stream_ref2]), meck:sequence( - httpc, - request, - 4, + gun, + await, + 3, [ - {ok, {{protocol, 200, message}, headers, "Bob"}}, + {response, nofin, 200, headers}, {error, timeout} ] ), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"Bob">>} end), ?assertEqual({error, undefined}, rabbitmq_aws_config:credentials()) end}, {"with instance metadata service credentials not found", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:sequence(gun, get, 3, [stream_ref1, stream_ref2]), + meck:sequence( + gun, + await, + 3, + [ + {response, nofin, 200, headers}, + {response, nofin, 404, headers} + ] + ), meck:sequence( - httpc, - request, - 4, + gun, + await_body, + 3, [ - {ok, {{protocol, 200, message}, headers, "Bob"}}, - {ok, {{protocol, 404, message}, headers, "File Not Found"}} + {ok, <<"Bob">>}, + {ok, <<"File Not Found">>} ] ), ?assertEqual({error, undefined}, rabbitmq_aws_config:credentials()) @@ -357,10 +388,10 @@ region_test_() -> { foreach, fun() -> - meck:new(httpc), - meck:new(rabbitmq_aws), + meck:new(gun, []), + meck:new(rabbitmq_aws, [passthrough]), reset_environment(), - [httpc, rabbitmq_aws] + [gun, rabbitmq_aws] end, fun meck:unload/1, [ @@ -383,12 +414,12 @@ region_test_() -> end}, {"from instance metadata service", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), - meck:expect( - httpc, - request, - 4, - {ok, {{protocol, 200, message}, headers, "us-west-1a"}} - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, get, fun(_, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 200, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"us-west-1a">>} end), ?assertEqual({ok, "us-west-1"}, rabbitmq_aws_config:region()) end}, {"full lookup failure", fun() -> @@ -397,12 +428,12 @@ region_test_() -> end}, {"http error failure", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), - meck:expect( - httpc, - request, - 4, - {ok, {{protocol, 500, message}, headers, "Internal Server Error"}} - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, get, fun(_, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 500, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"Internal Server Error">>} end), ?assertEqual({ok, ?DEFAULT_REGION}, rabbitmq_aws_config:region()) end} ] @@ -412,32 +443,41 @@ instance_id_test_() -> { foreach, fun() -> - meck:new(httpc), - meck:new(rabbitmq_aws), + meck:new(gun, []), + meck:new(rabbitmq_aws, [passthrough]), reset_environment(), - [httpc, rabbitmq_aws] + [gun, rabbitmq_aws] end, fun meck:unload/1, [ {"get instance id successfully", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, undefined), - meck:expect( - httpc, request, 4, {ok, {{protocol, 200, message}, headers, "instance-id"}} - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, get, fun(_, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 200, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"instance-id">>} end), ?assertEqual({ok, "instance-id"}, rabbitmq_aws_config:instance_id()) end}, {"getting instance id is rejected with invalid token error", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, "invalid"), - meck:expect( - httpc, request, 4, {error, {{protocol, 401, message}, headers, "Invalid token"}} - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, get, fun(_, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 401, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"Invalid token">>} end), ?assertEqual({error, undefined}, rabbitmq_aws_config:instance_id()) end}, {"getting instance id is rejected with access denied error", fun() -> meck:expect(rabbitmq_aws, ensure_imdsv2_token_valid, 0, "expired token"), - meck:expect( - httpc, request, 4, {error, {{protocol, 403, message}, headers, "access denied"}} - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, get, fun(_, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 403, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, <<"access denied">>} end), ?assertEqual({error, undefined}, rabbitmq_aws_config:instance_id()) end} ] @@ -447,36 +487,34 @@ load_imdsv2_token_test_() -> { foreach, fun() -> - meck:new(httpc), - [httpc] + meck:new(gun, []), + [gun] end, fun meck:unload/1, [ {"fail to get imdsv2 token - timeout", fun() -> - meck:expect(httpc, request, 4, {error, timeout}), + meck:expect(gun, open, fun(_, _, _) -> {error, timeout} end), ?assertEqual(undefined, rabbitmq_aws_config:load_imdsv2_token()) end}, {"fail to get imdsv2 token - PUT request is not valid", fun() -> - meck:expect( - httpc, - request, - 4, - {error, { - {protocol, 400, messge}, - headers, - "Missing or Invalid Parameters – The PUT request is not valid." - }} - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, put, fun(_, _, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 400, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> + {ok, <<"Missing or Invalid Parameters – The PUT request is not valid.">>} + end), ?assertEqual(undefined, rabbitmq_aws_config:load_imdsv2_token()) end}, {"successfully get imdsv2 token from instance metadata service", fun() -> IMDSv2Token = "super_secret_token_value", - meck:sequence( - httpc, - request, - 4, - [{ok, {{protocol, 200, message}, headers, IMDSv2Token}}] - ), + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect(gun, put, fun(_, _, _, _) -> stream_ref end), + meck:expect(gun, await, fun(_, _, _) -> {response, nofin, 200, headers} end), + meck:expect(gun, await_body, fun(_, _, _) -> {ok, list_to_binary(IMDSv2Token)} end), ?assertEqual(IMDSv2Token, rabbitmq_aws_config:load_imdsv2_token()) end} ] @@ -486,7 +524,7 @@ maybe_imdsv2_token_headers_test_() -> { foreach, fun() -> - meck:new(rabbitmq_aws), + meck:new(rabbitmq_aws, [passthrough]), [rabbitmq_aws] end, fun meck:unload/1, @@ -516,7 +554,7 @@ reset_environment() -> "AWS_SHARED_CREDENTIALS_FILE", "bad_credentials.ini" ), - meck:expect(httpc, request, 4, {error, timeout}). + meck:expect(gun, open, fun(_, _, _) -> {error, timeout} end). setup_test_config_env_var() -> setup_test_file_with_env_var("AWS_CONFIG_FILE", "test_aws_config.ini"). diff --git a/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl b/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl index 7f5eaa906e44..44d5cf917edf 100644 --- a/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl +++ b/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl @@ -27,7 +27,7 @@ init_test_() -> os:unsetenv("AWS_SECRET_ACCESS_KEY"), Expectation = {state, "Sésame", "ouvre-toi", undefined, undefined, "us-west-3", undefined, - undefined}, + undefined, #{}}, ?assertEqual(Expectation, State) end}, {"error", fun() -> @@ -39,14 +39,21 @@ init_test_() -> ok = gen_server:stop(Pid), Expectation = {state, undefined, undefined, undefined, undefined, "us-west-3", undefined, - test_result}, + test_result, #{}}, ?assertEqual(Expectation, State), meck:validate(rabbitmq_aws_config) end} ]}. terminate_test() -> - ?assertEqual(ok, rabbitmq_aws:terminate(foo, bar)). + ?assertEqual( + ok, + rabbitmq_aws:terminate( + foo, + {state, undefined, undefined, undefined, undefined, "us-west-3", undefined, test_result, + #{}} + ) + ). code_change_test() -> ?assertEqual({ok, {state, denial}}, rabbitmq_aws:code_change(foo, bar, {state, denial})). @@ -133,9 +140,11 @@ format_response_test_() -> {"ok", fun() -> Response = {ok, { - {"HTTP/1.1", 200, "Ok"}, [{"Content-Type", "text/xml"}], "Value" + {"HTTP/1.1", 200, "Ok"}, + [{<<"Content-Type">>, <<"text/xml">>}], + "Value" }}, - Expectation = {ok, {[{"Content-Type", "text/xml"}], [{"test", "Value"}]}}, + Expectation = {ok, {[{<<"Content-Type">>, <<"text/xml">>}], [{"test", "Value"}]}}, ?assertEqual(Expectation, rabbitmq_aws:format_response(Response)) end}, {"error", fun() -> @@ -161,8 +170,8 @@ gen_server_call_test_() -> os:putenv("AWS_DEFAULT_REGION", "us-west-3"), os:putenv("AWS_ACCESS_KEY_ID", "Sésame"), os:putenv("AWS_SECRET_ACCESS_KEY", "ouvre-toi"), - meck:new(httpc, []), - [httpc] + meck:new(gun, []), + [gun] end, fun(Mods) -> meck:unload(Mods), @@ -186,31 +195,43 @@ gen_server_call_test_() -> Body = "", Options = [], Host = undefined, + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), meck:expect( - httpc, - request, - fun( - get, - {"https://ec2.us-east-1.amazonaws.com/?Action=DescribeTags&Version=2015-10-01", - _Headers}, - _Options, - [] - ) -> - {ok, { - {"HTTP/1.0", 200, "OK"}, - [{"content-type", "application/json"}], - "{\"pass\": true}" - }} + gun, + get, + fun(_Pid, _Path, _Headers) -> nofin end + ), + %% {ok, {{"HTTP/1.0", 200, "OK"}, [{"content-type", "application/json"}], "{\"pass\": true}"}} + %% end), + meck:expect( + gun, + await, + fun(_Pid, _, _) -> + {response, nofin, 200, [{<<"content-type">>, <<"application/json">>}]} end ), + meck:expect( + gun, + await_body, + fun(_Pid, _, _) -> {ok, <<"{\"pass\": true}">>} end + ), + + %% {ok, {{"HTTP/1.0", 200, "OK"}, [{"content-type", "application/json"}], "{\"pass\": true}"}} + %% end), Expectation = - {reply, {ok, {[{"content-type", "application/json"}], [{"pass", true}]}}, - State}, + {reply, + {ok, + {[{<<"content-type">>, <<"application/json">>}], [{"pass", true}]}}, + State#state{ + gun_connections = #{"ec2.us-east-1.amazonaws.com:443" => pid} + }}, Result = rabbitmq_aws:handle_call( {request, Service, Method, Headers, Path, Body, Options, Host}, eunit, State ), ?assertEqual(Expectation, Result), - meck:validate(httpc) + meck:validate(gun) end }, { @@ -388,9 +409,9 @@ perform_request_test_() -> { foreach, fun() -> - meck:new(httpc, []), + meck:new(gun, []), meck:new(rabbitmq_aws_config, []), - [httpc, rabbitmq_aws_config] + [gun, rabbitmq_aws_config] end, fun meck:unload/1, [ @@ -411,33 +432,37 @@ perform_request_test_() -> Host = undefined, ExpectURI = "https://ec2.us-east-1.amazonaws.com/?Action=DescribeTags&Version=2015-10-01", + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + + meck:expect( + gun, + get, + fun(_Pid, "/?Action=DescribeTags&Version=2015-10-01", _Headers) -> nofin end + ), meck:expect( - httpc, - request, - fun(get, {URI, _Headers}, _Options, []) -> - case URI of - ExpectURI -> - {ok, { - {"HTTP/1.0", 200, "OK"}, - [{"content-type", "application/json"}], - "{\"pass\": true}" - }}; - _ -> - {ok, - {{"HTTP/1.0", 400, "RequestFailure", - [{"content-type", "application/json"}], - "{\"pass\": false}"}}} - end + gun, + await, + fun(_Pid, _, _) -> + {response, nofin, 200, [{<<"content-type">>, <<"application/json">>}]} end ), + meck:expect( + gun, + await_body, + fun(_Pid, _, _) -> {ok, <<"{\"pass\": true}">>} end + ), + Expectation = { - {ok, {[{"content-type", "application/json"}], [{"pass", true}]}}, State + {ok, {[{<<"content-type">>, <<"application/json">>}], [{"pass", true}]}}, + State#state{gun_connections = #{"ec2.us-east-1.amazonaws.com:443" => pid}} }, Result = rabbitmq_aws:perform_request( State, Service, Method, Headers, Path, Body, Options, Host ), ?assertEqual(Expectation, Result), - meck:validate(httpc) + meck:validate(gun) end }, { @@ -451,19 +476,11 @@ perform_request_test_() -> Body = "", Options = [], Host = undefined, - meck:expect(httpc, request, fun(get, {_URI, _Headers}, _Options, []) -> - {ok, { - {"HTTP/1.0", 400, "RequestFailure"}, - [{"content-type", "application/json"}], - "{\"pass\": false}" - }} - end), Expectation = {{error, {credentials, State#state.error}}, State}, Result = rabbitmq_aws:perform_request( State, Service, Method, Headers, Path, Body, Options, Host ), - ?assertEqual(Expectation, Result), - meck:validate(httpc) + ?assertEqual(Expectation, Result) end }, { @@ -554,9 +571,9 @@ api_get_request_test_() -> { foreach, fun() -> - meck:new(httpc, []), + meck:new(gun, []), meck:new(rabbitmq_aws_config, []), - [httpc, rabbitmq_aws_config] + [gun, rabbitmq_aws_config] end, fun meck:unload/1, [ @@ -567,23 +584,34 @@ api_get_request_test_() -> region = "us-east-1", expiration = {{3016, 4, 1}, {12, 0, 0}} }, + meck:expect(gun, open, fun(_, _, _) -> {ok, pid} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), meck:expect( - httpc, - request, - 4, - {ok, { - {"HTTP/1.0", 200, "OK"}, - [{"content-type", "application/json"}], - "{\"data\": \"value\"}" - }} + gun, + get, + fun(_Pid, _Path, _Headers) -> nofin end ), + meck:expect( + gun, + await, + fun(_Pid, _, _) -> + {response, nofin, 200, [{<<"content-type">>, <<"application/json">>}]} + end + ), + meck:expect( + gun, + await_body, + fun(_Pid, _, _) -> {ok, <<"{\"data\": \"value\"}">>} end + ), + {ok, Pid} = rabbitmq_aws:start_link(), rabbitmq_aws:set_region("us-east-1"), rabbitmq_aws:set_credentials(State), Result = rabbitmq_aws:api_get_request("AWS", "API"), ok = gen_server:stop(Pid), ?assertEqual({ok, [{"data", "value"}]}, Result), - meck:validate(httpc) + meck:validate(gun) end}, {"AWS service API request failed - credentials", fun() -> meck:expect(rabbitmq_aws_config, credentials, 0, {error, undefined}), @@ -600,14 +628,27 @@ api_get_request_test_() -> region = "us-east-1", expiration = {{3016, 4, 1}, {12, 0, 0}} }, - meck:expect(httpc, request, 4, {error, "network error"}), + meck:expect(gun, open, fun(_, _, _) -> {ok, spawn(fun() -> ok end)} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), + meck:expect( + gun, + get, + fun(_Pid, _Path, _Headers) -> nofin end + ), + meck:expect( + gun, + await, + fun(_Pid, _, _) -> {error, "network error"} end + ), + {ok, Pid} = rabbitmq_aws:start_link(), rabbitmq_aws:set_region("us-east-1"), rabbitmq_aws:set_credentials(State), Result = rabbitmq_aws:api_get_request_with_retries("AWS", "API", 3, 1), ok = gen_server:stop(Pid), ?assertEqual({error, "AWS service is unavailable"}, Result), - meck:validate(httpc) + meck:validate(gun) end}, {"AWS service API request succeeded after a transient error", fun() -> State = #state{ @@ -616,22 +657,35 @@ api_get_request_test_() -> region = "us-east-1", expiration = {{3016, 4, 1}, {12, 0, 0}} }, + meck:expect(gun, open, fun(_, _, _) -> {ok, spawn(fun() -> ok end)} end), + meck:expect(gun, close, fun(_) -> ok end), + meck:expect(gun, await_up, fun(_, _) -> {ok, protocol} end), meck:expect( - httpc, - request, - 4, + gun, + get, + fun(_Pid, _Path, _Headers) -> nofin end + ), + + %% meck:expect(gun, get, 3, meck:seq( + %% fun(_Pid, _Path, _Headers) -> {error, "network errors"} end), + meck:expect( + gun, + await, + 3, meck:seq([ {error, "network error"}, - {ok, { - {"HTTP/1.0", 500, "OK"}, - [{"content-type", "application/json"}], - "{\"error\": \"server error\"}" - }}, - {ok, { - {"HTTP/1.0", 200, "OK"}, - [{"content-type", "application/json"}], - "{\"data\": \"value\"}" - }} + {response, nofin, 500, [{<<"content-type">>, <<"application/json">>}]}, + {response, nofin, 200, [{<<"content-type">>, <<"application/json">>}]} + ]) + ), + + meck:expect( + gun, + await_body, + 3, + meck:seq([ + {ok, <<"{\"error\": \"server error\"}">>}, + {ok, <<"{\"data\": \"value\"}">>} ]) ), {ok, Pid} = rabbitmq_aws:start_link(), @@ -640,7 +694,7 @@ api_get_request_test_() -> Result = rabbitmq_aws:api_get_request_with_retries("AWS", "API", 3, 1), ok = gen_server:stop(Pid), ?assertEqual({ok, [{"data", "value"}]}, Result), - meck:validate(httpc) + meck:validate(gun) end} ] }. From b319649e4260a475ad0d4eae363400793e0deb76 Mon Sep 17 00:00:00 2001 From: Simon Unge Date: Thu, 31 Jul 2025 11:40:25 +0000 Subject: [PATCH 02/11] Hacky new "direct" api for async handling. Will most likely replace the old one but curretly they will live side by side --- deps/rabbitmq_aws/src/rabbitmq_aws.erl | 221 +++++++++++++++++++- deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl | 10 +- 2 files changed, 228 insertions(+), 3 deletions(-) diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index d54f890e0af0..445e6955a1ae 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -22,7 +22,14 @@ ensure_imdsv2_token_valid/0, api_get_request/2, close_connection/3, - status_text/1 + status_text/1, + %% New concurrent API + open_connection/1, open_connection/2, + close_direct_connection/1, + direct_get/3, direct_get/4, + direct_post/4, direct_post/5, + direct_put/4, direct_put/5, + direct_request/6 ]). %% gen-server exports @@ -44,6 +51,16 @@ -include("rabbitmq_aws.hrl"). -include_lib("kernel/include/logger.hrl"). +%% Types for new concurrent API +-type connection_handle() :: {gun:conn_ref(), credential_context()}. +-type credential_context() :: #{ + access_key => access_key(), + secret_access_key => secret_access_key(), + security_token => security_token(), + region => region(), + service => string() +}. + %%==================================================================== %% exported wrapper functions %%==================================================================== @@ -112,6 +129,67 @@ refresh_credentials() -> close_connection(Service, Path, Options) -> gen_server:cast(?MODULE, {close_connection, Service, Path, Options}). +%%==================================================================== +%% New Concurrent API Functions +%%==================================================================== + +%% Open a connection and return handle for direct use +-spec open_connection(Service :: string()) -> {ok, connection_handle()} | {error, term()}. +open_connection(Service) -> + open_connection(Service, []). + +-spec open_connection(Service :: string(), Options :: list()) -> {ok, connection_handle()} | {error, term()}. +open_connection(Service, Options) -> + gen_server:call(?MODULE, {open_direct_connection, Service, Options}). + +%% Close a direct connection +-spec close_direct_connection(Handle :: connection_handle()) -> ok. +close_direct_connection({GunPid, _CredContext}) -> + gun:close(GunPid). + +%% Direct API calls that bypass gen_server +-spec direct_get(Handle :: connection_handle(), Path :: path(), Headers :: headers()) -> result(). +direct_get(Handle, Path, Headers) -> + direct_request(Handle, get, Path, <<>>, Headers, []). + +-spec direct_get(Handle :: connection_handle(), Path :: path(), Headers :: headers(), Options :: list()) -> result(). +direct_get(Handle, Path, Headers, Options) -> + direct_request(Handle, get, Path, <<>>, Headers, Options). + +-spec direct_post(Handle :: connection_handle(), Path :: path(), Body :: body(), Headers :: headers()) -> result(). +direct_post(Handle, Path, Body, Headers) -> + direct_request(Handle, post, Path, Body, Headers, []). + +-spec direct_post(Handle :: connection_handle(), Path :: path(), Body :: body(), Headers :: headers(), Options :: list()) -> result(). +direct_post(Handle, Path, Body, Headers, Options) -> + direct_request(Handle, post, Path, Body, Headers, Options). + +-spec direct_put(Handle :: connection_handle(), Path :: path(), Body :: body(), Headers :: headers()) -> result(). +direct_put(Handle, Path, Body, Headers) -> + direct_request(Handle, put, Path, Body, Headers, []). + +-spec direct_put(Handle :: connection_handle(), Path :: path(), Body :: body(), Headers :: headers(), Options :: list()) -> result(). +direct_put(Handle, Path, Body, Headers, Options) -> + direct_request(Handle, put, Path, Body, Headers, Options). + +-spec direct_request( + Handle :: connection_handle(), + Method :: method(), + Path :: path(), + Body :: body(), + Headers :: headers(), + Options :: list() +) -> result(). +direct_request({GunPid, CredContext}, Method, Path, Body, Headers, Options) -> + #{service := Service, region := Region} = CredContext, + % Build URI for signing + Host = endpoint_host(Region, Service), + URI = "https://" ++ Host ++ Path, + % Sign headers directly (no gen_server call) + SignedHeaders = sign_headers_with_context(CredContext, Method, URI, Headers, Body), + % Make Gun request directly + perform_direct_gun_request(GunPid, Method, Path, SignedHeaders, Body, Options). + -spec refresh_credentials(state()) -> ok | error. %% @doc Manually refresh the credentials from the environment, filesystem or EC2 Instance Metadata Service. %% @end @@ -248,6 +326,18 @@ handle_msg({request, Service, Method, Headers, Path, Body, Options, Host}, State State, Service, Method, Headers, Path, Body, Options, Host ), {reply, Response, NewState}; +handle_msg({open_direct_connection, Service, Options}, State) -> + case ensure_credentials_valid_internal(State) of + {ok, ValidState} -> + case create_direct_connection(ValidState, Service, Options) of + {ok, Handle} -> + {reply, {ok, Handle}, ValidState}; + {error, Reason} -> + {reply, {error, Reason}, ValidState} + end; + {error, Reason} -> + {reply, {error, Reason}, State} + end; handle_msg(get_state, State) -> {reply, {ok, State}, State}; handle_msg(refresh_credentials, State) -> @@ -711,7 +801,7 @@ gun_request(State, Method, URI, Headers, Body, Options) -> ), {Host, Port, Path} = parse_uri(URI), {ConnPid, NewState} = get_or_create_gun_connection(State, Host, Port, Path, Options), - Timeout = proplists:get_value(timeout, Options, ?DEFAULT_HTTP_TIMEOUT), + Timeout = proplists:get_value(timeout, ensure_timeout(Options), ?DEFAULT_HTTP_TIMEOUT), try StreamRef = do_gun_request(ConnPid, Method, Path, HeadersBin, Body), case gun:await(ConnPid, StreamRef, Timeout) of @@ -865,3 +955,130 @@ status_text(404) -> "Not Found"; status_text(416) -> "Range Not Satisfiable"; status_text(500) -> "Internal Server Error"; status_text(Code) -> integer_to_list(Code). + +%%==================================================================== +%% New Concurrent API Helper Functions +%%==================================================================== + +%% Create a direct connection handle +-spec create_direct_connection(State :: state(), Service :: string(), Options :: list()) -> + {ok, connection_handle()} | {error, term()}. +create_direct_connection(State, Service, Options) -> + Region = State#state.region, + Host = endpoint_host(Region, Service), + Port = 443, + + HttpVersion = proplists:get_value(version, Options, "HTTP/1.1"), + Protocols = + case HttpVersion of + "HTTP/2" -> [http2, http]; + "HTTP/2.0" -> [http2, http]; + "HTTP/1.1" -> [http]; + "HTTP/1.0" -> [http]; + % Default: try HTTP/2, fallback to HTTP/1.1 + _ -> [http2, http] + end, + ConnectTimeout = proplists:get_value(connect_timeout, Options, 5000), + GunOpts = #{ + transport => tls, + %% if + %% Port == 443 -> tls; + %% true -> tcp + %% end, + protocols => Protocols, + connect_timeout => ConnectTimeout + }, + + case gun:open(Host, Port, GunOpts) of + {ok, GunPid} -> + case gun:await_up(GunPid, ConnectTimeout) of + {ok, _Protocol} -> + CredContext = #{ + access_key => State#state.access_key, + secret_access_key => State#state.secret_access_key, + security_token => State#state.security_token, + region => Region, + service => Service + }, + {ok, {GunPid, CredContext}}; + {error, Reason} -> + gun:close(GunPid), + {error, {connection_failed, Reason}} + end; + {error, Reason} -> + {error, {gun_open_failed, Reason}} + end. + +%% Sign headers using credential context (no gen_server state needed) +-spec sign_headers_with_context( + CredContext :: credential_context(), + Method :: method(), + URI :: string(), + Headers :: headers(), + Body :: body() +) -> headers(). +sign_headers_with_context(CredContext, Method, URI, Headers, Body) -> + #{ + access_key := AccessKey, + secret_access_key := SecretKey, + security_token := SecurityToken, + region := Region, + service := Service + } = CredContext, + rabbitmq_aws_sign:headers(#request{ + access_key = AccessKey, + secret_access_key = SecretKey, + security_token = SecurityToken, + region = Region, + service = Service, + method = Method, + uri = URI, + headers = Headers, + body = Body + }). + +%% Direct Gun request (extracted from existing gun_request function) +-spec perform_direct_gun_request( + GunPid :: gun:conn_ref(), + Method :: method(), + Path :: path(), + Headers :: headers(), + Body :: body(), + Options :: list() +) -> result(). +perform_direct_gun_request(GunPid, Method, Path, Headers, Body, Options) -> + HeadersBin = lists:map( + fun({Key, Value}) -> + {list_to_binary(Key), list_to_binary(Value)} + end, + Headers + ), + Timeout = proplists:get_value(timeout, Options, ?DEFAULT_HTTP_TIMEOUT), + try + StreamRef = do_gun_request(GunPid, Method, Path, HeadersBin, Body), + case gun:await(GunPid, StreamRef, Timeout) of + {response, fin, Status, RespHeaders} -> + format_response({ok, {{http_version, Status, status_text(Status)}, RespHeaders, <<>>}}); + {response, nofin, Status, RespHeaders} -> + {ok, RespBody} = gun:await_body(GunPid, StreamRef, Timeout), + format_response({ok, {{http_version, Status, status_text(Status)}, RespHeaders, RespBody}}); + {error, Reason} -> + {error, Reason} + end + catch + _:Error -> + {error, Error} + end. + +%% Internal credential validation (extracted from existing logic) +-spec ensure_credentials_valid_internal(State :: state()) -> {ok, state()} | {error, term()}. +ensure_credentials_valid_internal(State) -> + case has_credentials(State) of + true -> + case expired_credentials(State#state.expiration) of + false -> {ok, State}; + true -> load_credentials(State) + end; + false -> + load_credentials(State) + end. diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl index 7a95a2b44e77..da76d9cc6211 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl @@ -28,9 +28,17 @@ headers(Request) -> PayloadHash = sha256(Request#request.body), URI = rabbitmq_aws_urilib:parse(Request#request.uri), {_, Host, _} = URI#uri.authority, + + BodyLength = case Request#request.body of + Body when is_binary(Body) -> + size(Body); + Body when is_list(Body) -> + length(Body) + end, + Headers = append_headers( RequestTimestamp, - length(Request#request.body), + BodyLength, PayloadHash, Host, Request#request.security_token, From dbd3455937e6d291ae204aac8b6ffd45009a5a41 Mon Sep 17 00:00:00 2001 From: Simon Unge Date: Thu, 31 Jul 2025 13:11:54 +0000 Subject: [PATCH 03/11] Remove gun_connection pool --- deps/rabbitmq_aws/include/rabbitmq_aws.hrl | 4 +- deps/rabbitmq_aws/src/rabbitmq_aws.erl | 132 +++++---------------- 2 files changed, 31 insertions(+), 105 deletions(-) diff --git a/deps/rabbitmq_aws/include/rabbitmq_aws.hrl b/deps/rabbitmq_aws/include/rabbitmq_aws.hrl index 2106b3a372df..6a0cacd81131 100644 --- a/deps/rabbitmq_aws/include/rabbitmq_aws.hrl +++ b/deps/rabbitmq_aws/include/rabbitmq_aws.hrl @@ -71,9 +71,7 @@ security_token :: security_token() | undefined, region :: region() | undefined, imdsv2_token :: imdsv2token() | undefined, - error :: atom() | string() | undefined, - % host -> gun_pid mapping - gun_connections = #{} :: #{string() => pid()} + error :: atom() | string() | undefined }). -type state() :: #state{}. diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index 445e6955a1ae..b941baae907c 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -21,7 +21,6 @@ set_region/1, ensure_imdsv2_token_valid/0, api_get_request/2, - close_connection/3, status_text/1, %% New concurrent API open_connection/1, open_connection/2, @@ -126,9 +125,6 @@ put(Service, Path, Body, Headers, Options) -> refresh_credentials() -> gen_server:call(rabbitmq_aws, refresh_credentials). -close_connection(Service, Path, Options) -> - gen_server:cast(?MODULE, {close_connection, Service, Path, Options}). - %%==================================================================== %% New Concurrent API Functions %%==================================================================== @@ -188,7 +184,7 @@ direct_request({GunPid, CredContext}, Method, Path, Body, Headers, Options) -> % Sign headers directly (no gen_server call) SignedHeaders = sign_headers_with_context(CredContext, Method, URI, Headers, Body), % Make Gun request directly - perform_direct_gun_request(GunPid, Method, Path, SignedHeaders, Body, Options). + direct_gun_request(GunPid, Method, Path, SignedHeaders, Body, Options). -spec refresh_credentials(state()) -> ok | error. %% @doc Manually refresh the credentials from the environment, filesystem or EC2 Instance Metadata Service. @@ -294,14 +290,6 @@ init([]) -> {ok, #state{}}. terminate(_, State) -> - %% Close all Gun connections - maps:fold( - fun(_Host, ConnPid, _Acc) -> - gun:close(ConnPid) - end, - ok, - State#state.gun_connections - ), ok. code_change(_, _, State) -> @@ -310,8 +298,6 @@ code_change(_, _, State) -> handle_call(Msg, _From, State) -> handle_msg(Msg, State). -handle_cast({close_connection, Service, Path, Options}, State) -> - {noreply, close_connection(Service, Path, Options, State)}; handle_cast(_Request, State) -> {noreply, State}. @@ -352,23 +338,12 @@ handle_msg({set_credentials, AccessKey, SecretAccessKey}, State) -> error = undefined }}; handle_msg({set_credentials, NewState}, State) -> - spawn(fun() -> - maps:fold( - fun(_Host, ConnPid, _Acc) -> - gun:close(ConnPid) - end, - ok, - State#state.gun_connections - ) - end), {reply, ok, State#state{ access_key = NewState#state.access_key, secret_access_key = NewState#state.secret_access_key, security_token = NewState#state.security_token, expiration = NewState#state.expiration, - error = NewState#state.error, - % Potentially new credentials, so clear the connection pool? - gun_connections = #{} + error = NewState#state.error }}; handle_msg({set_region, Region}, State) -> {reply, ok, State#state{region = Region}}; @@ -469,7 +444,7 @@ expired_credentials(Expiration) -> %% - Credentials file %% - EC2 Instance Metadata Service %% @end -load_credentials(#state{region = Region, gun_connections = GunConnections}) -> +load_credentials(#state{region = Region}) -> case rabbitmq_aws_config:credentials() of {ok, AccessKey, SecretAccessKey, Expiration, SecurityToken} -> {ok, #state{ @@ -479,8 +454,7 @@ load_credentials(#state{region = Region, gun_connections = GunConnections}) -> secret_access_key = SecretAccessKey, expiration = Expiration, security_token = SecurityToken, - imdsv2_token = undefined, - gun_connections = GunConnections + imdsv2_token = undefined }}; {error, Reason} -> ?LOG_ERROR( @@ -494,8 +468,7 @@ load_credentials(#state{region = Region, gun_connections = GunConnections}) -> secret_access_key = undefined, expiration = undefined, security_token = undefined, - imdsv2_token = undefined, - gun_connections = GunConnections + imdsv2_token = undefined }} end. @@ -800,30 +773,28 @@ gun_request(State, Method, URI, Headers, Body, Options) -> Headers ), {Host, Port, Path} = parse_uri(URI), - {ConnPid, NewState} = get_or_create_gun_connection(State, Host, Port, Path, Options), + ConnPid = create_gun_connection(Host, Port, Path, Options), Timeout = proplists:get_value(timeout, ensure_timeout(Options), ?DEFAULT_HTTP_TIMEOUT), - try - StreamRef = do_gun_request(ConnPid, Method, Path, HeadersBin, Body), - case gun:await(ConnPid, StreamRef, Timeout) of - {response, fin, Status, RespHeaders} -> - Response = {ok, {{http_version, Status, status_text(Status)}, RespHeaders, <<>>}}, - {Response, NewState}; - {response, nofin, Status, RespHeaders} -> - {ok, RespBody} = gun:await_body(ConnPid, StreamRef, Timeout), - Response = - {ok, {{http_version, Status, status_text(Status)}, RespHeaders, RespBody}}, - {Response, NewState}; - {error, Reason} -> - {{error, Reason}, NewState} - end - catch - _:Error -> - % Connection failed, remove from pool and return error - HostKey = get_connection_key(Host, Port, Path, Options), - NewConnections = maps:remove(HostKey, NewState#state.gun_connections), - gun:close(ConnPid), - {{error, Error}, NewState#state{gun_connections = NewConnections}} - end. + Response = try + StreamRef = do_gun_request(ConnPid, Method, Path, HeadersBin, Body), + case gun:await(ConnPid, StreamRef, Timeout) of + {response, fin, Status, RespHeaders} -> + Response = {ok, {{http_version, Status, status_text(Status)}, RespHeaders, <<>>}}, + {Response, NewState}; + {response, nofin, Status, RespHeaders} -> + {ok, RespBody} = gun:await_body(ConnPid, StreamRef, Timeout), + Response = + {ok, {{http_version, Status, status_text(Status)}, RespHeaders, RespBody}}, + {Response, NewState}; + {error, Reason} -> + {{error, Reason}, NewState} + end + catch + _:Error -> + {{error, Error}, NewState} + end, + gun:close(ConnPid), + Reponse. do_gun_request(ConnPid, get, Path, Headers, _Body) -> gun:get(ConnPid, Path, Headers); @@ -840,35 +811,7 @@ do_gun_request(ConnPid, patch, Path, Headers, Body) -> do_gun_request(ConnPid, options, Path, Headers, _Body) -> gun:options(ConnPid, Path, Headers, #{}). -get_or_create_gun_connection(State, Host, Port, Path, Options) -> - HostKey = get_connection_key(Host, Port, Path, Options), - case maps:get(HostKey, State#state.gun_connections, undefined) of - undefined -> - create_gun_connection(State, Host, Port, HostKey, Options); - ConnPid -> - case is_process_alive(ConnPid) andalso gun:info(ConnPid) =/= undefined of - true -> - {ConnPid, State}; - false -> - % Connection is dead, create new one - gun:close(ConnPid), - create_gun_connection(State, Host, Port, HostKey, Options) - end - end. - -get_connection_key(Host, Port, Path, Options) -> - case proplists:get_value(connection_key_type, Options, host) of - host -> - Host ++ ":" ++ integer_to_list(Port); - path -> - Host ++ ":" ++ integer_to_list(Port) ++ Path; - {path_custom, Extra} -> - Host ++ ":" ++ integer_to_list(Port) ++ Path ++ ":" ++ Extra; - _ -> - Host ++ ":" ++ integer_to_list(Port) - end. - -create_gun_connection(State, Host, Port, HostKey, Options) -> +create_gun_connection(Host, Port, HostKey, Options) -> % Map HTTP version to Gun protocols, always include http as fallback HttpVersion = proplists:get_value(version, Options, "HTTP/1.1"), Protocols = @@ -894,9 +837,7 @@ create_gun_connection(State, Host, Port, HostKey, Options) -> {ok, ConnPid} -> case gun:await_up(ConnPid, ConnectTimeout) of {ok, _Protocol} -> - NewConnections = maps:put(HostKey, ConnPid, State#state.gun_connections), - NewState = State#state{gun_connections = NewConnections}, - {ConnPid, NewState}; + ConnPid; {error, Reason} -> gun:close(ConnPid), error({gun_connection_failed, Reason}) @@ -905,19 +846,6 @@ create_gun_connection(State, Host, Port, HostKey, Options) -> error({gun_open_failed, Reason}) end. -close_connection(Service, Path, Options, State) -> - URI = endpoint(State, undefined, Service, Path), - {Host, Port, Path} = parse_uri(URI), - HostKey = get_connection_key(Host, Port, Path, Options), - case maps:get(HostKey, State#state.gun_connections, undefined) of - undefined -> - State; - ConnPid -> - gun:close(ConnPid), - NewConnections = maps:remove(HostKey, State#state.gun_connections), - State#state{gun_connections = NewConnections} - end. - parse_uri(URI) -> case string:split(URI, "://", leading) of [Scheme, Rest] -> @@ -1038,7 +966,7 @@ sign_headers_with_context(CredContext, Method, URI, Headers, Body) -> }). %% Direct Gun request (extracted from existing gun_request function) --spec perform_direct_gun_request( +-spec direct_gun_request( GunPid :: gun:conn_ref(), Method :: method(), Path :: path(), @@ -1046,7 +974,7 @@ sign_headers_with_context(CredContext, Method, URI, Headers, Body) -> Body :: body(), Options :: list() ) -> result(). -perform_direct_gun_request(GunPid, Method, Path, Headers, Body, Options) -> +direct_gun_request(GunPid, Method, Path, Headers, Body, Options) -> HeadersBin = lists:map( fun({Key, Value}) -> {list_to_binary(Key), list_to_binary(Value)} From 52b507811eb00a622e04e752a1af9f8439311165 Mon Sep 17 00:00:00 2001 From: Simon Unge Date: Thu, 31 Jul 2025 17:16:34 +0000 Subject: [PATCH 04/11] Cleanup --- deps/rabbitmq_aws/src/rabbitmq_aws.erl | 203 ++++++------------ deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl | 13 +- deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl | 13 +- 3 files changed, 73 insertions(+), 156 deletions(-) diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index b941baae907c..53576bdcbdd4 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -22,13 +22,8 @@ ensure_imdsv2_token_valid/0, api_get_request/2, status_text/1, - %% New concurrent API open_connection/1, open_connection/2, - close_direct_connection/1, - direct_get/3, direct_get/4, - direct_post/4, direct_post/5, - direct_put/4, direct_put/5, - direct_request/6 + close_connection/1 ]). %% gen-server exports @@ -65,18 +60,18 @@ %%==================================================================== -spec get( - Service :: string(), + ServiceOrHandle :: string() | connection_handle(), Path :: path() ) -> result(). %% @doc Perform a HTTP GET request to the AWS API for the specified service. The %% response will automatically be decoded if it is either in JSON, or XML %% format. %% @end -get(Service, Path) -> - get(Service, Path, []). +get(ServiceOrHandle, Path) -> + get(ServiceOrHandle, Path, []). -spec get( - Service :: string(), + ServiceOrHandle :: string() | connection_handle(), Path :: path(), Headers :: headers() ) -> result(). @@ -84,14 +79,14 @@ get(Service, Path) -> %% response will automatically be decoded if it is either in JSON or XML %% format. %% @end -get(Service, Path, Headers) -> - request(Service, get, Path, "", Headers, []). +get(ServiceOrHandle, Path, Headers) -> + get(ServiceOrHandle, Path, Headers, []). get(Service, Path, Headers, Options) -> - request(Service, get, Path, "", Headers, Options). + request(Service, get, Path, <<>>, Headers, Options). -spec post( - Service :: string(), + ServiceOrHandle :: string() | connection_handle(), Path :: path(), Body :: body(), Headers :: headers() @@ -100,11 +95,14 @@ get(Service, Path, Headers, Options) -> %% response will automatically be decoded if it is either in JSON or XML %% format. %% @end -post(Service, Path, Body, Headers) -> - request(Service, post, Path, Body, Headers). +post(ServiceOrHandle, Path, Body, Headers) -> + post(ServiceOrHandle, Path, Body, Headers, []). + +post(Service, Path, Body, Headers, Options) -> + request(Service, post, Path, Body, Headers, Options). -spec put( - Service :: string(), + ServiceOrHandle :: string() | connection_handle(), Path :: path(), Body :: body(), Headers :: headers() @@ -113,8 +111,8 @@ post(Service, Path, Body, Headers) -> %% response will automatically be decoded if it is either in JSON or XML %% format. %% @end -put(Service, Path, Body, Headers) -> - put(Service, Path, Body, Headers, []). +put(ServiceOrHandle, Path, Body, Headers) -> + put(ServiceOrHandle, Path, Body, Headers, []). put(Service, Path, Body, Headers, Options) -> request(Service, put, Path, Body, Headers, Options). @@ -134,40 +132,16 @@ refresh_credentials() -> open_connection(Service) -> open_connection(Service, []). --spec open_connection(Service :: string(), Options :: list()) -> {ok, connection_handle()} | {error, term()}. +-spec open_connection(Service :: string(), Options :: list()) -> + {ok, connection_handle()} | {error, term()}. open_connection(Service, Options) -> gen_server:call(?MODULE, {open_direct_connection, Service, Options}). %% Close a direct connection --spec close_direct_connection(Handle :: connection_handle()) -> ok. -close_direct_connection({GunPid, _CredContext}) -> +-spec close_connection(Handle :: connection_handle()) -> ok. +close_connection({GunPid, _CredContext}) -> gun:close(GunPid). -%% Direct API calls that bypass gen_server --spec direct_get(Handle :: connection_handle(), Path :: path(), Headers :: headers()) -> result(). -direct_get(Handle, Path, Headers) -> - direct_request(Handle, get, Path, <<>>, Headers, []). - --spec direct_get(Handle :: connection_handle(), Path :: path(), Headers :: headers(), Options :: list()) -> result(). -direct_get(Handle, Path, Headers, Options) -> - direct_request(Handle, get, Path, <<>>, Headers, Options). - --spec direct_post(Handle :: connection_handle(), Path :: path(), Body :: body(), Headers :: headers()) -> result(). -direct_post(Handle, Path, Body, Headers) -> - direct_request(Handle, post, Path, Body, Headers, []). - --spec direct_post(Handle :: connection_handle(), Path :: path(), Body :: body(), Headers :: headers(), Options :: list()) -> result(). -direct_post(Handle, Path, Body, Headers, Options) -> - direct_request(Handle, post, Path, Body, Headers, Options). - --spec direct_put(Handle :: connection_handle(), Path :: path(), Body :: body(), Headers :: headers()) -> result(). -direct_put(Handle, Path, Body, Headers) -> - direct_request(Handle, put, Path, Body, Headers, []). - --spec direct_put(Handle :: connection_handle(), Path :: path(), Body :: body(), Headers :: headers(), Options :: list()) -> result(). -direct_put(Handle, Path, Body, Headers, Options) -> - direct_request(Handle, put, Path, Body, Headers, Options). - -spec direct_request( Handle :: connection_handle(), Method :: method(), @@ -207,7 +181,7 @@ refresh_credentials(State) -> %% format. %% @end request(Service, Method, Path, Body, Headers) -> - gen_server:call(rabbitmq_aws, {request, Service, Method, Headers, Path, Body, [], undefined}). + request(Service, Method, Path, Body, Headers, []). -spec request( Service :: string(), @@ -222,12 +196,10 @@ request(Service, Method, Path, Body, Headers) -> %% format. %% @end request(Service, Method, Path, Body, Headers, HTTPOptions) -> - gen_server:call( - rabbitmq_aws, {request, Service, Method, Headers, Path, Body, HTTPOptions, undefined} - ). + request(Service, Method, Path, Body, Headers, HTTPOptions, undefined). -spec request( - Service :: string(), + ServiceOrHandle :: string() | connection_handle(), Method :: method(), Path :: path(), Body :: body(), @@ -240,6 +212,8 @@ request(Service, Method, Path, Body, Headers, HTTPOptions) -> %% of services such as DynamoDB. The response will automatically be decoded %% if it is either in JSON or XML format. %% @end +request({GunPid, _CredContext} = Handle, Method, Path, Body, Headers, HTTPOptions, _) when is_pid(GunPid) -> + direct_request(Handle, Method, Path, Body, Headers, HTTPOptions); request(Service, Method, Path, Body, Headers, HTTPOptions, Endpoint) -> gen_server:call( rabbitmq_aws, {request, Service, Method, Headers, Path, Body, HTTPOptions, Endpoint} @@ -289,7 +263,7 @@ init([]) -> {ok, _} = application:ensure_all_started(gun), {ok, #state{}}. -terminate(_, State) -> +terminate(_, _State) -> ok. code_change(_, _, State) -> @@ -614,11 +588,11 @@ perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, %% expired, perform the request and return the response. %% @end perform_request_with_creds(State, Method, URI, Headers, "", Options0) -> - {Response, NewState} = gun_request(State, Method, URI, Headers, <<>>, Options0), - {format_response(Response), NewState}; + Response = gun_request(Method, URI, Headers, <<>>, Options0), + {Response, State}; perform_request_with_creds(State, Method, URI, Headers, Body, Options0) -> - {Response, NewState} = gun_request(State, Method, URI, Headers, Body, Options0), - {format_response(Response), NewState}. + Response = gun_request(Method, URI, Headers, Body, Options0), + {Response, State}. -spec perform_request_creds_error(State :: state()) -> {result_error(), NewState :: state()}. @@ -765,36 +739,12 @@ api_get_request_with_retries(Service, Path, Retries, WaitTimeBetweenRetries) -> end. %% Gun HTTP client functions -gun_request(State, Method, URI, Headers, Body, Options) -> - HeadersBin = lists:map( - fun({Key, Value}) -> - {list_to_binary(Key), list_to_binary(Value)} - end, - Headers - ), +gun_request(Method, URI, Headers, Body, Options) -> {Host, Port, Path} = parse_uri(URI), - ConnPid = create_gun_connection(Host, Port, Path, Options), - Timeout = proplists:get_value(timeout, ensure_timeout(Options), ?DEFAULT_HTTP_TIMEOUT), - Response = try - StreamRef = do_gun_request(ConnPid, Method, Path, HeadersBin, Body), - case gun:await(ConnPid, StreamRef, Timeout) of - {response, fin, Status, RespHeaders} -> - Response = {ok, {{http_version, Status, status_text(Status)}, RespHeaders, <<>>}}, - {Response, NewState}; - {response, nofin, Status, RespHeaders} -> - {ok, RespBody} = gun:await_body(ConnPid, StreamRef, Timeout), - Response = - {ok, {{http_version, Status, status_text(Status)}, RespHeaders, RespBody}}, - {Response, NewState}; - {error, Reason} -> - {{error, Reason}, NewState} - end - catch - _:Error -> - {{error, Error}, NewState} - end, - gun:close(ConnPid), - Reponse. + GunPid = create_gun_connection(Host, Port, Options), + Reply = direct_gun_request(GunPid, Method, Path, Headers, Body, ensure_timeout(Options)), + gun:close(GunPid), + Reply. do_gun_request(ConnPid, get, Path, Headers, _Body) -> gun:get(ConnPid, Path, Headers); @@ -811,7 +761,7 @@ do_gun_request(ConnPid, patch, Path, Headers, Body) -> do_gun_request(ConnPid, options, Path, Headers, _Body) -> gun:options(ConnPid, Path, Headers, #{}). -create_gun_connection(Host, Port, HostKey, Options) -> +create_gun_connection(Host, Port, Options) -> % Map HTTP version to Gun protocols, always include http as fallback HttpVersion = proplists:get_value(version, Options, "HTTP/1.1"), Protocols = @@ -895,47 +845,15 @@ create_direct_connection(State, Service, Options) -> Region = State#state.region, Host = endpoint_host(Region, Service), Port = 443, - - HttpVersion = proplists:get_value(version, Options, "HTTP/1.1"), - Protocols = - case HttpVersion of - "HTTP/2" -> [http2, http]; - "HTTP/2.0" -> [http2, http]; - "HTTP/1.1" -> [http]; - "HTTP/1.0" -> [http]; - % Default: try HTTP/2, fallback to HTTP/1.1 - _ -> [http2, http] - end, - ConnectTimeout = proplists:get_value(connect_timeout, Options, 5000), - GunOpts = #{ - transport => tls, - %% if - %% Port == 443 -> tls; - %% true -> tcp - %% end, - protocols => Protocols, - connect_timeout => ConnectTimeout + GunPid = create_gun_connection(Host, Port, Options), + CredContext = #{ + access_key => State#state.access_key, + secret_access_key => State#state.secret_access_key, + security_token => State#state.security_token, + region => Region, + service => Service }, - - case gun:open(Host, Port, GunOpts) of - {ok, GunPid} -> - case gun:await_up(GunPid, ConnectTimeout) of - {ok, _Protocol} -> - CredContext = #{ - access_key => State#state.access_key, - secret_access_key => State#state.secret_access_key, - security_token => State#state.security_token, - region => Region, - service => Service - }, - {ok, {GunPid, CredContext}}; - {error, Reason} -> - gun:close(GunPid), - {error, {connection_failed, Reason}} - end; - {error, Reason} -> - {error, {gun_open_failed, Reason}} - end. + {ok, {GunPid, CredContext}}. %% Sign headers using credential context (no gen_server state needed) -spec sign_headers_with_context( @@ -982,21 +900,22 @@ direct_gun_request(GunPid, Method, Path, Headers, Body, Options) -> Headers ), Timeout = proplists:get_value(timeout, Options, ?DEFAULT_HTTP_TIMEOUT), - try - StreamRef = do_gun_request(GunPid, Method, Path, HeadersBin, Body), - case gun:await(GunPid, StreamRef, Timeout) of - {response, fin, Status, RespHeaders} -> - format_response({ok, {{http_version, Status, status_text(Status)}, RespHeaders, <<>>}}); - {response, nofin, Status, RespHeaders} -> - {ok, RespBody} = gun:await_body(GunPid, StreamRef, Timeout), - format_response({ok, {{http_version, Status, status_text(Status)}, RespHeaders, RespBody}}); - {error, Reason} -> - {error, Reason} - end - catch - _:Error -> - {error, Error} - end. + Response = try + StreamRef = do_gun_request(GunPid, Method, Path, HeadersBin, Body), + case gun:await(GunPid, StreamRef, Timeout) of + {response, fin, Status, RespHeaders} -> + {ok, {{http_version, Status, status_text(Status)}, RespHeaders, <<>>}}; + {response, nofin, Status, RespHeaders} -> + {ok, RespBody} = gun:await_body(GunPid, StreamRef, Timeout), + {ok, {{http_version, Status, status_text(Status)}, RespHeaders, RespBody}}; + {error, Reason} -> + {error, Reason} + end + catch + _:Error -> + {error, Error} + end, + format_response(Response). %% Internal credential validation (extracted from existing logic) -spec ensure_credentials_valid_internal(State :: state()) -> {ok, state()} | {error, term()}. diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl index da76d9cc6211..bed1e5f85967 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl @@ -29,12 +29,13 @@ headers(Request) -> URI = rabbitmq_aws_urilib:parse(Request#request.uri), {_, Host, _} = URI#uri.authority, - BodyLength = case Request#request.body of - Body when is_binary(Body) -> - size(Body); - Body when is_list(Body) -> - length(Body) - end, + BodyLength = + case Request#request.body of + Body when is_binary(Body) -> + size(Body); + Body when is_list(Body) -> + length(Body) + end, Headers = append_headers( RequestTimestamp, diff --git a/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl b/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl index 44d5cf917edf..66c23e0f65cc 100644 --- a/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl +++ b/deps/rabbitmq_aws/test/rabbitmq_aws_tests.erl @@ -27,7 +27,7 @@ init_test_() -> os:unsetenv("AWS_SECRET_ACCESS_KEY"), Expectation = {state, "Sésame", "ouvre-toi", undefined, undefined, "us-west-3", undefined, - undefined, #{}}, + undefined}, ?assertEqual(Expectation, State) end}, {"error", fun() -> @@ -39,7 +39,7 @@ init_test_() -> ok = gen_server:stop(Pid), Expectation = {state, undefined, undefined, undefined, undefined, "us-west-3", undefined, - test_result, #{}}, + test_result}, ?assertEqual(Expectation, State), meck:validate(rabbitmq_aws_config) end} @@ -50,8 +50,7 @@ terminate_test() -> ok, rabbitmq_aws:terminate( foo, - {state, undefined, undefined, undefined, undefined, "us-west-3", undefined, test_result, - #{}} + {state, undefined, undefined, undefined, undefined, "us-west-3", undefined, test_result} ) ). @@ -224,9 +223,7 @@ gen_server_call_test_() -> {reply, {ok, {[{<<"content-type">>, <<"application/json">>}], [{"pass", true}]}}, - State#state{ - gun_connections = #{"ec2.us-east-1.amazonaws.com:443" => pid} - }}, + State}, Result = rabbitmq_aws:handle_call( {request, Service, Method, Headers, Path, Body, Options, Host}, eunit, State ), @@ -456,7 +453,7 @@ perform_request_test_() -> Expectation = { {ok, {[{<<"content-type">>, <<"application/json">>}], [{"pass", true}]}}, - State#state{gun_connections = #{"ec2.us-east-1.amazonaws.com:443" => pid}} + State }, Result = rabbitmq_aws:perform_request( State, Service, Method, Headers, Path, Body, Options, Host From 42055f00d8c05a07c2ad66107e321577e9bb1e68 Mon Sep 17 00:00:00 2001 From: Simon Unge Date: Thu, 31 Jul 2025 17:58:01 +0000 Subject: [PATCH 05/11] And Formattet --- deps/rabbitmq_aws/src/rabbitmq_aws.erl | 35 ++++++++++++++------------ 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index 53576bdcbdd4..c83cd5040f82 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -212,7 +212,9 @@ request(Service, Method, Path, Body, Headers, HTTPOptions) -> %% of services such as DynamoDB. The response will automatically be decoded %% if it is either in JSON or XML format. %% @end -request({GunPid, _CredContext} = Handle, Method, Path, Body, Headers, HTTPOptions, _) when is_pid(GunPid) -> +request({GunPid, _CredContext} = Handle, Method, Path, Body, Headers, HTTPOptions, _) when + is_pid(GunPid) +-> direct_request(Handle, Method, Path, Body, Headers, HTTPOptions); request(Service, Method, Path, Body, Headers, HTTPOptions, Endpoint) -> gen_server:call( @@ -900,21 +902,22 @@ direct_gun_request(GunPid, Method, Path, Headers, Body, Options) -> Headers ), Timeout = proplists:get_value(timeout, Options, ?DEFAULT_HTTP_TIMEOUT), - Response = try - StreamRef = do_gun_request(GunPid, Method, Path, HeadersBin, Body), - case gun:await(GunPid, StreamRef, Timeout) of - {response, fin, Status, RespHeaders} -> - {ok, {{http_version, Status, status_text(Status)}, RespHeaders, <<>>}}; - {response, nofin, Status, RespHeaders} -> - {ok, RespBody} = gun:await_body(GunPid, StreamRef, Timeout), - {ok, {{http_version, Status, status_text(Status)}, RespHeaders, RespBody}}; - {error, Reason} -> - {error, Reason} - end - catch - _:Error -> - {error, Error} - end, + Response = + try + StreamRef = do_gun_request(GunPid, Method, Path, HeadersBin, Body), + case gun:await(GunPid, StreamRef, Timeout) of + {response, fin, Status, RespHeaders} -> + {ok, {{http_version, Status, status_text(Status)}, RespHeaders, <<>>}}; + {response, nofin, Status, RespHeaders} -> + {ok, RespBody} = gun:await_body(GunPid, StreamRef, Timeout), + {ok, {{http_version, Status, status_text(Status)}, RespHeaders, RespBody}}; + {error, Reason} -> + {error, Reason} + end + catch + _:Error -> + {error, Error} + end, format_response(Response). %% Internal credential validation (extracted from existing logic) From feffcf5ec016946c4424f05e75649d6d02c0ece0 Mon Sep 17 00:00:00 2001 From: Simon Unge Date: Thu, 31 Jul 2025 22:50:43 +0000 Subject: [PATCH 06/11] Improve timeouts --- deps/rabbitmq_aws/src/rabbitmq_aws.erl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index c83cd5040f82..2f26d5248925 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -744,7 +744,7 @@ api_get_request_with_retries(Service, Path, Retries, WaitTimeBetweenRetries) -> gun_request(Method, URI, Headers, Body, Options) -> {Host, Port, Path} = parse_uri(URI), GunPid = create_gun_connection(Host, Port, Options), - Reply = direct_gun_request(GunPid, Method, Path, Headers, Body, ensure_timeout(Options)), + Reply = direct_gun_request(GunPid, Method, Path, Headers, Body, Options), gun:close(GunPid), Reply. From 279f7dda1713d01817f33617518cc1dfd355e82b Mon Sep 17 00:00:00 2001 From: Simon Unge Date: Mon, 4 Aug 2025 09:51:26 +0000 Subject: [PATCH 07/11] Use virtual host style --- deps/rabbitmq_aws/src/rabbitmq_aws.erl | 27 +++++++++----------------- 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index 2f26d5248925..0b4f7c7e9e48 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -154,7 +154,7 @@ direct_request({GunPid, CredContext}, Method, Path, Body, Headers, Options) -> #{service := Service, region := Region} = CredContext, % Build URI for signing Host = endpoint_host(Region, Service), - URI = "https://" ++ Host ++ Path, + URI = create_uri(Host, Path), % Sign headers directly (no gen_server call) SignedHeaders = sign_headers_with_context(CredContext, Method, URI, Headers, Body), % Make Gun request directly @@ -604,22 +604,6 @@ perform_request_with_creds(State, Method, URI, Headers, Body, Options0) -> perform_request_creds_error(State) -> {{error, {credentials, State#state.error}}, State}. -%% @doc Ensure that the timeout option is set and greater than 0 and less -%% than about 1/2 of the default gen_server:call timeout. This gives -%% enough time for a long connect and request phase to succeed. -%% @end --spec ensure_timeout(Options :: http_options()) -> http_options(). -ensure_timeout(Options) -> - case proplists:get_value(timeout, Options) of - undefined -> - Options ++ [{timeout, ?DEFAULT_HTTP_TIMEOUT}]; - Value when is_integer(Value) andalso Value >= 0 andalso Value =< ?DEFAULT_HTTP_TIMEOUT -> - Options; - _ -> - Options1 = proplists:delete(timeout, Options), - Options1 ++ [{timeout, ?DEFAULT_HTTP_TIMEOUT}] - end. - -spec sign_headers( State :: state(), Service :: string(), @@ -775,7 +759,7 @@ create_gun_connection(Host, Port, Options) -> % Default: try HTTP/2, fallback to HTTP/1.1 _ -> [http2, http] end, - ConnectTimeout = proplists:get_value(connect_timeout, Options, 5000), + ConnectTimeout = proplists:get_value(connect_timeout, Options, infinity), Opts = #{ transport => if @@ -798,6 +782,11 @@ create_gun_connection(Host, Port, Options) -> error({gun_open_failed, Reason}) end. +create_uri(Host, Path) when is_list(Path) -> + "https://" ++ Host ++ Path; +create_uri(Host, {Bucket, Key}) -> + "https://" ++ Bucket ++ "." ++ Host ++ "/" ++ Key. + parse_uri(URI) -> case string:split(URI, "://", leading) of [Scheme, Rest] -> @@ -894,6 +883,8 @@ sign_headers_with_context(CredContext, Method, URI, Headers, Body) -> Body :: body(), Options :: list() ) -> result(). +direct_gun_request(GunPid, Method, {_, Path}, Headers, Body, Options) -> + direct_gun_request(GunPid, Method, [$/|Path], Headers, Body, Options); direct_gun_request(GunPid, Method, Path, Headers, Body, Options) -> HeadersBin = lists:map( fun({Key, Value}) -> From 34fe08d82bd1d0ae240875ef604810a4afa8c0a4 Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Tue, 12 Aug 2025 14:19:56 -0400 Subject: [PATCH 08/11] minor: Fix type of `gun` conn pid There is `gun:stream_ref()` but not a custom type for the conn pid, so dialyzer was unhappy. --- deps/rabbitmq_aws/src/rabbitmq_aws.erl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index 0b4f7c7e9e48..e0a779bcedc9 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -46,7 +46,7 @@ -include_lib("kernel/include/logger.hrl"). %% Types for new concurrent API --type connection_handle() :: {gun:conn_ref(), credential_context()}. +-type connection_handle() :: {pid(), credential_context()}. -type credential_context() :: #{ access_key => access_key(), secret_access_key => secret_access_key(), @@ -876,7 +876,7 @@ sign_headers_with_context(CredContext, Method, URI, Headers, Body) -> %% Direct Gun request (extracted from existing gun_request function) -spec direct_gun_request( - GunPid :: gun:conn_ref(), + GunPid :: pid(), Method :: method(), Path :: path(), Headers :: headers(), From cae38060a78687c8b036081454f467d2ec66c10f Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Tue, 12 Aug 2025 14:23:02 -0400 Subject: [PATCH 09/11] rabbitmq_aws: Accept `iodata()` in request bodies Binaries are unnecessarily restrictive since gun allows sending iodata bodies. This would be useful for larger requests where we want to combine multiple lists or binaries without paying the cost of concatenating them. Note that `crypto:hash/2` accepts an `iodata()` arg. Also this commit avoids double-SHA256-hashing the request body in `rabbitmq_aws_sign`. The `request_hash/5` function performed a second `sha256/1` of the body to get the payload hash but this value already exists at the top of `headers/1`. --- deps/rabbitmq_aws/include/rabbitmq_aws.hrl | 2 +- deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl | 19 ++++++------------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/deps/rabbitmq_aws/include/rabbitmq_aws.hrl b/deps/rabbitmq_aws/include/rabbitmq_aws.hrl index 6a0cacd81131..88880b9ff992 100644 --- a/deps/rabbitmq_aws/include/rabbitmq_aws.hrl +++ b/deps/rabbitmq_aws/include/rabbitmq_aws.hrl @@ -103,7 +103,7 @@ -type value() :: string(). -type header() :: {Field :: field(), Value :: value()}. -type headers() :: [header()]. --type body() :: string() | binary(). +-type body() :: iodata(). -type ssl_options() :: [ssl:tls_client_option()]. diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl index bed1e5f85967..149fc2ed303c 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl @@ -28,14 +28,7 @@ headers(Request) -> PayloadHash = sha256(Request#request.body), URI = rabbitmq_aws_urilib:parse(Request#request.uri), {_, Host, _} = URI#uri.authority, - - BodyLength = - case Request#request.body of - Body when is_binary(Body) -> - size(Body); - Body when is_list(Body) -> - length(Body) - end, + BodyLength = iolist_size(Request#request.body), Headers = append_headers( RequestTimestamp, @@ -50,7 +43,7 @@ headers(Request) -> URI#uri.path, URI#uri.query, Headers, - Request#request.body + PayloadHash ), AuthValue = authorization( Request#request.access_key, @@ -211,11 +204,11 @@ query_string(QueryArgs) -> rabbitmq_aws_urilib:build_query_string(lists:keysort( Path :: path(), QArgs :: query_args(), Headers :: headers(), - Payload :: string() + PayloadHash :: string() ) -> string(). %% @doc Create the request hash value %% @end -request_hash(Method, Path, QArgs, Headers, Payload) -> +request_hash(Method, Path, QArgs, Headers, PayloadHash) -> RawPath = case string:slice(Path, 0, 1) of "/" -> Path; @@ -229,7 +222,7 @@ request_hash(Method, Path, QArgs, Headers, Payload) -> query_string(QArgs), canonical_headers(Headers), signed_headers(Headers), - sha256(Payload) + PayloadHash ], "\n" ), @@ -245,7 +238,7 @@ request_hash(Method, Path, QArgs, Headers, Payload) -> scope(AMZDate, Region, Service) -> string:join([AMZDate, Region, Service, "aws4_request"], "/"). --spec sha256(Value :: string()) -> string(). +-spec sha256(Value :: iodata()) -> string(). %% @doc Return the SHA-256 hash for the specified value. %% @end sha256(Value) -> From a7e30aad2210103ac1e27b16d6e974d7fa58162e Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Tue, 12 Aug 2025 15:47:20 -0400 Subject: [PATCH 10/11] rabbitmq_aws: Allow passing in payload sha256 signatures In some cases we can compute this incrementally with the streaming hash utilities from crypto: > Hash0 = crypto:hash_init(sha256), > Hash1 = crypto:hash_update(Hash0, Data0), %% ... > HashN = crypto:hash_update(HashN1, DataN), > Hash = crypto:hash_final(HashN). Especially for large bodies this lets us skip a lot of double work. Currently this is only added to the direct_request API, with the idea that the other method that blocks the server is deprecated. --- deps/rabbitmq_aws/src/rabbitmq_aws.erl | 35 ++++++++++++--------- deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl | 8 +++-- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index e0a779bcedc9..4ca12856de6b 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -155,8 +155,11 @@ direct_request({GunPid, CredContext}, Method, Path, Body, Headers, Options) -> % Build URI for signing Host = endpoint_host(Region, Service), URI = create_uri(Host, Path), + BodyHash = proplists:get_value(payload_hash, Options), % Sign headers directly (no gen_server call) - SignedHeaders = sign_headers_with_context(CredContext, Method, URI, Headers, Body), + SignedHeaders = sign_headers_with_context( + CredContext, Method, URI, Headers, Body, BodyHash + ), % Make Gun request directly direct_gun_request(GunPid, Method, Path, SignedHeaders, Body, Options). @@ -852,9 +855,10 @@ create_direct_connection(State, Service, Options) -> Method :: method(), URI :: string(), Headers :: headers(), - Body :: body() + Body :: body(), + BodyHash :: iodata() ) -> headers(). -sign_headers_with_context(CredContext, Method, URI, Headers, Body) -> +sign_headers_with_context(CredContext, Method, URI, Headers, Body, BodyHash) -> #{ access_key := AccessKey, secret_access_key := SecretKey, @@ -862,17 +866,20 @@ sign_headers_with_context(CredContext, Method, URI, Headers, Body) -> region := Region, service := Service } = CredContext, - rabbitmq_aws_sign:headers(#request{ - access_key = AccessKey, - secret_access_key = SecretKey, - security_token = SecurityToken, - region = Region, - service = Service, - method = Method, - uri = URI, - headers = Headers, - body = Body - }). + rabbitmq_aws_sign:headers( + #request{ + access_key = AccessKey, + secret_access_key = SecretKey, + security_token = SecurityToken, + region = Region, + service = Service, + method = Method, + uri = URI, + headers = Headers, + body = Body + }, + BodyHash + ). %% Direct Gun request (extracted from existing gun_request function) -spec direct_gun_request( diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl index 149fc2ed303c..c5f9b0bddd9d 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_sign.erl @@ -8,7 +8,7 @@ -module(rabbitmq_aws_sign). %% API --export([headers/1, request_hash/5]). +-export([headers/1, headers/2, request_hash/5]). %% Export all for unit tests -ifdef(TEST). @@ -24,8 +24,12 @@ %% @doc Create the signed request headers %% end headers(Request) -> + headers(Request, undefined). + +headers(Request, undefined) -> + headers(Request, sha256(Request#request.body)); +headers(Request, PayloadHash) -> RequestTimestamp = local_time(), - PayloadHash = sha256(Request#request.body), URI = rabbitmq_aws_urilib:parse(Request#request.uri), {_, Host, _} = URI#uri.authority, BodyLength = iolist_size(Request#request.body), From 517270025f0ea99c224170ce3bb12ae7f662bedc Mon Sep 17 00:00:00 2001 From: Michael Davis Date: Thu, 14 Aug 2025 12:40:44 -0400 Subject: [PATCH 11/11] erlfmt rabbitmq_aws --- deps/rabbitmq_aws/src/rabbitmq_aws.erl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index 4ca12856de6b..94452827e008 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -158,7 +158,7 @@ direct_request({GunPid, CredContext}, Method, Path, Body, Headers, Options) -> BodyHash = proplists:get_value(payload_hash, Options), % Sign headers directly (no gen_server call) SignedHeaders = sign_headers_with_context( - CredContext, Method, URI, Headers, Body, BodyHash + CredContext, Method, URI, Headers, Body, BodyHash ), % Make Gun request directly direct_gun_request(GunPid, Method, Path, SignedHeaders, Body, Options). @@ -891,7 +891,7 @@ sign_headers_with_context(CredContext, Method, URI, Headers, Body, BodyHash) -> Options :: list() ) -> result(). direct_gun_request(GunPid, Method, {_, Path}, Headers, Body, Options) -> - direct_gun_request(GunPid, Method, [$/|Path], Headers, Body, Options); + direct_gun_request(GunPid, Method, [$/ | Path], Headers, Body, Options); direct_gun_request(GunPid, Method, Path, Headers, Body, Options) -> HeadersBin = lists:map( fun({Key, Value}) ->