Skip to content

Commit 67ef2c0

Browse files
committed
Add backward-compatibility for LightningLite in PL (#14735)
1 parent ecca4f5 commit 67ef2c0

File tree

15 files changed

+448
-84
lines changed

15 files changed

+448
-84
lines changed

examples/convert_from_pt_to_pl/image_classifier_2_lite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@
3838
from torch.optim.lr_scheduler import StepLR
3939
from torchmetrics.classification import Accuracy
4040

41-
from lightning_lite.lite import LightningLite # import LightningLite
4241
from pytorch_lightning import seed_everything
4342
from pytorch_lightning.demos.boring_classes import Net
4443
from pytorch_lightning.demos.mnist_datamodule import MNIST
44+
from pytorch_lightning.lite import LightningLite # import LightningLite
4545

4646
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
4747

examples/convert_from_pt_to_pl/image_classifier_3_lite_to_lightning_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434
from torch.optim.lr_scheduler import StepLR
3535
from torchmetrics import Accuracy
3636

37-
from lightning_lite.lite import LightningLite
3837
from pytorch_lightning import seed_everything
3938
from pytorch_lightning.demos.boring_classes import Net
4039
from pytorch_lightning.demos.mnist_datamodule import MNIST
40+
from pytorch_lightning.lite import LightningLite
4141

4242
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
4343

examples/pl_loops/mnist_lite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
from torch.optim.lr_scheduler import StepLR
2323
from torchmetrics import Accuracy
2424

25-
from lightning_lite.lite import LightningLite
2625
from pytorch_lightning import seed_everything
2726
from pytorch_lightning.demos.boring_classes import Net
2827
from pytorch_lightning.demos.mnist_datamodule import MNIST
28+
from pytorch_lightning.lite import LightningLite
2929
from pytorch_lightning.loops import Loop
3030

3131
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")

src/lightning_lite/connector.py

Lines changed: 4 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@
5252
XLAStrategy,
5353
)
5454
from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES
55-
from lightning_lite.utilities import _StrategyType, rank_zero_deprecation, rank_zero_info, rank_zero_warn
55+
from lightning_lite.utilities import _StrategyType, rank_zero_info, rank_zero_warn
5656
from lightning_lite.utilities.device_parser import determine_root_gpu_device
5757
from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE, _TPU_AVAILABLE
5858

59-
_PLUGIN = Union[Strategy, Precision, ClusterEnvironment, CheckpointIO]
59+
_PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO]
6060
_PLUGIN_INPUT = Union[_PLUGIN, str]
6161

6262

@@ -99,8 +99,6 @@ def __init__(
9999
num_nodes: int = 1,
100100
precision: Union[int, str] = 32,
101101
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
102-
tpu_cores: Optional[Union[List[int], str, int]] = None, # deprecated
103-
gpus: Optional[Union[List[int], str, int]] = None, # deprecated
104102
) -> None:
105103
# 1. Parsing flags
106104
# Get registered strategies, built-in accelerators and precision plugins
@@ -125,9 +123,7 @@ def __init__(
125123
precision=precision,
126124
plugins=plugins,
127125
)
128-
self._check_device_config_and_set_final_flags(
129-
devices=devices, num_nodes=num_nodes, gpus=gpus, tpu_cores=tpu_cores
130-
)
126+
self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes)
131127

