From b28bed000cb3cbe27842c50f338b5182104da49b Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 19 Nov 2021 11:14:28 +0530 Subject: [PATCH 1/5] Update Lightning examples with the latest updates --- python/ray/tune/examples/mnist_ptl_mini.py | 5 +++-- python/ray/tune/examples/mnist_pytorch_lightning.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/ray/tune/examples/mnist_ptl_mini.py b/python/ray/tune/examples/mnist_ptl_mini.py index 3183aa7984fe..8632f4f61ffb 100644 --- a/python/ray/tune/examples/mnist_ptl_mini.py +++ b/python/ray/tune/examples/mnist_ptl_mini.py @@ -3,6 +3,7 @@ import torch from filelock import FileLock from torch.nn import functional as F +from torchmetrics import Accuracy import pytorch_lightning as pl from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule import os @@ -24,7 +25,7 @@ def __init__(self, config, data_dir=None): self.layer_1 = torch.nn.Linear(28 * 28, layer_1) self.layer_2 = torch.nn.Linear(layer_1, layer_2) self.layer_3 = torch.nn.Linear(layer_2, 10) - self.accuracy = pl.metrics.Accuracy() + self.accuracy = Accuracy() def forward(self, x): batch_size, channels, width, height = x.size() @@ -74,7 +75,7 @@ def train_mnist_tune(config, num_epochs=10, num_gpus=0): max_epochs=num_epochs, # If fractional GPUs passed in, convert to int. gpus=math.ceil(num_gpus), - progress_bar_refresh_rate=0, + enable_progress_bar=False, callbacks=[TuneReportCallback(metrics, on="validation_end")]) trainer.fit(model, dm) diff --git a/python/ray/tune/examples/mnist_pytorch_lightning.py b/python/ray/tune/examples/mnist_pytorch_lightning.py index 1fb720e3eda2..5bfa6a4b98a8 100644 --- a/python/ray/tune/examples/mnist_pytorch_lightning.py +++ b/python/ray/tune/examples/mnist_pytorch_lightning.py @@ -121,7 +121,7 @@ def configure_optimizers(self): def train_mnist(config): model = LightningMNISTClassifier(config) - trainer = pl.Trainer(max_epochs=10, show_progress_bar=False) + trainer = pl.Trainer(max_epochs=10, enable_progress_bar=False) trainer.fit(model) # __lightning_end__ @@ -148,7 +148,7 @@ def train_mnist_tune(config, num_epochs=10, num_gpus=0, data_dir="~/data"): gpus=math.ceil(num_gpus), logger=TensorBoardLogger( save_dir=tune.get_trial_dir(), name="", version="."), - progress_bar_refresh_rate=0, + enable_progress_bar=False, callbacks=[ TuneReportCallback( { @@ -174,7 +174,7 @@ def train_mnist_tune_checkpoint(config, "gpus": math.ceil(num_gpus), "logger": TensorBoardLogger( save_dir=tune.get_trial_dir(), name="", version="."), - "progress_bar_refresh_rate": 0, + "enable_progress_bar": False, "callbacks": [ TuneReportCheckpointCallback( metrics={ From 9cdcff9fa5e553b29bed6b2d053c009a36378cec Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Sun, 28 Nov 2021 21:24:37 +0530 Subject: [PATCH 2/5] Update lightning version --- python/requirements/ml/requirements_tune.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/requirements/ml/requirements_tune.txt b/python/requirements/ml/requirements_tune.txt index c60fa72e1d24..5ad2a804d540 100644 --- a/python/requirements/ml/requirements_tune.txt +++ b/python/requirements/ml/requirements_tune.txt @@ -29,7 +29,7 @@ nevergrad==0.4.3.post7 optuna==2.9.1 pytest-remotedata==0.3.2 lightning-bolts==0.4.0 -pytorch-lightning==1.4.9 +pytorch-lightning>=1.5.3 shortuuid==1.0.1 scikit-learn==0.24.2 scikit-optimize==0.8.1 From 7a3ea1d926ffa88a05f58a4652ef83c751439de4 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Sun, 28 Nov 2021 21:25:17 +0530 Subject: [PATCH 3/5] Fix sanity check calls --- python/ray/tune/integration/pytorch_lightning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/tune/integration/pytorch_lightning.py b/python/ray/tune/integration/pytorch_lightning.py index 66ca898462df..45ad72cb9607 100644 --- a/python/ray/tune/integration/pytorch_lightning.py +++ b/python/ray/tune/integration/pytorch_lightning.py @@ -174,7 +174,7 @@ def __init__(self, def _get_report_dict(self, trainer: Trainer, pl_module: LightningModule): # Don't report if just doing initial validation sanity checks. - if trainer.running_sanity_check: + if trainer.sanity_checking: return if not self._metrics: report_dict = { @@ -228,7 +228,7 @@ def __init__(self, self._filename = filename def _handle(self, trainer: Trainer, pl_module: LightningModule): - if trainer.running_sanity_check: + if trainer.sanity_checking: return step = f"epoch={trainer.current_epoch}-step={trainer.global_step}" with tune.checkpoint_dir(step=step) as checkpoint_dir: From 821f33f9971be92bb8a2743cf9c2dd3128b547fa Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Fri, 11 Feb 2022 10:05:08 -0800 Subject: [PATCH 4/5] Format with black --- python/ray/tune/examples/mnist_ptl_mini.py | 25 ++-- .../ray/tune/integration/pytorch_lightning.py | 136 +++++++++++------- 2 files changed, 101 insertions(+), 60 deletions(-) diff --git a/python/ray/tune/examples/mnist_ptl_mini.py b/python/ray/tune/examples/mnist_ptl_mini.py index 8632f4f61ffb..99243a0ab6e1 100644 --- a/python/ray/tune/examples/mnist_ptl_mini.py +++ b/python/ray/tune/examples/mnist_ptl_mini.py @@ -69,14 +69,16 @@ def train_mnist_tune(config, num_epochs=10, num_gpus=0): model = LightningMNISTClassifier(config, data_dir) with FileLock(os.path.expanduser("~/.data.lock")): dm = MNISTDataModule( - data_dir=data_dir, num_workers=1, batch_size=config["batch_size"]) + data_dir=data_dir, num_workers=1, batch_size=config["batch_size"] + ) metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"} trainer = pl.Trainer( max_epochs=num_epochs, # If fractional GPUs passed in, convert to int. gpus=math.ceil(num_gpus), enable_progress_bar=False, - callbacks=[TuneReportCallback(metrics, on="validation_end")]) + callbacks=[TuneReportCallback(metrics, on="validation_end")], + ) trainer.fit(model, dm) @@ -89,18 +91,17 @@ def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0): } trainable = tune.with_parameters( - train_mnist_tune, num_epochs=num_epochs, num_gpus=gpus_per_trial) + train_mnist_tune, num_epochs=num_epochs, num_gpus=gpus_per_trial + ) analysis = tune.run( trainable, - resources_per_trial={ - "cpu": 1, - "gpu": gpus_per_trial - }, + resources_per_trial={"cpu": 1, "gpu": gpus_per_trial}, metric="loss", mode="min", config=config, num_samples=num_samples, - name="tune_mnist") + name="tune_mnist", + ) print("Best hyperparameters found were: ", analysis.best_config) @@ -110,14 +111,15 @@ def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0): parser = argparse.ArgumentParser() parser.add_argument( - "--smoke-test", action="store_true", help="Finish quickly for testing") + "--smoke-test", action="store_true", help="Finish quickly for testing" + ) parser.add_argument( "--server-address", type=str, default=None, required=False, - help="The address of server to connect to if using " - "Ray Client.") + help="The address of server to connect to if using " "Ray Client.", + ) args, _ = parser.parse_known_args() if args.smoke_test: @@ -125,6 +127,7 @@ def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0): else: if args.server_address: import ray + ray.init(f"ray://{args.server_address}") tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0) diff --git a/python/ray/tune/integration/pytorch_lightning.py b/python/ray/tune/integration/pytorch_lightning.py index 45ad72cb9607..daea752416b7 100644 --- a/python/ray/tune/integration/pytorch_lightning.py +++ b/python/ray/tune/integration/pytorch_lightning.py @@ -11,13 +11,29 @@ class TuneCallback(Callback): """Base class for Tune's PyTorch Lightning callbacks.""" + _allowed = [ - "init_start", "init_end", "fit_start", "fit_end", "sanity_check_start", - "sanity_check_end", "epoch_start", "epoch_end", "batch_start", - "validation_batch_start", "validation_batch_end", "test_batch_start", - "test_batch_end", "batch_end", "train_start", "train_end", - "validation_start", "validation_end", "test_start", "test_end", - "keyboard_interrupt" + "init_start", + "init_end", + "fit_start", + "fit_end", + "sanity_check_start", + "sanity_check_end", + "epoch_start", + "epoch_end", + "batch_start", + "validation_batch_start", + "validation_batch_end", + "test_batch_start", + "test_batch_end", + "batch_end", + "train_start", + "train_end", + "validation_start", + "validation_end", + "test_start", + "test_end", + "keyboard_interrupt", ] def __init__(self, on: Union[str, List[str]] = "validation_end"): @@ -26,7 +42,9 @@ def __init__(self, on: Union[str, List[str]] = "validation_end"): if any(w not in self._allowed for w in on): raise ValueError( "Invalid trigger time selected: {}. Must be one of {}".format( - on, self._allowed)) + on, self._allowed + ) + ) self._on = on def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]): @@ -40,25 +58,21 @@ def on_init_end(self, trainer: Trainer): if "init_end" in self._on: self._handle(trainer, None) - def on_fit_start(self, - trainer: Trainer, - pl_module: Optional[LightningModule] = None): + def on_fit_start( + self, trainer: Trainer, pl_module: Optional[LightningModule] = None + ): if "fit_start" in self._on: self._handle(trainer, None) - def on_fit_end(self, - trainer: Trainer, - pl_module: Optional[LightningModule] = None): + def on_fit_end(self, trainer: Trainer, pl_module: Optional[LightningModule] = None): if "fit_end" in self._on: self._handle(trainer, None) - def on_sanity_check_start(self, trainer: Trainer, - pl_module: LightningModule): + def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule): if "sanity_check_start" in self._on: self._handle(trainer, pl_module) - def on_sanity_check_end(self, trainer: Trainer, - pl_module: LightningModule): + def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule): if "sanity_check_end" in self._on: self._handle(trainer, pl_module) @@ -74,25 +88,49 @@ def on_batch_start(self, trainer: Trainer, pl_module: LightningModule): if "batch_start" in self._on: self._handle(trainer, pl_module) - def on_validation_batch_start(self, trainer: Trainer, - pl_module: LightningModule, batch, batch_idx, - dataloader_idx): + def on_validation_batch_start( + self, + trainer: Trainer, + pl_module: LightningModule, + batch, + batch_idx, + dataloader_idx, + ): if "validation_batch_start" in self._on: self._handle(trainer, pl_module) - def on_validation_batch_end(self, trainer: Trainer, - pl_module: LightningModule, outputs, batch, - batch_idx, dataloader_idx): + def on_validation_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs, + batch, + batch_idx, + dataloader_idx, + ): if "validation_batch_end" in self._on: self._handle(trainer, pl_module) - def on_test_batch_start(self, trainer: Trainer, pl_module: LightningModule, - batch, batch_idx, dataloader_idx): + def on_test_batch_start( + self, + trainer: Trainer, + pl_module: LightningModule, + batch, + batch_idx, + dataloader_idx, + ): if "test_batch_start" in self._on: self._handle(trainer, pl_module) - def on_test_batch_end(self, trainer: Trainer, pl_module: LightningModule, - outputs, batch, batch_idx, dataloader_idx): + def on_test_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs, + batch, + batch_idx, + dataloader_idx, + ): if "test_batch_end" in self._on: self._handle(trainer, pl_module) @@ -108,8 +146,7 @@ def on_train_end(self, trainer: Trainer, pl_module: LightningModule): if "train_end" in self._on: self._handle(trainer, pl_module) - def on_validation_start(self, trainer: Trainer, - pl_module: LightningModule): + def on_validation_start(self, trainer: Trainer, pl_module: LightningModule): if "validation_start" in self._on: self._handle(trainer, pl_module) @@ -125,8 +162,7 @@ def on_test_end(self, trainer: Trainer, pl_module: LightningModule): if "test_end" in self._on: self._handle(trainer, pl_module) - def on_keyboard_interrupt(self, trainer: Trainer, - pl_module: LightningModule): + def on_keyboard_interrupt(self, trainer: Trainer, pl_module: LightningModule): if "keyboard_interrupt" in self._on: self._handle(trainer, pl_module) @@ -164,9 +200,11 @@ class TuneReportCallback(TuneCallback): """ - def __init__(self, - metrics: Union[None, str, List[str], Dict[str, str]] = None, - on: Union[str, List[str]] = "validation_end"): + def __init__( + self, + metrics: Union[None, str, List[str], Dict[str, str]] = None, + on: Union[str, List[str]] = "validation_end", + ): super(TuneReportCallback, self).__init__(on) if isinstance(metrics, str): metrics = [metrics] @@ -177,10 +215,7 @@ def _get_report_dict(self, trainer: Trainer, pl_module: LightningModule): if trainer.sanity_checking: return if not self._metrics: - report_dict = { - k: v.item() - for k, v in trainer.callback_metrics.items() - } + report_dict = {k: v.item() for k, v in trainer.callback_metrics.items()} else: report_dict = {} for key in self._metrics: @@ -191,8 +226,10 @@ def _get_report_dict(self, trainer: Trainer, pl_module: LightningModule): if metric in trainer.callback_metrics: report_dict[key] = trainer.callback_metrics[metric].item() else: - logger.warning(f"Metric {metric} does not exist in " - "`trainer.callback_metrics.") + logger.warning( + f"Metric {metric} does not exist in " + "`trainer.callback_metrics." + ) return report_dict @@ -221,9 +258,9 @@ class _TuneCheckpointCallback(TuneCallback): """ - def __init__(self, - filename: str = "checkpoint", - on: Union[str, List[str]] = "validation_end"): + def __init__( + self, filename: str = "checkpoint", on: Union[str, List[str]] = "validation_end" + ): super(_TuneCheckpointCallback, self).__init__(on) self._filename = filename @@ -232,8 +269,7 @@ def _handle(self, trainer: Trainer, pl_module: LightningModule): return step = f"epoch={trainer.current_epoch}-step={trainer.global_step}" with tune.checkpoint_dir(step=step) as checkpoint_dir: - trainer.save_checkpoint( - os.path.join(checkpoint_dir, self._filename)) + trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename)) class TuneReportCheckpointCallback(TuneCallback): @@ -275,10 +311,12 @@ class TuneReportCheckpointCallback(TuneCallback): _checkpoint_callback_cls = _TuneCheckpointCallback _report_callbacks_cls = TuneReportCallback - def __init__(self, - metrics: Union[None, str, List[str], Dict[str, str]] = None, - filename: str = "checkpoint", - on: Union[str, List[str]] = "validation_end"): + def __init__( + self, + metrics: Union[None, str, List[str], Dict[str, str]] = None, + filename: str = "checkpoint", + on: Union[str, List[str]] = "validation_end", + ): super(TuneReportCheckpointCallback, self).__init__(on) self._checkpoint = self._checkpoint_callback_cls(filename, on) self._report = self._report_callbacks_cls(metrics, on) From 153d3a889c0986ddc24de6333fe8406d6df4c771 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Fri, 11 Feb 2022 12:41:03 -0800 Subject: [PATCH 5/5] update api --- python/ray/util/ray_lightning/simple_tune.py | 4 ++-- python/ray/util/ray_lightning/tune/__init__.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/ray/util/ray_lightning/simple_tune.py b/python/ray/util/ray_lightning/simple_tune.py index d95acb3f290f..7a5f171d7fcd 100644 --- a/python/ray/util/ray_lightning/simple_tune.py +++ b/python/ray/util/ray_lightning/simple_tune.py @@ -8,7 +8,7 @@ import pytorch_lightning as pl from ray.util.ray_lightning import RayPlugin -from ray.util.ray_lightning.tune import TuneReportCallback, get_tune_ddp_resources +from ray.util.ray_lightning.tune import TuneReportCallback, get_tune_resources num_cpus_per_actor = 1 num_workers = 1 @@ -70,7 +70,7 @@ def main(): num_samples=1, metric="loss", mode="min", - resources_per_trial=get_tune_ddp_resources( + resources_per_trial=get_tune_resources( num_workers=num_workers, cpus_per_worker=num_cpus_per_actor ), ) diff --git a/python/ray/util/ray_lightning/tune/__init__.py b/python/ray/util/ray_lightning/tune/__init__.py index 2cc1bd329c84..6c90b5515529 100644 --- a/python/ray/util/ray_lightning/tune/__init__.py +++ b/python/ray/util/ray_lightning/tune/__init__.py @@ -4,13 +4,13 @@ TuneReportCallback = None TuneReportCheckpointCallback = None -get_tune_ddp_resources = None +get_tune_resources = None try: from ray_lightning.tune import ( TuneReportCallback, TuneReportCheckpointCallback, - get_tune_ddp_resources, + get_tune_resources, ) except ImportError: logger.info( @@ -22,5 +22,5 @@ __all__ = [ "TuneReportCallback", "TuneReportCheckpointCallback", - "get_tune_ddp_resources", + "get_tune_resources", ]