Skip to content

Conversation

@ksanjeevan
Copy link
Contributor

@ksanjeevan ksanjeevan commented Mar 5, 2020

@vincentqb, fixing what was mentioned #456.

@vincentqb
Copy link
Contributor

vincentqb commented Mar 5, 2020

Can you show an error that would occur without that change? Some attributes/etc are inferred by nn.Module, see here.

@ksanjeevan
Copy link
Contributor Author

Yeah so here's an example:

import torch
from torchaudio.transforms import TimeStretch
m = TimeStretch().cuda()
m(torch.randn(1, 2, 201, 400, 2).cuda(), 1.3)

Will give a RuntimeError: expected device cuda:0 but got device cpu. Also we're using buffers on the other transforms as far as I'm aware, see here for example.

Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I can reproduce the error. Good catch!

@vincentqb vincentqb merged commit f1a5503 into pytorch:master Mar 5, 2020
@ksanjeevan ksanjeevan deleted the audio454 branch March 5, 2020 19:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants