From 13a4976a208dd7c46b2e8680e9ac6813f4b353b3 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 10 Feb 2025 16:37:09 -0800 Subject: [PATCH 01/17] support dds and nonzero op --- .../dynamo/conversion/aten_ops_converters.py | 17 ++ .../dynamo/conversion/impl/unary/ops.py | 15 ++ .../runtime/_PythonTorchTensorRTModule.py | 217 ++++++++++++++---- .../py/dynamo/conversion/test_nonzero_aten.py | 74 ++++++ 4 files changed, 280 insertions(+), 43 deletions(-) create mode 100644 tests/py/dynamo/conversion/test_nonzero_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 2a9255ed68..7792c0a456 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -3552,3 +3552,20 @@ def aten_ops_full( fill_value=args[1], dtype=kwargs.get("dtype", None), ) + + +@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default) +def aten_ops_nonzero( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.nonzero( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 34b667acf1..89e490392d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -624,3 +624,18 @@ def native_dropout( mask = np.ones(input_val.shape, dtype=bool) mask = get_trt_tensor(ctx, mask, f"{name}_mask") return identity_layer.get_output(0), mask + + +def nonzero( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + non_zero_layer = ctx.net.add_non_zero(input_val) + set_layer_name(non_zero_layer, target, f"{name}_non_zero", source_ir) + shuffle_layer = ctx.net.add_shuffle(non_zero_layer.get_output(0)) + shuffle_layer.first_transpose = trt.Permutation([1, 0]) + set_layer_name(shuffle_layer, target, f"{name}_transpose", source_ir) + return shuffle_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 9086de657f..377450517d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -23,6 +23,41 @@ logger = logging.getLogger(__name__) +class DynamicOutputAllocator(trt.IOutputAllocator): # type: ignore[misc] + def __init__(self, output_dtypes: Dict[str, torch.dtype]) -> None: + trt.IOutputAllocator.__init__(self) + self.buffers: Dict[str, torch.Tensor] = {} + self.shapes: Dict[str, Tuple[int, ...]] = {} + self.dtypes: Dict[str, torch.dtype] = output_dtypes + + def reallocate_output_async( + self, + tensor_name: str, + memory: int, + size: int, + alignment: int, + stream: torch.cuda.Stream, + ) -> Any: + shape = (size,) + if tensor_name not in self.buffers: + self.buffers[tensor_name] = torch.empty( + shape, + dtype=self.dtypes[tensor_name], + device=torch.cuda.current_device(), + ) + else: + if self.buffers[tensor_name].shape != shape: + self.buffers[tensor_name] = torch.empty( + shape, + dtype=self.dtypes[tensor_name], + device=torch.cuda.current_device(), + ) + return self.buffers[tensor_name].data_ptr() + + def notify_shape(self, tensor_name: str, shape: Tuple[int, ...]) -> None: + self.shapes[tensor_name] = tuple(shape) + + class TorchTRTRuntimeStates: def __init__(self, new_cudagraphs: bool): # Indicates whether CUDAGraphs were enabled in the previous execute_engine @@ -164,8 +199,11 @@ def __init__( self.runtime_states = TorchTRTRuntimeStates( torch_tensorrt.runtime.get_cudagraphs_mode() ) + + self.contains_dds_layer = False self.pre_allocated_outputs: List[torch.Tensor] = [] self.use_pre_allocated_outputs = False + self.output_allocator: Optional[DynamicOutputAllocator] = None if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() @@ -238,9 +276,19 @@ def setup_engine(self) -> None: for output_name in self.output_names ] + self.contains_dds_layer = self._check_dds_layer() + if self.contains_dds_layer: + self.setup_output_allocator() + if torch_tensorrt.runtime.get_cudagraphs_mode(): self.cudagraph = torch.cuda.CUDAGraph() + def _check_dds_layer(self) -> bool: + layer_info = self.get_layer_info() + if "trainStation" in layer_info: # contains dds layer + return True + return False + def _check_initialized(self) -> None: if not self.initialized: raise RuntimeError("PythonTorchTensorRTModule is not initialized.") @@ -358,19 +406,22 @@ def create_output_tensors(self) -> List[torch.Tensor]: def set_pre_allocated_outputs(self, enable: bool) -> None: self.use_pre_allocated_outputs = enable + def setup_output_allocator(self) -> None: + if self.output_allocator is None: + output_dtypes_dict = {} + for o, output_name in enumerate(self.output_names): + output_dtypes_dict[output_name] = self.output_dtypes[o] + self.output_allocator = DynamicOutputAllocator(output_dtypes_dict) + + for output_name in self.output_names: + if not self.context.set_output_allocator( + output_name, self.output_allocator + ): + raise RuntimeError(f"Failed to set output allocator for {output_name}") + def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: - # Ensure inputs are available in all scopes and cast symbolic integers to Tensors - contiguous_inputs: List[torch.Tensor] = [ - (i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) - for i in inputs - ] - with ( - torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward") - if self.profiling_enabled - else nullcontext() - ): - self._check_initialized() + def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]: cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() shape_changed = self.validate_input_shapes(inputs) ( @@ -389,38 +440,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . self._input_buffers = [None] * len(self.input_names) self._output_buffers = [None] * len(self.output_names) - # If in safe mode, check at each iteration for whether a switch is required - if ( - torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE - ): - curr_device_id = torch.cuda.current_device() - curr_device_properties = torch.cuda.get_device_properties( - curr_device_id - ) - logger.debug(f"Current Device: cuda:{curr_device_id}") - - # If a switch is required, move all inputs to new device and set as active device - if _is_switch_required( - curr_device_id, - self.target_device_id, - curr_device_properties, - self.target_device_properties, - ): - device_id, _ = _select_rt_device( - curr_device_id, - self.target_device_id, - self.target_device_properties, - ) - - # Update current device - device = torch.device(device_id) - torch.cuda.set_device(device_id) - - contiguous_inputs = [ - tensor.to(device) for tensor in contiguous_inputs - ] - logger.warning(f"Moved all input Tensors to cuda:{device_id}") - with ( torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:ProcessInputs" @@ -536,6 +555,118 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . return outputs + def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessInputs" + ) + if self.profiling_enabled + else nullcontext() + ): + assert len(contiguous_inputs) == len( + self.input_names + ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." + + self.setup_input_tensors(contiguous_inputs, False, False) + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:TensorRTRuntime" + ) + if self.profiling_enabled + else nullcontext() + ): + self._caller_stream = torch.cuda.current_stream() + if ( + self._engine_stream == torch.cuda.default_stream() + or self._engine_stream is None + ): + self._engine_stream = torch.cuda.Stream() + + self._engine_stream.wait_stream(self._caller_stream) + + with torch.cuda.stream(self._engine_stream): + self.context.execute_async_v3( + self._engine_stream.cuda_stream + ) # The OutputAllocator is called by execute_async_v3() + + self._caller_stream.wait_stream(self._engine_stream) + + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:ProcessOutputs" + ) + if self.profiling_enabled + else nullcontext() + ): + outputs = [] + assert self.output_allocator is not None + for o, output_name in enumerate(self.output_names): + shape = self.output_allocator.shapes.get(output_name, None) + dtype = self.output_dtypes[o] + output = ( + self.output_allocator.buffers.get(output_name, None) + .clone() + .detach() + ) + prod = int(torch.prod(torch.tensor(shape))) + output = output.reshape(-1).view(dtype)[:prod].reshape(shape) + outputs.append(output) + + if len(outputs) == 1: + return outputs[0] + + return outputs + + # Run forward function + contiguous_inputs: List[torch.Tensor] = [ + (i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) + for i in inputs + ] + with ( + torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward") + if self.profiling_enabled + else nullcontext() + ): + self._check_initialized() + + # If in safe mode, check at each iteration for whether a switch is required + if ( + torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE + ): + curr_device_id = torch.cuda.current_device() + curr_device_properties = torch.cuda.get_device_properties( + curr_device_id + ) + logger.debug(f"Current Device: cuda:{curr_device_id}") + + # If a switch is required, move all inputs to new device and set as active device + if _is_switch_required( + curr_device_id, + self.target_device_id, + curr_device_properties, + self.target_device_properties, + ): + device_id, _ = _select_rt_device( + curr_device_id, + self.target_device_id, + self.target_device_properties, + ) + + # Update current device + device = torch.device(device_id) + torch.cuda.set_device(device_id) + + contiguous_inputs = [ + tensor.to(device) for tensor in contiguous_inputs + ] + logger.warning(f"Moved all input Tensors to cuda:{device_id}") + + if self.contains_dds_layer: + return run_output_allocator() + else: + return run_cuda_graph() + def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None: """ Enable TensorRT profiling. After calling this function, TensorRT will report diff --git a/tests/py/dynamo/conversion/test_nonzero_aten.py b/tests/py/dynamo/conversion/test_nonzero_aten.py new file mode 100644 index 0000000000..c75bd4d4a9 --- /dev/null +++ b/tests/py/dynamo/conversion/test_nonzero_aten.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestNonZeroConverter(DispatchTestCase): + @parameterized.expand( + [ + ((10,), torch.int), + ((1, 20), torch.int32), + ((2, 3), torch.int64), + ((2, 3, 4), torch.float), + ((2, 3, 4, 5), torch.float), + ] + ) + def test_non_zero(self, input_shape, dtype): + class NonZero(nn.Module): + def forward(self, input): + return torch.ops.aten.nonzero.default(input) + + inputs = [torch.randint(low=0, high=3, size=input_shape, dtype=dtype)] + self.run_test( + NonZero(), + inputs, + ) + + @parameterized.expand( + [ + ( + "1d", + (1,), + (10,), + (100,), + torch.int32, + ), + ( + "2d", + (1, 2), + (5, 10), + (20, 40), + torch.float16, + ), + ( + "3d", + (1, 2, 3), + (5, 10, 20), + (30, 40, 50), + torch.float, + ), + ] + ) + def test_nonzero_dynamic_shape(self, _, min_shape, opt_shape, max_shape, dtype): + class NonZero(nn.Module): + def forward(self, input): + return torch.ops.aten.nonzero.default(input) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=dtype, + ), + ] + + self.run_test_with_dynamic_shape(NonZero(), input_specs) + + +if __name__ == "__main__": + run_tests() From 975e5231e274f6892b9bddf9fc09920f6a5a0ee8 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 12 Feb 2025 21:59:56 -0800 Subject: [PATCH 02/17] check output shape to implicitly decide whether network is dds --- .../dynamo/conversion/_TRTInterpreter.py | 20 +++++++++ .../dynamo/conversion/_conversion.py | 3 +- .../runtime/_PythonTorchTensorRTModule.py | 42 +++++++++++-------- .../dynamo/runtime/_TorchTensorRTModule.py | 10 ++++- tests/py/dynamo/conversion/harness.py | 2 + .../py/dynamo/conversion/test_nonzero_aten.py | 27 +++++++++++- 6 files changed, 83 insertions(+), 21 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 2f35a6d124..f8ac22ecf8 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -64,6 +64,7 @@ class TRTInterpreterResult(NamedTuple): input_names: Sequence[str] output_names: Sequence[str] weight_name_map: Optional[dict[Any, Any]] + engine_is_dds: bool class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc] @@ -138,6 +139,9 @@ def __init__( # Engine cache for storing and reusing TRT engines self.engine_cache = engine_cache + # Whether the engine is data-dependent shape (dds) + self.engine_is_dds: bool = False + def validate_conversion(self) -> Set[str]: missing_converters: Set[str] = set() @@ -582,6 +586,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No self.input_specs, self.compilation_settings, self.weight_name_map, + self.engine_is_dds, ), ) @@ -596,6 +601,7 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: cached_engine_input_specs, engine_compilation_settings, self.weight_name_map, + self.engine_is_dds, ) = cached_data setting_compatiblity, incompattible_settings = settings_are_compatible( @@ -657,9 +663,20 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: self._input_names, self._output_names, self.weight_name_map, + self.engine_is_dds, ) return None + def check_dds(self, serialized_engine: bytes, output_names: List[str]) -> bool: + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(serialized_engine) + + for output_name in output_names: + output_shape = engine.get_tensor_shape(output_name) + if -1 in output_shape: + return True + return False + def run( self, strict_type_constraints: bool = False, @@ -716,6 +733,8 @@ def run( ) assert serialized_engine + self.engine_is_dds = self.check_dds(serialized_engine, self._output_names) + _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) @@ -742,6 +761,7 @@ def run( self._input_names, self._output_names, self.weight_name_map, + self.engine_is_dds, ) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 1dad18989c..82c3ceb48e 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -30,7 +30,7 @@ def infer_module_output_dtypes( """ outputs = [node for node in module.graph.nodes if node.op == "output"] outputs = outputs[0].args - return get_output_dtypes(outputs, truncate_double) + return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return] def interpret_module_to_result( @@ -112,4 +112,5 @@ def convert_module( name=name, settings=settings, weight_name_map=interpreter_result.weight_name_map, + engine_is_dds=interpreter_result.engine_is_dds, ) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 377450517d..4d02ac3bc7 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -127,6 +127,7 @@ def __init__( name: str = "", settings: CompilationSettings = CompilationSettings(), weight_name_map: Optional[dict[Any, Any]] = None, + engine_is_dds: bool = False, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine @@ -140,6 +141,7 @@ def __init__( name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed weight_name_map (dict): Mapping of engine weight name to state_dict weight name + engine_is_dds (bool): Whether the engine is Data Dependent Shape Example: @@ -200,7 +202,7 @@ def __init__( torch_tensorrt.runtime.get_cudagraphs_mode() ) - self.contains_dds_layer = False + self.engine_is_dds = engine_is_dds self.pre_allocated_outputs: List[torch.Tensor] = [] self.use_pre_allocated_outputs = False self.output_allocator: Optional[DynamicOutputAllocator] = None @@ -276,19 +278,12 @@ def setup_engine(self) -> None: for output_name in self.output_names ] - self.contains_dds_layer = self._check_dds_layer() - if self.contains_dds_layer: - self.setup_output_allocator() + if self.engine_is_dds: + self.create_output_allocator() if torch_tensorrt.runtime.get_cudagraphs_mode(): self.cudagraph = torch.cuda.CUDAGraph() - def _check_dds_layer(self) -> bool: - layer_info = self.get_layer_info() - if "trainStation" in layer_info: # contains dds layer - return True - return False - def _check_initialized(self) -> None: if not self.initialized: raise RuntimeError("PythonTorchTensorRTModule is not initialized.") @@ -406,19 +401,13 @@ def create_output_tensors(self) -> List[torch.Tensor]: def set_pre_allocated_outputs(self, enable: bool) -> None: self.use_pre_allocated_outputs = enable - def setup_output_allocator(self) -> None: + def create_output_allocator(self) -> None: if self.output_allocator is None: output_dtypes_dict = {} for o, output_name in enumerate(self.output_names): output_dtypes_dict[output_name] = self.output_dtypes[o] self.output_allocator = DynamicOutputAllocator(output_dtypes_dict) - for output_name in self.output_names: - if not self.context.set_output_allocator( - output_name, self.output_allocator - ): - raise RuntimeError(f"Failed to set output allocator for {output_name}") - def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]: @@ -569,6 +558,23 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: self.setup_input_tensors(contiguous_inputs, False, False) + with ( + torch.autograd.profiler.record_function( + "PythonTorchTensorRTModule:SetupOutputAllocator" + ) + if self.profiling_enabled + else nullcontext() + ): + self.create_output_allocator() + # need to set output allocator every run + for output_name in self.output_names: + if not self.context.set_output_allocator( + output_name, self.output_allocator + ): + raise RuntimeError( + f"Failed to set output allocator for {output_name}" + ) + with ( torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:TensorRTRuntime" @@ -662,7 +668,7 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: ] logger.warning(f"Moved all input Tensors to cuda:{device_id}") - if self.contains_dds_layer: + if self.engine_is_dds: return run_output_allocator() else: return run_cuda_graph() diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index b809e70ddf..ae0cc3cec6 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -79,6 +79,7 @@ def __init__( name: str = "", settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed weight_name_map: Optional[dict[Any, Any]] = None, + engine_is_dds: bool = False, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines @@ -97,6 +98,7 @@ def __init__( name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed weight_name_map (dict): Mapping of engine weight name to state_dict weight name + engine_is_dds (bool): Whether the engine is Data Dependent Shape Example: @@ -132,6 +134,7 @@ def __init__( self.weight_name_map = weight_name_map self.serialized_engine = serialized_engine self.engine = None + self.engine_is_dds = engine_is_dds if ( serialized_engine @@ -146,7 +149,11 @@ def _pack_engine_info(self) -> List[str | bytes]: if self.settings.device is not None else Device._current_device() ) - metadata = {"settings": self.settings, "weight_name_map": self.weight_name_map} + metadata = { + "settings": self.settings, + "weight_name_map": self.weight_name_map, + "engine_is_dds": self.engine_is_dds, + } target_platform = ( Platform.current_platform() if not self.settings.enable_cross_compile_for_windows @@ -263,6 +270,7 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: metadata = TorchTensorRTModule.decode_metadata(serialized_metadata) self.settings = metadata["settings"] self.weight_name_map = metadata["weight_name_map"] + self.engine_is_dds = metadata["engine_is_dds"] else: self.engine = None diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 9813548a10..3e17273b98 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -207,6 +207,7 @@ def run_test( input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), name="test_engine", + engine_is_dds=interpreter_result.engine_is_dds, ) mod = mod.cuda() if pyt_inputs is not None: @@ -289,6 +290,7 @@ def run_test_custom_compare_results( input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), name="test_engine", + engine_is_dds=interpreter_result.engine_is_dds, ) res_trt = trt_mod(*cuda_inputs).cpu() res_cpu = mod(*cuda_inputs).cpu() diff --git a/tests/py/dynamo/conversion/test_nonzero_aten.py b/tests/py/dynamo/conversion/test_nonzero_aten.py index c75bd4d4a9..4479993780 100644 --- a/tests/py/dynamo/conversion/test_nonzero_aten.py +++ b/tests/py/dynamo/conversion/test_nonzero_aten.py @@ -19,8 +19,33 @@ class TestNonZeroConverter(DispatchTestCase): ) def test_non_zero(self, input_shape, dtype): class NonZero(nn.Module): + # This is a DDS network def forward(self, input): - return torch.ops.aten.nonzero.default(input) + out = torch.ops.aten.nonzero.default(input) + return out + + inputs = [torch.randint(low=0, high=3, size=input_shape, dtype=dtype)] + self.run_test( + NonZero(), + inputs, + ) + + @parameterized.expand( + [ + ((10,), torch.int), + ((1, 20), torch.int32), + ((2, 3), torch.int64), + ((2, 3, 4), torch.float), + ((2, 3, 4, 5), torch.float), + ] + ) + def test_non_zero(self, input_shape, dtype): + class NonZero(nn.Module): + # This is a static network + def forward(self, input): + out = torch.ops.aten.nonzero.default(input) + out = torch.ops.aten.sum.dim_IntList(out, 0) + return out inputs = [torch.randint(low=0, high=3, size=input_shape, dtype=dtype)] self.run_test( From f58d8a99fd2bb78588ab53ac0601a53271c15c9c Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 12 Feb 2025 23:03:28 -0800 Subject: [PATCH 03/17] fix bug1 --- py/torch_tensorrt/dynamo/_engine_cache.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_engine_cache.py b/py/torch_tensorrt/dynamo/_engine_cache.py index 83f75dc4e9..2355b13469 100644 --- a/py/torch_tensorrt/dynamo/_engine_cache.py +++ b/py/torch_tensorrt/dynamo/_engine_cache.py @@ -25,6 +25,7 @@ Sequence[Input], CompilationSettings, Optional[Dict[str, Any]], + bool, ] @@ -106,6 +107,7 @@ def pack( input_specs: Sequence[Input], compilation_settings: CompilationSettings, weight_name_map: Optional[Dict[Any, Any]], + engine_is_dds: bool, ) -> bytes: """Pack serialized engine, input names, output names, and weight map into a single blob @@ -116,7 +118,7 @@ def pack( input_specs (Sequence[Input]): input specs of TRT engine compilation_settings (CompilationSettings): compilation settings of TRT engine weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting - + engine_is_dds (bool): whether the engine is data-dependent shape Returns: bytes: packed blob """ @@ -130,6 +132,7 @@ def pack( "input_specs": input_specs, "compilation_settings": settings, "weight_name_map": weight_name_map, + "engine_is_dds": engine_is_dds, } ) @@ -151,6 +154,7 @@ def unpack(packed_obj: bytes) -> UnpackedCacheHit: unpacked["input_specs"], unpacked["compilation_settings"], unpacked["weight_name_map"], + unpacked["engine_is_dds"], ) def insert( From 4fd440d67db330921b9576f5d9d439df1ce0e592 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 13 Feb 2025 19:35:23 -0800 Subject: [PATCH 04/17] disable cuda graph for output allocator mode --- py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py | 1 + 1 file changed, 1 insertion(+) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 4d02ac3bc7..48fb475cc7 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -545,6 +545,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]: return outputs def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: + torch_tensorrt.runtime.set_cudagraphs_mode(False) with ( torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:ProcessInputs" From 58ab2c2e4e20b2737c0a437e153f70b6a3551ee6 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 26 Feb 2025 13:23:56 -0800 Subject: [PATCH 05/17] implement with ctx manager --- .../dynamo/conversion/aten_ops_converters.py | 2 +- .../lowering/passes/_aten_lowering_pass.py | 2 + .../passes/remove_num_users_is_0_nodes.py | 27 ++ .../remove_sym_size_and_constrain_nodes.py | 34 ++ .../runtime/_CudaGraphsTorchTensorRTModule.py | 9 + .../runtime/_PythonTorchTensorRTModule.py | 46 ++- py/torch_tensorrt/runtime/__init__.py | 1 + .../runtime/_output_allocator.py | 51 +++ .../py/dynamo/conversion/test_nonzero_aten.py | 51 ++- .../runtime/test_output_allocator_py.py | 375 ++++++++++++++++++ 10 files changed, 584 insertions(+), 14 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/remove_sym_size_and_constrain_nodes.py create mode 100644 py/torch_tensorrt/runtime/_output_allocator.py create mode 100644 tests/py/dynamo/runtime/test_output_allocator_py.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 7792c0a456..8b5932a19c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -3554,7 +3554,7 @@ def aten_ops_full( ) -@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default) +@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default, supports_dynamic_shapes=True) def aten_ops_nonzero( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index b66f36c11e..87acc4bb3b 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -13,6 +13,8 @@ from .remove_assert_nodes import remove_assert_nodes from .remove_detach import remove_detach from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones +from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes +from .remove_sym_size_and_constrain_nodes import remove_sym_size_and_constrain_nodes from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py new file mode 100644 index 0000000000..72cd3b0e53 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py @@ -0,0 +1,27 @@ +import logging + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def remove_num_users_is_0_nodes( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Remove ops that [num_users=0] in the graph""" + output_node = list(gm.graph.nodes)[-1] + + for node in gm.graph.nodes: + if node != output_node and len(node.users) == 0: + node_input = node.all_input_nodes[0] + node.replace_all_uses_with(node_input) + gm.graph.erase_node(node) + gm = clean_up_graph_after_modifications(gm) + + logger.debug(f"Removed ops that [num_users=0] nodes:\n{gm.graph}") + + return gm diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_size_and_constrain_nodes.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_size_and_constrain_nodes.py new file mode 100644 index 0000000000..4d24be5cac --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_size_and_constrain_nodes.py @@ -0,0 +1,34 @@ +import logging + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def remove_sym_size_and_constrain_nodes( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Remove aten.sym_size.int and aten.sym_constrain_range_for_size.default ops in the graph""" + count = 0 + for node in gm.graph.nodes: + if node.op == "call_function" and ( + node.target == torch.ops.aten.sym_size.int + or node.target == torch.ops.aten.sym_constrain_range_for_size.default + ): + node_input = node.all_input_nodes[0] + node.replace_all_uses_with(node_input) + gm.graph.erase_node(node) + count += 1 + + if count > 0: + gm = clean_up_graph_after_modifications(gm) + + logger.debug( + f"Removed {count} aten.sym_size.int or aten.sym_constrain_range_for_size.default nodes:\n{gm.graph}" + ) + + return gm diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 1cc6d6c785..2163d237fe 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -32,6 +32,7 @@ def __init__( self._input_buffers: List[torch.Tensor] = [] self._output_buffers: List[torch.Tensor] = [] self.cudagraph: Optional[torch.cuda.CUDAGraph] = None + self.use_output_allocator_outputs = False self.shape_key: Optional[str] = None self._caller_stream: Optional[torch.cuda.Stream] = None self._engine_stream: Optional[torch.cuda.Stream] = None @@ -73,8 +74,16 @@ def __del__(self) -> None: if self.cudagraph: self.cudagraph.reset() + def set_output_allocator_outputs(self, enable: bool) -> None: + self.use_output_allocator_outputs = enable + def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode() + if cudagraphs_enabled and self.use_output_allocator_outputs: + raise RuntimeError( + "There are non-TRT submodules in the module. OutputAllocator is not compatible with modules with non-TRT submodules." + ) + if cudagraphs_enabled: shape_changed = self.validate_input_shapes(inputs) need_cudagraphs_record = shape_changed or self.is_weight_streaming_set diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 48fb475cc7..f78f89ee9e 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -202,10 +202,13 @@ def __init__( torch_tensorrt.runtime.get_cudagraphs_mode() ) - self.engine_is_dds = engine_is_dds + self.cudagraphs_enabled = False self.pre_allocated_outputs: List[torch.Tensor] = [] self.use_pre_allocated_outputs = False + + self.engine_is_dds = engine_is_dds self.output_allocator: Optional[DynamicOutputAllocator] = None + self.use_output_allocator_outputs = False if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() @@ -401,6 +404,9 @@ def create_output_tensors(self) -> List[torch.Tensor]: def set_pre_allocated_outputs(self, enable: bool) -> None: self.use_pre_allocated_outputs = enable + def set_output_allocator_outputs(self, enable: bool) -> None: + self.use_output_allocator_outputs = enable + def create_output_allocator(self) -> None: if self.output_allocator is None: output_dtypes_dict = {} @@ -410,15 +416,14 @@ def create_output_allocator(self) -> None: def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: - def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]: - cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() + def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: shape_changed = self.validate_input_shapes(inputs) ( need_cudagraphs_record, can_use_pre_allocated_outputs, need_cudagraphs_reset, ) = self.runtime_states.set_runtime_states( - cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed + self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed ) if need_cudagraphs_reset and self.cudagraph: @@ -441,7 +446,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]: ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." self.setup_input_tensors( - contiguous_inputs, cudagraphs_enabled, need_cudagraphs_record + contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record ) if shape_changed: @@ -477,7 +482,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]: if need_cudagraphs_record: self._output_buffers[o] = outputs[o].clone() - if cudagraphs_enabled: + if self.cudagraphs_enabled: self.context.set_tensor_address( output_name, self._output_buffers[o].data_ptr() ) @@ -503,7 +508,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]: self._engine_stream.wait_stream(self._caller_stream) with torch.cuda.stream(self._engine_stream): - if cudagraphs_enabled: + if self.cudagraphs_enabled: if need_cudagraphs_record: self.cudagraph = torch.cuda.CUDAGraph() @@ -535,7 +540,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]: if self.use_pre_allocated_outputs: self.pre_allocated_outputs = self.create_output_tensors() - if cudagraphs_enabled: + if self.cudagraphs_enabled: for idx, o in enumerate(outputs): o.copy_(self._output_buffers[idx]) @@ -545,7 +550,9 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]: return outputs def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: - torch_tensorrt.runtime.set_cudagraphs_mode(False) + assert ( + not torch_tensorrt.runtime.get_cudagraphs_mode() + ), "CUDA Graphs are not compatible with OutputAllocator." with ( torch.autograd.profiler.record_function( "PythonTorchTensorRTModule:ProcessInputs" @@ -625,6 +632,8 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: return outputs + self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() + # Run forward function contiguous_inputs: List[torch.Tensor] = [ (i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) @@ -670,9 +679,26 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: logger.warning(f"Moved all input Tensors to cuda:{device_id}") if self.engine_is_dds: + if self.cudagraphs_enabled: + raise RuntimeError( + "The module is Data-Dependent Shape (DDS). It has to be handled by OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs." + ) + logger.debug( + "The module is Data-Dependent Shape (DDS). Using output allocator." + ) return run_output_allocator() else: - return run_cuda_graph() + if self.cudagraphs_enabled and self.use_output_allocator_outputs: + raise RuntimeError( + "Both CUDA Graphs and OutputAllocator are enabled. Please disable either one." + ) + if self.use_output_allocator_outputs: + logger.debug("Using output allocator.") + return run_output_allocator() + logger.debug( + f"Using standard execution with cudagraphs={self.cudagraphs_enabled}." + ) + return run_standard_execution() def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None: """ diff --git a/py/torch_tensorrt/runtime/__init__.py b/py/torch_tensorrt/runtime/__init__.py index 470074a377..cfc9b322b5 100644 --- a/py/torch_tensorrt/runtime/__init__.py +++ b/py/torch_tensorrt/runtime/__init__.py @@ -9,5 +9,6 @@ set_cudagraphs_mode, ) from torch_tensorrt.runtime._multi_device_safe_mode import set_multi_device_safe_mode +from torch_tensorrt.runtime._output_allocator import enable_output_allocator from torch_tensorrt.runtime._pre_allocated_outputs import enable_pre_allocated_outputs from torch_tensorrt.runtime._weight_streaming import weight_streaming diff --git a/py/torch_tensorrt/runtime/_output_allocator.py b/py/torch_tensorrt/runtime/_output_allocator.py new file mode 100644 index 0000000000..4b0d644231 --- /dev/null +++ b/py/torch_tensorrt/runtime/_output_allocator.py @@ -0,0 +1,51 @@ +import logging +from typing import Any, Union + +import torch +from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule +from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import ( + CudaGraphsTorchTensorRTModule, +) + +logger = logging.getLogger(__name__) + + +class _OutputAllocatorContextManager(object): + """ + Helper class to set up output_allocator + """ + + def __init__( + self, module: Union[torch.fx.GraphModule, CudaGraphsTorchTensorRTModule] + ) -> None: + if isinstance(module, CudaGraphsTorchTensorRTModule): + rt_mods = [module] + else: + rt_mods = [] + + for name, rt_mod in module.named_children(): + if "_run_on_acc" in name and isinstance( + rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule) + ): + rt_mods.append(rt_mod) + + self.rt_mods = rt_mods + + def set_output_allocator_output(self, enable: bool) -> None: + for mod in self.rt_mods: + mod.set_output_allocator_outputs(enable) + + def __enter__(self) -> "_OutputAllocatorContextManager": + # Enable output_allocator for TRT submodules + self.set_output_allocator_output(True) + return self + + def __exit__(self, *args: Any) -> None: + # Disable output_allocator + self.set_output_allocator_output(False) + + +def enable_output_allocator( + module: torch.fx.GraphModule, +) -> _OutputAllocatorContextManager: + return _OutputAllocatorContextManager(module) diff --git a/tests/py/dynamo/conversion/test_nonzero_aten.py b/tests/py/dynamo/conversion/test_nonzero_aten.py index 4479993780..f2c5123575 100644 --- a/tests/py/dynamo/conversion/test_nonzero_aten.py +++ b/tests/py/dynamo/conversion/test_nonzero_aten.py @@ -17,7 +17,7 @@ class TestNonZeroConverter(DispatchTestCase): ((2, 3, 4, 5), torch.float), ] ) - def test_non_zero(self, input_shape, dtype): + def test_nonzero_dds(self, input_shape, dtype): class NonZero(nn.Module): # This is a DDS network def forward(self, input): @@ -39,7 +39,7 @@ def forward(self, input): ((2, 3, 4, 5), torch.float), ] ) - def test_non_zero(self, input_shape, dtype): + def test_nonzero_non_dds(self, input_shape, dtype): class NonZero(nn.Module): # This is a static network def forward(self, input): @@ -78,7 +78,7 @@ def forward(self, input): ), ] ) - def test_nonzero_dynamic_shape(self, _, min_shape, opt_shape, max_shape, dtype): + def test_nonzero_dynamic_shape_dds(self, _, min_shape, opt_shape, max_shape, dtype): class NonZero(nn.Module): def forward(self, input): return torch.ops.aten.nonzero.default(input) @@ -94,6 +94,51 @@ def forward(self, input): self.run_test_with_dynamic_shape(NonZero(), input_specs) + @parameterized.expand( + [ + ( + "1d", + (1,), + (10,), + (100,), + torch.int32, + ), + ( + "2d", + (1, 2), + (5, 10), + (20, 40), + torch.float16, + ), + ( + "3d", + (1, 2, 3), + (5, 10, 20), + (30, 40, 50), + torch.float, + ), + ] + ) + def test_nonzero_dynamic_shape_non_dds( + self, _, min_shape, opt_shape, max_shape, dtype + ): + class NonZero(nn.Module): + def forward(self, input): + out = torch.ops.aten.nonzero.default(input) + out = torch.ops.aten.sum.dim_IntList(out, 0) + return out + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=dtype, + ), + ] + + self.run_test_with_dynamic_shape(NonZero(), input_specs) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/runtime/test_output_allocator_py.py b/tests/py/dynamo/runtime/test_output_allocator_py.py new file mode 100644 index 0000000000..98c9fcf155 --- /dev/null +++ b/tests/py/dynamo/runtime/test_output_allocator_py.py @@ -0,0 +1,375 @@ +import pytest +import torch +import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests + +from ..testing_utilities import DECIMALS_OF_AGREEMENT + +INPUT_SIZE = (3, 16, 16) +TRIALS = 5 + + +class StaticModel(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.abs.default(input) + + +class DDSModel(torch.nn.Module): + def forward(self, input): + return torch.ops.aten.nonzero.default(input) + + +class NonDDSModel(torch.nn.Module): + def forward(self, inputs): + out = torch.ops.aten.nonzero.default(inputs) + out = torch.ops.aten.sum.dim_IntList(out, 0) + return out + + +class TestOutputAllocatorPythonStaticModel(TestCase): + def test_cudagraphs(self): + model = StaticModel().eval().cuda() + inputs = [torch.randn((2, 3), dtype=torch.float).cuda()] + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + + ref_out = model(*inputs) + + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + cg_out = cudagraphs_module(*inputs) + + self.assertAlmostEqual( + float(torch.max(torch.abs(ref_out - cg_out))), + 0, + DECIMALS_OF_AGREEMENT, + msg="CUDA Graphs Python TRT outputs don't match with the original model.", + ) + + def test_output_allocator(self): + model = StaticModel().eval().cuda() + inputs = [torch.randn((2, 3), dtype=torch.float).cuda()] + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + + ref_out = model(*inputs) + + with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + oa_out = compiled_model(*inputs) + + self.assertAlmostEqual( + float(torch.max(torch.abs(ref_out - oa_out))), + 0, + DECIMALS_OF_AGREEMENT, + msg="Output Allocator Python TRT outputs don't match with the original model.", + ) + + def test_default(self): + """ + Static models use standard execution with cudagraphs=False by default. + """ + model = StaticModel().eval().cuda() + inputs = [torch.randn((2, 3), dtype=torch.float).cuda()] + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + standard_out = compiled_model(*inputs) + ref_out = model(*inputs) + + self.assertAlmostEqual( + float(torch.max(torch.abs(ref_out - standard_out))), + 0, + DECIMALS_OF_AGREEMENT, + msg="Default standard execution outputs don't match with the original model.", + ) + + @pytest.mark.xfail( + strict=True, + raises=RuntimeError, + reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + ) + def test_cudagraphs_and_output_allocator(self): + model = StaticModel().eval().cuda() + inputs = [torch.randn((2, 3), dtype=torch.float).cuda()] + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + with torch_tensorrt.runtime.enable_output_allocator(cudagraphs_module): + out = cudagraphs_module(*inputs) + + @pytest.mark.xfail( + strict=True, + raises=RuntimeError, + reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + ) + def test_output_allocator_and_cudagraphs(self): + model = StaticModel().eval().cuda() + inputs = [torch.randn((2, 3), dtype=torch.float).cuda()] + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + out = cudagraphs_module(*inputs) + + +class TestOutputAllocatorPythonDDSModel(TestCase): + @pytest.mark.xfail( + strict=True, + raises=RuntimeError, + reason="The module is Data-Dependent Shape (DDS). It has to be handled by OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs.", + ) + def test_cudagraphs(self): + model = DDSModel().eval().cuda() + inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + cg_out = cudagraphs_module(*inputs) + + def test_output_allocator(self): + model = DDSModel().eval().cuda() + inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + + ref_out = model(*inputs) + + with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + oa_out = compiled_model(*inputs) + + self.assertAlmostEqual( + float(torch.max(torch.abs(ref_out - oa_out))), + 0, + DECIMALS_OF_AGREEMENT, + msg="Output Allocator Python TRT outputs don't match with the original model.", + ) + + def test_default(self): + """ + DDS models use OutputAllocator by default. + """ + model = DDSModel().eval().cuda() + inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + oa_out = compiled_model(*inputs) + ref_out = model(*inputs) + + self.assertAlmostEqual( + float(torch.max(torch.abs(ref_out - oa_out))), + 0, + DECIMALS_OF_AGREEMENT, + msg="Default Output Allocator Python TRT outputs don't match with the original model.", + ) + + @pytest.mark.xfail( + strict=True, + raises=RuntimeError, + reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + ) + def test_cudagraphs_and_output_allocator(self): + model = DDSModel().eval().cuda() + inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + with torch_tensorrt.runtime.enable_output_allocator(cudagraphs_module): + out = cudagraphs_module(*inputs) + + @pytest.mark.xfail( + strict=True, + raises=RuntimeError, + reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + ) + def test_output_allocator_and_cudagraphs(self): + model = DDSModel().eval().cuda() + inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + out = cudagraphs_module(*inputs) + + +class TestOutputAllocatorPythonNonDDSModel(TestCase): + @pytest.mark.skip(reason="NonDDSModel is currently not supported in CUDA Graphs.") + def test_cudagraphs(self): + model = NonDDSModel().eval().cuda() + inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + + ref_out = model(*inputs) + + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + cg_out = cudagraphs_module(*inputs) + + self.assertAlmostEqual( + float(torch.max(torch.abs(ref_out - cg_out))), + 0, + DECIMALS_OF_AGREEMENT, + msg="CUDA Graphs Python TRT outputs don't match with the original model.", + ) + + def test_output_allocator(self): + model = NonDDSModel().eval().cuda() + inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + + ref_out = model(*inputs) + + with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + oa_out = compiled_model(*inputs) + + self.assertAlmostEqual( + float(torch.max(torch.abs(ref_out - oa_out))), + 0, + DECIMALS_OF_AGREEMENT, + msg="Output Allocator Python TRT outputs don't match with the original model.", + ) + + def test_default(self): + """ + NonDDS models use standard execution with cudagraphs=False by default. + """ + model = DDSModel().eval().cuda() + inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + standard_out = compiled_model(*inputs) + ref_out = model(*inputs) + + self.assertAlmostEqual( + float(torch.max(torch.abs(ref_out - standard_out))), + 0, + DECIMALS_OF_AGREEMENT, + msg="Default standard execution outputs don't match with the original model.", + ) + + @pytest.mark.xfail( + strict=True, + raises=RuntimeError, + reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + ) + def test_cudagraphs_and_output_allocator(self): + model = NonDDSModel().eval().cuda() + inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + with torch_tensorrt.runtime.enable_output_allocator(cudagraphs_module): + out = cudagraphs_module(*inputs) + + @pytest.mark.xfail( + strict=True, + raises=RuntimeError, + reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + ) + def test_output_allocator_and_cudagraphs(self): + model = NonDDSModel().eval().cuda() + inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + ) + with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + out = cudagraphs_module(*inputs) + + +if __name__ == "__main__": + run_tests() From 11890ffe09bbdef9704631aa39adf755f7f72045 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 27 Feb 2025 21:56:08 -0800 Subject: [PATCH 06/17] refactor --- py/torch_tensorrt/dynamo/_engine_cache.py | 8 +- .../dynamo/conversion/_ConversionContext.py | 2 + .../dynamo/conversion/_ConverterRegistry.py | 21 ++- .../dynamo/conversion/_TRTInterpreter.py | 34 ++--- .../dynamo/conversion/_conversion.py | 2 +- .../dynamo/conversion/aten_ops_converters.py | 6 +- .../dynamo/partitioning/common.py | 6 +- .../runtime/_CudaGraphsTorchTensorRTModule.py | 5 - .../runtime/_PythonTorchTensorRTModule.py | 18 +-- .../dynamo/runtime/_TorchTensorRTModule.py | 10 +- py/torch_tensorrt/runtime/_cudagraphs.py | 11 +- tests/py/dynamo/conversion/harness.py | 4 +- .../runtime/test_output_allocator_py.py | 140 +++++++++++++++++- 13 files changed, 202 insertions(+), 65 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_engine_cache.py b/py/torch_tensorrt/dynamo/_engine_cache.py index 2355b13469..3baef500ef 100644 --- a/py/torch_tensorrt/dynamo/_engine_cache.py +++ b/py/torch_tensorrt/dynamo/_engine_cache.py @@ -107,7 +107,7 @@ def pack( input_specs: Sequence[Input], compilation_settings: CompilationSettings, weight_name_map: Optional[Dict[Any, Any]], - engine_is_dds: bool, + requires_output_allocator: bool, ) -> bytes: """Pack serialized engine, input names, output names, and weight map into a single blob @@ -118,7 +118,7 @@ def pack( input_specs (Sequence[Input]): input specs of TRT engine compilation_settings (CompilationSettings): compilation settings of TRT engine weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting - engine_is_dds (bool): whether the engine is data-dependent shape + requires_output_allocator (bool): whether the engine requires output allocator Returns: bytes: packed blob """ @@ -132,7 +132,7 @@ def pack( "input_specs": input_specs, "compilation_settings": settings, "weight_name_map": weight_name_map, - "engine_is_dds": engine_is_dds, + "requires_output_allocator": requires_output_allocator, } ) @@ -154,7 +154,7 @@ def unpack(packed_obj: bytes) -> UnpackedCacheHit: unpacked["input_specs"], unpacked["compilation_settings"], unpacked["weight_name_map"], - unpacked["engine_is_dds"], + unpacked["requires_output_allocator"], ) def insert( diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py index 37581f76cd..691f38281e 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -11,9 +11,11 @@ class ConversionContext: Args: net: TensorRT Network being built compilation_settings: Settings selected by the user for compilation + requires_output_allocator: Whether the network requires output allocator """ net: TRTNetwork compilation_settings: CompilationSettings = field( default_factory=CompilationSettings ) + requires_output_allocator: bool = False diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index 1efacea619..a518f7d80b 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -18,6 +18,7 @@ cast, ) +import tensorrt as trt import torch from torch import SymBool, SymFloat, SymInt from torch._ops import OpOverloadPacket @@ -26,8 +27,6 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS -import tensorrt as trt - logger = logging.getLogger(__name__) LegacyConverterImplSignature = Callable[ @@ -81,6 +80,7 @@ class ConverterSupport: whether that node can be supported by its companion converter. Note that this function must not modify the node or its graph supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic inputs. + requires_output_allocator: Boolean flag indicating if the converter requires to run in output allocator. """ converter_implementation: ConverterImplSignature @@ -88,6 +88,7 @@ class ConverterSupport: default=lambda node, compilation_settings: True ) supports_dynamic_shapes: bool = False + requires_output_allocator: bool = False # Dictionary representing Dynamo aten-only converters @@ -197,6 +198,7 @@ def dynamo_tensorrt_converter( capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, priority: ConverterPriority = ConverterPriority.STANDARD, supports_dynamic_shapes: bool = False, + requires_output_allocator: bool = False, ) -> Callable[[ConverterImplSignature], ConverterImplSignature]: """Decorator for Dynamo TensorRT Converter @@ -212,6 +214,8 @@ def dynamo_tensorrt_converter( this means all nodes of "key" kind can be supported by this converter priority: Converter's level of priority relative to other converters with the same target + supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic shapes. + requires_output_allocator: Boolean flag indicating if the converter requires to run in output allocator. Returns: The converter being decorated """ @@ -225,6 +229,7 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat converter_support = ConverterSupport( converter_implementation=converter, supports_dynamic_shapes=supports_dynamic_shapes, + requires_output_allocator=requires_output_allocator, ) else: assert callable( @@ -234,6 +239,7 @@ def register_converter(converter: ConverterImplSignature) -> ConverterImplSignat converter_implementation=converter, capability_validator=capability_validator, supports_dynamic_shapes=supports_dynamic_shapes, + requires_output_allocator=requires_output_allocator, ) # OpOverloadPackets are only valid if they have a single overload, or @@ -404,7 +410,7 @@ def __getitem_without_validation__( def __getitem__( self, node: Node ) -> Tuple[ - Any, CallingConvention + Any, CallingConvention, bool ]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters """Get the first-found validated converter in any registry @@ -462,6 +468,7 @@ def __getitem__( return ( candidate.converter_implementation, calling_convention, + candidate.requires_output_allocator, ) else: logger.debug( @@ -471,7 +478,11 @@ def __getitem__( else: # Assuming FX converters don't have dynamic shapes supported if not node_has_dynamic_shapes(node): - return converters, calling_convention + return ( + converters, + calling_convention, + candidate.requires_output_allocator, + ) raise KeyError( f"None of the converter registries have a validated entry for {key}, with node {node}" @@ -495,7 +506,7 @@ def get_unvalidated( def get( self, node: Node, value: Optional[ConverterImplSignature] = None ) -> Union[ - Any, Tuple[Any, CallingConvention] + Any, Tuple[Any, CallingConvention, bool] ]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters """Get validated converter for input node with a default return""" try: diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index f8ac22ecf8..c34495b8c2 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -64,7 +64,7 @@ class TRTInterpreterResult(NamedTuple): input_names: Sequence[str] output_names: Sequence[str] weight_name_map: Optional[dict[Any, Any]] - engine_is_dds: bool + requires_output_allocator: bool class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc] @@ -139,9 +139,6 @@ def __init__( # Engine cache for storing and reusing TRT engines self.engine_cache = engine_cache - # Whether the engine is data-dependent shape (dds) - self.engine_is_dds: bool = False - def validate_conversion(self) -> Set[str]: missing_converters: Set[str] = set() @@ -586,7 +583,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No self.input_specs, self.compilation_settings, self.weight_name_map, - self.engine_is_dds, + self.ctx.requires_output_allocator, ), ) @@ -601,7 +598,7 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: cached_engine_input_specs, engine_compilation_settings, self.weight_name_map, - self.engine_is_dds, + self.ctx.requires_output_allocator, ) = cached_data setting_compatiblity, incompattible_settings = settings_are_compatible( @@ -663,20 +660,10 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: self._input_names, self._output_names, self.weight_name_map, - self.engine_is_dds, + self.ctx.requires_output_allocator, ) return None - def check_dds(self, serialized_engine: bytes, output_names: List[str]) -> bool: - runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine(serialized_engine) - - for output_name in output_names: - output_shape = engine.get_tensor_shape(output_name) - if -1 in output_shape: - return True - return False - def run( self, strict_type_constraints: bool = False, @@ -733,8 +720,6 @@ def run( ) assert serialized_engine - self.engine_is_dds = self.check_dds(serialized_engine, self._output_names) - _LOGGER.info( f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" ) @@ -761,7 +746,7 @@ def run( self._input_names, self._output_names, self.weight_name_map, - self.engine_is_dds, + self.ctx.requires_output_allocator, ) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: @@ -855,7 +840,7 @@ def call_module( f"Conversion of module of type {submod_type} not currently supported!" ) - converter, calling_convention = converter_packet + converter, calling_convention, requires_output_allocator = converter_packet assert self._cur_node_name is not None @@ -872,7 +857,10 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any: f"Conversion of function {torch.typename(target)} not currently supported!" ) - converter, calling_convention = converter_packet + converter, calling_convention, requires_output_allocator = converter_packet + if requires_output_allocator: + self.ctx.requires_output_allocator = True + _LOGGER.debug(f"{target} requires output allocator") if calling_convention is CallingConvention.LEGACY: return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) @@ -902,7 +890,7 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any: raise UnsupportedOperatorException( f"Conversion of method {target} not currently supported!" ) - converter, calling_convention = converter_packet + converter, calling_convention, requires_output_allocator = converter_packet if calling_convention is CallingConvention.LEGACY: return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 82c3ceb48e..adb7039e7e 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -112,5 +112,5 @@ def convert_module( name=name, settings=settings, weight_name_map=interpreter_result.weight_name_map, - engine_is_dds=interpreter_result.engine_is_dds, + requires_output_allocator=interpreter_result.requires_output_allocator, ) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 8b5932a19c..e4b5d732b6 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -3554,7 +3554,11 @@ def aten_ops_full( ) -@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter( + torch.ops.aten.nonzero.default, + supports_dynamic_shapes=True, + requires_output_allocator=True, +) def aten_ops_nonzero( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index 45c3508458..685ec6ebef 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -31,8 +31,10 @@ def construct_dynamic_input( if isinstance(dim, torch.SymInt): min_max_opt = extract_var_range_info(dim) min_shape.append(min_max_opt["min"]) - # opt might not exist - opt_shape.append(min_max_opt.get("opt")) + # if opt not exist, set it to the mean of min and max + opt_shape.append( + min_max_opt.get("opt", int(min_max_opt["min"] + min_max_opt["max"] / 2)) + ) max_shape.append(min_max_opt["max"]) else: min_shape.append(dim) diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 2163d237fe..7ec8930983 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -79,11 +79,6 @@ def set_output_allocator_outputs(self, enable: bool) -> None: def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode() - if cudagraphs_enabled and self.use_output_allocator_outputs: - raise RuntimeError( - "There are non-TRT submodules in the module. OutputAllocator is not compatible with modules with non-TRT submodules." - ) - if cudagraphs_enabled: shape_changed = self.validate_input_shapes(inputs) need_cudagraphs_record = shape_changed or self.is_weight_streaming_set diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index f78f89ee9e..0ac81c9baf 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -127,7 +127,7 @@ def __init__( name: str = "", settings: CompilationSettings = CompilationSettings(), weight_name_map: Optional[dict[Any, Any]] = None, - engine_is_dds: bool = False, + requires_output_allocator: bool = False, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine @@ -141,7 +141,7 @@ def __init__( name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed weight_name_map (dict): Mapping of engine weight name to state_dict weight name - engine_is_dds (bool): Whether the engine is Data Dependent Shape + requires_output_allocator (bool): Whether the engine requires an output allocator Example: @@ -206,7 +206,7 @@ def __init__( self.pre_allocated_outputs: List[torch.Tensor] = [] self.use_pre_allocated_outputs = False - self.engine_is_dds = engine_is_dds + self.requires_output_allocator = requires_output_allocator self.output_allocator: Optional[DynamicOutputAllocator] = None self.use_output_allocator_outputs = False @@ -281,7 +281,7 @@ def setup_engine(self) -> None: for output_name in self.output_names ] - if self.engine_is_dds: + if self.requires_output_allocator: self.create_output_allocator() if torch_tensorrt.runtime.get_cudagraphs_mode(): @@ -678,14 +678,12 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: ] logger.warning(f"Moved all input Tensors to cuda:{device_id}") - if self.engine_is_dds: + if self.requires_output_allocator: if self.cudagraphs_enabled: raise RuntimeError( - "The module is Data-Dependent Shape (DDS). It has to be handled by OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs." + "This module requires OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs." ) - logger.debug( - "The module is Data-Dependent Shape (DDS). Using output allocator." - ) + logger.debug("Using OutputAllocator in runtime.") return run_output_allocator() else: if self.cudagraphs_enabled and self.use_output_allocator_outputs: @@ -693,7 +691,7 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: "Both CUDA Graphs and OutputAllocator are enabled. Please disable either one." ) if self.use_output_allocator_outputs: - logger.debug("Using output allocator.") + logger.debug("Using OutputAllocator in runtime.") return run_output_allocator() logger.debug( f"Using standard execution with cudagraphs={self.cudagraphs_enabled}." diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index ae0cc3cec6..ca9a9a9f25 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -79,7 +79,7 @@ def __init__( name: str = "", settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed weight_name_map: Optional[dict[Any, Any]] = None, - engine_is_dds: bool = False, + requires_output_allocator: bool = False, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines @@ -98,7 +98,7 @@ def __init__( name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed weight_name_map (dict): Mapping of engine weight name to state_dict weight name - engine_is_dds (bool): Whether the engine is Data Dependent Shape + requires_output_allocator (bool): Whether the engine requires an output allocator Example: @@ -134,7 +134,7 @@ def __init__( self.weight_name_map = weight_name_map self.serialized_engine = serialized_engine self.engine = None - self.engine_is_dds = engine_is_dds + self.requires_output_allocator = requires_output_allocator if ( serialized_engine @@ -152,7 +152,7 @@ def _pack_engine_info(self) -> List[str | bytes]: metadata = { "settings": self.settings, "weight_name_map": self.weight_name_map, - "engine_is_dds": self.engine_is_dds, + "requires_output_allocator": self.requires_output_allocator, } target_platform = ( Platform.current_platform() @@ -270,7 +270,7 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: metadata = TorchTensorRTModule.decode_metadata(serialized_metadata) self.settings = metadata["settings"] self.weight_name_map = metadata["weight_name_map"] - self.engine_is_dds = metadata["engine_is_dds"] + self.requires_output_allocator = metadata["requires_output_allocator"] else: self.engine = None diff --git a/py/torch_tensorrt/runtime/_cudagraphs.py b/py/torch_tensorrt/runtime/_cudagraphs.py index d1564cb4dc..d3c4ea2a05 100644 --- a/py/torch_tensorrt/runtime/_cudagraphs.py +++ b/py/torch_tensorrt/runtime/_cudagraphs.py @@ -74,13 +74,20 @@ def __enter__(self) -> torch.nn.Module: num_torch_module = 0 num_trt_module = 0 - for name, _ in self.compiled_module.named_children(): + disable_cudagraphs = False + for name, module in self.compiled_module.named_children(): + # disable cudagraphs if any model requires output allocator + if ( + hasattr(module, "requires_output_allocator") + and module.requires_output_allocator + ): + disable_cudagraphs = True if "_run_on_acc" in name: num_trt_module += 1 elif "_run_on_gpu" in name: num_torch_module += 1 - if num_torch_module > 0: + if num_torch_module > 0 and not disable_cudagraphs: # Set whole cudagraphs mode and returns wrapped module _PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS # Set new mode for C++ diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 3e17273b98..6ff45507a0 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -207,7 +207,7 @@ def run_test( input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), name="test_engine", - engine_is_dds=interpreter_result.engine_is_dds, + requires_output_allocator=interpreter_result.requires_output_allocator, ) mod = mod.cuda() if pyt_inputs is not None: @@ -290,7 +290,7 @@ def run_test_custom_compare_results( input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), name="test_engine", - engine_is_dds=interpreter_result.engine_is_dds, + requires_output_allocator=interpreter_result.requires_output_allocator, ) res_trt = trt_mod(*cuda_inputs).cpu() res_cpu = mod(*cuda_inputs).cpu() diff --git a/tests/py/dynamo/runtime/test_output_allocator_py.py b/tests/py/dynamo/runtime/test_output_allocator_py.py index 98c9fcf155..c0ac064f7a 100644 --- a/tests/py/dynamo/runtime/test_output_allocator_py.py +++ b/tests/py/dynamo/runtime/test_output_allocator_py.py @@ -26,7 +26,16 @@ def forward(self, inputs): return out -class TestOutputAllocatorPythonStaticModel(TestCase): +class DDSModel2(torch.nn.Module): + def forward(self, input): + # combination of multiple non-zero and other ops + out = torch.ops.aten.nonzero.default(input) + out = torch.ops.aten.abs.default(out) + out = torch.ops.aten.nonzero.default(out) + return out + + +class TestOutputAllocatorStaticModelPython(TestCase): def test_cudagraphs(self): model = StaticModel().eval().cuda() inputs = [torch.randn((2, 3), dtype=torch.float).cuda()] @@ -142,11 +151,11 @@ def test_output_allocator_and_cudagraphs(self): out = cudagraphs_module(*inputs) -class TestOutputAllocatorPythonDDSModel(TestCase): +class TestOutputAllocatorDDSModelPython(TestCase): @pytest.mark.xfail( strict=True, raises=RuntimeError, - reason="The module is Data-Dependent Shape (DDS). It has to be handled by OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs.", + reason="This module requires OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs.", ) def test_cudagraphs(self): model = DDSModel().eval().cuda() @@ -254,8 +263,12 @@ def test_output_allocator_and_cudagraphs(self): out = cudagraphs_module(*inputs) -class TestOutputAllocatorPythonNonDDSModel(TestCase): - @pytest.mark.skip(reason="NonDDSModel is currently not supported in CUDA Graphs.") +class TestOutputAllocatorNonDDSModelPython(TestCase): + @pytest.mark.xfail( + strict=True, + raises=RuntimeError, + reason="This module requires OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs.", + ) def test_cudagraphs(self): model = NonDDSModel().eval().cuda() inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) @@ -371,5 +384,122 @@ def test_output_allocator_and_cudagraphs(self): out = cudagraphs_module(*inputs) +class TestOutputAllocatorDDSModelWithGraphBreakPython(TestCase): + @pytest.mark.xfail( + strict=True, + raises=RuntimeError, + reason="This module requires OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs.", + ) + def test_cudagraphs(self): + model = DDSModel2().eval().cuda() + inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + torch_executed_ops={"torch.ops.aten.abs.default"}, + ) + + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + cg_out = cudagraphs_module(*inputs) + + def test_output_allocator(self): + model = DDSModel2().eval().cuda() + inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + torch_executed_ops={"torch.ops.aten.abs.default"}, + ) + + ref_out = model(*inputs) + + with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + oa_out = compiled_model(*inputs) + + self.assertAlmostEqual( + float(torch.max(torch.abs(ref_out - oa_out))), + 0, + DECIMALS_OF_AGREEMENT, + msg="Output Allocator Python TRT outputs don't match with the original model.", + ) + + def test_default(self): + """ + Use OutputAllocator by default. + """ + model = DDSModel2().eval().cuda() + inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + torch_executed_ops={"torch.ops.aten.abs.default"}, + ) + oa_out = compiled_model(*inputs) + ref_out = model(*inputs) + + self.assertAlmostEqual( + float(torch.max(torch.abs(ref_out - oa_out))), + 0, + DECIMALS_OF_AGREEMENT, + msg="Default Output Allocator Python TRT outputs don't match with the original model.", + ) + + @pytest.mark.xfail( + strict=True, + raises=RuntimeError, + reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + ) + def test_cudagraphs_and_output_allocator(self): + model = DDSModel2().eval().cuda() + inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + torch_executed_ops={"torch.ops.aten.abs.default"}, + ) + + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + with torch_tensorrt.runtime.enable_output_allocator(cudagraphs_module): + out = cudagraphs_module(*inputs) + + @pytest.mark.xfail( + strict=True, + raises=RuntimeError, + reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + ) + def test_output_allocator_and_cudagraphs(self): + model = DDSModel2().eval().cuda() + inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) + compiled_model = torch_tensorrt.compile( + model, + "dynamo", + inputs, + min_block_size=1, + use_python_runtime=True, + torch_executed_ops={"torch.ops.aten.abs.default"}, + ) + with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + out = cudagraphs_module(*inputs) + + if __name__ == "__main__": run_tests() From fa49005a7b7fe9a2af4a9a96d5b579d3d4272b97 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 28 Feb 2025 17:08:47 -0800 Subject: [PATCH 07/17] remove sym_size lowering pass --- .../lowering/passes/_aten_lowering_pass.py | 28 +++++++-------- .../remove_sym_size_and_constrain_nodes.py | 34 ------------------- 2 files changed, 13 insertions(+), 49 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/lowering/passes/remove_sym_size_and_constrain_nodes.py diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 87acc4bb3b..1398b0687e 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -14,24 +14,22 @@ from .remove_detach import remove_detach from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes -from .remove_sym_size_and_constrain_nodes import remove_sym_size_and_constrain_nodes from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices -pass_list = [ - remove_input_alias_fixing_clones, - constant_fold, - repair_input_as_output, - fuse_prims_broadcast, - replace_max_pool_with_indices, - remove_assert_nodes, - accumulate_fp32_matmul, -] - -if not is_tegra_platform(): - pass_list.append(fuse_distributed_ops) - -ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list) +ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist( + [ + remove_input_alias_fixing_clones, + constant_fold, + repair_input_as_output, + fuse_prims_broadcast, + fuse_distributed_ops, + replace_max_pool_with_indices, + remove_assert_nodes, + accumulate_fp32_matmul, + remove_num_users_is_0_nodes, + ] +) ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist( [ diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_size_and_constrain_nodes.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_size_and_constrain_nodes.py deleted file mode 100644 index 4d24be5cac..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_size_and_constrain_nodes.py +++ /dev/null @@ -1,34 +0,0 @@ -import logging - -import torch -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) - -logger = logging.getLogger(__name__) - - -def remove_sym_size_and_constrain_nodes( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Remove aten.sym_size.int and aten.sym_constrain_range_for_size.default ops in the graph""" - count = 0 - for node in gm.graph.nodes: - if node.op == "call_function" and ( - node.target == torch.ops.aten.sym_size.int - or node.target == torch.ops.aten.sym_constrain_range_for_size.default - ): - node_input = node.all_input_nodes[0] - node.replace_all_uses_with(node_input) - gm.graph.erase_node(node) - count += 1 - - if count > 0: - gm = clean_up_graph_after_modifications(gm) - - logger.debug( - f"Removed {count} aten.sym_size.int or aten.sym_constrain_range_for_size.default nodes:\n{gm.graph}" - ) - - return gm From 0700ea3710db93bf1c5e3b919538889df0734422 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 28 Feb 2025 23:20:48 -0800 Subject: [PATCH 08/17] fix bugs from CI --- py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py | 2 +- .../dynamo/lowering/passes/remove_num_users_is_0_nodes.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index a518f7d80b..cead3e14fd 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -481,7 +481,7 @@ def __getitem__( return ( converters, calling_convention, - candidate.requires_output_allocator, + False, ) raise KeyError( diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py index 72cd3b0e53..3151c05cd4 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py @@ -16,7 +16,11 @@ def remove_num_users_is_0_nodes( output_node = list(gm.graph.nodes)[-1] for node in gm.graph.nodes: - if node != output_node and len(node.users) == 0: + if ( + node != output_node + and len(node.users) == 0 + and len(node.all_input_nodes) > 0 + ): node_input = node.all_input_nodes[0] node.replace_all_uses_with(node_input) gm.graph.erase_node(node) From ff235f854cb963bc6840509b43c1df90f0218afb Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 3 Mar 2025 21:37:56 -0800 Subject: [PATCH 09/17] resolve comments --- py/torch_tensorrt/dynamo/_engine_cache.py | 2 +- .../dynamo/conversion/_ConversionContext.py | 2 +- .../dynamo/conversion/_ConverterRegistry.py | 18 ++++++++++++------ .../dynamo/conversion/_TRTInterpreter.py | 8 ++++---- .../passes/remove_num_users_is_0_nodes.py | 7 +++---- .../runtime/_PythonTorchTensorRTModule.py | 2 +- .../dynamo/runtime/_TorchTensorRTModule.py | 2 +- py/torch_tensorrt/runtime/_cudagraphs.py | 9 +++++---- 8 files changed, 28 insertions(+), 22 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_engine_cache.py b/py/torch_tensorrt/dynamo/_engine_cache.py index 3baef500ef..a6d9a1face 100644 --- a/py/torch_tensorrt/dynamo/_engine_cache.py +++ b/py/torch_tensorrt/dynamo/_engine_cache.py @@ -118,7 +118,7 @@ def pack( input_specs (Sequence[Input]): input specs of TRT engine compilation_settings (CompilationSettings): compilation settings of TRT engine weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting - requires_output_allocator (bool): whether the engine requires output allocator + requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) Returns: bytes: packed blob """ diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py index 691f38281e..0dbdb2a8f4 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -11,7 +11,7 @@ class ConversionContext: Args: net: TensorRT Network being built compilation_settings: Settings selected by the user for compilation - requires_output_allocator: Whether the network requires output allocator + requires_output_allocator: Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) """ net: TRTNetwork diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index cead3e14fd..eb1692e392 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -80,7 +80,7 @@ class ConverterSupport: whether that node can be supported by its companion converter. Note that this function must not modify the node or its graph supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic inputs. - requires_output_allocator: Boolean flag indicating if the converter requires to run in output allocator. + requires_output_allocator: Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators). """ converter_implementation: ConverterImplSignature @@ -215,7 +215,7 @@ def dynamo_tensorrt_converter( priority: Converter's level of priority relative to other converters with the same target supports_dynamic_shapes: Boolean flag indicating if the converter has support for dynamic shapes. - requires_output_allocator: Boolean flag indicating if the converter requires to run in output allocator. + requires_output_allocator: Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators). Returns: The converter being decorated """ @@ -410,7 +410,7 @@ def __getitem_without_validation__( def __getitem__( self, node: Node ) -> Tuple[ - Any, CallingConvention, bool + Any, CallingConvention, Dict[str, bool] ]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters """Get the first-found validated converter in any registry @@ -468,7 +468,10 @@ def __getitem__( return ( candidate.converter_implementation, calling_convention, - candidate.requires_output_allocator, + { + "supports_dynamic_shapes": candidate.supports_dynamic_shapes, + "requires_output_allocator": candidate.requires_output_allocator, + }, ) else: logger.debug( @@ -481,7 +484,10 @@ def __getitem__( return ( converters, calling_convention, - False, + { + "supports_dynamic_shapes": False, + "requires_output_allocator": False, + }, ) raise KeyError( @@ -506,7 +512,7 @@ def get_unvalidated( def get( self, node: Node, value: Optional[ConverterImplSignature] = None ) -> Union[ - Any, Tuple[Any, CallingConvention, bool] + Any, Tuple[Any, CallingConvention, Dict[str, bool]] ]: # TODO: Narrow to ConverterImplSignature this when we can remove FX converters """Get validated converter for input node with a default return""" try: diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index c34495b8c2..7f26a7c3e6 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -840,7 +840,7 @@ def call_module( f"Conversion of module of type {submod_type} not currently supported!" ) - converter, calling_convention, requires_output_allocator = converter_packet + converter, calling_convention, _ = converter_packet assert self._cur_node_name is not None @@ -857,8 +857,8 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any: f"Conversion of function {torch.typename(target)} not currently supported!" ) - converter, calling_convention, requires_output_allocator = converter_packet - if requires_output_allocator: + converter, calling_convention, converter_info = converter_packet + if converter_info.get("requires_output_allocator", False): self.ctx.requires_output_allocator = True _LOGGER.debug(f"{target} requires output allocator") @@ -890,7 +890,7 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any: raise UnsupportedOperatorException( f"Conversion of method {target} not currently supported!" ) - converter, calling_convention, requires_output_allocator = converter_packet + converter, calling_convention, _ = converter_packet if calling_convention is CallingConvention.LEGACY: return converter(self.ctx.net, target, args, kwargs, self._cur_node_name) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py index 3151c05cd4..2a2c8e9d5e 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py @@ -13,16 +13,15 @@ def remove_num_users_is_0_nodes( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: """Remove ops that [num_users=0] in the graph""" - output_node = list(gm.graph.nodes)[-1] + nodes = list(gm.graph.nodes) + output_node = nodes[-1] - for node in gm.graph.nodes: + for node in nodes[::-1]: if ( node != output_node and len(node.users) == 0 and len(node.all_input_nodes) > 0 ): - node_input = node.all_input_nodes[0] - node.replace_all_uses_with(node_input) gm.graph.erase_node(node) gm = clean_up_graph_after_modifications(gm) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 0ac81c9baf..82f5214267 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -141,7 +141,7 @@ def __init__( name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed weight_name_map (dict): Mapping of engine weight name to state_dict weight name - requires_output_allocator (bool): Whether the engine requires an output allocator + requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) Example: diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index ca9a9a9f25..e8ad07f2d6 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -98,7 +98,7 @@ def __init__( name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed weight_name_map (dict): Mapping of engine weight name to state_dict weight name - requires_output_allocator (bool): Whether the engine requires an output allocator + requires_output_allocator (bool): Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) Example: diff --git a/py/torch_tensorrt/runtime/_cudagraphs.py b/py/torch_tensorrt/runtime/_cudagraphs.py index d3c4ea2a05..1802df7a9f 100644 --- a/py/torch_tensorrt/runtime/_cudagraphs.py +++ b/py/torch_tensorrt/runtime/_cudagraphs.py @@ -74,20 +74,21 @@ def __enter__(self) -> torch.nn.Module: num_torch_module = 0 num_trt_module = 0 - disable_cudagraphs = False for name, module in self.compiled_module.named_children(): - # disable cudagraphs if any model requires output allocator + # need to disable cudagraphs if any model requires output allocator if ( hasattr(module, "requires_output_allocator") and module.requires_output_allocator ): - disable_cudagraphs = True + raise RuntimeError( + "There are converters that require Output Allocator. Please disable CUDA Graphs." + ) if "_run_on_acc" in name: num_trt_module += 1 elif "_run_on_gpu" in name: num_torch_module += 1 - if num_torch_module > 0 and not disable_cudagraphs: + if num_torch_module > 0: # Set whole cudagraphs mode and returns wrapped module _PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS # Set new mode for C++ From 41a81bb79d9d946ea9d0468d0436d81a246eb56e Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 10 Mar 2025 17:03:49 -0700 Subject: [PATCH 10/17] support C++ runtime and add tests --- core/runtime/TRTEngine.cpp | 30 ++ core/runtime/TRTEngine.h | 35 ++ core/runtime/execute_engine.cpp | 412 ++++++++++------ core/runtime/register_jit_hooks.cpp | 2 + core/runtime/runtime.h | 1 + .../runtime/_PythonTorchTensorRTModule.py | 21 +- .../dynamo/runtime/_TorchTensorRTModule.py | 21 +- .../runtime/meta_ops/register_meta_ops.py | 3 + .../runtime/test_output_allocator_py.py | 449 ++++++++---------- 9 files changed, 585 insertions(+), 389 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 3e10fd7c7d..4b67de7e42 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -30,6 +30,29 @@ std::vector split(const std::string& str, char delim) { return strings; } +DynamicOutputAllocator::DynamicOutputAllocator(const std::unordered_map& output_dtypes) + : dtypes(output_dtypes) {} + +void* DynamicOutputAllocator::reallocateOutputAsync( + char const* tensorName, + void* currentMemory, + uint64_t size, + uint64_t alignment, + cudaStream_t stream) { + std::vector shape = {static_cast(size)}; + auto it = buffers.find(tensorName); + if (it == buffers.end() || it->second.sizes() != shape) { + buffers[tensorName] = at::empty(shape, at::TensorOptions().dtype(dtypes.at(tensorName)).device(c10::kCUDA)); + return buffers[tensorName].data_ptr(); + } else { + return it->second.data_ptr(); + } +} + +void DynamicOutputAllocator::notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept { + shapes[tensorName] = dims; +} + TRTEngine::TRTEngine( const std::string& serialized_engine, const RTDevice& cuda_device, @@ -37,6 +60,7 @@ TRTEngine::TRTEngine( const std::vector& _out_binding_names, const Platform& target_platform, bool hardware_compatible, + bool requires_output_allocator, const std::string& serialized_metadata) : TRTEngine( "deserialized_trt", @@ -46,6 +70,7 @@ TRTEngine::TRTEngine( _out_binding_names, target_platform, hardware_compatible, + requires_output_allocator, serialized_metadata) {} TRTEngine::TRTEngine(std::vector serialized_info) @@ -57,6 +82,7 @@ TRTEngine::TRTEngine(std::vector serialized_info) split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM), Platform(serialized_info[TARGET_PLATFORM_IDX]), static_cast(std::stoi(serialized_info[HW_COMPATIBLE_IDX])), + static_cast(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])), serialized_info[SERIALIZED_METADATA_IDX]) {} TRTEngine::TRTEngine( @@ -67,6 +93,7 @@ TRTEngine::TRTEngine( const std::vector& _out_binding_names, const Platform& target_platform, bool hardware_compatible, + bool requires_output_allocator, const std::string& serialized_metadata) { TORCHTRT_CHECK( is_supported_on_current_platform(target_platform), @@ -79,6 +106,7 @@ TRTEngine::TRTEngine( TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine"); this->serialized_metadata = serialized_metadata; + this->requires_output_allocator = requires_output_allocator; device_info = most_compatible_device.value(); multi_gpu_device_check(); set_rt_device(device_info); @@ -397,6 +425,7 @@ FlattenedState TRTEngine::__obj_flatten__() { std::tuple("out_binding_names", serialized_info[OUTPUT_BINDING_NAMES_IDX]), std::tuple("hardware_compatible", serialized_info[HW_COMPATIBLE_IDX]), std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]), + std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]), std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX])); } @@ -417,6 +446,7 @@ std::vector TRTEngine::serialize() { serialized_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(this->in_binding_names); serialized_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(this->out_binding_names); serialized_info[HW_COMPATIBLE_IDX] = this->hardware_compatible ? "1" : "0"; + serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this->requires_output_allocator ? "1" : "0"; serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata; serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize(); diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index e1d8ba5471..e9b1905610 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -27,6 +27,7 @@ using FlattenedState = std::tuple< std::tuple, // input binding names std::tuple, // output binding names std::tuple, // HW compatibility + std::tuple, // requires_output_allocator std::tuple, // serialized metadata std::tuple>; // Platform @@ -69,6 +70,33 @@ struct TorchTRTRuntimeStates { } }; +class DynamicOutputAllocator : public nvinfer1::IOutputAllocator { + public: + DynamicOutputAllocator(const std::unordered_map& output_dtypes); + + void* reallocateOutputAsync( + char const* tensorName, + void* currentMemory, + uint64_t size, + uint64_t alignment, + cudaStream_t stream) override; + + void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override; + + const std::unordered_map& getBuffers() const { + return buffers; + } + + const std::unordered_map& getShapes() const { + return shapes; + } + + private: + std::unordered_map dtypes; + std::unordered_map buffers; + std::unordered_map shapes; +}; + struct TRTEngine : torch::CustomClassHolder { // Each engine needs it's own runtime object std::shared_ptr rt; @@ -99,6 +127,7 @@ struct TRTEngine : torch::CustomClassHolder { const std::vector& out_binding_names, const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, + bool requires_output_allocator = false, const std::string& serialized_metadata = ""); TRTEngine(std::vector serialized_info); @@ -111,6 +140,7 @@ struct TRTEngine : torch::CustomClassHolder { const std::vector& out_binding_names, const Platform& target_platform = get_current_platform(), bool hardware_compatible = false, + bool requires_output_allocator = false, const std::string& serialized_metadata = ""); TRTEngine& operator=(const TRTEngine& other); @@ -146,6 +176,11 @@ struct TRTEngine : torch::CustomClassHolder { bool use_pre_allocated_outputs = false; std::vector pre_allocated_outputs; + // Output Allocator-Related Functionality + bool requires_output_allocator = false; // engine requires output allocator + bool use_output_allocator_outputs = false; // users specify to use output allocator + std::shared_ptr output_allocator; + // TODO: Implement a call method // c10::List Run(c10::List inputs); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 5348ade8c4..c857cff6c8 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -13,7 +13,7 @@ namespace torch_tensorrt { namespace core { namespace runtime { -// Checks if the context switch requred for device ID +// Checks if the context switch required for device ID bool is_switch_required(const RTDevice& curr_device, const RTDevice& engine_device) { // If SM capability is not the same as configured then switch if ((curr_device.major != engine_device.major) || (curr_device.minor != engine_device.minor)) { @@ -91,6 +91,7 @@ bool _validate_shapes(std::vector inputs, c10::intrusive_ptr inputs, c10::intrusive_ptr compiled_engine, @@ -163,6 +164,7 @@ void setup_input_tensors( } } } + std::vector create_output_tensors(c10::intrusive_ptr compiled_engine) { std::vector outputs(compiled_engine->num_io.second); for (auto output_indices : compiled_engine->out_binding_map) { @@ -181,7 +183,268 @@ std::vector create_output_tensors(c10::intrusive_ptr comp return outputs; } +void create_output_allocator(c10::intrusive_ptr compiled_engine) { + if (compiled_engine->output_allocator == nullptr) { + std::unordered_map output_dtypes_dict; + for (size_t o = 0; o < compiled_engine->out_binding_names.size(); ++o) { + auto name = compiled_engine->out_binding_names[o]; + output_dtypes_dict[name] = + util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); + } + compiled_engine->output_allocator = std::make_shared(output_dtypes_dict); + } + for (const auto& output_name : compiled_engine->out_binding_names) { + if (!compiled_engine->exec_ctx->setOutputAllocator(output_name.c_str(), compiled_engine->output_allocator.get())) { + TORCHTRT_THROW_ERROR("Failed to set output allocator for " + output_name); + } + } +} + std::vector execute_engine(std::vector inputs, c10::intrusive_ptr compiled_engine) { + auto run_standard_execution = [&]() { + bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); + bool shape_changed = _validate_shapes(inputs, compiled_engine); + + // Whether cudagraphs needs to record the graph on this pass + auto result = compiled_engine->runtime_states.set_runtime_states( + cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed); + + bool need_cudagraphs_record = std::get<0>(result); + bool can_use_pre_allocated_outputs = std::get<1>(result); + bool need_cudagraphs_reset = std::get<2>(result); + + if (need_cudagraphs_reset) { + compiled_engine->cudagraph.reset(); + } + + std::vector outputs(compiled_engine->num_io.second); + + // Intialize inputs and outputs to be available throughout the succeeding scopes + { // Input Setup + std::unique_ptr input_profiler_guard; + if (compiled_engine->profile_execution) { + input_profiler_guard = + std::make_unique(compiled_engine->input_profile_path); + } + + setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record); + // Check if input shapes can be inferred. + int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; + std::vector names(io_size); + int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data()); + TORCHTRT_CHECK( + nbNames == 0, + "The shapes of the inputs: " + << names + << " cannot be inferred. This could happen if the input tensor addresses/shapes haven't been configured correctly"); + } + + { // Output Setup + std::unique_ptr output_profiler_guard; + if (compiled_engine->profile_execution) { + output_profiler_guard = + std::make_unique(compiled_engine->output_profile_path); + } + if (can_use_pre_allocated_outputs) { + outputs = compiled_engine->pre_allocated_outputs; + } else { + outputs = create_output_tensors(compiled_engine); + } + + for (auto output_indices : compiled_engine->out_binding_map) { + auto pyt_idx = output_indices.second; + std::string name = compiled_engine->out_binding_names[pyt_idx]; + if (need_cudagraphs_record) { + // If we are recording the cuda graph then we need to update the persistent output buffer + compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); + } + + if (cudagraphs_enabled) { + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress( + name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), + "Error while setting the output tensor address"); + } else { + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()), + "Error while setting the output tensor address"); + } + } + } + + auto current_device_id = -1; + if (inputs.size() > 0) { + current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart + } else if (outputs.size() > 0) { + current_device_id = outputs[0].device().index(); // Done this way to avoid a call to cudart + } + + compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); + if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { + // Create a new stream if the engine stream is the default stream + compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); + } + + { // Engine Execution (execute on engine stream) + c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream); + + std::unique_ptr enqueue_profiler_guard; + if (compiled_engine->profile_execution) { + enqueue_profiler_guard = + std::make_unique(compiled_engine->enqueue_profile_path); + } + + // Block engine stream until results are available on caller stream + at::cuda::CUDAEvent caller_exec_complete; + caller_exec_complete.record(compiled_engine->caller_stream); + caller_exec_complete.block(compiled_engine->engine_stream); + + if (!cudagraphs_enabled) { + // Direct execution uses the caller buffers directly + compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); + } else { + if (need_cudagraphs_record) { + // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph + c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream; + compiled_engine->cudagraph.capture_begin(); + compiled_engine->exec_ctx->enqueueV3(recording_stream); + compiled_engine->cudagraph.capture_end(); + + if (compiled_engine->profile_execution) { + compiled_engine->cudagraph.debug_dump(compiled_engine->cuda_graph_debug_path); + } + } + + // Replay the CUDAGraph + compiled_engine->cudagraph.replay(); // Has a cudaDeviceSynchronize internally + } + } // End engine exeuction (resets to caller stream) + + // Create output buffer for next execution of graph or trt context. + if (compiled_engine->use_pre_allocated_outputs) { + compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine); + } + + // Block caller stream until engine execution is complete + at::cuda::CUDAEvent trt_exec_complete; + trt_exec_complete.record(compiled_engine->engine_stream); + trt_exec_complete.block(compiled_engine->caller_stream); + + if (cudagraphs_enabled) { + // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream) + for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) { + outputs[o].copy_(compiled_engine->output_buffers[o], false); + } + } + + if (compiled_engine->profile_execution) { + LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler); + dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler); + compiled_engine->dump_engine_layer_info(); + } + + return outputs; + }; + + auto run_output_allocator = [&]() { + { // Input Setup + std::unique_ptr input_profiler_guard; + if (compiled_engine->profile_execution) { + input_profiler_guard = + std::make_unique(compiled_engine->input_profile_path); + } + + setup_input_tensors(inputs, compiled_engine, false, false); + // Check if input shapes can be inferred. + int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; + std::vector names(io_size); + int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data()); + TORCHTRT_CHECK( + nbNames == 0, + "The shapes of the inputs: " + << names + << " cannot be inferred. This could happen if the input tensor addresses/shapes haven't been configured correctly"); + } + + { // OutputAllocator Setup + std::unique_ptr output_allocator_profiler_guard; + if (compiled_engine->profile_execution) { + output_allocator_profiler_guard = + std::make_unique(compiled_engine->output_profile_path); + } + create_output_allocator(compiled_engine); + } + + auto current_device_id = -1; + if (inputs.size() > 0) { + current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart + } else { + current_device_id = c10::cuda::current_device(); + } + + compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); + if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { + // Create a new stream if the engine stream is the default stream + compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); + } + + { // Engine Execution (execute on engine stream) + c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream); + + std::unique_ptr enqueue_profiler_guard; + if (compiled_engine->profile_execution) { + enqueue_profiler_guard = + std::make_unique(compiled_engine->enqueue_profile_path); + } + + // Block engine stream until results are available on caller stream + at::cuda::CUDAEvent caller_exec_complete; + caller_exec_complete.record(compiled_engine->caller_stream); + caller_exec_complete.block(compiled_engine->engine_stream); + + // Direct execution uses the caller buffers directly + compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); + + } // End engine exeuction (resets to caller stream) + + // Block caller stream until engine execution is complete + at::cuda::CUDAEvent trt_exec_complete; + trt_exec_complete.record(compiled_engine->engine_stream); + trt_exec_complete.block(compiled_engine->caller_stream); + + std::unique_ptr output_profiler_guard; + if (compiled_engine->profile_execution) { + output_profiler_guard = + std::make_unique(compiled_engine->output_profile_path); + } + std::vector outputs; + for (size_t i = 0; i < compiled_engine->out_binding_names.size(); i++) { + auto name = compiled_engine->out_binding_names[i]; + auto dims = compiled_engine->output_allocator->getShapes().at(name); + auto dtype = + util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); + at::Tensor output = compiled_engine->output_allocator->getBuffers().at(name).clone().detach(); + int64_t prod = 1; + for (int i = 0; i < dims.nbDims; ++i) { + prod *= dims.d[i]; + } + std::vector dims_vec(dims.nbDims); + for (int i = 0; i < dims.nbDims; ++i) { + dims_vec[i] = dims.d[i]; + } + output = output.reshape(-1).view(dtype).slice(0, 0, prod).reshape(dims_vec); + outputs.push_back(output); + } + + if (compiled_engine->profile_execution) { + LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler); + dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler); + compiled_engine->dump_engine_layer_info(); + } + + return outputs; + }; + LOG_DEBUG( "Attempting to run engine (ID: " << compiled_engine->name << "); Hardware Compatible: " << compiled_engine->hardware_compatible); @@ -203,22 +466,6 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->cudagraph.enable_debug_mode(); } bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); - bool shape_changed = _validate_shapes(inputs, compiled_engine); - - // Whether cudagraphs needs to record the graph on this pass - auto result = compiled_engine->runtime_states.set_runtime_states( - cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed); - - bool need_cudagraphs_record = std::get<0>(result); - bool can_use_pre_allocated_outputs = std::get<1>(result); - bool need_cudagraphs_reset = std::get<2>(result); - - if (need_cudagraphs_reset) { - compiled_engine->cudagraph.reset(); - } - - // Intialize inputs and outputs to be available throughout the succeeding scopes - std::vector outputs(compiled_engine->num_io.second); if (MULTI_DEVICE_SAFE_MODE) { std::unique_ptr device_profiler_guard; @@ -268,130 +515,25 @@ std::vector execute_engine(std::vector inputs, c10::intr } } - { // Input Setup - std::unique_ptr input_profiler_guard; - if (compiled_engine->profile_execution) { - input_profiler_guard = - std::make_unique(compiled_engine->input_profile_path); - } - - setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record); - // Check if input shapes can be inferred. - int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; - std::vector names(io_size); - int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data()); - TORCHTRT_CHECK( - nbNames == 0, - "The shapes of the inputs: " - << names - << " cannot be inferred. This could happen if the input tensor addresses/shapes haven't been configured correctly"); - } - - { // Output Setup - std::unique_ptr output_profiler_guard; - if (compiled_engine->profile_execution) { - output_profiler_guard = - std::make_unique(compiled_engine->output_profile_path); - } - if (can_use_pre_allocated_outputs) { - outputs = compiled_engine->pre_allocated_outputs; - } else { - outputs = create_output_tensors(compiled_engine); + if (compiled_engine->requires_output_allocator) { // engine requires OA + if (cudagraphs_enabled) { + TORCHTRT_THROW_ERROR( + "This module requires OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs."); } - - for (auto output_indices : compiled_engine->out_binding_map) { - auto pyt_idx = output_indices.second; - std::string name = compiled_engine->out_binding_names[pyt_idx]; - if (need_cudagraphs_record) { - // If we are recording the cuda graph then we need to update the persistent output buffer - compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); - } - + LOG_DEBUG("Using OutputAllocator in runtime."); + return run_output_allocator(); + } else { + if (compiled_engine->use_output_allocator_outputs) { // users call OA context manager if (cudagraphs_enabled) { - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress( - name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), - "Error while setting the output tensor address"); - } else { - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()), - "Error while setting the output tensor address"); + TORCHTRT_THROW_ERROR("Both CUDA Graphs and OutputAllocator are enabled. Please disable either one."); } - } - } - - auto current_device_id = -1; - if (inputs.size() > 0) { - current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart - } else if (outputs.size() > 0) { - current_device_id = outputs[0].device().index(); // Done this way to avoid a call to cudart - } - - compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); - if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { - // Create a new stream if the engine stream is the default stream - compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); - } - - { // Engine Execution (execute on engine stream) - c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream); - - std::unique_ptr enqueue_profiler_guard; - if (compiled_engine->profile_execution) { - enqueue_profiler_guard = - std::make_unique(compiled_engine->enqueue_profile_path); - } - - // Block engine stream until results are available on caller stream - at::cuda::CUDAEvent caller_exec_complete; - caller_exec_complete.record(compiled_engine->caller_stream); - caller_exec_complete.block(compiled_engine->engine_stream); - - if (!cudagraphs_enabled) { - // Direct execution uses the caller buffers directly - compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); + LOG_DEBUG("Using OutputAllocator in runtime."); + return run_output_allocator(); } else { - if (need_cudagraphs_record) { - // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph - c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream; - compiled_engine->cudagraph.capture_begin(); - compiled_engine->exec_ctx->enqueueV3(recording_stream); - compiled_engine->cudagraph.capture_end(); - - if (compiled_engine->profile_execution) { - compiled_engine->cudagraph.debug_dump(compiled_engine->cuda_graph_debug_path); - } - } - - // Replay the CUDAGraph - compiled_engine->cudagraph.replay(); // Has a cudaDeviceSynchronize internally + LOG_DEBUG("Using standard execution with cudagraphs=" << cudagraphs_enabled << "."); + return run_standard_execution(); } - } // End engine exeuction (resets to caller stream) - - // Create output buffer for next execution of graph or trt context. - if (compiled_engine->use_pre_allocated_outputs) { - compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine); } - - // Block caller stream until engine execution is complete - at::cuda::CUDAEvent trt_exec_complete; - trt_exec_complete.record(compiled_engine->engine_stream); - trt_exec_complete.block(compiled_engine->caller_stream); - - if (cudagraphs_enabled) { - // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream) - for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) { - outputs[o].copy_(compiled_engine->output_buffers[o], false); - } - } - - if (compiled_engine->profile_execution) { - LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler); - dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler); - compiled_engine->dump_engine_layer_info(); - } - - return outputs; } } // namespace runtime diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 3ded080b1d..c05be4e8aa 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -89,6 +89,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("get_engine_layer_info", &TRTEngine::get_engine_layer_info) .def("infer_outputs", &TRTEngine::infer_outputs) .def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs) + .def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs) .def_property( "device_memory_budget", &TRTEngine::get_device_memory_budget, @@ -130,6 +131,7 @@ TORCH_LIBRARY(tensorrt, m) { m.def("HW_COMPATIBLE_IDX", []() -> int64_t { return HW_COMPATIBLE_IDX; }); m.def("SERIALIZED_METADATA_IDX", []() -> int64_t { return SERIALIZED_METADATA_IDX; }); m.def("TARGET_PLATFORM_IDX", []() -> int64_t { return TARGET_PLATFORM_IDX; }); + m.def("REQUIRES_OUTPUT_ALLOCATOR_IDX", []() -> int64_t { return REQUIRES_OUTPUT_ALLOCATOR_IDX; }); m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; }); m.def("_platform_linux_x86_64", []() -> std::string { auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64); diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 6f1436c745..3a2fe9c24b 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -37,6 +37,7 @@ typedef enum { HW_COMPATIBLE_IDX, SERIALIZED_METADATA_IDX, TARGET_PLATFORM_IDX, + REQUIRES_OUTPUT_ALLOCATOR_IDX, SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO } SerializedInfoIndex; diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 82f5214267..d0a1cf6c2e 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -678,7 +678,7 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: ] logger.warning(f"Moved all input Tensors to cuda:{device_id}") - if self.requires_output_allocator: + if self.requires_output_allocator: # engine requires OA if self.cudagraphs_enabled: raise RuntimeError( "This module requires OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs." @@ -686,17 +686,18 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: logger.debug("Using OutputAllocator in runtime.") return run_output_allocator() else: - if self.cudagraphs_enabled and self.use_output_allocator_outputs: - raise RuntimeError( - "Both CUDA Graphs and OutputAllocator are enabled. Please disable either one." - ) - if self.use_output_allocator_outputs: + if self.use_output_allocator_outputs: # users call OA context manager + if self.cudagraphs_enabled: + raise RuntimeError( + "Both CUDA Graphs and OutputAllocator are enabled. Please disable either one." + ) logger.debug("Using OutputAllocator in runtime.") return run_output_allocator() - logger.debug( - f"Using standard execution with cudagraphs={self.cudagraphs_enabled}." - ) - return run_standard_execution() + else: + logger.debug( + f"Using standard execution with cudagraphs={self.cudagraphs_enabled}." + ) + return run_standard_execution() def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None: """ diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index e8ad07f2d6..5f49326e28 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -34,6 +34,7 @@ HW_COMPATIBLE_IDX = -1 # Not implemented SERIALIZED_METADATA_IDX = -1 # Not implemented TARGET_PLATFORM_IDX = -1 # Not implemented +REQUIRES_OUTPUT_ALLOCATOR_IDX = -1 # Not implemented SERIALIZATION_LEN = -1 # Not implemented if ENABLED_FEATURES.torch_tensorrt_runtime: @@ -46,7 +47,10 @@ HW_COMPATIBLE_IDX = torch.ops.tensorrt.HW_COMPATIBLE_IDX() # 6 SERIALIZED_METADATA_IDX = torch.ops.tensorrt.SERIALIZED_METADATA_IDX() # 7 TARGET_PLATFORM_IDX = torch.ops.tensorrt.TARGET_PLATFORM_IDX() # 8 - SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 9 + REQUIRES_OUTPUT_ALLOCATOR_IDX = ( + torch.ops.tensorrt.REQUIRES_OUTPUT_ALLOCATOR_IDX() + ) # 9 + SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 10 @for_all_methods(needs_torch_tensorrt_runtime) @@ -152,7 +156,6 @@ def _pack_engine_info(self) -> List[str | bytes]: metadata = { "settings": self.settings, "weight_name_map": self.weight_name_map, - "requires_output_allocator": self.requires_output_allocator, } target_platform = ( Platform.current_platform() @@ -178,6 +181,9 @@ def _pack_engine_info(self) -> List[str | bytes]: engine_info[HW_COMPATIBLE_IDX] = str(int(self.hardware_compatible)) engine_info[SERIALIZED_METADATA_IDX] = self.encode_metadata(metadata) engine_info[TARGET_PLATFORM_IDX] = target_platform._to_serialized_rt_platform() + engine_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = str( + int(self.requires_output_allocator) + ) return engine_info @@ -263,14 +269,18 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: serialized_engine_info[ENGINE_IDX] ) self.engine = torch.classes.tensorrt.Engine(serialized_engine_info) - self.hardware_compatible = bool(int(state[1][HW_COMPATIBLE_IDX])) + self.hardware_compatible = bool( + int(serialized_engine_info[HW_COMPATIBLE_IDX]) + ) + self.requires_output_allocator = bool( + int(serialized_engine_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]) + ) serialized_metadata = serialized_engine_info[SERIALIZED_METADATA_IDX] assert isinstance(serialized_metadata, bytes) metadata = TorchTensorRTModule.decode_metadata(serialized_metadata) self.settings = metadata["settings"] self.weight_name_map = metadata["weight_name_map"] - self.requires_output_allocator = metadata["requires_output_allocator"] else: self.engine = None @@ -283,6 +293,9 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: def set_pre_allocated_outputs(self, enable: bool) -> None: self.engine.use_pre_allocated_outputs = enable + def set_output_allocator_outputs(self, enable: bool) -> None: + self.engine.use_output_allocator_outputs = enable + def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: """Implementation of the forward pass for a TensorRT engine diff --git a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py index eee743c497..f481c5b2b8 100644 --- a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py +++ b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py @@ -93,6 +93,9 @@ def __init__(self, engine_info: List[str]) -> None: self.serialized_metadata = engine_info[ torch.ops.tensorrt.SERIALIZED_METADATA_IDX() ] + self.requires_output_allocator = engine_info[ + torch.ops.tensorrt.REQUIRES_OUTPUT_ALLOCATOR_IDX() + ] self.target_platform = engine_info[torch.ops.tensorrt.TARGET_PLATFORM_IDX()] @classmethod diff --git a/tests/py/dynamo/runtime/test_output_allocator_py.py b/tests/py/dynamo/runtime/test_output_allocator_py.py index c0ac064f7a..7065df1649 100644 --- a/tests/py/dynamo/runtime/test_output_allocator_py.py +++ b/tests/py/dynamo/runtime/test_output_allocator_py.py @@ -1,6 +1,7 @@ import pytest import torch import torch_tensorrt +from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests from ..testing_utilities import DECIMALS_OF_AGREEMENT @@ -35,8 +36,14 @@ def forward(self, input): return out -class TestOutputAllocatorStaticModelPython(TestCase): - def test_cudagraphs(self): +class TestOutputAllocatorStaticModel(TestCase): + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_cudagraphs_and_output_allocator(self, _, use_python_runtime): model = StaticModel().eval().cuda() inputs = [torch.randn((2, 3), dtype=torch.float).cuda()] compiled_model = torch_tensorrt.compile( @@ -44,7 +51,7 @@ def test_cudagraphs(self): "dynamo", inputs, min_block_size=1, - use_python_runtime=True, + use_python_runtime=use_python_runtime, ) ref_out = model(*inputs) @@ -58,22 +65,9 @@ def test_cudagraphs(self): float(torch.max(torch.abs(ref_out - cg_out))), 0, DECIMALS_OF_AGREEMENT, - msg="CUDA Graphs Python TRT outputs don't match with the original model.", - ) - - def test_output_allocator(self): - model = StaticModel().eval().cuda() - inputs = [torch.randn((2, 3), dtype=torch.float).cuda()] - compiled_model = torch_tensorrt.compile( - model, - "dynamo", - inputs, - min_block_size=1, - use_python_runtime=True, + msg="CUDA Graphs runtime outputs don't match with the original model.", ) - ref_out = model(*inputs) - with torch_tensorrt.runtime.enable_output_allocator(compiled_model): oa_out = compiled_model(*inputs) @@ -81,10 +75,16 @@ def test_output_allocator(self): float(torch.max(torch.abs(ref_out - oa_out))), 0, DECIMALS_OF_AGREEMENT, - msg="Output Allocator Python TRT outputs don't match with the original model.", + msg="Output Allocator runtime outputs don't match with the original model.", ) - def test_default(self): + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_default(self, _, use_python_runtime): """ Static models use standard execution with cudagraphs=False by default. """ @@ -95,7 +95,7 @@ def test_default(self): "dynamo", inputs, min_block_size=1, - use_python_runtime=True, + use_python_runtime=use_python_runtime, ) standard_out = compiled_model(*inputs) ref_out = model(*inputs) @@ -104,15 +104,16 @@ def test_default(self): float(torch.max(torch.abs(ref_out - standard_out))), 0, DECIMALS_OF_AGREEMENT, - msg="Default standard execution outputs don't match with the original model.", + msg="Default standard execution (cudagraphs=False) outputs don't match with the original model.", ) - @pytest.mark.xfail( - strict=True, - raises=RuntimeError, - reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] ) - def test_cudagraphs_and_output_allocator(self): + def test_combination_of_cg_and_oa(self, _, use_python_runtime): model = StaticModel().eval().cuda() inputs = [torch.randn((2, 3), dtype=torch.float).cuda()] compiled_model = torch_tensorrt.compile( @@ -120,44 +121,38 @@ def test_cudagraphs_and_output_allocator(self): "dynamo", inputs, min_block_size=1, - use_python_runtime=True, + use_python_runtime=use_python_runtime, ) - with torch_tensorrt.runtime.enable_cudagraphs( - compiled_model - ) as cudagraphs_module: - with torch_tensorrt.runtime.enable_output_allocator(cudagraphs_module): - out = cudagraphs_module(*inputs) - - @pytest.mark.xfail( - strict=True, - raises=RuntimeError, - reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", - ) - def test_output_allocator_and_cudagraphs(self): - model = StaticModel().eval().cuda() - inputs = [torch.randn((2, 3), dtype=torch.float).cuda()] - compiled_model = torch_tensorrt.compile( - model, - "dynamo", - inputs, - min_block_size=1, - use_python_runtime=True, - ) - with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + with pytest.raises( + RuntimeError, + match="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + ): with torch_tensorrt.runtime.enable_cudagraphs( compiled_model ) as cudagraphs_module: - out = cudagraphs_module(*inputs) - - -class TestOutputAllocatorDDSModelPython(TestCase): - @pytest.mark.xfail( - strict=True, - raises=RuntimeError, - reason="This module requires OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs.", + with torch_tensorrt.runtime.enable_output_allocator(cudagraphs_module): + out = cudagraphs_module(*inputs) + + with pytest.raises( + RuntimeError, + match="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + ): + with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + out = cudagraphs_module(*inputs) + + +class TestOutputAllocatorDDSModel(TestCase): + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] ) - def test_cudagraphs(self): + def test_cudagraphs_and_output_allocator(self, _, use_python_runtime): model = DDSModel().eval().cuda() inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) compiled_model = torch_tensorrt.compile( @@ -165,38 +160,37 @@ def test_cudagraphs(self): "dynamo", inputs, min_block_size=1, - use_python_runtime=True, + use_python_runtime=use_python_runtime, ) - with torch_tensorrt.runtime.enable_cudagraphs( - compiled_model - ) as cudagraphs_module: - cg_out = cudagraphs_module(*inputs) - - def test_output_allocator(self): - model = DDSModel().eval().cuda() - inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) - compiled_model = torch_tensorrt.compile( - model, - "dynamo", - inputs, - min_block_size=1, - use_python_runtime=True, - ) - - ref_out = model(*inputs) + with pytest.raises( + RuntimeError, + match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + ): + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + cg_out = cudagraphs_module(*inputs) with torch_tensorrt.runtime.enable_output_allocator(compiled_model): oa_out = compiled_model(*inputs) + ref_out = model(*inputs) + self.assertAlmostEqual( float(torch.max(torch.abs(ref_out - oa_out))), 0, DECIMALS_OF_AGREEMENT, - msg="Output Allocator Python TRT outputs don't match with the original model.", + msg="Output Allocator runtime outputs don't match with the original model.", ) - def test_default(self): + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_default(self, _, use_python_runtime): """ DDS models use OutputAllocator by default. """ @@ -207,7 +201,7 @@ def test_default(self): "dynamo", inputs, min_block_size=1, - use_python_runtime=True, + use_python_runtime=use_python_runtime, ) oa_out = compiled_model(*inputs) ref_out = model(*inputs) @@ -216,15 +210,16 @@ def test_default(self): float(torch.max(torch.abs(ref_out - oa_out))), 0, DECIMALS_OF_AGREEMENT, - msg="Default Output Allocator Python TRT outputs don't match with the original model.", + msg="Default Output Allocator runtime outputs don't match with the original model.", ) - @pytest.mark.xfail( - strict=True, - raises=RuntimeError, - reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] ) - def test_cudagraphs_and_output_allocator(self): + def test_combination_of_cg_and_oa(self, _, use_python_runtime): model = DDSModel().eval().cuda() inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) compiled_model = torch_tensorrt.compile( @@ -232,69 +227,42 @@ def test_cudagraphs_and_output_allocator(self): "dynamo", inputs, min_block_size=1, - use_python_runtime=True, + use_python_runtime=use_python_runtime, ) - with torch_tensorrt.runtime.enable_cudagraphs( - compiled_model - ) as cudagraphs_module: - with torch_tensorrt.runtime.enable_output_allocator(cudagraphs_module): - out = cudagraphs_module(*inputs) - - @pytest.mark.xfail( - strict=True, - raises=RuntimeError, - reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", - ) - def test_output_allocator_and_cudagraphs(self): - model = DDSModel().eval().cuda() - inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) - compiled_model = torch_tensorrt.compile( - model, - "dynamo", - inputs, - min_block_size=1, - use_python_runtime=True, - ) - with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + with pytest.raises( + RuntimeError, + match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + ): with torch_tensorrt.runtime.enable_cudagraphs( compiled_model ) as cudagraphs_module: - out = cudagraphs_module(*inputs) - - -class TestOutputAllocatorNonDDSModelPython(TestCase): - @pytest.mark.xfail( - strict=True, - raises=RuntimeError, - reason="This module requires OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs.", + with torch_tensorrt.runtime.enable_output_allocator(cudagraphs_module): + out = cudagraphs_module(*inputs) + + with pytest.raises( + RuntimeError, + match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + ): + with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + out = cudagraphs_module(*inputs) + + +class TestOutputAllocatorNonDDSModel(TestCase): + """ + The NonDDSModel is a model that contains DDS op + reduction op. + """ + + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] ) - def test_cudagraphs(self): - model = NonDDSModel().eval().cuda() - inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) - compiled_model = torch_tensorrt.compile( - model, - "dynamo", - inputs, - min_block_size=1, - use_python_runtime=True, - ) - - ref_out = model(*inputs) - - with torch_tensorrt.runtime.enable_cudagraphs( - compiled_model - ) as cudagraphs_module: - cg_out = cudagraphs_module(*inputs) - - self.assertAlmostEqual( - float(torch.max(torch.abs(ref_out - cg_out))), - 0, - DECIMALS_OF_AGREEMENT, - msg="CUDA Graphs Python TRT outputs don't match with the original model.", - ) - - def test_output_allocator(self): + def test_cudagraphs_and_output_allocator(self, _, use_python_runtime): model = NonDDSModel().eval().cuda() inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) compiled_model = torch_tensorrt.compile( @@ -302,33 +270,48 @@ def test_output_allocator(self): "dynamo", inputs, min_block_size=1, - use_python_runtime=True, + use_python_runtime=use_python_runtime, ) - ref_out = model(*inputs) + with pytest.raises( + RuntimeError, + match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + ): + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + cg_out = cudagraphs_module(*inputs) with torch_tensorrt.runtime.enable_output_allocator(compiled_model): oa_out = compiled_model(*inputs) + ref_out = model(*inputs) + self.assertAlmostEqual( float(torch.max(torch.abs(ref_out - oa_out))), 0, DECIMALS_OF_AGREEMENT, - msg="Output Allocator Python TRT outputs don't match with the original model.", + msg="Output Allocator runtime outputs don't match with the original model.", ) - def test_default(self): + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_default(self, _, use_python_runtime): """ NonDDS models use standard execution with cudagraphs=False by default. """ - model = DDSModel().eval().cuda() + model = NonDDSModel().eval().cuda() inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) compiled_model = torch_tensorrt.compile( model, "dynamo", inputs, min_block_size=1, - use_python_runtime=True, + use_python_runtime=use_python_runtime, ) standard_out = compiled_model(*inputs) ref_out = model(*inputs) @@ -337,15 +320,16 @@ def test_default(self): float(torch.max(torch.abs(ref_out - standard_out))), 0, DECIMALS_OF_AGREEMENT, - msg="Default standard execution outputs don't match with the original model.", + msg="Default Output Allocator runtime outputs don't match with the original model.", ) - @pytest.mark.xfail( - strict=True, - raises=RuntimeError, - reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] ) - def test_cudagraphs_and_output_allocator(self): + def test_combination_of_cg_and_oa(self, _, use_python_runtime): model = NonDDSModel().eval().cuda() inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) compiled_model = torch_tensorrt.compile( @@ -353,44 +337,38 @@ def test_cudagraphs_and_output_allocator(self): "dynamo", inputs, min_block_size=1, - use_python_runtime=True, + use_python_runtime=use_python_runtime, ) - with torch_tensorrt.runtime.enable_cudagraphs( - compiled_model - ) as cudagraphs_module: - with torch_tensorrt.runtime.enable_output_allocator(cudagraphs_module): - out = cudagraphs_module(*inputs) - - @pytest.mark.xfail( - strict=True, - raises=RuntimeError, - reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", - ) - def test_output_allocator_and_cudagraphs(self): - model = NonDDSModel().eval().cuda() - inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) - compiled_model = torch_tensorrt.compile( - model, - "dynamo", - inputs, - min_block_size=1, - use_python_runtime=True, - ) - with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + with pytest.raises( + RuntimeError, + match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + ): with torch_tensorrt.runtime.enable_cudagraphs( compiled_model ) as cudagraphs_module: - out = cudagraphs_module(*inputs) - - -class TestOutputAllocatorDDSModelWithGraphBreakPython(TestCase): - @pytest.mark.xfail( - strict=True, - raises=RuntimeError, - reason="This module requires OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs.", + with torch_tensorrt.runtime.enable_output_allocator(cudagraphs_module): + out = cudagraphs_module(*inputs) + + with pytest.raises( + RuntimeError, + match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + ): + with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + out = cudagraphs_module(*inputs) + + +class TestOutputAllocatorDDSModelWithGraphBreak(TestCase): + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] ) - def test_cudagraphs(self): + def test_cudagraphs_and_output_allocator(self, _, use_python_runtime): model = DDSModel2().eval().cuda() inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) compiled_model = torch_tensorrt.compile( @@ -398,42 +376,40 @@ def test_cudagraphs(self): "dynamo", inputs, min_block_size=1, - use_python_runtime=True, + use_python_runtime=use_python_runtime, torch_executed_ops={"torch.ops.aten.abs.default"}, ) - with torch_tensorrt.runtime.enable_cudagraphs( - compiled_model - ) as cudagraphs_module: - cg_out = cudagraphs_module(*inputs) - - def test_output_allocator(self): - model = DDSModel2().eval().cuda() - inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) - compiled_model = torch_tensorrt.compile( - model, - "dynamo", - inputs, - min_block_size=1, - use_python_runtime=True, - torch_executed_ops={"torch.ops.aten.abs.default"}, - ) - - ref_out = model(*inputs) + with pytest.raises( + RuntimeError, + match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + ): + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + cg_out = cudagraphs_module(*inputs) with torch_tensorrt.runtime.enable_output_allocator(compiled_model): oa_out = compiled_model(*inputs) + ref_out = model(*inputs) + self.assertAlmostEqual( float(torch.max(torch.abs(ref_out - oa_out))), 0, DECIMALS_OF_AGREEMENT, - msg="Output Allocator Python TRT outputs don't match with the original model.", + msg="Output Allocator runtime outputs don't match with the original model.", ) - def test_default(self): + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_default(self, _, use_python_runtime): """ - Use OutputAllocator by default. + Use Output Allocator by default. """ model = DDSModel2().eval().cuda() inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) @@ -442,7 +418,7 @@ def test_default(self): "dynamo", inputs, min_block_size=1, - use_python_runtime=True, + use_python_runtime=use_python_runtime, torch_executed_ops={"torch.ops.aten.abs.default"}, ) oa_out = compiled_model(*inputs) @@ -452,15 +428,16 @@ def test_default(self): float(torch.max(torch.abs(ref_out - oa_out))), 0, DECIMALS_OF_AGREEMENT, - msg="Default Output Allocator Python TRT outputs don't match with the original model.", + msg="Default Output Allocator runtime outputs don't match with the original model.", ) - @pytest.mark.xfail( - strict=True, - raises=RuntimeError, - reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] ) - def test_cudagraphs_and_output_allocator(self): + def test_combination_of_cg_and_oa(self, _, use_python_runtime): model = DDSModel2().eval().cuda() inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) compiled_model = torch_tensorrt.compile( @@ -468,37 +445,29 @@ def test_cudagraphs_and_output_allocator(self): "dynamo", inputs, min_block_size=1, - use_python_runtime=True, + use_python_runtime=use_python_runtime, torch_executed_ops={"torch.ops.aten.abs.default"}, ) - with torch_tensorrt.runtime.enable_cudagraphs( - compiled_model - ) as cudagraphs_module: - with torch_tensorrt.runtime.enable_output_allocator(cudagraphs_module): - out = cudagraphs_module(*inputs) - - @pytest.mark.xfail( - strict=True, - raises=RuntimeError, - reason="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", - ) - def test_output_allocator_and_cudagraphs(self): - model = DDSModel2().eval().cuda() - inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) - compiled_model = torch_tensorrt.compile( - model, - "dynamo", - inputs, - min_block_size=1, - use_python_runtime=True, - torch_executed_ops={"torch.ops.aten.abs.default"}, - ) - with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + with pytest.raises( + RuntimeError, + match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + ): with torch_tensorrt.runtime.enable_cudagraphs( compiled_model ) as cudagraphs_module: - out = cudagraphs_module(*inputs) + with torch_tensorrt.runtime.enable_output_allocator(cudagraphs_module): + out = cudagraphs_module(*inputs) + + with pytest.raises( + RuntimeError, + match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + ): + with torch_tensorrt.runtime.enable_output_allocator(compiled_model): + with torch_tensorrt.runtime.enable_cudagraphs( + compiled_model + ) as cudagraphs_module: + out = cudagraphs_module(*inputs) if __name__ == "__main__": From 5199067884123894e4b0da01421e047883580d54 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Mon, 10 Mar 2025 21:19:40 -0700 Subject: [PATCH 11/17] minor fixes --- core/runtime/TRTEngine.cpp | 2 +- core/runtime/execute_engine.cpp | 10 ++++++---- .../dynamo/runtime/_PythonTorchTensorRTModule.py | 2 ++ 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 4b67de7e42..9f93fe4b4e 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -42,7 +42,7 @@ void* DynamicOutputAllocator::reallocateOutputAsync( std::vector shape = {static_cast(size)}; auto it = buffers.find(tensorName); if (it == buffers.end() || it->second.sizes() != shape) { - buffers[tensorName] = at::empty(shape, at::TensorOptions().dtype(dtypes.at(tensorName)).device(c10::kCUDA)); + buffers[tensorName] = at::empty(shape, at::TensorOptions().dtype(dtypes.at(tensorName)).device(at::kCUDA)); return buffers[tensorName].data_ptr(); } else { return it->second.data_ptr(); diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index c857cff6c8..c35fb46a1e 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -379,7 +379,7 @@ std::vector execute_engine(std::vector inputs, c10::intr if (inputs.size() > 0) { current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart } else { - current_device_id = c10::cuda::current_device(); + current_device_id = at::cuda::current_device(); } compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); @@ -428,11 +428,13 @@ std::vector execute_engine(std::vector inputs, c10::intr for (int i = 0; i < dims.nbDims; ++i) { prod *= dims.d[i]; } - std::vector dims_vec(dims.nbDims); + std::vector shape(dims.nbDims); for (int i = 0; i < dims.nbDims; ++i) { - dims_vec[i] = dims.d[i]; + shape[i] = dims.d[i]; } - output = output.reshape(-1).view(dtype).slice(0, 0, prod).reshape(dims_vec); + // When using the OutputAllocator, the allocated buffer might be larger than the size of the output, + // so we need to reshape the buffer to the output shape + output = output.reshape(-1).view(dtype).slice(0, 0, prod).reshape(shape); outputs.push_back(output); } diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index d0a1cf6c2e..f4e25dcedd 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -624,6 +624,8 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: .detach() ) prod = int(torch.prod(torch.tensor(shape))) + # When using the OutputAllocator, the allocated buffer might be larger than the size of the output, + # so we need to reshape the buffer to the output shape output = output.reshape(-1).view(dtype)[:prod].reshape(shape) outputs.append(output) From ed558d9675211190a00509e74db13cce067c54c5 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 11 Mar 2025 12:13:55 -0700 Subject: [PATCH 12/17] fix comments --- core/runtime/execute_engine.cpp | 11 +++++----- core/runtime/runtime.h | 2 +- .../runtime/_CudaGraphsTorchTensorRTModule.py | 2 +- .../runtime/_PythonTorchTensorRTModule.py | 12 +++++----- .../dynamo/runtime/_TorchTensorRTModule.py | 2 +- py/torch_tensorrt/runtime/_cudagraphs.py | 2 +- .../runtime/_output_allocator.py | 2 +- .../runtime/test_output_allocator_py.py | 22 +++++++++---------- 8 files changed, 28 insertions(+), 27 deletions(-) diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index c35fb46a1e..64b111750f 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -520,19 +520,20 @@ std::vector execute_engine(std::vector inputs, c10::intr if (compiled_engine->requires_output_allocator) { // engine requires OA if (cudagraphs_enabled) { TORCHTRT_THROW_ERROR( - "This module requires OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs."); + "The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs."); } - LOG_DEBUG("Using OutputAllocator in runtime."); + LOG_DEBUG("Using the dynamic allocator runtime mode."); return run_output_allocator(); } else { if (compiled_engine->use_output_allocator_outputs) { // users call OA context manager if (cudagraphs_enabled) { - TORCHTRT_THROW_ERROR("Both CUDA Graphs and OutputAllocator are enabled. Please disable either one."); + TORCHTRT_THROW_ERROR( + "Both CUDA Graphs and dynamic output allocation are enabled, which are incompatible runtime modes. Please disable one of the two."); } - LOG_DEBUG("Using OutputAllocator in runtime."); + LOG_DEBUG("Using the dynamic allocator runtime mode."); return run_output_allocator(); } else { - LOG_DEBUG("Using standard execution with cudagraphs=" << cudagraphs_enabled << "."); + LOG_DEBUG("Using the standard execution runtime mode with cudagraphs=" << cudagraphs_enabled << "."); return run_standard_execution(); } } diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 3a2fe9c24b..894df55bfe 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -16,7 +16,7 @@ namespace core { namespace runtime { using EngineID = int64_t; -const std::string ABI_VERSION = "6"; +const std::string ABI_VERSION = "7"; extern bool MULTI_DEVICE_SAFE_MODE; typedef enum { diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 7ec8930983..b3ac25bc3a 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -74,7 +74,7 @@ def __del__(self) -> None: if self.cudagraph: self.cudagraph.reset() - def set_output_allocator_outputs(self, enable: bool) -> None: + def set_use_output_allocator(self, enable: bool) -> None: self.use_output_allocator_outputs = enable def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index f4e25dcedd..891d063ed3 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -404,7 +404,7 @@ def create_output_tensors(self) -> List[torch.Tensor]: def set_pre_allocated_outputs(self, enable: bool) -> None: self.use_pre_allocated_outputs = enable - def set_output_allocator_outputs(self, enable: bool) -> None: + def set_use_output_allocator(self, enable: bool) -> None: self.use_output_allocator_outputs = enable def create_output_allocator(self) -> None: @@ -683,21 +683,21 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: if self.requires_output_allocator: # engine requires OA if self.cudagraphs_enabled: raise RuntimeError( - "This module requires OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs." + "The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs." ) - logger.debug("Using OutputAllocator in runtime.") + logger.debug("Using the dynamic allocator runtime mode.") return run_output_allocator() else: if self.use_output_allocator_outputs: # users call OA context manager if self.cudagraphs_enabled: raise RuntimeError( - "Both CUDA Graphs and OutputAllocator are enabled. Please disable either one." + "Both CUDA Graphs and dynamic output allocation are enabled, which are incompatible runtime modes. Please disable one of the two." ) - logger.debug("Using OutputAllocator in runtime.") + logger.debug("Using the dynamic allocator runtime mode.") return run_output_allocator() else: logger.debug( - f"Using standard execution with cudagraphs={self.cudagraphs_enabled}." + f"Using the standard execution runtime mode with cudagraphs={self.cudagraphs_enabled}." ) return run_standard_execution() diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 5f49326e28..e6b6a21421 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -293,7 +293,7 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: def set_pre_allocated_outputs(self, enable: bool) -> None: self.engine.use_pre_allocated_outputs = enable - def set_output_allocator_outputs(self, enable: bool) -> None: + def set_use_output_allocator(self, enable: bool) -> None: self.engine.use_output_allocator_outputs = enable def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: diff --git a/py/torch_tensorrt/runtime/_cudagraphs.py b/py/torch_tensorrt/runtime/_cudagraphs.py index 1802df7a9f..c771564826 100644 --- a/py/torch_tensorrt/runtime/_cudagraphs.py +++ b/py/torch_tensorrt/runtime/_cudagraphs.py @@ -81,7 +81,7 @@ def __enter__(self) -> torch.nn.Module: and module.requires_output_allocator ): raise RuntimeError( - "There are converters that require Output Allocator. Please disable CUDA Graphs." + "The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs." ) if "_run_on_acc" in name: num_trt_module += 1 diff --git a/py/torch_tensorrt/runtime/_output_allocator.py b/py/torch_tensorrt/runtime/_output_allocator.py index 4b0d644231..163fc26306 100644 --- a/py/torch_tensorrt/runtime/_output_allocator.py +++ b/py/torch_tensorrt/runtime/_output_allocator.py @@ -33,7 +33,7 @@ def __init__( def set_output_allocator_output(self, enable: bool) -> None: for mod in self.rt_mods: - mod.set_output_allocator_outputs(enable) + mod.set_use_output_allocator(enable) def __enter__(self) -> "_OutputAllocatorContextManager": # Enable output_allocator for TRT submodules diff --git a/tests/py/dynamo/runtime/test_output_allocator_py.py b/tests/py/dynamo/runtime/test_output_allocator_py.py index 7065df1649..5466f1bce7 100644 --- a/tests/py/dynamo/runtime/test_output_allocator_py.py +++ b/tests/py/dynamo/runtime/test_output_allocator_py.py @@ -126,7 +126,7 @@ def test_combination_of_cg_and_oa(self, _, use_python_runtime): with pytest.raises( RuntimeError, - match="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + match="Both CUDA Graphs and dynamic output allocation are enabled, which are incompatible runtime modes. Please disable one of the two.", ): with torch_tensorrt.runtime.enable_cudagraphs( compiled_model @@ -136,7 +136,7 @@ def test_combination_of_cg_and_oa(self, _, use_python_runtime): with pytest.raises( RuntimeError, - match="Both CUDA Graphs and OutputAllocator are enabled. Please disable either one.", + match="Both CUDA Graphs and dynamic output allocation are enabled, which are incompatible runtime modes. Please disable one of the two.", ): with torch_tensorrt.runtime.enable_output_allocator(compiled_model): with torch_tensorrt.runtime.enable_cudagraphs( @@ -165,7 +165,7 @@ def test_cudagraphs_and_output_allocator(self, _, use_python_runtime): with pytest.raises( RuntimeError, - match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + match="The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs.", ): with torch_tensorrt.runtime.enable_cudagraphs( compiled_model @@ -232,7 +232,7 @@ def test_combination_of_cg_and_oa(self, _, use_python_runtime): with pytest.raises( RuntimeError, - match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + match="The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs.", ): with torch_tensorrt.runtime.enable_cudagraphs( compiled_model @@ -242,7 +242,7 @@ def test_combination_of_cg_and_oa(self, _, use_python_runtime): with pytest.raises( RuntimeError, - match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + match="The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs.", ): with torch_tensorrt.runtime.enable_output_allocator(compiled_model): with torch_tensorrt.runtime.enable_cudagraphs( @@ -275,7 +275,7 @@ def test_cudagraphs_and_output_allocator(self, _, use_python_runtime): with pytest.raises( RuntimeError, - match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + match="The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs.", ): with torch_tensorrt.runtime.enable_cudagraphs( compiled_model @@ -342,7 +342,7 @@ def test_combination_of_cg_and_oa(self, _, use_python_runtime): with pytest.raises( RuntimeError, - match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + match="The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs.", ): with torch_tensorrt.runtime.enable_cudagraphs( compiled_model @@ -352,7 +352,7 @@ def test_combination_of_cg_and_oa(self, _, use_python_runtime): with pytest.raises( RuntimeError, - match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + match="The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs.", ): with torch_tensorrt.runtime.enable_output_allocator(compiled_model): with torch_tensorrt.runtime.enable_cudagraphs( @@ -382,7 +382,7 @@ def test_cudagraphs_and_output_allocator(self, _, use_python_runtime): with pytest.raises( RuntimeError, - match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + match="The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs.", ): with torch_tensorrt.runtime.enable_cudagraphs( compiled_model @@ -451,7 +451,7 @@ def test_combination_of_cg_and_oa(self, _, use_python_runtime): with pytest.raises( RuntimeError, - match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + match="The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs.", ): with torch_tensorrt.runtime.enable_cudagraphs( compiled_model @@ -461,7 +461,7 @@ def test_combination_of_cg_and_oa(self, _, use_python_runtime): with pytest.raises( RuntimeError, - match="There are converters that require Output Allocator. Please disable CUDA Graphs.", + match="The model contains submodules that require a dynamic output allocator at runtime, which is incompatible with CUDA Graphs. Please disable CUDA Graphs.", ): with torch_tensorrt.runtime.enable_output_allocator(compiled_model): with torch_tensorrt.runtime.enable_cudagraphs( From 4c91bc8d0d790ca9617c0a2448d012e471a88801 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 12 Mar 2025 16:05:57 -0700 Subject: [PATCH 13/17] add doc --- docs/_sources/py_api/runtime.rst.txt | 2 +- docsrc/py_api/runtime.rst | 4 ++- docsrc/user_guide/runtime.rst | 38 ++++++++++++++++++++++++ examples/dynamo/converter_overloading.py | 7 +++-- 4 files changed, 46 insertions(+), 5 deletions(-) diff --git a/docs/_sources/py_api/runtime.rst.txt b/docs/_sources/py_api/runtime.rst.txt index 4ece390816..55a31933ee 100644 --- a/docs/_sources/py_api/runtime.rst.txt +++ b/docs/_sources/py_api/runtime.rst.txt @@ -19,7 +19,7 @@ Functions .. autofunction:: get_whole_cudagraphs_mode -.. autofunction:: set_cudagraphs_modue +.. autofunction:: set_cudagraphs_mode .. autofunction:: enable_pre_allocated_outputs diff --git a/docsrc/py_api/runtime.rst b/docsrc/py_api/runtime.rst index 4ece390816..719d8f6555 100644 --- a/docsrc/py_api/runtime.rst +++ b/docsrc/py_api/runtime.rst @@ -19,12 +19,14 @@ Functions .. autofunction:: get_whole_cudagraphs_mode -.. autofunction:: set_cudagraphs_modue +.. autofunction:: set_cudagraphs_mode .. autofunction:: enable_pre_allocated_outputs .. autofunction:: weight_streaming +.. autofunction:: enable_output_allocator + Classes --------- diff --git a/docsrc/user_guide/runtime.rst b/docsrc/user_guide/runtime.rst index 8672fdebe4..b0a44b0df4 100644 --- a/docsrc/user_guide/runtime.rst +++ b/docsrc/user_guide/runtime.rst @@ -92,3 +92,41 @@ Cudagraphs can accelerate certain models by reducing kernel overheads, as docume In the current implementation, use of a new input shape (for instance in dynamic shape cases), will cause the cudagraph to be re-recorded. Cudagraph recording is generally not latency intensive, and future improvements include caching cudagraphs for multiple input shapes. + +Dynamic Output Allocation Mode +------------------------------ + +Dynamic output allocation is a feature in Torch-TensorRT which allows the output buffer of TensorRT engines to be +dynamically allocated. This is useful for models with dynamic output shapes, especially ops with data-dependent shapes. +Without dynamic output allocation, the output buffer is statically allocated and the size is the maximum possible size +required by the op. This can lead to inefficient memory usage if the actual output size is smaller than the maximum possible size. + +There are two scenarios in which dynamic output allocation is enabled: + +1. When the model contains submodules that require a dynamic output allocator at runtime, users don't have to manually enable dynamic output allocation mode. + +To specify if a module requires a dynamic output allocator, users can set the ``requires_output_allocator=True`` flag in the ``@dynamo_tensorrt_converter`` decorator of converters. e.g., + +.. code-block:: python + + @dynamo_tensorrt_converter( + torch.ops.aten.nonzero.default, + supports_dynamic_shapes=True, + requires_output_allocator=True, + ) + def aten_ops_nonzero( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[TRTTensor, Sequence[TRTTensor]]: + ... + +2. When users manually enable dynamic output allocation via the ``torch_tensorrt.runtime.enable_output_allocator`` context manager. + +.. code-block:: python + + # Enables Dynamic Output Allocation Mode, then resets the mode to its prior setting + with torch_tensorrt.runtime.enable_output_allocator(trt_module): + ... diff --git a/examples/dynamo/converter_overloading.py b/examples/dynamo/converter_overloading.py index dc25a18287..e27c53cb50 100644 --- a/examples/dynamo/converter_overloading.py +++ b/examples/dynamo/converter_overloading.py @@ -58,12 +58,11 @@ def forward(self, x): from typing import Dict, Sequence, Tuple, Union +import tensorrt as trt from torch.fx.node import Argument, Node, Target from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo.conversion import ConversionContext -import tensorrt as trt - # %% # Converter Metadata # ^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -80,6 +79,8 @@ def forward(self, x): supports_dynamic_shapes=True, # Set the priority of the converter to supersede the default one priority=torch_tensorrt.dynamo.conversion.ConverterPriority.HIGH, + # Whether the converter requires a dynamic output allocator to run (e.g. data dependent ops) + requires_output_allocator=True, ) # %% @@ -98,7 +99,7 @@ def forward(self, x): # # Finally there is the ``priority`` argument, which is an enum from the ``torch_tensorrt.dynamo.conversion.ConverterPriority`` class that defines the priority of the converter. The two options are ``HIGH`` and ``STANDARD``. # Converters registered with ``STANDARD`` will be appended to the converter list for a given operation, while converters registered with ``HIGH`` will be prepended to the list. -# Candidate converters are evalated for their suitablity in this priority order and the first converter that passes the validator is used. +# Candidate converters are evalated for their suitability in this priority order and the first converter that passes the validator is used. # %% From 5193d3572defd4fd0af9d86a3dc1713b16d56703 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 12 Mar 2025 18:01:12 -0700 Subject: [PATCH 14/17] fix doc --- docs/_sources/py_api/runtime.rst.txt | 2 +- docsrc/user_guide/runtime.rst | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/_sources/py_api/runtime.rst.txt b/docs/_sources/py_api/runtime.rst.txt index 55a31933ee..4ece390816 100644 --- a/docs/_sources/py_api/runtime.rst.txt +++ b/docs/_sources/py_api/runtime.rst.txt @@ -19,7 +19,7 @@ Functions .. autofunction:: get_whole_cudagraphs_mode -.. autofunction:: set_cudagraphs_mode +.. autofunction:: set_cudagraphs_modue .. autofunction:: enable_pre_allocated_outputs diff --git a/docsrc/user_guide/runtime.rst b/docsrc/user_guide/runtime.rst index b0a44b0df4..5ca842514e 100644 --- a/docsrc/user_guide/runtime.rst +++ b/docsrc/user_guide/runtime.rst @@ -97,15 +97,18 @@ Dynamic Output Allocation Mode ------------------------------ Dynamic output allocation is a feature in Torch-TensorRT which allows the output buffer of TensorRT engines to be -dynamically allocated. This is useful for models with dynamic output shapes, especially ops with data-dependent shapes. -Without dynamic output allocation, the output buffer is statically allocated and the size is the maximum possible size -required by the op. This can lead to inefficient memory usage if the actual output size is smaller than the maximum possible size. +dynamically allocated. This is useful for models with dynamic output shapes, especially ops with data-dependent shapes. +Dynamic output allocation mode cannot be used in conjunction with CUDA Graphs nor pre-allocated outputs feature. +Without dynamic output allocation, the output buffer is allocated based on the inferred output shape based on input size. There are two scenarios in which dynamic output allocation is enabled: -1. When the model contains submodules that require a dynamic output allocator at runtime, users don't have to manually enable dynamic output allocation mode. +1. The model has been identified at compile time to require dynamic output allocation for at least one TensorRT subgraph. +These models will engage the runtime mode automatically (with logging) and are incompatible with other runtime modes +such as CUDA Graphs. -To specify if a module requires a dynamic output allocator, users can set the ``requires_output_allocator=True`` flag in the ``@dynamo_tensorrt_converter`` decorator of converters. e.g., +Converters can declare that subgraphs that they produce will require the output allocator using `requires_output_allocator=True` +there by forcing any model which utilizes the converter to automatically use the output allocator runtime mode. e.g., .. code-block:: python @@ -123,7 +126,7 @@ To specify if a module requires a dynamic output allocator, users can set the `` ) -> Union[TRTTensor, Sequence[TRTTensor]]: ... -2. When users manually enable dynamic output allocation via the ``torch_tensorrt.runtime.enable_output_allocator`` context manager. +2. Users may manually enable dynamic output allocation mode via the ``torch_tensorrt.runtime.enable_output_allocator`` context manager. .. code-block:: python From 8cedc031199d6d1d07440dca1a21874a9b773889 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 13 Mar 2025 10:54:35 -0700 Subject: [PATCH 15/17] update naming --- .../runtime/test_output_allocator_py.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/py/dynamo/runtime/test_output_allocator_py.py b/tests/py/dynamo/runtime/test_output_allocator_py.py index 5466f1bce7..c915f42173 100644 --- a/tests/py/dynamo/runtime/test_output_allocator_py.py +++ b/tests/py/dynamo/runtime/test_output_allocator_py.py @@ -20,7 +20,12 @@ def forward(self, input): return torch.ops.aten.nonzero.default(input) -class NonDDSModel(torch.nn.Module): +class DDSOpWithReductionOpModel(torch.nn.Module): + """ + DDSOpWithReductionOpModel is a model that contains DDS op + reduction op. + Since nonzero requires output allocator, this model will use output allocator by default. + """ + def forward(self, inputs): out = torch.ops.aten.nonzero.default(inputs) out = torch.ops.aten.sum.dim_IntList(out, 0) @@ -251,9 +256,9 @@ def test_combination_of_cg_and_oa(self, _, use_python_runtime): out = cudagraphs_module(*inputs) -class TestOutputAllocatorNonDDSModel(TestCase): +class TestOutputAllocatorDDSOpWithReductionOpModel(TestCase): """ - The NonDDSModel is a model that contains DDS op + reduction op. + The DDSOpWithReductionOpModel is a model that contains DDS op + reduction op. """ @parameterized.expand( @@ -263,7 +268,7 @@ class TestOutputAllocatorNonDDSModel(TestCase): ] ) def test_cudagraphs_and_output_allocator(self, _, use_python_runtime): - model = NonDDSModel().eval().cuda() + model = DDSOpWithReductionOpModel().eval().cuda() inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) compiled_model = torch_tensorrt.compile( model, @@ -302,9 +307,9 @@ def test_cudagraphs_and_output_allocator(self, _, use_python_runtime): ) def test_default(self, _, use_python_runtime): """ - NonDDS models use standard execution with cudagraphs=False by default. + The DDSOpWithReductionOpModel is a model that contains nonzero op + reduction op, in which nonzero op requires output allocator. """ - model = NonDDSModel().eval().cuda() + model = DDSOpWithReductionOpModel().eval().cuda() inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) compiled_model = torch_tensorrt.compile( model, @@ -313,11 +318,11 @@ def test_default(self, _, use_python_runtime): min_block_size=1, use_python_runtime=use_python_runtime, ) - standard_out = compiled_model(*inputs) + oa_out = compiled_model(*inputs) ref_out = model(*inputs) self.assertAlmostEqual( - float(torch.max(torch.abs(ref_out - standard_out))), + float(torch.max(torch.abs(ref_out - oa_out))), 0, DECIMALS_OF_AGREEMENT, msg="Default Output Allocator runtime outputs don't match with the original model.", @@ -330,7 +335,7 @@ def test_default(self, _, use_python_runtime): ] ) def test_combination_of_cg_and_oa(self, _, use_python_runtime): - model = NonDDSModel().eval().cuda() + model = DDSOpWithReductionOpModel().eval().cuda() inputs = (torch.randint(low=0, high=3, size=(10,), dtype=torch.int).to("cuda"),) compiled_model = torch_tensorrt.compile( model, From 8e3de2354579b8785563371825484079d55e8f2c Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 13 Mar 2025 12:03:17 -0700 Subject: [PATCH 16/17] rebase --- .../lowering/passes/_aten_lowering_pass.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 1398b0687e..2ecc45ecf3 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -17,19 +17,21 @@ from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices -ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist( - [ - remove_input_alias_fixing_clones, - constant_fold, - repair_input_as_output, - fuse_prims_broadcast, - fuse_distributed_ops, - replace_max_pool_with_indices, - remove_assert_nodes, - accumulate_fp32_matmul, - remove_num_users_is_0_nodes, - ] -) +pass_list = [ + remove_input_alias_fixing_clones, + constant_fold, + repair_input_as_output, + fuse_prims_broadcast, + replace_max_pool_with_indices, + remove_assert_nodes, + accumulate_fp32_matmul, + remove_num_users_is_0_nodes, +] + +if not is_tegra_platform(): + pass_list.append(fuse_distributed_ops) + +ATEN_POST_LOWERING_PASSES = DynamoPassManager.build_from_passlist(pass_list) ATEN_PRE_LOWERING_PASSES = DynamoPassManager.build_from_passlist( [ From ec572332676ca2ce39224792822ecf71124fd063 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 13 Mar 2025 14:20:59 -0700 Subject: [PATCH 17/17] rename test file --- .../{test_output_allocator_py.py => test_output_allocator.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/py/dynamo/runtime/{test_output_allocator_py.py => test_output_allocator.py} (100%) diff --git a/tests/py/dynamo/runtime/test_output_allocator_py.py b/tests/py/dynamo/runtime/test_output_allocator.py similarity index 100% rename from tests/py/dynamo/runtime/test_output_allocator_py.py rename to tests/py/dynamo/runtime/test_output_allocator.py