We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 86ed12a commit 0fe13a4Copy full SHA for 0fe13a4
torchaudio/prototype/rnnt_loss.py
@@ -105,16 +105,6 @@ def rnnt_loss(
105
False # softmax needs the original logits value
106
)
107
108
- # move everything to the same device.
109
- targets = targets.to(device=logits.device)
110
- logit_lengths = logit_lengths.to(device=logits.device)
111
- target_lengths = target_lengths.to(device=logits.device)
112
-
113
- # make sure all int tensors are of type int32.
114
- targets = targets.int()
115
- logit_lengths = logit_lengths.int()
116
- target_lengths = target_lengths.int()
117
118
if blank < 0: # reinterpret blank index if blank < 0.
119
blank = logits.shape[-1] + blank
120
0 commit comments