Skip to content

Commit ff2ab1d

Browse files
committed
add code to generate model and data
1 parent 67b5625 commit ff2ab1d

File tree

4 files changed

+57
-0
lines changed

4 files changed

+57
-0
lines changed
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models_with_onnxscript.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,60 @@ def two_resizes_with_shared_subgraphs(x: ost.FLOAT["batch", 1, "height", "width"
124124
return opset11.Add(resized_y, resized_z)
125125

126126
make_model_and_data(two_resizes_with_shared_subgraphs, np.random.rand(1, 1, 4, 5).astype(np.float32), np.random.rand(1, 1, 3, 2).astype(np.float32), np.random.rand(1, 1, 2, 1).astype(np.float32))
127+
128+
batch_size = 1
129+
sequence_length = 320
130+
input_hidden_size = 48
131+
qk_hidden_size = 48
132+
v_hidden_size = 48
133+
num_heads = 4
134+
qk_head_size = int(qk_hidden_size / num_heads)
135+
v_head_size = int(v_hidden_size / num_heads)
136+
attention_weight = np.random.rand(input_hidden_size, qk_hidden_size + qk_hidden_size + v_hidden_size).astype(np.float32)
137+
attention_bias = np.random.rand(qk_hidden_size + qk_hidden_size + v_hidden_size).astype(np.float32)
138+
139+
@ost.script()
140+
def attention(x: ost.FLOAT[batch_size, sequence_length, input_hidden_size]) -> ost.FLOAT[batch_size, sequence_length, input_hidden_size]:
141+
transpose = op.Transpose(x, perm=[1, 0, 2])
142+
qkv_matmul_weight = op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.FLOAT, attention_weight.shape, attention_weight))
143+
qkv_matmul = op.MatMul(transpose, qkv_matmul_weight)
144+
145+
qkv_add_bias = op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.FLOAT, attention_bias.shape, attention_bias))
146+
qkv_add = op.Add(qkv_add_bias, qkv_matmul)
147+
148+
# q path
149+
q_path_slice = op.Slice(qkv_add,
150+
op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.INT64, [1], np.array([0], dtype=np.int64))),
151+
op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.INT64, [1], np.array([qk_hidden_size], dtype=np.int64))),
152+
op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.INT64, [1], np.array([-1], dtype=np.int64))))
153+
q_path_reshape = op.Reshape(q_path_slice, op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.INT64, [3], np.array([sequence_length, batch_size * num_heads, qk_head_size], dtype=np.int64))), allowzero=0)
154+
q_path_transpose = op.Transpose(q_path_reshape, perm=[1, 0, 2])
155+
q_path_div = op.Div(q_path_transpose, op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.FLOAT, [], np.array([np.sqrt(qk_hidden_size)], dtype=np.float32))))
156+
# k path
157+
k_path_slice = op.Slice(qkv_add,
158+
op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.INT64, [1], np.array([qk_hidden_size], dtype=np.int64))),
159+
op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.INT64, [1], np.array([qk_hidden_size + qk_hidden_size], dtype=np.int64))),
160+
op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.INT64, [1], np.array([-1], dtype=np.int64))))
161+
k_path_reshape = op.Reshape(k_path_slice, op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.INT64, [3], np.array([sequence_length, batch_size * num_heads, qk_head_size], dtype=np.int64))), allowzero=0)
162+
k_path_transpose = op.Transpose(k_path_reshape, perm=[1, 2, 0])
163+
164+
# qk path
165+
qk_matmul = op.MatMul(q_path_div, k_path_transpose)
166+
qk_softmax = op.Softmax(qk_matmul)
167+
168+
# v path
169+
v_path_slice = op.Slice(qkv_add,
170+
op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.INT64, [1], np.array([qk_hidden_size + qk_hidden_size], dtype=np.int64))),
171+
op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.INT64, [1], np.array([qk_hidden_size + qk_hidden_size + v_hidden_size], dtype=np.int64))),
172+
op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.INT64, [1], np.array([-1], dtype=np.int64))))
173+
v_path_reshape = op.Reshape(v_path_slice, op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.INT64, [3], np.array([sequence_length, batch_size * num_heads, v_head_size], dtype=np.int64))), allowzero=0)
174+
v_path_transpose = op.Transpose(v_path_reshape, perm=[1, 0, 2])
175+
176+
# matmul
177+
matmul = op.MatMul(qk_softmax, v_path_transpose)
178+
trans = op.Transpose(matmul, perm=[1, 0, 2])
179+
reshape = op.Reshape(trans, op.Constant(value=onnx.helper.make_tensor("", onnx.TensorProto.INT64, [3], np.array([batch_size, sequence_length, v_hidden_size], dtype=np.int64))))
180+
181+
return reshape
182+
183+
make_model_and_data(attention, np.random.rand(batch_size, sequence_length, input_hidden_size).astype(np.float32))
-39 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)