Skip to content

Commit 3136bad

Browse files
carmoccajustusschock
authored andcommitted
Clean-up dtype management (#14823)
Co-authored-by: Justus Schock <[email protected]>
1 parent 62038cf commit 3136bad

File tree

9 files changed

+90
-161
lines changed

9 files changed

+90
-161
lines changed

src/lightning_lite/plugins/precision/utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Union
1415

1516
import torch
1617

1718
from lightning_lite.utilities.enums import PrecisionType
1819

1920

2021
def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor:
21-
if torch.is_floating_point(tensor):
22-
if precision == PrecisionType.HALF:
23-
return tensor.half()
24-
if precision == PrecisionType.BFLOAT:
25-
return tensor.bfloat16()
26-
22+
if precision == PrecisionType.HALF:
23+
return _convert_fp_tensor(tensor, torch.half)
24+
if precision == PrecisionType.BFLOAT:
25+
return _convert_fp_tensor(tensor, torch.bfloat16)
2726
return tensor
27+
28+
29+
def _convert_fp_tensor(tensor: torch.Tensor, dst_type: Union[str, torch.dtype]) -> torch.Tensor:
30+
return tensor.to(dst_type) if torch.is_floating_point(tensor) else tensor

src/lightning_lite/utilities/device_dtype_mixin.py

Lines changed: 9 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -47,67 +47,10 @@ def device(self) -> torch.device:
4747
return device
4848

4949
def to(self, *args: Any, **kwargs: Any) -> Self: # type: ignore[valid-type]
50-
"""Moves and/or casts the parameters and buffers.
51-
52-
This can be called as
53-
.. function:: to(device=None, dtype=None, non_blocking=False)
54-
.. function:: to(dtype, non_blocking=False)
55-
.. function:: to(tensor, non_blocking=False)
56-
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
57-
floating point desired :attr:`dtype` s. In addition, this method will
58-
only cast the floating point parameters and buffers to :attr:`dtype`
59-
(if given). The integral parameters and buffers will be moved
60-
:attr:`device`, if that is given, but with dtypes unchanged. When
61-
:attr:`non_blocking` is set, it tries to convert/move asynchronously
62-
with respect to the host if possible, e.g., moving CPU Tensors with
63-
pinned memory to CUDA devices.
64-
See below for examples.
65-
66-
Note:
67-
This method modifies the module in-place.
68-
69-
Args:
70-
device: the desired device of the parameters
71-
and buffers in this module
72-
dtype: the desired floating point type of
73-
the floating point parameters and buffers in this module
74-
tensor: Tensor whose dtype and device are the desired
75-
dtype and device for all parameters and buffers in this module
76-
77-
Returns:
78-
Module: self
79-
80-
Example::
81-
>>> from torch import Tensor
82-
>>> class ExampleModule(_DeviceDtypeModuleMixin):
83-
... def __init__(self, weight: Tensor):
84-
... super().__init__()
85-
... self.register_buffer('weight', weight)
86-
>>> _ = torch.manual_seed(0)
87-
>>> module = ExampleModule(torch.rand(3, 4))
88-
>>> module.weight #doctest: +ELLIPSIS
89-
tensor([[...]])
90-
>>> module.to(torch.double)
91-
ExampleModule()
92-
>>> module.weight #doctest: +ELLIPSIS
93-
tensor([[...]], dtype=torch.float64)
94-
>>> cpu = torch.device('cpu')
95-
>>> module.to(cpu, dtype=torch.half, non_blocking=True)
96-
ExampleModule()
97-
>>> module.weight #doctest: +ELLIPSIS
98-
tensor([[...]], dtype=torch.float16)
99-
>>> module.to(cpu)
100-
ExampleModule()
101-
>>> module.weight #doctest: +ELLIPSIS
102-
tensor([[...]], dtype=torch.float16)
103-
>>> module.device
104-
device(type='cpu')
105-
>>> module.dtype
106-
torch.float16
107-
"""
108-
# there is diff nb vars in PT 1.5
109-
out = torch._C._nn._parse_to(*args, **kwargs)
110-
self.__update_properties(device=out[0], dtype=out[1])
50+
"""See :meth:`torch.nn.Module.to`."""
51+
# this converts `str` device to `torch.device`
52+
device, dtype = torch._C._nn._parse_to(*args, **kwargs)[:2]
53+
self.__update_properties(device=device, dtype=dtype)
11154
return super().to(*args, **kwargs)
11255

11356
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # type: ignore[valid-type]
@@ -130,50 +73,27 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: # ty
13073
return super().cuda(device=device)
13174

13275
def cpu(self) -> Self: # type: ignore[valid-type]
133-
"""Moves all model parameters and buffers to the CPU.
134-
135-
Returns:
136-
Module: self
137-
"""
76+
"""See :meth:`torch.nn.Module.cpu`."""
13877
self.__update_properties(device=torch.device("cpu"))
13978
return super().cpu()
14079

14180
def type(self, dst_type: Union[str, torch.dtype]) -> Self: # type: ignore[valid-type]
142-
"""Casts all parameters and buffers to :attr:`dst_type`.
143-
144-
Arguments:
145-
dst_type (type or string): the desired type
146-
147-
Returns:
148-
Module: self
149-
"""
81+
"""See :meth:`torch.nn.Module.type`."""
15082
self.__update_properties(dtype=dst_type)
15183
return super().type(dst_type=dst_type)
15284

15385
def float(self) -> Self: # type: ignore[valid-type]
154-
"""Casts all floating point parameters and buffers to ``float`` datatype.
155-
156-
Returns:
157-
Module: self
158-
"""
86+
"""See :meth:`torch.nn.Module.float`."""
15987
self.__update_properties(dtype=torch.float)
16088
return super().float()
16189

16290
def double(self) -> Self: # type: ignore[valid-type]
163-
"""Casts all floating point parameters and buffers to ``double`` datatype.
164-
165-
Returns:
166-
Module: self
167-
"""
91+
"""See :meth:`torch.nn.Module.double`."""
16892
self.__update_properties(dtype=torch.double)
16993
return super().double()
17094

17195
def half(self) -> Self: # type: ignore[valid-type]
172-
"""Casts all floating point parameters and buffers to ``half`` datatype.
173-
174-
Returns:
175-
Module: self
176-
"""
96+
"""See :meth:`torch.nn.Module.half`."""
17797
self.__update_properties(dtype=torch.half)
17898
return super().half()
17999

src/lightning_lite/wrappers.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torch.utils.data import DataLoader
2323

2424
from lightning_lite.plugins import Precision
25+
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
2526
from lightning_lite.strategies import Strategy
2627
from lightning_lite.utilities import move_data_to_device
2728
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
@@ -104,18 +105,16 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
104105
64: torch.float64,
105106
}
106107
# TODO: let the precision plugin handle the conversion
107-
to_type = precision_to_type[precision]
108-
109-
def _convert_float_tensor(t: Tensor) -> Tensor:
110-
return t.to(to_type) if torch.is_floating_point(t) else t
111-
112-
args, kwargs = apply_to_collection([args, kwargs], function=_convert_float_tensor, dtype=Tensor)
108+
args, kwargs = apply_to_collection(
109+
[args, kwargs], dtype=Tensor, function=_convert_fp_tensor, dst_type=precision_to_type[precision]
110+
)
113111

