-
Couldn't load subscription status.
- Fork 734
Add target dtype argument to save function for sox backend
#1204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made some suggestion on testing, otherwise it looks good.
torchaudio/backend/sox_io_backend.py
Outdated
| ``dtype`` parameter is only effective for ``float32`` Tensor. | ||
| """ | ||
| if src.dtype == torch.float32 and dtype == None: | ||
| warnings.warn('`dtype` default value will be changed to `int16` in 0.9 release') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add an instruction for suppress this warning? Many users want to be able to suppress warnings.
Something like "provide dtype argument to suppress this warning."
torchaudio/csrc/sox/io.cpp
Outdated
| throw std::runtime_error("dtype conversion only supported for float32 tensors"); | ||
| } | ||
| const auto tgt_dtype = (tensor.dtype() == torch::kFloat32 && dtype.has_value()) ? | ||
| get_dtype_from_str(dtype.value().c_str()) : tensor.dtype(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the call c_str() necessary? It looks like the resulting pointer will be implicitly converted to std::string again.
| sox_io_backend.save(path, data, 8000, dtype=dtype) | ||
| found = load_wav(path, normalize=False)[0] | ||
| self.assertEqual(found.dtype, getattr(torch, dtype)) | ||
| assert(torch.all(found >= range[0]) and torch.all(found <= range[1])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This range assertion does not look quite right. In case of int32 conversion, it will still pass even if the conversion does not happen.
How about performing exact comparison with Tensors that are small enough to catch the element wise parity?
@parameterized.expand([
('float32', torch.Tensor([-1.0, -0.5, 0, 0.5, 1.0]).to(torch.float32)),
('int32', torch.Tensor([-2147483648, ..., 0, ..., 2147483647]).to(torch.int32)),
...
])
def test_dtype_conversion(self, dtype, expected):
data = torch.Tensor([-1.0, -0.5, 0, 0.5, 1.0]).to(torch.dtype32)
...
self.assertEqual(found, expected)
torchaudio/backend/sox_io_backend.py
Outdated
| ``dtype=None`` means no conversion is performed. | ||
| ``dtype`` parameter is only effective for ``float32`` Tensor. | ||
| """ | ||
| if src.dtype == torch.float32 and dtype == None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The save function has to be compatible with TorchScript compiler. Check out "Type Refinement" section to resolve None case.
https://pytorch.org/docs/stable/jit_language_reference.html#optional-type-refinement
I am not sure how TorchScript compiler processes warnings.warn. If the torchscript_compatibility test fails, you can move this warnings to _save function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Can you fix the warning message?
torchaudio/backend/sox_io_backend.py
Outdated
| """ | ||
| if src.dtype == torch.float32 and dtype is None: | ||
| warnings.warn( | ||
| '`dtype` default value will be changed to `int16` in 0.9 release' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| '`dtype` default value will be changed to `int16` in 0.9 release' | |
| '`dtype` default value will be changed to `int16` in 0.9 release. ' |
torchaudio/backend/sox_io_backend.py
Outdated
| if src.dtype == torch.float32 and dtype is None: | ||
| warnings.warn( | ||
| '`dtype` default value will be changed to `int16` in 0.9 release' | ||
| 'provide `dtype argument to suppress this warning' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| 'provide `dtype argument to suppress this warning' | |
| 'Specify `dtype` to suppress this warning.' |
|
Thanks! |
* add new speech tutorial. * update with a few parameter tuned. model takes less than 10 min to run now. * feedback. * improve GPU performance. add interactive demo at the end. * feedback.
issue: #1197