diff --git a/CHANGELOG.md b/CHANGELOG.md index fb558f6e2cb16..2287e97f7751e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -103,7 +103,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `GPUStatsMonitor` callback ([#12554](https://github.com/PyTorchLightning/pytorch-lightning/pull/12554)) -- Removed support for passing strategy names or strategy classes to the accelerator Trainer argument ([#12696](https://github.com/PyTorchLightning/pytorch-lightning/pull/12696)) +- Removed support for passing strategy names or strategy instances to the accelerator Trainer argument ([#12696](https://github.com/PyTorchLightning/pytorch-lightning/pull/12696)) + + +- Removed support for passing strategy names or strategy instances to the plugins Trainer argument ([#12700](https://github.com/PyTorchLightning/pytorch-lightning/pull/12700)) ### Fixed diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 7c48fd95f3500..d472a7d17ab83 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -123,21 +123,18 @@ def __init__( C. plugins flag could be: 1. List of str, which could contain: - i. strategy str - ii. precision str (Not supported in the old accelerator_connector version) - iii. checkpoint_io str (Not supported in the old accelerator_connector version) - iv. cluster_environment str (Not supported in the old accelerator_connector version) + i. precision str (Not supported in the old accelerator_connector version) + ii. checkpoint_io str (Not supported in the old accelerator_connector version) + iii. cluster_environment str (Not supported in the old accelerator_connector version) 2. List of class, which could contains: - i. strategy class (deprecated in 1.5 will be removed in 1.7) - ii. precision class (should be removed, and precision flag should allow user pass classes) - iii. checkpoint_io class - iv. cluster_environment class + i. precision class (should be removed, and precision flag should allow user pass classes) + ii. checkpoint_io class + iii. cluster_environment class priorities which to take when: A. Class > str B. Strategy > Accelerator/precision/plugins - C. TODO When multiple flag set to the same thing """ if benchmark and deterministic: rank_zero_warn( @@ -228,13 +225,14 @@ def _check_config_and_set_final_flags( ) -> None: """This method checks: - 1. strategy: strategy and plugin can be set to strategies + 1. strategy: whether the strategy name is valid, and sets the internal flags if it is. 2. accelerator: if the value of the accelerator argument is a type of accelerator (instance or string), set self._accelerator_flag accordingly. 3. precision: The final value of the precision flag may be determined either by the precision argument or by a plugin instance. - 4. plugins: a plugin could occur as a value of the strategy argument (handled by 1), or the precision - argument (handled by 3). We also extract the CheckpointIO and ClusterEnvironment plugins. + 4. plugins: The list of plugins may contain a Precision plugin, CheckpointIO, ClusterEnvironment and others. + Additionally, other flags such as `precision` or `sync_batchnorm` can populate the list with the + corresponding plugin instances. """ if plugins is not None: plugins = [plugins] if not isinstance(plugins, list) else plugins @@ -254,18 +252,6 @@ def _check_config_and_set_final_flags( "`Trainer(strategy='tpu_spawn')` is not a valid strategy," " you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead." ) - if plugins: - for plugin in plugins: - if isinstance(plugin, Strategy): - raise MisconfigurationException( - f"You have passed `Trainer(strategy={strategy})`" - f" and you can only specify one strategy, but you have passed {plugin} as a plugin." - ) - if isinstance(plugin, str) and plugin in self._registered_strategies: - raise MisconfigurationException( - f"You have passed `Trainer(strategy={strategy})`" - f" and you can only specify one strategy, but you have passed {plugin} as a plugin." - ) if accelerator is not None: if accelerator in self._accelerator_types or accelerator == "auto" or isinstance(accelerator, Accelerator): @@ -281,15 +267,7 @@ def _check_config_and_set_final_flags( if plugins: plugins_flags_types: Dict[str, int] = Counter() for plugin in plugins: - if isinstance(plugin, Strategy) or isinstance(plugin, str) and plugin in self._registered_strategies: - self._strategy_flag = plugin - rank_zero_deprecation( - f"Passing {plugin} `strategy` to the `plugins` flag in Trainer has been deprecated" - f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plugin})` instead." - ) - plugins_flags_types[Strategy.__name__] += 1 - - elif isinstance(plugin, PrecisionPlugin): + if isinstance(plugin, PrecisionPlugin): self._precision_plugin_flag = plugin plugins_flags_types[PrecisionPlugin.__name__] += 1 elif isinstance(plugin, CheckpointIO): @@ -309,7 +287,7 @@ def _check_config_and_set_final_flags( else: raise MisconfigurationException( f"Found invalid type for plugin {plugin}. Expected one of: PrecisionPlugin, " - "CheckpointIO, ClusterEnviroment, LayerSync, or Strategy." + "CheckpointIO, ClusterEnviroment, or LayerSync." ) duplicated_plugin_key = [k for k, v in plugins_flags_types.items() if v > 1] diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 64961026977ee..55d859033bbda 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -330,12 +330,6 @@ def test_unsupported_strategy_types_on_cpu(strategy): assert isinstance(trainer.strategy, DDPStrategy) -def test_exception_when_strategy_used_with_plugins(): - with pytest.raises(MisconfigurationException, match="only specify one strategy, but you have passed"): - with pytest.deprecated_call(match=r"`strategy` to the `plugins` flag in Trainer has been deprecated"): - Trainer(plugins="ddp_find_unused_parameters_false", strategy="ddp_spawn") - - def test_exception_invalid_strategy(): with pytest.raises(MisconfigurationException, match=r"strategy='ddp_cpu'\)` is not a valid"): Trainer(strategy="ddp_cpu") diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 7001728226bb2..466a4a5561664 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -289,11 +289,6 @@ def test_v1_7_0_deprecate_parameter_validation(): from pytorch_lightning.core.decorators import parameter_validation # noqa: F401 -def test_v1_7_0_passing_strategy_to_plugins_flag(): - with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."): - Trainer(plugins="ddp_spawn") - - def test_v1_7_0_weights_summary_trainer(tmpdir): with pytest.deprecated_call(match=r"Setting `Trainer\(weights_summary=full\)` is deprecated in v1.5"): t = Trainer(weights_summary="full")