From 9913e7f1de8c4e9c3e75bc9e1658278fb3c2325c Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 21 Dec 2020 14:51:54 -0800 Subject: [PATCH 01/18] git add warp-transducer. --- .gitmodules | 4 ++++ third_party/warp_transducer/submodule | 1 + 2 files changed, 5 insertions(+) create mode 100644 .gitmodules create mode 160000 third_party/warp_transducer/submodule diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..41b1684ccd --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "third_party/warp_transducer/submodule"] + path = third_party/warp_transducer/submodule + url = https://github.com/HawkAaron/warp-transducer + branch = master diff --git a/third_party/warp_transducer/submodule b/third_party/warp_transducer/submodule new file mode 160000 index 0000000000..f546575109 --- /dev/null +++ b/third_party/warp_transducer/submodule @@ -0,0 +1 @@ +Subproject commit f546575109111c455354861a0567c8aa794208a2 From 1866cb8c397ce8737502f4a6cfe0a616c5c1b7de Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 21 Dec 2020 15:32:10 -0800 Subject: [PATCH 02/18] submodule update. --- .circleci/torchscript_bc_test/common.sh | 1 + .circleci/unittest/linux/scripts/install.sh | 1 + packaging/build_conda.sh | 1 + packaging/build_wheel.sh | 1 + 4 files changed, 4 insertions(+) diff --git a/.circleci/torchscript_bc_test/common.sh b/.circleci/torchscript_bc_test/common.sh index 24ad5e45fb..9b1d0e1ef4 100644 --- a/.circleci/torchscript_bc_test/common.sh +++ b/.circleci/torchscript_bc_test/common.sh @@ -66,5 +66,6 @@ build_master() { conda install -y -q pytorch "cpuonly" -c pytorch-nightly printf "* Installing torchaudio\n" cd "${_root_dir}" || exit 1 + git submodule update --init --recursive BUILD_SOX=1 python setup.py clean install } diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh index 27dec79251..c945cf9609 100755 --- a/.circleci/unittest/linux/scripts/install.sh +++ b/.circleci/unittest/linux/scripts/install.sh @@ -38,6 +38,7 @@ conda install -y -c "pytorch-${UPLOAD_CHANNEL}" pytorch ${cudatoolkit} # 2. Install torchaudio printf "* Installing torchaudio\n" +git submodule update --init --recursive BUILD_SOX=1 python setup.py install # 3. Install Test tools diff --git a/packaging/build_conda.sh b/packaging/build_conda.sh index 488d0fe871..090c5fa84e 100755 --- a/packaging/build_conda.sh +++ b/packaging/build_conda.sh @@ -9,4 +9,5 @@ export NO_CUDA_PACKAGE=1 setup_env 0.8.0 export SOURCE_ROOT_DIR="$PWD" setup_conda_pytorch_constraint +git submodule update --init --recursive conda build $CONDA_CHANNEL_FLAGS --no-anaconda-upload --python "$PYTHON_VERSION" packaging/torchaudio diff --git a/packaging/build_wheel.sh b/packaging/build_wheel.sh index ad45a08d59..47cfbdeba3 100755 --- a/packaging/build_wheel.sh +++ b/packaging/build_wheel.sh @@ -11,6 +11,7 @@ setup_wheel_python pip_install numpy future setup_pip_pytorch_version python setup.py clean +git submodule update --init --recursive if [[ "$OSTYPE" == "msys" ]]; then python_tag="$(echo "cp$PYTHON_VERSION" | tr -d '.')" python setup.py bdist_wheel --plat-name win_amd64 --python-tag $python_tag From e6903f5f46fad60e263f739fe96ec5a07222ceb3 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 21 Dec 2020 15:09:21 -0800 Subject: [PATCH 03/18] copy bindings inside torchaudio prototype. build static with torchaudio. deactivate OMP. add test. --- build_tools/setup_helpers/extension.py | 49 +++ .../common_utils/__init__.py | 1 + .../common_utils/case_utils.py | 6 + test/torchaudio_unittest/transducer_test.py | 291 ++++++++++++++++++ third_party/warp_transducer/CMakeLists.txt | 142 +++++++++ third_party/warp_transducer/binding.cpp | 47 +++ torchaudio/__init__.py | 2 + torchaudio/extension/__init__.py | 2 + torchaudio/extension/extension.py | 8 + torchaudio/prototype/__init__.py | 1 + torchaudio/prototype/transducer.py | 147 +++++++++ 11 files changed, 696 insertions(+) create mode 100644 test/torchaudio_unittest/transducer_test.py create mode 100755 third_party/warp_transducer/CMakeLists.txt create mode 100644 third_party/warp_transducer/binding.cpp create mode 100644 torchaudio/prototype/__init__.py create mode 100644 torchaudio/prototype/transducer.py diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index 8e6cd337cb..363c810660 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -132,6 +132,7 @@ def get_ext_modules(debug=False): extra_objects=_get_extra_objects(), extra_link_args=_get_ela(debug), ), + _get_transducer_module(), ] @@ -139,4 +140,52 @@ class BuildExtension(TorchBuildExtension): def build_extension(self, ext): if ext.name == _EXT_NAME and _BUILD_SOX: _build_third_party() + if ext.name == _TRANSDUCER_NAME: + _build_transducer() super().build_extension(ext) + + +_TRANSDUCER_NAME = '_warp_transducer' +_TP_TRANSDUCER_BASE_DIR = _ROOT_DIR / 'third_party' / 'warp_transducer' + + +def _build_transducer(): + build_dir = str(_TP_TRANSDUCER_BASE_DIR / 'submodule' / 'build') + os.makedirs(build_dir, exist_ok=True) + subprocess.run( + args=['cmake', str(_TP_TRANSDUCER_BASE_DIR), '-DWITH_OMP=OFF'], + cwd=build_dir, + check=True, + ) + subprocess.run( + args=['cmake', '--build', '.'], + cwd=build_dir, + check=True, + ) + + +def _get_transducer_module(): + extra_compile_args = [ + '-fPIC', + '-std=c++14', + ] + + librairies = ['warprnnt'] + + source_paths = [ + _TP_TRANSDUCER_BASE_DIR / 'binding.cpp', + _TP_TRANSDUCER_BASE_DIR / 'submodule' / 'pytorch_binding' / 'src' / 'binding.cpp', + ] + build_path = _TP_TRANSDUCER_BASE_DIR / 'submodule' / 'build' + include_path = _TP_TRANSDUCER_BASE_DIR / 'submodule' / 'include' + + return CppExtension( + name=_TRANSDUCER_NAME, + sources=[os.path.realpath(path) for path in source_paths], + libraries=librairies, + include_dirs=[os.path.realpath(include_path)], + library_dirs=[os.path.realpath(build_path)], + extra_compile_args=extra_compile_args, + extra_objects=[str(build_path / f'lib{lib}.a') for lib in librairies], + extra_link_args=['-Wl,-rpath,' + os.path.realpath(build_path)], + ) diff --git a/test/torchaudio_unittest/common_utils/__init__.py b/test/torchaudio_unittest/common_utils/__init__.py index 105a054864..33379a0a15 100644 --- a/test/torchaudio_unittest/common_utils/__init__.py +++ b/test/torchaudio_unittest/common_utils/__init__.py @@ -16,6 +16,7 @@ skipIfNoModule, skipIfNoExtension, skipIfNoSoxBackend, + skipIfNoTransducer, ) from .wav_utils import ( get_wav_data, diff --git a/test/torchaudio_unittest/common_utils/case_utils.py b/test/torchaudio_unittest/common_utils/case_utils.py index 2e0a17b5da..c2583cc721 100644 --- a/test/torchaudio_unittest/common_utils/case_utils.py +++ b/test/torchaudio_unittest/common_utils/case_utils.py @@ -75,3 +75,9 @@ def skipIfNoExtension(test_item): if 'TORCHAUDIO_TEST_FAIL_IF_NO_EXTENSION' in os.environ: raise RuntimeError('torchaudio C++ extension is not available.') return unittest.skip('torchaudio C++ extension is not available')(test_item) + + +skipIfNoTransducer = unittest.skipIf( + not is_module_available('_warp_transducer'), + '"_warp_transducer" is not available', +) diff --git a/test/torchaudio_unittest/transducer_test.py b/test/torchaudio_unittest/transducer_test.py new file mode 100644 index 0000000000..f2ac952d73 --- /dev/null +++ b/test/torchaudio_unittest/transducer_test.py @@ -0,0 +1,291 @@ +import numpy as np +import torch + +from torchaudio_unittest import common_utils +from torchaudio.prototype.transducer import RNNTLoss + + +def get_numpy_data_B2_T4_U3_D3(dtype=np.float32): + logits = np.array( + [ + 0.065357, + 0.787530, + 0.081592, + 0.529716, + 0.750675, + 0.754135, + 0.609764, + 0.868140, + 0.622532, + 0.668522, + 0.858039, + 0.164539, + 0.989780, + 0.944298, + 0.603168, + 0.946783, + 0.666203, + 0.286882, + 0.094184, + 0.366674, + 0.736168, + 0.166680, + 0.714154, + 0.399400, + 0.535982, + 0.291821, + 0.612642, + 0.324241, + 0.800764, + 0.524106, + 0.779195, + 0.183314, + 0.113745, + 0.240222, + 0.339470, + 0.134160, + 0.505562, + 0.051597, + 0.640290, + 0.430733, + 0.829473, + 0.177467, + 0.320700, + 0.042883, + 0.302803, + 0.675178, + 0.569537, + 0.558474, + 0.083132, + 0.060165, + 0.107958, + 0.748615, + 0.943918, + 0.486356, + 0.418199, + 0.652408, + 0.024243, + 0.134582, + 0.366342, + 0.295830, + 0.923670, + 0.689929, + 0.741898, + 0.250005, + 0.603430, + 0.987289, + 0.592606, + 0.884672, + 0.543450, + 0.660770, + 0.377128, + 0.358021, + ], + dtype=dtype, + ).reshape(2, 4, 3, 3) + + targets = np.array([[1, 2], [1, 1]], dtype=np.int32) + src_lengths = np.array([4, 4], dtype=np.int32) + tgt_lengths = np.array([2, 2], dtype=np.int32) + + blank = 0 + + ref_costs = np.array([4.2806528590890736, 3.9384369822503591], dtype=dtype) + + ref_gradients = np.array( + [ + -0.186844, + -0.062555, + 0.249399, + -0.203377, + 0.202399, + 0.000977, + -0.141016, + 0.079123, + 0.061893, + -0.011552, + -0.081280, + 0.092832, + -0.154257, + 0.229433, + -0.075176, + -0.246593, + 0.146405, + 0.100188, + -0.012918, + -0.061593, + 0.074512, + -0.055986, + 0.219831, + -0.163845, + -0.497627, + 0.209240, + 0.288387, + 0.013605, + -0.030220, + 0.016615, + 0.113925, + 0.062781, + -0.176706, + -0.667078, + 0.367659, + 0.299419, + -0.356344, + -0.055347, + 0.411691, + -0.096922, + 0.029459, + 0.067463, + -0.063518, + 0.027654, + 0.035863, + -0.154499, + -0.073942, + 0.228441, + -0.166790, + -0.000088, + 0.166878, + -0.172370, + 0.105565, + 0.066804, + 0.023875, + -0.118256, + 0.094381, + -0.104707, + -0.108934, + 0.213642, + -0.369844, + 0.180118, + 0.189726, + 0.025714, + -0.079462, + 0.053748, + 0.122328, + -0.238789, + 0.116460, + -0.598687, + 0.302203, + 0.296484, + ], + dtype=dtype, + ).reshape(2, 4, 3, 3) + + data = { + "logits": logits, + "targets": targets, + "src_lengths": src_lengths, + "tgt_lengths": tgt_lengths, + "blank": blank, + } + + return data, ref_costs, ref_gradients + + +def numpy_to_torch(data, device, requires_grad=True): + + logits = torch.from_numpy(data["logits"]) + targets = torch.from_numpy(data["targets"]) + src_lengths = torch.from_numpy(data["src_lengths"]) + tgt_lengths = torch.from_numpy(data["tgt_lengths"]) + + logits.requires_grad_(requires_grad) + + logits = logits.to(device) + + def grad_hook(grad): + logits.saved_grad = grad.clone() + + logits.register_hook(grad_hook) + + data["logits"] = logits + data["src_lengths"] = src_lengths + data["tgt_lengths"] = tgt_lengths + data["targets"] = targets + + return data + + +def compute_with_pytorch_transducer(data): + costs = RNNTLoss(blank=data["blank"], reduction="none")( + acts=data["logits_sparse"] if "logits_sparse" in data else data["logits"], + labels=data["targets"], + act_lens=data["src_lengths"], + label_lens=data["tgt_lengths"], + ) + + loss = torch.sum(costs) + loss.backward() + costs = costs.cpu().data.numpy() + gradients = data["logits"].saved_grad.cpu().data.numpy() + return costs, gradients + + +class TransducerTester: + def test_basic_backward(self): + # Test if example provided in README runs + # https://github.com/HawkAaron/warp-transducer + + rnnt_loss = RNNTLoss() + + acts = torch.FloatTensor( + [ + [ + [ + [0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.6, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.8, 0.1], + ], + [ + [0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.1, 0.1], + [0.7, 0.1, 0.2, 0.1, 0.1], + ], + ] + ] + ) + labels = torch.IntTensor([[1, 2]]) + act_length = torch.IntTensor([2]) + label_length = torch.IntTensor([2]) + + acts = acts.to(self.device) + labels = labels.to(self.device) + act_length = act_length.to(self.device) + label_length = label_length.to(self.device) + + acts.requires_grad_(True) + + loss = rnnt_loss(acts, labels, act_length, label_length) + loss.backward() + + def _test_costs_and_gradients( + self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2 + ): + logits_shape = data["logits"].shape + costs, gradients = compute_with_pytorch_transducer(data=data) + np.testing.assert_allclose(costs, ref_costs, atol=atol, rtol=rtol) + self.assertEqual(logits_shape, gradients.shape) + if not np.allclose(gradients, ref_gradients, atol=atol, rtol=rtol): + for b in range(len(gradients)): + T = data["src_lengths"][b] + U = data["tgt_lengths"][b] + for t in range(gradients.shape[1]): + for u in range(gradients.shape[2]): + np.testing.assert_allclose( + gradients[b, t, u], + ref_gradients[b, t, u], + atol=atol, + rtol=rtol, + err_msg=f"failed on b={b}, t={t}/T={T}, u={u}/U={U}", + ) + + def test_costs_and_gradients_B2_T4_U3_D3_fp32(self): + data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3(dtype=np.float32) + data = numpy_to_torch(data=data, device=self.device, requires_grad=True) + self._test_costs_and_gradients( + data=data, ref_costs=ref_costs, ref_gradients=ref_gradients + ) + + +@common_utils.skipIfNoTransducer +class CPUTransducerTester(TransducerTester, common_utils.PytorchTestCase): + device = "cpu" diff --git a/third_party/warp_transducer/CMakeLists.txt b/third_party/warp_transducer/CMakeLists.txt new file mode 100755 index 0000000000..deb8d78065 --- /dev/null +++ b/third_party/warp_transducer/CMakeLists.txt @@ -0,0 +1,142 @@ +# Modified from HawkAaron/warp-transducer/CMakeLists.txt to build statically + +IF (APPLE) + cmake_minimum_required(VERSION 3.4) +ELSE() + cmake_minimum_required(VERSION 2.8) +ENDIF() + +project(rnnt_release) + +IF (NOT APPLE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2") +ENDIF() + +IF (APPLE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2") + add_definitions(-DAPPLE) +ENDIF() + +include_directories(submodule/include) + +FIND_PACKAGE(CUDA) +MESSAGE(STATUS "cuda found ${CUDA_FOUND}") + +option(USE_NAIVE_KERNEL "use naive alpha-beta kernel" OFF) +option(DEBUG_TIME "output kernel time" OFF) +option(DEBUG_KERNEL "output alpha beta" OFF) +if (USE_NAIVE_KERNEL) + add_definitions(-DUSE_NAIVE_KERNEL) +endif() +if (DEBUG_TIME) + add_definitions(-DDEBUG_TIME) +endif() +if (DEBUG_KERNEL) + add_definitions(-DDEBUG_KERNEL) +endif() + +option(WITH_GPU "compile warp-rnnt with cuda." OFF) +option(WITH_OMP "compile warp-rnnt with openmp." ON) + +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler -fPIC") + +if(NOT WITH_OMP) + add_definitions(-DRNNT_DISABLE_OMP) +endif() + +if (WITH_OMP) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -fopenmp") +endif() + +# need to be at least 30 or __shfl_down in reduce wont compile +set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_30,code=sm_30 -O2") +set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_35,code=sm_35") + +set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_50,code=sm_50") +set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_52,code=sm_52") +IF(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5) + SET(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -D_MWAITXINTRIN_H_INCLUDED -D_FORCE_INLINES") +ENDIF() + +IF (CUDA_VERSION GREATER 7.6) + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_60,code=sm_60") + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_61,code=sm_61") + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_62,code=sm_62") +ENDIF() + +IF (CUDA_VERSION GREATER 8.9) + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_70,code=sm_70") +ENDIF() + +IF (CUDA_VERSION GREATER 9.9) + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_75,code=sm_75") +ENDIF() + +if (NOT APPLE) + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --std=c++11") + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS}") +ENDIF() + + +IF (APPLE) + EXEC_PROGRAM(uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION) + STRING(REGEX MATCH "[0-9]+" DARWIN_VERSION ${DARWIN_VERSION}) + MESSAGE(STATUS "DARWIN_VERSION=${DARWIN_VERSION}") + + #for el capitain have to use rpath + + IF (DARWIN_VERSION LESS 15) + set(CMAKE_SKIP_RPATH TRUE) + ENDIF () + +ELSE() + #always skip for linux + set(CMAKE_SKIP_RPATH TRUE) +ENDIF() + + +IF (WITH_GPU) + + MESSAGE(STATUS "Building static library with GPU support") + + CUDA_ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cu) + IF (!Torch_FOUND) + TARGET_LINK_LIBRARIES(warprnnt ${CUDA_curand_LIBRARY}) + ENDIF() + + cuda_add_executable(test_time_gpu submodule/tests/test_time.cu submodule/tests/random.cpp ) + TARGET_LINK_LIBRARIES(test_time_gpu warprnnt ${CUDA_curand_LIBRARY}) + SET_TARGET_PROPERTIES(test_time_gpu PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11") + + cuda_add_executable(test_gpu submodule/tests/test_gpu.cu submodule/tests/random.cpp ) + TARGET_LINK_LIBRARIES(test_gpu warprnnt ${CUDA_curand_LIBRARY}) + SET_TARGET_PROPERTIES(test_gpu PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11") + +ELSE() + MESSAGE(STATUS "Building static library with no GPU support") + + if (NOT APPLE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2") + ENDIF() + + ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cpp) + +ENDIF() + + +add_executable(test_cpu submodule/tests/test_cpu.cpp submodule/tests/random.cpp ) +TARGET_LINK_LIBRARIES(test_cpu warprnnt) +SET_TARGET_PROPERTIES(test_cpu PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11") + +add_executable(test_time submodule/tests/test_time.cpp submodule/tests/random.cpp ) +TARGET_LINK_LIBRARIES(test_time warprnnt) +SET_TARGET_PROPERTIES(test_time PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11") + +INSTALL(TARGETS warprnnt + RUNTIME DESTINATION "bin" + LIBRARY DESTINATION "lib" + ARCHIVE DESTINATION "lib") + +INSTALL(FILES submodule/include/rnnt.h DESTINATION "submodule/include") diff --git a/third_party/warp_transducer/binding.cpp b/third_party/warp_transducer/binding.cpp new file mode 100644 index 0000000000..31d3cb6253 --- /dev/null +++ b/third_party/warp_transducer/binding.cpp @@ -0,0 +1,47 @@ +#include +#include + +#include +#include "rnnt.h" + +int cpu_rnnt(torch::Tensor acts, + torch::Tensor labels, + torch::Tensor input_lengths, + torch::Tensor label_lengths, + torch::Tensor costs, + torch::Tensor grads, + int blank_label, + int num_threads); + +int64_t cpu_rnnt_torchbind(torch::Tensor acts, + torch::Tensor labels, + torch::Tensor input_lengths, + torch::Tensor label_lengths, + torch::Tensor costs, + torch::Tensor grads, + int64_t blank_label, + int64_t num_threads) { +return cpu_rnnt(acts, + labels, + input_lengths, + label_lengths, + costs, + grads, + blank_label, + num_threads); +} + +TORCH_LIBRARY(warprnnt_pytorch_warp_rnnt, m) { + m.def("rnnt(Tensor acts," + "Tensor labels," + "Tensor input_lengths," + "Tensor label_lengths," + "Tensor costs," + "Tensor grads," + "int blank_label," + "int num_threads) -> int"); +} + +TORCH_LIBRARY_IMPL(warprnnt_pytorch_warp_rnnt, CPU, m) { + m.impl("rnnt", &cpu_rnnt_torchbind); +} diff --git a/torchaudio/__init__.py b/torchaudio/__init__.py index be46d8a0f6..db8ac84f48 100644 --- a/torchaudio/__init__.py +++ b/torchaudio/__init__.py @@ -1,4 +1,6 @@ from . import extension +from . import prototype + from torchaudio._internal import module_utils as _mod_utils from torchaudio import ( compliance, diff --git a/torchaudio/extension/__init__.py b/torchaudio/extension/__init__.py index d9b6c76fac..53168c9a8b 100644 --- a/torchaudio/extension/__init__.py +++ b/torchaudio/extension/__init__.py @@ -1,7 +1,9 @@ from .extension import ( _init_extension, + _init_transducer_extension, ) _init_extension() +_init_transducer_extension() del _init_extension diff --git a/torchaudio/extension/extension.py b/torchaudio/extension/extension.py index b01ba13e39..2c92fe8cec 100644 --- a/torchaudio/extension/extension.py +++ b/torchaudio/extension/extension.py @@ -14,6 +14,14 @@ def _init_extension(): warnings.warn('torchaudio C++ extension is not available.') +def _init_transducer_extension(): + ext = '_warp_transducer' + if _mod_utils.is_module_available(ext): + _init_script_module(ext) + else: + warnings.warn('{ext} extension is not available.') + + def _init_script_module(module): path = importlib.util.find_spec(module).origin torch.classes.load_library(path) diff --git a/torchaudio/prototype/__init__.py b/torchaudio/prototype/__init__.py new file mode 100644 index 0000000000..955ebd9884 --- /dev/null +++ b/torchaudio/prototype/__init__.py @@ -0,0 +1 @@ +from . import transducer diff --git a/torchaudio/prototype/transducer.py b/torchaudio/prototype/transducer.py new file mode 100644 index 0000000000..77d85962b8 --- /dev/null +++ b/torchaudio/prototype/transducer.py @@ -0,0 +1,147 @@ +import torch +from torch.autograd import Function +from torch.nn import Module +from torchaudio._internal import ( + module_utils as _mod_utils, +) + +__all__ = ["rnnt_loss", "RNNTLoss"] + + +class _RNNT(Function): + @staticmethod + def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction): + """ + acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network + labels: 2 dimensional Tensor containing all the targets of the batch with zero padded + act_lens: Tensor of size (batch) containing size of each output sequence from the network + label_lens: Tensor of (batch) containing label length of each example + """ + + device = acts.device + certify_inputs(acts, labels, act_lens, label_lens) + + acts = acts.to("cpu") + labels = labels.to("cpu") + act_lens = act_lens.to("cpu") + label_lens = label_lens.to("cpu") + + loss_func = torch.ops.warprnnt_pytorch_warp_rnnt.rnnt + + grads = torch.zeros_like(acts) + minibatch_size = acts.size(0) + costs = torch.zeros(minibatch_size, dtype=acts.dtype) + + loss_func(acts, labels, act_lens, label_lens, costs, grads, blank, 0) + + if reduction in ["sum", "mean"]: + costs = costs.sum().unsqueeze_(-1) + if reduction == "mean": + costs /= minibatch_size + grads /= minibatch_size + + costs = costs.to(device) + ctx.grads = grads.to(device) + + return costs + + @staticmethod + def backward(ctx, grad_output): + grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads) + return ctx.grads.mul_(grad_output), None, None, None, None, None + + +@_mod_utils.requires_module('_warp_transducer') +def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction="mean"): + """RNN Transducer Loss + + Args: + acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network + labels: 2 dimensional Tensor containing all the targets of the batch with zero padded + act_lens: Tensor of size (batch) containing size of each output sequence from the network + label_lens: Tensor of (batch) containing label length of each example + blank (int, optional): blank label. Default: 0. + reduction (string, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, + 'mean': the output losses will be divided by the target lengths and + then the mean over the batch is taken. Default: 'mean' + """ + + # NOTE manually done log_softmax for CPU version, + # log_softmax is computed within GPU version. + acts = torch.nn.functional.log_softmax(acts, -1) + return _RNNT.apply(acts, labels, act_lens, label_lens, blank, reduction) + + +@_mod_utils.requires_module('_warp_transducer') +class RNNTLoss(Module): + """ + Parameters: + blank (int, optional): blank label. Default: 0. + reduction (string, optional): Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, + 'mean': the output losses will be divided by the target lengths and + then the mean over the batch is taken. Default: 'mean' + """ + + def __init__(self, blank=0, reduction="mean"): + super(RNNTLoss, self).__init__() + self.blank = blank + self.reduction = reduction + self.loss = _RNNT.apply + + def forward(self, acts, labels, act_lens, label_lens): + """ + acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network + labels: 2 dimensional Tensor containing all the targets of the batch with zero padded + act_lens: Tensor of size (batch) containing size of each output sequence from the network + label_lens: Tensor of (batch) containing label length of each example + """ + + # NOTE manually done log_softmax for CPU version, + # log_softmax is computed within GPU version. + acts = torch.nn.functional.log_softmax(acts, -1) + return self.loss(acts, labels, act_lens, label_lens, self.blank, self.reduction) + + +def check_type(var, t, name): + if var.dtype is not t: + raise TypeError("{} must be {}".format(name, t)) + + +def check_contiguous(var, name): + if not var.is_contiguous(): + raise ValueError("{} must be contiguous".format(name)) + + +def check_dim(var, dim, name): + if len(var.shape) != dim: + raise ValueError("{} must be {}D".format(name, dim)) + + +def certify_inputs(log_probs, labels, lengths, label_lengths): + # check_type(log_probs, torch.float32, "log_probs") + check_type(labels, torch.int32, "labels") + check_type(label_lengths, torch.int32, "label_lengths") + check_type(lengths, torch.int32, "lengths") + check_contiguous(log_probs, "log_probs") + check_contiguous(labels, "labels") + check_contiguous(label_lengths, "label_lengths") + check_contiguous(lengths, "lengths") + + if lengths.shape[0] != log_probs.shape[0]: + raise ValueError("must have a length per example.") + if label_lengths.shape[0] != log_probs.shape[0]: + raise ValueError("must have a label length per example.") + + check_dim(log_probs, 4, "log_probs") + check_dim(labels, 2, "labels") + check_dim(lengths, 1, "lenghts") + check_dim(label_lengths, 1, "label_lenghts") + max_T = torch.max(lengths) + max_U = torch.max(label_lengths) + T, U = log_probs.shape[1:3] + if T != max_T: + raise ValueError("Input length mismatch") + if U != max_U + 1: + raise ValueError("Output length mismatch") From 3394a551c633f8e9227555d89488be4eeb08ffea Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 30 Dec 2020 11:26:20 -0800 Subject: [PATCH 04/18] lint. --- build_tools/setup_helpers/extension.py | 6 +-- third_party/warp_transducer/binding.cpp | 42 +++++++++--------- torchaudio/prototype/transducer.py | 59 +++++++++++++++---------- 3 files changed, 59 insertions(+), 48 deletions(-) diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index 363c810660..d963d35495 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -170,7 +170,7 @@ def _get_transducer_module(): '-std=c++14', ] - librairies = ['warprnnt'] + libraries = ['warprnnt'] source_paths = [ _TP_TRANSDUCER_BASE_DIR / 'binding.cpp', @@ -182,10 +182,10 @@ def _get_transducer_module(): return CppExtension( name=_TRANSDUCER_NAME, sources=[os.path.realpath(path) for path in source_paths], - libraries=librairies, + libraries=libraries, include_dirs=[os.path.realpath(include_path)], library_dirs=[os.path.realpath(build_path)], extra_compile_args=extra_compile_args, - extra_objects=[str(build_path / f'lib{lib}.a') for lib in librairies], + extra_objects=[str(build_path / f'lib{lib}.a') for lib in libraries], extra_link_args=['-Wl,-rpath,' + os.path.realpath(build_path)], ) diff --git a/third_party/warp_transducer/binding.cpp b/third_party/warp_transducer/binding.cpp index 31d3cb6253..3280e264de 100644 --- a/third_party/warp_transducer/binding.cpp +++ b/third_party/warp_transducer/binding.cpp @@ -5,30 +5,30 @@ #include "rnnt.h" int cpu_rnnt(torch::Tensor acts, - torch::Tensor labels, - torch::Tensor input_lengths, - torch::Tensor label_lengths, - torch::Tensor costs, - torch::Tensor grads, - int blank_label, - int num_threads); + torch::Tensor labels, + torch::Tensor input_lengths, + torch::Tensor label_lengths, + torch::Tensor costs, + torch::Tensor grads, + int blank_label, + int num_threads); int64_t cpu_rnnt_torchbind(torch::Tensor acts, - torch::Tensor labels, - torch::Tensor input_lengths, - torch::Tensor label_lengths, - torch::Tensor costs, - torch::Tensor grads, - int64_t blank_label, - int64_t num_threads) { + torch::Tensor labels, + torch::Tensor input_lengths, + torch::Tensor label_lengths, + torch::Tensor costs, + torch::Tensor grads, + int64_t blank_label, + int64_t num_threads) { return cpu_rnnt(acts, - labels, - input_lengths, - label_lengths, - costs, - grads, - blank_label, - num_threads); + labels, + input_lengths, + label_lengths, + costs, + grads, + blank_label, + num_threads); } TORCH_LIBRARY(warprnnt_pytorch_warp_rnnt, m) { diff --git a/torchaudio/prototype/transducer.py b/torchaudio/prototype/transducer.py index 77d85962b8..26e8004b83 100644 --- a/torchaudio/prototype/transducer.py +++ b/torchaudio/prototype/transducer.py @@ -12,10 +12,12 @@ class _RNNT(Function): @staticmethod def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction): """ - acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network - labels: 2 dimensional Tensor containing all the targets of the batch with zero padded - act_lens: Tensor of size (batch) containing size of each output sequence from the network - label_lens: Tensor of (batch) containing label length of each example + Args: + acts (Tensor): Tensor of dimension (batch, time, label, class) containing output from network + before applying ``torch.nn.functional.log_softmax``. + labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero + act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence + label_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence """ device = acts.device @@ -53,18 +55,25 @@ def backward(ctx, grad_output): @_mod_utils.requires_module('_warp_transducer') def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction="mean"): - """RNN Transducer Loss + """Compute the RNN Transducer Loss. + + The RNN Transducer loss (`Graves 2012 `__) extends the CTC loss by defining + a distribution over output sequences of all lengths, and by jointly modelling both input-output and output-output + dependencies. + + The implementation uses `warp-transducer `__. Args: - acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network - labels: 2 dimensional Tensor containing all the targets of the batch with zero padded - act_lens: Tensor of size (batch) containing size of each output sequence from the network - label_lens: Tensor of (batch) containing label length of each example - blank (int, optional): blank label. Default: 0. - reduction (string, optional): Specifies the reduction to apply to the output: - 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, - 'mean': the output losses will be divided by the target lengths and - then the mean over the batch is taken. Default: 'mean' + acts (Tensor): Tensor of dimension (batch, time, label, class) containing output from network + before applying ``torch.nn.functional.log_softmax``. + labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero + act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence + label_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence + blank (int): blank label. (Default: ``0``) + reduction (string): If ``'sum'``, the output losses will be summed. + If ``'mean'``, the output losses will be divided by the target lengths and + then the mean over the batch is taken. If ``'none'``, no reduction will be applied. + (Default: ``'mean'``) """ # NOTE manually done log_softmax for CPU version, @@ -76,12 +85,12 @@ def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction="mean"): @_mod_utils.requires_module('_warp_transducer') class RNNTLoss(Module): """ - Parameters: - blank (int, optional): blank label. Default: 0. - reduction (string, optional): Specifies the reduction to apply to the output: - 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, - 'mean': the output losses will be divided by the target lengths and - then the mean over the batch is taken. Default: 'mean' + Args: + blank (int): blank label. (Default: ``0``) + reduction (string): If ``'sum'``, the output losses will be summed. + If ``'mean'``, the output losses will be divided by the target lengths and + then the mean over the batch is taken. If ``'none'``, no reduction will be applied. + (Default: ``'mean'``) """ def __init__(self, blank=0, reduction="mean"): @@ -92,10 +101,12 @@ def __init__(self, blank=0, reduction="mean"): def forward(self, acts, labels, act_lens, label_lens): """ - acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network - labels: 2 dimensional Tensor containing all the targets of the batch with zero padded - act_lens: Tensor of size (batch) containing size of each output sequence from the network - label_lens: Tensor of (batch) containing label length of each example + Args: + acts (Tensor): Tensor of dimension (batch, time, label, class) containing output from network + before applying ``torch.nn.functional.log_softmax``. + labels (Tensor): Tensor of dimension (batch, max label length) containing the labels padded by zero + act_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence + label_lens (Tensor): Tensor of dimension (batch) containing the length of each output sequence """ # NOTE manually done log_softmax for CPU version, From bde1aba0141dc228cf00894afe5e253615d884ea Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 30 Dec 2020 12:22:13 -0800 Subject: [PATCH 05/18] shorten cmake file. --- third_party/warp_transducer/CMakeLists.txt | 138 ++++----------------- 1 file changed, 22 insertions(+), 116 deletions(-) diff --git a/third_party/warp_transducer/CMakeLists.txt b/third_party/warp_transducer/CMakeLists.txt index deb8d78065..873a7acb2e 100755 --- a/third_party/warp_transducer/CMakeLists.txt +++ b/third_party/warp_transducer/CMakeLists.txt @@ -1,138 +1,44 @@ -# Modified from HawkAaron/warp-transducer/CMakeLists.txt to build statically - -IF (APPLE) - cmake_minimum_required(VERSION 3.4) +IF(APPLE) + CMAKE_MINIMUM_REQUIRED(VERSION 3.4) ELSE() - cmake_minimum_required(VERSION 2.8) -ENDIF() - -project(rnnt_release) - -IF (NOT APPLE) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2") -ENDIF() - -IF (APPLE) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2") - add_definitions(-DAPPLE) -ENDIF() - -include_directories(submodule/include) - -FIND_PACKAGE(CUDA) -MESSAGE(STATUS "cuda found ${CUDA_FOUND}") - -option(USE_NAIVE_KERNEL "use naive alpha-beta kernel" OFF) -option(DEBUG_TIME "output kernel time" OFF) -option(DEBUG_KERNEL "output alpha beta" OFF) -if (USE_NAIVE_KERNEL) - add_definitions(-DUSE_NAIVE_KERNEL) -endif() -if (DEBUG_TIME) - add_definitions(-DDEBUG_TIME) -endif() -if (DEBUG_KERNEL) - add_definitions(-DDEBUG_KERNEL) -endif() - -option(WITH_GPU "compile warp-rnnt with cuda." OFF) -option(WITH_OMP "compile warp-rnnt with openmp." ON) - -set(CMAKE_POSITION_INDEPENDENT_CODE ON) -set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler -fPIC") - -if(NOT WITH_OMP) - add_definitions(-DRNNT_DISABLE_OMP) -endif() - -if (WITH_OMP) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -fopenmp") -endif() - -# need to be at least 30 or __shfl_down in reduce wont compile -set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_30,code=sm_30 -O2") -set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_35,code=sm_35") - -set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_50,code=sm_50") -set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_52,code=sm_52") -IF(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5) - SET(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -D_MWAITXINTRIN_H_INCLUDED -D_FORCE_INLINES") + CMAKE_MINIMUM_REQUIRED(VERSION 2.8) ENDIF() -IF (CUDA_VERSION GREATER 7.6) - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_60,code=sm_60") - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_61,code=sm_61") - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_62,code=sm_62") -ENDIF() +PROJECT(rnnt_release) -IF (CUDA_VERSION GREATER 8.9) - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_70,code=sm_70") +IF(APPLE) + SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2") + ADD_DEFINITIONS(-DAPPLE) +ELSE() + SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2") ENDIF() -IF (CUDA_VERSION GREATER 9.9) - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_75,code=sm_75") -ENDIF() +INCLUDE_DIRECTORIES(submodule/include) -if (NOT APPLE) - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --std=c++11") - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS}") -ENDIF() +SET(CMAKE_POSITION_INDEPENDENT_CODE ON) +ADD_DEFINITIONS(-DRNNT_DISABLE_OMP) -IF (APPLE) +IF(APPLE) EXEC_PROGRAM(uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION) STRING(REGEX MATCH "[0-9]+" DARWIN_VERSION ${DARWIN_VERSION}) MESSAGE(STATUS "DARWIN_VERSION=${DARWIN_VERSION}") - #for el capitain have to use rpath - - IF (DARWIN_VERSION LESS 15) - set(CMAKE_SKIP_RPATH TRUE) - ENDIF () - -ELSE() - #always skip for linux - set(CMAKE_SKIP_RPATH TRUE) -ENDIF() - - -IF (WITH_GPU) - - MESSAGE(STATUS "Building static library with GPU support") - - CUDA_ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cu) - IF (!Torch_FOUND) - TARGET_LINK_LIBRARIES(warprnnt ${CUDA_curand_LIBRARY}) + # for el capitain have to use rpath + IF(DARWIN_VERSION LESS 15) + SET(CMAKE_SKIP_RPATH TRUE) ENDIF() - cuda_add_executable(test_time_gpu submodule/tests/test_time.cu submodule/tests/random.cpp ) - TARGET_LINK_LIBRARIES(test_time_gpu warprnnt ${CUDA_curand_LIBRARY}) - SET_TARGET_PROPERTIES(test_time_gpu PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11") - - cuda_add_executable(test_gpu submodule/tests/test_gpu.cu submodule/tests/random.cpp ) - TARGET_LINK_LIBRARIES(test_gpu warprnnt ${CUDA_curand_LIBRARY}) - SET_TARGET_PROPERTIES(test_gpu PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11") - ELSE() - MESSAGE(STATUS "Building static library with no GPU support") - - if (NOT APPLE) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2") - ENDIF() - - ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cpp) - + # always skip for linux + SET(CMAKE_SKIP_RPATH TRUE) ENDIF() +IF(NOT APPLE) + SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2") +ENDIF() -add_executable(test_cpu submodule/tests/test_cpu.cpp submodule/tests/random.cpp ) -TARGET_LINK_LIBRARIES(test_cpu warprnnt) -SET_TARGET_PROPERTIES(test_cpu PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11") - -add_executable(test_time submodule/tests/test_time.cpp submodule/tests/random.cpp ) -TARGET_LINK_LIBRARIES(test_time warprnnt) -SET_TARGET_PROPERTIES(test_time PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11") +ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cpp) INSTALL(TARGETS warprnnt RUNTIME DESTINATION "bin" From 0e637564a9e1c1e3f29638a065812e75b80a7698 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 31 Dec 2020 11:06:35 -0800 Subject: [PATCH 06/18] build only one extension. --- build_tools/setup_helpers/extension.py | 53 ++------------- third_party/CMakeLists.txt | 2 + third_party/warp_transducer/binding.cpp | 47 -------------- torchaudio/csrc/transducer.cpp | 85 +++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 96 deletions(-) delete mode 100644 third_party/warp_transducer/binding.cpp create mode 100644 torchaudio/csrc/transducer.cpp diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index d963d35495..27d89150ce 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -18,6 +18,8 @@ _CSRC_DIR = _ROOT_DIR / 'torchaudio' / 'csrc' _TP_BASE_DIR = _ROOT_DIR / 'third_party' _TP_INSTALL_DIR = _TP_BASE_DIR / 'install' +_TRANSDUCER_BUILD_DIR = _TP_BASE_DIR / 'build' / 'warp_transducer' +_TRANSDUCER_BASE_DIR = _TP_BASE_DIR / 'warp_transducer' / 'submodule' def _get_build_sox(): @@ -64,6 +66,7 @@ def _get_srcs(): def _get_include_dirs(): dirs = [ str(_ROOT_DIR), + str(_TRANSDUCER_BASE_DIR / 'include'), ] if _BUILD_SOX: dirs.append(str(_TP_INSTALL_DIR / 'include')) @@ -94,6 +97,7 @@ def _get_extra_objects(): ] for lib in libs: objs.append(str(_TP_INSTALL_DIR / 'lib' / lib)) + objs.append(str(_TRANSDUCER_BUILD_DIR / 'libwarprnnt.a')) return objs @@ -132,7 +136,6 @@ def get_ext_modules(debug=False): extra_objects=_get_extra_objects(), extra_link_args=_get_ela(debug), ), - _get_transducer_module(), ] @@ -140,52 +143,4 @@ class BuildExtension(TorchBuildExtension): def build_extension(self, ext): if ext.name == _EXT_NAME and _BUILD_SOX: _build_third_party() - if ext.name == _TRANSDUCER_NAME: - _build_transducer() super().build_extension(ext) - - -_TRANSDUCER_NAME = '_warp_transducer' -_TP_TRANSDUCER_BASE_DIR = _ROOT_DIR / 'third_party' / 'warp_transducer' - - -def _build_transducer(): - build_dir = str(_TP_TRANSDUCER_BASE_DIR / 'submodule' / 'build') - os.makedirs(build_dir, exist_ok=True) - subprocess.run( - args=['cmake', str(_TP_TRANSDUCER_BASE_DIR), '-DWITH_OMP=OFF'], - cwd=build_dir, - check=True, - ) - subprocess.run( - args=['cmake', '--build', '.'], - cwd=build_dir, - check=True, - ) - - -def _get_transducer_module(): - extra_compile_args = [ - '-fPIC', - '-std=c++14', - ] - - libraries = ['warprnnt'] - - source_paths = [ - _TP_TRANSDUCER_BASE_DIR / 'binding.cpp', - _TP_TRANSDUCER_BASE_DIR / 'submodule' / 'pytorch_binding' / 'src' / 'binding.cpp', - ] - build_path = _TP_TRANSDUCER_BASE_DIR / 'submodule' / 'build' - include_path = _TP_TRANSDUCER_BASE_DIR / 'submodule' / 'include' - - return CppExtension( - name=_TRANSDUCER_NAME, - sources=[os.path.realpath(path) for path in source_paths], - libraries=libraries, - include_dirs=[os.path.realpath(include_path)], - library_dirs=[os.path.realpath(build_path)], - extra_compile_args=extra_compile_args, - extra_objects=[str(build_path / f'lib{lib}.a') for lib in libraries], - extra_link_args=['-Wl,-rpath,' + os.path.realpath(build_path)], - ) diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index ad0eac582b..187ebfa260 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -88,3 +88,5 @@ ExternalProject_Add(libsox # See https://github.com/pytorch/audio/pull/1026 CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp ) + +add_subdirectory(warp_transducer) diff --git a/third_party/warp_transducer/binding.cpp b/third_party/warp_transducer/binding.cpp deleted file mode 100644 index 3280e264de..0000000000 --- a/third_party/warp_transducer/binding.cpp +++ /dev/null @@ -1,47 +0,0 @@ -#include -#include - -#include -#include "rnnt.h" - -int cpu_rnnt(torch::Tensor acts, - torch::Tensor labels, - torch::Tensor input_lengths, - torch::Tensor label_lengths, - torch::Tensor costs, - torch::Tensor grads, - int blank_label, - int num_threads); - -int64_t cpu_rnnt_torchbind(torch::Tensor acts, - torch::Tensor labels, - torch::Tensor input_lengths, - torch::Tensor label_lengths, - torch::Tensor costs, - torch::Tensor grads, - int64_t blank_label, - int64_t num_threads) { -return cpu_rnnt(acts, - labels, - input_lengths, - label_lengths, - costs, - grads, - blank_label, - num_threads); -} - -TORCH_LIBRARY(warprnnt_pytorch_warp_rnnt, m) { - m.def("rnnt(Tensor acts," - "Tensor labels," - "Tensor input_lengths," - "Tensor label_lengths," - "Tensor costs," - "Tensor grads," - "int blank_label," - "int num_threads) -> int"); -} - -TORCH_LIBRARY_IMPL(warprnnt_pytorch_warp_rnnt, CPU, m) { - m.impl("rnnt", &cpu_rnnt_torchbind); -} diff --git a/torchaudio/csrc/transducer.cpp b/torchaudio/csrc/transducer.cpp new file mode 100644 index 0000000000..07621b3022 --- /dev/null +++ b/torchaudio/csrc/transducer.cpp @@ -0,0 +1,85 @@ +#include +#include + +#include +#include "rnnt.h" + +int64_t cpu_rnnt(torch::Tensor acts, + torch::Tensor labels, + torch::Tensor input_lengths, + torch::Tensor label_lengths, + torch::Tensor costs, + torch::Tensor grads, + int64_t blank_label, + int64_t num_threads) { + + int maxT = acts.size(1); + int maxU = acts.size(2); + int minibatch_size = acts.size(0); + int alphabet_size = acts.size(3); + + rnntOptions options; + memset(&options, 0, sizeof(options)); + options.maxT = maxT; + options.maxU = maxU; + options.blank_label = blank_label; + options.batch_first = true; + options.loc = RNNT_CPU; + options.num_threads = num_threads; + + // have to use at least one + options.num_threads = std::max(options.num_threads, (unsigned int) 1); + + size_t cpu_size_bytes = 0; + switch (acts.type().scalarType()) { + case torch::ScalarType::Float: + { + get_workspace_size(maxT, maxU, minibatch_size, + false, &cpu_size_bytes); + + float* cpu_workspace = (float*) new unsigned char[cpu_size_bytes]; + compute_rnnt_loss(acts.data(), grads.data(), + labels.data(), label_lengths.data(), + input_lengths.data(), alphabet_size, + minibatch_size, costs.data(), + cpu_workspace, options); + + delete cpu_workspace; + return 0; + } + case torch::ScalarType::Double: + { + get_workspace_size(maxT, maxU, minibatch_size, + false, &cpu_size_bytes, + sizeof(double)); + + double* cpu_workspace = (double*) new unsigned char[cpu_size_bytes]; + compute_rnnt_loss_fp64(acts.data(), grads.data(), + labels.data(), label_lengths.data(), + input_lengths.data(), alphabet_size, + minibatch_size, costs.data(), + cpu_workspace, options); + + delete cpu_workspace; + return 0; + } + default: + std::cerr << __FILE__ << ':' << __LINE__ << ": " << "unsupported data type" << std::endl; + } + return -1; +} + +TORCH_LIBRARY(warprnnt_pytorch_warp_rnnt, m) { + m.def("rnnt(Tensor acts," + "Tensor labels," + "Tensor input_lengths," + "Tensor label_lengths," + "Tensor costs," + "Tensor grads," + "int blank_label," + "int num_threads) -> int"); +} + +TORCH_LIBRARY_IMPL(warprnnt_pytorch_warp_rnnt, CPU, m) { + m.impl("rnnt", &cpu_rnnt); +} From 8dab3737a5aa58df7f1fa3a8a394026951ac3ba8 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 31 Dec 2020 11:30:18 -0800 Subject: [PATCH 07/18] move submodule update. --- .circleci/unittest/linux/scripts/install.sh | 1 - .circleci/unittest/linux/scripts/setup_env.sh | 1 + packaging/build_conda.sh | 1 - packaging/build_wheel.sh | 1 - packaging/pkg_helpers.bash | 1 + 5 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh index c945cf9609..27dec79251 100755 --- a/.circleci/unittest/linux/scripts/install.sh +++ b/.circleci/unittest/linux/scripts/install.sh @@ -38,7 +38,6 @@ conda install -y -c "pytorch-${UPLOAD_CHANNEL}" pytorch ${cudatoolkit} # 2. Install torchaudio printf "* Installing torchaudio\n" -git submodule update --init --recursive BUILD_SOX=1 python setup.py install # 3. Install Test tools diff --git a/.circleci/unittest/linux/scripts/setup_env.sh b/.circleci/unittest/linux/scripts/setup_env.sh index 26292a00d5..f56b211ac9 100755 --- a/.circleci/unittest/linux/scripts/setup_env.sh +++ b/.circleci/unittest/linux/scripts/setup_env.sh @@ -43,6 +43,7 @@ conda activate "${env_dir}" pip --quiet install cmake ninja # 4. Buld codecs +git submodule update --init --recursive mkdir -p third_party/build ( cd third_party/build diff --git a/packaging/build_conda.sh b/packaging/build_conda.sh index 090c5fa84e..488d0fe871 100755 --- a/packaging/build_conda.sh +++ b/packaging/build_conda.sh @@ -9,5 +9,4 @@ export NO_CUDA_PACKAGE=1 setup_env 0.8.0 export SOURCE_ROOT_DIR="$PWD" setup_conda_pytorch_constraint -git submodule update --init --recursive conda build $CONDA_CHANNEL_FLAGS --no-anaconda-upload --python "$PYTHON_VERSION" packaging/torchaudio diff --git a/packaging/build_wheel.sh b/packaging/build_wheel.sh index 47cfbdeba3..ad45a08d59 100755 --- a/packaging/build_wheel.sh +++ b/packaging/build_wheel.sh @@ -11,7 +11,6 @@ setup_wheel_python pip_install numpy future setup_pip_pytorch_version python setup.py clean -git submodule update --init --recursive if [[ "$OSTYPE" == "msys" ]]; then python_tag="$(echo "cp$PYTHON_VERSION" | tr -d '.')" python setup.py bdist_wheel --plat-name win_amd64 --python-tag $python_tag diff --git a/packaging/pkg_helpers.bash b/packaging/pkg_helpers.bash index 3316f789e2..635bfa3a14 100644 --- a/packaging/pkg_helpers.bash +++ b/packaging/pkg_helpers.bash @@ -103,6 +103,7 @@ setup_macos() { # # Usage: setup_env 0.2.0 setup_env() { + git submodule update --init --recursive setup_cuda setup_build_version "$1" setup_macos From 986e54f4ae473c4d40c649f792c0c5b83a9a1bd6 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 31 Dec 2020 11:40:38 -0800 Subject: [PATCH 08/18] print dir. --- third_party/warp_transducer/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/warp_transducer/CMakeLists.txt b/third_party/warp_transducer/CMakeLists.txt index 873a7acb2e..2703bfa532 100755 --- a/third_party/warp_transducer/CMakeLists.txt +++ b/third_party/warp_transducer/CMakeLists.txt @@ -1,3 +1,4 @@ +MESSAGE("path to cmake current source dir: ${CMAKE_CURRENT_SOURCE_DIR}") IF(APPLE) CMAKE_MINIMUM_REQUIRED(VERSION 3.4) ELSE() From 25e07ac343884c2a95bfab44595ffe6adf712b48 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 31 Dec 2020 13:11:53 -0800 Subject: [PATCH 09/18] remove runtime destination, and use archives. --- third_party/warp_transducer/CMakeLists.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/third_party/warp_transducer/CMakeLists.txt b/third_party/warp_transducer/CMakeLists.txt index 2703bfa532..9cac5476ac 100755 --- a/third_party/warp_transducer/CMakeLists.txt +++ b/third_party/warp_transducer/CMakeLists.txt @@ -42,8 +42,7 @@ ENDIF() ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cpp) INSTALL(TARGETS warprnnt - RUNTIME DESTINATION "bin" LIBRARY DESTINATION "lib" - ARCHIVE DESTINATION "lib") + ARCHIVE DESTINATION "archives") INSTALL(FILES submodule/include/rnnt.h DESTINATION "submodule/include") From 2674bea248b9b61bbbd922227242672d56b4521c Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 31 Dec 2020 13:22:11 -0800 Subject: [PATCH 10/18] rename. --- .gitmodules | 2 +- build_tools/setup_helpers/extension.py | 4 ++-- third_party/CMakeLists.txt | 2 +- third_party/{warp_transducer => transducer}/CMakeLists.txt | 0 third_party/{warp_transducer => transducer}/submodule | 0 5 files changed, 4 insertions(+), 4 deletions(-) rename third_party/{warp_transducer => transducer}/CMakeLists.txt (100%) rename third_party/{warp_transducer => transducer}/submodule (100%) diff --git a/.gitmodules b/.gitmodules index 41b1684ccd..43202ee4ff 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "third_party/warp_transducer/submodule"] - path = third_party/warp_transducer/submodule + path = third_party/transducer/submodule url = https://github.com/HawkAaron/warp-transducer branch = master diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index 27d89150ce..01b768b570 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -18,8 +18,8 @@ _CSRC_DIR = _ROOT_DIR / 'torchaudio' / 'csrc' _TP_BASE_DIR = _ROOT_DIR / 'third_party' _TP_INSTALL_DIR = _TP_BASE_DIR / 'install' -_TRANSDUCER_BUILD_DIR = _TP_BASE_DIR / 'build' / 'warp_transducer' -_TRANSDUCER_BASE_DIR = _TP_BASE_DIR / 'warp_transducer' / 'submodule' +_TRANSDUCER_BUILD_DIR = _TP_BASE_DIR / 'build' / 'transducer' +_TRANSDUCER_BASE_DIR = _TP_BASE_DIR / 'transducer' / 'submodule' def _get_build_sox(): diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 187ebfa260..b43d9fc594 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -89,4 +89,4 @@ ExternalProject_Add(libsox CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp ) -add_subdirectory(warp_transducer) +add_subdirectory(transducer) diff --git a/third_party/warp_transducer/CMakeLists.txt b/third_party/transducer/CMakeLists.txt similarity index 100% rename from third_party/warp_transducer/CMakeLists.txt rename to third_party/transducer/CMakeLists.txt diff --git a/third_party/warp_transducer/submodule b/third_party/transducer/submodule similarity index 100% rename from third_party/warp_transducer/submodule rename to third_party/transducer/submodule From ed0a4c5d2561ab6ffb1948f6732cdf3f45565b45 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 31 Dec 2020 13:38:05 -0800 Subject: [PATCH 11/18] use torchaudio namespace. --- .../common_utils/__init__.py | 1 - .../common_utils/case_utils.py | 6 ---- test/torchaudio_unittest/transducer_test.py | 2 +- torchaudio/csrc/register.cpp | 12 +++++++ torchaudio/csrc/transducer.cpp | 31 ++++++------------- torchaudio/extension/__init__.py | 2 -- torchaudio/extension/extension.py | 8 ----- torchaudio/prototype/transducer.py | 2 +- 8 files changed, 24 insertions(+), 40 deletions(-) diff --git a/test/torchaudio_unittest/common_utils/__init__.py b/test/torchaudio_unittest/common_utils/__init__.py index 33379a0a15..105a054864 100644 --- a/test/torchaudio_unittest/common_utils/__init__.py +++ b/test/torchaudio_unittest/common_utils/__init__.py @@ -16,7 +16,6 @@ skipIfNoModule, skipIfNoExtension, skipIfNoSoxBackend, - skipIfNoTransducer, ) from .wav_utils import ( get_wav_data, diff --git a/test/torchaudio_unittest/common_utils/case_utils.py b/test/torchaudio_unittest/common_utils/case_utils.py index c2583cc721..2e0a17b5da 100644 --- a/test/torchaudio_unittest/common_utils/case_utils.py +++ b/test/torchaudio_unittest/common_utils/case_utils.py @@ -75,9 +75,3 @@ def skipIfNoExtension(test_item): if 'TORCHAUDIO_TEST_FAIL_IF_NO_EXTENSION' in os.environ: raise RuntimeError('torchaudio C++ extension is not available.') return unittest.skip('torchaudio C++ extension is not available')(test_item) - - -skipIfNoTransducer = unittest.skipIf( - not is_module_available('_warp_transducer'), - '"_warp_transducer" is not available', -) diff --git a/test/torchaudio_unittest/transducer_test.py b/test/torchaudio_unittest/transducer_test.py index f2ac952d73..a3fa375259 100644 --- a/test/torchaudio_unittest/transducer_test.py +++ b/test/torchaudio_unittest/transducer_test.py @@ -286,6 +286,6 @@ def test_costs_and_gradients_B2_T4_U3_D3_fp32(self): ) -@common_utils.skipIfNoTransducer +@common_utils.skipIfNoExtension class CPUTransducerTester(TransducerTester, common_utils.PytorchTestCase): device = "cpu" diff --git a/torchaudio/csrc/register.cpp b/torchaudio/csrc/register.cpp index 0eb73f1daf..e977bca883 100644 --- a/torchaudio/csrc/register.cpp +++ b/torchaudio/csrc/register.cpp @@ -77,5 +77,17 @@ TORCH_LIBRARY(torchaudio, m) { m.def( "torchaudio::sox_effects_apply_effects_file", &torchaudio::sox_effects::apply_effects_file); + + ////////////////////////////////////////////////////////////////////////////// + // transducer.cpp + ////////////////////////////////////////////////////////////////////////////// + m.def("rnnt(Tensor acts," + "Tensor labels," + "Tensor input_lengths," + "Tensor label_lengths," + "Tensor costs," + "Tensor grads," + "int blank_label," + "int num_threads) -> int"); } #endif diff --git a/torchaudio/csrc/transducer.cpp b/torchaudio/csrc/transducer.cpp index 07621b3022..04408d2a64 100644 --- a/torchaudio/csrc/transducer.cpp +++ b/torchaudio/csrc/transducer.cpp @@ -4,14 +4,14 @@ #include #include "rnnt.h" -int64_t cpu_rnnt(torch::Tensor acts, - torch::Tensor labels, - torch::Tensor input_lengths, - torch::Tensor label_lengths, - torch::Tensor costs, - torch::Tensor grads, - int64_t blank_label, - int64_t num_threads) { +int64_t cpu_rnnt_loss(torch::Tensor acts, + torch::Tensor labels, + torch::Tensor input_lengths, + torch::Tensor label_lengths, + torch::Tensor costs, + torch::Tensor grads, + int64_t blank_label, + int64_t num_threads) { int maxT = acts.size(1); int maxU = acts.size(2); @@ -69,17 +69,6 @@ int64_t cpu_rnnt(torch::Tensor acts, return -1; } -TORCH_LIBRARY(warprnnt_pytorch_warp_rnnt, m) { - m.def("rnnt(Tensor acts," - "Tensor labels," - "Tensor input_lengths," - "Tensor label_lengths," - "Tensor costs," - "Tensor grads," - "int blank_label," - "int num_threads) -> int"); -} - -TORCH_LIBRARY_IMPL(warprnnt_pytorch_warp_rnnt, CPU, m) { - m.impl("rnnt", &cpu_rnnt); +TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("rnnt_loss", &cpu_rnnt_loss); } diff --git a/torchaudio/extension/__init__.py b/torchaudio/extension/__init__.py index 53168c9a8b..d9b6c76fac 100644 --- a/torchaudio/extension/__init__.py +++ b/torchaudio/extension/__init__.py @@ -1,9 +1,7 @@ from .extension import ( _init_extension, - _init_transducer_extension, ) _init_extension() -_init_transducer_extension() del _init_extension diff --git a/torchaudio/extension/extension.py b/torchaudio/extension/extension.py index 2c92fe8cec..b01ba13e39 100644 --- a/torchaudio/extension/extension.py +++ b/torchaudio/extension/extension.py @@ -14,14 +14,6 @@ def _init_extension(): warnings.warn('torchaudio C++ extension is not available.') -def _init_transducer_extension(): - ext = '_warp_transducer' - if _mod_utils.is_module_available(ext): - _init_script_module(ext) - else: - warnings.warn('{ext} extension is not available.') - - def _init_script_module(module): path = importlib.util.find_spec(module).origin torch.classes.load_library(path) diff --git a/torchaudio/prototype/transducer.py b/torchaudio/prototype/transducer.py index 26e8004b83..d1ccd20956 100644 --- a/torchaudio/prototype/transducer.py +++ b/torchaudio/prototype/transducer.py @@ -28,7 +28,7 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction): act_lens = act_lens.to("cpu") label_lens = label_lens.to("cpu") - loss_func = torch.ops.warprnnt_pytorch_warp_rnnt.rnnt + loss_func = torch.ops.torchaudio.rnnt_loss grads = torch.zeros_like(acts) minibatch_size = acts.size(0) From f1d51f69e3d5cd4f571ef7438610b430608c56d3 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 31 Dec 2020 13:47:34 -0800 Subject: [PATCH 12/18] replace certify by check. remove float32 check. --- torchaudio/prototype/transducer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchaudio/prototype/transducer.py b/torchaudio/prototype/transducer.py index d1ccd20956..b7766c6670 100644 --- a/torchaudio/prototype/transducer.py +++ b/torchaudio/prototype/transducer.py @@ -21,7 +21,7 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction): """ device = acts.device - certify_inputs(acts, labels, act_lens, label_lens) + check_inputs(acts, labels, act_lens, label_lens) acts = acts.to("cpu") labels = labels.to("cpu") @@ -130,8 +130,7 @@ def check_dim(var, dim, name): raise ValueError("{} must be {}D".format(name, dim)) -def certify_inputs(log_probs, labels, lengths, label_lengths): - # check_type(log_probs, torch.float32, "log_probs") +def check_inputs(log_probs, labels, lengths, label_lengths): check_type(labels, torch.int32, "labels") check_type(label_lengths, torch.int32, "label_lengths") check_type(lengths, torch.int32, "lengths") From 379a2b8eb6902a178fa5062682bd33635f5f0508 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 31 Dec 2020 13:48:01 -0800 Subject: [PATCH 13/18] ask user to import prototype explicitly. --- torchaudio/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchaudio/__init__.py b/torchaudio/__init__.py index db8ac84f48..dbe2acc013 100644 --- a/torchaudio/__init__.py +++ b/torchaudio/__init__.py @@ -1,5 +1,4 @@ from . import extension -from . import prototype from torchaudio._internal import module_utils as _mod_utils from torchaudio import ( From 16c7636847dab8b95da7588c688c13017f9a21e3 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 31 Dec 2020 13:49:58 -0800 Subject: [PATCH 14/18] avoid new variables. --- build_tools/setup_helpers/extension.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index 01b768b570..6b8e132b9a 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -18,8 +18,6 @@ _CSRC_DIR = _ROOT_DIR / 'torchaudio' / 'csrc' _TP_BASE_DIR = _ROOT_DIR / 'third_party' _TP_INSTALL_DIR = _TP_BASE_DIR / 'install' -_TRANSDUCER_BUILD_DIR = _TP_BASE_DIR / 'build' / 'transducer' -_TRANSDUCER_BASE_DIR = _TP_BASE_DIR / 'transducer' / 'submodule' def _get_build_sox(): @@ -66,7 +64,7 @@ def _get_srcs(): def _get_include_dirs(): dirs = [ str(_ROOT_DIR), - str(_TRANSDUCER_BASE_DIR / 'include'), + str(_TP_BASE_DIR / 'transducer' / 'submodule'/ 'include'), ] if _BUILD_SOX: dirs.append(str(_TP_INSTALL_DIR / 'include')) @@ -97,7 +95,7 @@ def _get_extra_objects(): ] for lib in libs: objs.append(str(_TP_INSTALL_DIR / 'lib' / lib)) - objs.append(str(_TRANSDUCER_BUILD_DIR / 'libwarprnnt.a')) + objs.append(str(_TP_BASE_DIR / 'build' / 'transducer' / 'libwarprnnt.a')) return objs From e71feb12f5d697f5eb299a937a2344766dd66764 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 31 Dec 2020 13:55:33 -0800 Subject: [PATCH 15/18] replace by torchscript. --- torchaudio/csrc/transducer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/csrc/transducer.cpp b/torchaudio/csrc/transducer.cpp index 04408d2a64..14289bfb14 100644 --- a/torchaudio/csrc/transducer.cpp +++ b/torchaudio/csrc/transducer.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include #include "rnnt.h" int64_t cpu_rnnt_loss(torch::Tensor acts, From 44ae3006726cd23dd555d7d606796cc1031bc190 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 31 Dec 2020 14:02:02 -0800 Subject: [PATCH 16/18] change module requirement. --- torchaudio/prototype/transducer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchaudio/prototype/transducer.py b/torchaudio/prototype/transducer.py index b7766c6670..d96f4d4b98 100644 --- a/torchaudio/prototype/transducer.py +++ b/torchaudio/prototype/transducer.py @@ -53,7 +53,7 @@ def backward(ctx, grad_output): return ctx.grads.mul_(grad_output), None, None, None, None, None -@_mod_utils.requires_module('_warp_transducer') +@_mod_utils.requires_module('torchaudio._torchaudio') def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction="mean"): """Compute the RNN Transducer Loss. @@ -82,7 +82,7 @@ def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction="mean"): return _RNNT.apply(acts, labels, act_lens, label_lens, blank, reduction) -@_mod_utils.requires_module('_warp_transducer') +@_mod_utils.requires_module('torchaudio._torchaudio') class RNNTLoss(Module): """ Args: From f197b67dbdb823a3b7b16c8cf0637010cf805ca2 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 31 Dec 2020 14:04:53 -0800 Subject: [PATCH 17/18] lint. --- build_tools/setup_helpers/extension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index 6b8e132b9a..6a30093238 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -64,7 +64,7 @@ def _get_srcs(): def _get_include_dirs(): dirs = [ str(_ROOT_DIR), - str(_TP_BASE_DIR / 'transducer' / 'submodule'/ 'include'), + str(_TP_BASE_DIR / 'transducer' / 'submodule' / 'include'), ] if _BUILD_SOX: dirs.append(str(_TP_INSTALL_DIR / 'include')) From c8597365fd0c3a3cb85870aa629bed608baa41e9 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 31 Dec 2020 14:22:15 -0800 Subject: [PATCH 18/18] fix name. --- torchaudio/csrc/register.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchaudio/csrc/register.cpp b/torchaudio/csrc/register.cpp index e977bca883..7ae91f1d7a 100644 --- a/torchaudio/csrc/register.cpp +++ b/torchaudio/csrc/register.cpp @@ -81,13 +81,13 @@ TORCH_LIBRARY(torchaudio, m) { ////////////////////////////////////////////////////////////////////////////// // transducer.cpp ////////////////////////////////////////////////////////////////////////////// - m.def("rnnt(Tensor acts," - "Tensor labels," - "Tensor input_lengths," - "Tensor label_lengths," - "Tensor costs," - "Tensor grads," - "int blank_label," - "int num_threads) -> int"); + m.def("rnnt_loss(Tensor acts," + "Tensor labels," + "Tensor input_lengths," + "Tensor label_lengths," + "Tensor costs," + "Tensor grads," + "int blank_label," + "int num_threads) -> int"); } #endif