Skip to content

Commit e2e6562

Browse files
committed
copy bindings inside torchaudio prototype. build static with torchaudio. deactivate OMP. add test.
1 parent 5da2ac3 commit e2e6562

File tree

11 files changed

+696
-0
lines changed

11 files changed

+696
-0
lines changed

build_tools/setup_helpers/extension.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,60 @@ def get_ext_modules(debug=False):
132132
extra_objects=_get_extra_objects(),
133133
extra_link_args=_get_ela(debug),
134134
),
135+
_get_transducer_module(),
135136
]
136137

137138

138139
class BuildExtension(TorchBuildExtension):
139140
def build_extension(self, ext):
140141
if ext.name == _EXT_NAME and _BUILD_SOX:
141142
_build_third_party()
143+
if ext.name == _TRANSDUCER_NAME:
144+
_build_transducer()
142145
super().build_extension(ext)
146+
147+
148+
_TRANSDUCER_NAME = '_warp_transducer'
149+
_TP_TRANSDUCER_BASE_DIR = _ROOT_DIR / 'third_party' / 'warp_transducer'
150+
151+
152+
def _build_transducer():
153+
build_dir = str(_TP_TRANSDUCER_BASE_DIR / 'submodule' / 'build')
154+
os.makedirs(build_dir, exist_ok=True)
155+
subprocess.run(
156+
args=['cmake', str(_TP_TRANSDUCER_BASE_DIR), '-DWITH_OMP=OFF'],
157+
cwd=build_dir,
158+
check=True,
159+
)
160+
subprocess.run(
161+
args=['cmake', '--build', '.'],
162+
cwd=build_dir,
163+
check=True,
164+
)
165+
166+
167+
def _get_transducer_module():
168+
extra_compile_args = [
169+
'-fPIC',
170+
'-std=c++14',
171+
]
172+
173+
librairies = ['warprnnt']
174+
175+
source_paths = [
176+
_TP_TRANSDUCER_BASE_DIR / 'binding.cpp',
177+
_TP_TRANSDUCER_BASE_DIR / 'submodule' / 'pytorch_binding' / 'src' / 'binding.cpp',
178+
]
179+
build_path = _TP_TRANSDUCER_BASE_DIR / 'submodule' / 'build'
180+
include_path = _TP_TRANSDUCER_BASE_DIR / 'submodule' / 'include'
181+
182+
return CppExtension(
183+
name=_TRANSDUCER_NAME,
184+
sources=[os.path.realpath(path) for path in source_paths],
185+
libraries=librairies,
186+
include_dirs=[os.path.realpath(include_path)],
187+
library_dirs=[os.path.realpath(build_path)],
188+
extra_compile_args=extra_compile_args,
189+
extra_objects=[str(build_path / f'lib{lib}.a') for lib in librairies],
190+
extra_link_args=['-Wl,-rpath,' + os.path.realpath(build_path)],
191+
)

