File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
test/torchaudio_unittest/rnnt Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -405,10 +405,10 @@ def get_numpy_random_data(
405405
406406
407407def numpy_to_torch (data , device , requires_grad = True ):
408- logits = torch .from_numpy (data ["logits" ])
409- targets = torch .from_numpy (data ["targets" ])
410- logit_lengths = torch .from_numpy (data ["logit_lengths" ])
411- target_lengths = torch .from_numpy (data ["target_lengths" ])
408+ logits = torch .from_numpy (data ["logits" ]). to ( device = device )
409+ targets = torch .from_numpy (data ["targets" ]). to ( device = device )
410+ logit_lengths = torch .from_numpy (data ["logit_lengths" ]). to ( device = device )
411+ target_lengths = torch .from_numpy (data ["target_lengths" ]). to ( device = device )
412412
413413 if "nbest_wers" in data :
414414 data ["nbest_wers" ] = torch .from_numpy (data ["nbest_wers" ]).to (device = device )
You can’t perform that action at this time.
0 commit comments