From e56ed982c9b0d43f1dacfd930f0b8cf22f640666 Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Sun, 3 May 2020 14:28:10 -0400 Subject: [PATCH] Eliminate .data access for parameters as much as possible --- dcgan/main.py | 6 +++--- distributed/rpc/pipeline/main.py | 6 +++--- distributed/rpc/rnn/rnn.py | 6 +++--- regression/main.py | 2 +- word_language_model/main.py | 2 +- word_language_model/model.py | 12 ++++++------ 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/dcgan/main.py b/dcgan/main.py index b426db32f8..674ba620b8 100644 --- a/dcgan/main.py +++ b/dcgan/main.py @@ -113,10 +113,10 @@ def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: - m.weight.data.normal_(0.0, 0.02) + torch.nn.init.normal_(m.weight, 0.0, 0.02) elif classname.find('BatchNorm') != -1: - m.weight.data.normal_(1.0, 0.02) - m.bias.data.fill_(0) + torch.nn.init.normal_(m.weight, 1.0, 0.02) + torch.nn.init.zeros_(m.bias) class Generator(nn.Module): diff --git a/distributed/rpc/pipeline/main.py b/distributed/rpc/pipeline/main.py index 2fcd8220a0..e1c8cd38f0 100644 --- a/distributed/rpc/pipeline/main.py +++ b/distributed/rpc/pipeline/main.py @@ -134,8 +134,8 @@ def __init__(self, device, *args, **kwargs): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) def forward(self, x_rref): x = x_rref.to_here().to(self.device) @@ -285,4 +285,4 @@ def run_worker(rank, world_size, num_split): tik = time.time() mp.spawn(run_worker, args=(world_size, num_split), nprocs=world_size, join=True) tok = time.time() - print(f"number of splits = {num_split}, execution time = {tok - tik}") \ No newline at end of file + print(f"number of splits = {num_split}, execution time = {tok - tik}") diff --git a/distributed/rpc/rnn/rnn.py b/distributed/rpc/rnn/rnn.py index de1206f0a4..64096d035a 100644 --- a/distributed/rpc/rnn/rnn.py +++ b/distributed/rpc/rnn/rnn.py @@ -42,7 +42,7 @@ def __init__(self, ntoken, ninp, dropout): super(EmbeddingTable, self).__init__() self.drop = nn.Dropout(dropout) self.encoder = nn.Embedding(ntoken, ninp).cuda() - self.encoder.weight.data.uniform_(-0.1, 0.1) + nn.init.uniform_(self.encoder.weight, -0.1, 0.1) def forward(self, input): return self.drop(self.encoder(input.cuda())).cpu() @@ -56,8 +56,8 @@ def __init__(self, ntoken, nhid, dropout): super(Decoder, self).__init__() self.drop = nn.Dropout(dropout) self.decoder = nn.Linear(nhid, ntoken) - self.decoder.bias.data.zero_() - self.decoder.weight.data.uniform_(-0.1, 0.1) + nn.init.zeros_(self.decoder.bias) + nn.init.uniform_(self.decoder.weight, -0.1, 0.1) def forward(self, output): return self.decoder(self.drop(output)) diff --git a/regression/main.py b/regression/main.py index af4dd228b7..9cc4cfc8dd 100755 --- a/regression/main.py +++ b/regression/main.py @@ -57,7 +57,7 @@ def get_batch(batch_size=32): # Apply gradients for param in fc.parameters(): - param.data.add_(-0.1 * param.grad.data) + param.add_(-0.1 * param.grad) # Stop criterion if loss < 1e-3: diff --git a/word_language_model/main.py b/word_language_model/main.py index 9d1429c62b..b888c2aa94 100644 --- a/word_language_model/main.py +++ b/word_language_model/main.py @@ -178,7 +178,7 @@ def train(): # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) for p in model.parameters(): - p.data.add_(-lr, p.grad.data) + p.add_(-lr, p.grad) total_loss += loss.item() diff --git a/word_language_model/model.py b/word_language_model/model.py index ca876ccdd5..fb2bb3c39d 100644 --- a/word_language_model/model.py +++ b/word_language_model/model.py @@ -41,9 +41,9 @@ def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weigh def init_weights(self): initrange = 0.1 - self.encoder.weight.data.uniform_(-initrange, initrange) - self.decoder.bias.data.zero_() - self.decoder.weight.data.uniform_(-initrange, initrange) + nn.init.uniform_(self.encoder.weight, -initrange, initrange) + nn.init.zeros_(self.decoder) + nn.init.uniform_(self.decoder.weight, -initrange, initrange) def forward(self, input, hidden): emb = self.drop(self.encoder(input)) @@ -132,9 +132,9 @@ def _generate_square_subsequent_mask(self, sz): def init_weights(self): initrange = 0.1 - self.encoder.weight.data.uniform_(-initrange, initrange) - self.decoder.bias.data.zero_() - self.decoder.weight.data.uniform_(-initrange, initrange) + nn.init.uniform_(self.encoder.weight, -initrange, initrange) + nn.init.zeros_(self.decoder) + nn.init.uniform_(self.decoder.weight, -initrange, initrange) def forward(self, src, has_mask=True): if has_mask: