Skip to content

Commit 5bd3cd5

Browse files
ejohbawaelchli
andauthored
Bugfix/cuda oom detection and handling (#6934)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 895bea1 commit 5bd3cd5

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

pytorch_lightning/utilities/memory.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def is_oom_error(exception):
5353
def is_cuda_out_of_memory(exception):
5454
return isinstance(exception, RuntimeError) \
5555
and len(exception.args) == 1 \
56-
and "CUDA out of memory." in exception.args[0]
56+
and "CUDA" in exception.args[0] \
57+
and "out of memory" in exception.args[0]
5758

5859

5960
# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
@@ -76,4 +77,10 @@ def garbage_collection_cuda():
7677
"""Garbage collection Torch (CUDA) memory."""
7778
gc.collect()
7879
if torch.cuda.is_available():
79-
torch.cuda.empty_cache()
80+
try:
81+
# This is the last thing that should cause an OOM error, but seemingly it can.
82+
torch.cuda.empty_cache()
83+
except RuntimeError as exception:
84+
if not is_oom_error(exception):
85+
# Only handle OOM errors
86+
raise

0 commit comments

Comments
 (0)