Skip to content

Commit 0f82217

Browse files
authored
Split extension into custom impl and Python wrapper libraries (#1752)
* Split `libtorchaudio` and `_torchaudio` This change extract the core implementation from `_torchaudio` to `libtorchaudio`, so that `libtorchaudio` is reusable in TorchScript-based app. `_torchaudio` is a wrapper around `libtorchaudio` and only provides PyBind11-based features. (currently file-like object support in I/O) * Removed `BUILD_LIBTORCHAUDIO` option When invoking `cmake`, `libtorchaudio` is always built, so this option is removed. The new assumptions around the library discoverability - In regular OSS workflow (`pip`/`conda`-based binary installation), both `libtorchaudio` and `_torchaudio` are present. In this case,`libtorchaudio` has to be loaded manually with `torch.ops.load_library` and/or `torch.classes.load_library` otherwise importing `_torchaudio` would not be able to resolve the symbols defined in `libtorchaudio`. - When `torchaudio` is deployed with PEX format (single zip file) - We expect that`libtorchaudio.so` exists as a file in some search path configured by client code. - `_torchaudio` is still importable and because we do not know where `libtorchaudio` will exist, we will let the dynamic loader resolve the dependency from `_torchaudio` to `libtorchaudio`, which should work as long as `libtorchaudio` is in a library search path (search path is not modifiable from already-running Python process).
1 parent 0d007b7 commit 0f82217

File tree

6 files changed

+96
-105
lines changed

6 files changed

+96
-105
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ endif()
5050
option(BUILD_SOX "Build libsox statically" ON)
5151
option(BUILD_KALDI "Build kaldi statically" ON)
5252
option(BUILD_RNNT "Enable RNN transducer" ON)
53-
option(BUILD_LIBTORCHAUDIO "Build C++ Library" ON)
5453
option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF)
5554
option(USE_CUDA "Enable CUDA support" OFF)
5655
option(USE_ROCM "Enable ROCM support" OFF)

build_tools/setup_helpers/extension.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ def _get_build(var, default=False):
4343

4444

4545
def get_ext_modules():
46-
return [Extension(name='torchaudio._torchaudio', sources=[])]
46+
return [
47+
Extension(name='torchaudio.libtorchaudio', sources=[]),
48+
Extension(name='torchaudio._torchaudio', sources=[]),
49+
]
4750

4851

4952
# Based off of
@@ -53,10 +56,19 @@ def run(self):
5356
try:
5457
subprocess.check_output(['cmake', '--version'])
5558
except OSError:
56-
raise RuntimeError("CMake is not available.")
59+
raise RuntimeError("CMake is not available.") from None
5760
super().run()
5861

5962
def build_extension(self, ext):
63+
# Since two library files (libtorchaudio and _torchaudio) need to be
64+
# recognized by setuptools, we instantiate `Extension` twice. (see `get_ext_modules`)
65+
# This leads to the situation where this `build_extension` method is called twice.
66+
# However, the following `cmake` command will build all of them at the same time,
67+
# so, we do not need to perform `cmake` twice.
68+
# Therefore we call `cmake` only for `torchaudio._torchaudio`.
69+
if ext.name != 'torchaudio._torchaudio':
70+
return
71+
6072
extdir = os.path.abspath(
6173
os.path.dirname(self.get_ext_fullpath(ext.name)))
6274

@@ -76,7 +88,6 @@ def build_extension(self, ext):
7688
f"-DBUILD_KALDI:BOOL={'ON' if _BUILD_KALDI else 'OFF'}",
7789
f"-DBUILD_RNNT:BOOL={'ON' if _BUILD_RNNT else 'OFF'}",
7890
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
79-
"-DBUILD_LIBTORCHAUDIO:BOOL=OFF",
8091
f"-DUSE_ROCM:BOOL={'ON' if _USE_ROCM else 'OFF'}",
8192
f"-DUSE_CUDA:BOOL={'ON' if _USE_CUDA else 'OFF'}",
8293
]

examples/libtorchaudio/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ cmake_minimum_required(VERSION 3.5)
22

33
project(libtorchaudio-cpp-example)
44

5-
SET(BUILD_LIBTORCHAUDIO ON CACHE BOOL "Build libtorchaudio")
65
SET(BUILD_SOX ON CACHE BOOL "Build libsox into libtorchaudio")
76

87
SET(BUILD_KALDI OFF CACHE BOOL "Build Kaldi into libtorchaudio")

