Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/lightning_lite/plugins/precision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union

import torch

from lightning_lite.utilities.enums import PrecisionType


def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor:
if torch.is_floating_point(tensor):
if precision == PrecisionType.HALF:
return tensor.half()
if precision == PrecisionType.BFLOAT:
return tensor.bfloat16()

if precision == PrecisionType.HALF:
return _convert_fp_tensor(tensor, torch.half)
if precision == PrecisionType.BFLOAT:
return _convert_fp_tensor(tensor, torch.bfloat16)
return tensor


def _convert_fp_tensor(tensor: torch.Tensor, dst_type: Union[str, torch.dtype]) -> torch.Tensor:
return tensor.to(dst_type) if torch.is_floating_point(tensor) else tensor
98 changes: 9 additions & 89 deletions src/lightning_lite/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,67 +47,10 @@ def device(self) -> torch.device:
return device

def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type]
"""Moves and/or casts the parameters and buffers.

This can be called as
.. function:: to(device=None, dtype=None, non_blocking=False)
.. function:: to(dtype, non_blocking=False)
.. function:: to(tensor, non_blocking=False)
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point desired :attr:`dtype` s. In addition, this method will
only cast the floating point parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
See below for examples.

Note:
This method modifies the module in-place.

Args:
device: the desired device of the parameters
and buffers in this module
dtype: the desired floating point type of
the floating point parameters and buffers in this module
tensor: Tensor whose dtype and device are the desired
dtype and device for all parameters and buffers in this module

Returns:
Module: self

Example::
>>> from torch import Tensor
>>> class ExampleModule(_DeviceDtypeModuleMixin):
... def __init__(self, weight: Tensor):
... super().__init__()
... self.register_buffer('weight', weight)
>>> _ = torch.manual_seed(0)
>>> module = ExampleModule(torch.rand(3, 4))
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]])
>>> module.to(torch.double)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float64)
>>> cpu = torch.device('cpu')
>>> module.to(cpu, dtype=torch.half, non_blocking=True)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float16)
>>> module.to(cpu)
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float16)
>>> module.device
device(type='cpu')
>>> module.dtype
torch.float16
"""
# there is diff nb vars in PT 1.5
out = torch._C._nn._parse_to(*args, **kwargs)
self.__update_properties(device=out[0], dtype=out[1])
"""See :meth:`torch.nn.Module.to`."""
# this converts `str` device to `torch.device`
device, dtype = torch._C._nn._parse_to(*args, **kwargs)[:2]
self.__update_properties(device=device, dtype=dtype)
return super().to(*args, **kwargs)

def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # type: ignore[valid-type]
Expand All @@ -130,50 +73,27 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # ty
return super().cuda(device=device)

def cpu(self) -> Self: # type: ignore[valid-type]
"""Moves all model parameters and buffers to the CPU.

Returns:
Module: self
"""
"""See :meth:`torch.nn.Module.cpu`."""
self.__update_properties(device=torch.device("cpu"))
return super().cpu()

def type(self, dst_type: Union[str, torch.dtype]) -> Self: # type: ignore[valid-type]
"""Casts all parameters and buffers to :attr:`dst_type`.

Arguments:
dst_type (type or string): the desired type

Returns:
Module: self
"""
"""See :meth:`torch.nn.Module.type`."""
self.__update_properties(dtype=dst_type)
return super().type(dst_type=dst_type)

def float(self) -> Self: # type: ignore[valid-type]
"""Casts all floating point parameters and buffers to ``float`` datatype.

Returns:
Module: self
"""
"""See :meth:`torch.nn.Module.float`."""
self.__update_properties(dtype=torch.float)
return super().float()

def double(self) -> Self: # type: ignore[valid-type]
"""Casts all floating point parameters and buffers to ``double`` datatype.

Returns:
Module: self
"""
"""See :meth:`torch.nn.Module.double`."""
self.__update_properties(dtype=torch.double)
return super().double()

