Skip to content

torch.norm is 113x slower than torch.sqrt(a**2 + b**2) #455

@PetrochukM

Description

@PetrochukM

🐛 Bug

This is mostly a PyTorch bug but it has a significant impact on the performance of computing a magnitude spectrogram.

To Reproduce

import torch
import timeit
import numpy

input_ = torch.randn(500, 1000, 2)


def version_one(input_):
    a, b = input_.unbind(-1)
    return torch.sqrt(a**2 + b**2)


def version_two(input_):
    return torch.norm(input_, dim=-1)


numpy.testing.assert_almost_equal(
    version_one(input_).numpy(), version_two(input_).numpy(), decimal=6)
print('Version One:', timeit.timeit('version_one(input_)', number=100, globals=globals()))
print('Version Two:', timeit.timeit('version_two(input_)', number=100, globals=globals()))
Version One: 0.10085414599999998
Version Two: 13.104690384

Expected behavior

The performance should be similar...

Environment

Collecting environment information...
PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: None

OS: Mac OSX 10.14.6
GCC version: Could not collect
CMake version: Could not collect

Python version: 3.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip3] numpy==1.18.1
[pip3] pytorch-nlp==0.5.0
[pip3] torch==1.4.0
[pip3] torchaudio==0.4.0
[conda] Could not collect

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions