Skip to content

Commit 0c41aaa

Browse files
authored
Merge pull request #959 from rogday:lstm
2 parents 294a5ca + b840f46 commit 0c41aaa

10 files changed

+33
-0
lines changed
288 Bytes
Binary file not shown.
288 Bytes
Binary file not shown.
160 Bytes
Binary file not shown.
608 Bytes
Binary file not shown.
368 Bytes
Binary file not shown.
152 Bytes
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,39 @@ def forward(self, t):
892892
save_data_and_model("hidden_lstm_bi", input, hidden_lstm, version=11, export_params=True)
893893

894894

895+
batch = 5
896+
features = 4
897+
hidden = 3
898+
seq_len = 2
899+
num_layers=1
900+
bidirectional=True
901+
902+
class LSTM(nn.Module):
903+
904+
def __init__(self):
905+
super(LSTM, self).__init__()
906+
self.lstm = nn.LSTM(features, hidden, num_layers, bidirectional=bidirectional)
907+
self.h0 = torch.from_numpy(np.ones((num_layers + int(bidirectional), batch, hidden), dtype=np.float32))
908+
self.c0 = torch.from_numpy(np.ones((num_layers + int(bidirectional), batch, hidden), dtype=np.float32))
909+
910+
def forward(self, x):
911+
a, (b, c) = self.lstm(x, (self.h0, self.c0))
912+
if bidirectional:
913+
return torch.cat((a, b, c), dim=2)
914+
else:
915+
return torch.cat((a, b, c), dim=0)
916+
917+
918+
input_ = Variable(torch.randn(seq_len, batch, features))
919+
lstm = LSTM()
920+
save_data_and_model("lstm_cell_bidirectional", input_, lstm, export_params=True)
921+
922+
bidirectional = False
923+
input_ = Variable(torch.randn(seq_len, batch, features))
924+
lstm = LSTM()
925+
save_data_and_model("lstm_cell_forward", input_, lstm, export_params=True)
926+
927+
895928
class MatMul(nn.Module):
896929
def __init__(self):
897930
super(MatMul, self).__init__()
2.28 KB
Binary file not shown.
1.58 KB
Binary file not shown.
841 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)