66import torch_tensorrt as torchtrt
77from parameterized import parameterized
88from torch .testing ._internal .common_utils import TestCase , run_tests
9+ from torch_tensorrt .dynamo .utils import prepare_inputs
910
1011INPUT_SIZE = (64 , 100 )
1112
@@ -302,45 +303,62 @@ def __init__(self):
302303 self .layer2 = torch .nn .Linear (128 , 64 )
303304 self .relu = torch .nn .ReLU ()
304305
305- def forward (self , x ):
306+ def forward (self , x , b = None , c = None , d = None , e = [] ):
306307 out = self .layer1 (x )
308+ out = out + b
309+ if c is not None :
310+ out = out * c
307311 out = self .relu ((out + 2.0 ) * 0.05 )
312+ if d is not None :
313+ out = out - d ["value" ] + d ["value2" ]
308314 out = self .layer2 (out )
315+ for n in e :
316+ out += n
309317 return out
310318
311- inputs = torchtrt .Input (
312- min_shape = (1 , 100 ),
313- opt_shape = (64 , 100 ),
314- max_shape = (128 , 100 ),
315- dtype = torch .float ,
316- name = "x" ,
317- )
318319 model = SampleModel ().eval ().cuda ()
319320 input_list = []
320- input_list .append (torch .randn ((8 , 100 )).cuda ())
321- input_list .append (torch .randn ((12 , 100 )).cuda ())
322- input_list .append (torch .randn ((12 , 100 )).cuda ())
323- input_list .append (torch .randn ((8 , 100 )).cuda ())
324- input_list .append (torch .randn ((8 , 100 )).cuda ())
325-
326- dynamic_shapes = (
327- {
328- 0 : torch .export .Dim ("batch_size" , min = 1 , max = 128 ),
329- },
330- )
331- exp_program = torch .export .export (
332- model , (input_list [0 ],), dynamic_shapes = dynamic_shapes
333- )
334-
321+ for batch_size in [8 , 12 , 12 , 8 , 8 ]:
322+ args = [torch .rand ((batch_size , 100 )).to ("cuda" )]
323+ kwargs = {
324+ "b" : torch .rand ((1 , 128 )).to ("cuda" ),
325+ "d" : {
326+ "value" : torch .rand (1 ).to ("cuda" ),
327+ "value2" : torch .tensor (1.2 ).to ("cuda" ),
328+ },
329+ "e" : [torch .rand (1 ).to ("cuda" ), torch .rand (1 ).to ("cuda" )],
330+ }
331+ input_list .append ((args , kwargs ))
332+
333+ kwarg_torchtrt_input = prepare_inputs (input_list [0 ][1 ])
334+
335+ compile_spec = {
336+ "inputs" : [
337+ torchtrt .Input (
338+ min_shape = (1 , 100 ),
339+ opt_shape = (64 , 100 ),
340+ max_shape = (128 , 100 ),
341+ dtype = torch .float32 ,
342+ name = "x" ,
343+ ),
344+ ],
345+ "kwarg_inputs" : kwarg_torchtrt_input ,
346+ "device" : torchtrt .Device ("cuda:0" ),
347+ "enabled_precisions" : {torch .float },
348+ "pass_through_build_failures" : True ,
349+ "min_block_size" : 1 ,
350+ "ir" : "dynamo" ,
351+ "cache_built_engines" : False ,
352+ "reuse_cached_engines" : False ,
353+ "use_explicit_typing" : True ,
354+ "enable_weight_streaming" : True ,
355+ "torch_executed_ops" : {"torch.ops.aten.mul.Tensor" },
356+ "use_python_runtime" : use_python_runtime ,
357+ }
358+ exp_program = torchtrt .dynamo .trace (model , ** compile_spec )
335359 optimized_model = torchtrt .dynamo .compile (
336360 exp_program ,
337- inputs ,
338- min_block_size = 1 ,
339- pass_through_build_failures = True ,
340- use_explicit_typing = True ,
341- enable_weight_streaming = True ,
342- torch_executed_ops = {"torch.ops.aten.mul.Tensor" },
343- use_python_runtime = use_python_runtime ,
361+ ** compile_spec ,
344362 )
345363
346364 # List of tuples representing different configurations for three features:
@@ -361,12 +379,12 @@ def test_trt_model(enable_weight_streaming, optimized_model, input_list):
361379 for i in range (len (input_list )):
362380 if enable_weight_streaming and i == 4 :
363381 weight_streaming_ctx .device_budget = int (streamable_budget * 0.6 )
364- out_list .append (optimized_model (input_list [i ]))
382+ out_list .append (optimized_model (* input_list [i ][ 0 ], ** input_list [ i ][ 1 ]))
365383 return out_list
366384
367385 ref_out_list = []
368386 for i in range (len (input_list )):
369- ref_out_list .append (model (input_list [i ]))
387+ ref_out_list .append (model (* input_list [i ][ 0 ], ** input_list [ i ][ 1 ]))
370388
371389 pre_allocated_output_ctx = torchtrt .runtime .enable_pre_allocated_outputs (
372390 optimized_model
0 commit comments