Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2385,3 +2385,4 @@ Nsh
UmK
fe
vmware
PythonLauncher
5 changes: 4 additions & 1 deletion neural_coder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@ simultaneously on below PyTorch evaluation code, we generate the optimized code

## Getting Started!

There are currently 2 ways to use Neural Coder for automatic quantization enabling and benchmark.
There are currently 3 ways to use Neural Coder for automatic quantization enabling and benchmark.

### Jupyter Lab Extension
We offer Neural Coder as an extension plugin in Jupyter Lab. This enables users to utilize Neural Coder while writing their Deep Learning models in Jupyter Lab coding platform. Users can simply search for ```jupyter-lab-neural-compressor``` in the Extension Manager in JupyterLab and install Neural Coder with one click. For more details, please refer to this [guide](extensions/neural_compressor_ext_lab/README.md)

### Python Launcher
Neural Coder can be used as a Python Launcher. Users can run the Python Deep Learning model code as it is with automatic enabling of optimizations by simply adding an inline prefix ```-m neural_coder``` to the Python command line. For more details, please refer to this [guide](docs/PythonLauncher.md)

### Python API
There are 3 user-facing APIs for Neural Coder: enable, bench and superbench. For more details, please refer to this [guide](docs/PythonAPI.md). We have provided a [list](docs/SupportMatrix.md) of supported Deep Learning optimization features. Specifically for quantization, we provide an auto-quantization API that helps automatically enable quantization on Deep Learning models and automatically evaluates for the best performance on the model with no manual coding needed. Supported features include Post-Training Static Quantization, Post-Training Dynamic Quantization, and Mixed Precision. For more details, please refer to this [guide](docs/Quantization.md).

Expand Down
2 changes: 1 addition & 1 deletion neural_coder/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def parse_args():
parser.add_argument("--opt", type=str, default="",
help="optimization feature to enable")