test/torchaudio_unittest/common_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
skipIfNoModule,
1717
skipIfNoExtension,
1818
skipIfNoSoxBackend,
19+
skipIfNoTransducer,
1920
)
2021
from .wav_utils import (
2122
get_wav_data,

test/torchaudio_unittest/common_utils/case_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,9 @@ def skipIfNoExtension(test_item):
7575
if 'TORCHAUDIO_TEST_FAIL_IF_NO_EXTENSION' in os.environ:
7676
raise RuntimeError('torchaudio C++ extension is not available.')
7777
return unittest.skip('torchaudio C++ extension is not available')(test_item)
78+
79+
80+
skipIfNoTransducer = unittest.skipIf(
81+
not is_module_available('_warp_transducer'),
82+
'"_warp_transducer" is not available',
83+
)
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
import numpy as np
2+
import torch
3+
4+
from torchaudio_unittest import common_utils
5+
from torchaudio.prototype.transducer import RNNTLoss
6+
7+
8+
def get_numpy_data_B2_T4_U3_D3(dtype=np.float32):
9+
logits = np.array(
10+
[
11+
0.065357,
12+
0.787530,
13+
0.081592,
14+
0.529716,
15+
0.750675,
16+
0.754135,
17+
0.609764,
18+
0.868140,
19+
0.622532,
20+
0.668522,
21+
0.858039,
22+
0.164539,
23+
0.989780,
24+
0.944298,
25+
0.603168,
26+
0.946783,
27+
0.666203,
28+
0.286882,
29+
0.094184,
30+
0.366674,
31+
0.736168,
32+
0.166680,
33+
0.714154,
34+
0.399400,
35+
0.535982,
36+
0.291821,
37+
0.612642,
38+
0.324241,
39+
0.800764,
40+
0.524106,
41+
0.779195,
42+
0.183314,
43+
0.113745,
44+
0.240222,
45+
0.339470,
46+
0.134160,
47+
0.505562,
48+
0.051597,
49+
0.640290,
50+
0.430733,
51+
0.829473,
52+
0.177467,
53+
0.320700,
54+
0.042883,
55+
0.302803,
56+
0.675178,
57+
0.569537,
58+
0.558474,
59+
0.083132,
60+
0.060165,
61+
0.107958,
62+
0.748615,
63+
0.943918,
64+
0.486356,
65+
0.418199,
66+
0.652408,
67+
0.024243,
68+
0.134582,
69+
0.366342,
70+
0.295830,
71+
0.923670,
72+
0.689929,
73+
0.741898,
74+
0.250005,
75+
0.603430,
76+
0.987289,
77+
0.592606,
78+
0.884672,
79+
0.543450,
80+
0.660770,
81+
0.377128,
82+
0.358021,
83+
],
84+
dtype=dtype,
85+
).reshape(2, 4, 3, 3)
86+
87+
targets = np.array([[1, 2], [1, 1]], dtype=np.int32)
88+
src_lengths = np.array([4, 4], dtype=np.int32)
89+
tgt_lengths = np.array([2, 2], dtype=np.int32)
90+
91+
blank = 0
92+
93+
ref_costs = np.array([4.2806528590890736, 3.9384369822503591], dtype=dtype)
94+
95+
ref_gradients = np.array(
96+
[
97+
-0.186844,
98+
-0.062555,
99+
0.249399,
100+
-0.203377,
101+
0.202399,
102+
0.000977,
103+
-0.141016,
104+
0.079123,
105+
0.061893,
106+
-0.011552,
107+
-0.081280,
108+
0.092832,
109+
-0.154257,
110+
0.229433,
111+
-0.075176,
112+
-0.246593,
113+
0.146405,
114+
0.100188,
115+
-0.012918,
116+
-0.061593,
117+
0.074512,
118+
-0.055986,
119+
0.219831,
120+
-0.163845,
121+
-0.497627,
122+
0.209240,
123+
0.288387,
124+
0.013605,
125+
-0.030220,
126+
0.016615,
127+
0.113925,
128+
0.062781,
129+
-0.176706,
130+
-0.667078,
131+
0.367659,
132+
0.299419,
133+
-0.356344,
134+
-0.055347,
135+
0.411691,
136+
-0.096922,
137+
0.029459,
138+
0.067463,
139+
-0.063518,
140+
0.027654,
141+
0.035863,
142+
-0.154499,
143+
-0.073942,
144+
0.228441,
145+
-0.166790,
146+
-0.000088,
147+
0.166878,
148+
-0.172370,
149+
0.105565,
150+
0.066804,
151+
0.023875,
152+
-0.118256,
153+
0.094381,
154+
-0.104707,
155+
-0.108934,
156+
0.213642,
157+
-0.369844,
158+
0.180118,
159+
0.189726,
160+
0.025714,
161+
-0.079462,
162+
0.053748,
163+
0.122328,
164+
-0.238789,
165+
0.116460,
166+
-0.598687,
167+
0.302203,
168+
0.296484,
169+
],
170+
dtype=dtype,
171+
).reshape(2, 4, 3, 3)
172+
173+
data = {
174+
"logits": logits,
175+
"targets": targets,
176+
"src_lengths": src_lengths,
177+
"tgt_lengths": tgt_lengths,
178+
"blank": blank,
179+
}
180+
181+
return data, ref_costs, ref_gradients
182+
183+
184+
def numpy_to_torch(data, device, requires_grad=True):
185+
186+
logits = torch.from_numpy(data["logits"])
187+
targets = torch.from_numpy(data["targets"])
188+
src_lengths = torch.from_numpy(data["src_lengths"])
189+
tgt_lengths = torch.from_numpy(data["tgt_lengths"])
190+
191+
logits.requires_grad_(requires_grad)
192+
193+
logits = logits.to(device)
194+
195+
def grad_hook(grad):
196+
logits.saved_grad = grad.clone()
197+
198+
logits.register_hook(grad_hook)
199+
200+
data["logits"] = logits
201+
data["src_lengths"] = src_lengths
202+
data["tgt_lengths"] = tgt_lengths
203+
data["targets"] = targets
204+
205+
return data
206+
207+
208+
def compute_with_pytorch_transducer(data):
209+
costs = RNNTLoss(blank=data["blank"], reduction="none")(
210+
acts=data["logits_sparse"] if "logits_sparse" in data else data["logits"],
211+
labels=data["targets"],
212+
act_lens=data["src_lengths"],
213+
label_lens=data["tgt_lengths"],
214+
)
215+
216+
loss = torch.sum(costs)
217+
loss.backward()
218+
costs = costs.cpu().data.numpy()
219+
gradients = data["logits"].saved_grad.cpu().data.numpy()
220+
return costs, gradients
221+
222+
223+
class TransducerTester:
224+
def test_basic_backward(self):
225+
# Test if example provided in README runs
226+
# https://github.com/HawkAaron/warp-transducer
227+
228+
rnnt_loss = RNNTLoss()
229+
230+
acts = torch.FloatTensor(
231+
[
232+
[
233+
[
234+
[0.1, 0.6, 0.1, 0.1, 0.1],
235+
[0.1, 0.1, 0.6, 0.1, 0.1],
236+
[0.1, 0.1, 0.2, 0.8, 0.1],
237+
],
238+
[
239+
[0.1, 0.6, 0.1, 0.1, 0.1],
240+
[0.1, 0.1, 0.2, 0.1, 0.1],
241+
[0.7, 0.1, 0.2, 0.1, 0.1],
242+
],
243+
]
244+
]
245+
)
246+
labels = torch.IntTensor([[1, 2]])
247+
act_length = torch.IntTensor([2])
248+
label_length = torch.IntTensor([2])
249+
250+
acts = acts.to(self.device)
251+
labels = labels.to(self.device)
252+
act_length = act_length.to(self.device)
253+
label_length = label_length.to(self.device)
254+
255+
acts.requires_grad_(True)
256+
257+
loss = rnnt_loss(acts, labels, act_length, label_length)
258+
loss.backward()
259+
260+
def _test_costs_and_gradients(
261+
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
262+
):
263+
logits_shape = data["logits"].shape
264+
costs, gradients = compute_with_pytorch_transducer(data=data)
265+
np.testing.assert_allclose(costs, ref_costs, atol=atol, rtol=rtol)
266+
self.assertEqual(logits_shape, gradients.shape)
267+
if not np.allclose(gradients, ref_gradients, atol=atol, rtol=rtol):
268+
for b in range(len(gradients)):
269+
T = data["src_lengths"][b]
270+
U = data["tgt_lengths"][b]
271+
for t in range(gradients.shape[1]):
272+
for u in range(gradients.shape[2]):
273+
np.testing.assert_allclose(
274+
gradients[b, t, u],
275+
ref_gradients[b, t, u],
276+
atol=atol,
277+
rtol=rtol,
278+
err_msg=f"failed on b={b}, t={t}/T={T}, u={u}/U={U}",
279+
)
280+
281+
def test_costs_and_gradients_B2_T4_U3_D3_fp32(self):
282+
data, ref_costs, ref_gradients = get_numpy_data_B2_T4_U3_D3(dtype=np.float32)
283+
data = numpy_to_torch(data=data, device=self.device, requires_grad=True)
284+
self._test_costs_and_gradients(
285+
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
286+
)
287+
288+
289+
@common_utils.skipIfNoTransducer
290+
class CPUTransducerTester(TransducerTester, common_utils.PytorchTestCase):
291+
device = "cpu"

0 commit comments

Comments
 (0)