Skip to content

Commit d6e4e68

Browse files
committed
build only one extension.
1 parent bde1aba commit d6e4e68

File tree

4 files changed

+91
-96
lines changed

4 files changed

+91
-96
lines changed

build_tools/setup_helpers/extension.py

Lines changed: 4 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
_CSRC_DIR = _ROOT_DIR / 'torchaudio' / 'csrc'
1919
_TP_BASE_DIR = _ROOT_DIR / 'third_party'
2020
_TP_INSTALL_DIR = _TP_BASE_DIR / 'install'
21+
_TRANSDUCER_BUILD_DIR = _TP_BASE_DIR / 'build' / 'warp_transducer'
22+
_TRANSDUCER_BASE_DIR = _TP_BASE_DIR / 'warp_transducer' / 'submodule'
2123

2224

2325
def _get_build_sox():
@@ -64,6 +66,7 @@ def _get_srcs():
6466
def _get_include_dirs():
6567
dirs = [
6668
str(_ROOT_DIR),
69+
str(_TRANSDUCER_BASE_DIR / 'include'),
6770
]
6871
if _BUILD_SOX:
6972
dirs.append(str(_TP_INSTALL_DIR / 'include'))
@@ -94,6 +97,7 @@ def _get_extra_objects():
9497
]
9598
for lib in libs:
9699
objs.append(str(_TP_INSTALL_DIR / 'lib' / lib))
100+
objs.append(str(_TRANSDUCER_BUILD_DIR / 'libwarprnnt.a'))
97101
return objs
98102

99103

@@ -132,60 +136,11 @@ def get_ext_modules(debug=False):
132136
extra_objects=_get_extra_objects(),
133137
extra_link_args=_get_ela(debug),
134138
),
135-
_get_transducer_module(),
136139
]
137140

138141

139142
class BuildExtension(TorchBuildExtension):
140143
def build_extension(self, ext):
141144
if ext.name == _EXT_NAME and _BUILD_SOX:
142145
_build_third_party()
143-
if ext.name == _TRANSDUCER_NAME:
144-
_build_transducer()
145146
super().build_extension(ext)
146-
147-
148-
_TRANSDUCER_NAME = '_warp_transducer'
149-
_TP_TRANSDUCER_BASE_DIR = _ROOT_DIR / 'third_party' / 'warp_transducer'
150-
151-
152-
def _build_transducer():
153-
build_dir = str(_TP_TRANSDUCER_BASE_DIR / 'submodule' / 'build')
154-
os.makedirs(build_dir, exist_ok=True)
155-
subprocess.run(
156-
args=['cmake', str(_TP_TRANSDUCER_BASE_DIR), '-DWITH_OMP=OFF'],
157-
cwd=build_dir,
158-
check=True,
159-
)
160-
subprocess.run(
161-
args=['cmake', '--build', '.'],
162-
cwd=build_dir,
163-
check=True,
164-
)
165-
166-
167-
def _get_transducer_module():
168-
extra_compile_args = [
169-
'-fPIC',
170-
'-std=c++14',
171-
]
172-
173-
libraries = ['warprnnt']
174-
175-
source_paths = [
176-
_TP_TRANSDUCER_BASE_DIR / 'binding.cpp',
177-
_TP_TRANSDUCER_BASE_DIR / 'submodule' / 'pytorch_binding' / 'src' / 'binding.cpp',
178-
]
179-
build_path = _TP_TRANSDUCER_BASE_DIR / 'submodule' / 'build'
180-
include_path = _TP_TRANSDUCER_BASE_DIR / 'submodule' / 'include'
181-
182-
return CppExtension(
183-
name=_TRANSDUCER_NAME,
184-
sources=[os.path.realpath(path) for path in source_paths],
185-
libraries=libraries,
186-
include_dirs=[os.path.realpath(include_path)],
187-
library_dirs=[os.path.realpath(build_path)],
188-
extra_compile_args=extra_compile_args,
189-
extra_objects=[str(build_path / f'lib{lib}.a') for lib in libraries],
190-
extra_link_args=['-Wl,-rpath,' + os.path.realpath(build_path)],
191-
)

