Skip to content

Commit eecd1bd

Browse files
committed
models and data for unflatten and attention tests
1 parent 42f7ec2 commit eecd1bd

File tree

7 files changed

+41
-0
lines changed

7 files changed

+41
-0
lines changed
368 Bytes
Binary file not shown.
1.88 KB
Binary file not shown.
368 Bytes
Binary file not shown.
1.88 KB
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,47 @@ def forward(self, x):
15181518
save_data_and_model("einsum_transpose", mat, einsum, export_params=True)
15191519

15201520

1521+
class TorchAttentionLayer(nn.Module):
1522+
def __init__(self, embed_dim=6, num_heads=1):
1523+
super(TorchAttentionLayer, self).__init__()
1524+
self.attention = nn.MultiheadAttention(
1525+
embed_dim=embed_dim,
1526+
num_heads=num_heads,
1527+
bias=True,
1528+
batch_first=True)
1529+
def forward(self, x):
1530+
return self.attention(x, x, x)[0]
1531+
1532+
num_heads = 1
1533+
batch_size = 2
1534+
num_tokens = 5
1535+
emb_dim = 6
1536+
model = TorchAttentionLayer(embed_dim=emb_dim, num_heads=num_heads).eval()
1537+
1538+
x = torch.rand(batch_size, num_tokens, emb_dim)
1539+
with torch.no_grad():
1540+
output = model(x)
1541+
1542+
save_data_and_model("torch_attention_single_head", x, model, export_params=True)
1543+
class Unflatten(torch.nn.Module):
1544+
def __init__(self, E, times):
1545+
super(Unflatten, self).__init__()
1546+
self.E = E
1547+
self.times = times
1548+
1549+
def forward(self, x):
1550+
return x.unflatten(-1, (self.times, self.E))
1551+
1552+
unflatten_dim = 5
1553+
times = 3
1554+
model = Unflatten(unflatten_dim, times).eval()
1555+
1556+
x = torch.rand(10, 3, unflatten_dim * times)
1557+
with torch.no_grad():
1558+
output = model(x)
1559+
1560+
save_data_and_model("unflatten", x, model, export_params=True)
1561+
15211562
def _extract_value_info(x, name, type_proto=None): # type: (Union[List[Any], np.ndarray, None], Text, Optional[TypeProto]) -> onnx.ValueInfoProto
15221563
if type_proto is None:
15231564
if x is None:
6.79 KB
Binary file not shown.
1.36 KB
Binary file not shown.

0 commit comments

Comments
 (0)