@@ -114,8 +114,8 @@ class _Stretch2d(nn.Module):
114114 Koray Kavukcuoglu. arXiv:1802.08435, 2018.
115115
116116 Args:
117- x_scale: the scale factor in x axis (required).
118- y_scale: the scale factor in y axis (required).
117+ x_scale: the scale factor in x axis (required)
118+ y_scale: the scale factor in y axis (required)
119119
120120 Examples::
121121 >>> stretch2d = _Stretch2d(x_scale=1, y_scale=1)
@@ -136,19 +136,17 @@ def forward(self, x: Tensor) -> Tensor:
136136 r"""Pass the input through the _Stretch2d layer.
137137
138138 Args:
139- x: the input sequence to the _Stretch2d layer (required).
139+ x: the input sequence to the _Stretch2d layer (required)
140140
141141 Shape:
142- - x: :math:`(N, C, S, T)`.
143- - output: :math:`(N, C, S * y_scale, T * x_scale)`.
144- where N is the batch size, C is the channel size, S is the number of input sequence,
145- T is the length of input sequence.
142+ - x: :math:`(batch_size, channel, freq, time)`
143+ - output: :math:`(batch_size, channel, freq * y_scale, time * x_scale)`
146144 """
147145
148- n , c , s , t = x .size ()
146+ batch_size , channel , freq , time = x .size ()
149147 x = x .unsqueeze (- 1 ).unsqueeze (3 )
150148 x = x .repeat (1 , 1 , 1 , self .y_scale , 1 , self .x_scale )
151- return x .view (n , c , s * self .y_scale , t * self .x_scale )
149+ return x .view (batch_size , channel , freq * self .y_scale , time * self .x_scale )
152150
153151
154152class _UpsampleNetwork (nn .Module ):
@@ -158,12 +156,12 @@ class _UpsampleNetwork(nn.Module):
158156 Florian Stimberg, Aaron van den Oord, Sander Dieleman, Koray Kavukcuoglu. arXiv:1802.08435, 2018.
159157
160158 Args:
161- upsample_scales: the list of upsample scales (required).
162- res_blocks: the number of ResBlock in stack (default=10).
163- input_dims: the number of input sequence (default=100).
164- hidden_dims: the number of compute dimensions (default=128).
165- output_dims: the number of output sequence (default=128).
166- pad: the number of kernal size (pad * 2 + 1) in the first Conv1d layer (default=2).
159+ upsample_scales: the list of upsample scales (required)
160+ res_blocks: the number of ResBlock in stack (default=10)
161+ input_dims: the number of input sequence (default=100)
162+ hidden_dims: the number of compute dimensions (default=128)
163+ output_dims: the number of output sequence (default=128)
164+ pad: the kernel size (kernel_size = pad * 2 + 1) in the first Conv1d layer (default=2)
167165
168166 Examples::
169167 >>> upsamplenetwork = _UpsampleNetwork(upsample_scales=[4, 4, 16],
@@ -208,20 +206,20 @@ def forward(self, x: Tensor) -> Tensor:
208206 r"""Pass the input through the _UpsampleNetwork layer.
209207
210208 Args:
211- x: the input sequence to the _UpsampleNetwork layer (required).
209+ x: the input sequence to the _UpsampleNetwork layer (required)
212210
213211 Shape:
214- - x: :math:`(N, S, T)`.
215- - output: :math:`(N, (T - 2 * pad) * Total_Scale, S)`, `(N, (T - 2 * pad) * total_scale, P)`.
216- where N is the batch size, S is the number of input sequence, T is the length of input sequence.
217- P is the number of output sequence. Total_Scale is the product of all elements in upsample_scales.
212+ - x: :math:`(batch_size, freq, time)`
213+ - output: :math:`(batch_size, (time - 2 * pad) * total_scale, freq)`, `(batch_size, (time - 2 * pad) * total_scale, output_dims)`
214+ where total_scale is the product of all elements in upsample_scales.
218215 """
219216
220217 resnet_output = self .resnet (x ).unsqueeze (1 )
221218 resnet_output = self .resnet_stretch (resnet_output )
222219 resnet_output = resnet_output .squeeze (1 )
223220
224- upsampling_output = self .upsample_layers (x .unsqueeze (1 ))
221+ x = x .unsqueeze (1 )
222+ upsampling_output = self .upsample_layers (x )
225223 upsampling_output = upsampling_output .squeeze (1 )[:, :, self .indent :- self .indent ]
226224
227225 return upsampling_output .transpose (1 , 2 ), resnet_output .transpose (1 , 2 )
0 commit comments