diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index d3296c347c..7b6f0d5908 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -39,7 +39,7 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { } torch::jit::EliminateDeadCode(g); if (lower_info.forced_fallback_modules.size() > 0) { - passes::MarkNodesForFallback(g, true); + passes::MarkNodesForFallback(g, true, lower_info.default_torch_execution); } passes::UnpackHardSwish(g); passes::EliminateExceptionOrPassPattern(g); diff --git a/core/lowering/lowering.h b/core/lowering/lowering.h index 22ce9101d9..eb87ca538f 100644 --- a/core/lowering/lowering.h +++ b/core/lowering/lowering.h @@ -15,6 +15,7 @@ struct LowerInfo { // Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering // pass. Disable this in order to not disturb TensorRT's QAT optimizations. bool disable_cse = false; + bool default_torch_execution = false; std::vector forced_fallback_modules; friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l); }; diff --git a/core/lowering/passes/module_fallback.cpp b/core/lowering/passes/module_fallback.cpp index 415b385634..49e713c60c 100644 --- a/core/lowering/passes/module_fallback.cpp +++ b/core/lowering/passes/module_fallback.cpp @@ -92,7 +92,7 @@ void NotateModuleForFallback( } } -void MarkNodesForFallback(std::shared_ptr& g, bool delete_delims) { +void MarkNodesForFallback(std::shared_ptr& g, bool delete_delims, bool default_torch_execution) { auto b = g->block(); std::stack mark = std::stack({false}); @@ -126,7 +126,7 @@ void MarkNodesForFallback(std::shared_ptr& g, bool delete_del if (n->s(c10::Symbol::attr("compilation_edge")) == "end") { LOG_WARNING("Found the end of segmented block targeted for torch while not actively marking a block"); } - } else if (mark.top()) { + } else if ((!mark.top() && default_torch_execution) or (mark.top() && !default_torch_execution)) { LOG_GRAPH("Marking " << util::node_info(n) << " to run in PyTorch"); n->i_(c10::Symbol::attr("to_compile"), (int64_t) false); } diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index 73bd9f61d7..5af5602db3 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -22,7 +22,7 @@ void EliminateExceptionOrPassPattern(std::shared_ptr graph); void ReduceToOperation(std::shared_ptr& graph); void ReduceGelu(std::shared_ptr& graph); void ReduceRemainder(std::shared_ptr& graph); -void MarkNodesForFallback(std::shared_ptr& g, bool delete_delims); +void MarkNodesForFallback(std::shared_ptr& g, bool delete_delims, bool default_torch_execution); void RemoveBNDimCheck(std::shared_ptr graph); void RemoveContiguous(std::shared_ptr& graph); void ViewToReshape(std::shared_ptr& graph); diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index a89fe692bd..aaa8d93d1b 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -218,6 +218,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() { info.partition_info.forced_fallback_operators = torch_fallback.forced_fallback_operators; info.partition_info.truncate_long_and_double = truncate_long_and_double; info.lower_info.forced_fallback_modules = torch_fallback.forced_fallback_modules; + info.lower_info.default_torch_execution = default_torch_execution; info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double; info.convert_info.engine_settings.capability = toTRTEngineCapability(capability); diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.h b/py/torch_tensorrt/csrc/tensorrt_classes.h index 0c80641005..72d40923ef 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.h +++ b/py/torch_tensorrt/csrc/tensorrt_classes.h @@ -163,6 +163,7 @@ struct CompileSpec : torch::CustomClassHolder { bool refit = false; bool debug = false; bool truncate_long_and_double = false; + bool default_torch_execution = false; Device device; TorchFallback torch_fallback; EngineCapability capability = EngineCapability::kDEFAULT; diff --git a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp index 6e5f333f78..d47a396f16 100644 --- a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp +++ b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp @@ -304,7 +304,8 @@ PYBIND11_MODULE(_C, m) { .def_readwrite("num_avg_timing_iters", &CompileSpec::num_avg_timing_iters) .def_readwrite("workspace_size", &CompileSpec::workspace_size) .def_readwrite("torch_fallback", &CompileSpec::torch_fallback) - .def_readwrite("truncate_long_and_double", &CompileSpec::truncate_long_and_double); + .def_readwrite("truncate_long_and_double", &CompileSpec::truncate_long_and_double) + .def_readwrite("default_torch_execution", &CompileSpec::default_torch_execution); py::class_(ts_sub_mod, "TorchFallback") .def(py::init<>()) diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index e406096677..567422a0ae 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -221,7 +221,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec: if "torch_fallback" in compile_spec: info.torch_fallback = _parse_torch_fallback(compile_spec["torch_fallback"]) - + + if "default_torch_execution" in compile_spec: + assert type(compile_spec["default_torch_execution"]) is bool + info.default_torch_execution = compile_spec["default_torch_execution"] return info diff --git a/py/torch_tensorrt/ts/_compiler.py b/py/torch_tensorrt/ts/_compiler.py index c0e88b99ce..498d04d2cb 100644 --- a/py/torch_tensorrt/ts/_compiler.py +++ b/py/torch_tensorrt/ts/_compiler.py @@ -26,7 +26,9 @@ def compile(module: torch.jit.ScriptModule, require_full_compilation=False, min_block_size=3, torch_executed_ops=[], - torch_executed_modules=[]) -> torch.jit.ScriptModule: + torch_executed_modules=[], + default_torch_execution=False, + trt_executed_modules=[]) -> torch.jit.ScriptModule: """Compile a TorchScript module for NVIDIA GPUs using TensorRT Takes a existing TorchScript module and a set of settings to configure the compiler @@ -74,6 +76,8 @@ def compile(module: torch.jit.ScriptModule, min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT torch_executed_ops (List[str]): List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True torch_executed_modules (List[str]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True + default_torch_execution (bool): If turned on, modules would be executed in torch by default, and those specified by trt_executed_modules would be compiled + trt_executed_modules (List[str]): List of modules that would be compiled to TensorRT. An error will be thrown if this list is not empty but ``default_torch_execution`` is False Returns: torch.jit.ScriptModule: Compiled TorchScript Module, when run it will execute via TensorRT @@ -87,6 +91,14 @@ def compile(module: torch.jit.ScriptModule, raise ValueError( "require_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: " + torch_executed_ops + ", torch_executed_modules: " + torch_executed_modules) + + if default_torch_execution: + if require_full_compilation: + raise ValueError("require_full_compilation is enabled however default_torch_execution mode is also switched on, which causes confliction") + if len(torch_executed_modules) > 0: + raise ValueError("With default_torch_execution=True, it is unnecessary to specify torch_executed_modules") + if len(trt_executed_modules) == 0: + raise ValueError("With default_torch_execution=True, it is necesary to specify some trt_executed_modules otherwise nothing will be compiled") spec = { "inputs": inputs, @@ -105,9 +117,10 @@ def compile(module: torch.jit.ScriptModule, "torch_fallback": { "enabled": not require_full_compilation, "forced_fallback_ops": torch_executed_ops, - "forced_fallback_modules": torch_executed_modules, + "forced_fallback_modules": torch_executed_modules if not default_torch_execution else trt_executed_modules, "min_block_size": min_block_size - } + }, + "default_torch_execution": default_torch_execution } compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))