From c491bb45ee27f8f634e0d0c54835f26d58004694 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 22 Oct 2025 16:39:59 +0000 Subject: [PATCH 1/2] auth: clone the client request body before roundtripping RoundTrippers may read and close the body, so be careful to clone before roundtripping during client oauth, as the request may be issued multiple times. Fixes #590 --- auth/client.go | 37 +++++++++++++++++++++++++++++++- auth/client_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/auth/client.go b/auth/client.go index 7c65e9de..f313be47 100644 --- a/auth/client.go +++ b/auth/client.go @@ -7,8 +7,10 @@ package auth import ( + "bytes" "context" "errors" + "io" "net/http" "sync" @@ -67,6 +69,31 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { base := t.opts.Base t.mu.Unlock() + // req1 is our first request in the authorization flow. + // + // If we mutate its body, we must clone it first. + var ( + // If haveBody is set, the request has a nontrivial body, and we need avoid + // reading (or closing) it multiple times. In that case, bodyBytes is its + // content. + haveBody bool + bodyBytes []byte + ) + if req.Body != nil && req.Body != http.NoBody { + // if we're setting Body, we must mutate first. + req = req.Clone(req.Context()) + haveBody = true + var err error + bodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return nil, err + } + // Now that we've read the request body, http.RoundTripper requires that we + // close it. + req.Body.Close() // ignore error + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + resp, err := base.RoundTrip(req) if err != nil { return nil, err @@ -97,7 +124,15 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { } t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} } - return t.opts.Base.RoundTrip(req.Clone(req.Context())) + + // If we don't have a body, the request is reusable, though it will be cloned + // by the base. However, if we've had to read the body, we must clone. + if haveBody { + req = req.Clone(req.Context()) + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + return t.opts.Base.RoundTrip(req) } func extractResourceMetadataURL(authHeaders []string) string { diff --git a/auth/client_test.go b/auth/client_test.go index 310fc56e..63a45b93 100644 --- a/auth/client_test.go +++ b/auth/client_test.go @@ -10,13 +10,24 @@ import ( "context" "errors" "fmt" + "io" "net/http" "net/http/httptest" + "strings" "testing" "golang.org/x/oauth2" ) +// A basicReader is an io.Reader to be used as a non-rereadable request body. +// +// net/http has special handling for strings.Reader that we want to avoid. +type basicReader struct { + r *strings.Reader +} + +func (r *basicReader) Read(p []byte) (n int, err error) { return r.r.Read(p) } + // TestHTTPTransport validates the OAuth HTTPTransport. func TestHTTPTransport(t *testing.T) { const testToken = "test-token-123" @@ -27,6 +38,20 @@ func TestHTTPTransport(t *testing.T) { // authServer simulates a resource that requires OAuth. authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + // Ensure that the body was properly cloned, by reading it completely. + // If the body is not cloned, reading it the second time may yield no + // bytes. + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if len(body) == 0 { + http.Error(w, "empty body", http.StatusBadRequest) + return + } + } authHeader := r.Header.Get("Authorization") if authHeader == fmt.Sprintf("Bearer %s", testToken) { w.WriteHeader(http.StatusOK) @@ -82,6 +107,32 @@ func TestHTTPTransport(t *testing.T) { } }) + t.Run("request body is cloned", func(t *testing.T) { + handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) { + if args.ResourceMetadataURL != "http://metadata.example.com" { + t.Errorf("handler got metadata URL %q, want %q", args.ResourceMetadataURL, "http://metadata.example.com") + } + return fakeTokenSource, nil + } + + transport, err := NewHTTPTransport(handler, nil) + if err != nil { + t.Fatalf("NewHTTPTransport() failed: %v", err) + } + client := &http.Client{Transport: transport} + + resp, err := client.Post(authServer.URL, "application/json", &basicReader{strings.NewReader("{}")}) + // resp, err := client.Post(authServer.URL, "application/json", strings.NewReader("{}")) + if err != nil { + t.Fatalf("client.Post() failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("got status %d, want %d", resp.StatusCode, http.StatusOK) + } + }) + t.Run("handler returns error", func(t *testing.T) { handlerErr := errors.New("user rejected auth") handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) { From 3e333bde1880d7f5603d0a8797ea21d065b59621 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 23 Oct 2025 13:27:50 +0000 Subject: [PATCH 2/2] clean up stale comments --- auth/client.go | 3 --- auth/client_test.go | 1 - 2 files changed, 4 deletions(-) diff --git a/auth/client.go b/auth/client.go index f313be47..3ed6b695 100644 --- a/auth/client.go +++ b/auth/client.go @@ -69,9 +69,6 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { base := t.opts.Base t.mu.Unlock() - // req1 is our first request in the authorization flow. - // - // If we mutate its body, we must clone it first. var ( // If haveBody is set, the request has a nontrivial body, and we need avoid // reading (or closing) it multiple times. In that case, bodyBytes is its diff --git a/auth/client_test.go b/auth/client_test.go index 63a45b93..e1b7dc70 100644 --- a/auth/client_test.go +++ b/auth/client_test.go @@ -122,7 +122,6 @@ func TestHTTPTransport(t *testing.T) { client := &http.Client{Transport: transport} resp, err := client.Post(authServer.URL, "application/json", &basicReader{strings.NewReader("{}")}) - // resp, err := client.Post(authServer.URL, "application/json", strings.NewReader("{}")) if err != nil { t.Fatalf("client.Post() failed: %v", err) }