Skip to content

Commit 326e34d

Browse files
author
Caroline Chen
committed
remove unncessary code
1 parent 1398197 commit 326e34d

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

test/torchaudio_unittest/rnnt/rnnt_loss_impl.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,22 @@ def _test_costs_and_gradients(
1818
self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
1919
):
2020
logits_shape = data["logits"].shape
21-
with self.subTest():
22-
costs, gradients = compute_with_pytorch_transducer(data=data)
23-
np.testing.assert_allclose(costs, ref_costs, atol=atol, rtol=rtol)
24-
self.assertEqual(logits_shape, gradients.shape)
25-
if not np.allclose(gradients, ref_gradients, atol=atol, rtol=rtol):
26-
for b in range(len(gradients)):
27-
T = data["logit_lengths"][b]
28-
U = data["target_lengths"][b]
29-
for t in range(gradients.shape[1]):
30-
for u in range(gradients.shape[2]):
31-
np.testing.assert_allclose(
32-
gradients[b, t, u],
33-
ref_gradients[b, t, u],
34-
atol=atol,
35-
rtol=rtol,
36-
err_msg=f"failed on b={b}, t={t}/T={T}, u={u}/U={U}",
37-
)
21+
costs, gradients = compute_with_pytorch_transducer(data=data)
22+
np.testing.assert_allclose(costs, ref_costs, atol=atol, rtol=rtol)
23+
self.assertEqual(logits_shape, gradients.shape)
24+
if not np.allclose(gradients, ref_gradients, atol=atol, rtol=rtol):
25+
for b in range(len(gradients)):
26+
T = data["logit_lengths"][b]
27+
U = data["target_lengths"][b]
28+
for t in range(gradients.shape[1]):
29+
for u in range(gradients.shape[2]):
30+
np.testing.assert_allclose(
31+
gradients[b, t, u],
32+
ref_gradients[b, t, u],
33+
atol=atol,
34+
rtol=rtol,
35+
err_msg=f"failed on b={b}, t={t}/T={T}, u={u}/U={U}",
36+
)
3837

3938
def test_basic_backward(self):
4039
rnnt_loss = RNNTLoss()

0 commit comments

Comments
 (0)