diff --git a/auth/client.go b/auth/client.go index 7c65e9de..3ed6b695 100644 --- a/auth/client.go +++ b/auth/client.go @@ -7,8 +7,10 @@ package auth import ( + "bytes" "context" "errors" + "io" "net/http" "sync" @@ -67,6 +69,28 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { base := t.opts.Base t.mu.Unlock() + var ( + // If haveBody is set, the request has a nontrivial body, and we need avoid + // reading (or closing) it multiple times. In that case, bodyBytes is its + // content. + haveBody bool + bodyBytes []byte + ) + if req.Body != nil && req.Body != http.NoBody { + // if we're setting Body, we must mutate first. + req = req.Clone(req.Context()) + haveBody = true + var err error + bodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return nil, err + } + // Now that we've read the request body, http.RoundTripper requires that we + // close it. + req.Body.Close() // ignore error + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + resp, err := base.RoundTrip(req) if err != nil { return nil, err @@ -97,7 +121,15 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { } t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} } - return t.opts.Base.RoundTrip(req.Clone(req.Context())) + + // If we don't have a body, the request is reusable, though it will be cloned + // by the base. However, if we've had to read the body, we must clone. + if haveBody { + req = req.Clone(req.Context()) + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + return t.opts.Base.RoundTrip(req) } func extractResourceMetadataURL(authHeaders []string) string { diff --git a/auth/client_test.go b/auth/client_test.go index 310fc56e..e1b7dc70 100644 --- a/auth/client_test.go +++ b/auth/client_test.go @@ -10,13 +10,24 @@ import ( "context" "errors" "fmt" + "io" "net/http" "net/http/httptest" + "strings" "testing" "golang.org/x/oauth2" ) +// A basicReader is an io.Reader to be used as a non-rereadable request body. +// +// net/http has special handling for strings.Reader that we want to avoid. +type basicReader struct { + r *strings.Reader +} + +func (r *basicReader) Read(p []byte) (n int, err error) { return r.r.Read(p) } + // TestHTTPTransport validates the OAuth HTTPTransport. func TestHTTPTransport(t *testing.T) { const testToken = "test-token-123" @@ -27,6 +38,20 @@ func TestHTTPTransport(t *testing.T) { // authServer simulates a resource that requires OAuth. authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + // Ensure that the body was properly cloned, by reading it completely. + // If the body is not cloned, reading it the second time may yield no + // bytes. + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if len(body) == 0 { + http.Error(w, "empty body", http.StatusBadRequest) + return + } + } authHeader := r.Header.Get("Authorization") if authHeader == fmt.Sprintf("Bearer %s", testToken) { w.WriteHeader(http.StatusOK) @@ -82,6 +107,31 @@ func TestHTTPTransport(t *testing.T) { } }) + t.Run("request body is cloned", func(t *testing.T) { + handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) { + if args.ResourceMetadataURL != "http://metadata.example.com" { + t.Errorf("handler got metadata URL %q, want %q", args.ResourceMetadataURL, "http://metadata.example.com") + } + return fakeTokenSource, nil + } + + transport, err := NewHTTPTransport(handler, nil) + if err != nil { + t.Fatalf("NewHTTPTransport() failed: %v", err) + } + client := &http.Client{Transport: transport} + + resp, err := client.Post(authServer.URL, "application/json", &basicReader{strings.NewReader("{}")}) + if err != nil { + t.Fatalf("client.Post() failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("got status %d, want %d", resp.StatusCode, http.StatusOK) + } + }) + t.Run("handler returns error", func(t *testing.T) { handlerErr := errors.New("user rejected auth") handler := func(ctx context.Context, args OAuthHandlerArgs) (oauth2.TokenSource, error) {