Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions src/lightning_lite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,11 @@
_logger.addHandler(logging.StreamHandler())
_logger.propagate = False

# TODO(lite): Re-enable this import
# from lightning_lite.lite import LightningLite

from lightning_lite.lite import LightningLite # noqa: E402
from lightning_lite.utilities.seed import seed_everything # noqa: E402

__all__ = [
# TODO(lite): Re-enable this import
# "LightningLite",
"seed_everything",
]
__all__ = ["LightningLite", "seed_everything"]

# for compatibility with namespace packages
__import__("pkg_resources").declare_namespace(__name__)
63 changes: 4 additions & 59 deletions src/lightning_lite/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@
XLAStrategy,
)
from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES
from lightning_lite.utilities import _StrategyType, rank_zero_deprecation, rank_zero_info, rank_zero_warn
from lightning_lite.utilities import _StrategyType, rank_zero_info, rank_zero_warn
from lightning_lite.utilities.device_parser import determine_root_gpu_device
from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE, _TPU_AVAILABLE

_PLUGIN = Union[Strategy, Precision, ClusterEnvironment, CheckpointIO]
_PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO]
_PLUGIN_INPUT = Union[_PLUGIN, str]


Expand Down Expand Up @@ -99,8 +99,6 @@ def __init__(
num_nodes: int = 1,
precision: Union[int, str] = 32,
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
tpu_cores: Optional[Union[List[int], str, int]] = None, # deprecated
gpus: Optional[Union[List[int], str, int]] = None, # deprecated
) -> None:
# 1. Parsing flags
# Get registered strategies, built-in accelerators and precision plugins
Expand All @@ -125,9 +123,7 @@ def __init__(
precision=precision,
plugins=plugins,
)
self._check_device_config_and_set_final_flags(
devices=devices, num_nodes=num_nodes, gpus=gpus, tpu_cores=tpu_cores
)
self._check_device_config_and_set_final_flags(devices=devices, num_nodes=num_nodes)

# 2. Instantiate Accelerator
# handle `auto`, `None` and `gpu`
Expand Down Expand Up @@ -278,11 +274,7 @@ def _check_config_and_set_final_flags(
self._parallel_devices = self._strategy_flag.parallel_devices

def _check_device_config_and_set_final_flags(
self,
devices: Optional[Union[List[int], str, int]],
num_nodes: int,
gpus: Optional[Union[List[int], str, int]],
tpu_cores: Optional[Union[List[int], str, int]],
self, devices: Optional[Union[List[int], str, int]], num_nodes: int
) -> None:
self._num_nodes_flag = int(num_nodes) if num_nodes is not None else 1
self._devices_flag = devices
Expand All @@ -298,56 +290,12 @@ def _check_device_config_and_set_final_flags(
f" using {accelerator_name} accelerator."
)

# TODO: Delete this method when num_processes, gpus, ipus and tpu_cores gets removed
self._map_deprecated_devices_specific_info_to_accelerator_and_device_flag(devices, gpus, tpu_cores)

if self._devices_flag == "auto" and self._accelerator_flag is None:
raise ValueError(
f"You passed `devices={devices}` but haven't specified"
" `accelerator=('auto'|'tpu'|'gpu'|'cpu'|'mps')` for the devices mapping."
)

def _map_deprecated_devices_specific_info_to_accelerator_and_device_flag(
self,
devices: Optional[Union[List[int], str, int]],
gpus: Optional[Union[List[int], str, int]],
tpu_cores: Optional[Union[List[int], str, int]],
) -> None:
"""Emit deprecation warnings for num_processes, gpus, ipus, tpu_cores and set the `devices_flag` and
`accelerator_flag`."""
if gpus is not None:
rank_zero_deprecation(
f"Setting `Lite(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
f" in v2.0. Please use `Lite(accelerator='gpu', devices={gpus!r})` instead."
)
if tpu_cores is not None:
rank_zero_deprecation(
f"Setting `Lite(tpu_cores={tpu_cores!r})` is deprecated in v1.7 and will be removed"
f" in v2.0. Please use `Lite(accelerator='tpu', devices={tpu_cores!r})` instead."
)
self._gpus: Optional[Union[List[int], str, int]] = gpus
self._tpu_cores: Optional[Union[List[int], str, int]] = tpu_cores
deprecated_devices_specific_flag = gpus or tpu_cores
if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in ([], 0, "0"):
if devices:
# TODO: improve error message
rank_zero_warn(
f"The flag `devices={devices}` will be ignored, "
f"instead the device specific number {deprecated_devices_specific_flag} will be used"
)

if [(gpus is not None), (tpu_cores is not None)].count(True) > 1:
# TODO: improve error message
rank_zero_warn("more than one device specific flag has been set")
self._devices_flag = deprecated_devices_specific_flag

if self._accelerator_flag is None:
# set accelerator type based on num_processes, gpus, ipus, tpu_cores
if tpu_cores:
self._accelerator_flag = "tpu"
if gpus:
self._accelerator_flag = "cuda"

def _choose_auto_accelerator(self) -> str:
"""Choose the accelerator type (str) based on availability when ``accelerator='auto'``."""
if self._accelerator_flag == "auto":
Expand Down Expand Up @@ -392,9 +340,6 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:

self._set_devices_flag_if_auto_passed()

self._gpus = self._devices_flag if not self._gpus else self._gpus
self._tpu_cores = self._devices_flag if not self._tpu_cores else self._tpu_cores

self._devices_flag = self.accelerator.parse_devices(self._devices_flag)
if not self._parallel_devices:
self._parallel_devices = self.accelerator.get_parallel_devices(self._devices_flag)
Expand Down
Loading