Skip to content

Commit 9552f9a

Browse files
committed
Add dataloader doc_string
Signed-off-by: mengniwa <[email protected]>
1 parent 49560bc commit 9552f9a

File tree

8 files changed

+169
-37
lines changed

8 files changed

+169
-37
lines changed

neural_compressor/experimental/data/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17-
17+
# ==============================================================================
18+
"""Built-in dataloaders, datasets, transforms, filters for multiple framework backends."""
1819

1920
from .datasets import DATASETS, Dataset, IterableDataset, dataset_registry
2021
from .transforms import TRANSFORMS, BaseTransform, transform_registry

neural_compressor/experimental/data/dataloaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17+
# ==============================================================================
18+
"""Built-in dataloaders for multiple framework backends."""
1719

1820
from .dataloader import DATALOADERS
1921

neural_compressor/experimental/data/dataloaders/base_dataloader.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,40 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17+
# ==============================================================================
18+
"""BaseDataloder of all dataloaders."""
1719

1820
from abc import abstractmethod
1921

2022

21-
class BaseDataLoader(object):
22-
"""Base class for all DataLoaders. _generate_dataloader is needed to create a dataloader object
23-
from the general params like batch_size and sampler. The dynamic batching is just to
24-
generate a new dataloader by setting batch_size and last_batch.
23+
class BaseDataLoader:
24+
"""Base class for all DataLoaders.
2525
26-
"""
26+
_generate_dataloader is needed to create a dataloader object
27+
from the general params like batch_size and sampler. The dynamic batching is just to
28+
generate a new dataloader by setting batch_size and last_batch.
2729
30+
"""
31+
2832
def __init__(self, dataset, batch_size=1, last_batch='rollover', collate_fn=None,
2933
sampler=None, batch_sampler=None, num_workers=0, pin_memory=False,
3034
shuffle=False, distributed=False):
35+
"""Initialize BaseDataLoader.
3136
37+
Args:
38+
dataset (object): dataset from which to load the data
39+
batch_size (int, optional): number of samples per batch. Defaults to 1.
40+
last_batch (str, optional): whether to drop the last batch if it is incomplete.
41+
Support ['rollover', 'discard'], rollover means False, discard means True.
42+
Defaults to 'rollover'.
43+
collate_fn (callable, optional): merge data with outer dimension batch size. Defaults to None.
44+
sampler (Sampler, optional): Sampler object to sample data. Defaults to None.
45+
batch_sampler (BatchSampler, optional): BatchSampler object to generate batch of indices. Defaults to None.
46+
num_workers (int, optional): number of subprocesses to use for data loading. Defaults to 0.
47+
pin_memory (bool, optional): whether to copy data into pinned memory before returning. Defaults to False.
48+
shuffle (bool, optional): whether to shuffle data. Defaults to False.
49+
distributed (bool, optional): whether the dataloader is distributed. Defaults to False.
50+
"""
3251
self.dataset = dataset
3352
self.collate_fn = collate_fn
3453
self.sampler = sampler
@@ -54,6 +73,14 @@ def __init__(self, dataset, batch_size=1, last_batch='rollover', collate_fn=None
5473
distributed=distributed)
5574

5675
def batch(self, batch_size, last_batch=None):
76+
"""Set batch size for dataloader.
77+
78+
Args:
79+
batch_size (int): number of samples per batch.
80+
last_batch (str, optional): whether to drop the last batch if it is incomplete.
81+
Support ['rollover', 'discard'], rollover means False, discard means True.
82+
Defaults to None.
83+
"""
5784
self._batch_size = batch_size
5885
if last_batch is not None:
5986
self.last_batch = last_batch
@@ -71,9 +98,19 @@ def batch(self, batch_size, last_batch=None):
7198

7299
@property
73100
def batch_size(self):
101+
"""Get dataloader's batch_size.
102+
103+
Returns:
104+
int: batch_size
105+
"""
74106
return self._batch_size
75107

76108
def __iter__(self):
109+
"""Yield data in iterative order.
110+
111+
Returns:
112+
iterator: iterator for dataloder
113+
"""
77114
return iter(self.dataloader)
78115

