@@ -24,12 +24,13 @@ def rnnt_loss(
2424    dependencies. 
2525
2626    Args: 
27-         logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner 
27+         logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class) 
28+             containing output from joiner 
2829        targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded 
2930        logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder 
3031        target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence 
31-         blank (int, opt ): blank label (Default: ``-1``) 
32-         clamp (float): clamp for gradients (Default: ``-1``) 
32+         blank (int, optional ): blank label (Default: ``-1``) 
33+         clamp (float, optional ): clamp for gradients (Default: ``-1``) 
3334        reduction (string, optional): Specifies the reduction to apply to the output: 
3435            ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) 
3536
@@ -69,8 +70,8 @@ class RNNTLoss(torch.nn.Module):
6970    dependencies. 
7071
7172    Args: 
72-         blank (int, opt ): blank label (Default: ``-1``) 
73-         clamp (float): clamp for gradients (Default: ``-1``) 
73+         blank (int, optional ): blank label (Default: ``-1``) 
74+         clamp (float, optional ): clamp for gradients (Default: ``-1``) 
7475        reduction (string, optional): Specifies the reduction to apply to the output: 
7576            ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) 
7677    """ 
@@ -95,7 +96,8 @@ def forward(
9596    ):
9697        """ 
9798        Args: 
98-             logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner 
99+             logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class) 
100+                 containing output from joiner 
99101            targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded 
100102            logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder 
101103            target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence 
0 commit comments