diff --git a/docsrc/tutorials/images/cuda_graphs.png b/docsrc/tutorials/images/cuda_graphs.png new file mode 100755 index 0000000000..4e6632868d Binary files /dev/null and b/docsrc/tutorials/images/cuda_graphs.png differ diff --git a/docsrc/tutorials/images/cuda_graphs_breaks.png b/docsrc/tutorials/images/cuda_graphs_breaks.png new file mode 100755 index 0000000000..60247b1af0 Binary files /dev/null and b/docsrc/tutorials/images/cuda_graphs_breaks.png differ diff --git a/examples/dynamo/torch_export_cudagraphs.py b/examples/dynamo/torch_export_cudagraphs.py index fb31766b7c..e316dffc58 100644 --- a/examples/dynamo/torch_export_cudagraphs.py +++ b/examples/dynamo/torch_export_cudagraphs.py @@ -4,7 +4,12 @@ Torch Export with Cudagraphs ====================================================== -This interactive script is intended as an overview of the process by which the Torch-TensorRT Cudagraphs integration can be used in the `ir="dynamo"` path. The functionality works similarly in the `torch.compile` path as well. +CUDA Graphs allow multiple GPU operations to be launched through a single CPU operation, reducing launch overheads and improving GPU utilization. Torch-TensorRT provides a simple interface to enable CUDA graphs. This feature allows users to easily leverage the performance benefits of CUDA graphs without managing the complexities of capture and replay manually. + +.. image:: /tutorials/images/cuda_graphs.png + +This interactive script is intended as an overview of the process by which the Torch-TensorRT Cudagraphs integration can be used in the `ir="dynamo"` path. The functionality works similarly in the +`torch.compile` path as well. """ # %% @@ -70,19 +75,25 @@ # %% # Cuda graphs with module that contains graph breaks -# ---------------------------------- +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # # When CUDA Graphs are applied to a TensorRT model that contains graph breaks, each break introduces additional # overhead. This occurs because graph breaks prevent the entire model from being executed as a single, continuous # optimized unit. As a result, some of the performance benefits typically provided by CUDA Graphs, such as reduced # kernel launch overhead and improved execution efficiency, may be diminished. +# # Using a wrapped runtime module with CUDA Graphs allows you to encapsulate sequences of operations into graphs -# that can be executed efficiently, even in the presence of graph breaks. -# If TensorRT module has graph breaks, CUDA Graph context manager returns a wrapped_module. This module captures entire -# execution graph, enabling efficient replay during subsequent inferences by reducing kernel launch overheads -# and improving performance. Note that initializing with the wrapper module involves a warm-up phase where the +# that can be executed efficiently, even in the presence of graph breaks. If TensorRT module has graph breaks, CUDA +# Graph context manager returns a wrapped_module. And this module captures entire execution graph, enabling efficient +# replay during subsequent inferences by reducing kernel launch overheads and improving performance. +# +# Note that initializing with the wrapper module involves a warm-up phase where the # module is executed several times. This warm-up ensures that memory allocations and initializations are not # recorded in CUDA Graphs, which helps maintain consistent execution paths and optimize performance. +# +# .. image:: /tutorials/images/cuda_graphs_breaks.png +# :scale: 60 % +# :align: left class SampleModel(torch.nn.Module): diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index e860c5762f..1cc6d6c785 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -115,12 +115,12 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . contiguous_inputs[i].dtype == self.inputs[i].dtype ), f"Dtype mismatch for {i}th input. Expect {self.inputs[i].dtype}, got {contiguous_inputs[i].dtype}." - if need_cudagraphs_record: - # If cudagraphs is enabled, this memory is reserved for future cudagraph runs - # Clone is required to avoid re-using user-provided GPU memory - self._input_buffers[i] = contiguous_inputs[i].clone() - else: - self._input_buffers[i].copy_(contiguous_inputs[i]) + if need_cudagraphs_record: + # If cudagraphs is enabled, this memory is reserved for future cudagraph runs + # Clone is required to avoid re-using user-provided GPU memory + self._input_buffers[i] = contiguous_inputs[i].clone() + else: + self._input_buffers[i].copy_(contiguous_inputs[i]) self._caller_stream = torch.cuda.current_stream() if (