Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions examples/tutorials/forced_alignment_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,11 @@ def plot():
#
# To generate, the probability of time step :math:`t+1`, we look at the
# trellis from time step :math:`t` and emission at time step :math:`t+1`.
# There are two path to reach to time step :math:`t+1` with label
# :math:`c_{j+1}`. The first one is the case where the label was
# There are three paths to reach to time step :math:`t+1` with label
# :math:`c_{j+1}`. The first two are the cases where the label was
# :math:`c_{j+1}` at :math:`t` and there was no label change from
# :math:`t` to :math:`t+1`. The other case is where the label was
# :math:`t` to :math:`t+1`. For this we use he probability of 'blank'
# emission and repeating the same letter. The last case is where the label was
# :math:`c_j` at :math:`t` and it transitioned to the next label
# :math:`c_{j+1}` at :math:`t+1`.
#
Expand All @@ -160,7 +161,7 @@ def plot():
# Since we are looking for the most likely transitions, we take the more
# likely path for the value of :math:`k_{(t+1, j+1)}`, that is
#
# :math:`k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1, repeat) )`
# :math:`k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1, blank), k_{(t, j+1)} p(t+1, c_{j+1}))`
#
# where :math:`k` represents is trellis matrix, and :math:`p(t, c_j)`
# represents the probability of label :math:`c_j` at time step :math:`t`.
Expand Down Expand Up @@ -194,6 +195,10 @@ def get_trellis(emission, tokens, blank_id=0):
# Score for changing to the next token
trellis[t, :-1] + emission[t, tokens[1:]],
)
trellis[t+1,1:] = torch.maximum(
trellis[t+1,1:],
trellis[t, 1:] + emission[t, tokens[1:]],
)
return trellis


Expand Down