Skip to content

Commit 047f121

Browse files
PenghuiChengintel-zhangyi
authored andcommitted
Fixed pruning and distillation bug and remove invalid code (#251)
Signed-off-by: Cheng, Penghui <[email protected]>
1 parent 162611b commit 047f121

File tree

10 files changed

+44
-192
lines changed

10 files changed

+44
-192
lines changed

docs/source/dataloader.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ calib_data = mx.io.ImageRecordIter(path_imgrec=dataset,
100100
ctx=args.ctx,
101101
**combine_mean_std)
102102

103-
from neural_compressor import Quantization, common
103+
from neural_compressor.experimental import Quantization, common
104104
quantizer = Quantization('conf.yaml')
105105
quantizer.model = fp32_model
106106
quantizer.calib_dataloader = calib_data

docs/source/dataset.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class Dataset(object):
9696
After defining the dataset class, pass it to the quantizer:
9797

9898
```python
99-
from neural_compressor import Quantization, common
99+
from neural_compressor.experimental import Quantization, common
100100
quantizer = Quantization(yaml_file)
101101
quantizer.calib_dataloader = common.DataLoader(dataset) # user can pass more optional args to dataloader such as batch_size and collate_fn
102102
quantizer.model = graph

neural_compressor/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18-
from .quantization import Quantization
1918
from .pruning import Pruning
2019
from .benchmark import Benchmark
2120
from .version import __version__

neural_compressor/experimental/common/criterion.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,6 +1454,36 @@ def loss_cal(self, student_outputs):
14541454
self.loss += tmp_loss
14551455
return self.loss
14561456

1457+
def teacher_model_forward(self, input, teacher_model=None, device=None):
1458+
"""Teacher model forward.
1459+
1460+
Args:
1461+
input (tensor): input data
1462+
teacher_model (torch.nn.model, optional): teacher model. Defaults to None.
1463+
device (torch.device, optional): device. Defaults to None.
1464+
1465+
Returns:
1466+
tensor: output
1467+
"""
1468+
outputs = None
1469+
if self.loss_weights[1] > 0:
1470+
model = self.teacher_model if teacher_model is None else teacher_model
1471+
assert isinstance(model, torch.nn.Module), \
1472+
'Teacher model should be a torch Module instead of {}'.format(type(model))
1473+
model.eval()
1474+
try:
1475+
model_device = next(model.parameters()).device
1476+
except:
1477+
logger.warning("Cannot get model device, assuming it's in CPU.")
1478+
model_device = "cpu"
1479+
device = model_device if device is None else device
1480+
if device != model_device:
1481+
model.to(device)
1482+
with torch.no_grad():
1483+
outputs = pytorch_forward_wrapper(model, input, device=device)
1484+
self.teacher_outputs = outputs
1485+
return outputs
1486+
14571487

14581488
@criterion_registry('SelfKnowledgeDistillationLoss', 'pytorch')
14591489
class PyTorchSelfKnowledgeDistillationLossWrapper(object):

neural_compressor/experimental/distillation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,9 @@ def pre_process(self):
226226
framework_specific_info = {'device': self.cfg.device,
227227
'random_seed': self.cfg.tuning.random_seed,
228228
'workspace_path': self.cfg.tuning.workspace.path,
229-
'q_dataloader': None}
229+
'q_dataloader': None,
230+
'format': 'default',
231+
'backend': 'default'}
230232

231233
if self.framework == 'tensorflow':
232234
framework_specific_info.update(

neural_compressor/quantization.py

Lines changed: 0 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -15,188 +15,10 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18-
import os
19-
from .utils import logger
20-
from .data import DATALOADERS, DATASETS
2118
from .experimental import Quantization as ExpQuantization
22-
from deprecated import deprecated
2319
from neural_compressor.conf.pythonic_config import Config
2420
from neural_compressor.config import PostTrainingQuantConfig
2521

26-
class Quantization(object):
27-
"""Quantization class automatically searches for optimal quantization recipes for low
28-
precision model inference, achieving best tuning objectives like inference performance
29-
within accuracy loss constraints.
30-
31-
Tuner abstracts out the differences of quantization APIs across various DL frameworks
32-
and brings a unified API for automatic quantization that works on frameworks including
33-
tensorflow, pytorch and mxnet.
34-
35-
Since DL use cases vary in the accuracy metrics (Top-1, MAP, ROC etc.), loss criteria
36-
(<1% or <0.1% etc.) and tuning objectives (performance, memory footprint etc.).
37-
Tuner class provides a flexible configuration interface via YAML for users to specify
38-
these parameters.
39-
40-
Args:
41-
conf_fname_or_obj (string or obj): The path to the YAML configuration file or
42-
Quantization_Conf class containing accuracy goal, tuning objective and preferred
43-
calibration & quantization tuning space etc.
44-
45-
"""
46-
47-
def __init__(self, conf_fname_or_obj):
48-
self.exp_quantizer = ExpQuantization(conf_fname_or_obj)
49-
50-
@deprecated(version='2.0', reason="please use neural_compressor.quantization.fit instead")
51-
def __call__(self, model, q_dataloader=None, q_func=None, eval_dataloader=None,
52-
eval_func=None):
53-
"""The main entry point of automatic quantization tuning.
54-
55-
This interface works on all the DL frameworks that neural_compressor supports
56-
and provides three usages:
57-
a) Fully yaml configuration: User specifies all the info through yaml,
58-
including dataloaders used in calibration and evaluation phases
59-
and quantization tuning settings.
60-
61-
For this usage, only model parameter is mandatory.
62-
63-
b) Partial yaml configuration: User specifies dataloaders used in calibration
64-
and evaluation phase by code.
65-
The tool provides built-in dataloaders and evaluators, user just need provide
66-
a dataset implemented __iter__ or __getitem__ methods and invoke dataloader()
67-
with dataset as input parameter to create neural_compressor dataloader before calling this
68-
function.
69-
70-
After that, User specifies fp32 "model", calibration dataset "q_dataloader"
71-
and evaluation dataset "eval_dataloader".
72-
The calibrated and quantized model is evaluated with "eval_dataloader"
73-
with evaluation metrics specified in the configuration file. The evaluation tells
74-
the tuner whether the quantized model meets the accuracy criteria. If not,
75-
the tuner starts a new calibration and tuning flow.
76-
77-
For this usage, model, q_dataloader and eval_dataloader parameters are mandatory.
78-
79-
c) Partial yaml configuration: User specifies dataloaders used in calibration phase
80-
by code.
81-
This usage is quite similar with b), just user specifies a custom "eval_func"
82-
which encapsulates the evaluation dataset by itself.
83-
The calibrated and quantized model is evaluated with "eval_func".
84-
The "eval_func" tells the tuner whether the quantized model meets
85-
the accuracy criteria. If not, the Tuner starts a new calibration and tuning flow.
86-
87-
For this usage, model, q_dataloader and eval_func parameters are mandatory.
88-
89-
Args:
90-
model (object): For Tensorflow model, it could be a path
91-
to frozen pb,loaded graph_def object or
92-
a path to ckpt/savedmodel folder.
93-
For PyTorch model, it's torch.nn.model
94-
instance.
95-
For MXNet model, it's mxnet.symbol.Symbol
96-
or gluon.HybirdBlock instance.
97-
q_dataloader (generator): Data loader for calibration, mandatory for
98-
post-training quantization. It is iterable
99-
and should yield a tuple (input, label) for
100-
calibration dataset containing label,
101-
or yield (input, _) for label-free calibration
102-
dataset. The input could be a object, list,
103-
tuple or dict, depending on user implementation,
104-
as well as it can be taken as model input.
105-
q_func (function, optional): Training function for Quantization-Aware
106-
Training. It is optional and only takes effect
107-
when user choose "quant_aware_training"
108-
approach in yaml.
109-
This function takes "model" as input parameter
110-
and executes entire training process with self
111-
contained training hyper-parameters. If this
112-
parameter specified, eval_dataloader parameter
113-
plus metric defined in yaml, or eval_func
114-
parameter should also be specified at same time.
115-
eval_dataloader (generator, optional): Data loader for evaluation. It is iterable
116-
and should yield a tuple of (input, label).
117-
The input could be a object, list, tuple or
118-
dict, depending on user implementation,
119-
as well as it can be taken as model input.
120-
The label should be able to take as input of
121-
supported metrics. If this parameter is
122-
not None, user needs to specify pre-defined
123-
evaluation metrics through configuration file
124-
and should set "eval_func" paramter as None.
125-
Tuner will combine model, eval_dataloader
126-
and pre-defined metrics to run evaluation
127-
process.
128-
eval_func (function, optional): The evaluation function provided by user.
129-
This function takes model as parameter,
130-
and evaluation dataset and metrics should be
131-
encapsulated in this function implementation
132-
and outputs a higher-is-better accuracy scalar
133-
value.
134-
135-
The pseudo code should be something like:
136-
137-
def eval_func(model):
138-
input, label = dataloader()
139-
output = model(input)
140-
accuracy = metric(output, label)
141-
return accuracy
142-
143-
Returns:
144-
quantized model: best qanitized model found, otherwise return None
145-
146-
"""
147-
148-
logger.warning("This API is going to be deprecated. Please import "
149-
"neural_compressor.experimental.Quantization, initialize an instance of `Quantization`,"
150-
"set its dataloader and metric attributes, then invoke its __call__ method.")
151-
152-
self.exp_quantizer.model = model
153-
if q_dataloader is not None:
154-
self.exp_quantizer.calib_dataloader = q_dataloader
155-
elif q_func is not None:
156-
self.exp_quantizer.q_func = q_func
157-
158-
if eval_func is not None:
159-
self.exp_quantizer.eval_func = eval_func
160-
elif eval_dataloader is not None:
161-
self.exp_quantizer.eval_dataloader = eval_dataloader
162-
163-
nc_model = self.exp_quantizer.fit()
164-
if self.exp_quantizer.framework == 'tensorflow':
165-
return nc_model.graph if nc_model else None
166-
if self.exp_quantizer.framework == 'pytorch':
167-
saved_path = os.path.abspath(os.path.join(os.path.expanduser(
168-
self.exp_quantizer.conf.usr_cfg.tuning.workspace.path), 'checkpoint'))
169-
nc_model.save(saved_path)
170-
return nc_model.model
171-
172-
fit = __call__
173-
174-
@deprecated(version='2.0', reason="this function has been deprecated")
175-
def dataset(self, dataset_type, *args, **kwargs):
176-
return DATASETS(self.exp_quantizer.framework)[dataset_type](*args, **kwargs)
177-
178-
@deprecated(version='2.0', reason="this function has been deprecated")
179-
def dataloader(self, dataset, batch_size=1, collate_fn=None, last_batch='rollover',
180-
sampler=None, batch_sampler=None, num_workers=0, pin_memory=False):
181-
return DATALOADERS[self.exp_quantizer.framework](
182-
dataset=dataset,
183-
batch_size=batch_size, collate_fn=collate_fn, last_batch=last_batch,
184-
sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers,
185-
pin_memory=pin_memory
186-
)
187-
188-
@deprecated(version='2.0', reason="this function has been deprecated")
189-
def metric(self, name, metric_cls, **kwargs):
190-
from .experimental.common import Metric as NCMetric
191-
nc_metric = NCMetric(metric_cls, name, **kwargs)
192-
self.exp_quantizer.metric = nc_metric
193-
194-
@deprecated(version='2.0', reason="this function has been deprecated")
195-
def postprocess(self, name, postprocess_cls, **kwargs):
196-
from .experimental.common import Postprocess as NCPostprocess
197-
nc_postprocess = NCPostprocess(postprocess_cls, name, **kwargs)
198-
self.exp_quantizer.postprocess = nc_postprocess
199-
20022

20123
def fit(model,
20224
conf,

neural_compressor/training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ def prepare_compression(model: Callable, confs: Union[Callable, List], **kwargs)
232232
component.model = model
233233
if isinstance(confs, QuantizationAwareTrainingConfig):
234234
component.prepare_qat()
235+
else:
236+
component.prepare()
235237
compression_manager = CompressionManager(component)
236238

237239
return compression_manager

test/distillation/test_distillation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from neural_compressor.data import DATASETS
1010
from neural_compressor.config import DistillationConfig, KnowledgeDistillationLossConfig
1111
from neural_compressor.experimental.data.dataloaders.pytorch_dataloader import PyTorchDataLoader
12-
12+
from neural_compressor.adaptor.tf_utils.util import version1_lt_version2
1313

1414
def build_fake_yaml():
1515
fake_yaml = """
@@ -252,7 +252,7 @@ def test_distillation_external_new_API(self):
252252
stat = torch.load('./saved/best_model.pt')
253253
opt_model = self.student_model.load_state_dict(stat)
254254

255-
@unittest.skipIf(tf.version.VERSION < '2.3.0', " keras requires higher version than tf-2.3.0")
255+
@unittest.skipIf(version1_lt_version2(tf.version.VERSION, '2.3.0'), " keras requires higher version than tf-2.3.0")
256256
def test_tf_distillation(self):
257257
from neural_compressor.experimental import Distillation
258258
from neural_compressor.conf.config import DistillationConf

test/metric/test_register_metric_transform.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_register_metric_postprocess(self):
4242
resize_image = resize_image - mean
4343
images = np.expand_dims(resize_image, axis=0)
4444
labels = [768]
45-
from neural_compressor import Benchmark, Quantization
45+
from neural_compressor import Benchmark
4646
from neural_compressor.experimental.data.transforms.imagenet_transform import LabelShift
4747
from neural_compressor.experimental.metric.metric import TensorflowTopK
4848
os.environ['NC_ENV_CONF'] = 'True'
@@ -53,9 +53,6 @@ def test_register_metric_postprocess(self):
5353
dataloader = evaluator.dataloader(dataset=list(zip(images, labels)))
5454
evaluator(self.pb_path, b_dataloader=dataloader)
5555

56-
quantizer = Quantization('fake_yaml.yaml')
57-
quantizer.postprocess('label_quantize', LabelShift, label_shift=1)
58-
quantizer.metric('topk_quantize', TensorflowTopK)
5956
evaluator = Benchmark('fake_yaml.yaml')
6057
evaluator.metric('topk_second', TensorflowTopK)
6158
dataloader = evaluator.dataloader(dataset=list(zip(images, labels)))

test/quantization/test_quantization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,13 +342,13 @@ def test_resume(self):
342342

343343
def test_autodump(self):
344344
# test auto_dump using old api
345-
from neural_compressor.quantization import Quantization
345+
from neural_compressor.experimental import Quantization, common
346346
quantizer = Quantization('fake_yaml3.yaml')
347347
dataset = quantizer.dataset('dummy', shape=(100, 3, 3, 1), label=True)
348-
dataloader = quantizer.dataloader(dataset)
348+
quantizer.eval_dataloader = common.DataLoader(dataset)
349+
quantizer.calib_dataloader = common.DataLoader(dataset)
349350
quantizer.model = self.constant_graph
350-
output_graph = quantizer(self.constant_graph, \
351-
q_dataloader=dataloader, eval_dataloader=dataloader)
351+
output_graph = quantizer.fit()
352352
self.assertNotEqual(output_graph, None)
353353

354354
def test_performance_only(self):

0 commit comments

Comments
 (0)