Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions examples/new_project_templates/lightning_module_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,23 @@ def validation_end(self, outputs):
val_loss_mean = 0
val_acc_mean = 0
for output in outputs:
val_loss_mean += output['val_loss']
val_acc_mean += output['val_acc']
val_loss = output['val_loss']

# reduce manually when using dp
if self.trainer.use_dp:
val_loss = torch.mean(val_loss)
val_loss_mean += val_loss

# reduce manually when using dp
val_acc = output['val_acc']
if self.trainer.use_dp:
val_acc_mean = torch.mean(val_acc)

val_acc_mean += val_acc_mean

val_loss_mean /= len(outputs)
val_acc_mean /= len(outputs)
tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
tqdm_dic = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
return tqdm_dic

# ---------------------
Expand Down
23 changes: 17 additions & 6 deletions pytorch_lightning/models/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,14 @@ def validate(self, model, dataloader, max_batches):
output = model(data_batch, batch_i)
elif self.use_dp:
output = model(data_batch, batch_i)
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))

elif self.single_gpu:
# put inputs on gpu manually
gpu_id = self.data_parallel_device_ids[0]
for i, x in enumerate(data_batch):
if isinstance(x, torch.Tensor):
data_batch[i] = x.cuda(gpu_id)

# do non dp, ddp step
output = model.validation_step(data_batch, batch_i)

else:
Expand Down Expand Up @@ -853,7 +854,6 @@ def __run_tng_batch(self, data_batch, batch_nb):
output = self.model(data_batch, batch_nb)
elif self.use_dp:
output = self.model(data_batch, batch_nb)
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
elif self.single_gpu:
gpu_id = self.data_parallel_device_ids[0]
for i, x in enumerate(data_batch):
Expand All @@ -865,7 +865,14 @@ def __run_tng_batch(self, data_batch, batch_nb):
output = self.model.training_step(data_batch, batch_nb)

try:
model_specific_tqdm_metrics_dic = output['prog']
prog_output = output['prog']

# reduce prog metrics for tqdm when using dp
if self.use_dp:
nb_gpus = len(self.data_parallel_device_ids)
prog_output = reduce_distributed_output(prog_output, nb_gpus)

model_specific_tqdm_metrics_dic = prog_output
except Exception:
model_specific_tqdm_metrics_dic = {}

Expand All @@ -877,6 +884,10 @@ def __run_tng_batch(self, data_batch, batch_nb):
if type(output) is torch.Tensor:
loss = output

# when using dp need to reduce the loss
if self.use_dp:
loss = reduce_distributed_output(loss, len(self.data_parallel_device_ids))

self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)

# backward pass
Expand Down Expand Up @@ -959,12 +970,12 @@ def __run_validation(self):
# use full val set on end of epoch
# use a small portion otherwise
max_batches = None if not self.fast_dev_run else 1
model_specific_tqdm_metrics_dic = self.validate(
validation_results = self.validate(
self.model,
self.val_dataloader,
max_batches
)
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
self.__add_tqdm_metrics(validation_results)

# hook
if self.__is_function_implemented('on_post_performance_check'):
Expand Down