Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `GPUStatsMonitor` callback ([#12554](https://github.com/PyTorchLightning/pytorch-lightning/pull/12554))


- Removed support for passing strategy names or strategy classes to the accelerator Trainer argument ([#12696](https://github.com/PyTorchLightning/pytorch-lightning/pull/12696))
- Removed support for passing strategy names or strategy instances to the accelerator Trainer argument ([#12696](https://github.com/PyTorchLightning/pytorch-lightning/pull/12696))


- Removed support for passing strategy names or strategy instances to the plugins Trainer argument ([#12700](https://github.com/PyTorchLightning/pytorch-lightning/pull/12700))


### Fixed
Expand Down
46 changes: 12 additions & 34 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,18 @@ def __init__(

C. plugins flag could be:
1. List of str, which could contain:
i. strategy str
ii. precision str (Not supported in the old accelerator_connector version)
iii. checkpoint_io str (Not supported in the old accelerator_connector version)
iv. cluster_environment str (Not supported in the old accelerator_connector version)
i. precision str (Not supported in the old accelerator_connector version)
ii. checkpoint_io str (Not supported in the old accelerator_connector version)
iii. cluster_environment str (Not supported in the old accelerator_connector version)
2. List of class, which could contains:
i. strategy class (deprecated in 1.5 will be removed in 1.7)
ii. precision class (should be removed, and precision flag should allow user pass classes)
iii. checkpoint_io class
iv. cluster_environment class
i. precision class (should be removed, and precision flag should allow user pass classes)
ii. checkpoint_io class
iii. cluster_environment class


priorities which to take when:
A. Class > str
B. Strategy > Accelerator/precision/plugins
C. TODO When multiple flag set to the same thing
"""
if benchmark and deterministic:
rank_zero_warn(
Expand Down Expand Up @@ -228,13 +225,14 @@ def _check_config_and_set_final_flags(
) -> None:
"""This method checks:

1. strategy: strategy and plugin can be set to strategies
1. strategy: whether the strategy name is valid, and sets the internal flags if it is.
2. accelerator: if the value of the accelerator argument is a type of accelerator (instance or string),
set self._accelerator_flag accordingly.
3. precision: The final value of the precision flag may be determined either by the precision argument or
by a plugin instance.
4. plugins: a plugin could occur as a value of the strategy argument (handled by 1), or the precision
argument (handled by 3). We also extract the CheckpointIO and ClusterEnvironment plugins.
4. plugins: The list of plugins may contain a Precision plugin, CheckpointIO, ClusterEnvironment and others.
Additionally, other flags such as `precision` or `sync_batchnorm` can populate the list with the
corresponding plugin instances.
"""
if plugins is not None:
plugins = [plugins] if not isinstance(plugins, list) else plugins
Expand All @@ -254,18 +252,6 @@ def _check_config_and_set_final_flags(
"`Trainer(strategy='tpu_spawn')` is not a valid strategy,"
" you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead."
)
if plugins:
for plugin in plugins:
if isinstance(plugin, Strategy):
raise MisconfigurationException(
f"You have passed `Trainer(strategy={strategy})`"
f" and you can only specify one strategy, but you have passed {plugin} as a plugin."
)
if isinstance(plugin, str) and plugin in self._registered_strategies:
raise MisconfigurationException(
f"You have passed `Trainer(strategy={strategy})`"
f" and you can only specify one strategy, but you have passed {plugin} as a plugin."
)

if accelerator is not None:
if accelerator in self._accelerator_types or accelerator == "auto" or isinstance(accelerator, Accelerator):
Expand All @@ -281,15 +267,7 @@ def _check_config_and_set_final_flags(
if plugins:
plugins_flags_types: Dict[str, int] = Counter()
for plugin in plugins:
if isinstance(plugin, Strategy) or isinstance(plugin, str) and plugin in self._registered_strategies:
self._strategy_flag = plugin
rank_zero_deprecation(
f"Passing {plugin} `strategy` to the `plugins` flag in Trainer has been deprecated"
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plugin})` instead."
)
plugins_flags_types[Strategy.__name__] += 1

elif isinstance(plugin, PrecisionPlugin):
if isinstance(plugin, PrecisionPlugin):
self._precision_plugin_flag = plugin
plugins_flags_types[PrecisionPlugin.__name__] += 1
elif isinstance(plugin, CheckpointIO):
Expand All @@ -309,7 +287,7 @@ def _check_config_and_set_final_flags(
else:
raise MisconfigurationException(
f"Found invalid type for plugin {plugin}. Expected one of: PrecisionPlugin, "
"CheckpointIO, ClusterEnviroment, LayerSync, or Strategy."
"CheckpointIO, ClusterEnviroment, or LayerSync."
)

duplicated_plugin_key = [k for k, v in plugins_flags_types.items() if v > 1]
Expand Down
6 changes: 0 additions & 6 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,6 @@ def test_unsupported_strategy_types_on_cpu(strategy):
assert isinstance(trainer.strategy, DDPStrategy)


def test_exception_when_strategy_used_with_plugins():
with pytest.raises(MisconfigurationException, match="only specify one strategy, but you have passed"):
with pytest.deprecated_call(match=r"`strategy` to the `plugins` flag in Trainer has been deprecated"):
Trainer(plugins="ddp_find_unused_parameters_false", strategy="ddp_spawn")


def test_exception_invalid_strategy():
with pytest.raises(MisconfigurationException, match=r"strategy='ddp_cpu'\)` is not a valid"):
Trainer(strategy="ddp_cpu")
Expand Down
5 changes: 0 additions & 5 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,6 @@ def test_v1_7_0_deprecate_parameter_validation():
from pytorch_lightning.core.decorators import parameter_validation # noqa: F401


def test_v1_7_0_passing_strategy_to_plugins_flag():
with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."):
Trainer(plugins="ddp_spawn")


def test_v1_7_0_weights_summary_trainer(tmpdir):
with pytest.deprecated_call(match=r"Setting `Trainer\(weights_summary=full\)` is deprecated in v1.5"):
t = Trainer(weights_summary="full")
Expand Down