114112
with self._precision_plugin.forward_context():
115113
output = self._forward_module(*args, **kwargs)
116114

117-
to_type = torch.get_default_dtype()
118-
output = apply_to_collection(output, function=_convert_float_tensor, dtype=Tensor)
115+
output = apply_to_collection(
116+
output, dtype=Tensor, function=_convert_fp_tensor, dst_type=torch.get_default_dtype()
117+
)
119118
return output
120119

121120
@overload

src/pytorch_lightning/core/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ def all_gather(
602602

603603
def forward(self, *args: Any, **kwargs: Any) -> Any:
604604
r"""
605-
Same as :meth:`torch.nn.Module.forward()`.
605+
Same as :meth:`torch.nn.Module.forward`.
606606
607607
Args:
608608
*args: Whatever you decide to pass into the forward method.

src/pytorch_lightning/plugins/precision/double.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torch.optim import Optimizer
2222

2323
import pytorch_lightning as pl
24+
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
2425
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
2526
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
2627

@@ -33,15 +34,9 @@ class LightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase):
3334
pl_module: the model to wrap
3435
"""
3536

36-
@staticmethod
37-
def _to_double_precision(data: Tensor) -> Tensor:
38-
if data.is_floating_point():
39-
return data.double()
40-
return data
41-
4237
@staticmethod
4338
def _move_float_tensors_to_double(collection: Any) -> Any:
44-
return apply_to_collection(collection, Tensor, LightningDoublePrecisionModule._to_double_precision)
39+
return apply_to_collection(collection, Tensor, function=_convert_fp_tensor, dst_type=torch.double)
4540

4641
def training_step(self, *args: Any, **kwargs: Any) -> Any:
4742
return self.module.training_step(

src/pytorch_lightning/strategies/deepspeed.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,9 @@ def __init__(
8585
self.precision = precision
8686

8787
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
88-
inputs = apply_to_collection(inputs, Tensor, function=self._batch_to)
88+
inputs = apply_to_collection(inputs, Tensor, function=_fp_to_half, precision=self.precision)
8989
return super().forward(*inputs, **kwargs)
9090

91-
def _batch_to(self, batch: Tensor) -> Tensor:
92-
if torch.is_floating_point(batch):
93-
if self.precision == PrecisionType.HALF:
94-
return batch.half()
95-
elif self.precision == PrecisionType.BFLOAT:
96-
return batch.bfloat16()
97-
return batch
98-
9991

10092
class DeepSpeedStrategy(DDPStrategy):
10193
strategy_name = "deepspeed"

src/pytorch_lightning/strategies/ipu.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717

1818
import torch
1919
from lightning_utilities.core.apply_func import apply_to_collection
20-
from torch import FloatTensor, Tensor
20+
from torch import Tensor
2121
from torch.utils.data import DataLoader, Sampler
2222

2323
import pytorch_lightning as pl
2424
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
2525
from lightning_lite.plugins.precision.utils import _fp_to_half
2626
from lightning_lite.utilities.cloud_io import get_filesystem
27-
from lightning_lite.utilities.enums import PrecisionType
2827
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
2928
from pytorch_lightning.plugins.precision import PrecisionPlugin
3029
from pytorch_lightning.strategies.parallel import ParallelStrategy
@@ -61,21 +60,9 @@ def __init__(
6160
self.precision = precision
6261

6362
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
64-
if self.precision == PrecisionType.HALF:
65-
inputs = self._move_float_tensors_to_half(inputs)
66-
63+
inputs = apply_to_collection(inputs, Tensor, function=_fp_to_half, precision=self.precision)
6764
return super().forward(*inputs, **kwargs)
6865

69-
@staticmethod
70-
def batch_to(data: Tensor) -> Tensor:
71-
if torch.is_floating_point(data):
72-
return data.half()
73-
return data
74-
75-
def _move_float_tensors_to_half(self, batch: Any) -> Any:
76-
batch = apply_to_collection(batch, (FloatTensor, torch.cuda.FloatTensor), function=self.batch_to)
77-
return batch
78-
7966

8067
class IPUStrategy(ParallelStrategy):
8168
"""Plugin for training on IPU devices."""

tests/tests_lite/utilities/test_device_dtype_mixin.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self) -> None:
2323

2424

2525
@pytest.mark.parametrize(
26-
"dst_device_str,dst_dtype",
26+
"dst_device_str,dst_type",
2727
[
2828
("cpu", torch.half),
2929
("cpu", torch.float),
@@ -35,21 +35,19 @@ def __init__(self) -> None:
3535
],
3636
)
3737
@RunIf(min_cuda_gpus=1)
38-
def test_submodules_device_and_dtype(dst_device_str, dst_dtype):
38+
def test_submodules_device_and_dtype(dst_device_str, dst_type):
3939
"""Test that the device and dtype property updates propagate through mixed nesting of regular nn.Modules and
4040
the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule)."""
41-
4241
dst_device = torch.device(dst_device_str)
43-
4442
model = TopModule()
4543
assert model.device == torch.device("cpu")
46-
model = model.to(device=dst_device, dtype=dst_dtype)
44+
model = model.to(device=dst_device, dtype=dst_type)
4745
# nn.Module does not have these attributes
4846
assert not hasattr(model.module, "_device")
4947
assert not hasattr(model.module, "_dtype")
5048
# device and dtype change should propagate down into all children
5149
assert model.device == model.module.module.device == dst_device
52-
assert model.dtype == model.module.module.dtype == dst_dtype
50+
assert model.dtype == model.module.module.dtype == dst_type
5351

5452

5553
@pytest.mark.parametrize(
@@ -72,6 +70,16 @@ def test_cuda_device(device):
7270
assert device.index == torch.cuda.current_device()
7371

7472

73+
@RunIf(min_cuda_gpus=1)
74+
def test_cpu_device():
75+
model = SubSubModule().cuda()
76+
assert model.device.type == "cuda"
77+
assert model.device.index == 0
78+
model.cpu()
79+
assert model.device.type == "cpu"
80+
assert model.device.index is None
81+
82+
7583
@RunIf(min_cuda_gpus=2)
7684
def test_cuda_current_device():
7785
"""Test that calling .cuda() moves the model to the correct device and respects current cuda device setting."""
@@ -92,3 +100,46 @@ def __init__(self):
92100
model.cuda() # model is already on device 1, and calling .cuda() without device index should not move model
93101
assert model.device == torch.device("cuda", 1)
94102
assert model.layer.weight.device == torch.device("cuda", 1)
103+
104+
105+
class ExampleModule(_DeviceDtypeModuleMixin):
106+
def __init__(self, weight):
107+
super().__init__()
108+
self.register_buffer("weight", weight)
109+
110+
111+
def test_to_combinations():
112+
module = ExampleModule(torch.rand(3, 4))
113+
# sanity check
114+
assert module.weight.shape == (3, 4)
115+
assert module.weight.dtype is torch.float32
116+
# positional dtype
117+
module.to(torch.double)
118+
assert module.weight.dtype is torch.float64
119+
# positional device
120+
module.to("cpu", dtype=torch.half, non_blocking=True)
121+
assert module.weight.dtype is torch.float16
122+
assert module.device == torch.device("cpu")
123+
assert module.dtype is torch.float16
124+
125+
126+
def test_dtype_conversions():
127+
module = ExampleModule(torch.tensor(1))
128+
# different dtypes
129+
assert module.weight.dtype is torch.int64
130+
assert module.dtype is torch.float32
131+
# `.double()` skips non floating points
132+
module.double()
133+
assert module.weight.dtype is torch.int64
134+
assert module.dtype is torch.float64
135+
# but `type` doesn't
136+
module.type(torch.float)
137+
assert module.weight.dtype is torch.float32
138+
assert module.dtype is torch.float32
139+
# now, test the rest
140+
module.float()
141+
assert module.weight.dtype is torch.float32
142+
assert module.dtype is torch.float32
143+
module.half()
144+
assert module.weight.dtype is torch.float16
145+
assert module.dtype is torch.float16

0 commit comments

Comments
 (0)