Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dcgan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
torch.nn.init.normal_(m.weight, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
torch.nn.init.normal_(m.weight, 1.0, 0.02)
torch.nn.init.zeros_(m.bias)


class Generator(nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions distributed/rpc/pipeline/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def __init__(self, device, *args, **kwargs):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)

def forward(self, x_rref):
x = x_rref.to_here().to(self.device)
Expand Down Expand Up @@ -285,4 +285,4 @@ def run_worker(rank, world_size, num_split):
tik = time.time()
mp.spawn(run_worker, args=(world_size, num_split), nprocs=world_size, join=True)
tok = time.time()
print(f"number of splits = {num_split}, execution time = {tok - tik}")
print(f"number of splits = {num_split}, execution time = {tok - tik}")
6 changes: 3 additions & 3 deletions distributed/rpc/rnn/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, ntoken, ninp, dropout):
super(EmbeddingTable, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp).cuda()
self.encoder.weight.data.uniform_(-0.1, 0.1)
nn.init.uniform_(self.encoder.weight, -0.1, 0.1)

def forward(self, input):
return self.drop(self.encoder(input.cuda())).cpu()
Expand All @@ -56,8 +56,8 @@ def __init__(self, ntoken, nhid, dropout):
super(Decoder, self).__init__()
self.drop = nn.Dropout(dropout)
self.decoder = nn.Linear(nhid, ntoken)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-0.1, 0.1)
nn.init.zeros_(self.decoder.bias)
nn.init.uniform_(self.decoder.weight, -0.1, 0.1)

def forward(self, output):
return self.decoder(self.drop(output))
Expand Down
2 changes: 1 addition & 1 deletion regression/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_batch(batch_size=32):

# Apply gradients
for param in fc.parameters():
param.data.add_(-0.1 * param.grad.data)
param.add_(-0.1 * param.grad)

# Stop criterion
if loss < 1e-3:
Expand Down
2 changes: 1 addition & 1 deletion word_language_model/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def train():
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
for p in model.parameters():
p.data.add_(-lr, p.grad.data)
p.add_(-lr, p.grad)

total_loss += loss.item()

Expand Down
12 changes: 6 additions & 6 deletions word_language_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weigh

def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
nn.init.uniform_(self.encoder.weight, -initrange, initrange)
nn.init.zeros_(self.decoder)
nn.init.uniform_(self.decoder.weight, -initrange, initrange)

def forward(self, input, hidden):
emb = self.drop(self.encoder(input))
Expand Down Expand Up @@ -132,9 +132,9 @@ def _generate_square_subsequent_mask(self, sz):

def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
nn.init.uniform_(self.encoder.weight, -initrange, initrange)
nn.init.zeros_(self.decoder)
nn.init.uniform_(self.decoder.weight, -initrange, initrange)

def forward(self, src, has_mask=True):
if has_mask:
Expand Down