def half(self) -> Self: # type: ignore[valid-type]
"""Casts all floating point parameters and buffers to ``half`` datatype.

Returns:
Module: self
"""
"""See :meth:`torch.nn.Module.half`."""
self.__update_properties(dtype=torch.half)
return super().half()

Expand Down
15 changes: 7 additions & 8 deletions src/lightning_lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.utils.data import DataLoader

from lightning_lite.plugins import Precision
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.strategies import Strategy
from lightning_lite.utilities import move_data_to_device
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
Expand Down Expand Up @@ -104,18 +105,16 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
64: torch.float64,
}
# TODO: let the precision plugin handle the conversion
to_type = precision_to_type[precision]

def _convert_float_tensor(t: Tensor) -> Tensor:
return t.to(to_type) if torch.is_floating_point(t) else t

args, kwargs = apply_to_collection([args, kwargs], function=_convert_float_tensor, dtype=Tensor)
args, kwargs = apply_to_collection(
[args, kwargs], dtype=Tensor, function=_convert_fp_tensor, dst_type=precision_to_type[precision]
)

with self._precision_plugin.forward_context():
output = self._forward_module(*args, **kwargs)

to_type = torch.get_default_dtype()
output = apply_to_collection(output, function=_convert_float_tensor, dtype=Tensor)
output = apply_to_collection(
output, dtype=Tensor, function=_convert_fp_tensor, dst_type=torch.get_default_dtype()
)
return output

