44# LICENSE file in the root directory of this source tree.
55
66import torch
7+ from executorch .backends .arm .test import common
78from executorch .backends .arm .test .common import (
89 XfailIfNoCorstone300 ,
910 XfailIfNoCorstone320 ,
1415 TosaPipelineBI ,
1516 TosaPipelineMI ,
1617)
17- from parameterized import parameterized
1818
1919scalar_input_t = tuple [torch .Tensor , int ]
2020
@@ -23,11 +23,20 @@ class LshiftScalar(torch.nn.Module):
2323 torch_op_MI = "torch.ops.aten.__lshift__.Scalar"
2424 torch_op_BI = "torch.ops.aten.bitwise_left_shift.Tensor"
2525 exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_left_shift_Tensor"
26- test_data = [
27- ((torch .randint (- 8 , 8 , (1 , 12 , 3 , 4 ), dtype = torch .int8 ), 1 ),),
28- ((torch .randint (- 100 , 100 , (1 , 5 , 3 , 4 ), dtype = torch .int16 ), 5 ),),
29- ((torch .randint (- 100 , 100 , (1 , 5 , 3 , 4 ), dtype = torch .int32 ), 2 ),),
30- ]
26+ test_data = {
27+ "randint_neg_8_int8" : (
28+ torch .randint (- 8 , 8 , (1 , 12 , 3 , 4 ), dtype = torch .int8 ),
29+ 1 ,
30+ ),
31+ "randint_neg_100_int16" : (
32+ torch .randint (- 100 , 100 , (1 , 5 , 3 , 4 ), dtype = torch .int16 ),
33+ 5 ,
34+ ),
35+ "randint_neg_100_int32" : (
36+ torch .randint (- 100 , 100 , (1 , 5 , 3 , 4 ), dtype = torch .int32 ),
37+ 2 ,
38+ ),
39+ }
3140
3241 def forward (self , x : torch .Tensor , shift : int ):
3342 return x << shift
@@ -39,33 +48,27 @@ def forward(self, x: torch.Tensor, shift: int):
3948class LshiftTensor (torch .nn .Module ):
4049 torch_op = "torch.ops.aten.bitwise_left_shift.Tensor"
4150 exir_op = "executorch_exir_dialects_edge__ops_aten_bitwise_left_shift_Tensor"
42- test_data = [
43- (
44- (
45- torch .randint (- 8 , 8 , (3 , 3 ), dtype = torch .int8 ),
46- torch .randint (0 , 4 , (3 , 3 ), dtype = torch .int8 ),
47- ),
51+ test_data = {
52+ "randint_neg_8_tensor_int8" : (
53+ torch .randint (- 8 , 8 , (3 , 3 ), dtype = torch .int8 ),
54+ torch .randint (0 , 4 , (3 , 3 ), dtype = torch .int8 ),
4855 ),
49- (
50- (
51- torch .randint (- 1024 , 1024 , (3 , 3 , 3 ), dtype = torch .int16 ),
52- torch .randint (0 , 5 , (3 , 3 , 3 ), dtype = torch .int16 ),
53- ),
56+ "randint_neg_1024_tensor_int16" : (
57+ torch .randint (- 1024 , 1024 , (3 , 3 , 3 ), dtype = torch .int16 ),
58+ torch .randint (0 , 5 , (3 , 3 , 3 ), dtype = torch .int16 ),
5459 ),
55- (
56- (
57- torch .randint (0 , 127 , (1 , 2 , 3 , 3 ), dtype = torch .int32 ),
58- torch .randint (0 , 5 , (1 , 2 , 3 , 3 ), dtype = torch .int32 ),
59- ),
60+ "randint_0_tensor_int16" : (
61+ torch .randint (0 , 127 , (1 , 2 , 3 , 3 ), dtype = torch .int32 ),
62+ torch .randint (0 , 5 , (1 , 2 , 3 , 3 ), dtype = torch .int32 ),
6063 ),
61- ]
64+ }
6265
6366 def forward (self , x : torch .Tensor , shift : torch .Tensor ):
6467 return x .bitwise_left_shift (shift )
6568
6669
67- @parameterized . expand ( LshiftScalar .test_data )
68- def test_lshift_scalar_tosa_MI (test_data ):
70+ @common . parametrize ( "test_data" , LshiftScalar .test_data )
71+ def test_lshift_scalar_tosa_MI_scalar (test_data ):
6972 TosaPipelineMI [scalar_input_t ](
7073 LshiftScalar (),
7174 test_data ,
@@ -74,18 +77,21 @@ def test_lshift_scalar_tosa_MI(test_data):
7477 ).run ()
7578
7679
77- @parameterized . expand ( LshiftScalar .test_data )
78- def test_lshift_scalar_tosa_BI (test_data ):
80+ @common . parametrize ( "test_data" , LshiftScalar .test_data )
81+ def test_bitwise_left_shift_tensor_tosa_BI_scalar (test_data ):
7982 pipeline = TosaPipelineBI [scalar_input_t ](
80- LshiftScalar (), test_data , LshiftScalar .torch_op_BI , LshiftScalar .exir_op
83+ LshiftScalar (),
84+ test_data ,
85+ LshiftScalar .torch_op_BI ,
86+ LshiftScalar .exir_op ,
8187 )
8288 pipeline .pop_stage ("check.quant_nodes" )
8389 pipeline .run ()
8490
8591
86- @parameterized . expand ( LshiftScalar .test_data )
92+ @common . parametrize ( "test_data" , LshiftScalar .test_data )
8793@XfailIfNoCorstone300
88- def test_lshift_scalar_tosa_u55 (test_data ):
94+ def test_bitwise_left_shift_tensor_u55_BI_scalar (test_data ):
8995 pipeline = EthosU55PipelineBI [scalar_input_t ](
9096 LshiftScalar (),
9197 test_data ,
@@ -97,9 +103,9 @@ def test_lshift_scalar_tosa_u55(test_data):
97103 pipeline .run ()
98104
99105
100- @parameterized . expand ( LshiftScalar .test_data )
106+ @common . parametrize ( "test_data" , LshiftScalar .test_data )
101107@XfailIfNoCorstone320
102- def test_lshift_scalar_tosa_u85 (test_data ):
108+ def test_bitwise_left_shift_tensor_u85_BI_scalar (test_data ):
103109 pipeline = EthosU85PipelineBI [scalar_input_t ](
104110 LshiftScalar (),
105111 test_data ,
@@ -111,8 +117,8 @@ def test_lshift_scalar_tosa_u85(test_data):
111117 pipeline .run ()
112118
113119
114- @parameterized . expand ( LshiftTensor .test_data )
115- def test_lshift_tensor_tosa_MI (test_data ):
120+ @common . parametrize ( "test_data" , LshiftTensor .test_data )
121+ def test_lshift_scalar_tosa_MI (test_data ):
116122 TosaPipelineMI [scalar_input_t ](
117123 LshiftTensor (),
118124 test_data ,
@@ -121,18 +127,21 @@ def test_lshift_tensor_tosa_MI(test_data):
121127 ).run ()
122128
123129
124- @parameterized . expand ( LshiftTensor .test_data )
125- def test_lshift_tensor_tosa_BI (test_data ):
130+ @common . parametrize ( "test_data" , LshiftTensor .test_data )
131+ def test_bitwise_left_shift_tensor_tosa_BI (test_data ):
126132 pipeline = TosaPipelineBI [scalar_input_t ](
127- LshiftTensor (), test_data , LshiftTensor .torch_op , LshiftTensor .exir_op
133+ LshiftTensor (),
134+ test_data ,
135+ LshiftTensor .torch_op ,
136+ LshiftTensor .exir_op ,
128137 )
129138 pipeline .pop_stage ("check.quant_nodes" )
130139 pipeline .run ()
131140
132141
133- @parameterized . expand ( LshiftTensor .test_data )
142+ @common . parametrize ( "test_data" , LshiftTensor .test_data )
134143@XfailIfNoCorstone300
135- def test_lshift_tensor_tosa_u55 (test_data ):
144+ def test_bitwise_left_shift_tensor_u55_BI (test_data ):
136145 pipeline = EthosU55PipelineBI [scalar_input_t ](
137146 LshiftTensor (),
138147 test_data ,
@@ -144,9 +153,9 @@ def test_lshift_tensor_tosa_u55(test_data):
144153 pipeline .run ()
145154
146155
147- @parameterized . expand ( LshiftTensor .test_data )
156+ @common . parametrize ( "test_data" , LshiftTensor .test_data )
148157@XfailIfNoCorstone320
149- def test_lshift_tensor_tosa_u85 (test_data ):
158+ def test_bitwise_left_shift_tensor_u85_BI (test_data ):
150159 pipeline = EthosU85PipelineBI [scalar_input_t ](
151160 LshiftTensor (),
152161 test_data ,
0 commit comments