diff --git a/CMakeLists.txt b/CMakeLists.txt index f376947aa4..5e0e8346f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,6 +62,7 @@ option(BUILD_RNNT "Enable RNN transducer" ON) option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF) option(USE_CUDA "Enable CUDA support" OFF) option(USE_ROCM "Enable ROCM support" OFF) +option(USE_OPENMP "Enable OpenMP support" OFF) # check that USE_CUDA and USE_ROCM are not set at the same time @@ -122,6 +123,10 @@ if(MSVC) endif() endif() +if(USE_OPENMP) + find_package(OpenMP REQUIRED) +endif() + # TORCH_CXX_FLAGS contains the same -D_GLIBCXX_USE_CXX11_ABI value as PyTorch set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall ${TORCH_CXX_FLAGS}") diff --git a/tools/setup_helpers/extension.py b/tools/setup_helpers/extension.py index 8f062c6ff8..945c7e8c21 100644 --- a/tools/setup_helpers/extension.py +++ b/tools/setup_helpers/extension.py @@ -39,6 +39,8 @@ def _get_build(var, default=False): _BUILD_RNNT = _get_build("BUILD_RNNT", True) _USE_ROCM = _get_build("USE_ROCM", torch.cuda.is_available() and torch.version.hip is not None) _USE_CUDA = _get_build("USE_CUDA", torch.cuda.is_available() and torch.version.hip is None) +_USE_OPENMP = _get_build("USE_OPENMP", True) and \ + 'ATen parallel backend: OpenMP' in torch.__config__.parallel_info() _TORCH_CUDA_ARCH_LIST = os.environ.get('TORCH_CUDA_ARCH_LIST', None) @@ -90,6 +92,7 @@ def build_extension(self, ext): "-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON", f"-DUSE_ROCM:BOOL={'ON' if _USE_ROCM else 'OFF'}", f"-DUSE_CUDA:BOOL={'ON' if _USE_CUDA else 'OFF'}", + f"-DUSE_OPENMP:BOOL={'ON' if _USE_OPENMP else 'OFF'}", ] build_args = [ '--target', 'install' diff --git a/torchaudio/csrc/CMakeLists.txt b/torchaudio/csrc/CMakeLists.txt index 2187abafa1..9ca8efe0cb 100644 --- a/torchaudio/csrc/CMakeLists.txt +++ b/torchaudio/csrc/CMakeLists.txt @@ -94,6 +94,10 @@ if (MSVC) set_target_properties(libtorchaudio PROPERTIES SUFFIX ".pyd") endif(MSVC) +if(OpenMP_CXX_FOUND) + target_link_libraries(libtorchaudio OpenMP::OpenMP_CXX) +endif() + install( TARGETS libtorchaudio LIBRARY DESTINATION lib