Skip to content
Binary file added test/assets/genres/blues/blues.00000.wav
Binary file not shown.
48 changes: 42 additions & 6 deletions torchaudio/datasets/gtzan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torchaudio.datasets.utils import (
download_url,
extract_archive,
walk_files,
)

# The following lists prefixed with `filtered_` provide a filtered split
Expand All @@ -22,6 +21,19 @@
# Those are used when GTZAN is initialised with the `filtered` keyword.
# The split was taken from (github) jordipons/sklearn-audio-transfer-learning.

gtzan_genres = [
"blues",
"classical",
"country",
"disco",
"hiphop",
"jazz",
"metal",
"pop",
"reggae",
"rock",
]

filtered_test = [
"blues.00012",
"blues.00013",
Expand Down Expand Up @@ -964,7 +976,9 @@

URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz"
FOLDER_IN_ARCHIVE = "genres"
_CHECKSUMS = {"http://opihi.cs.uvic.ca/sound/genres.tar.gz": "5b3d6dddb579ab49814ab86dba69e7c7"}
_CHECKSUMS = {
"http://opihi.cs.uvic.ca/sound/genres.tar.gz": "5b3d6dddb579ab49814ab86dba69e7c7"
}


def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]:
Expand Down Expand Up @@ -1032,10 +1046,32 @@ def __init__(
)

if self.subset is None:
walker = walk_files(
self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True
)
self._walker = list(walker)
# Check every subdirectory under dataset root
# which has the same name as the genres in
# GTZAN (e.g. `root_dir'/blues/, `root_dir'/rock, etc.)
# This lets users remove or move around song files,
# useful when e.g. they want to use only some of the files
# in a genre or want to label other files with a different
# genre.
self._walker = []

root = os.path.expanduser(self._path)

for directory in gtzan_genres:
fulldir = os.path.join(root, directory)

if not os.path.exists(fulldir):
continue

songs_in_genre = os.listdir(fulldir)
for fname in songs_in_genre:
name, ext = os.path.splitext(fname)
if ext.lower() == ".wav" and "." in name:
# Check whether the file is of the form
# `gtzan_genre`.`5 digit number`.wav
genre, num = name.split(".")
if genre in gtzan_genres and len(num) == 5 and num.isdigit():
self._walker.append(name)
else:
if self.subset == "training":
self._walker = filtered_train
Expand Down