Skip to content

Commit d02fe34

Browse files
ethanwharriscarmoccarohitgr7justusschock
authored
Feature/double precision (#6595)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Justus Schock <[email protected]>
1 parent 5733889 commit d02fe34

File tree

8 files changed

+240
-4
lines changed

8 files changed

+240
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6161
- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))
6262

6363

64+
- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595))
65+
66+
6467
### Changed
6568

6669
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))

docs/source/common/trainer.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1157,7 +1157,7 @@ precision
11571157
11581158
|
11591159
1160-
Full precision (32), half precision (16).
1160+
Double precision (64), full precision (32) or half precision (16).
11611161
Can be used on CPU, GPU or TPUs.
11621162
11631163
If used on TPU will use torch.bfloat16 but tensor printing
@@ -1172,6 +1172,9 @@ will still show torch.float32.
11721172
# 16-bit precision
11731173
trainer = Trainer(precision=16, gpus=1)
11741174
1175+
# 64-bit precision
1176+
trainer = Trainer(precision=64)
1177+
11751178
Example::
11761179
11771180
# one day

pytorch_lightning/plugins/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401
22
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
33
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
4+
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
45
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
56
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
67
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
@@ -29,6 +30,7 @@
2930
"DDPSpawnPlugin",
3031
"DeepSpeedPlugin",
3132
"DeepSpeedPrecisionPlugin",
33+
"DoublePrecisionPlugin",
3234
"HorovodPlugin",
3335
"NativeMixedPrecisionPlugin",
3436
"PrecisionPlugin",

pytorch_lightning/plugins/precision/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
22
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
3+
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
34
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
45
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
56
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from functools import wraps
15+
from typing import Any, Sequence, Tuple, TYPE_CHECKING, List
16+
17+
import torch
18+
19+
from pytorch_lightning.core.lightning import LightningModule
20+
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
21+
from pytorch_lightning.utilities.apply_func import apply_to_collection
22+
23+
if TYPE_CHECKING:
24+
from torch.nn import Module
25+
from torch.optim import Optimizer
26+
27+
28+
class _DoublePrecisionPatch:
29+
"""Class to handle patching of methods in the ``LightningModule`` and subsequent teardown."""
30+
31+
def __init__(self, model: 'Module', method_name: str, old_method: Any) -> None:
32+
self.model = model
33+
self.method_name = method_name
34+
self.old_method = old_method
35+
36+
def teardown(self) -> None:
37+
setattr(self.model, self.method_name, self.old_method)
38+
39+
@staticmethod
40+
def _to_double_precision(data: torch.Tensor) -> torch.Tensor:
41+
if data.is_floating_point():
42+
return data.double()
43+
return data
44+
45+
@staticmethod
46+
def _move_float_tensors_to_double(collection: Any) -> Any:
47+
return apply_to_collection(
48+
collection, torch.Tensor, function=_DoublePrecisionPatch._to_double_precision
49+
)
50+
51+
@classmethod
52+
def patch(cls, model: 'Module', method_name: str) -> '_DoublePrecisionPatch':
53+
old_method = getattr(model, method_name)
54+
55+
@wraps(old_method)
56+
def new_method(*args: Any, **kwargs: Any) -> Any:
57+
return old_method(
58+
*_DoublePrecisionPatch._move_float_tensors_to_double(args),
59+
**_DoublePrecisionPatch._move_float_tensors_to_double(kwargs)
60+
)
61+
62+
setattr(model, method_name, new_method if callable(old_method) else old_method)
63+
return cls(model, method_name, old_method)
64+
65+
66+
class DoublePrecisionPlugin(PrecisionPlugin):
67+
"""Plugin for training with double (``torch.float64``) precision."""
68+
69+
precision: int = 64
70+
71+
def __init__(self) -> None:
72+
self.patches: List[_DoublePrecisionPatch] = []
73+
74+
def connect(
75+
self,
76+
model: 'Module',
77+
optimizers: Sequence['Optimizer'],
78+
lr_schedulers: Sequence[Any],
79+
) -> Tuple['Module', Sequence['Optimizer'], Sequence[Any]]:
80+
"""Converts the model to double precision and wraps the `training_step`, `validation_step`, `test_step`,
81+
`predict_step`, and `forward` methods to convert incoming floating point data to double. Does not alter
82+
`optimizers` or `lr_schedulers`."""
83+
model = model.to(dtype=torch.float64)
84+
if isinstance(model, LightningModule):
85+
self.patches.append(_DoublePrecisionPatch.patch(model, 'training_step'))
86+
self.patches.append(_DoublePrecisionPatch.patch(model, 'validation_step'))
87+
self.patches.append(_DoublePrecisionPatch.patch(model, 'test_step'))
88+
self.patches.append(_DoublePrecisionPatch.patch(model, 'predict_step'))
89+
self.patches.append(_DoublePrecisionPatch.patch(model, 'forward'))
90+
91+
return super().connect(model, optimizers, lr_schedulers)
92+
93+
def post_dispatch(self) -> None:
94+
while len(self.patches) > 0:
95+
self.patches.pop().teardown()

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
DDPSpawnShardedPlugin,
3333
DeepSpeedPlugin,
3434
DeepSpeedPrecisionPlugin,
35+
DoublePrecisionPlugin,
3536
HorovodPlugin,
3637
NativeMixedPrecisionPlugin,
3738
PrecisionPlugin,
@@ -319,7 +320,8 @@ def select_precision_plugin(self) -> PrecisionPlugin:
319320

