Skip to content

Commit 34b8ece

Browse files
committed
Fixed pruning and distillation bug and update examples
Signed-off-by: Cheng, Penghui <[email protected]>
1 parent 583545b commit 34b8ece

File tree

13 files changed

+21
-314
lines changed

13 files changed

+21
-314
lines changed

docs/source/dataloader.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ calib_data = mx.io.ImageRecordIter(path_imgrec=dataset,
100100
ctx=args.ctx,
101101
**combine_mean_std)
102102

103-
from neural_compressor import Quantization, common
103+
from neural_compressor.experimental import Quantization, common
104104
quantizer = Quantization('conf.yaml')
105105
quantizer.model = fp32_model
106106
quantizer.calib_dataloader = calib_data

docs/source/dataset.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class Dataset(object):
9696
After defining the dataset class, pass it to the quantizer:
9797

9898
```python
99-
from neural_compressor import Quantization, common
99+
from neural_compressor.experimental import Quantization, common
100100
quantizer = Quantization(yaml_file)
101101
quantizer.calib_dataloader = common.DataLoader(dataset) # user can pass more optional args to dataloader such as batch_size and collate_fn
102102
quantizer.model = graph

examples/.config/model_params_pytorch.json

Lines changed: 13 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -220,91 +220,61 @@
220220
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_static/fx",
221221
"dataset_location": "",
222222
"input_model": "/tf_dataset/pytorch/glue_data/base_weights/bert_MRPC_output",
223-
"yaml": "conf.yaml",
224-
"strategy": "basic",
225-
"batch_size": 64,
226-
"new_benchmark": false
223+
"batch_size": 64
227224
},
228225
"bert_base_CoLA": {
229226
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_static/fx",
230227
"dataset_location": "",
231228
"input_model": "/tf_dataset/pytorch/glue_data/base_weights/bert_CoLA_output",
232-
"yaml": "conf.yaml",
233-
"strategy": "basic",
234-
"batch_size": 64,
235-
"new_benchmark": false
229+
"batch_size": 64
236230
},
237231
"bert_base_STS-B": {
238232
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_static/fx",
239233
"dataset_location": "",
240234
"input_model": "/tf_dataset/pytorch/glue_data/base_weights/bert_STS-B_output",
241-
"yaml": "conf.yaml",
242-
"strategy": "basic",
243-
"batch_size": 64,
244-
"new_benchmark": false
235+
"batch_size": 64
245236
},
246237
"bert_base_SST-2": {
247238
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_static/fx",
248239
"dataset_location": "",
249240
"input_model": "/tf_dataset/pytorch/glue_data/base_weights/bert_SST-2_output",
250-
"yaml": "conf.yaml",
251-
"strategy": "basic",
252-
"batch_size": 64,
253-
"new_benchmark": false
241+
"batch_size": 64
254242
},
255243
"bert_base_RTE": {
256244
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_static/fx",
257245
"dataset_location": "",
258246
"input_model": "/tf_dataset/pytorch/glue_data/base_weights/bert_RTE_output",
259-
"yaml": "conf.yaml",
260-
"strategy": "basic",
261-
"batch_size": 64,
262-
"new_benchmark": false
247+
"batch_size": 64
263248
},
264249
"bert_large_MRPC": {
265250
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_static/fx",
266251
"dataset_location": "",
267252
"input_model": "/tf_dataset/pytorch/glue_data/weights/bert_MRPC_output",
268-
"yaml": "conf.yaml",
269-
"strategy": "basic",
270-
"batch_size": 64,
271-
"new_benchmark": false
253+
"batch_size": 64
272254
},
273255
"bert_large_SQuAD": {
274256
"model_src_dir": "nlp/huggingface_models/question-answering/quantization/ptq_static/fx",
275257
"dataset_location": "",
276258
"input_model": "",
277-
"yaml": "conf.yaml",
278-
"strategy": "basic",
279-
"batch_size": 64,
280-
"new_benchmark": false
259+
"batch_size": 64
281260
},
282261
"bert_large_QNLI": {
283262
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_static/fx",
284263
"dataset_location": "",
285264
"input_model": "/tf_dataset/pytorch/glue_data/weights/bert_QNLI_output",
286-
"yaml": "conf.yaml",
287-
"strategy": "basic",
288-
"batch_size": 64,
289-
"new_benchmark": false
265+
"batch_size": 64
290266
},
291267
"bert_large_RTE": {
292268
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_dynamic/fx",
293269
"dataset_location": "",
294270
"input_model": "/tf_dataset/pytorch/glue_data/weights/bert_large_rte",
295-
"yaml": "conf.yaml",
296-
"strategy": "basic",
297-
"batch_size": 64,
298-
"new_benchmark": false
271+
"batch_size": 64
299272
},
300273
"bert_large_CoLA": {
301274
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_static/fx",
302275
"dataset_location": "",
303276
"input_model": "/tf_dataset/pytorch/glue_data/weights/bert_CoLA_output",
304-
"yaml": "conf.yaml",
305-
"strategy": "basic",
306-
"batch_size": 64,
307-
"new_benchmark": false
277+
"batch_size": 64
308278
},
309279
"dlrm": {
310280
"model_src_dir": "recommendation/dlrm/quantization/ptq/eager",
@@ -436,10 +406,7 @@
436406
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_static/fx",
437407
"dataset_location": "/tf_dataset/pytorch/glue_data_new",
438408
"input_model": "/tf_dataset/pytorch/huggingface/language_translation_pt/distilbert_mrpc",
439-
"yaml": "conf.yaml",
440-
"strategy": "basic",
441-
"batch_size": 64,
442-
"new_benchmark": false
409+
"batch_size": 64
443410
},
444411
"albert_base_MRPC": {
445412
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_dynamic/eager",
@@ -463,10 +430,7 @@
463430
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_static/fx",
464431
"dataset_location": "/tf_dataset/pytorch/glue_data_new",
465432
"input_model": "/tf_dataset/pytorch/huggingface/language_translation_pt/funnel_mrpc",
466-
"yaml": "conf.yaml",
467-
"strategy": "basic",
468-
"batch_size": 64,
469-
"new_benchmark": false
433+
"batch_size": 64
470434
},
471435
"bart_WNLI(rm)": {
472436
"model_src_dir": "eager/huggingface_models",
@@ -697,10 +661,7 @@
697661
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/qat/fx",
698662
"dataset_location": "/tf_dataset/pytorch/glue_data/MRPC/",
699663
"input_model": "/tf_dataset2/models/pytorch/bert_model",
700-
"yaml": "conf_qat.yaml",
701-
"strategy": "basic",
702-
"batch_size": 8,
703-
"new_benchmark": false
664+
"batch_size": 8
704665
},
705666
"wide_resnet101_2_fx": {
706667
"model_src_dir": "oob_models/gen-efficientnet-pytorch",

examples/helloworld/tf_example4/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
import numpy as np
44
from tensorflow import keras
5-
from neural_compressor.experimental import Quantization, common
5+
from neural_compressor.experimental import Quantization, common
66

77
def main():
88

examples/helloworld/tf_example6/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ python test.py --benchmark
6464
### 6. Introduction
6565
* We only need to add the following lines for quantization to create an int8 model.
6666
```python
67-
from neural_compressor import Quantization
67+
from neural_compressor.experimental import Quantization
6868
quantizer = Quantization('./conf.yaml')
6969
quantized_model = quantizer('./mobilenet_v1_1.0_224_frozen.pb')
7070
tf.io.write_graph(graph_or_graph_def=quantized_model,

examples/helloworld/tf_example6/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def main():
99
args = arg_parser.parse_args()
1010

1111
if args.tune:
12-
from neural_compressor import Quantization
12+
from neural_compressor.experimental import Quantization
1313
quantizer = Quantization('./conf.yaml')
1414
quantized_model = quantizer("./mobilenet_v1_1.0_224_frozen.pb")
1515
tf.io.write_graph(graph_or_graph_def=quantized_model,

examples/pytorch/nlp/huggingface_models/text-classification/quantization/ptq_dynamic/fx/conf.yaml

Lines changed: 0 additions & 31 deletions
This file was deleted.

examples/pytorch/nlp/huggingface_models/text-classification/quantization/ptq_static/fx/conf.yaml

Lines changed: 0 additions & 31 deletions
This file was deleted.

examples/pytorch/nlp/huggingface_models/text-classification/quantization/qat/fx/conf_qat.yaml

Lines changed: 0 additions & 16 deletions
This file was deleted.

examples/pytorch/nlp/huggingface_models/text-classification/quantization/qat/fx/run_glue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ def benchmark(model):
528528
conf = QuantizationAwareTrainingConfig()
529529
compression_manager = prepare_compression(model, conf)
530530
compression_manager.callbacks.on_train_begin()
531+
trainer.model = compression_manager.model
531532
trainer.train()
532533
compression_manager.callbacks.on_train_end()
533534

0 commit comments

Comments
 (0)