torchaudio/__init__.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,4 @@
1-
from torchaudio._internal import module_utils as _mod_utils # noqa: F401
2-
3-
if _mod_utils.is_module_available('torchaudio._torchaudio'):
4-
# Note this import has two purposes
5-
# 1. Make _torchaudio accessible by the other modules (regular import)
6-
# 2. Register torchaudio's custom ops bound via TorchScript
7-
#
8-
# For 2, normally function calls `torch.ops.load_library` and `torch.classes.load_library`
9-
# are used. However, in our cases, this is inconvenient and unnecessary.
10-
#
11-
# - Why inconvenient?
12-
# When torchaudio is deployed with `pex` format, all the files are deployed as a single zip
13-
# file, and the extension module is not present as a file with full path. Therefore it is not
14-
# possible to pass the path to library to `torch.[ops|classes].load_library` functions.
15-
#
16-
# - Why unnecessary?
17-
# When torchaudio extension module (C++ module) is available, it is assumed that
18-
# the extension contains both TorchScript-based binding and PyBind11-based binding.*
19-
# Under this assumption, simply performing `from torchaudio import _torchaudio` will load the
20-
# library which contains TorchScript-based binding as well, and the functions/classes bound
21-
# via TorchScript become accessible under `torch.ops` and `torch.classes`.
22-
#
23-
# *Note that this holds true even when these two bindings are split into two library files and
24-
# the library that contains PyBind11-based binding (`_torchaudio.so` in the following diagram)
25-
# depends on the other one (`libtorchaudio.so`), because when the process tries to load
26-
# `_torchaudio.so` it detects undefined symbols from `libtorchaudio.so` and will automatically
27-
# loads `libtorchaudio.so`. (given that the library is found in a search path)
28-
#
29-
# [libtorchaudio.so] <- [_torchaudio.so]
30-
#
31-
#
32-
from torchaudio import _torchaudio # noqa
33-
else:
34-
import warnings
35-
warnings.warn('torchaudio C++ extension is not available.')
36-
1+
from torchaudio import _extension # noqa: F401
372
from torchaudio import (
383
compliance,
394
datasets,

torchaudio/_extension.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import os
2+
import warnings
3+
from pathlib import Path
4+
5+
import torch
6+
from torchaudio._internal import module_utils as _mod_utils # noqa: F401
7+
8+
9+
def _init_extension():
10+
if not _mod_utils.is_module_available('torchaudio._torchaudio'):
11+
warnings.warn('torchaudio C++ extension is not available.')
12+
return
13+
14+
suffix = 'dll' if os.name == 'nt' else 'so'
15+
path = Path(__file__).parent / f'libtorchaudio.{suffix}'
16+
# In case `torchaudio` is deployed with `pex` format, this file does not exist.
17+
# In this case, we expect that `libtorchaudio` is available somewhere
18+
# in the search path of dynamic loading mechanism, and importing `_torchaudio`,
19+
# which depends on `libtorchaudio` and dynamic loader will handle it for us.
20+
if path.exists():
21+
torch.ops.load_library(path)
22+
torch.classes.load_library(path)
23+
# This import is for initializing the methods registered via PyBind11
24+
from torchaudio import _torchaudio # noqa
25+
26+
27+
_init_extension()

torchaudio/csrc/CMakeLists.txt

Lines changed: 54 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
get_property(TORCHAUDIO_THIRD_PARTIES GLOBAL PROPERTY TORCHAUDIO_THIRD_PARTIES)
22

33
################################################################################
4-
# Stuff common to libtorchaudio and _torchaudio.so
4+
# libtorchaudio
55
################################################################################
66
set(
77
LIBTORCHAUDIO_SOURCES
@@ -11,75 +11,94 @@ set(
1111
)
1212

1313
if(BUILD_RNNT)
14-
set(
15-
RNNT_SOURCES
14+
list(
15+
APPEND
16+
LIBTORCHAUDIO_SOURCES
1617
rnnt/cpu/compute_alphas.cpp
1718
rnnt/cpu/compute_betas.cpp
1819
rnnt/cpu/compute.cpp
1920
rnnt/compute_alphas.cpp
2021
rnnt/compute_betas.cpp
2122
rnnt/compute.cpp
2223
rnnt/autograd.cpp
23-
)
24-
24+
)
2525
if (USE_CUDA)
26-
set(
27-
CUDA_RNNT_SOURCES
26+
list(
27+
APPEND
28+
LIBTORCHAUDIO_SOURCES
2829
rnnt/gpu/compute_alphas.cu
2930
rnnt/gpu/compute_betas.cu
3031
rnnt/gpu/compute.cu
31-
)
32-
list(APPEND RNNT_SOURCES ${CUDA_RNNT_SOURCES})
32+
)
3333
endif()
34-
35-
list(APPEND LIBTORCHAUDIO_SOURCES ${RNNT_SOURCES})
3634
endif()
3735

3836
if(BUILD_KALDI)
3937
list(APPEND LIBTORCHAUDIO_SOURCES kaldi.cpp)
4038
endif()
4139

4240
if(BUILD_SOX)
43-
set(
44-
SOX_SOURCES
41+
list(
42+
APPEND
43+
LIBTORCHAUDIO_SOURCES
4544
sox/io.cpp
4645
sox/utils.cpp
4746
sox/effects.cpp
4847
sox/effects_chain.cpp
4948
sox/types.cpp
49+
)
50+
endif()
51+
52+
add_library(
53+
libtorchaudio
54+
SHARED
55+
${LIBTORCHAUDIO_SOURCES}
56+
)
57+
set_target_properties(libtorchaudio PROPERTIES PREFIX "")
58+
59+
target_include_directories(
60+
libtorchaudio
61+
PRIVATE
62+
${PROJECT_SOURCE_DIR}
5063
)
51-
list(APPEND LIBTORCHAUDIO_SOURCES ${SOX_SOURCES})
64+
65+
target_link_libraries(
66+
libtorchaudio
67+
torch
68+
${TORCHAUDIO_THIRD_PARTIES}
69+
)
70+
71+
if (BUILD_SOX)
72+
target_compile_definitions(libtorchaudio PUBLIC INCLUDE_SOX)
5273
endif()
5374

54-
################################################################################
55-
# libtorchaudio.so
56-
################################################################################
57-
if(BUILD_LIBTORCHAUDIO)
58-
add_library(
59-
libtorchaudio
60-
SHARED
61-
${LIBTORCHAUDIO_SOURCES}
62-
)
63-
set_target_properties(libtorchaudio PROPERTIES PREFIX "")
75+
if (BUILD_KALDI)
76+
target_compile_definitions(libtorchaudio PUBLIC INCLUDE_KALDI)
77+
endif()
6478

79+
if(USE_CUDA)
80+
target_compile_definitions(libtorchaudio PRIVATE USE_CUDA)
6581
target_include_directories(
6682
libtorchaudio
67-
PUBLIC
68-
${PROJECT_SOURCE_DIR}
83+
PRIVATE
84+
${CUDA_TOOLKIT_INCLUDE}
6985
)
70-
7186
target_link_libraries(
7287
libtorchaudio
73-
${TORCH_LIBRARIES}
74-
${TORCHAUDIO_THIRD_PARTIES}
88+
${C10_CUDA_LIBRARY}
89+
${CUDA_CUDART_LIBRARY}
7590
)
91+
endif()
7692

77-
install(
78-
TARGETS
79-
libtorchaudio
80-
LIBRARY DESTINATION lib
81-
)
93+
install(
94+
TARGETS libtorchaudio
95+
LIBRARY DESTINATION .
96+
RUNTIME DESTINATION .
97+
)
8298

99+
if (APPLE)
100+
set(TORCHAUDIO_LIBRARY libtorchaudio CACHE INTERNAL "")
101+
else()
83102
set(TORCHAUDIO_LIBRARY -Wl,--no-as-needed libtorchaudio -Wl,--as-needed CACHE INTERNAL "")
84103
endif()
85104

@@ -104,7 +123,6 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
104123
add_library(
105124
_torchaudio
106125
SHARED
107-
${LIBTORCHAUDIO_SOURCES}
108126
${EXTENSION_SOURCES}
109127
)
110128

@@ -119,31 +137,12 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
119137
set_target_properties(_torchaudio PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
120138
endif()
121139

122-
if (BUILD_SOX)
123-
target_compile_definitions(_torchaudio PRIVATE INCLUDE_SOX)
124-
endif()
125-
126-
if (BUILD_KALDI)
127-
target_compile_definitions(_torchaudio PRIVATE INCLUDE_KALDI)
128-
endif()
129-
130-
if (USE_CUDA)
131-
target_compile_definitions(_torchaudio PRIVATE USE_CUDA)
132-
endif()
133-
134140
target_include_directories(
135141
_torchaudio
136142
PRIVATE
137143
${PROJECT_SOURCE_DIR}
138144
${Python_INCLUDE_DIR}
139145
)
140-
if(USE_CUDA)
141-
target_include_directories(
142-
_torchaudio
143-
PRIVATE
144-
${CUDA_TOOLKIT_INCLUDE}
145-
)
146-
endif()
147146

148147
# See https://github.com/pytorch/pytorch/issues/38122
149148
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
@@ -155,20 +154,11 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
155154

156155
target_link_libraries(
157156
_torchaudio
158-
torch
157+
libtorchaudio
159158
${TORCH_PYTHON_LIBRARY}
160-
${TORCHAUDIO_THIRD_PARTIES}
161159
${ADDITIONAL_ITEMS}
162160
)
163161

164-
if(USE_CUDA)
165-
target_link_libraries(
166-
_torchaudio
167-
${C10_CUDA_LIBRARY}
168-
${CUDA_CUDART_LIBRARY}
169-
)
170-
endif()
171-
172162
install(
173163
TARGETS _torchaudio
174164
LIBRARY DESTINATION .

0 commit comments

Comments
 (0)