320321
if self.precision == 32:
321322
return PrecisionPlugin()
322-
323+
elif self.precision == 64:
324+
return DoublePrecisionPlugin()
323325
elif self.precision == 16:
324326
if self.on_tpu:
325327
return TPUHalfPrecisionPlugin()
@@ -358,7 +360,7 @@ def select_precision_plugin(self) -> PrecisionPlugin:
358360
log.info("Using APEX 16bit precision.")
359361
return ApexMixedPrecisionPlugin(self.amp_level)
360362

361-
raise NotImplementedError("We only support precisions 32 and 16!")
363+
raise NotImplementedError("We only support precisions 64, 32 and 16!")
362364

363365
def select_training_type_plugin(self) -> TrainingTypePlugin:
364366
if self.use_ddp2:

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ def __init__(
227227
228228
plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
229229
230-
precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs.
230+
precision: Double precision (64), full precision (32) or half precision (16). Can be used on CPU, GPU or
231+
TPUs.
231232
232233
max_epochs: Stop training once this number of epochs is reached. Disabled by default (None).
233234
If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000.
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
16+
import torch
17+
from torch.utils.data import DataLoader, Dataset
18+
19+
from pytorch_lightning import Trainer
20+
from tests.helpers.boring_model import BoringModel, RandomDataset
21+
22+
23+
class RandomFloatIntDataset(Dataset):
24+
25+
def __init__(self, size, length):
26+
self.len = length
27+
self.float_data = torch.randn(length, size)
28+
self.int_data = torch.randint(10, (length, 1))
29+
30+
def __getitem__(self, index):
31+
return self.float_data[index], self.int_data[index]
32+
33+
def __len__(self):
34+
return self.len
35+
36+
37+
class DoublePrecisionBoringModel(BoringModel):
38+
39+
def training_step(self, batch, batch_idx):
40+
float_data, int_data = batch
41+
assert float_data.dtype == torch.float64
42+
output = self(float_data)
43+
loss = self.loss(batch, output)
44+
return {"loss": loss}
45+
46+
def validation_step(self, batch, batch_idx):
47+
assert batch.dtype == torch.float64
48+
output = self(batch)
49+
loss = self.loss(batch, output)
50+
return {"x": loss}
51+
52+
def test_step(self, batch, batch_idx):
53+
assert batch.dtype == torch.float64
54+
output = self(batch)
55+
loss = self.loss(batch, output)
56+
return {"y": loss}
57+
58+
def predict_step(self, batch, batch_idx, dataloader_idx=None):
59+
assert batch.dtype == torch.float64
60+
return self(batch)
61+
62+
def on_fit_start(self):
63+
assert self.layer.weight.dtype == torch.float64
64+
65+
def on_after_backward(self):
66+
assert self.layer.weight.grad.dtype == torch.float64
67+
68+
def train_dataloader(self):
69+
dataset = RandomFloatIntDataset(32, 64)
70+
assert dataset.float_data.dtype == torch.float32 # Don't start with double data
71+
return DataLoader(dataset)
72+
73+
def predict_dataloader(self):
74+
return DataLoader(RandomDataset(32, 64))
75+
76+
77+
class DoublePrecisionBoringModelNoForward(BoringModel):
78+
79+
def training_step(self, batch, batch_idx):
80+
assert batch.dtype == torch.float64
81+
output = self.layer(batch)
82+
assert output.dtype == torch.float64
83+
loss = self.loss(batch, output)
84+
return {"loss": loss}
85+
86+
def validation_step(self, batch, batch_idx):
87+
assert batch.dtype == torch.float64
88+
output = self.layer(batch)
89+
assert output.dtype == torch.float64
90+
loss = self.loss(batch, output)
91+
return {"x": loss}
92+
93+
def test_step(self, batch, batch_idx):
94+
assert batch.dtype == torch.float64
95+
output = self.layer(batch)
96+
assert output.dtype == torch.float64
97+
loss = self.loss(batch, output)
98+
return {"y": loss}
99+
100+
def predict_step(self, batch, batch_idx, dataloader_idx=None):
101+
assert batch.dtype == torch.float64
102+
output = self.layer(batch)
103+
assert output.dtype == torch.float64
104+
return output
105+
106+
def predict_dataloader(self):
107+
return DataLoader(RandomDataset(32, 64))
108+
109+
110+
@pytest.mark.parametrize(
111+
'boring_model',
112+
(DoublePrecisionBoringModel, DoublePrecisionBoringModelNoForward)
113+
)
114+
def test_double_precision(tmpdir, boring_model):
115+
model = boring_model()
116+
original_training_step = model.training_step
117+
118+
trainer = Trainer(
119+
max_epochs=2,
120+
default_root_dir=tmpdir,
121+
fast_dev_run=2,
122+
precision=64,
123+
log_every_n_steps=1,
124+
)
125+
trainer.fit(model)
126+
trainer.test(model)
127+
trainer.predict(model)
128+
129+
assert model.training_step == original_training_step

0 commit comments

Comments
 (0)