-
Notifications
You must be signed in to change notification settings - Fork 739
extend batch support #391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
extend batch support #391
Conversation
| waveform, sample_rate = torchaudio.load(self.test_filepath) | ||
|
|
||
| # Single then transform then batch | ||
| expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is repetitive. I'm wondering if there is a way of creating a test generator that explicitly exercises this decorator. It's a function that accepts a function with some args and kwargs and then assumes that the first input is to be batchable (so it applies various reshapes to args[0] etc.)
def _gen_batchable(self, func, *args, **kwargs):
self.assertEqual(func(*args, **kwargs),
func(*(args[0].reshape(3, -1, -1, -1) + args[1:]), **kwargs)
(not sure on the reshape behavior, could also be a lambda).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed the format that I did for testing jitability, and simply introduced common functions to test batching. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'm mostly just trying to iterate on top of that. Might be one of these non-problems I'm trying to fix.
| Args: | ||
| waveform (torch.Tensor): Tensor of audio of dimension (channel, time). | ||
| waveform (torch.Tensor): Tensor of audio of dimension (..., time). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe ([...], channel, time) to indicate that we still follow that convention of channel x time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as it's documented and standardized so we can apply it consistently this is fine by me
| torch.random.manual_seed(42) | ||
| tensor = torch.rand((1, 201, 6)) | ||
|
|
||
| n_fft = 400 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these parameters special to the batch test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using the same as test_torchscript_spectrogram, but we can change them or we could set them in the class directly.
|
|
||
| stft = torch.tensor([ | ||
| [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]], | ||
| [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why repeat all 0s instead of something more interesting (maybe just [0., 4.])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm reusing the parameters from test_istft_of_ones.
|
|
||
| # pack batch | ||
| shape = specgram.size() | ||
| specgram = specgram.reshape([-1] + list(shape[-2:])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of reshape I suggest you use view. Reshape dispatches to view if it doesn't require a copy. If this does require a copy I think it'd be good to know about since it's not something we'd expect to be necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, but that means I need to change all of the batching. I suggest to change that all at the same time explicitly as a separate PR.
cpuhrsch
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine, I'd just revisit reshape vs. view
* extend batch support closes pytorch#383 * function for batch test. * set seed.
mask_along_axis_iidsupports batching, so is not updated here, see standardizing freq/time axis #401 and Move batch from vocoder transform to functional #350.Closes #383