diff --git a/test/assets/genres/blues/blues.00000.wav b/test/assets/genres/blues/blues.00000.wav new file mode 100644 index 0000000000..cf1cbcde22 Binary files /dev/null and b/test/assets/genres/blues/blues.00000.wav differ diff --git a/torchaudio/datasets/gtzan.py b/torchaudio/datasets/gtzan.py index 13d7e8b68b..9098cf1fe0 100644 --- a/torchaudio/datasets/gtzan.py +++ b/torchaudio/datasets/gtzan.py @@ -8,7 +8,6 @@ from torchaudio.datasets.utils import ( download_url, extract_archive, - walk_files, ) # The following lists prefixed with `filtered_` provide a filtered split @@ -22,6 +21,19 @@ # Those are used when GTZAN is initialised with the `filtered` keyword. # The split was taken from (github) jordipons/sklearn-audio-transfer-learning. +gtzan_genres = [ + "blues", + "classical", + "country", + "disco", + "hiphop", + "jazz", + "metal", + "pop", + "reggae", + "rock", +] + filtered_test = [ "blues.00012", "blues.00013", @@ -964,7 +976,9 @@ URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz" FOLDER_IN_ARCHIVE = "genres" -_CHECKSUMS = {"http://opihi.cs.uvic.ca/sound/genres.tar.gz": "5b3d6dddb579ab49814ab86dba69e7c7"} +_CHECKSUMS = { + "http://opihi.cs.uvic.ca/sound/genres.tar.gz": "5b3d6dddb579ab49814ab86dba69e7c7" +} def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]: @@ -1032,10 +1046,32 @@ def __init__( ) if self.subset is None: - walker = walk_files( - self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True - ) - self._walker = list(walker) + # Check every subdirectory under dataset root + # which has the same name as the genres in + # GTZAN (e.g. `root_dir'/blues/, `root_dir'/rock, etc.) + # This lets users remove or move around song files, + # useful when e.g. they want to use only some of the files + # in a genre or want to label other files with a different + # genre. + self._walker = [] + + root = os.path.expanduser(self._path) + + for directory in gtzan_genres: + fulldir = os.path.join(root, directory) + + if not os.path.exists(fulldir): + continue + + songs_in_genre = os.listdir(fulldir) + for fname in songs_in_genre: + name, ext = os.path.splitext(fname) + if ext.lower() == ".wav" and "." in name: + # Check whether the file is of the form + # `gtzan_genre`.`5 digit number`.wav + genre, num = name.split(".") + if genre in gtzan_genres and len(num) == 5 and num.isdigit(): + self._walker.append(name) else: if self.subset == "training": self._walker = filtered_train