-
Notifications
You must be signed in to change notification settings - Fork 739
Description
🐛 Bug
I am trying to train a pytorch model on audio samples using Google Colab, to load these samples I am using the following dataloader:
import torch
from torch.utils.data import Dataset
from torchaudio import load, transforms
import glob
import os
import numpy as np
class AudioDataset(Dataset):
def __init__(self, path, sample_rate=22050, n_fft=2048, n_mels=128, log_mel=True):
"""
A custom dataset class to load audio snippets and create
mel spectrograms.
Args:
path (string): path to folder with audio files
sample_rate (integer): sample rate of audio signal
n_fft (integer): number of Fourier transforms to use for the mel spectrogram
n_mels (integer): number of mel bins to use for the mel spectrogram
log_mel (boolean): whether to use log-mel spectrograms instead of db-scaled
"""
self.path = path
self.sr = sample_rate
self.n_fft = n_fft
self.n_mels = n_mels
self.log_mel = log_mel
self.file_paths = glob.glob(
os.path.join(self.path, "**", f"*wav"), recursive=True
)
self.labels = [x.split("/")[-2] for x in self.file_paths]
self.mapping = {"ads_other": 0, "music": 1}
for i, label in enumerate(self.labels):
self.labels[i] = self.mapping[label]
def __len__(self):
return len(self.file_paths)
def __getitem__(self, index):
audio, sr = load(self.file_paths[index])
audio = torch.mean(audio, dim=0, keepdim=True)
if self.sr != sr:
audio = transforms.Resample(sr, self.sr)(audio)
mel_spectrogram = transforms.MelSpectrogram(
sample_rate=self.sr, n_fft=self.n_fft, n_mels=self.n_mels, f_max=self.sr / 2
)(audio)
if self.log_mel:
offset = 1e-6
mel_spectrogram = torch.log(mel_spectrogram + offset)
else:
mel_spectrogram = transforms.AmplitudeToDB(stype="power", top_db=80)(mel_spectrogram)
label = self.labels[index]
return mel_spectrogram, label
When training my model the amount of RAM used stays constant at the start, however after around 120 batches starts to suddenly increase. This continues to around 170 batches, after which memory usage is at 100% and Google Colab crashes/disconnects (see this image for the memory usage over time: https://i.imgur.com/jGC4CH2.png).
To Reproduce
I've narrowed the problem down to the following minimum working example:
import torch
import torchaudio
for epoch in range(10):
print(f"Epoch {epoch+1}/{30}")
for example in range(20000):
if (example + 1) % 100 == 0:
print(f"Example {example + 1}")
x = torch.rand(1, 88200)
mel_spec = torchaudio.transforms.MelSpectrogram()(x)
The memory started to spike when around 14000 examples in the first epoch.
Environment
I installed torchaudio using pip and am using version 0.3.1 (which is also what torchaudio.__version__ prints).
Output from collection_env.py:
Collecting environment information...
PyTorch version: 1.3.1
Is debug build: No
CUDA used to build PyTorch: 10.1.243
OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.12.0
Python version: 3.6
Is CUDA available: No
CUDA runtime version: 10.1.243
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
Versions of relevant libraries:
[pip3] numpy==1.17.4
[pip3] torch==1.3.1
[pip3] torchaudio==0.3.1
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.3.1
[pip3] torchvision==0.4.2
[conda] Could not collect
Additional Context
The same issue occurs when calling torchaudio.transforms.Spectrogram (which is what torchaudio.transforms.MelSpectrogram is calling first), however it does seem to take a little longer before the out of memory problem occurs.