Skip to content

Commit a42e8c0

Browse files
committed
.
1 parent ce5a4ab commit a42e8c0

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

pytorch_lightning/overrides/data_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def _worker(i, module, input, kwargs, device=None):
285285
if output is None:
286286
warn_missing_output(fx_called)
287287

288-
if output is not None and module.distrib_type in ("dp", "ddp2"):
288+
if output is not None and module._distrib_type in ('dp', 'ddp2'):
289289
auto_squeeze_dim_zeros(output)
290290
# ---------------
291291

pytorch_lightning/utilities/enums.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class DistributedType(LightningEnum):
5050
>>> DistributedType.DDP == 'ddp'
5151
True
5252
>>> # which is case invariant
53-
>>> DistributedType.DDP2 == 'DDP2'
53+
>>> DistributedType.DDP2 in ('ddp2', )
5454
True
5555
"""
5656
DP = 'dp'
@@ -69,7 +69,7 @@ class DeviceType(LightningEnum):
6969
>>> DeviceType.GPU == 'GPU'
7070
True
7171
>>> # which is case invariant
72-
>>> DeviceType.TPU == 'tpu'
72+
>>> DeviceType.TPU in ('tpu', 'CPU')
7373
True
7474
"""
7575
CPU = 'CPU'

0 commit comments

Comments
 (0)