-
Notifications
You must be signed in to change notification settings - Fork 281
Remove data, metric and common to neural_compressor #244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
f55987f
refine class name
changwangss c69b9d1
add common folder
changwangss d2389da
add data folder
changwangss 66ddd2a
rename DATASETS
changwangss 61982c0
append dataset and transform feature
changwangss 86f8188
fix pylint and pydocstyle
changwangss bc016ca
fix docstring
changwangss ad5d78e
replace UT DATASETS
changwangss 9db4958
remove common
changwangss d01f5c0
rebase master
changwangss 9d2e61b
fix path
changwangss a98c7e7
rebase model/model.py
changwangss f6dabae
add no cover
changwangss c27df84
remove import for tensorflow_model.py
changwangss 5179326
fix MODELS
changwangss 6e06124
fix ut issue about experimental data and data
changwangss 1970aab
Merge branch 'master' into wangchang/api
changwangss c3ce078
refine class name
changwangss cb64192
add common folder
changwangss b282506
add data folder
changwangss 0ff55f8
rename DATASETS
changwangss 7c7e7d9
append dataset and transform feature
changwangss 13b90f5
fix pylint and pydocstyle
changwangss ad154d3
fix docstring
changwangss 4dcc749
replace UT DATASETS
changwangss 02b41d5
remove common
changwangss f9f497c
rebase master
changwangss f64ce3a
fix path
changwangss df3203f
rebase model/model.py
changwangss 78e01ae
add no cover
changwangss 7b60dc2
remove import for tensorflow_model.py
changwangss 9dbc7be
fix MODELS
changwangss bbfc9e4
fix ut issue about experimental data and data
changwangss dabb103
rebase master
changwangss a9d67bb
fix ut
changwangss 5219a12
fix conflict
changwangss 6273157
Merge branch 'master' into wangchang/api
changwangss c33cc2a
fix name model to tensorflow_model
changwangss d8890bc
fix import name
changwangss File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| #!/usr/bin/env python | ||
| # -*- coding: utf-8 -*- | ||
| # | ||
| # Copyright (c) 2021 Intel Corporation | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """BaseDataloder of all dataloaders.""" | ||
|
|
||
| from abc import abstractmethod | ||
|
|
||
|
|
||
| class BaseDataLoader: # pragma: no cover | ||
| """Base class for all DataLoaders. | ||
|
|
||
| _generate_dataloader is needed to create a dataloader object | ||
| from the general params like batch_size and sampler. The dynamic batching is just to | ||
| generate a new dataloader by setting batch_size and last_batch. | ||
|
|
||
| """ | ||
|
|
||
| def __init__(self, dataset, batch_size=1, last_batch='rollover', collate_fn=None, | ||
| sampler=None, batch_sampler=None, num_workers=0, pin_memory=False, | ||
| shuffle=False, distributed=False): | ||
| """Initialize BaseDataLoader. | ||
|
|
||
| Args: | ||
| dataset (object): dataset from which to load the data | ||
| batch_size (int, optional): number of samples per batch. Defaults to 1. | ||
| last_batch (str, optional): whether to drop the last batch if it is incomplete. | ||
| Support ['rollover', 'discard'], rollover means False, discard means True. | ||
| Defaults to 'rollover'. | ||
| collate_fn (callable, optional): merge data with outer dimension batch size. Defaults to None. | ||
| sampler (Sampler, optional): Sampler object to sample data. Defaults to None. | ||
| batch_sampler (BatchSampler, optional): BatchSampler object to generate batch of indices. Defaults to None. | ||
| num_workers (int, optional): number of subprocesses to use for data loading. Defaults to 0. | ||
| pin_memory (bool, optional): whether to copy data into pinned memory before returning. Defaults to False. | ||
| shuffle (bool, optional): whether to shuffle data. Defaults to False. | ||
| distributed (bool, optional): whether the dataloader is distributed. Defaults to False. | ||
| """ | ||
| self.dataset = dataset | ||
| self.collate_fn = collate_fn | ||
| self.sampler = sampler | ||
| self.batch_sampler = batch_sampler | ||
| self.num_workers = num_workers | ||
| self.pin_memory = pin_memory | ||
| self._batch_size = batch_size | ||
| self.shuffle = shuffle | ||
| self.distributed = distributed | ||
| self.last_batch = last_batch | ||
| self.drop_last = False if last_batch == 'rollover' else True | ||
|
|
||
| self.dataloader = self._generate_dataloader( | ||
| self.dataset, | ||
| batch_size=batch_size, | ||
| last_batch=last_batch, | ||
| collate_fn=collate_fn, | ||
| sampler=sampler, | ||
| batch_sampler=batch_sampler, | ||
| num_workers=num_workers, | ||
| pin_memory=pin_memory, | ||
| shuffle=shuffle, | ||
| distributed=distributed) | ||
|
|
||
| def batch(self, batch_size, last_batch=None): | ||
| """Set batch size for dataloader. | ||
|
|
||
| Args: | ||
| batch_size (int): number of samples per batch. | ||
| last_batch (str, optional): whether to drop the last batch if it is incomplete. | ||
| Support ['rollover', 'discard'], rollover means False, discard means True. | ||
| Defaults to None. | ||
| """ | ||
| self._batch_size = batch_size | ||
| if last_batch is not None: | ||
| self.last_batch = last_batch | ||
| self.dataloader = self._generate_dataloader( | ||
| self.dataset, | ||
| batch_size, | ||
| self.last_batch, | ||
| self.collate_fn, | ||
| self.sampler, | ||
| self.batch_sampler, | ||
| self.num_workers, | ||
| self.pin_memory, | ||
| self.shuffle, | ||
| self.distributed) | ||
|
|
||
| @property | ||
| def batch_size(self): | ||
| """Get dataloader's batch_size. | ||
|
|
||
| Returns: | ||
| int: batch_size | ||
| """ | ||
| return self._batch_size | ||
|
|
||
| def __iter__(self): | ||
| """Yield data in iterative order. | ||
|
|
||
| Returns: | ||
| iterator: iterator for dataloder | ||
| """ | ||
| return iter(self.dataloader) | ||
|
|
||
| @abstractmethod | ||
| def _generate_dataloader(self, dataset, batch_size, last_batch, collate_fn, sampler, | ||
| batch_sampler, num_workers, pin_memory, shuffle, distributed): | ||
| raise NotImplementedError |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
143 changes: 143 additions & 0 deletions
143
neural_compressor/data/dataloaders/default_dataloader.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| #!/usr/bin/env python | ||
| # -*- coding: utf-8 -*- | ||
| # | ||
| # Copyright (c) 2021 Intel Corporation | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """Default dataloader for multiple framework backends.""" | ||
|
|
||
| import collections | ||
| import numpy as np | ||
| from math import ceil, floor | ||
| from abc import abstractmethod | ||
| from .sampler import IterableSampler, SequentialSampler, BatchSampler | ||
| from .fetcher import FETCHERS | ||
| from .base_dataloader import BaseDataLoader | ||
|
|
||
| def default_collate(batch): # pragma: no cover | ||
| """Merge data with outer dimension batch size.""" | ||
| elem = batch[0] | ||
| if isinstance(elem, collections.abc.Mapping): | ||
| return {key: default_collate([d[key] for d in batch]) for key in elem} | ||
| elif isinstance(elem, collections.abc.Sequence): | ||
| batch = zip(*batch) | ||
| return [default_collate(samples) for samples in batch] | ||
| elif isinstance(elem, np.ndarray): | ||
| try: | ||
| return np.stack(batch) | ||
| except: | ||
| return batch | ||
| else: | ||
| return batch | ||
|
|
||
| class DefaultDataLoader(BaseDataLoader): # pragma: no cover | ||
| """DefaultDataLoader for multiple framework backends.""" | ||
|
|
||
| def __init__(self, dataset, batch_size=1, last_batch='rollover', collate_fn=None, | ||
| sampler=None, batch_sampler=None, num_workers=0, pin_memory=False, | ||
| shuffle=False, distributed=False): | ||
| """Initialize DefaultDataLoader. | ||
|
|
||
| Args: | ||
| dataset (object): dataset from which to load the data | ||
| batch_size (int, optional): number of samples per batch. Defaults to 1. | ||
| last_batch (str, optional): whether to drop the last batch if it is incomplete. | ||
| Support ['rollover', 'discard'], rollover means False, discard means True. | ||
| Defaults to 'rollover'. | ||
| collate_fn (callable, optional): merge data with outer dimension batch size. Defaults to None. | ||
| sampler (Sampler, optional): Sampler object to sample data. Defaults to None. | ||
| batch_sampler (BatchSampler, optional): BatchSampler object to generate batch of indices. Defaults to None. | ||
| num_workers (int, optional): number of subprocesses to use for data loading. Defaults to 0. | ||
| pin_memory (bool, optional): whether to copy data into pinned memory before returning. Defaults to False. | ||
| shuffle (bool, optional): whether to shuffle data. Defaults to False. | ||
| distributed (bool, optional): whether the dataloader is distributed. Defaults to False. | ||
| """ | ||
| self.dataset = dataset | ||
| self.last_batch = last_batch | ||
| self.sampler = sampler | ||
| self.batch_sampler = batch_sampler | ||
| self.num_workers = num_workers | ||
| self.pin_memory = pin_memory | ||
| self.collate_fn = collate_fn | ||
| self._batch_size = batch_size | ||
| self.shuffle = shuffle | ||
| self.distributed = distributed | ||
| self.drop_last = False if last_batch == 'rollover' else True | ||
| if self.collate_fn == None: | ||
| self.collate_fn = default_collate | ||
|
|
||
| def batch(self, batch_size, last_batch='rollover'): | ||
| """Set batch_size and last_batch.""" | ||
| self._batch_size = batch_size | ||
| self.last_batch = last_batch | ||
|
|
||
| @property | ||
| def dataloader(self): | ||
| """Return dataloader.""" | ||
| return self | ||
|
|
||
| def __iter__(self): | ||
| """Yield data in iterative order.""" | ||
| return self._generate_dataloader( | ||
| self.dataset, | ||
| batch_size=self.batch_size, | ||
| last_batch=self.last_batch, | ||
| collate_fn=self.collate_fn, | ||
| sampler=self.sampler, | ||
| batch_sampler=self.batch_sampler, | ||
| num_workers=self.num_workers, | ||
| pin_memory=self.pin_memory, | ||
| shuffle=self.shuffle, | ||
| distributed=self.distributed) | ||
|
|
||
| def __len__(self): | ||
| """Get dataset length.""" | ||
| try: | ||
| dataset_len = self.dataset.__len__() | ||
| except (AttributeError, TypeError): | ||
| dataset_len = 0 | ||
| for _ in self.dataset: | ||
| dataset_len += 1 | ||
| except Exception: | ||
| raise ValueError(f"{self.dataset} is invalid, {self.dataset}" \ | ||
| " does not support calculating the length of its dataloader") | ||
| if self.drop_last == False: | ||
| dataloader_len = ceil(dataset_len / self.batch_size) | ||
| else: | ||
| dataloader_len = floor(dataset_len / self.batch_size) | ||
| return dataloader_len | ||
|
|
||
| def _generate_dataloader(self, dataset, batch_size, last_batch, collate_fn, sampler, | ||
| batch_sampler, num_workers, pin_memory, shuffle, distributed): | ||
|
|
||
| sampler = self._generate_sampler(dataset, distributed) | ||
| self.batch_sampler = BatchSampler(sampler, batch_size, self.drop_last) | ||
| self.fetcher = FETCHERS[self.dataset_type](dataset, collate_fn, self.drop_last, distributed) | ||
|
|
||
| for batched_indices in self.batch_sampler: | ||
| try: | ||
| data = self.fetcher(batched_indices) | ||
| yield data | ||
| except StopIteration: | ||
| return | ||
|
|
||
| def _generate_sampler(self, dataset, distributed): | ||
| if hasattr(dataset, "__getitem__"): | ||
| self.dataset_type = 'index' | ||
| return SequentialSampler(dataset, distributed) | ||
| elif hasattr(dataset, "__iter__"): | ||
| self.dataset_type = 'iter' | ||
| return IterableSampler(dataset) | ||
| else: | ||
| raise ValueError("dataset type only support (index, iter)") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.