Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/torchscript_bc_test/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@ build_master() {
conda install -y -q pytorch "cpuonly" -c pytorch-nightly
printf "* Installing torchaudio\n"
cd "${_root_dir}" || exit 1
git submodule update --init --recursive
BUILD_SOX=1 python setup.py clean install
}
1 change: 1 addition & 0 deletions .circleci/unittest/linux/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ conda activate "${env_dir}"
pip --quiet install cmake ninja

# 4. Buld codecs
git submodule update --init --recursive
mkdir -p third_party/build
(
cd third_party/build
Expand Down
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[submodule "third_party/warp_transducer/submodule"]
path = third_party/transducer/submodule
url = https://github.com/HawkAaron/warp-transducer
branch = master
2 changes: 2 additions & 0 deletions build_tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _get_srcs():
def _get_include_dirs():
dirs = [
str(_ROOT_DIR),
str(_TP_BASE_DIR / 'transducer' / 'submodule' / 'include'),
]
if _BUILD_SOX:
dirs.append(str(_TP_INSTALL_DIR / 'include'))
Expand Down Expand Up @@ -94,6 +95,7 @@ def _get_extra_objects():
]
for lib in libs:
objs.append(str(_TP_INSTALL_DIR / 'lib' / lib))
objs.append(str(_TP_BASE_DIR / 'build' / 'transducer' / 'libwarprnnt.a'))
return objs


