From 89a5937efab3f35b1ee05c646ba8f119999ba7d2 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Tue, 6 Apr 2021 10:31:26 +0200 Subject: [PATCH 1/3] Added datamodules to lr_find --- pytorch_lightning/tuner/tuning.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index b9fa9afe0e77e..b23836dc44826 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -61,7 +61,11 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule): # Run learning rate finder: if self.trainer.auto_lr_find: - self.lr_find(model, update_attr=True) + self.lr_find(model, + train_dataloader=train_dataloader, + val_dataloaders=val_dataloaders, + datamodule=datamodule + update_attr=True) self.trainer.state = TrainerState.FINISHED From ea928bc15f060d09c24ca655112e2e654556de40 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Tue, 6 Apr 2021 10:33:01 +0200 Subject: [PATCH 2/3] small change --- pytorch_lightning/tuner/tuning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index b23836dc44826..a520a1caf7e03 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -64,7 +64,7 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule): self.lr_find(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, - datamodule=datamodule + datamodule=datamodule, update_attr=True) self.trainer.state = TrainerState.FINISHED From 78866609e7c4922e45bd672330b8f6536ad07692 Mon Sep 17 00:00:00 2001 From: Lukas Folle Date: Tue, 6 Apr 2021 10:37:51 +0200 Subject: [PATCH 3/3] Removed whitespaces --- pytorch_lightning/tuner/tuning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index a520a1caf7e03..6ebc25915ef83 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -61,7 +61,7 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule): # Run learning rate finder: if self.trainer.auto_lr_find: - self.lr_find(model, + self.lr_find(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule,