diff --git a/word_language_model/model.py b/word_language_model/model.py index 776573b4c5..7b6c23135e 100644 --- a/word_language_model/model.py +++ b/word_language_model/model.py @@ -133,7 +133,7 @@ def _generate_square_subsequent_mask(self, sz): def init_weights(self): initrange = 0.1 nn.init.uniform_(self.encoder.weight, -initrange, initrange) - nn.init.zeros_(self.decoder) + nn.init.zeros_(self.decoder.weight) nn.init.uniform_(self.decoder.weight, -initrange, initrange) def forward(self, src, has_mask=True):