@@ -18,23 +18,22 @@ def _test_costs_and_gradients(
1818 self , data , ref_costs , ref_gradients , atol = 1e-6 , rtol = 1e-2
1919 ):
2020 logits_shape = data ["logits" ].shape
21- with self .subTest ():
22- costs , gradients = compute_with_pytorch_transducer (data = data )
23- np .testing .assert_allclose (costs , ref_costs , atol = atol , rtol = rtol )
24- self .assertEqual (logits_shape , gradients .shape )
25- if not np .allclose (gradients , ref_gradients , atol = atol , rtol = rtol ):
26- for b in range (len (gradients )):
27- T = data ["logit_lengths" ][b ]
28- U = data ["target_lengths" ][b ]
29- for t in range (gradients .shape [1 ]):
30- for u in range (gradients .shape [2 ]):
31- np .testing .assert_allclose (
32- gradients [b , t , u ],
33- ref_gradients [b , t , u ],
34- atol = atol ,
35- rtol = rtol ,
36- err_msg = f"failed on b={ b } , t={ t } /T={ T } , u={ u } /U={ U } " ,
37- )
21+ costs , gradients = compute_with_pytorch_transducer (data = data )
22+ np .testing .assert_allclose (costs , ref_costs , atol = atol , rtol = rtol )
23+ self .assertEqual (logits_shape , gradients .shape )
24+ if not np .allclose (gradients , ref_gradients , atol = atol , rtol = rtol ):
25+ for b in range (len (gradients )):
26+ T = data ["logit_lengths" ][b ]
27+ U = data ["target_lengths" ][b ]
28+ for t in range (gradients .shape [1 ]):
29+ for u in range (gradients .shape [2 ]):
30+ np .testing .assert_allclose (
31+ gradients [b , t , u ],
32+ ref_gradients [b , t , u ],
33+ atol = atol ,
34+ rtol = rtol ,
35+ err_msg = f"failed on b={ b } , t={ t } /T={ T } , u={ u } /U={ U } " ,
36+ )
3837
3938 def test_basic_backward (self ):
4039 rnnt_loss = RNNTLoss ()
0 commit comments