@@ -817,72 +817,33 @@ class sliceScatter(torch.nn.Module):
817817 def __init__ (self , * args , ** kwargs ) -> None :
818818 super ().__init__ (* args , ** kwargs )
819819
820- def forward (self , x , src , dim , start = None , end = None , step = 1 ):
821- y = torch .ops .aten .slice_scatter (x , src , dim , start , end , step )
820+ def forward (self , x , src ):
821+ y = torch .ops .aten .slice_scatter (x , src , 1 , 6 , None , 1 )
822822 return y
823823
824- # Operations expected to be removed in the traced graph after decompositions
825- expected_ops = {
826- torch .ops .aten .scatter .src ,
827- }
828- unexpected_ops = {torch .ops .aten .select_scatter }
829-
830- a = torch .zeros (8 , 8 ).cuda ()
831- b = torch .ones (8 , 2 ).cuda ()
832-
833- # 0-D tensors for dynamic scalar values
834- start = torch .tensor (1 , dtype = torch .int64 ).cuda ()
835- end = torch .tensor (6 , dtype = torch .int64 ).cuda ()
836- step = torch .tensor (1 , dtype = torch .int64 ).cuda ()
837-
838- # Mark scalar tensors as dynamic (note: shape = ())
839- torch ._dynamo .mark_dynamic (start , (), min = 1 , max = 3 )
840- torch ._dynamo .mark_dynamic (end , (), min = 4 , max = 6 )
841-
842- inputs = (a , b , start , end , None , step )
843824 fx_graph = torch .fx .symbolic_trace (sliceScatter ())
844- unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
845- fx_graph ,
846- inputs ,
847- expected_ops = expected_ops ,
848- unexpected_ops = unexpected_ops ,
849- min_block_size = 1 ,
850- )
851825
852- self .assertEqual (
853- len (unexpected_ops_seen ),
854- 0 ,
855- f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
856- )
857-
858- self .assertEqual (
859- len (expected_ops_unseen ),
860- 0 ,
861- f"The following expected ops were not encountered: { expected_ops_unseen } " ,
826+ dim1 = torch .export .Dim ("dim1" , min = 8 , max = 10 )
827+ dynamic_shapes = {
828+ "x" : [torch .export .Dim .STATIC , dim1 ],
829+ "src" : [torch .export .Dim .STATIC , None ],
830+ }
831+ inputs = (torch .zeros (8 , 8 ).cuda (), torch .ones (8 , 2 ).cuda ())
832+ exported_program = torch .export .export (
833+ sliceScatter (), tuple (inputs ), dynamic_shapes = dynamic_shapes
862834 )
863-
835+ fx_graph = exported_program .module ()
836+ inputs = [
837+ torch_tensorrt .Input (
838+ min_shape = [8 , 8 ], opt_shape = [8 , 10 ], max_shape = [8 , 10 ]
839+ ),
840+ torch_tensorrt .Input (min_shape = [8 , 2 ], opt_shape = [8 , 2 ], max_shape = [8 , 2 ]),
841+ ]
864842 torch ._dynamo .reset ()
865-
866- # Validate that the results between Torch and Torch-TRT are similar
867- optimized_model = torch_tensorrt .compile (
868- fx_graph ,
869- "torch_compile" ,
870- inputs ,
871- min_block_size = 1 ,
872- truncate_double = True ,
873- pass_through_build_failures = True ,
874- )
875- optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
876- torch_model_results = fx_graph (* inputs ).detach ().cpu ()
877-
878- max_diff = float (
879- torch .max (torch .abs (optimized_model_results - torch_model_results ))
880- )
881- self .assertAlmostEqual (
882- max_diff ,
883- 0 ,
884- DECIMALS_OF_AGREEMENT ,
885- f"Slice_scatter TRT outputs don't match with the original model." ,
843+ trt_model = torch_tensorrt .dynamo .compile (exported_program , inputs )
844+ inputs = (torch .zeros (8 , 8 ).cuda (), torch .ones (8 , 2 ).cuda ())
845+ torch .testing .assert_close (
846+ trt_model (* inputs ), fx_graph (* inputs ), rtol = RTOL , atol = ATOL
886847 )
887848
888849 def test_lowering_select_scatter_dimZero_module (self ):
0 commit comments