Skip to content

Commit 950fb9a

Browse files
author
Caroline Chen
committed
update rnnt loss docstrings
1 parent 77b3082 commit 950fb9a

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

torchaudio/prototype/rnnt_loss.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ def rnnt_loss(
2424
dependencies.
2525
2626
Args:
27-
logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
27+
logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
28+
containing output from joiner
2829
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
2930
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
3031
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
31-
blank (int, opt): blank label (Default: ``-1``)
32-
clamp (float): clamp for gradients (Default: ``-1``)
32+
blank (int, optional): blank label (Default: ``-1``)
33+
clamp (float, optional): clamp for gradients (Default: ``-1``)
3334
reduction (string, optional): Specifies the reduction to apply to the output:
3435
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
3536
@@ -69,8 +70,8 @@ class RNNTLoss(torch.nn.Module):
6970
dependencies.
7071
7172
Args:
72-
blank (int, opt): blank label (Default: ``-1``)
73-
clamp (float): clamp for gradients (Default: ``-1``)
73+
blank (int, optional): blank label (Default: ``-1``)
74+
clamp (float, optional): clamp for gradients (Default: ``-1``)
7475
reduction (string, optional): Specifies the reduction to apply to the output:
7576
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
7677
"""

0 commit comments

Comments
 (0)