Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 101 additions & 1 deletion torchaudio/datasets/yesno.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import os
from pathlib import Path
from dataclasses import dataclass
from typing import List, Tuple, Union

from torch import Tensor
from torch.utils.data import Dataset
from torch.utils.data import Dataset, IterDataPipe

import torchaudio
from torchaudio.datasets.utils import (
download_url,
extract_archive,

)


Expand All @@ -17,10 +19,108 @@
"folder_in_archive": "waves_yesno",
"url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz",
"checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27",
"files": [
"0_0_0_0_1_1_1_1.wav",
"0_0_0_1_0_0_0_1.wav",
"0_0_0_1_0_1_1_0.wav",
"0_0_1_0_0_0_1_0.wav",
"0_0_1_0_0_1_1_0.wav",
"0_0_1_0_0_1_1_1.wav",
"0_0_1_0_1_0_0_0.wav",
"0_0_1_0_1_0_0_1.wav",
"0_0_1_0_1_0_1_1.wav",
"0_0_1_1_0_0_0_1.wav",
"0_0_1_1_0_1_0_0.wav",
"0_0_1_1_0_1_1_0.wav",
"0_0_1_1_0_1_1_1.wav",
"0_0_1_1_1_0_0_0.wav",
"0_0_1_1_1_0_0_1.wav",
"0_0_1_1_1_1_0_0.wav",
"0_0_1_1_1_1_1_0.wav",
"0_1_0_0_0_1_0_0.wav",
"0_1_0_0_0_1_1_0.wav",
"0_1_0_0_1_0_1_0.wav",
"0_1_0_0_1_0_1_1.wav",
"0_1_0_1_0_0_0_0.wav",
"0_1_0_1_1_0_1_0.wav",
"0_1_0_1_1_1_0_0.wav",
"0_1_1_0_0_1_1_0.wav",
"0_1_1_0_0_1_1_1.wav",
"0_1_1_1_0_0_0_0.wav",
"0_1_1_1_0_0_1_0.wav",
"0_1_1_1_0_1_0_1.wav",
"0_1_1_1_1_0_1_0.wav",
"0_1_1_1_1_1_1_1.wav",
"1_0_0_0_0_0_0_0.wav",
"1_0_0_0_0_0_0_1.wav",
"1_0_0_0_0_0_1_1.wav",
"1_0_0_0_1_0_0_1.wav",
"1_0_0_1_0_1_1_1.wav",
"1_0_1_0_1_0_0_1.wav",
"1_0_1_1_0_1_1_1.wav",
"1_0_1_1_1_0_1_0.wav",
"1_0_1_1_1_1_0_1.wav",
"1_1_0_0_0_0_0_1.wav",
"1_1_0_0_0_1_1_1.wav",
"1_1_0_0_1_0_1_0.wav",
"1_1_0_0_1_0_1_1.wav",
"1_1_0_0_1_1_1_0.wav",
"1_1_0_1_0_1_0_0.wav",
"1_1_0_1_0_1_1_0.wav",
"1_1_0_1_1_0_0_1.wav",
"1_1_0_1_1_0_1_1.wav",
"1_1_0_1_1_1_1_0.wav",
"1_1_1_0_0_0_0_1.wav",
"1_1_1_0_0_1_0_1.wav",
"1_1_1_0_0_1_1_1.wav",
"1_1_1_0_1_0_1_0.wav",
"1_1_1_0_1_0_1_1.wav",
"1_1_1_1_0_0_1_0.wav",
"1_1_1_1_0_1_0_0.wav",
"1_1_1_1_1_0_0_0.wav",
"1_1_1_1_1_1_0_0.wav",
"1_1_1_1_1_1_1_1.wav",
]
}
}


@dataclass
class YesNoItem:
path: str
label: List[int]
waveform: Tensor
sample_rate: int


class ListYesNoItems(IterDataPipe):
"""Given a root directory, return the list of files"""
def __init__(self, root):
self.data_dir = os.path.join(root, 'waves_yesno')
self.files = _RELEASE_CONFIGS['release1']['files']

def __iter__(self):
for filename in self.files:
path = os.path.join(self.data_dir, filename)
label = [int(c) for c in path.split("_")]
yield path, label
Comment on lines +102 to +106
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use ListDirFiles to generate a filename per iteration.



class LoadYesNoItem(IterDataPipe):
def __init__(self, data_pipe):
self.data_pipe = data_pipe

def __iter__(self):
for path, label in self.data_pipe:
waveform, sample_rate = torchaudio.load(path)
yield YesNoItem(path, label, waveform, sample_rate)
Comment on lines +109 to +116
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use Map to apply a function to each item in the pipeline.



def get_yesno_dataset(root_dir: str): # , download=False, url=_RELEASE_CONFIGS["release1"]["url"]):
# TODO: download dataset if necessary
return LoadYesNoItem(ListYesNoItems(root_dir))
Comment on lines +120 to +121
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use existing DataPipes

from torch.utils.data import datapipes

dp = datapipes.iter. ListDirFiles(root_dir)
dp = datapipes.iter.Map(dp, fn=loading_fn)
return dp

Or, you use the functional API (I prefer this way)

return ListDirFiles(root_dir).map(fn=loading_fn)

where loading_fn converts file to YesNoItem.



class YESNO(Dataset):
"""Create a Dataset for YesNo.
Expand Down