diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 27d4b20a34..8ba05dc670 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -33,6 +33,14 @@ COMMONVOICE :special-members: +GTZAN +~~~~~ + +.. autoclass:: GTZAN + :members: __getitem__ + :special-members: + + LIBRISPEECH ~~~~~~~~~~~ diff --git a/test/assets/genres/noise/noise.0000.wav b/test/assets/genres/noise/noise.0000.wav new file mode 100644 index 0000000000..cf1cbcde22 Binary files /dev/null and b/test/assets/genres/noise/noise.0000.wav differ diff --git a/test/test_datasets.py b/test/test_datasets.py index 96c12884d6..ab1e74e888 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -7,6 +7,7 @@ from torchaudio.datasets.vctk import VCTK from torchaudio.datasets.yesno import YESNO from torchaudio.datasets.ljspeech import LJSPEECH +from torchaudio.datasets.gtzan import GTZAN from . import common_utils @@ -55,6 +56,10 @@ def test_speechcommands(self): data = SPEECHCOMMANDS(self.path) data[0] + def test_gtzan(self): + data = GTZAN(self.path) + data[0] + if __name__ == "__main__": unittest.main() diff --git a/torchaudio/datasets/__init__.py b/torchaudio/datasets/__init__.py index 26d8216114..1741889029 100644 --- a/torchaudio/datasets/__init__.py +++ b/torchaudio/datasets/__init__.py @@ -3,6 +3,7 @@ from .speechcommands import SPEECHCOMMANDS from .utils import bg_iterator, diskcache_iterator from .vctk import VCTK +from .gtzan import GTZAN from .yesno import YESNO from .ljspeech import LJSPEECH @@ -13,6 +14,7 @@ "VCTK", "YESNO", "LJSPEECH", + "GTZAN", "diskcache_iterator", "bg_iterator", ) diff --git a/torchaudio/datasets/gtzan.py b/torchaudio/datasets/gtzan.py new file mode 100644 index 0000000000..13d7e8b68b --- /dev/null +++ b/torchaudio/datasets/gtzan.py @@ -0,0 +1,1054 @@ +import os +import warnings +from typing import Any, Tuple + +import torchaudio +from torch import Tensor +from torch.utils.data import Dataset +from torchaudio.datasets.utils import ( + download_url, + extract_archive, + walk_files, +) + +# The following lists prefixed with `filtered_` provide a filtered split +# that: +# +# a. Mitigate a known issue with GTZAN (duplication) +# +# b. Provide a standard split for testing it against other +# methods (e.g. the one in jordipons/sklearn-audio-transfer-learning). +# +# Those are used when GTZAN is initialised with the `filtered` keyword. +# The split was taken from (github) jordipons/sklearn-audio-transfer-learning. + +filtered_test = [ + "blues.00012", + "blues.00013", + "blues.00014", + "blues.00015", + "blues.00016", + "blues.00017", + "blues.00018", + "blues.00019", + "blues.00020", + "blues.00021", + "blues.00022", + "blues.00023", + "blues.00024", + "blues.00025", + "blues.00026", + "blues.00027", + "blues.00028", + "blues.00061", + "blues.00062", + "blues.00063", + "blues.00064", + "blues.00065", + "blues.00066", + "blues.00067", + "blues.00068", + "blues.00069", + "blues.00070", + "blues.00071", + "blues.00072", + "blues.00098", + "blues.00099", + "classical.00011", + "classical.00012", + "classical.00013", + "classical.00014", + "classical.00015", + "classical.00016", + "classical.00017", + "classical.00018", + "classical.00019", + "classical.00020", + "classical.00021", + "classical.00022", + "classical.00023", + "classical.00024", + "classical.00025", + "classical.00026", + "classical.00027", + "classical.00028", + "classical.00029", + "classical.00034", + "classical.00035", + "classical.00036", + "classical.00037", + "classical.00038", + "classical.00039", + "classical.00040", + "classical.00041", + "classical.00049", + "classical.00077", + "classical.00078", + "classical.00079", + "country.00030", + "country.00031", + "country.00032", + "country.00033", + "country.00034", + "country.00035", + "country.00036", + "country.00037", + "country.00038", + "country.00039", + "country.00040", + "country.00043", + "country.00044", + "country.00046", + "country.00047", + "country.00048", + "country.00050", + "country.00051", + "country.00053", + "country.00054", + "country.00055", + "country.00056", + "country.00057", + "country.00058", + "country.00059", + "country.00060", + "country.00061", + "country.00062", + "country.00063", + "country.00064", + "disco.00001", + "disco.00021", + "disco.00058", + "disco.00062", + "disco.00063", + "disco.00064", + "disco.00065", + "disco.00066", + "disco.00069", + "disco.00076", + "disco.00077", + "disco.00078", + "disco.00079", + "disco.00080", + "disco.00081", + "disco.00082", + "disco.00083", + "disco.00084", + "disco.00085", + "disco.00086", + "disco.00087", + "disco.00088", + "disco.00091", + "disco.00092", + "disco.00093", + "disco.00094", + "disco.00096", + "disco.00097", + "disco.00099", + "hiphop.00000", + "hiphop.00026", + "hiphop.00027", + "hiphop.00030", + "hiphop.00040", + "hiphop.00043", + "hiphop.00044", + "hiphop.00045", + "hiphop.00051", + "hiphop.00052", + "hiphop.00053", + "hiphop.00054", + "hiphop.00062", + "hiphop.00063", + "hiphop.00064", + "hiphop.00065", + "hiphop.00066", + "hiphop.00067", + "hiphop.00068", + "hiphop.00069", + "hiphop.00070", + "hiphop.00071", + "hiphop.00072", + "hiphop.00073", + "hiphop.00074", + "hiphop.00075", + "hiphop.00099", + "jazz.00073", + "jazz.00074", + "jazz.00075", + "jazz.00076", + "jazz.00077", + "jazz.00078", + "jazz.00079", + "jazz.00080", + "jazz.00081", + "jazz.00082", + "jazz.00083", + "jazz.00084", + "jazz.00085", + "jazz.00086", + "jazz.00087", + "jazz.00088", + "jazz.00089", + "jazz.00090", + "jazz.00091", + "jazz.00092", + "jazz.00093", + "jazz.00094", + "jazz.00095", + "jazz.00096", + "jazz.00097", + "jazz.00098", + "jazz.00099", + "metal.00012", + "metal.00013", + "metal.00014", + "metal.00015", + "metal.00022", + "metal.00023", + "metal.00025", + "metal.00026", + "metal.00027", + "metal.00028", + "metal.00029", + "metal.00030", + "metal.00031", + "metal.00032", + "metal.00033", + "metal.00038", + "metal.00039", + "metal.00067", + "metal.00070", + "metal.00073", + "metal.00074", + "metal.00075", + "metal.00078", + "metal.00083", + "metal.00085", + "metal.00087", + "metal.00088", + "pop.00000", + "pop.00001", + "pop.00013", + "pop.00014", + "pop.00043", + "pop.00063", + "pop.00064", + "pop.00065", + "pop.00066", + "pop.00069", + "pop.00070", + "pop.00071", + "pop.00072", + "pop.00073", + "pop.00074", + "pop.00075", + "pop.00076", + "pop.00077", + "pop.00078", + "pop.00079", + "pop.00082", + "pop.00088", + "pop.00089", + "pop.00090", + "pop.00091", + "pop.00092", + "pop.00093", + "pop.00094", + "pop.00095", + "pop.00096", + "reggae.00034", + "reggae.00035", + "reggae.00036", + "reggae.00037", + "reggae.00038", + "reggae.00039", + "reggae.00040", + "reggae.00046", + "reggae.00047", + "reggae.00048", + "reggae.00052", + "reggae.00053", + "reggae.00064", + "reggae.00065", + "reggae.00066", + "reggae.00067", + "reggae.00068", + "reggae.00071", + "reggae.00079", + "reggae.00082", + "reggae.00083", + "reggae.00084", + "reggae.00087", + "reggae.00088", + "reggae.00089", + "reggae.00090", + "rock.00010", + "rock.00011", + "rock.00012", + "rock.00013", + "rock.00014", + "rock.00015", + "rock.00027", + "rock.00028", + "rock.00029", + "rock.00030", + "rock.00031", + "rock.00032", + "rock.00033", + "rock.00034", + "rock.00035", + "rock.00036", + "rock.00037", + "rock.00039", + "rock.00040", + "rock.00041", + "rock.00042", + "rock.00043", + "rock.00044", + "rock.00045", + "rock.00046", + "rock.00047", + "rock.00048", + "rock.00086", + "rock.00087", + "rock.00088", + "rock.00089", + "rock.00090", +] + +filtered_train = [ + "blues.00029", + "blues.00030", + "blues.00031", + "blues.00032", + "blues.00033", + "blues.00034", + "blues.00035", + "blues.00036", + "blues.00037", + "blues.00038", + "blues.00039", + "blues.00040", + "blues.00041", + "blues.00042", + "blues.00043", + "blues.00044", + "blues.00045", + "blues.00046", + "blues.00047", + "blues.00048", + "blues.00049", + "blues.00073", + "blues.00074", + "blues.00075", + "blues.00076", + "blues.00077", + "blues.00078", + "blues.00079", + "blues.00080", + "blues.00081", + "blues.00082", + "blues.00083", + "blues.00084", + "blues.00085", + "blues.00086", + "blues.00087", + "blues.00088", + "blues.00089", + "blues.00090", + "blues.00091", + "blues.00092", + "blues.00093", + "blues.00094", + "blues.00095", + "blues.00096", + "blues.00097", + "classical.00030", + "classical.00031", + "classical.00032", + "classical.00033", + "classical.00043", + "classical.00044", + "classical.00045", + "classical.00046", + "classical.00047", + "classical.00048", + "classical.00050", + "classical.00051", + "classical.00052", + "classical.00053", + "classical.00054", + "classical.00055", + "classical.00056", + "classical.00057", + "classical.00058", + "classical.00059", + "classical.00060", + "classical.00061", + "classical.00062", + "classical.00063", + "classical.00064", + "classical.00065", + "classical.00066", + "classical.00067", + "classical.00080", + "classical.00081", + "classical.00082", + "classical.00083", + "classical.00084", + "classical.00085", + "classical.00086", + "classical.00087", + "classical.00088", + "classical.00089", + "classical.00090", + "classical.00091", + "classical.00092", + "classical.00093", + "classical.00094", + "classical.00095", + "classical.00096", + "classical.00097", + "classical.00098", + "classical.00099", + "country.00019", + "country.00020", + "country.00021", + "country.00022", + "country.00023", + "country.00024", + "country.00025", + "country.00026", + "country.00028", + "country.00029", + "country.00065", + "country.00066", + "country.00067", + "country.00068", + "country.00069", + "country.00070", + "country.00071", + "country.00072", + "country.00073", + "country.00074", + "country.00075", + "country.00076", + "country.00077", + "country.00078", + "country.00079", + "country.00080", + "country.00081", + "country.00082", + "country.00083", + "country.00084", + "country.00085", + "country.00086", + "country.00087", + "country.00088", + "country.00089", + "country.00090", + "country.00091", + "country.00092", + "country.00093", + "country.00094", + "country.00095", + "country.00096", + "country.00097", + "country.00098", + "country.00099", + "disco.00005", + "disco.00015", + "disco.00016", + "disco.00017", + "disco.00018", + "disco.00019", + "disco.00020", + "disco.00022", + "disco.00023", + "disco.00024", + "disco.00025", + "disco.00026", + "disco.00027", + "disco.00028", + "disco.00029", + "disco.00030", + "disco.00031", + "disco.00032", + "disco.00033", + "disco.00034", + "disco.00035", + "disco.00036", + "disco.00037", + "disco.00039", + "disco.00040", + "disco.00041", + "disco.00042", + "disco.00043", + "disco.00044", + "disco.00045", + "disco.00047", + "disco.00049", + "disco.00053", + "disco.00054", + "disco.00056", + "disco.00057", + "disco.00059", + "disco.00061", + "disco.00070", + "disco.00073", + "disco.00074", + "disco.00089", + "hiphop.00002", + "hiphop.00003", + "hiphop.00004", + "hiphop.00005", + "hiphop.00006", + "hiphop.00007", + "hiphop.00008", + "hiphop.00009", + "hiphop.00010", + "hiphop.00011", + "hiphop.00012", + "hiphop.00013", + "hiphop.00014", + "hiphop.00015", + "hiphop.00016", + "hiphop.00017", + "hiphop.00018", + "hiphop.00019", + "hiphop.00020", + "hiphop.00021", + "hiphop.00022", + "hiphop.00023", + "hiphop.00024", + "hiphop.00025", + "hiphop.00028", + "hiphop.00029", + "hiphop.00031", + "hiphop.00032", + "hiphop.00033", + "hiphop.00034", + "hiphop.00035", + "hiphop.00036", + "hiphop.00037", + "hiphop.00038", + "hiphop.00041", + "hiphop.00042", + "hiphop.00055", + "hiphop.00056", + "hiphop.00057", + "hiphop.00058", + "hiphop.00059", + "hiphop.00060", + "hiphop.00061", + "hiphop.00077", + "hiphop.00078", + "hiphop.00079", + "hiphop.00080", + "jazz.00000", + "jazz.00001", + "jazz.00011", + "jazz.00012", + "jazz.00013", + "jazz.00014", + "jazz.00015", + "jazz.00016", + "jazz.00017", + "jazz.00018", + "jazz.00019", + "jazz.00020", + "jazz.00021", + "jazz.00022", + "jazz.00023", + "jazz.00024", + "jazz.00041", + "jazz.00047", + "jazz.00048", + "jazz.00049", + "jazz.00050", + "jazz.00051", + "jazz.00052", + "jazz.00053", + "jazz.00054", + "jazz.00055", + "jazz.00056", + "jazz.00057", + "jazz.00058", + "jazz.00059", + "jazz.00060", + "jazz.00061", + "jazz.00062", + "jazz.00063", + "jazz.00064", + "jazz.00065", + "jazz.00066", + "jazz.00067", + "jazz.00068", + "jazz.00069", + "jazz.00070", + "jazz.00071", + "jazz.00072", + "metal.00002", + "metal.00003", + "metal.00005", + "metal.00021", + "metal.00024", + "metal.00035", + "metal.00046", + "metal.00047", + "metal.00048", + "metal.00049", + "metal.00050", + "metal.00051", + "metal.00052", + "metal.00053", + "metal.00054", + "metal.00055", + "metal.00056", + "metal.00057", + "metal.00059", + "metal.00060", + "metal.00061", + "metal.00062", + "metal.00063", + "metal.00064", + "metal.00065", + "metal.00066", + "metal.00069", + "metal.00071", + "metal.00072", + "metal.00079", + "metal.00080", + "metal.00084", + "metal.00086", + "metal.00089", + "metal.00090", + "metal.00091", + "metal.00092", + "metal.00093", + "metal.00094", + "metal.00095", + "metal.00096", + "metal.00097", + "metal.00098", + "metal.00099", + "pop.00002", + "pop.00003", + "pop.00004", + "pop.00005", + "pop.00006", + "pop.00007", + "pop.00008", + "pop.00009", + "pop.00011", + "pop.00012", + "pop.00016", + "pop.00017", + "pop.00018", + "pop.00019", + "pop.00020", + "pop.00023", + "pop.00024", + "pop.00025", + "pop.00026", + "pop.00027", + "pop.00028", + "pop.00029", + "pop.00031", + "pop.00032", + "pop.00033", + "pop.00034", + "pop.00035", + "pop.00036", + "pop.00038", + "pop.00039", + "pop.00040", + "pop.00041", + "pop.00042", + "pop.00044", + "pop.00046", + "pop.00049", + "pop.00050", + "pop.00080", + "pop.00097", + "pop.00098", + "pop.00099", + "reggae.00000", + "reggae.00001", + "reggae.00002", + "reggae.00004", + "reggae.00006", + "reggae.00009", + "reggae.00011", + "reggae.00012", + "reggae.00014", + "reggae.00015", + "reggae.00016", + "reggae.00017", + "reggae.00018", + "reggae.00019", + "reggae.00020", + "reggae.00021", + "reggae.00022", + "reggae.00023", + "reggae.00024", + "reggae.00025", + "reggae.00026", + "reggae.00027", + "reggae.00028", + "reggae.00029", + "reggae.00030", + "reggae.00031", + "reggae.00032", + "reggae.00042", + "reggae.00043", + "reggae.00044", + "reggae.00045", + "reggae.00049", + "reggae.00050", + "reggae.00051", + "reggae.00054", + "reggae.00055", + "reggae.00056", + "reggae.00057", + "reggae.00058", + "reggae.00059", + "reggae.00060", + "reggae.00063", + "reggae.00069", + "rock.00000", + "rock.00001", + "rock.00002", + "rock.00003", + "rock.00004", + "rock.00005", + "rock.00006", + "rock.00007", + "rock.00008", + "rock.00009", + "rock.00016", + "rock.00017", + "rock.00018", + "rock.00019", + "rock.00020", + "rock.00021", + "rock.00022", + "rock.00023", + "rock.00024", + "rock.00025", + "rock.00026", + "rock.00057", + "rock.00058", + "rock.00059", + "rock.00060", + "rock.00061", + "rock.00062", + "rock.00063", + "rock.00064", + "rock.00065", + "rock.00066", + "rock.00067", + "rock.00068", + "rock.00069", + "rock.00070", + "rock.00091", + "rock.00092", + "rock.00093", + "rock.00094", + "rock.00095", + "rock.00096", + "rock.00097", + "rock.00098", + "rock.00099", +] + +filtered_valid = [ + "blues.00000", + "blues.00001", + "blues.00002", + "blues.00003", + "blues.00004", + "blues.00005", + "blues.00006", + "blues.00007", + "blues.00008", + "blues.00009", + "blues.00010", + "blues.00011", + "blues.00050", + "blues.00051", + "blues.00052", + "blues.00053", + "blues.00054", + "blues.00055", + "blues.00056", + "blues.00057", + "blues.00058", + "blues.00059", + "blues.00060", + "classical.00000", + "classical.00001", + "classical.00002", + "classical.00003", + "classical.00004", + "classical.00005", + "classical.00006", + "classical.00007", + "classical.00008", + "classical.00009", + "classical.00010", + "classical.00068", + "classical.00069", + "classical.00070", + "classical.00071", + "classical.00072", + "classical.00073", + "classical.00074", + "classical.00075", + "classical.00076", + "country.00000", + "country.00001", + "country.00002", + "country.00003", + "country.00004", + "country.00005", + "country.00006", + "country.00007", + "country.00009", + "country.00010", + "country.00011", + "country.00012", + "country.00013", + "country.00014", + "country.00015", + "country.00016", + "country.00017", + "country.00018", + "country.00027", + "country.00041", + "country.00042", + "country.00045", + "country.00049", + "disco.00000", + "disco.00002", + "disco.00003", + "disco.00004", + "disco.00006", + "disco.00007", + "disco.00008", + "disco.00009", + "disco.00010", + "disco.00011", + "disco.00012", + "disco.00013", + "disco.00014", + "disco.00046", + "disco.00048", + "disco.00052", + "disco.00067", + "disco.00068", + "disco.00072", + "disco.00075", + "disco.00090", + "disco.00095", + "hiphop.00081", + "hiphop.00082", + "hiphop.00083", + "hiphop.00084", + "hiphop.00085", + "hiphop.00086", + "hiphop.00087", + "hiphop.00088", + "hiphop.00089", + "hiphop.00090", + "hiphop.00091", + "hiphop.00092", + "hiphop.00093", + "hiphop.00094", + "hiphop.00095", + "hiphop.00096", + "hiphop.00097", + "hiphop.00098", + "jazz.00002", + "jazz.00003", + "jazz.00004", + "jazz.00005", + "jazz.00006", + "jazz.00007", + "jazz.00008", + "jazz.00009", + "jazz.00010", + "jazz.00025", + "jazz.00026", + "jazz.00027", + "jazz.00028", + "jazz.00029", + "jazz.00030", + "jazz.00031", + "jazz.00032", + "metal.00000", + "metal.00001", + "metal.00006", + "metal.00007", + "metal.00008", + "metal.00009", + "metal.00010", + "metal.00011", + "metal.00016", + "metal.00017", + "metal.00018", + "metal.00019", + "metal.00020", + "metal.00036", + "metal.00037", + "metal.00068", + "metal.00076", + "metal.00077", + "metal.00081", + "metal.00082", + "pop.00010", + "pop.00053", + "pop.00055", + "pop.00058", + "pop.00059", + "pop.00060", + "pop.00061", + "pop.00062", + "pop.00081", + "pop.00083", + "pop.00084", + "pop.00085", + "pop.00086", + "reggae.00061", + "reggae.00062", + "reggae.00070", + "reggae.00072", + "reggae.00074", + "reggae.00076", + "reggae.00077", + "reggae.00078", + "reggae.00085", + "reggae.00092", + "reggae.00093", + "reggae.00094", + "reggae.00095", + "reggae.00096", + "reggae.00097", + "reggae.00098", + "reggae.00099", + "rock.00038", + "rock.00049", + "rock.00050", + "rock.00051", + "rock.00052", + "rock.00053", + "rock.00054", + "rock.00055", + "rock.00056", + "rock.00071", + "rock.00072", + "rock.00073", + "rock.00074", + "rock.00075", + "rock.00076", + "rock.00077", + "rock.00078", + "rock.00079", + "rock.00080", + "rock.00081", + "rock.00082", + "rock.00083", + "rock.00084", + "rock.00085", +] + + +URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz" +FOLDER_IN_ARCHIVE = "genres" +_CHECKSUMS = {"http://opihi.cs.uvic.ca/sound/genres.tar.gz": "5b3d6dddb579ab49814ab86dba69e7c7"} + + +def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]: + """ + Loads a file from the dataset and returns the raw waveform + as a Torch Tensor, its sample rate as an integer, and its + genre as a string. + """ + # Filenames are of the form label.id, e.g. blues.00078 + label, _ = fileid.split(".") + + # Read wav + file_audio = os.path.join(path, label, fileid + ext_audio) + waveform, sample_rate = torchaudio.load(file_audio) + + return waveform, sample_rate, label + + +class GTZAN(Dataset): + """ + Create a Dataset for GTZAN. Each item is a tuple of the form: + waveform, sample_rate, label. + + Please see http://marsyas.info/downloads/datasets.html + if you are planning to use this dataset to publish results. + """ + + _ext_audio = ".wav" + + def __init__( + self, + root: str, + url: str = URL, + folder_in_archive: str = FOLDER_IN_ARCHIVE, + download: bool = False, + subset: Any = None, + ) -> None: + + # super(GTZAN, self).__init__() + self.root = root + self.url = url + self.folder_in_archive = folder_in_archive + self.download = download + self.subset = subset + + assert subset is None or subset in ["training", "validation", "testing"], ( + "When `subset` not None, it must take a value from " + + "{'training', 'validation', 'testing'}." + ) + + archive = os.path.basename(url) + archive = os.path.join(root, archive) + self._path = os.path.join(root, folder_in_archive) + + if download: + if not os.path.isdir(self._path): + if not os.path.isfile(archive): + checksum = _CHECKSUMS.get(url, None) + download_url(url, root, hash_value=checksum, hash_type="md5") + extract_archive(archive) + + if not os.path.isdir(self._path): + raise RuntimeError( + "Dataset not found. Please use `download=True` to download it." + ) + + if self.subset is None: + walker = walk_files( + self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True + ) + self._walker = list(walker) + else: + if self.subset == "training": + self._walker = filtered_train + elif self.subset == "validation": + self._walker = filtered_valid + elif self.subset == "testing": + self._walker = filtered_test + + def __getitem__(self, n: int) -> Tuple[Tensor, int, str]: + fileid = self._walker[n] + item = load_gtzan_item(fileid, self._path, self._ext_audio) + waveform, sample_rate, label = item + return waveform, sample_rate, label + + def __len__(self) -> int: + return len(self._walker)