third_party/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,5 @@ ExternalProject_Add(libsox
8888
# See https://github.com/pytorch/audio/pull/1026
8989
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp
9090
)
91+
92+
add_subdirectory(warp_transducer)

third_party/warp_transducer/binding.cpp

Lines changed: 0 additions & 47 deletions
This file was deleted.

torchaudio/csrc/transducer.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#include <iostream>
2+
#include <numeric>
3+
4+
#include <torch/extension.h>
5+
#include "rnnt.h"
6+
7+
int cpu_rnnt(torch::Tensor acts,
8+
torch::Tensor labels,
9+
torch::Tensor input_lengths,
10+
torch::Tensor label_lengths,
11+
torch::Tensor costs,
12+
torch::Tensor grads,
13+
int blank_label,
14+
int num_threads) {
15+
16+
int maxT = acts.size(1);
17+
int maxU = acts.size(2);
18+
int minibatch_size = acts.size(0);
19+
int alphabet_size = acts.size(3);
20+
21+
rnntOptions options;
22+
memset(&options, 0, sizeof(options));
23+
options.maxT = maxT;
24+
options.maxU = maxU;
25+
options.blank_label = blank_label;
26+
options.batch_first = true;
27+
options.loc = RNNT_CPU;
28+
options.num_threads = num_threads;
29+
30+
// have to use at least one
31+
options.num_threads = std::max(options.num_threads, (unsigned int) 1);
32+
33+
size_t cpu_size_bytes = 0;
34+
switch (acts.type().scalarType()) {
35+
case torch::ScalarType::Float:
36+
{
37+
get_workspace_size(maxT, maxU, minibatch_size,
38+
false, &cpu_size_bytes);
39+
40+
float* cpu_workspace = (float*) new unsigned char[cpu_size_bytes];
41+
compute_rnnt_loss(acts.data<float>(), grads.data<float>(),
42+
labels.data<int>(), label_lengths.data<int>(),
43+
input_lengths.data<int>(), alphabet_size,
44+
minibatch_size, costs.data<float>(),
45+
cpu_workspace, options);
46+
47+
delete cpu_workspace;
48+
return 0;
49+
}
50+
case torch::ScalarType::Double:
51+
{
52+
get_workspace_size(maxT, maxU, minibatch_size,
53+
false, &cpu_size_bytes,
54+
sizeof(double));
55+
56+
double* cpu_workspace = (double*) new unsigned char[cpu_size_bytes];
57+
compute_rnnt_loss_fp64(acts.data<double>(), grads.data<double>(),
58+
labels.data<int>(), label_lengths.data<int>(),
59+
input_lengths.data<int>(), alphabet_size,
60+
minibatch_size, costs.data<double>(),
61+
cpu_workspace, options);
62+
63+
delete cpu_workspace;
64+
return 0;
65+
}
66+
default:
67+
std::cerr << __FILE__ << ':' << __LINE__ << ": " << "unsupported data type" << std::endl;
68+
}
69+
return -1;
70+
}
71+
72+
TORCH_LIBRARY(warprnnt_pytorch_warp_rnnt, m) {
73+
m.def("rnnt(Tensor acts,"
74+
"Tensor labels,"
75+
"Tensor input_lengths,"
76+
"Tensor label_lengths,"
77+
"Tensor costs,"
78+
"Tensor grads,"
79+
"int blank_label,"
80+
"int num_threads) -> int");
81+
}
82+
83+
TORCH_LIBRARY_IMPL(warprnnt_pytorch_warp_rnnt, CPU, m) {
84+
m.impl("rnnt", &cpu_rnnt);
85+
}

0 commit comments

Comments
 (0)