Skip to content

Commit 0137564

Browse files
tchatoncarmocca
andauthored
[bugfix] Add set_default_tensor_type to torch.DoubleTensor with precision=64 (#7108)
* update * Update pytorch_lightning/plugins/precision/double.py Co-authored-by: Carlos Mocholí <[email protected]> * Update pytorch_lightning/plugins/precision/double.py Co-authored-by: Carlos Mocholí <[email protected]> * Update pytorch_lightning/plugins/precision/double.py Co-authored-by: Carlos Mocholí <[email protected]> * resolve tests Co-authored-by: Carlos Mocholí <[email protected]>
1 parent ca21da4 commit 0137564

File tree

5 files changed

+32
-4
lines changed

5 files changed

+32
-4
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def predict_step(self, args: List[Union[Any, int]]) -> STEP_OUTPUT:
242242

243243
args[0] = batch
244244

245-
with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context():
245+
with self.precision_plugin.predict_step_context(), self.training_type_plugin.predict_step_context():
246246
return self.training_type_plugin.predict_step(*args)
247247

248248
def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:

pytorch_lightning/plugins/base_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,6 @@ def test_step_context(self) -> Generator:
4141
yield
4242

4343
@contextlib.contextmanager
44-
def predict_context(self) -> Generator:
44+
def predict_step_context(self) -> Generator:
4545
"""A contextmanager for the predict step"""
4646
yield

pytorch_lightning/plugins/precision/double.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import contextlib
1415
from functools import wraps
15-
from typing import Any, List, Tuple
16+
from typing import Any, Generator, List, Tuple
1617

1718
import torch
1819
import torch.nn as nn
@@ -90,3 +91,18 @@ def connect(
9091
def post_dispatch(self) -> None:
9192
while len(self.patches) > 0:
9293
self.patches.pop().teardown()
94+
95+
@contextlib.contextmanager
96+
def tensor_type_context(self) -> Generator:
97+
"""
98+
A context manager to change the default tensor type.
99+
See: :meth:`torch.set_default_tensor_type`
100+
"""
101+
torch.set_default_tensor_type(torch.DoubleTensor)
102+
yield
103+
torch.set_default_tensor_type(torch.FloatTensor)
104+
105+
train_step_context = tensor_type_context
106+
val_step_context = tensor_type_context
107+
test_step_context = tensor_type_context
108+
predict_step_context = tensor_type_context

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_step_context(self) -> Generator[None, None, None]:
115115
yield
116116

117117
@contextmanager
118-
def predict_context(self) -> Generator[None, None, None]:
118+
def predict_step_context(self) -> Generator[None, None, None]:
119119
"""Enable autocast context"""
120120
with torch.cuda.amp.autocast():
121121
yield

tests/plugins/test_double_plugin.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,37 @@ class DoublePrecisionBoringModel(BoringModel):
3737

3838
def training_step(self, batch, batch_idx):
3939
float_data, int_data = batch
40+
assert torch.tensor([0.]).dtype == torch.float64
41+
assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16
4042
assert float_data.dtype == torch.float64
4143
output = self(float_data)
4244
loss = self.loss(batch, output)
4345
return {"loss": loss}
4446

47+
def training_epoch_end(self, outputs) -> None:
48+
assert torch.tensor([0.]).dtype == torch.float32
49+
return super().training_epoch_end(outputs)
50+
4551
def validation_step(self, batch, batch_idx):
4652
assert batch.dtype == torch.float64
53+
assert torch.tensor([0.]).dtype == torch.float64
54+
assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16
4755
output = self(batch)
4856
loss = self.loss(batch, output)
4957
return {"x": loss}
5058

5159
def test_step(self, batch, batch_idx):
5260
assert batch.dtype == torch.float64
61+
assert torch.tensor([0.]).dtype == torch.float64
62+
assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16
5363
output = self(batch)
5464
loss = self.loss(batch, output)
5565
return {"y": loss}
5666

5767
def predict_step(self, batch, batch_idx, dataloader_idx=None):
5868
assert batch.dtype == torch.float64
69+
assert torch.tensor([0.]).dtype == torch.float64
70+
assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16
5971
return self(batch)
6072

6173
def on_fit_start(self):

0 commit comments

Comments
 (0)