@@ -24,12 +24,13 @@ def rnnt_loss(
2424 dependencies.
2525
2626 Args:
27- logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
27+ logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
28+ containing output from joiner
2829 targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
2930 logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
3031 target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
31- blank (int, opt ): blank label (Default: ``-1``)
32- clamp (float): clamp for gradients (Default: ``-1``)
32+ blank (int, optional ): blank label (Default: ``-1``)
33+ clamp (float, optional ): clamp for gradients (Default: ``-1``)
3334 reduction (string, optional): Specifies the reduction to apply to the output:
3435 ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
3536
@@ -69,8 +70,8 @@ class RNNTLoss(torch.nn.Module):
6970 dependencies.
7071
7172 Args:
72- blank (int, opt ): blank label (Default: ``-1``)
73- clamp (float): clamp for gradients (Default: ``-1``)
73+ blank (int, optional ): blank label (Default: ``-1``)
74+ clamp (float, optional ): clamp for gradients (Default: ``-1``)
7475 reduction (string, optional): Specifies the reduction to apply to the output:
7576 ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
7677 """
@@ -95,7 +96,8 @@ def forward(
9596 ):
9697 """
9798 Args:
98- logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
99+ logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
100+ containing output from joiner
99101 targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
100102 logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
101103 target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
0 commit comments