@@ -892,6 +892,39 @@ def forward(self, t):
892892save_data_and_model ("hidden_lstm_bi" , input , hidden_lstm , version = 11 , export_params = True )
893893
894894
895+ batch = 5
896+ features = 4
897+ hidden = 3
898+ seq_len = 2
899+ num_layers = 1
900+ bidirectional = True
901+
902+ class LSTM (nn .Module ):
903+
904+ def __init__ (self ):
905+ super (LSTM , self ).__init__ ()
906+ self .lstm = nn .LSTM (features , hidden , num_layers , bidirectional = bidirectional )
907+ self .h0 = torch .from_numpy (np .ones ((num_layers + int (bidirectional ), batch , hidden ), dtype = np .float32 ))
908+ self .c0 = torch .from_numpy (np .ones ((num_layers + int (bidirectional ), batch , hidden ), dtype = np .float32 ))
909+
910+ def forward (self , x ):
911+ a , (b , c ) = self .lstm (x , (self .h0 , self .c0 ))
912+ if bidirectional :
913+ return torch .cat ((a , b , c ), dim = 2 )
914+ else :
915+ return torch .cat ((a , b , c ), dim = 0 )
916+
917+
918+ input_ = Variable (torch .randn (seq_len , batch , features ))
919+ lstm = LSTM ()
920+ save_data_and_model ("lstm_cell_bidirectional" , input_ , lstm , export_params = True )
921+
922+ bidirectional = False
923+ input_ = Variable (torch .randn (seq_len , batch , features ))
924+ lstm = LSTM ()
925+ save_data_and_model ("lstm_cell_forward" , input_ , lstm , export_params = True )
926+
927+
895928class MatMul (nn .Module ):
896929 def __init__ (self ):
897930 super (MatMul , self ).__init__ ()
0 commit comments