diff --git a/.azure-pipelines/scripts/codeScan/pyspelling/lpot_dict.txt b/.azure-pipelines/scripts/codeScan/pyspelling/lpot_dict.txt
index 3eea4060374..a7e46bb180d 100644
--- a/.azure-pipelines/scripts/codeScan/pyspelling/lpot_dict.txt
+++ b/.azure-pipelines/scripts/codeScan/pyspelling/lpot_dict.txt
@@ -696,6 +696,7 @@ Goyal
gpg
GPG
gpt
+GPTJ
gpu
gpus
GPUs
diff --git a/examples/.config/model_params_pytorch.json b/examples/.config/model_params_pytorch.json
index 848c1e9f0c6..df42ff22308 100644
--- a/examples/.config/model_params_pytorch.json
+++ b/examples/.config/model_params_pytorch.json
@@ -531,6 +531,15 @@
"batch_size": 64,
"new_benchmark": false
},
+ "gpt_j_wikitext":{
+ "model_src_dir": "nlp/huggingface_models/language-modeling/quantization/ptq_static/fx",
+ "dataset_location": "",
+ "input_model": "/tf_dataset2/models/pytorch/gpt-j-6B",
+ "yaml": "conf.yaml",
+ "strategy": "basic",
+ "batch_size": 8,
+ "new_benchmark": false
+ },
"xlm-roberta-base_MRPC": {
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_static/eager",
"dataset_location": "",
diff --git a/examples/README.md b/examples/README.md
index 3cbfd9758d9..49cd7d200c3 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -519,6 +519,12 @@ IntelĀ® Neural Compressor validated examples with multiple compression technique
Post-Training Dynamic Quantization |
eager |
+
+ | GPTJ |
+ Natural Language Processing |
+ Post-Training Static Quantization |
+ fx |
+
diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/README.md b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/README.md
new file mode 100644
index 00000000000..b3599c59a88
--- /dev/null
+++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/README.md
@@ -0,0 +1,38 @@
+Step-by-Step
+============
+
+This document is used to list steps of reproducing PyTorch BERT tuning zoo result.
+
+# Prerequisite
+
+## 1. Installation
+
+The dependent packages are all in requirements, please install as following.
+
+```
+pip install -r requirements.txt
+```
+
+## 2. Run
+
+If the automatic download from modelhub fails, you can download [EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/gpt-j-6B?text=My+name+is+Clara+and+I+am) offline.
+
+```shell
+
+python run_clm.py \
+ --model_name_or_path EleutherAI/gpt-j-6B \
+ --dataset_name wikitext\
+ --dataset_config_name wikitext-2-raw-v1 \
+ --do_train \
+ --do_eval \
+ --tune \
+ --output_dir /path/to/checkpoint/dir
+```
+
+
+## 3. Command
+
+```
+bash run_tuning.sh --topology=gpt_j_wikitext
+bash run_benchmark.sh --topology=gpt_j_wikitext --mode=performance --int8=true
+```
diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/conf.yaml b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/conf.yaml
new file mode 100644
index 00000000000..0f75f809781
--- /dev/null
+++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/conf.yaml
@@ -0,0 +1,31 @@
+#
+# Copyright (c) 2021 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+version: 1.0
+
+model: # mandatory. used to specify model specific information.
+ name: bert
+ framework: pytorch_fx # mandatory. possible values are tensorflow, mxnet, pytorch, pytorch_ipex, onnxrt_integerops and onnxrt_qlinearops.
+
+quantization: # optional. tuning constraints on model-wise for advance user to reduce tuning space.
+ approach: post_training_static_quant
+
+tuning:
+ accuracy_criterion:
+ relative: 0.5 # optional. default value is relative, other value is absolute. this example allows relative accuracy loss: 1%.
+ higher_is_better: False
+ exit_policy:
+ max_trials: 600
+ random_seed: 9527 # optional. random seed for deterministic tuning.
diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/requirements.txt b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/requirements.txt
new file mode 100644
index 00000000000..763bed755a8
--- /dev/null
+++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/requirements.txt
@@ -0,0 +1,5 @@
+sentencepiece != 0.1.92
+protobuf
+evaluate
+datasets
+transformers >= 4.22.0
diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/run_benchmark.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/run_benchmark.sh
new file mode 100644
index 00000000000..a36507f4fca
--- /dev/null
+++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/run_benchmark.sh
@@ -0,0 +1,91 @@
+#!/bin/bash
+set -x
+
+function main {
+
+ init_params "$@"
+ run_benchmark
+
+}
+
+# init params
+function init_params {
+ iters=100
+ batch_size=16
+ tuned_checkpoint=saved_results
+ max_eval_samples=`expr ${iters} \* ${batch_size}`
+ echo ${max_eval_samples}
+ for var in "$@"
+ do
+ case $var in
+ --topology=*)
+ topology=$(echo $var |cut -f2 -d=)
+ ;;
+ --dataset_location=*)
+ dataset_location=$(echo $var |cut -f2 -d=)
+ ;;
+ --input_model=*)
+ input_model=$(echo $var |cut -f2 -d=)
+ ;;
+ --mode=*)
+ mode=$(echo $var |cut -f2 -d=)
+ ;;
+ --batch_size=*)
+ batch_size=$(echo $var |cut -f2 -d=)
+ ;;
+ --iters=*)
+ iters=$(echo ${var} |cut -f2 -d=)
+ ;;
+ --int8=*)
+ int8=$(echo ${var} |cut -f2 -d=)
+ ;;
+ --config=*)
+ tuned_checkpoint=$(echo $var |cut -f2 -d=)
+ ;;
+ *)
+ echo "Error: No such parameter: ${var}"
+ exit 1
+ ;;
+ esac
+ done
+
+}
+
+
+# run_benchmark
+function run_benchmark {
+ extra_cmd=''
+
+ if [[ ${mode} == "accuracy" ]]; then
+ mode_cmd=" --accuracy_only "
+ elif [[ ${mode} == "benchmark" ]]; then
+ mode_cmd=" --benchmark "
+ extra_cmd=$extra_cmd" --max_eval_samples ${max_eval_samples}"
+ else
+ echo "Error: No such mode: ${mode}"
+ exit 1
+ fi
+
+ if [ "${topology}" = "gpt_j_wikitext" ]; then
+ TASK_NAME='wikitext'
+ model_name_or_path=$input_model
+ extra_cmd='--dataset_config_name=wikitext-2-raw-v1'
+ fi
+
+ if [[ ${int8} == "true" ]]; then
+ extra_cmd=$extra_cmd" --int8"
+ fi
+ echo $extra_cmd
+
+ python -u run_clm.py \
+ --model_name_or_path ${model_name_or_path} \
+ --dataset_name ${TASK_NAME} \
+ --do_eval \
+ --per_device_eval_batch_size ${batch_size} \
+ --output_dir ${tuned_checkpoint} \
+ ${mode_cmd} \
+ ${extra_cmd}
+
+}
+
+main "$@"
diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/run_clm.py b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/run_clm.py
new file mode 100644
index 00000000000..17a32f1b57a
--- /dev/null
+++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/run_clm.py
@@ -0,0 +1,650 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
+Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
+https://huggingface.co/models?filter=text-generation
+"""
+# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
+
+import logging
+import math
+import os
+import sys
+from dataclasses import dataclass, field
+from itertools import chain
+from typing import Optional
+
+import datasets
+from datasets import load_dataset
+
+import evaluate
+import transformers
+from transformers import (
+ CONFIG_MAPPING,
+ MODEL_FOR_CAUSAL_LM_MAPPING,
+ AutoConfig,
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ HfArgumentParser,
+ Trainer,
+ TrainingArguments,
+ default_data_collator,
+ is_torch_tpu_available,
+ set_seed,
+)
+from transformers.testing_utils import CaptureLogger
+from transformers.trainer_utils import get_last_checkpoint
+from transformers.utils import check_min_version, send_example_telemetry
+from transformers.utils.versions import require_version
+
+
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.22.0.dev0")
+
+require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
+
+logger = logging.getLogger(__name__)
+
+
+MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
+MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
+
+
+@dataclass
+class ModelArguments:
+ """
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
+ """
+
+ model_name_or_path: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
+ )
+ },
+ )
+ model_type: Optional[str] = field(
+ default=None,
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
+ )
+ config_overrides: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Override some existing default config settings when a model is trained from scratch. Example: "
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
+ )
+ },
+ )
+ config_name: Optional[str] = field(
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
+ )
+ tokenizer_name: Optional[str] = field(
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
+ )
+ cache_dir: Optional[str] = field(
+ default=None,
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
+ )
+ use_fast_tokenizer: bool = field(
+ default=True,
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
+ )
+ model_revision: str = field(
+ default="main",
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
+ )
+ use_auth_token: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
+ "with private models)."
+ )
+ },
+ )
+ tune: bool = field(
+ default=False, metadata={"help": "tune quantized model with Neural Compressor"}
+ )
+ int8: bool = field(
+ default=False, metadata={"help": "use int8 model to get accuracy or benchmark"}
+ )
+ benchmark: bool = field(
+ default=False, metadata={"help": "get benchmark instead of accuracy"}
+ )
+ accuracy_only: bool = field(
+ default=False, metadata={"help": "get accuracy"}
+ )
+
+ def __post_init__(self):
+ if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
+ raise ValueError(
+ "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
+ )
+
+
+@dataclass
+class DataTrainingArguments:
+ """
+ Arguments pertaining to what data we are going to input our model for training and eval.
+ """
+
+ dataset_name: Optional[str] = field(
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
+ )
+ dataset_config_name: Optional[str] = field(
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
+ )
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
+ validation_file: Optional[str] = field(
+ default=None,
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
+ )
+ max_train_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ )
+ },
+ )
+ max_eval_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ )
+ },
+ )
+
+ block_size: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Optional input sequence length after tokenization. "
+ "The training dataset will be truncated in block of this size for training. "
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ )
+ },
+ )
+ overwrite_cache: bool = field(
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
+ )
+ validation_split_percentage: Optional[int] = field(
+ default=5,
+ metadata={
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
+ },
+ )
+ preprocessing_num_workers: Optional[int] = field(
+ default=None,
+ metadata={"help": "The number of processes to use for the preprocessing."},
+ )
+ keep_linebreaks: bool = field(
+ default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
+ )
+
+ def __post_init__(self):
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
+ raise ValueError("Need either a dataset name or a training/validation file.")
+ else:
+ if self.train_file is not None:
+ extension = self.train_file.split(".")[-1]
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
+ if self.validation_file is not None:
+ extension = self.validation_file.split(".")[-1]
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
+
+
+def main():
+ # See all possible arguments in src/transformers/training_args.py
+ # or by passing the --help flag to this script.
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
+
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+ # If we pass only one argument to the script and it's the path to a json file,
+ # let's parse it to get our arguments.
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
+ else:
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
+ send_example_telemetry("run_clm", model_args, data_args)
+
+ # Setup logging
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+
+ log_level = training_args.get_process_log_level()
+ logger.setLevel(log_level)
+ datasets.utils.logging.set_verbosity(log_level)
+ transformers.utils.logging.set_verbosity(log_level)
+ transformers.utils.logging.enable_default_handler()
+ transformers.utils.logging.enable_explicit_format()
+
+ # Log on each process the small summary:
+ logger.warning(
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
+ )
+ logger.info(f"Training/evaluation parameters {training_args}")
+
+ # Detecting last checkpoint.
+ last_checkpoint = None
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
+ raise ValueError(
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
+ "Use --overwrite_output_dir to overcome."
+ )
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
+ logger.info(
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
+ )
+
+ # Set seed before initializing model.
+ set_seed(training_args.seed)
+
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
+ # (the dataset will be downloaded automatically from the datasets Hub).
+ #
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
+ # 'text' is found. You can easily tweak this behavior (see below).
+ #
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
+ # download the dataset.
+ if data_args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ raw_datasets = load_dataset(
+ data_args.dataset_name,
+ data_args.dataset_config_name,
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ if "validation" not in raw_datasets.keys():
+ raw_datasets["validation"] = load_dataset(
+ data_args.dataset_name,
+ data_args.dataset_config_name,
+ split=f"train[:{data_args.validation_split_percentage}%]",
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ raw_datasets["train"] = load_dataset(
+ data_args.dataset_name,
+ data_args.dataset_config_name,
+ split=f"train[{data_args.validation_split_percentage}%:]",
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ else:
+ data_files = {}
+ dataset_args = {}
+ if data_args.train_file is not None:
+ data_files["train"] = data_args.train_file
+ if data_args.validation_file is not None:
+ data_files["validation"] = data_args.validation_file
+ extension = (
+ data_args.train_file.split(".")[-1]
+ if data_args.train_file is not None
+ else data_args.validation_file.split(".")[-1]
+ )
+ if extension == "txt":
+ extension = "text"
+ dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
+ raw_datasets = load_dataset(
+ extension,
+ data_files=data_files,
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ **dataset_args,
+ )
+ # If no validation data is there, validation_split_percentage will be used to divide the dataset.
+ if "validation" not in raw_datasets.keys():
+ raw_datasets["validation"] = load_dataset(
+ extension,
+ data_files=data_files,
+ split=f"train[:{data_args.validation_split_percentage}%]",
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ **dataset_args,
+ )
+ raw_datasets["train"] = load_dataset(
+ extension,
+ data_files=data_files,
+ split=f"train[{data_args.validation_split_percentage}%:]",
+ cache_dir=model_args.cache_dir,
+ use_auth_token=True if model_args.use_auth_token else None,
+ **dataset_args,
+ )
+
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
+
+ # Load pretrained model and tokenizer
+ #
+ # Distributed training:
+ # The .from_pretrained methods guarantee that only one local process can concurrently
+ # download model & vocab.
+
+ config_kwargs = {
+ "cache_dir": model_args.cache_dir,
+ "revision": model_args.model_revision,
+ "use_auth_token": True if model_args.use_auth_token else None,
+ }
+ if model_args.config_name:
+ config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
+ elif model_args.model_name_or_path:
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
+ else:
+ config = CONFIG_MAPPING[model_args.model_type]()
+ logger.warning("You are instantiating a new config instance from scratch.")
+ if model_args.config_overrides is not None:
+ logger.info(f"Overriding config: {model_args.config_overrides}")
+ config.update_from_string(model_args.config_overrides)
+ logger.info(f"New config: {config}")
+
+ tokenizer_kwargs = {
+ "cache_dir": model_args.cache_dir,
+ "use_fast": model_args.use_fast_tokenizer,
+ "revision": model_args.model_revision,
+ "use_auth_token": True if model_args.use_auth_token else None,
+ }
+ if model_args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
+ elif model_args.model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
+ else:
+ raise ValueError(
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
+ )
+
+ if model_args.model_name_or_path:
+ model = AutoModelForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
+ config=config,
+ cache_dir=model_args.cache_dir,
+ revision=model_args.model_revision,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ else:
+ model = AutoModelForCausalLM.from_config(config)
+ n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
+ logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
+
+ # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
+ # on a small vocab and want a smaller embedding size, remove this test.
+ embedding_size = model.get_input_embeddings().weight.shape[0]
+ if len(tokenizer) > embedding_size:
+ model.resize_token_embeddings(len(tokenizer))
+
+ # Preprocessing the datasets.
+ # First we tokenize all the texts.
+ if training_args.do_train:
+ column_names = raw_datasets["train"].column_names
+ else:
+ column_names = raw_datasets["validation"].column_names
+ text_column_name = "text" if "text" in column_names else column_names[0]
+
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
+ tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
+
+ def tokenize_function(examples):
+ with CaptureLogger(tok_logger) as cl:
+ output = tokenizer(examples[text_column_name])
+ # clm input could be much much longer than block_size
+ if "Token indices sequence length is longer than the" in cl.out:
+ tok_logger.warning(
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
+ " before being passed to the model."
+ )
+ return output
+
+ with training_args.main_process_first(desc="dataset map tokenization"):
+ tokenized_datasets = raw_datasets.map(
+ tokenize_function,
+ batched=True,
+ num_proc=data_args.preprocessing_num_workers,
+ remove_columns=column_names,
+ load_from_cache_file=not data_args.overwrite_cache,
+ desc="Running tokenizer on dataset",
+ )
+
+ if data_args.block_size is None:
+ block_size = tokenizer.model_max_length
+ if block_size > 1024:
+ logger.warning(
+ f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
+ "Picking 1024 instead. You can change that default value by passing --block_size xxx."
+ )
+ block_size = 1024
+ else:
+ if data_args.block_size > tokenizer.model_max_length:
+ logger.warning(
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
+ )
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
+
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
+ def group_texts(examples):
+ # Concatenate all texts.
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
+ # customize this part to your needs.
+ if total_length >= block_size:
+ total_length = (total_length // block_size) * block_size
+ # Split by chunks of max_len.
+ result = {
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
+ for k, t in concatenated_examples.items()
+ }
+ result["labels"] = result["input_ids"].copy()
+ return result
+
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
+ # to preprocess.
+ #
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
+
+ with training_args.main_process_first(desc="grouping texts together"):
+ lm_datasets = tokenized_datasets.map(
+ group_texts,
+ batched=True,
+ num_proc=data_args.preprocessing_num_workers,
+ load_from_cache_file=not data_args.overwrite_cache,
+ desc=f"Grouping texts in chunks of {block_size}",
+ )
+
+ if training_args.do_train:
+ if "train" not in tokenized_datasets:
+ raise ValueError("--do_train requires a train dataset")
+ train_dataset = lm_datasets["train"]
+ if data_args.max_train_samples is not None:
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
+ train_dataset = train_dataset.select(range(max_train_samples))
+
+ if training_args.do_eval:
+ if "validation" not in tokenized_datasets:
+ raise ValueError("--do_eval requires a validation dataset")
+ eval_dataset = lm_datasets["validation"]
+ if data_args.max_eval_samples is not None:
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
+
+ def preprocess_logits_for_metrics(logits, labels):
+ if isinstance(logits, tuple):
+ # Depending on the model and config, logits may contain extra tensors,
+ # like past_key_values, but logits always come first
+ logits = logits[0]
+ return logits.argmax(dim=-1)
+
+ metric = evaluate.load("accuracy")
+
+ def compute_metrics(eval_preds):
+ preds, labels = eval_preds
+ # preds have the same shape as the labels, after the argmax(-1) has been calculated
+ # by preprocess_logits_for_metrics but we need to shift the labels
+ labels = labels[:, 1:].reshape(-1)
+ preds = preds[:, :-1].reshape(-1)
+ return metric.compute(predictions=preds, references=labels)
+
+ # Initialize our Trainer
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_dataset if training_args.do_train else None,
+ eval_dataset=eval_dataset if training_args.do_eval else None,
+ tokenizer=tokenizer,
+ # Data collator will default to DataCollatorWithPadding, so we change it.
+ data_collator=default_data_collator,
+ compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics
+ if training_args.do_eval and not is_torch_tpu_available()
+ else None,
+ )
+
+ # Tune
+ if model_args.tune:
+ def eval_func_for_nc(model_tuned):
+ trainer.model = model_tuned
+ eval_output = trainer.evaluate(eval_dataset=eval_dataset)
+ perplexity = math.exp(eval_output["eval_loss"])
+ results = {"perplexity":perplexity,"eval_loss":eval_output["eval_loss"],\
+ "eval_samples_per_second":eval_output['eval_samples_per_second']}
+ clm_task_metrics_keys = ["perplexity","eval_loss"]
+ for key in clm_task_metrics_keys:
+ if key in results.keys():
+ logger.info("Finally Eval {}:{}".format(key, results[key]))
+ if key=="eval_loss":
+ eval_loss = results[key]
+ break
+ print("Accuracy: %.5f" % eval_loss)
+ print('Throughput: %.3f samples/sec' % (results["eval_samples_per_second"]))
+ print('Latency: %.3f ms' % (1 * 1000 / results["eval_samples_per_second"]))
+ print('Batch size = %d' % training_args.per_device_eval_batch_size)
+
+ return eval_loss
+
+ from neural_compressor.experimental import Quantization, common
+ quantizer = Quantization("./conf.yaml")
+ quantizer.model = common.Model(model)
+ quantizer.calib_dataloader = trainer.get_eval_dataloader()
+ quantizer.eval_func = eval_func_for_nc
+ q_model = quantizer.fit()
+ q_model.save(training_args.output_dir)
+ exit(0)
+
+ # Benchmark or accuracy
+ if model_args.benchmark or model_args.accuracy_only:
+ if model_args.int8:
+ from neural_compressor.utils.pytorch import load
+ new_model = load(
+ os.path.abspath(os.path.expanduser(training_args.output_dir)), model)
+ else:
+ new_model = model
+ trainer.model = new_model
+ eval_output = trainer.evaluate(eval_dataset=eval_dataset)
+ perplexity = math.exp(eval_output["eval_loss"])
+ results = {"perplexity":perplexity,"eval_loss":eval_output["eval_loss"],\
+ "eval_samples_per_second":eval_output['eval_samples_per_second']}
+ clm_task_metrics_keys = ["eval_loss"]
+ for key in clm_task_metrics_keys:
+ if key in results.keys():
+ acc = results[key]
+ break
+ print("Accuracy: %.5f" % acc)
+ print('Throughput: %.3f samples/sec' % (results["eval_samples_per_second"]))
+ print('Latency: %.3f ms' % (1 * 1000 / results["eval_samples_per_second"]))
+ print('Batch size = %d' % training_args.per_device_eval_batch_size)
+ exit(0)
+
+ # Training
+ if training_args.do_train:
+ checkpoint = None
+ if training_args.resume_from_checkpoint is not None:
+ checkpoint = training_args.resume_from_checkpoint
+ elif last_checkpoint is not None:
+ checkpoint = last_checkpoint
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
+ trainer.save_model() # Saves the tokenizer too for easy upload
+
+ metrics = train_result.metrics
+
+ max_train_samples = (
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
+ )
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
+
+ trainer.log_metrics("train", metrics)
+ trainer.save_metrics("train", metrics)
+ trainer.save_state()
+
+ # Evaluation
+ if training_args.do_eval:
+ logger.info("*** Evaluate ***")
+
+ metrics = trainer.evaluate()
+
+ max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
+ try:
+ perplexity = math.exp(metrics["eval_loss"])
+ except OverflowError:
+ perplexity = float("inf")
+ metrics["perplexity"] = perplexity
+
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
+ if data_args.dataset_name is not None:
+ kwargs["dataset_tags"] = data_args.dataset_name
+ if data_args.dataset_config_name is not None:
+ kwargs["dataset_args"] = data_args.dataset_config_name
+ kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
+ else:
+ kwargs["dataset"] = data_args.dataset_name
+
+ if training_args.push_to_hub:
+ trainer.push_to_hub(**kwargs)
+ else:
+ trainer.create_model_card(**kwargs)
+
+
+def _mp_fn(index):
+ # For xla_spawn (TPUs)
+ main()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/run_tuning.sh b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/run_tuning.sh
new file mode 100644
index 00000000000..04b16872a59
--- /dev/null
+++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx/run_tuning.sh
@@ -0,0 +1,63 @@
+#!/bin/bash
+set -x
+
+function main {
+
+ init_params "$@"
+ run_tuning
+
+}
+
+# init params
+function init_params {
+ tuned_checkpoint=saved_results
+ for var in "$@"
+ do
+ case $var in
+ --topology=*)
+ topology=$(echo $var |cut -f2 -d=)
+ ;;
+ --dataset_location=*)
+ dataset_location=$(echo $var |cut -f2 -d=)
+ ;;
+ --input_model=*)
+ input_model=$(echo $var |cut -f2 -d=)
+ ;;
+ --output_model=*)
+ tuned_checkpoint=$(echo $var |cut -f2 -d=)
+ ;;
+ *)
+ echo "Error: No such parameter: ${var}"
+ exit 1
+ ;;
+ esac
+ done
+
+}
+
+# run_tuning
+function run_tuning {
+ extra_cmd=''
+ batch_size=8
+ model_type='bert'
+ approach='post_training_static_quant'
+
+ if [ "${topology}" = "gpt_j_wikitext" ]; then
+ TASK_NAME='wikitext'
+ model_name_or_path=$input_model
+ extra_cmd='--dataset_config_name=wikitext-2-raw-v1'
+ fi
+
+
+ python -u run_clm.py \
+ --model_name_or_path ${model_name_or_path} \
+ --dataset_name ${TASK_NAME} \
+ --do_eval \
+ --per_device_eval_batch_size ${batch_size} \
+ --output_dir ${tuned_checkpoint} \
+ --tune \
+ ${extra_cmd}
+
+}
+
+main "$@"