Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed the `precision_plugin` attribute from `Accelerator` in favor of its equivalent attribute `precision_plugin` in the `TrainingTypePlugin` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))


- Removed `DeepSpeedPlugin.{precision,amp_type,amp_level}` properties ([#10657](https://github.com/PyTorchLightning/pytorch-lightning/pull/10657))


### Fixed

- When a tensor is logged with `self.log`, run its computation with the same `dtype` ([#10076](https://github.com/PyTorchLightning/pytorch-lightning/pull/10076))
Expand Down
13 changes: 0 additions & 13 deletions pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None)
return seed_everything(seed=seed, workers=workers)

def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
self._set_plugin_specific_precision_variables()
self._accelerator.setup_environment()

# apply sharded context to prevent OOM
Expand All @@ -400,11 +399,6 @@ def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs:
with self._strategy.model_sharded_context(), _replace_dataloader_init_method():
return run_method(*args, **kwargs)

def _set_plugin_specific_precision_variables(self) -> None:
# todo: these are hacks as plugins rely on access to the precision plugin
if isinstance(self._strategy, DeepSpeedPlugin):
self._set_deepspeed_precision_variables()

def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
if isinstance(self._strategy, TPUSpawnPlugin):
# When the user creates the optimizer, they reference the parameters on the CPU.
Expand All @@ -423,13 +417,6 @@ def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -
model = self.to_device(model)
return model

def _set_deepspeed_precision_variables(self) -> None:
# TODO: Refactor this once precision pluging is part of the strategy.
amp_type = self._accelerator_connector.amp_type
amp_level = self._accelerator_connector.amp_level
precision = self._accelerator_connector.precision
self._strategy._amp_level, self._strategy._amp_type, self._strategy._precision = amp_level, amp_type, precision

def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
return (
self._accelerator_connector.is_distributed
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Union
from typing import Any, Callable, Optional, Union

from torch import Tensor
from torch.nn import Module
Expand All @@ -34,9 +34,11 @@
class DeepSpeedPrecisionPlugin(PrecisionPlugin):
"""Precision plugin for DeepSpeed integration."""

def __init__(self, precision: int) -> None:
def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None:
super().__init__()
self.precision = precision
self.amp_type = amp_type
self.amp_level = amp_level

def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None:
if is_overridden("backward", model):
Expand Down
44 changes: 12 additions & 32 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import AMPType, GradClipAlgorithmType
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.distributed import log, rank_zero_info
from pytorch_lightning.utilities.enums import _StrategyType, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -327,24 +327,6 @@ def __init__(
self.hysteresis = hysteresis
self.min_loss_scale = min_loss_scale

# optionally set by Lite
self._precision: Optional[Union[str, int]] = None
self._amp_level: Optional[str] = None
self._amp_type: Optional[str] = None

@property
def precision(self) -> Union[str, int]:
return self._precision or self.precision_plugin.precision

@property
def amp_level(self) -> Optional[str]:
if self._amp_type == AMPType.APEX:
return self._amp_level or self.lightning_module.trainer._accelerator_connector.amp_level

@property
def amp_type(self) -> Optional[str]:
return self._amp_type or self.lightning_module.trainer._accelerator_connector.amp_type

def _load_config(self, config):
if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
Expand Down Expand Up @@ -459,11 +441,11 @@ def init_deepspeed(self):
"DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs."
)

model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision)
model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision_plugin.precision)

if self.zero_stage_3 and self.partition_module:
# Ensure the entire model has been moved to the appropriate device
dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32
dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32
deepspeed.zero.Init(
module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
)
Expand Down Expand Up @@ -520,7 +502,7 @@ def _initialize_deepspeed_train(self, model):
def model_sharded_context(self) -> Generator[None, None, None]:
if self.zero_stage_3:
assert self._config_initialized
dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32
dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32
model_parallel_context = deepspeed.zero.Init(
remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
)
Expand Down Expand Up @@ -646,11 +628,9 @@ def _auto_select_batch_size(self):
)
return batch_size

def _format_precision_config(self):
if self.amp_type == AMPType.APEX:
amp_level = self.amp_level
if self.precision in (16, "mixed"):
if "fp16" not in self.config and self.amp_type == AMPType.NATIVE:
def _format_precision_config(self) -> None:
if self.precision_plugin.precision in (16, "mixed"):
if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE:
# FP16 is a DeepSpeed standalone AMP implementation
rank_zero_info("Enabling DeepSpeed FP16.")
self.config["fp16"] = {
Expand All @@ -661,9 +641,9 @@ def _format_precision_config(self):
"hysteresis": self.hysteresis,
"min_loss_scale": self.min_loss_scale,
}
elif "amp" not in self.config and self.amp_type == AMPType.APEX:
rank_zero_only("Enabling DeepSpeed APEX Implementation.")
self.config["amp"] = {"enabled": True, "opt_level": amp_level}
elif "amp" not in self.config and self.precision_plugin.amp_type == AMPType.APEX:
rank_zero_info("Enabling DeepSpeed APEX Implementation.")
self.config["amp"] = {"enabled": True, "opt_level": self.precision_plugin.amp_level}

def _create_default_config(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def select_precision_plugin(self) -> PrecisionPlugin:
return TPUBf16PrecisionPlugin()

if self._distrib_type == _StrategyType.DEEPSPEED or isinstance(self._training_type_plugin, DeepSpeedPlugin):
return DeepSpeedPrecisionPlugin(self.precision)
return DeepSpeedPrecisionPlugin(self.precision, self.amp_type, self.amp_level)

if self.precision == 32:
return PrecisionPlugin()
Expand Down