diff --git a/torchaudio/prototype/rnnt_loss.py b/torchaudio/prototype/rnnt_loss.py index 2bce60835b..ca48353ff5 100644 --- a/torchaudio/prototype/rnnt_loss.py +++ b/torchaudio/prototype/rnnt_loss.py @@ -15,7 +15,6 @@ def rnnt_loss( blank: int = -1, clamp: float = -1, fused_log_softmax: bool = True, - reuse_logits_for_grads: bool = True, ): """ Compute the RNN Transducer Loss. @@ -33,13 +32,9 @@ def rnnt_loss( clamp (float): clamp for gradients (Default: ``-1``) runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``) fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``) - reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``) """ if not fused_log_softmax: logits = torch.nn.functional.log_softmax(logits, dim=-1) - reuse_logits_for_grads = ( - False # softmax needs the original logits value - ) if blank < 0: # reinterpret blank index if blank < 0. blank = logits.shape[-1] + blank @@ -52,7 +47,7 @@ def rnnt_loss( blank=blank, clamp=clamp, fused_log_smax=fused_log_softmax, - reuse_logits_for_grads=reuse_logits_for_grads,) + reuse_logits_for_grads=False,) return costs @@ -69,7 +64,6 @@ class RNNTLoss(torch.nn.Module): blank (int, opt): blank label (Default: ``-1``) clamp (float): clamp for gradients (Default: ``-1``) fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``) - reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``) """ def __init__( @@ -77,13 +71,11 @@ def __init__( blank: int = -1, clamp: float = -1., fused_log_softmax: bool = True, - reuse_logits_for_grads: bool = True, ): super().__init__() self.blank = blank self.clamp = clamp self.fused_log_softmax = fused_log_softmax - self.reuse_logits_for_grads = reuse_logits_for_grads def forward( self, @@ -107,5 +99,4 @@ def forward( self.blank, self.clamp, self.fused_log_softmax, - self.reuse_logits_for_grads, )