77from torch_tensorrt .fx .tools .common_fx2trt import DispatchTestCase , InputTensorSpec
88
99class TestSelectConverter (DispatchTestCase ):
10- def test_select (self ):
10+ @parameterized .expand (
11+ [
12+ ("select_dim_index" , 2 , 1 ),
13+ ]
14+ )
15+ def test_select (self , dim_test , index_test ):
1116 class TestModule (torch .nn .Module ):
12- def forward (self , input , dim , index ):
13- return torch .select (input , dim , index )
17+ def __init__ (self , dim , index ):
18+ super ().__init__ ()
19+ self .dim = dim
20+ self .index = index
21+ def forward (self , input ):
22+ return torch .select (input , self .dim , self .index )
1423 input = [torch .randn (1 , 3 , 32 )]
15- dim = 2
16- index = 1
17- inputs = (input , dim , index )
1824 self .run_test (
19- TestModule (), input , expected_ops = {torch .ops .aten .select . Tensor }, test_explicit_precision = True ,
25+ TestModule (dim_test , index_test ), input , expected_ops = {torch .ops .aten .select }, test_explicit_precision = True ,
2026 )
2127
22- def test_select_with_dynamic_shape (self , x , y ):
28+ def test_select_with_dynamic_shape (self , dim_test , index_test ):
2329 class TestModule (torch .nn .Module ):
2430 def forward (self , input , dim , index ):
2531 return torch .select (input , dim , index )
@@ -31,9 +37,9 @@ def forward(self, input, dim, index):
3137 shape_ranges = [((1 , 3 , 3 ), (3 , 3 , 3 ), (32 , 32 , 32 ))],
3238 ),
3339 ]
34- dim = 2
35- index = 1
36- inputs_spec = (input_spec , dim , index )
3740 self .run_test_with_dynamic_shape (
38- TestModule (), inputs_spec , expected_ops = {torch .ops .aten .select .Tensor }
39- )
41+ TestModule (dim_test , index_test ), input_spec , expected_ops = {torch .ops .aten .select }
42+ )
43+
44+ if __name__ == "__main__" :
45+ run_tests ()
0 commit comments