79116
@abstractmethod

neural_compressor/experimental/data/dataloaders/dataloader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17+
# ==============================================================================
18+
"""Built-in dataloaders for multiple framework backends."""
1719

1820
from .tensorflow_dataloader import TensorflowDataLoader
1921
from .mxnet_dataloader import MXNetDataLoader

neural_compressor/experimental/data/dataloaders/default_dataloader.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17+
# ==============================================================================
18+
"""Default dataloader for multiple framework backends."""
1719

1820
import collections
1921
import numpy as np
@@ -24,7 +26,7 @@
2426
from .base_dataloader import BaseDataLoader
2527

2628
def default_collate(batch):
27-
"""Puts each data field into a pd frame with outer dimension batch size"""
29+
"""Merge data with outer dimension batch size."""
2830
elem = batch[0]
2931
if isinstance(elem, collections.abc.Mapping):
3032
return {key: default_collate([d[key] for d in batch]) for key in elem}
@@ -40,13 +42,27 @@ def default_collate(batch):
4042
return batch
4143

4244
class DefaultDataLoader(BaseDataLoader):
43-
"""DefaultDataLoader
44-
45-
"""
46-
45+
"""DefaultDataLoader for multiple framework backends."""
46+
4747
def __init__(self, dataset, batch_size=1, last_batch='rollover', collate_fn=None,
4848
sampler=None, batch_sampler=None, num_workers=0, pin_memory=False,
4949
shuffle=False, distributed=False):
50+
"""Initialize DefaultDataLoader.
51+
52+
Args:
53+
dataset (object): dataset from which to load the data
54+
batch_size (int, optional): number of samples per batch. Defaults to 1.
55+
last_batch (str, optional): whether to drop the last batch if it is incomplete.
56+
Support ['rollover', 'discard'], rollover means False, discard means True.
57+
Defaults to 'rollover'.
58+
collate_fn (callable, optional): merge data with outer dimension batch size. Defaults to None.
59+
sampler (Sampler, optional): Sampler object to sample data. Defaults to None.
60+
batch_sampler (BatchSampler, optional): BatchSampler object to generate batch of indices. Defaults to None.
61+
num_workers (int, optional): number of subprocesses to use for data loading. Defaults to 0.
62+
pin_memory (bool, optional): whether to copy data into pinned memory before returning. Defaults to False.
63+
shuffle (bool, optional): whether to shuffle data. Defaults to False.
64+
distributed (bool, optional): whether the dataloader is distributed. Defaults to False.
65+
"""
5066
self.dataset = dataset
5167
self.last_batch = last_batch
5268
self.sampler = sampler
@@ -62,14 +78,17 @@ def __init__(self, dataset, batch_size=1, last_batch='rollover', collate_fn=None
6278
self.collate_fn = default_collate
6379

6480
def batch(self, batch_size, last_batch='rollover'):
81+
"""Set batch_size and last_batch."""
6582
self._batch_size = batch_size
6683
self.last_batch = last_batch
6784

6885
@property
6986
def dataloader(self):
87+
"""Return dataloader."""
7088
return self
7189

7290
def __iter__(self):
91+
"""Yield data in iterative order."""
7392
return self._generate_dataloader(
7493
self.dataset,
7594
batch_size=self.batch_size,
@@ -83,6 +102,7 @@ def __iter__(self):
83102
distributed=self.distributed)
84103

85104
def __len__(self):
105+
"""Get dataset length."""
86106
try:
87107
dataset_len = self.dataset.__len__()
88108
except (AttributeError, TypeError):

neural_compressor/experimental/data/dataloaders/fetcher.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,49 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17+
# ==============================================================================
18+
"""Definitions of the methods to fetch data from an iterable-style or list-style dataset."""
1719

1820
from abc import abstractmethod
1921

2022
class Fetcher(object):
23+
"""Base class for different fetchers."""
24+
2125
def __init__(self, dataset, collate_fn, drop_last):
26+
"""Initialize Fetcher.
27+
28+
Args:
29+
dataset (object): dataset object from which to get data
30+
collate_fn (callable): merge data with outer dimension batch size
31+
drop_last (bool): whether to drop the last batch if it is incomplete
32+
"""
2233
self.dataset = dataset
2334
self.collate_fn = collate_fn
2435
self.drop_last = drop_last
2536

2637
@abstractmethod
2738
def __call__(self, batched_indices):
39+
"""Fetch data.
40+
41+
Args:
42+
batched_indices (list): fetch data according to batched_indices
43+
44+
"""
2845
raise NotImplementedError
2946

3047
class IterableFetcher(Fetcher):
48+
"""Iterate to get next batch-size samples as a batch."""
49+
3150
def __init__(self, dataset, collate_fn, drop_last, distributed):
51+
"""Initialize IterableFetcher.
52+
53+
Args:
54+
dataset (object): dataset object from which to get data
55+
collate_fn (callable): merge data with outer dimension batch size
56+
drop_last (bool): whether to drop the last batch if it is incomplete
57+
distributed (bool): whether the dataloader is distributed
58+
59+
"""
3260
super(IterableFetcher, self).__init__(dataset, collate_fn, drop_last)
3361
self.dataset_iter = iter(dataset)
3462
self.index_whole = 0
@@ -47,6 +75,12 @@ def __init__(self, dataset, collate_fn, drop_last, distributed):
4775
" please set 'distributed: True' and launch multiple processes.")
4876

