1-
2- local LSTMTDNN = {}
3-
1+ local LSTMCNN = {}
42
53local ok , cunn = pcall (require , ' fbcunn' )
64if not ok then
97 LookupTable = fbcunn .LookupTableGPU
108end
119
12- function LSTMTDNN .lstmtdnn (word_vocab_size , rnn_size , n , dropout , word_vec_size , char_vec_size , char_vocab_size ,
13- num_feature_maps , kernels , word2char2idx )
14- -- input_size = vocab size
10+ function LSTMCNN .lstmcnn (word_vocab_size , rnn_size , n , dropout , word_vec_size , char_vec_size , char_vocab_size ,
11+ feature_maps , kernels , word2char2idx )
1512 -- rnn_size = dimensionality of hidden layers
1613 -- n = number of layers
1714 -- k = word embedding size
@@ -20,38 +17,39 @@ function LSTMTDNN.lstmtdnn(word_vocab_size, rnn_size, n, dropout, word_vec_size,
2017
2118 -- there will be 2*n+1 inputs
2219 local length = word2char2idx :size (2 )
23- local word_vec_size = word_vec_size or rnn_size
2420 local inputs = {}
25- -- table.insert(inputs, nn.Identity()()) -- batch_size x 1 (word indices)
21+ table.insert (inputs , nn .Identity ()()) -- batch_size x 1 (word indices)
2622 table.insert (inputs , nn .Identity ()()) -- batch_size x word length (char indices)
2723 for L = 1 ,n do
2824 table.insert (inputs , nn .Identity ()()) -- prev_c[L]
2925 table.insert (inputs , nn .Identity ()()) -- prev_h[L]
3026 end
3127
32- local x , input_size_L , word_vec , char_vec , tdnn_output
28+ local x , input_size_L , word_vec , char_vec , cnn_output , pool_layer
3329 local outputs = {}
3430 for L = 1 ,n do
3531 -- c,h from previous timesteps
36- local prev_h = inputs [L * 2 + 2 - 1 ]
37- local prev_c = inputs [L * 2 + 1 - 1 ]
32+ local prev_h = inputs [L * 2 + 2 ]
33+ local prev_c = inputs [L * 2 + 1 ]
3834 -- the input to this layer
3935 if L == 1 then
40- char_vec = nn .LookupTable (char_vocab_size , char_vec_size )(inputs [1 ]) -- batch_size * word length * char_vec_size
36+ word_vec = nn .LookupTable (word_vocab_size , word_vec_size )(inputs [1 ])
37+ char_vec = nn .LookupTable (char_vocab_size , char_vec_size )(inputs [2 ]) -- batch_size * word length * char_vec_size
4138 local layer1 = {}
4239 for i = 1 , # kernels do
4340 local reduced_l = length - kernels [i ] + 1
44- local conv_layer = nn .TemporalConvolution (char_vec_size , num_feature_maps , kernels [i ])(char_vec )
45- local pool_layer = nn .TemporalMaxPooling (reduced_l )(nn .Tanh ()(conv_layer ))
41+ local conv_layer = nn .TemporalConvolution (char_vec_size , feature_maps [ i ] , kernels [i ])(char_vec )
42+ pool_layer = nn .TemporalMaxPooling (reduced_l )(nn .Tanh ()(conv_layer ))
4643 table.insert (layer1 , pool_layer )
4744 end
48- local layer1_concat = nn .JoinTable (3 )(layer1 )
49- tdnn_output = nn .Squeeze ()(layer1_concat )
50- -- tdnn_output = TDNN.tdnn(length, char_vec_size, tdnn_output_size, kernels) -- batch_size * tdnn_output_size
51- -- word_vec = LookupTable(word_vocab_size, word_vec_size)(inputs[1])
52- -- x = nn.Identity()(word_vec)
53- x = nn .Identity ()(tdnn_output )
54- input_size_L = word_vec_size
45+ if # kernels > 1 then
46+ local layer1_concat = nn .JoinTable (3 )(layer1 )
47+ cnn_output = nn .Squeeze ()(layer1_concat )
48+ else
49+ cnn_output = nn .Squeeze ()(pool_layer )
50+ end
51+ x = nn .JoinTable (2 )({cnn_output , word_vec })
52+ input_size_L = torch .Tensor (feature_maps ):sum () + word_vec_size
5553 else
5654 x = outputs [(L - 1 )* 2 ]
5755 if dropout > 0 then x = nn .Dropout (dropout )(x ) end -- apply dropout, if any
@@ -65,8 +63,8 @@ function LSTMTDNN.lstmtdnn(word_vocab_size, rnn_size, n, dropout, word_vec_size,
6563 local sigmoid_chunk = nn .Narrow (2 , 1 , 3 * rnn_size )(all_input_sums )
6664 sigmoid_chunk = nn .Sigmoid ()(sigmoid_chunk )
6765 local in_gate = nn .Narrow (2 , 1 , rnn_size )(sigmoid_chunk )
68- local forget_gate = nn .Narrow (2 , rnn_size + 1 , rnn_size )(sigmoid_chunk )
69- local out_gate = nn .Narrow (2 , 2 * rnn_size + 1 , rnn_size )(sigmoid_chunk )
66+ local out_gate = nn .Narrow (2 , rnn_size + 1 , rnn_size )(sigmoid_chunk )
67+ local forget_gate = nn .Narrow (2 , 2 * rnn_size + 1 , rnn_size )(sigmoid_chunk )
7068 -- decode the write inputs
7169 local in_transform = nn .Narrow (2 , 3 * rnn_size + 1 , rnn_size )(all_input_sums )
7270 in_transform = nn .Tanh ()(in_transform )
@@ -92,5 +90,5 @@ function LSTMTDNN.lstmtdnn(word_vocab_size, rnn_size, n, dropout, word_vec_size,
9290 return nn .gModule (inputs , outputs )
9391end
9492
95- return LSTMTDNN
93+ return LSTMCNN
9694
0 commit comments