132128
# 2. Instantiate Accelerator
133129
# handle `auto`, `None` and `gpu`
@@ -278,11 +274,7 @@ def _check_config_and_set_final_flags(
278274
self._parallel_devices = self._strategy_flag.parallel_devices
279275

280276
def _check_device_config_and_set_final_flags(
281-
self,
282-
devices: Optional[Union[List[int], str, int]],
283-
num_nodes: int,
284-
gpus: Optional[Union[List[int], str, int]],
285-
tpu_cores: Optional[Union[List[int], str, int]],
277+
self, devices: Optional[Union[List[int], str, int]], num_nodes: int
286278
) -> None:
287279
self._num_nodes_flag = int(num_nodes) if num_nodes is not None else 1
288280
self._devices_flag = devices
@@ -298,56 +290,12 @@ def _check_device_config_and_set_final_flags(
298290
f" using {accelerator_name} accelerator."
299291
)
300292

301-
# TODO: Delete this method when num_processes, gpus, ipus and tpu_cores gets removed
302-
self._map_deprecated_devices_specific_info_to_accelerator_and_device_flag(devices, gpus, tpu_cores)
303-
304293
if self._devices_flag == "auto" and self._accelerator_flag is None:
305294
raise ValueError(
306295
f"You passed `devices={devices}` but haven't specified"
307296
" `accelerator=('auto'|'tpu'|'gpu'|'cpu'|'mps')` for the devices mapping."
308297
)
309298

310-
def _map_deprecated_devices_specific_info_to_accelerator_and_device_flag(
311-
self,
312-
devices: Optional[Union[List[int], str, int]],
313-
gpus: Optional[Union[List[int], str, int]],
314-
tpu_cores: Optional[Union[List[int], str, int]],
315-
) -> None:
316-
"""Emit deprecation warnings for num_processes, gpus, ipus, tpu_cores and set the `devices_flag` and
317-
`accelerator_flag`."""
318-
if gpus is not None:
319-
rank_zero_deprecation(
320-
f"Setting `Lite(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
321-
f" in v2.0. Please use `Lite(accelerator='gpu', devices={gpus!r})` instead."
322-
)
323-
if tpu_cores is not None:
324-
rank_zero_deprecation(
325-
f"Setting `Lite(tpu_cores={tpu_cores!r})` is deprecated in v1.7 and will be removed"
326-
f" in v2.0. Please use `Lite(accelerator='tpu', devices={tpu_cores!r})` instead."
327-
)
328-
self._gpus: Optional[Union[List[int], str, int]] = gpus
329-
self._tpu_cores: Optional[Union[List[int], str, int]] = tpu_cores
330-
deprecated_devices_specific_flag = gpus or tpu_cores
331-
if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in ([], 0, "0"):
332-
if devices:
333-
# TODO: improve error message
334-
rank_zero_warn(
335-
f"The flag `devices={devices}` will be ignored, "
336-
f"instead the device specific number {deprecated_devices_specific_flag} will be used"
337-
)
338-
339-
if [(gpus is not None), (tpu_cores is not None)].count(True) > 1:
340-
# TODO: improve error message
341-
rank_zero_warn("more than one device specific flag has been set")
342-
self._devices_flag = deprecated_devices_specific_flag
343-
344-
if self._accelerator_flag is None:
345-
# set accelerator type based on num_processes, gpus, ipus, tpu_cores
346-
if tpu_cores:
347-
self._accelerator_flag = "tpu"
348-
if gpus:
349-
self._accelerator_flag = "cuda"
350-
351299
def _choose_auto_accelerator(self) -> str:
352300
"""Choose the accelerator type (str) based on availability when ``accelerator='auto'``."""
353301
if self._accelerator_flag == "auto":
@@ -392,9 +340,6 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
392340

393341
self._set_devices_flag_if_auto_passed()
394342

395-
self._gpus = self._devices_flag if not self._gpus else self._gpus
396-
self._tpu_cores = self._devices_flag if not self._tpu_cores else self._tpu_cores
397-
398343
self._devices_flag = self.accelerator.parse_devices(self._devices_flag)
399344
if not self._parallel_devices:
400345
self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag)

src/lightning_lite/lite.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ class LightningLite(ABC):
6464
precision: Double precision (``64``), full precision (``32``), half precision (``16``),
6565
or bfloat16 precision (``"bf16"``).
6666
plugins: One or several custom plugins
67-
gpus: Provides the same function as the ``devices`` argument but implies ``accelerator="gpu"``.
68-
tpu_cores: Provides the same function as the ``devices`` argument but implies ``accelerator="tpu"``.
6967
"""
7068

7169
def __init__(
@@ -76,8 +74,6 @@ def __init__(
7674
num_nodes: int = 1,
7775
precision: Union[int, str] = 32,
7876
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
79-
gpus: Optional[Union[List[int], str, int]] = None,
80-
tpu_cores: Optional[Union[List[int], str, int]] = None,
8177
) -> None:
8278
self._connector = _Connector(
8379
accelerator=accelerator,
@@ -86,8 +82,6 @@ def __init__(
8682
num_nodes=num_nodes,
8783
precision=precision,
8884
plugins=plugins,
89-
tpu_cores=tpu_cores,
90-
gpus=gpus,
9185
)
9286
self._strategy: Strategy = self._connector.strategy
9387
self._accelerator: Accelerator = self._connector.accelerator

src/lightning_lite/plugins/precision/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from lightning_lite.plugins.precision.precision import Precision # isort:skip
1415
from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision
16+
from lightning_lite.plugins.precision.double import DoublePrecision
1517
from lightning_lite.plugins.precision.mixed import MixedPrecision
1618
from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision
17-
from lightning_lite.plugins.precision.precision import Precision
1819
from lightning_lite.plugins.precision.tpu import TPUPrecision
1920
from lightning_lite.plugins.precision.tpu_bf16 import TPUBf16Precision
2021

2122
__all__ = [
2223
"DeepSpeedPrecision",
24+
"DoublePrecision",
2325
"MixedPrecision",
2426
"NativeMixedPrecision",
2527
"Precision",

src/lightning_lite/strategies/xla.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def __init__(
5454
parallel_devices: Optional[List[torch.device]] = None,
5555
checkpoint_io: Optional[CheckpointIO] = None,
5656
precision_plugin: Optional[Precision] = None,
57-
**_: Any,
5857
) -> None:
5958
super().__init__(
6059
accelerator=accelerator,

src/pytorch_lightning/callbacks/early_stopping.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,9 @@ def _improvement_message(self, current: Tensor) -> str:
261261

262262
@staticmethod
263263
def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None:
264-
rank = _get_rank(strategy=(trainer.strategy if trainer is not None else None)) # type: ignore[arg-type]
264+
rank = _get_rank(
265+
strategy=(trainer.strategy if trainer is not None else None), # type: ignore[arg-type]
266+
)
265267
if trainer is not None and trainer.world_size <= 1:
266268
rank = None
267269
message = rank_prefixed_message(message, rank)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pytorch_lightning.lite.lite import LightningLite
16+
17+
__all__ = ["LightningLite"]

0 commit comments

Comments
 (0)