@overload
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def all_gather(

def forward(self, *args: Any, **kwargs: Any) -> Any:
r"""
Same as :meth:`torch.nn.Module.forward()`.
Same as :meth:`torch.nn.Module.forward`.

Args:
*args: Whatever you decide to pass into the forward method.
Expand Down
9 changes: 2 additions & 7 deletions src/pytorch_lightning/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch.optim import Optimizer

import pytorch_lightning as pl
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin

Expand All @@ -33,15 +34,9 @@ class LightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase):
pl_module: the model to wrap
"""

@staticmethod
def _to_double_precision(data: Tensor) -> Tensor:
if data.is_floating_point():
return data.double()
return data

@staticmethod
def _move_float_tensors_to_double(collection: Any) -> Any:
return apply_to_collection(collection, Tensor, LightningDoublePrecisionModule._to_double_precision)
return apply_to_collection(collection, Tensor, function=_convert_fp_tensor, dst_type=torch.double)

def training_step(self, *args: Any, **kwargs: Any) -> Any:
return self.module.training_step(
Expand Down
10 changes: 1 addition & 9 deletions src/pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,9 @@ def __init__(
self.precision = precision

def forward(self, *inputs: Any, **kwargs: Any) -> Any:
inputs = apply_to_collection(inputs, Tensor, function=self._batch_to)
inputs = apply_to_collection(inputs, Tensor, function=_fp_to_half, precision=self.precision)
return super().forward(*inputs, **kwargs)

def _batch_to(self, batch: Tensor) -> Tensor:
if torch.is_floating_point(batch):
if self.precision == PrecisionType.HALF:
return batch.half()
elif self.precision == PrecisionType.BFLOAT:
return batch.bfloat16()
return batch


class DeepSpeedStrategy(DDPStrategy):
strategy_name = "deepspeed"
Expand Down
17 changes: 2 additions & 15 deletions src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import FloatTensor, Tensor
from torch import Tensor
from torch.utils.data import DataLoader, Sampler

import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.plugins.precision.utils import _fp_to_half
from lightning_lite.utilities.cloud_io import get_filesystem
from lightning_lite.utilities.enums import PrecisionType
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.parallel import ParallelStrategy
Expand Down Expand Up @@ -61,21 +60,9 @@ def __init__(
self.precision = precision

def forward(self, *inputs: Any, **kwargs: Any) -> Any:
if self.precision == PrecisionType.HALF:
inputs = self._move_float_tensors_to_half(inputs)

inputs = apply_to_collection(inputs, Tensor, function=_fp_to_half, precision=self.precision)
return super().forward(*inputs, **kwargs)

@staticmethod
def batch_to(data: Tensor) -> Tensor:
if torch.is_floating_point(data):
return data.half()
return data

def _move_float_tensors_to_half(self, batch: Any) -> Any:
batch = apply_to_collection(batch, (FloatTensor, torch.cuda.FloatTensor), function=self.batch_to)
return batch


class IPUStrategy(ParallelStrategy):
"""Plugin for training on IPU devices."""
Expand Down
63 changes: 57 additions & 6 deletions tests/tests_lite/utilities/test_device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self) -> None:


@pytest.mark.parametrize(
"dst_device_str,dst_dtype",
"dst_device_str,dst_type",
[
("cpu", torch.half),
("cpu", torch.float),
Expand All @@ -35,21 +35,19 @@ def __init__(self) -> None:
],
)
@RunIf(min_cuda_gpus=1)
def test_submodules_device_and_dtype(dst_device_str, dst_dtype):
def test_submodules_device_and_dtype(dst_device_str, dst_type):
"""Test that the device and dtype property updates propagate through mixed nesting of regular nn.Modules and
the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule)."""

dst_device = torch.device(dst_device_str)

model = TopModule()
assert model.device == torch.device("cpu")
model = model.to(device=dst_device, dtype=dst_dtype)
model = model.to(device=dst_device, dtype=dst_type)
# nn.Module does not have these attributes
assert not hasattr(model.module, "_device")
assert not hasattr(model.module, "_dtype")
# device and dtype change should propagate down into all children
assert model.device == model.module.module.device == dst_device
assert model.dtype == model.module.module.dtype == dst_dtype
assert model.dtype == model.module.module.dtype == dst_type


@pytest.mark.parametrize(
Expand All @@ -72,6 +70,16 @@ def test_cuda_device(device):
assert device.index == torch.cuda.current_device()


@RunIf(min_cuda_gpus=1)
def test_cpu_device():
model = SubSubModule().cuda()
assert model.device.type == "cuda"
assert model.device.index == 0
model.cpu()
assert model.device.type == "cpu"
assert model.device.index is None


@RunIf(min_cuda_gpus=2)
def test_cuda_current_device():
"""Test that calling .cuda() moves the model to the correct device and respects current cuda device setting."""
Expand All @@ -92,3 +100,46 @@ def __init__(self):
model.cuda() # model is already on device 1, and calling .cuda() without device index should not move model
assert model.device == torch.device("cuda", 1)
assert model.layer.weight.device == torch.device("cuda", 1)


class ExampleModule(_DeviceDtypeModuleMixin):
def __init__(self, weight):
super().__init__()
self.register_buffer("weight", weight)


def test_to_combinations():
module = ExampleModule(torch.rand(3, 4))
# sanity check
assert module.weight.shape == (3, 4)
assert module.weight.dtype is torch.float32
# positional dtype
module.to(torch.double)
assert module.weight.dtype is torch.float64
# positional device
module.to("cpu", dtype=torch.half, non_blocking=True)
assert module.weight.dtype is torch.float16
assert module.device == torch.device("cpu")
assert module.dtype is torch.float16


def test_dtype_conversions():
module = ExampleModule(torch.tensor(1))
# different dtypes
assert module.weight.dtype is torch.int64
assert module.dtype is torch.float32
# `.double()` skips non floating points
module.double()
assert module.weight.dtype is torch.int64
assert module.dtype is torch.float64
# but `type` doesn't
module.type(torch.float)
assert module.weight.dtype is torch.float32
assert module.dtype is torch.float32
# now, test the rest
module.float()
assert module.weight.dtype is torch.float32
assert module.dtype is torch.float32
module.half()
assert module.weight.dtype is torch.float16
assert module.dtype is torch.float16
Loading