Skip to content

Commit f3aff2f

Browse files
Make torch version check numeric (#4285)
Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent af16236 commit f3aff2f

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ def get_extensions():
147147
)
148148

149149
is_rocm_pytorch = False
150-
if torch.__version__ >= '1.5':
150+
TORCH_MAJOR = int(torch.__version__.split('.')[0])
151+
TORCH_MINOR = int(torch.__version__.split('.')[1])
152+
if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5):
151153
from torch.utils.cpp_extension import ROCM_HOME
152154
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
153155

0 commit comments

Comments
 (0)