Skip to content

Commit 1877b87

Browse files
authored
Eliminate .data access for parameters as much as possible (#767)
1 parent 31643b2 commit 1877b87

File tree

6 files changed

+16
-16
lines changed

6 files changed

+16
-16
lines changed

dcgan/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@
113113
def weights_init(m):
114114
classname = m.__class__.__name__
115115
if classname.find('Conv') != -1:
116-
m.weight.data.normal_(0.0, 0.02)
116+
torch.nn.init.normal_(m.weight, 0.0, 0.02)
117117
elif classname.find('BatchNorm') != -1:
118-
m.weight.data.normal_(1.0, 0.02)
119-
m.bias.data.fill_(0)
118+
torch.nn.init.normal_(m.weight, 1.0, 0.02)
119+
torch.nn.init.zeros_(m.bias)
120120

121121

122122
class Generator(nn.Module):

distributed/rpc/pipeline/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ def __init__(self, device, *args, **kwargs):
134134
if isinstance(m, nn.Conv2d):
135135
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
136136
elif isinstance(m, nn.BatchNorm2d):
137-
nn.init.constant_(m.weight, 1)
138-
nn.init.constant_(m.bias, 0)
137+
nn.init.ones_(m.weight)
138+
nn.init.zeros_(m.bias)
139139

140140
def forward(self, x_rref):
141141
x = x_rref.to_here().to(self.device)

distributed/rpc/rnn/rnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, ntoken, ninp, dropout):
4242
super(EmbeddingTable, self).__init__()
4343
self.drop = nn.Dropout(dropout)
4444
self.encoder = nn.Embedding(ntoken, ninp).cuda()
45-
self.encoder.weight.data.uniform_(-0.1, 0.1)
45+
nn.init.uniform_(self.encoder.weight, -0.1, 0.1)
4646

4747
def forward(self, input):
4848
return self.drop(self.encoder(input.cuda())).cpu()
@@ -56,8 +56,8 @@ def __init__(self, ntoken, nhid, dropout):
5656
super(Decoder, self).__init__()
5757
self.drop = nn.Dropout(dropout)
5858
self.decoder = nn.Linear(nhid, ntoken)
59-
self.decoder.bias.data.zero_()
60-
self.decoder.weight.data.uniform_(-0.1, 0.1)
59+
nn.init.zeros_(self.decoder.bias)
60+
nn.init.uniform_(self.decoder.weight, -0.1, 0.1)
6161

6262
def forward(self, output):
6363
return self.decoder(self.drop(output))

regression/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_batch(batch_size=32):
5757

5858
# Apply gradients
5959
for param in fc.parameters():
60-
param.data.add_(-0.1 * param.grad.data)
60+
param.add_(-0.1 * param.grad)
6161

6262
# Stop criterion
6363
if loss < 1e-3:

word_language_model/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def train():
178178
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
179179
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
180180
for p in model.parameters():
181-
p.data.add_(-lr, p.grad.data)
181+
p.add_(-lr, p.grad)
182182

183183
total_loss += loss.item()
184184

word_language_model/model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weigh
4141

4242
def init_weights(self):
4343
initrange = 0.1
44-
self.encoder.weight.data.uniform_(-initrange, initrange)
45-
self.decoder.bias.data.zero_()
46-
self.decoder.weight.data.uniform_(-initrange, initrange)
44+
nn.init.uniform_(self.encoder.weight, -initrange, initrange)
45+
nn.init.zeros_(self.decoder)
46+
nn.init.uniform_(self.decoder.weight, -initrange, initrange)
4747

4848
def forward(self, input, hidden):
4949
emb = self.drop(self.encoder(input))
@@ -132,9 +132,9 @@ def _generate_square_subsequent_mask(self, sz):
132132

133133
def init_weights(self):
134134
initrange = 0.1
135-
self.encoder.weight.data.uniform_(-initrange, initrange)
136-
self.decoder.bias.data.zero_()
137-
self.decoder.weight.data.uniform_(-initrange, initrange)
135+
nn.init.uniform_(self.encoder.weight, -initrange, initrange)
136+
nn.init.zeros_(self.decoder)
137+
nn.init.uniform_(self.decoder.weight, -initrange, initrange)
138138

139139
def forward(self, src, has_mask=True):
140140
if has_mask:

0 commit comments

Comments
 (0)