66
77from typing import Any , Optional , Sequence
88from torch_tensorrt import EngineCapability , Device
9- from torch_tensorrt .fx .utils import LowerPrecision
109from torch .fx .passes .pass_manager import PassManager
1110from torch .fx .passes .shape_prop import ShapeProp
1211from torch_tensorrt .dynamo .aten_tracer import trace
@@ -78,29 +77,29 @@ def compile(
7877 if not isinstance (inputs , collections .abc .Sequence ):
7978 inputs = [inputs ]
8079
81- inputs = prepare_inputs (inputs , prepare_device (device ))
80+ torchtrt_inputs , torch_inputs = prepare_inputs (inputs , prepare_device (device ))
8281
8382 if (
8483 torch .float16 in enabled_precisions
8584 or torch_tensorrt .dtype .half in enabled_precisions
8685 ):
87- lower_precision = LowerPrecision . FP16
86+ precision = torch . float16
8887 elif (
8988 torch .float32 in enabled_precisions
9089 or torch_tensorrt .dtype .float in enabled_precisions
9190 ):
92- lower_precision = LowerPrecision . FP32
91+ precision = torch . float32
9392 elif len (enabled_precisions ) == 0 :
9493 logger .info (f"No precision specified, defaulting to { PRECISION } " )
95- lower_precision = PRECISION
94+ precision = PRECISION
9695 else :
9796 raise ValueError (
9897 f"Precision { enabled_precisions } not supported in the Dynamo Path"
9998 )
10099
101100 if kwargs .get ("ir" , "dynamo" ) == "torch_compile" :
102101 custom_backend = create_backend (
103- precision = lower_precision ,
102+ precision = precision ,
104103 debug = debug ,
105104 workspace_size = workspace_size ,
106105 min_block_size = min_block_size ,
@@ -114,13 +113,13 @@ def compile(
114113 )
115114 model = torch .compile (gm , backend = custom_backend )
116115 # Ensure compilation occurs by calling the function with provided inputs
117- model (* inputs )
116+ model (* torch_inputs )
118117 return model
119118
120119 else :
121120 settings = CompilationSettings (
122121 debug = debug ,
123- precision = lower_precision ,
122+ precision = precision ,
124123 workspace_size = workspace_size ,
125124 min_block_size = min_block_size ,
126125 torch_executed_ops = torch_executed_ops ,
@@ -131,20 +130,20 @@ def compile(
131130 use_python_runtime = use_python_runtime ,
132131 )
133132
134- model = trace (gm , inputs , ** kwargs )
133+ model = trace (gm , torch_inputs , ** kwargs )
135134
136135 if kwargs .get ("use_capability_partitioner" , None ):
137- model = lower_model (model , inputs )
138- return _compile_module (model , inputs , settings )
136+ model = lower_model (model , torch_inputs )
137+ return _compile_module (model , torch_inputs , settings )
139138 else :
140- split_result = lower_model_using_trt_splitter (model , inputs )
141- trt_module = _compile_graph (split_result , inputs , settings )
139+ split_result = lower_model_using_trt_splitter (model , torch_inputs )
140+ trt_module = _compile_graph (split_result , torch_inputs , settings )
142141
143142 return trt_module
144143
145144
146145def create_backend (
147- precision : LowerPrecision = PRECISION ,
146+ precision : torch . dtype = PRECISION ,
148147 debug : bool = DEBUG ,
149148 workspace_size : int = WORKSPACE_SIZE ,
150149 min_block_size : int = MIN_BLOCK_SIZE ,
@@ -234,7 +233,7 @@ def lower_model(model: torch.nn.Module, inputs: Any, **kwargs):
234233 [fuse_permute_matmul , fuse_permute_linear ]
235234 )
236235 lowered_model = graph_optimization_pm (model )
237- if isinstance (lowered_model , torch .fx .GraphModule ):
238- ShapeProp (lowered_model ).propagate (* inputs )
236+ # if isinstance(lowered_model, torch.fx.GraphModule):
237+ # ShapeProp(lowered_model).propagate(*inputs)
239238
240239 return lowered_model
0 commit comments