Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def amp_backend(self) -> Optional[LightningEnum]:
return None

@property
def precision(self) -> int:
def precision(self) -> Union[str, int]:
return self.precision_plugin.precision

@property
Expand Down
17 changes: 2 additions & 15 deletions pytorch_lightning/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Generator, Optional, Sequence, Tuple

from torch.nn import Module
from abc import ABC
from typing import Generator


class Plugin(ABC):
"""Basic Plugin class to derive precision and training type plugins from."""

@abstractmethod
def connect(
self,
model: Module,
*args: Sequence,
**kwargs: Sequence,
) -> Optional[Tuple[Module, Sequence, Sequence]]:
"""Connects the plugin with the accelerator (and thereby with trainer and model).
Will be called by the accelerator.
"""

def pre_dispatch(self) -> None:
"""Hook to do something before the training/evaluation/prediction starts."""

Expand Down
40 changes: 24 additions & 16 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Tuple
from typing import Any, Callable, Generator, List, Sequence, Tuple, Type, TYPE_CHECKING

import torch
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
Expand All @@ -23,37 +22,41 @@
if _APEX_AVAILABLE:
from apex import amp

if TYPE_CHECKING:
from torch.optim import Optimizer


class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""

def __init__(self, amp_level: str):
def __init__(self, amp_level: str) -> None:
self.backend = AMPType.APEX
self.amp_level = amp_level

def master_params(self, optimizer: torch.optim.Optimizer):
def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]:
return amp.master_params(optimizer)

def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
def connect(self, model: torch.nn.Module, optimizers: Sequence['Optimizer'],
lr_schedulers: Sequence[Any]) -> Tuple[torch.nn.Module, Sequence['Optimizer'], Sequence[Any]]:
"""Connects the precision plugin to the training process,
configures apex and reinits the schedulers
"""
if model.device.type != "cuda":
return model, optimizers, lr_schedulers
model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level)
model, optimizers = self.configure_apex(amp, model, list(optimizers), self.amp_level)
self.reinit_scheduler_properties(optimizers, lr_schedulers)
return model, optimizers, lr_schedulers

def backward(
self,
model: LightningModule,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
optimizer: 'Optimizer',
opt_idx: int,
should_accumulate: bool,
*args,
**kwargs,
):
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
"""performs the actual backpropagation

Args:
Expand Down Expand Up @@ -94,11 +97,11 @@ def backward(

def configure_apex(
self,
amp: object,
amp: Type,
model: LightningModule,
optimizers: List[Optimizer],
optimizers: List['Optimizer'],
amp_level: str,
) -> Tuple[LightningModule, List[Optimizer]]:
) -> Tuple[LightningModule, List['Optimizer']]:
r"""
Override to init AMP your own way.
Must return a model and list of optimizers.
Expand Down Expand Up @@ -127,7 +130,7 @@ def configure_apex(self, amp, model, optimizers, amp_level):
return model, optimizers

@staticmethod
def reinit_scheduler_properties(optimizers: list, schedulers: list):
def reinit_scheduler_properties(optimizers: Sequence['Optimizer'], schedulers: Sequence[Any]) -> None:
"""Reinitializes schedulers with correct properties"""
# Reinitialize optimizer.step properties added by schedulers
for scheduler in schedulers:
Expand All @@ -149,7 +152,12 @@ def reinit_scheduler_properties(optimizers: list, schedulers: list):
break

def pre_optimizer_step(
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
self,
pl_module: LightningModule,
optimizer: 'Optimizer',
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
"""
always called before the optimizer step.
Expand All @@ -160,6 +168,6 @@ def pre_optimizer_step(
if not pl_module.automatic_optimization:
pl_module.trainer.call_hook("on_after_backward")

optimizer.step()
optimizer.step(**kwargs)

return False
38 changes: 24 additions & 14 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Union
from typing import Any, Callable, TYPE_CHECKING, Union

import torch
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache

if TYPE_CHECKING:
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule

warning_cache = WarningCache()


class DeepSpeedPrecisionPlugin(PrecisionPlugin):

def __init__(self, precision):
def __init__(self, precision: int) -> None:
super().__init__()
self.precision = precision

def pre_optimizer_step(
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
self,
pl_module: 'LightningModule',
optimizer: 'Optimizer',
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
deepspeed_engine = pl_module.trainer.model
# DeepSpeed not support closures.
Expand All @@ -46,28 +54,30 @@ def pre_optimizer_step(

def backward(
self,
lightning_module: LightningModule,
model: 'LightningModule',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this API change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was named model in all the other plugins and I checked, It is never called by name.
But if we don't do this, MypY complains, that the signature does not match with base class since it has different parameters

closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
optimizer: 'Optimizer',
opt_idx: int,
should_accumulate: bool,
*args,
**kwargs,
):
if is_overridden('backward', lightning_module):
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
if is_overridden('backward', model):
warning_cache.warn(
"Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles"
"backward logic outside of the LightningModule"
)
# todo: hack around for deepspeed engine to call backward
deepspeed_engine = lightning_module.trainer.model
deepspeed_engine.backward(closure_loss, **kwargs)
deepspeed_engine = model.trainer.model
deepspeed_engine.backward(closure_loss, *args, **kwargs)
# once backward has been applied, release graph
closure_loss = closure_loss.detach()

return closure_loss

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)):
def clip_gradients(
self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0
) -> None:
"""
DeepSpeed handles clipping gradients via the training type plugin.
"""
Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/plugins/precision/mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Union

from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import AMPType

if TYPE_CHECKING:
from pytorch_lightning.utilities import AMPType


class MixedPrecisionPlugin(PrecisionPlugin):
"""Base Class for mixed precision"""

EPSILON = 1e-5
backend: AMPType
precision = "mixed"
EPSILON: float = 1e-5
backend: 'AMPType'
precision: Union[str, int] = "mixed"
36 changes: 20 additions & 16 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Callable, Generator
from typing import Any, Callable, Generator, TYPE_CHECKING

import torch
from torch.optim import LBFGS, Optimizer
from torch.optim import LBFGS

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, AMPType
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _NATIVE_AMP_AVAILABLE:
from torch.cuda.amp import autocast
else:
autocast = None
if TYPE_CHECKING:
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule


class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):

def __init__(self):
def __init__(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it necessary to annotate init with None?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.backend = AMPType.NATIVE
self.scaler = torch.cuda.amp.GradScaler()

def backward(
self,
model: LightningModule,
model: 'LightningModule',
closure_loss: torch.Tensor,
optimizer: Optimizer,
optimizer: 'Optimizer',
opt_idx: int,
should_accumulate: bool,
*args,
**kwargs,
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
"""performs the actual backpropagation

Expand All @@ -65,7 +64,12 @@ def backward(
return closure_loss

def pre_optimizer_step(
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
self,
pl_module: 'LightningModule',
optimizer: 'Optimizer',
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
"""always called before the optimizer step.
Checks that the optimizer is not LBFGS, as this one is not supported by native amp
Expand All @@ -83,13 +87,13 @@ def pre_optimizer_step(

return False

def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
def post_optimizer_step(self, optimizer: 'Optimizer', optimizer_idx: int) -> None:
"""Updates the GradScaler"""
self.scaler.step(optimizer)
self.scaler.update()

@contextmanager
def train_step_context(self) -> Generator[autocast, None, None]:
def train_step_context(self) -> Generator[None, None, None]:
"""Enable autocast context"""
with torch.cuda.amp.autocast():
yield
Loading