diff --git a/CHANGELOG.md b/CHANGELOG.md index adb1b070dc386..ad5029a09bcae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -162,6 +162,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the `precision_plugin` attribute from `Accelerator` in favor of its equivalent attribute `precision_plugin` in the `TrainingTypePlugin` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) + +- Removed `DeepSpeedPlugin.{precision,amp_type,amp_level}` properties ([#10657](https://github.com/PyTorchLightning/pytorch-lightning/pull/10657)) + + ### Fixed - When a tensor is logged with `self.log`, run its computation with the same `dtype` ([#10076](https://github.com/PyTorchLightning/pytorch-lightning/pull/10076)) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 4997d7db779e7..6c41c80a56171 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -385,7 +385,6 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) return seed_everything(seed=seed, workers=workers) def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: - self._set_plugin_specific_precision_variables() self._accelerator.setup_environment() # apply sharded context to prevent OOM @@ -400,11 +399,6 @@ def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: with self._strategy.model_sharded_context(), _replace_dataloader_init_method(): return run_method(*args, **kwargs) - def _set_plugin_specific_precision_variables(self) -> None: - # todo: these are hacks as plugins rely on access to the precision plugin - if isinstance(self._strategy, DeepSpeedPlugin): - self._set_deepspeed_precision_variables() - def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: if isinstance(self._strategy, TPUSpawnPlugin): # When the user creates the optimizer, they reference the parameters on the CPU. @@ -423,13 +417,6 @@ def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) - model = self.to_device(model) return model - def _set_deepspeed_precision_variables(self) -> None: - # TODO: Refactor this once precision pluging is part of the strategy. - amp_type = self._accelerator_connector.amp_type - amp_level = self._accelerator_connector.amp_level - precision = self._accelerator_connector.precision - self._strategy._amp_level, self._strategy._amp_type, self._strategy._precision = amp_level, amp_type, precision - def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool: return ( self._accelerator_connector.is_distributed diff --git a/pytorch_lightning/plugins/precision/deepspeed.py b/pytorch_lightning/plugins/precision/deepspeed.py index 27ac384d25303..46cf023fc5d32 100644 --- a/pytorch_lightning/plugins/precision/deepspeed.py +++ b/pytorch_lightning/plugins/precision/deepspeed.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Union +from typing import Any, Callable, Optional, Union from torch import Tensor from torch.nn import Module @@ -34,9 +34,11 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin): """Precision plugin for DeepSpeed integration.""" - def __init__(self, precision: int) -> None: + def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None: super().__init__() self.precision = precision + self.amp_type = amp_type + self.amp_level = amp_level def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None: if is_overridden("backward", model): diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 01959bdcee212..86d380ac24ce8 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -34,10 +34,10 @@ from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import AMPType, GradClipAlgorithmType +from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only -from pytorch_lightning.utilities.enums import _StrategyType +from pytorch_lightning.utilities.distributed import log, rank_zero_info +from pytorch_lightning.utilities.enums import _StrategyType, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden @@ -327,24 +327,6 @@ def __init__( self.hysteresis = hysteresis self.min_loss_scale = min_loss_scale - # optionally set by Lite - self._precision: Optional[Union[str, int]] = None - self._amp_level: Optional[str] = None - self._amp_type: Optional[str] = None - - @property - def precision(self) -> Union[str, int]: - return self._precision or self.precision_plugin.precision - - @property - def amp_level(self) -> Optional[str]: - if self._amp_type == AMPType.APEX: - return self._amp_level or self.lightning_module.trainer._accelerator_connector.amp_level - - @property - def amp_type(self) -> Optional[str]: - return self._amp_type or self.lightning_module.trainer._accelerator_connector.amp_type - def _load_config(self, config): if config is None and self.DEEPSPEED_ENV_VAR in os.environ: rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") @@ -459,11 +441,11 @@ def init_deepspeed(self): "DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs." ) - model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision) + model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision_plugin.precision) if self.zero_stage_3 and self.partition_module: # Ensure the entire model has been moved to the appropriate device - dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32 + dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32 deepspeed.zero.Init( module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype ) @@ -520,7 +502,7 @@ def _initialize_deepspeed_train(self, model): def model_sharded_context(self) -> Generator[None, None, None]: if self.zero_stage_3: assert self._config_initialized - dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32 + dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32 model_parallel_context = deepspeed.zero.Init( remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype ) @@ -646,11 +628,9 @@ def _auto_select_batch_size(self): ) return batch_size - def _format_precision_config(self): - if self.amp_type == AMPType.APEX: - amp_level = self.amp_level - if self.precision in (16, "mixed"): - if "fp16" not in self.config and self.amp_type == AMPType.NATIVE: + def _format_precision_config(self) -> None: + if self.precision_plugin.precision in (16, "mixed"): + if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE: # FP16 is a DeepSpeed standalone AMP implementation rank_zero_info("Enabling DeepSpeed FP16.") self.config["fp16"] = { @@ -661,9 +641,9 @@ def _format_precision_config(self): "hysteresis": self.hysteresis, "min_loss_scale": self.min_loss_scale, } - elif "amp" not in self.config and self.amp_type == AMPType.APEX: - rank_zero_only("Enabling DeepSpeed APEX Implementation.") - self.config["amp"] = {"enabled": True, "opt_level": amp_level} + elif "amp" not in self.config and self.precision_plugin.amp_type == AMPType.APEX: + rank_zero_info("Enabling DeepSpeed APEX Implementation.") + self.config["amp"] = {"enabled": True, "opt_level": self.precision_plugin.amp_level} def _create_default_config( self, diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 7136437bbc69d..c95d46e77b977 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -637,7 +637,7 @@ def select_precision_plugin(self) -> PrecisionPlugin: return TPUBf16PrecisionPlugin() if self._distrib_type == _StrategyType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin): - return DeepSpeedPrecisionPlugin(self.precision) + return DeepSpeedPrecisionPlugin(self.precision, self.amp_type, self.amp_level) if self.precision == 32: return PrecisionPlugin()