Skip to content

Commit 64c8220

Browse files
committed
guard for transducer
1 parent ab23f3c commit 64c8220

File tree

7 files changed

+21
-10
lines changed

7 files changed

+21
-10
lines changed

.circleci/torchscript_bc_test/common.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,5 @@ build_master() {
6767
printf "* Installing torchaudio\n"
6868
cd "${_root_dir}" || exit 1
6969
git submodule update --init --recursive
70-
BUILD_SOX=1 python setup.py clean install
70+
BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py clean install
7171
}

.circleci/unittest/linux/scripts/install.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ conda install -y -c "pytorch-${UPLOAD_CHANNEL}" pytorch ${cudatoolkit}
3838

3939
# 2. Install torchaudio
4040
printf "* Installing torchaudio\n"
41-
BUILD_SOX=1 python setup.py install
41+
BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py install
4242

4343
# 3. Install Test tools
4444
printf "* Installing test tools\n"

build_tools/setup_helpers/extension.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,21 @@
2020
_TP_INSTALL_DIR = _TP_BASE_DIR / 'install'
2121

2222

23-
def _get_build_sox():
24-
val = os.environ.get('BUILD_SOX', '0')
23+
def _get_build(var):
24+
val = os.environ.get(var, '0')
2525
trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES']
2626
falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO']
2727
if val in trues:
2828
return True
2929
if val not in falses:
3030
print(
31-
f'WARNING: Unexpected environment variable value `BUILD_SOX={val}`. '
31+
f'WARNING: Unexpected environment variable value `{var}={val}`. '
3232
f'Expected one of {trues + falses}')
3333
return False
3434

3535

36-
_BUILD_SOX = _get_build_sox()
36+
_BUILD_SOX = _get_build("BUILD_SOX")
37+
_BUILD_TRANSDUCER = _get_build("BUILD_TRANSDUCER")
3738

3839

3940
def _get_eca(debug):
@@ -42,6 +43,8 @@ def _get_eca(debug):
4243
eca += ["-O0", "-g"]
4344
else:
4445
eca += ["-O3"]
46+
if _BUILD_TRANSDUCER:
47+
eca += ['-DBUILD_TRANSDUCER']
4548
return eca
4649

4750

@@ -64,10 +67,11 @@ def _get_srcs():
6467
def _get_include_dirs():
6568
dirs = [
6669
str(_ROOT_DIR),
67-
str(_TP_BASE_DIR / 'transducer' / 'submodule' / 'include'),
6870
]
6971
if _BUILD_SOX:
7072
dirs.append(str(_TP_INSTALL_DIR / 'include'))
73+
if _BUILD_TRANSDUCER:
74+
dirs.append(str(_TP_BASE_DIR / 'transducer' / 'submodule' / 'include'))
7175
return dirs
7276

7377

@@ -95,7 +99,8 @@ def _get_extra_objects():
9599
]
96100
for lib in libs:
97101
objs.append(str(_TP_INSTALL_DIR / 'lib' / lib))
98-
objs.append(str(_TP_BASE_DIR / 'build' / 'transducer' / 'libwarprnnt.a'))
102+
if _BUILD_TRANSDUCER:
103+
objs.append(str(_TP_BASE_DIR / 'build' / 'transducer' / 'libwarprnnt.a'))
99104
return objs
100105

101106

packaging/build_wheel.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@ if [[ "$OSTYPE" == "msys" ]]; then
1515
python_tag="$(echo "cp$PYTHON_VERSION" | tr -d '.')"
1616
python setup.py bdist_wheel --plat-name win_amd64 --python-tag $python_tag
1717
else
18-
BUILD_SOX=1 python setup.py bdist_wheel
18+
BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py bdist_wheel
1919
fi

packaging/torchaudio/build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
#!/usr/bin/env bash
22
set -ex
33

4-
BUILD_SOX=1 python setup.py install --single-version-externally-managed --record=record.txt
4+
BUILD_TRANSDUCER=1 BUILD_SOX=1 python setup.py install --single-version-externally-managed --record=record.txt

torchaudio/csrc/register.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ TORCH_LIBRARY(torchaudio, m) {
8181
//////////////////////////////////////////////////////////////////////////////
8282
// transducer.cpp
8383
//////////////////////////////////////////////////////////////////////////////
84+
#ifdef BUILD_TRANSDUCER
8485
m.def("rnnt_loss(Tensor acts,"
8586
"Tensor labels,"
8687
"Tensor input_lengths,"
@@ -89,5 +90,6 @@ TORCH_LIBRARY(torchaudio, m) {
8990
"Tensor grads,"
9091
"int blank_label,"
9192
"int num_threads) -> int");
93+
#endif
9294
}
9395
#endif

torchaudio/csrc/transducer.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#ifdef BUILD_TRANSDUCER
2+
13
#include <iostream>
24
#include <numeric>
35
#include <string>
@@ -76,3 +78,5 @@ int64_t cpu_rnnt_loss(torch::Tensor acts,
7678
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
7779
m.impl("rnnt_loss", &cpu_rnnt_loss);
7880
}
81+
82+
#endif

0 commit comments

Comments
 (0)