Expand Down
1 change: 1 addition & 0 deletions packaging/pkg_helpers.bash
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ setup_macos() {
#
# Usage: setup_env 0.2.0
setup_env() {
git submodule update --init --recursive
setup_cuda
setup_build_version "$1"
setup_macos
Expand Down
291 changes: 291 additions & 0 deletions test/torchaudio_unittest/transducer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
import numpy as np
import torch

from torchaudio_unittest import common_utils
from torchaudio.prototype.transducer import RNNTLoss


def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
logits = np.array(
[
0.065357,
0.787530,
0.081592,
0.529716,
0.750675,
0.754135,
0.609764,
0.868140,
0.622532,
0.668522,
0.858039,
0.164539,
0.989780,
0.944298,
0.603168,
0.946783,
0.666203,
0.286882,
0.094184,
0.366674,
0.736168,
0.166680,
0.714154,
0.399400,
0.535982,
0.291821,
0.612642,
0.324241,
0.800764,
0.524106,
0.779195,
0.183314,
0.113745,
0.240222,
0.339470,
0.134160,
0.505562,
0.051597,
0.640290,
0.430733,
0.829473,
0.177467,
0.320700,
0.042883,
0.302803,
0.675178,
0.569537,
0.558474,
0.083132,
0.060165,
0.107958,
0.748615,
0.943918,
0.486356,
0.418199,
0.652408,
0.024243,
0.134582,
0.366342,
0.295830,
0.923670,
0.689929,
0.741898,
0.250005,
0.603430,
0.987289,
0.592606,
0.884672,
0.543450,
0.660770,
0.377128,
0.358021,
],
dtype=dtype,
).reshape(2, 4, 3, 3)

targets = np.array([[1, 2], [1, 1]], dtype=np.int32)
src_lengths = np.array([4, 4], dtype=np.int32)
tgt_lengths = np.array([2, 2], dtype=np.int32)

blank = 0

ref_costs = np.array([4.2806528590890736, 3.9384369822503591], dtype=dtype)

ref_gradients = np.array(
[
-0.186844,
-0.062555,
0.249399,
-0.203377,
0.202399,
0.000977,
-0.141016,
0.079123,
0.061893,
-0.011552,
-0.081280,
0.092832,
-0.154257,
0.229433,
-0.075176,
-0.246593,
0.146405,
0.100188,
-0.012918,
-0.061593,
0.074512,
-0.055986,
0.219831,
-0.163845,
-0.497627,
0.209240,
0.288387,
0.013605,
-0.030220,
0.016615,
0.113925,
0.062781,
-0.176706,
-0.667078,
0.367659,
0.299419,
-0.356344,
-0.055347,
0.411691,
-0.096922,
0.029459,
0.067463,
-0.063518,
0.027654,
0.035863,
-0.154499,
-0.073942,
0.228441,
-0.166790,
-0.000088,
0.166878,
-0.172370,
0.105565,
0.066804,
0.023875,
-0.118256,
0.094381,
-0.104707,
-0.108934,
0.213642,
-0.369844,
0.180118,
0.189726,
0.025714,
-0.079462,
0.053748,
0.122328,
-0.238789,
0.116460,
-0.598687,
0.302203,
0.296484,
],
dtype=dtype,
).reshape(2, 4, 3, 3)

data = {
"logits": logits,
"targets": targets,
"src_lengths": src_lengths,
"tgt_lengths": tgt_lengths,
"blank": blank,
}

return data, ref_costs, ref_gradients


def numpy_to_torch(data, device, requires_grad=True):

logits = torch.from_numpy(data["logits"])
targets = torch.from_numpy(data["targets"])
src_lengths = torch.from_numpy(data["src_lengths"])
tgt_lengths = torch.from_numpy(data["tgt_lengths"])

logits.requires_grad_(requires_grad)

logits = logits.to(device)

def grad_hook(grad):
logits.saved_grad = grad.clone()

logits.register_hook(grad_hook)

data["logits"] = logits
data["src_lengths"] = src_lengths
data["tgt_lengths"] = tgt_lengths
data["targets"] = targets

return data


def compute_with_pytorch_transducer(data):
costs = RNNTLoss(blank=data["blank"], reduction="none")(
acts=data["logits_sparse"] if "logits_sparse" in data else data["logits"],
labels=data["targets"],
act_lens=data["src_lengths"],
label_lens=data["tgt_lengths"],
)

loss = torch.sum(costs)
loss.backward()
costs = costs.cpu().data.numpy()
gradients = data["logits"].saved_grad.cpu().data.numpy()
return costs, gradients


class TransducerTester:
def test_basic_backward(self):
# Test if example provided in README runs
# https://github.com/HawkAaron/warp-transducer

rnnt_loss = RNNTLoss()

acts = torch.FloatTensor(
[
[
[
[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1],
],
[
[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1],
],
]
]
)
labels = torch.IntTensor([[1, 2]])
act_length = torch.IntTensor([2])
label_length = torch.IntTensor([2])

acts = acts.to(self.device)
labels = labels.to(self.device)
act_length = act_length.to(self.device)
label_length = label_length.to(self.device)

acts.requires_grad_(True)

loss = rnnt_loss(acts, labels, act_length, label_length)
loss.backward()

def _test_costs_and_gradients(
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
):
logits_shape = data["logits"].shape
costs, gradients = compute_with_pytorch_transducer(data=data)
np.testing.assert_allclose(costs, ref_costs, atol=atol, rtol=rtol)
self.assertEqual(logits_shape, gradients.shape)
if not np.allclose(gradients, ref_gradients, atol=atol, rtol=rtol):
for b in range(len(gradients)):
T = data["src_lengths"][b]
U = data["tgt_lengths"][b]
for t in range(gradients.shape[1]):
for u in range(gradients.shape[2]):
np.testing.assert_allclose(
gradients[b, t, u],
ref_gradients[b, t, u],
atol=atol,
rtol=rtol,
err_msg=f"failed on b={b}, t={t}/T={T}, u={u}/U={U}",
)

def test_costs_and_gradients_B2_T4_U3_D3_fp32(self):
data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3(dtype=np.float32)
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)


@common_utils.skipIfNoExtension
class CPUTransducerTester(TransducerTester, common_utils.PytorchTestCase):
device = "cpu"
2 changes: 2 additions & 0 deletions third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,5 @@ ExternalProject_Add(libsox
# See https://github.com/pytorch/audio/pull/1026
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp
)

add_subdirectory(transducer)
48 changes: 48 additions & 0 deletions third_party/transducer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
MESSAGE("path to cmake current source dir: ${CMAKE_CURRENT_SOURCE_DIR}")
IF(APPLE)
CMAKE_MINIMUM_REQUIRED(VERSION 3.4)
ELSE()
CMAKE_MINIMUM_REQUIRED(VERSION 2.8)
ENDIF()

PROJECT(rnnt_release)

IF(APPLE)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2")
ADD_DEFINITIONS(-DAPPLE)
ELSE()
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2")
ENDIF()

INCLUDE_DIRECTORIES(submodule/include)

SET(CMAKE_POSITION_INDEPENDENT_CODE ON)

ADD_DEFINITIONS(-DRNNT_DISABLE_OMP)

IF(APPLE)
EXEC_PROGRAM(uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION)
STRING(REGEX MATCH "[0-9]+" DARWIN_VERSION ${DARWIN_VERSION})
MESSAGE(STATUS "DARWIN_VERSION=${DARWIN_VERSION}")

# for el capitain have to use rpath
IF(DARWIN_VERSION LESS 15)
SET(CMAKE_SKIP_RPATH TRUE)
ENDIF()

ELSE()
# always skip for linux
SET(CMAKE_SKIP_RPATH TRUE)
ENDIF()

IF(NOT APPLE)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2")
ENDIF()

ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cpp)

INSTALL(TARGETS warprnnt
LIBRARY DESTINATION "lib"
ARCHIVE DESTINATION "archives")

INSTALL(FILES submodule/include/rnnt.h DESTINATION "submodule/include")
1 change: 1 addition & 0 deletions third_party/transducer/submodule
Submodule submodule added at f54657
1 change: 1 addition & 0 deletions torchaudio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import extension

from torchaudio._internal import module_utils as _mod_utils
from torchaudio import (
compliance,
Expand Down
Loading