diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 28cb6a3fa2..78c8c594c9 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -485,9 +485,10 @@ def complex_norm( Returns: Tensor: Power of the normed input tensor. Shape of `(..., )` """ - if power == 1.0: - return torch.norm(complex_tensor, 2, -1) - return torch.norm(complex_tensor, 2, -1).pow(power) + + # Replace by torch.norm once issue is fixed + # https://github.com/pytorch/pytorch/issues/34279 + return complex_tensor.pow(2.).sum(-1).pow(0.5 * power) def angle(