@@ -1518,6 +1518,47 @@ def forward(self, x):
15181518save_data_and_model ("einsum_transpose" , mat , einsum , export_params = True )
15191519
15201520
1521+ class TorchAttentionLayer (nn .Module ):
1522+ def __init__ (self , embed_dim = 6 , num_heads = 1 ):
1523+ super (TorchAttentionLayer , self ).__init__ ()
1524+ self .attention = nn .MultiheadAttention (
1525+ embed_dim = embed_dim ,
1526+ num_heads = num_heads ,
1527+ bias = True ,
1528+ batch_first = True )
1529+ def forward (self , x ):
1530+ return self .attention (x , x , x )[0 ]
1531+
1532+ num_heads = 1
1533+ batch_size = 2
1534+ num_tokens = 5
1535+ emb_dim = 6
1536+ model = TorchAttentionLayer (embed_dim = emb_dim , num_heads = num_heads ).eval ()
1537+
1538+ x = torch .rand (batch_size , num_tokens , emb_dim )
1539+ with torch .no_grad ():
1540+ output = model (x )
1541+
1542+ save_data_and_model ("torch_attention_single_head" , x , model , export_params = True )
1543+ class Unflatten (torch .nn .Module ):
1544+ def __init__ (self , E , times ):
1545+ super (Unflatten , self ).__init__ ()
1546+ self .E = E
1547+ self .times = times
1548+
1549+ def forward (self , x ):
1550+ return x .unflatten (- 1 , (self .times , self .E ))
1551+
1552+ unflatten_dim = 5
1553+ times = 3
1554+ model = Unflatten (unflatten_dim , times ).eval ()
1555+
1556+ x = torch .rand (10 , 3 , unflatten_dim * times )
1557+ with torch .no_grad ():
1558+ output = model (x )
1559+
1560+ save_data_and_model ("unflatten" , x , model , export_params = True )
1561+
15211562def _extract_value_info (x , name , type_proto = None ): # type: (Union[List[Any], np.ndarray, None], Text, Optional[TypeProto]) -> onnx.ValueInfoProto
15221563 if type_proto is None :
15231564 if x is None :
0 commit comments