From 6dbfa1bed87219a47ca7f42785617fe59d654fc1 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 8 Aug 2019 10:45:47 -0400 Subject: [PATCH 1/5] removed reduce on non-loss outputs from dp --- pytorch_lightning/models/trainer.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 605f2c78b61b3..51960d900fb4e 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -395,8 +395,6 @@ def validate(self, model, dataloader, max_batches): output = model(data_batch, batch_i) elif self.use_dp: output = model(data_batch, batch_i) - output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) - elif self.single_gpu: gpu_id = self.data_parallel_device_ids[0] for i, x in enumerate(data_batch): @@ -853,7 +851,6 @@ def __run_tng_batch(self, data_batch, batch_nb): output = self.model(data_batch, batch_nb) elif self.use_dp: output = self.model(data_batch, batch_nb) - output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) elif self.single_gpu: gpu_id = self.data_parallel_device_ids[0] for i, x in enumerate(data_batch): @@ -865,7 +862,13 @@ def __run_tng_batch(self, data_batch, batch_nb): output = self.model.training_step(data_batch, batch_nb) try: - model_specific_tqdm_metrics_dic = output['prog'] + prog_output = output['prog'] + + # reduce prog metrics for tqdm when using dp + if self.use_dp: + prog_output = reduce_distributed_output(prog_output, len(self.data_parallel_device_ids)) + + model_specific_tqdm_metrics_dic = prog_output except Exception: model_specific_tqdm_metrics_dic = {} @@ -877,6 +880,10 @@ def __run_tng_batch(self, data_batch, batch_nb): if type(output) is torch.Tensor: loss = output + # when using dp need to reduce the loss + if self.use_dp: + loss = reduce_distributed_output(loss, len(self.data_parallel_device_ids)) + self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic) # backward pass @@ -959,12 +966,12 @@ def __run_validation(self): # use full val set on end of epoch # use a small portion otherwise max_batches = None if not self.fast_dev_run else 1 - model_specific_tqdm_metrics_dic = self.validate( + validation_results = self.validate( self.model, self.val_dataloader, max_batches ) - self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic) + self.__add_tqdm_metrics(validation_results) # hook if self.__is_function_implemented('on_post_performance_check'): From 776b08e8e22bd27140f42621c69db20f9c48129a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 8 Aug 2019 10:58:16 -0400 Subject: [PATCH 2/5] fixed val reduce --- .../lightning_module_template.py | 13 ++++++++++++- pytorch_lightning/models/trainer.py | 3 +++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/examples/new_project_templates/lightning_module_template.py b/examples/new_project_templates/lightning_module_template.py index c11dd1b335e57..5762d469f83f0 100644 --- a/examples/new_project_templates/lightning_module_template.py +++ b/examples/new_project_templates/lightning_module_template.py @@ -151,7 +151,18 @@ def validation_end(self, outputs): val_loss_mean = 0 val_acc_mean = 0 for output in outputs: - val_loss_mean += output['val_loss'] + val_loss = output['val_loss'] + + # reduce manually when using dp + if self.trainer.use_dp: + val_loss = torch.mean(val_loss) + val_loss_mean += val_loss + + # reduce manually when using dp + val_acc = output['val_acc'] + if self.trainer.use_dp: + val_acc_mean = torch.mean(val_acc) + val_acc_mean += output['val_acc'] val_loss_mean /= len(outputs) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 51960d900fb4e..91cda180c309f 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -396,10 +396,13 @@ def validate(self, model, dataloader, max_batches): elif self.use_dp: output = model(data_batch, batch_i) elif self.single_gpu: + # put inputs on gpu manually gpu_id = self.data_parallel_device_ids[0] for i, x in enumerate(data_batch): if isinstance(x, torch.Tensor): data_batch[i] = x.cuda(gpu_id) + + # do non dp, ddp step output = model.validation_step(data_batch, batch_i) else: From a1cdc37d141533494f42a69a05520292f7f4f0d4 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 8 Aug 2019 11:06:57 -0400 Subject: [PATCH 3/5] fixed val reduce --- examples/new_project_templates/lightning_module_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/new_project_templates/lightning_module_template.py b/examples/new_project_templates/lightning_module_template.py index 5762d469f83f0..8a81aef9b17c0 100644 --- a/examples/new_project_templates/lightning_module_template.py +++ b/examples/new_project_templates/lightning_module_template.py @@ -163,7 +163,7 @@ def validation_end(self, outputs): if self.trainer.use_dp: val_acc_mean = torch.mean(val_acc) - val_acc_mean += output['val_acc'] + val_acc_mean += val_acc_mean val_loss_mean /= len(outputs) val_acc_mean /= len(outputs) From 9ecf3408615625564a17b15fe51e679bf0742ea6 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 8 Aug 2019 11:26:00 -0400 Subject: [PATCH 4/5] fixed val reduce --- examples/new_project_templates/lightning_module_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/new_project_templates/lightning_module_template.py b/examples/new_project_templates/lightning_module_template.py index 8a81aef9b17c0..94e3407d96e1e 100644 --- a/examples/new_project_templates/lightning_module_template.py +++ b/examples/new_project_templates/lightning_module_template.py @@ -167,7 +167,7 @@ def validation_end(self, outputs): val_loss_mean /= len(outputs) val_acc_mean /= len(outputs) - tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} + tqdm_dic = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean} return tqdm_dic # --------------------- From 7ee0c58eae417f3911a229fcec8d2a34d786304d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 8 Aug 2019 11:49:01 -0400 Subject: [PATCH 5/5] fixed val reduce --- pytorch_lightning/models/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 91cda180c309f..e4b2caafb2793 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -869,7 +869,8 @@ def __run_tng_batch(self, data_batch, batch_nb): # reduce prog metrics for tqdm when using dp if self.use_dp: - prog_output = reduce_distributed_output(prog_output, len(self.data_parallel_device_ids)) + nb_gpus = len(self.data_parallel_device_ids) + prog_output = reduce_distributed_output(prog_output, nb_gpus) model_specific_tqdm_metrics_dic = prog_output except Exception: