diff --git a/docs/source/dataloader.md b/docs/source/dataloader.md index e89a79f4a5f..fab91d72368 100644 --- a/docs/source/dataloader.md +++ b/docs/source/dataloader.md @@ -100,7 +100,7 @@ calib_data = mx.io.ImageRecordIter(path_imgrec=dataset, ctx=args.ctx, **combine_mean_std) -from neural_compressor import Quantization, common +from neural_compressor.experimental import Quantization, common quantizer = Quantization('conf.yaml') quantizer.model = fp32_model quantizer.calib_dataloader = calib_data diff --git a/docs/source/dataset.md b/docs/source/dataset.md index b92bb828b9f..8d51cbdf723 100644 --- a/docs/source/dataset.md +++ b/docs/source/dataset.md @@ -96,7 +96,7 @@ class Dataset(object): After defining the dataset class, pass it to the quantizer: ```python -from neural_compressor import Quantization, common +from neural_compressor.experimental import Quantization, common quantizer = Quantization(yaml_file) quantizer.calib_dataloader = common.DataLoader(dataset) # user can pass more optional args to dataloader such as batch_size and collate_fn quantizer.model = graph diff --git a/neural_compressor/__init__.py b/neural_compressor/__init__.py index 5a2236347af..bcd0491a646 100644 --- a/neural_compressor/__init__.py +++ b/neural_compressor/__init__.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .quantization import Quantization from .pruning import Pruning from .benchmark import Benchmark from .version import __version__ diff --git a/neural_compressor/experimental/common/criterion.py b/neural_compressor/experimental/common/criterion.py index 11308854d10..abcba32299c 100644 --- a/neural_compressor/experimental/common/criterion.py +++ b/neural_compressor/experimental/common/criterion.py @@ -1454,6 +1454,36 @@ def loss_cal(self, student_outputs): self.loss += tmp_loss return self.loss + def teacher_model_forward(self, input, teacher_model=None, device=None): + """Teacher model forward. + + Args: + input (tensor): input data + teacher_model (torch.nn.model, optional): teacher model. Defaults to None. + device (torch.device, optional): device. Defaults to None. + + Returns: + tensor: output + """ + outputs = None + if self.loss_weights[1] > 0: + model = self.teacher_model if teacher_model is None else teacher_model + assert isinstance(model, torch.nn.Module), \ + 'Teacher model should be a torch Module instead of {}'.format(type(model)) + model.eval() + try: + model_device = next(model.parameters()).device + except: + logger.warning("Cannot get model device, assuming it's in CPU.") + model_device = "cpu" + device = model_device if device is None else device + if device != model_device: + model.to(device) + with torch.no_grad(): + outputs = pytorch_forward_wrapper(model, input, device=device) + self.teacher_outputs = outputs + return outputs + @criterion_registry('SelfKnowledgeDistillationLoss', 'pytorch') class PyTorchSelfKnowledgeDistillationLossWrapper(object): diff --git a/neural_compressor/experimental/distillation.py b/neural_compressor/experimental/distillation.py index 6938c191186..85ae6e52c85 100644 --- a/neural_compressor/experimental/distillation.py +++ b/neural_compressor/experimental/distillation.py @@ -226,7 +226,9 @@ def pre_process(self): framework_specific_info = {'device': self.cfg.device, 'random_seed': self.cfg.tuning.random_seed, 'workspace_path': self.cfg.tuning.workspace.path, - 'q_dataloader': None} + 'q_dataloader': None, + 'format': 'default', + 'backend': 'default'} if self.framework == 'tensorflow': framework_specific_info.update( diff --git a/neural_compressor/quantization.py b/neural_compressor/quantization.py index f90dfb74ffa..36504b5c12e 100644 --- a/neural_compressor/quantization.py +++ b/neural_compressor/quantization.py @@ -15,188 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from .utils import logger -from .data import DATALOADERS, DATASETS from .experimental import Quantization as ExpQuantization -from deprecated import deprecated from neural_compressor.conf.pythonic_config import Config from neural_compressor.config import PostTrainingQuantConfig -class Quantization(object): - """Quantization class automatically searches for optimal quantization recipes for low - precision model inference, achieving best tuning objectives like inference performance - within accuracy loss constraints. - - Tuner abstracts out the differences of quantization APIs across various DL frameworks - and brings a unified API for automatic quantization that works on frameworks including - tensorflow, pytorch and mxnet. - - Since DL use cases vary in the accuracy metrics (Top-1, MAP, ROC etc.), loss criteria - (<1% or <0.1% etc.) and tuning objectives (performance, memory footprint etc.). - Tuner class provides a flexible configuration interface via YAML for users to specify - these parameters. - - Args: - conf_fname_or_obj (string or obj): The path to the YAML configuration file or - Quantization_Conf class containing accuracy goal, tuning objective and preferred - calibration & quantization tuning space etc. - - """ - - def __init__(self, conf_fname_or_obj): - self.exp_quantizer = ExpQuantization(conf_fname_or_obj) - - @deprecated(version='2.0', reason="please use neural_compressor.quantization.fit instead") - def __call__(self, model, q_dataloader=None, q_func=None, eval_dataloader=None, - eval_func=None): - """The main entry point of automatic quantization tuning. - - This interface works on all the DL frameworks that neural_compressor supports - and provides three usages: - a) Fully yaml configuration: User specifies all the info through yaml, - including dataloaders used in calibration and evaluation phases - and quantization tuning settings. - - For this usage, only model parameter is mandatory. - - b) Partial yaml configuration: User specifies dataloaders used in calibration - and evaluation phase by code. - The tool provides built-in dataloaders and evaluators, user just need provide - a dataset implemented __iter__ or __getitem__ methods and invoke dataloader() - with dataset as input parameter to create neural_compressor dataloader before calling this - function. - - After that, User specifies fp32 "model", calibration dataset "q_dataloader" - and evaluation dataset "eval_dataloader". - The calibrated and quantized model is evaluated with "eval_dataloader" - with evaluation metrics specified in the configuration file. The evaluation tells - the tuner whether the quantized model meets the accuracy criteria. If not, - the tuner starts a new calibration and tuning flow. - - For this usage, model, q_dataloader and eval_dataloader parameters are mandatory. - - c) Partial yaml configuration: User specifies dataloaders used in calibration phase - by code. - This usage is quite similar with b), just user specifies a custom "eval_func" - which encapsulates the evaluation dataset by itself. - The calibrated and quantized model is evaluated with "eval_func". - The "eval_func" tells the tuner whether the quantized model meets - the accuracy criteria. If not, the Tuner starts a new calibration and tuning flow. - - For this usage, model, q_dataloader and eval_func parameters are mandatory. - - Args: - model (object): For Tensorflow model, it could be a path - to frozen pb,loaded graph_def object or - a path to ckpt/savedmodel folder. - For PyTorch model, it's torch.nn.model - instance. - For MXNet model, it's mxnet.symbol.Symbol - or gluon.HybirdBlock instance. - q_dataloader (generator): Data loader for calibration, mandatory for - post-training quantization. It is iterable - and should yield a tuple (input, label) for - calibration dataset containing label, - or yield (input, _) for label-free calibration - dataset. The input could be a object, list, - tuple or dict, depending on user implementation, - as well as it can be taken as model input. - q_func (function, optional): Training function for Quantization-Aware - Training. It is optional and only takes effect - when user choose "quant_aware_training" - approach in yaml. - This function takes "model" as input parameter - and executes entire training process with self - contained training hyper-parameters. If this - parameter specified, eval_dataloader parameter - plus metric defined in yaml, or eval_func - parameter should also be specified at same time. - eval_dataloader (generator, optional): Data loader for evaluation. It is iterable - and should yield a tuple of (input, label). - The input could be a object, list, tuple or - dict, depending on user implementation, - as well as it can be taken as model input. - The label should be able to take as input of - supported metrics. If this parameter is - not None, user needs to specify pre-defined - evaluation metrics through configuration file - and should set "eval_func" paramter as None. - Tuner will combine model, eval_dataloader - and pre-defined metrics to run evaluation - process. - eval_func (function, optional): The evaluation function provided by user. - This function takes model as parameter, - and evaluation dataset and metrics should be - encapsulated in this function implementation - and outputs a higher-is-better accuracy scalar - value. - - The pseudo code should be something like: - - def eval_func(model): - input, label = dataloader() - output = model(input) - accuracy = metric(output, label) - return accuracy - - Returns: - quantized model: best qanitized model found, otherwise return None - - """ - - logger.warning("This API is going to be deprecated. Please import " - "neural_compressor.experimental.Quantization, initialize an instance of `Quantization`," - "set its dataloader and metric attributes, then invoke its __call__ method.") - - self.exp_quantizer.model = model - if q_dataloader is not None: - self.exp_quantizer.calib_dataloader = q_dataloader - elif q_func is not None: - self.exp_quantizer.q_func = q_func - - if eval_func is not None: - self.exp_quantizer.eval_func = eval_func - elif eval_dataloader is not None: - self.exp_quantizer.eval_dataloader = eval_dataloader - - nc_model = self.exp_quantizer.fit() - if self.exp_quantizer.framework == 'tensorflow': - return nc_model.graph if nc_model else None - if self.exp_quantizer.framework == 'pytorch': - saved_path = os.path.abspath(os.path.join(os.path.expanduser( - self.exp_quantizer.conf.usr_cfg.tuning.workspace.path), 'checkpoint')) - nc_model.save(saved_path) - return nc_model.model - - fit = __call__ - - @deprecated(version='2.0', reason="this function has been deprecated") - def dataset(self, dataset_type, *args, **kwargs): - return DATASETS(self.exp_quantizer.framework)[dataset_type](*args, **kwargs) - - @deprecated(version='2.0', reason="this function has been deprecated") - def dataloader(self, dataset, batch_size=1, collate_fn=None, last_batch='rollover', - sampler=None, batch_sampler=None, num_workers=0, pin_memory=False): - return DATALOADERS[self.exp_quantizer.framework]( - dataset=dataset, - batch_size=batch_size, collate_fn=collate_fn, last_batch=last_batch, - sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, - pin_memory=pin_memory - ) - - @deprecated(version='2.0', reason="this function has been deprecated") - def metric(self, name, metric_cls, **kwargs): - from .experimental.common import Metric as NCMetric - nc_metric = NCMetric(metric_cls, name, **kwargs) - self.exp_quantizer.metric = nc_metric - - @deprecated(version='2.0', reason="this function has been deprecated") - def postprocess(self, name, postprocess_cls, **kwargs): - from .experimental.common import Postprocess as NCPostprocess - nc_postprocess = NCPostprocess(postprocess_cls, name, **kwargs) - self.exp_quantizer.postprocess = nc_postprocess - def fit(model, conf, diff --git a/neural_compressor/training.py b/neural_compressor/training.py index eee317ef57f..e0b60eb190e 100644 --- a/neural_compressor/training.py +++ b/neural_compressor/training.py @@ -232,6 +232,8 @@ def prepare_compression(model: Callable, confs: Union[Callable, List], **kwargs) component.model = model if isinstance(confs, QuantizationAwareTrainingConfig): component.prepare_qat() + else: + component.prepare() compression_manager = CompressionManager(component) return compression_manager diff --git a/test/distillation/test_distillation.py b/test/distillation/test_distillation.py index a5a993f2fdf..a43259ed7bb 100644 --- a/test/distillation/test_distillation.py +++ b/test/distillation/test_distillation.py @@ -9,7 +9,7 @@ from neural_compressor.data import DATASETS from neural_compressor.config import DistillationConfig, KnowledgeDistillationLossConfig from neural_compressor.experimental.data.dataloaders.pytorch_dataloader import PyTorchDataLoader - +from neural_compressor.adaptor.tf_utils.util import version1_lt_version2 def build_fake_yaml(): fake_yaml = """ @@ -252,7 +252,7 @@ def test_distillation_external_new_API(self): stat = torch.load('./saved/best_model.pt') opt_model = self.student_model.load_state_dict(stat) - @unittest.skipIf(tf.version.VERSION < '2.3.0', " keras requires higher version than tf-2.3.0") + @unittest.skipIf(version1_lt_version2(tf.version.VERSION, '2.3.0'), " keras requires higher version than tf-2.3.0") def test_tf_distillation(self): from neural_compressor.experimental import Distillation from neural_compressor.conf.config import DistillationConf diff --git a/test/metric/test_register_metric_transform.py b/test/metric/test_register_metric_transform.py index e8695006688..ff42df870a5 100644 --- a/test/metric/test_register_metric_transform.py +++ b/test/metric/test_register_metric_transform.py @@ -42,7 +42,7 @@ def test_register_metric_postprocess(self): resize_image = resize_image - mean images = np.expand_dims(resize_image, axis=0) labels = [768] - from neural_compressor import Benchmark, Quantization + from neural_compressor import Benchmark from neural_compressor.experimental.data.transforms.imagenet_transform import LabelShift from neural_compressor.experimental.metric.metric import TensorflowTopK os.environ['NC_ENV_CONF'] = 'True' @@ -53,9 +53,6 @@ def test_register_metric_postprocess(self): dataloader = evaluator.dataloader(dataset=list(zip(images, labels))) evaluator(self.pb_path, b_dataloader=dataloader) - quantizer = Quantization('fake_yaml.yaml') - quantizer.postprocess('label_quantize', LabelShift, label_shift=1) - quantizer.metric('topk_quantize', TensorflowTopK) evaluator = Benchmark('fake_yaml.yaml') evaluator.metric('topk_second', TensorflowTopK) dataloader = evaluator.dataloader(dataset=list(zip(images, labels))) diff --git a/test/quantization/test_quantization.py b/test/quantization/test_quantization.py index 61698fc21d0..6ab9351b042 100644 --- a/test/quantization/test_quantization.py +++ b/test/quantization/test_quantization.py @@ -342,13 +342,13 @@ def test_resume(self): def test_autodump(self): # test auto_dump using old api - from neural_compressor.quantization import Quantization + from neural_compressor.experimental import Quantization, common quantizer = Quantization('fake_yaml3.yaml') dataset = quantizer.dataset('dummy', shape=(100, 3, 3, 1), label=True) - dataloader = quantizer.dataloader(dataset) + quantizer.eval_dataloader = common.DataLoader(dataset) + quantizer.calib_dataloader = common.DataLoader(dataset) quantizer.model = self.constant_graph - output_graph = quantizer(self.constant_graph, \ - q_dataloader=dataloader, eval_dataloader=dataloader) + output_graph = quantizer.fit() self.assertNotEqual(output_graph, None) def test_performance_only(self):