@@ -124,3 +124,60 @@ def two_resizes_with_shared_subgraphs(x: ost.FLOAT["batch", 1, "height", "width"
124124 return opset11 .Add (resized_y , resized_z )
125125
126126make_model_and_data (two_resizes_with_shared_subgraphs , np .random .rand (1 , 1 , 4 , 5 ).astype (np .float32 ), np .random .rand (1 , 1 , 3 , 2 ).astype (np .float32 ), np .random .rand (1 , 1 , 2 , 1 ).astype (np .float32 ))
127+
128+ batch_size = 1
129+ sequence_length = 320
130+ input_hidden_size = 48
131+ qk_hidden_size = 48
132+ v_hidden_size = 48
133+ num_heads = 4
134+ qk_head_size = int (qk_hidden_size / num_heads )
135+ v_head_size = int (v_hidden_size / num_heads )
136+ attention_weight = np .random .rand (input_hidden_size , qk_hidden_size + qk_hidden_size + v_hidden_size ).astype (np .float32 )
137+ attention_bias = np .random .rand (qk_hidden_size + qk_hidden_size + v_hidden_size ).astype (np .float32 )
138+
139+ @ost .script ()
140+ def attention (x : ost .FLOAT [batch_size , sequence_length , input_hidden_size ]) -> ost .FLOAT [batch_size , sequence_length , input_hidden_size ]:
141+ transpose = op .Transpose (x , perm = [1 , 0 , 2 ])
142+ qkv_matmul_weight = op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .FLOAT , attention_weight .shape , attention_weight ))
143+ qkv_matmul = op .MatMul (transpose , qkv_matmul_weight )
144+
145+ qkv_add_bias = op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .FLOAT , attention_bias .shape , attention_bias ))
146+ qkv_add = op .Add (qkv_add_bias , qkv_matmul )
147+
148+ # q path
149+ q_path_slice = op .Slice (qkv_add ,
150+ op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .INT64 , [1 ], np .array ([0 ], dtype = np .int64 ))),
151+ op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .INT64 , [1 ], np .array ([qk_hidden_size ], dtype = np .int64 ))),
152+ op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .INT64 , [1 ], np .array ([- 1 ], dtype = np .int64 ))))
153+ q_path_reshape = op .Reshape (q_path_slice , op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .INT64 , [3 ], np .array ([sequence_length , batch_size * num_heads , qk_head_size ], dtype = np .int64 ))), allowzero = 0 )
154+ q_path_transpose = op .Transpose (q_path_reshape , perm = [1 , 0 , 2 ])
155+ q_path_div = op .Div (q_path_transpose , op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .FLOAT , [], np .array ([np .sqrt (qk_hidden_size )], dtype = np .float32 ))))
156+ # k path
157+ k_path_slice = op .Slice (qkv_add ,
158+ op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .INT64 , [1 ], np .array ([qk_hidden_size ], dtype = np .int64 ))),
159+ op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .INT64 , [1 ], np .array ([qk_hidden_size + qk_hidden_size ], dtype = np .int64 ))),
160+ op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .INT64 , [1 ], np .array ([- 1 ], dtype = np .int64 ))))
161+ k_path_reshape = op .Reshape (k_path_slice , op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .INT64 , [3 ], np .array ([sequence_length , batch_size * num_heads , qk_head_size ], dtype = np .int64 ))), allowzero = 0 )
162+ k_path_transpose = op .Transpose (k_path_reshape , perm = [1 , 2 , 0 ])
163+
164+ # qk path
165+ qk_matmul = op .MatMul (q_path_div , k_path_transpose )
166+ qk_softmax = op .Softmax (qk_matmul )
167+
168+ # v path
169+ v_path_slice = op .Slice (qkv_add ,
170+ op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .INT64 , [1 ], np .array ([qk_hidden_size + qk_hidden_size ], dtype = np .int64 ))),
171+ op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .INT64 , [1 ], np .array ([qk_hidden_size + qk_hidden_size + v_hidden_size ], dtype = np .int64 ))),
172+ op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .INT64 , [1 ], np .array ([- 1 ], dtype = np .int64 ))))
173+ v_path_reshape = op .Reshape (v_path_slice , op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .INT64 , [3 ], np .array ([sequence_length , batch_size * num_heads , v_head_size ], dtype = np .int64 ))), allowzero = 0 )
174+ v_path_transpose = op .Transpose (v_path_reshape , perm = [1 , 0 , 2 ])
175+
176+ # matmul
177+ matmul = op .MatMul (qk_softmax , v_path_transpose )
178+ trans = op .Transpose (matmul , perm = [1 , 0 , 2 ])
179+ reshape = op .Reshape (trans , op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .INT64 , [3 ], np .array ([batch_size , sequence_length , v_hidden_size ], dtype = np .int64 ))))
180+
181+ return reshape
182+
183+ make_model_and_data (attention , np .random .rand (batch_size , sequence_length , input_hidden_size ).astype (np .float32 ))
0 commit comments