Skip to content

Commit 6be0e3d

Browse files
committed
add data for unflatten and pytorch attention test
1 parent 42f7ec2 commit 6be0e3d

10 files changed

+41
-0
lines changed
368 Bytes
Binary file not shown.
368 Bytes
Binary file not shown.
1.88 KB
Binary file not shown.
368 Bytes
Binary file not shown.
368 Bytes
Binary file not shown.
1.88 KB
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,47 @@ def forward(self, x):
15181518
save_data_and_model("einsum_transpose", mat, einsum, export_params=True)
15191519

15201520

1521+
class TorchAttentionLayer(nn.Module):
1522+
def __init__(self, embed_dim=6, num_heads=1):
1523+
super(TorchAttentionLayer, self).__init__()
1524+
self.attention = nn.MultiheadAttention(
1525+
embed_dim=embed_dim,
1526+
num_heads=num_heads,
1527+
bias=True,
1528+
batch_first=True)
1529+
def forward(self, x):
1530+
return self.attention(x, x, x)[0]
1531+
1532+
num_heads = 1
1533+
batch_size = 2
1534+
num_tokens = 5
1535+
emb_dim = 6
1536+
model = TorchAttentionLayer(embed_dim=emb_dim, num_heads=num_heads).eval()
1537+
1538+
x = torch.rand(batch_size, num_tokens, emb_dim)
1539+
with torch.no_grad():
1540+
output = model(x)
1541+
1542+
save_data_and_model("pytorch_attention_single_head", x, model, export_params=True)
1543+
class Unflatten(torch.nn.Module):
1544+
def __init__(self, E, times):
1545+
super(Unflatten, self).__init__()
1546+
self.E = E
1547+
self.times = times
1548+
1549+
def forward(self, x):
1550+
return x.unflatten(-1, (self.times, self.E))
1551+
1552+
unflatten_dim = 5
1553+
times = 3
1554+
model = Unflatten(unflatten_dim, times).eval()
1555+
1556+
x = torch.rand(10, 3, unflatten_dim * times)
1557+
with torch.no_grad():
1558+
output = model(x)
1559+
1560+
save_data_and_model("unflatten", x, model, export_params=True)
1561+
15211562
def _extract_value_info(x, name, type_proto=None): # type: (Union[List[Any], np.ndarray, None], Text, Optional[TypeProto]) -> onnx.ValueInfoProto
15221563
if type_proto is None:
15231564
if x is None:
6.79 KB
Binary file not shown.
6.79 KB
Binary file not shown.
1.36 KB
Binary file not shown.

0 commit comments

Comments
 (0)