-
Notifications
You must be signed in to change notification settings - Fork 732
Add RNN Transducer Loss for CPU #1137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
9d6589a to
e2e6562
Compare
| def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction="mean"): | ||
| """RNN Transducer Loss | ||
| Args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the documentation could be improved a bit. It could also be useful to reference the paper.
| super().build_extension(ext) | ||
|
|
||
|
|
||
| _TRANSDUCER_NAME = '_warp_transducer' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will get installed in global namespace, outside of torchaudio package directory.
Please put it in torchaudio package.
| MESSAGE(STATUS "Building static library with GPU support") | ||
|
|
||
| CUDA_ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cu) | ||
| IF (!Torch_FOUND) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If torch is not found, shouldn't it be failing?
| self.reduction = reduction | ||
| self.loss = _RNNT.apply | ||
|
|
||
| def forward(self, acts, labels, act_lens, label_lens): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you don't want to copy-paste the docs from the functional you could reference it here within the documentation.
b6c4ce8 to
ca66151
Compare
|
Some TODOs:
Some follow-ups:
|
82b7186 to
456eefc
Compare
| # Test if example provided in README runs | ||
| # https://github.com/HawkAaron/warp-transducer | ||
|
|
||
| acts = torch.FloatTensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use the factory function torch.tensor([xyz], dtype=torch.float) instead of the type constructor. Same applies to IntTensor.
f96089b to
299310c
Compare
| U = data["tgt_lengths"][b] | ||
| for t in range(gradients.shape[1]): | ||
| for u in range(gradients.shape[2]): | ||
| np.testing.assert_allclose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.assertEqual should be preferred
f18105a to
1d2c5db
Compare
|
Some more TODOs:
Some more follow-ups:
Error below also happens on master: |
64c8220 to
32e3398
Compare
| loss = rnnt_loss(acts, labels, act_length, label_length) | ||
| loss.backward() | ||
|
|
||
| def _test_costs_and_gradients( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be inlined since it only has one call-site and is pretty small (but that's not the reason to remove an abstraction necessarily).
32e3398 to
fddfbd1
Compare
| } | ||
|
|
||
| TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { | ||
| m.impl("rnnt_loss", &cpu_rnnt_loss); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vincentqb Can you define a proper namespace? torchaudio::<something>::rnnt_loss
I am not sure how you want to move on, but if you have a plan to add different type of rnnt, then more descriptive name would work better later, like warprnnt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding anonymous namespace in #1159 for the time being.
|
@vincentqb I update the followup description for things addressed in #1159 and #1161. Please stamp these PRs when you have time. For |
|
For C++ ABI issue ssee #880 |
* fdsa * Tutorial runs * clarify one scaler per convergence run * adjust sizes, dont run illustrative sections * satisfying ocd * MORE * fdsa * details * rephrase * fix formatting * move script to recipes * hopefully moved to recipes * fdsa * add amp_tutorial to toctree * amp_tutorial -> amp_recipe * looks like backtick highlights dont render in card_description * correct path for amp_recipe.html * arch notes and saving/restoring * formatting * fdsa * Clarify autograd-autocast interaction for custom ops * touchups Co-authored-by: Brian Johnson <[email protected]>
This pull request introduces
rnnt_lossandRNNTLossas a prototype intorchaudio.prototype.transducerusing HawkAaron's warp-transducer.Follow-up work detailed in #1240.
cc @astaff, internal, #1099