Skip to content

Commit c98e289

Browse files
author
Ji Chen
committed
update variable names
1 parent 7c58904 commit c98e289

File tree

2 files changed

+24
-24
lines changed

2 files changed

+24
-24
lines changed

test/test_models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import torch
22
from torchaudio.models import Wav2Letter, _MelResNet, _UpsampleNetwork
33

4+
from . import common_utils
45

5-
class TestWav2Letter:
6+
7+
class TestWav2Letter(common_utils.TorchaudioTestCase):
68

79
def test_waveform(self):
810
batch_size = 2
@@ -31,7 +33,7 @@ def test_mfcc(self):
3133
assert out.size() == (batch_size, num_classes, 2)
3234

3335

34-
class TestMelResNet:
36+
class TestMelResNet(common_utils.TorchaudioTestCase):
3537

3638
def test_waveform(self):
3739

@@ -51,7 +53,7 @@ def test_waveform(self):
5153
assert out.size() == (batch_size, output_dims, num_features - pad * 2)
5254

5355

54-
class TestUpsampleNetwork:
56+
class TestUpsampleNetwork(common_utils.TorchaudioTestCase):
5557

5658
def test_waveform(self):
5759

torchaudio/models/_wavernn.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ class _Stretch2d(nn.Module):
114114
Koray Kavukcuoglu. arXiv:1802.08435, 2018.
115115
116116
Args:
117-
x_scale: the scale factor in x axis (required).
118-
y_scale: the scale factor in y axis (required).
117+
x_scale: the scale factor in x axis (required)
118+
y_scale: the scale factor in y axis (required)
119119
120120
Examples::
121121
>>> stretch2d = _Stretch2d(x_scale=1, y_scale=1)
@@ -136,19 +136,17 @@ def forward(self, x: Tensor) -> Tensor:
136136
r"""Pass the input through the _Stretch2d layer.
137137
138138
Args:
139-
x: the input sequence to the _Stretch2d layer (required).
139+
x: the input sequence to the _Stretch2d layer (required)
140140
141141
Shape:
142-
- x: :math:`(N, C, S, T)`.
143-
- output: :math:`(N, C, S * y_scale, T * x_scale)`.
144-
where N is the batch size, C is the channel size, S is the number of input sequence,
145-
T is the length of input sequence.
142+
- x: :math:`(batch_size, channel, freq, time)`
143+
- output: :math:`(batch_size, channel, freq * y_scale, time * x_scale)`
146144
"""
147145

148-
n, c, s, t = x.size()
146+
batch_size, channel, freq, time = x.size()
149147
x = x.unsqueeze(-1).unsqueeze(3)
150148
x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
151-
return x.view(n, c, s * self.y_scale, t * self.x_scale)
149+
return x.view(batch_size, channel, freq * self.y_scale, time * self.x_scale)
152150

153151

154152
class _UpsampleNetwork(nn.Module):
@@ -158,12 +156,12 @@ class _UpsampleNetwork(nn.Module):
158156
Florian Stimberg, Aaron van den Oord, Sander Dieleman, Koray Kavukcuoglu. arXiv:1802.08435, 2018.
159157
160158
Args:
161-
upsample_scales: the list of upsample scales (required).
162-
res_blocks: the number of ResBlock in stack (default=10).
163-
input_dims: the number of input sequence (default=100).
164-
hidden_dims: the number of compute dimensions (default=128).
165-
output_dims: the number of output sequence (default=128).
166-
pad: the number of kernal size (pad * 2 + 1) in the first Conv1d layer (default=2).
159+
upsample_scales: the list of upsample scales (required)
160+
res_blocks: the number of ResBlock in stack (default=10)
161+
input_dims: the number of input sequence (default=100)
162+
hidden_dims: the number of compute dimensions (default=128)
163+
output_dims: the number of output sequence (default=128)
164+
pad: the kernel size (kernel_size = pad * 2 + 1) in the first Conv1d layer (default=2)
167165
168166
Examples::
169167
>>> upsamplenetwork = _UpsampleNetwork(upsample_scales=[4, 4, 16],
@@ -208,20 +206,20 @@ def forward(self, x: Tensor) -> Tensor:
208206
r"""Pass the input through the _UpsampleNetwork layer.
209207
210208
Args:
211-
x: the input sequence to the _UpsampleNetwork layer (required).
209+
x: the input sequence to the _UpsampleNetwork layer (required)
212210
213211
Shape:
214-
- x: :math:`(N, S, T)`.
215-
- output: :math:`(N, (T - 2 * pad) * Total_Scale, S)`, `(N, (T - 2 * pad) * total_scale, P)`.
216-
where N is the batch size, S is the number of input sequence, T is the length of input sequence.
217-
P is the number of output sequence. Total_Scale is the product of all elements in upsample_scales.
212+
- x: :math:`(batch_size, freq, time)`
213+
- output: :math:`(batch_size, (time - 2 * pad) * total_scale, freq)`, `(batch_size, (time - 2 * pad) * total_scale, output_dims)`
214+
where total_scale is the product of all elements in upsample_scales.
218215
"""
219216

220217
resnet_output = self.resnet(x).unsqueeze(1)
221218
resnet_output = self.resnet_stretch(resnet_output)
222219
resnet_output = resnet_output.squeeze(1)
223220

224-
upsampling_output = self.upsample_layers(x.unsqueeze(1))
221+
x = x.unsqueeze(1)
222+
upsampling_output = self.upsample_layers(x)
225223
upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent]
226224

227225
return upsampling_output.transpose(1, 2), resnet_output.transpose(1, 2)

0 commit comments

Comments
 (0)