4977
def __call__(self, batched_indices):
78+
"""Fetch data.
79+
80+
Args:
81+
batched_indices (list): fetch data according to batched_indices
82+
83+
"""
5084
batch_data = []
5185
batch_size = len(batched_indices)
5286
while True:
@@ -64,10 +98,26 @@ def __call__(self, batched_indices):
6498
return self.collate_fn(batch_data)
6599

66100
class IndexFetcher(Fetcher):
101+
"""Take single index or a batch of indices to fetch samples as a batch."""
102+
67103
def __init__(self, dataset, collate_fn, drop_last, distributed):
104+
"""Initialize IndexFetcher.
105+
106+
Args:
107+
dataset (object): dataset object from which to get data
108+
collate_fn (callable): merge data with outer dimension batch size
109+
drop_last (bool): whether to drop the last batch if it is incomplete
110+
distributed (bool): whether the dataloader is distributed
111+
"""
68112
super(IndexFetcher, self).__init__(dataset, collate_fn, drop_last)
69113

70114
def __call__(self, batched_indices):
115+
"""Fetch data.
116+
117+
Args:
118+
batched_indices (list): fetch data according to batched_indices
119+
120+
"""
71121
data = [self.dataset[idx] for idx in batched_indices]
72122
return self.collate_fn(data)
73123

neural_compressor/experimental/data/dataloaders/onnxrt_dataloader.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17+
# ==============================================================================
18+
"""Built-in dataloaders for onnxruntime framework backends."""
1719

1820
from neural_compressor.utils.utility import LazyImport
1921
from .base_dataloader import BaseDataLoader
@@ -23,6 +25,8 @@
2325
torch = LazyImport('torch')
2426

2527
class ONNXRTBertDataLoader(DefaultDataLoader):
28+
"""Built-in dataloader for onnx bert model and its varients."""
29+
2630
def _generate_dataloader(self, dataset, batch_size, last_batch, collate_fn,
2731
sampler, batch_sampler, num_workers, pin_memory,
2832
shuffle, distributed):
@@ -59,6 +63,8 @@ def _generate_dataloader(self, dataset, batch_size, last_batch, collate_fn,
5963
return
6064

6165
class ONNXRTDataLoader(BaseDataLoader):
66+
"""Built-in dataloader for onnxruntime framework backends."""
67+
6268
def _generate_dataloader(self, dataset, batch_size, last_batch, collate_fn,
6369
sampler, batch_sampler, num_workers, pin_memory,
6470
shuffle, distributed):

0 commit comments

Comments
 (0)