parser.add_argument("--approach", type=str, default="static",
parser.add_argument("--approach", type=str, default="dynamic",
help="quantization approach (strategy)")

parser.add_argument('--config', type=str, default="",
Expand Down
35 changes: 35 additions & 0 deletions neural_coder/backends/intel_extension_for_transformers.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Note: For intel_extension_for_transformers support
# we default apply "PostTrainingDynamic" and "eval_f1"
# support for customization is pending further evaluation

transformation:
location:
- ["insert_below_dataloader_definition_line", "insert_below_model_definition_line"]
content:
- |-
[+] metric = metrics.Metric(name="eval_f1", is_relative=True, criterion=0.01)
[+] objective = objectives.performance
[+] q_config = QuantizationConfig(approach="PostTrainingDynamic", metrics=[metric], objectives=[objective])
[+] MODEL_NAME = trainer.quantize(quant_config=q_config)
order:
- below:
above:
- pytorch_jit_script
- pytorch_jit_script_ofi
- pytorch_jit_trace
- pytorch_jit_trace_ofi
- pytorch_channels_last
2 changes: 0 additions & 2 deletions neural_coder/coders/autoinc/autoinc_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,6 @@ def register_transformation(self):
lines_to_insert = lines_to_insert \
.replace("DATALOADER_NAME", dataloader_name)

if globals.optimum_quant_config == "":
globals.optimum_quant_config = "quantization/quant_config"
optimum_quant_config_line = \
'IncQuantizationConfig.from_pretrained("' + globals.optimum_quant_config + '")'

Expand Down
46 changes: 46 additions & 0 deletions neural_coder/coders/pytorch/change_trainer_to_nlptrainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from ...utils.line_operation import get_line_indent_level

class TrainerToNLPTrainer(object):
def __init__(self, file) -> None:
self.file = file
self.result = []

def transform(self):
lines = self.file.split('\n')

for line in lines:
if self.is_modify(line):
new_line = self.modify(line)
self.result.append(new_line)
else:
self.result.append(line)
for index, line in enumerate(self.result):
if index != len(self.result)-1:
self.result[index] += '\n'
return ''.join(self.result)

def is_modify(self, s):
if 'trainer = Trainer(' in s:
return True
else:
return False

def modify(self, s):
old = 'Trainer'
s = s.replace(old, 'NLPTrainer')
return s
6 changes: 3 additions & 3 deletions neural_coder/docs/PythonLauncher.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Example: Let's say you are running an NLP model using ```run_glue.py``` from Hug
python run_glue.py --model_name_or_path bert-base-cased --task_name mrpc --do_eval --output_dir result
```

With Neural Coder's **Launcher**, users can easily enjoy Deep Learning optimizations (default: INT8 static quantization by Intel® Neural Compressor) by simply adding an inline prefix
With Neural Coder's **Launcher**, users can easily enjoy Deep Learning optimizations (default: INT8 dynamic quantization by Intel® Neural Compressor) by simply adding an inline prefix
```bash
-m neural_coder
```
Expand All @@ -27,7 +27,7 @@ Note: Any modification on the optimized code ```run_glue_optimized.py``` will be

Users can specify which Deep Learning optimization they want to conduct using ```--opt``` argument. The list of supported Deep Learning optimization features can be found [here](SupportMatrix.md).

Note that if specifically optimizing with INT8 quantization by Intel® Neural Compressor, to choose a quantization approach (strategy), ```--approach``` argument can be specified with either ```static```, ```static_ipex``` or ```dynamic```. For example, to run INT8 dynamic quantization by Intel® Neural Compressor instead of the default static quantization:
Note that if specifically optimizing with INT8 quantization by Intel® Neural Compressor, to choose a quantization approach (strategy), ```--approach``` argument can be specified with either ```static```, ```static_ipex``` or ```dynamic```. For example, to run INT8 static quantization by Intel® Neural Compressor instead of the default dynamic quantization:
```bash
python -m neural_coder --approach dynamic run_glue.py --model_name_or_path bert-base-cased --task_name mrpc --do_eval --output_dir result
python -m neural_coder --approach static run_glue.py --model_name_or_path bert-base-cased --task_name mrpc --do_eval --output_dir result
```
32 changes: 30 additions & 2 deletions neural_coder/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def enable(
test_code_line=False, # print code line info for debug use
cache_load_transformers=True,
optimum_quant_config="", # only for HF optimum optimizations, yaml or hub path
use_inc=False,
):
"""enable a feature or a couple of features for the code

Expand Down Expand Up @@ -184,6 +185,7 @@ def enable(
"pytorch_cuda_to_cpu",
"pytorch_lightning_bf16_cpu",
"tensorflow_mixed_precision",
"change_trainer_to_nlptrainer",
]

# # features that need creating dummy dataloader (when needed) first
Expand All @@ -198,6 +200,10 @@ def enable(
"pytorch_inc_static_quant_ipex" in features:
features = ["pytorch_reclaim_inputs"] + features

# intel_extension_for_transformers
if "intel_extension_for_transformers" in features:
features = ["change_trainer_to_nlptrainer"] + features

transformed_list_code_path = []

## Determine Code Domain
Expand Down Expand Up @@ -276,20 +282,32 @@ def enable(
"pytorch_inc_static_quant_ipex",
"pytorch_inc_huggingface_optimum_static",
"pytorch_inc_huggingface_optimum_dynamic",
"onnx_inc_static_quant_qlinear"
"onnx_inc_static_quant_qlinear",
"onnx_inc_static_quant_qdq",
"onnx_inc_dynamic_quant",
"intel_extension_for_transformers",
]:

# determine domain
from .coders.autoinc.domain import determine_domain
globals.code_domain = determine_domain(globals.list_code_path[0])

# for transformers code, enable optimum-intel api by default
if "transformers" in globals.code_domain:
# if specify use_inc, then still use INC API
if "transformers" in globals.code_domain and not use_inc:
if "static_quant" in feature:
feature = "pytorch_inc_huggingface_optimum_static"
elif "dynamic_quant" in feature:
feature = "pytorch_inc_huggingface_optimum_dynamic"

# optimum-intel quantization config for static and dynamic
if feature == "pytorch_inc_huggingface_optimum_static":
globals.optimum_quant_config = "quantization/quant_config_static"
elif feature == "pytorch_inc_huggingface_optimum_dynamic":
globals.optimum_quant_config = "quantization/quant_config_dynamic"
else:
pass

from .coders.autoinc.autoinc_harness import AutoInc_Harness
from .coders.autoinc.calib_dataloader import Calib_Dataloader
from .coders.autoinc.eval_func import Eval_Func
Expand Down Expand Up @@ -332,6 +350,10 @@ def enable(
if "tensorflow_mixed_precision" in features:
from .coders.tensorflow.amp import TensorFlowKerasAMP
list_transformed_code[i] = TensorFlowKerasAMP(list_transformed_code[i]).transform()
# Change Trainer to NLPTrainer (only for intel_extension_for_pytorch)
if "change_trainer_to_nlptrainer" in features:
from .coders.pytorch.change_trainer_to_nlptrainer import TrainerToNLPTrainer
list_transformed_code[i] = TrainerToNLPTrainer(list_transformed_code[i]).transform()

logger.info(f"Code transformation for feature: [{feature}] finished.")

Expand Down Expand Up @@ -700,6 +722,7 @@ def superbench(
ncore_per_instance=-1, # only for "self_defined" mode
ninstances=-1, # only for "self_defined" mode
bench_batch_size=-1, # only for "self_defined" mode
use_inc=False,
auto_quant=False,
):

Expand Down Expand Up @@ -866,6 +889,7 @@ def superbench(
ncore_per_instance=ncore_per_instance,
ninstances=ninstances,
bench_batch_size=bench_batch_size,
use_inc=use_inc,
)

if dry_run:
Expand Down Expand Up @@ -1002,6 +1026,7 @@ def remove_if_have(list, element):
code=code,
features=features_to_generate,
save_patch_path="intel_optimization",
use_inc=use_inc,
)
logger.info('The optimization patch was saved to "intel_optimziation.diff"')

Expand Down Expand Up @@ -1061,6 +1086,7 @@ def remove_if_have(list, element):
ncore_per_instance=ncore_per_instance,
ninstances=ninstances,
bench_batch_size=bench_batch_size,
use_inc=use_inc,
)

if dry_run:
Expand Down Expand Up @@ -1225,6 +1251,7 @@ def auto_quant(
ncore_per_instance=-1, # only for "self_defined" mode
ninstances=-1, # only for "self_defined" mode
bench_batch_size=-1, # only for "self_defined" mode
use_inc=False,
):
return superbench(
code,
Expand All @@ -1240,5 +1267,6 @@ def auto_quant(
ncore_per_instance=ncore_per_instance, # only for "self_defined" mode
ninstances=ninstances, # only for "self_defined" mode
bench_batch_size=bench_batch_size, # only for "self_defined" mode
use_inc=use_inc,
auto_quant=True,
)