diff --git a/docs/source/_static/imgs/pruning/Pruning_patterns.JPG b/docs/source/_static/imgs/pruning/Pruning_patterns.JPG new file mode 100644 index 00000000000..38c061489c8 Binary files /dev/null and b/docs/source/_static/imgs/pruning/Pruning_patterns.JPG differ diff --git a/docs/source/_static/imgs/pruning/Pruning_patterns.PNG b/docs/source/_static/imgs/pruning/Pruning_patterns.PNG index d453622ed5a..0bb10d43906 100644 Binary files a/docs/source/_static/imgs/pruning/Pruning_patterns.PNG and b/docs/source/_static/imgs/pruning/Pruning_patterns.PNG differ diff --git a/docs/source/_static/imgs/pruning/Pruning_schedule.JPG b/docs/source/_static/imgs/pruning/Pruning_schedule.JPG new file mode 100644 index 00000000000..9e5063381a1 Binary files /dev/null and b/docs/source/_static/imgs/pruning/Pruning_schedule.JPG differ diff --git a/docs/source/_static/imgs/pruning/Regularization.JPG b/docs/source/_static/imgs/pruning/Regularization.JPG new file mode 100644 index 00000000000..94de6c74816 Binary files /dev/null and b/docs/source/_static/imgs/pruning/Regularization.JPG differ diff --git a/docs/source/pruning_details.md b/docs/source/pruning_details.md index 48e1df7398f..e55cd4fdca3 100644 --- a/docs/source/pruning_details.md +++ b/docs/source/pruning_details.md @@ -251,7 +251,8 @@ Pattern_lock pruning type uses masks of a fixed pattern during the pruning proce -Progressive pruning aims at smoothing the structured pruning by automatically interpolating a group of interval masks during the pruning process. In this method, a sequence of masks are generated to enable a more flexible pruning process and those masks would gradually change into ones to fit the target pruning structure. +Progressive pruning aims at smoothing the structured pruning by automatically interpolating a group of interval masks during the pruning process. In this method, a sequence of masks are generated to enable a more flexible pruning process and those masks would gradually change into ones to fit the target pruning structure. +Progressive pruning is used mainly for channel-wise pruning and currently only supports NxM pruning pattern. diff --git a/examples/pytorch/nlp/huggingface_models/question-answering/pruning/pytorch_pruner/eager/run_qa_no_trainer.py b/examples/pytorch/nlp/huggingface_models/question-answering/pruning/pytorch_pruner/eager/run_qa_no_trainer.py index a3966af4845..ddb785e6c5b 100644 --- a/examples/pytorch/nlp/huggingface_models/question-answering/pruning/pytorch_pruner/eager/run_qa_no_trainer.py +++ b/examples/pytorch/nlp/huggingface_models/question-answering/pruning/pytorch_pruner/eager/run_qa_no_trainer.py @@ -57,6 +57,8 @@ from transformers.utils import check_min_version from transformers.utils.versions import require_version from utils_qa import postprocess_qa_predictions +from neural_compressor.pruning import Pruning +from neural_compressor.pruner.utils import WeightPruningConfig # Will error if the minimal version of Transformers is not installed. Remove at your own risks. check_min_version("4.21.0.dev0") @@ -118,32 +120,32 @@ def parse_args(): help="The configuration name of the dataset to use (via the datasets library).", ) parser.add_argument( - "--train_file", - type=str, - default=None, + "--train_file", + type=str, + default=None, help="A csv or a json file containing the training data." ) parser.add_argument( - "--preprocessing_num_workers", - type=int, default=4, + "--preprocessing_num_workers", + type=int, default=4, help="A csv or a json file containing the training data." ) parser.add_argument( - "--do_predict", - action="store_true", + "--do_predict", + action="store_true", help="To do prediction on the question answering model" ) parser.add_argument( - "--validation_file", - type=str, - default=None, + "--validation_file", + type=str, + default=None, help="A csv or a json file containing the validation data." ) parser.add_argument( - "--test_file", - type=str, - default=None, + "--test_file", + type=str, + default=None, help="A csv or a json file containing the Prediction data." ) parser.add_argument( @@ -163,15 +165,14 @@ def parse_args(): parser.add_argument( "--model_name_or_path", type=str, - help="Path to pretrained model or model identifier from huggingface.co/models.", - required=False, + help="Path to pretrained model or model identifier from huggingface.co/models." ) parser.add_argument( "--teacher_model_name_or_path", type=str, default=None, help="Path to pretrained model or model identifier from huggingface.co/models.", - required=False, + required=False ) parser.add_argument( "--config_name", @@ -199,8 +200,8 @@ def parse_args(): parser.add_argument( "--distill_loss_weight", type=float, - default=1.0, - help="distiller loss weight", + default=0.0, + help="distiller loss weight" ) parser.add_argument( "--per_device_eval_batch_size", @@ -215,15 +216,15 @@ def parse_args(): help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( - "--weight_decay", - type=float, - default=0.0, + "--weight_decay", + type=float, + default=0.0, help="Weight decay to use." ) parser.add_argument( - "--num_train_epochs", - type=int, - default=3, + "--num_train_epochs", + type=int, + default=3, help="Total number of training epochs to perform." ) parser.add_argument( @@ -245,29 +246,28 @@ def parse_args(): help="The scheduler type to use.", choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], ) - parser.add_argument( - "--warm_epochs", - type=int, - default=0, + "--warm_epochs", + type=int, + default=0, help="Number of epochs the network not be purned" ) parser.add_argument( - "--num_warmup_steps", - type=int, - default=0, + "--num_warmup_steps", + type=int, + default=0, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( - "--output_dir", - type=str, - default=None, + "--output_dir", + type=str, + default=None, help="Where to store the final model." ) parser.add_argument( - "--seed", - type=int, - default=None, + "--seed", + type=int, + default=None, help="A seed for reproducible training." ) parser.add_argument( @@ -341,33 +341,18 @@ def parse_args(): choices=MODEL_TYPES, ) parser.add_argument( - "--cooldown_epochs", - type=int, default=0, - help="Cooling epochs after pruning." - ) - parser.add_argument( - "--do_prune", action="store_true", - help="Whether or not to prune the model" - ) - parser.add_argument( - "--pruning_config", - type=str, - help="pruning_config", - ) - - parser.add_argument( - "--push_to_hub", - action="store_true", + "--push_to_hub", + action="store_true", help="Whether or not to push the model to the Hub." ) parser.add_argument( - "--hub_model_id", - type=str, + "--hub_model_id", + type=str, help="The name of the repository to keep in sync with the local `output_dir`." ) parser.add_argument( - "--hub_token", - type=str, + "--hub_token", + type=str, help="The token to use to push to the Model Hub." ) parser.add_argument( @@ -382,13 +367,6 @@ def parse_args(): default=None, help="If the training should continue from a checkpoint folder.", ) - parser.add_argument( - "--distiller", - type=str, - default=None, - help="teacher model path", - ) - parser.add_argument( "--with_tracking", action="store_true", @@ -405,6 +383,35 @@ def parse_args(): ), ) + parser.add_argument( + "--cooldown_epochs", + type=int, default=0, + help="Cooling epochs after pruning." + ) + parser.add_argument( + "--do_prune", action="store_true", + help="Whether or not to prune the model" + ) + # parser.add_argument( + # "--keep_conf", action="store_true", + # help="Whether or not to keep the prune config infos" + # ) + parser.add_argument( + "--pruning_pattern", + type=str, default="1x1", + help="pruning pattern type, we support NxM and N:M." + ) + parser.add_argument( + "--target_sparsity", + type=float, default=0.8, + help="Target sparsity of the model." + ) + parser.add_argument( + "--pruning_frequency", + type=int, default=-1, + help="Sparse step frequency for iterative pruning, default to a quarter of pruning steps." + ) + args = parser.parse_args() # Sanity checks @@ -435,7 +442,7 @@ def parse_args(): def main(): args = parse_args() - # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. # send_example_telemetry("run_qa_no_trainer", args) @@ -528,10 +535,13 @@ def main(): "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) - - if args.teacher_model_name_or_path != None: + + if args.distill_loss_weight > 0: + teacher_path = args.teacher_model_name_or_path + if teacher_path is None: + teacher_path = args.model_name_or_path teacher_model = AutoModelForQuestionAnswering.from_pretrained( - args.teacher_model_name_or_path, + teacher_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, ) @@ -815,7 +825,6 @@ def post_processing_function(examples, features, predictions, stage="eval"): def create_and_fill_np_array(start_or_end_logits, dataset, max_len): """ Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor - Args: start_or_end_logits(:obj:`tensor`): This is the output predictions of the model. We can only enter either start or end logits. @@ -847,6 +856,7 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len): # Optimizer # Split weights in two groups, one with weight decay and the other not. no_decay = ["bias", "LayerNorm.weight"] + no_decay_outputs = ["bias", "LayerNorm.weight", "qa_outputs"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], @@ -876,10 +886,11 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len): num_training_steps=args.max_train_steps, ) - if args.teacher_model_name_or_path != None: + if args.distill_loss_weight > 0: teacher_model, model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( teacher_model, model, optimizer, train_dataloader, eval_dataloader, lr_scheduler ) + teacher_model.eval() else: # Prepare everything with our `accelerator`. model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( @@ -949,36 +960,36 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len): starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) - params = [(n, p) for (n, p) in model.named_parameters() if - "bias" not in n and "LayerNorm" not in n and "embeddings" not in n \ - and "layer.3.attention.output.dense.weight" not in n and "qa_outputs" not in n] - - params_keys = [n for (n, p) in params] - for key in params_keys: - print(key) - # Pruning preparation + pruning_configs=[ + { + "pruning_type": "snip_momentum", + "pruning_scope": "global", + "sparsity_decay_type": "exp" + } + ] + config = WeightPruningConfig( + pruning_configs, + target_sparsity=args.target_sparsity, + excluded_op_names=["qa_outputs", "pooler", ".*embeddings*"], + pruning_op_types=["Linear"], + max_sparsity_ratio_per_op=0.98, + pruning_scope="global", + pattern=args.pruning_pattern, + pruning_frequency=1000 + ) + pruner = Pruning(config) num_iterations = len(train_dataset) / total_batch_size - total_iterations = num_iterations * (args.num_train_epochs \ - - args.warm_epochs - args.cooldown_epochs) - args.num_warmup_steps - completed_pruned_cnt = 0 - total_cnt = 0 - for n, param in params: - total_cnt += param.numel() - print(f"The total param quantity is {total_cnt}") - - if args.teacher_model_name_or_path != None: - teacher_model.eval() - - from pytorch_pruner.pruning import Pruning - pruner = Pruning(args.pruning_config) + total_iterations = num_iterations * (args.num_train_epochs - args.warm_epochs - args.cooldown_epochs) \ + - args.num_warmup_steps if args.do_prune: - pruner.update_items_for_all_pruners( \ - start_step = int(args.warm_epochs*num_iterations+args.num_warmup_steps), \ - end_step = int(total_iterations))##iterative + start = int(args.warm_epochs * num_iterations+args.num_warmup_steps) + end = int(total_iterations) + frequency = int((end - start + 1) / 4) if (args.pruning_frequency == -1) else args.pruning_frequency + pruner.update_config(start_step=start, end_step=end, pruning_frequency=frequency)##iterative else: total_step = num_iterations * args.num_train_epochs + 1 - pruner.update_items_for_all_pruners(start_step=total_step, end_step=total_step) + pruner.update_config(start_step=total_step, end_step=total_step) pruner.model = model pruner.on_train_begin() @@ -994,14 +1005,14 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len): # We keep track of the loss at each epoch if args.with_tracking: total_loss += loss.detach().float() - if args.teacher_model_name_or_path != None: + if args.distill_loss_weight > 0: distill_loss_weight = args.distill_loss_weight with torch.no_grad(): teacher_outputs = teacher_model(**batch) loss = (distill_loss_weight) / 2 * get_loss_one_logit(outputs['start_logits'], - teacher_outputs['start_logits']) \ + teacher_outputs['start_logits']) \ + (distill_loss_weight) / 2 * get_loss_one_logit(outputs['end_logits'], - teacher_outputs['end_logits']) + teacher_outputs['end_logits']) loss = loss / args.gradient_accumulation_steps accelerator.backward(loss) @@ -1160,3 +1171,4 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len): if __name__ == "__main__": main() + diff --git a/neural_compressor/pruner/README.md b/neural_compressor/pruner/README.md index f81b47bd0ed..fee44bfde70 100644 --- a/neural_compressor/pruner/README.md +++ b/neural_compressor/pruner/README.md @@ -61,8 +61,8 @@ Neural network pruning is a promising model compression technique that removes t Pruning patterns defines the rules of pruned weights' arrangements in space. INC currently supports unstructured, N:M and NxM patterns. Please note that N:M pattern is applied to input channels while NxM pattern is applied to output ones. [Details](../../docs/source/pruning_details.md#pruning-patterns).
- -    Sparsity Pattern + +    Sparsity Pattern
@@ -70,11 +70,11 @@ Pruning patterns defines the rules of pruned weights' arrangements in space. INC -Pruning Criteria determines how should the weights of a neural network be scored and pruned. In the image below, pruning scores are represented by neurons' color and those with the lowest scores are pruned. The magnitude and gradient are widely used to score the weights. Currently, INC supports **magnitude**, **gradient**, **snip** and **snip_momentum** criteria. [Details](../../docs/source/pruning_details.md#pruning-criteria). +Pruning Criteria determines how should the weights of a neural network be scored and pruned. In the image below, pruning scores are represented by neurons' color and those with the lowest scores are pruned. The magnitude and gradient are widely used to score the weights. Currently, INC supports **magnitude**, **gradient**, **snip** and **snip_momentum** criteria; pruning criteria is defined along with pruning type in INC configurations. [Details](../../docs/source/pruning_details.md#pruning-criteria).
-    Pruning criteria +    Pruning criteria
@@ -85,8 +85,8 @@ Pruning Criteria determines how should the weights of a neural network be scored Pruning schedule defines the way the model reach the target sparsity (the ratio of pruned weights). Both **one-shot** and **iterative** pruning schedules are supported. [Details](../../docs/source/pruning_details.md#pruning-schedule).
- -    Pruning schedule + +    Pruning schedule
@@ -107,8 +107,8 @@ Regularization is a technique that discourages learning a more complex model and [Details](../../docs/source/pruning_details.md#regularization).
- -    Regularization + +    Regularization
@@ -122,7 +122,8 @@ Users can pass the customized training/evaluation functions to `Pruning` in vari -The following section is an example of how to use hooks in user pass-in training function to perform BERT training. Our pruning API supports multiple pruner objects in a single Pruning object, which means we can apply different pruning configurations for different layers in a model. Since these pruning configurations share the same parameter names, we introduce a global-local configuration structure to initialize a Pruning object. First, we set up a dict-like local_config, which refers to some unique configurations for specific pruners. Afterwards, we pass this local_config dict and common configurations for all pruners (known as "global setting") to Pruning's initialization function. Below is code example for how to utilize our global-local configuration method to initialize a Pruning object. +The following section exemplifies how to use hooks in user pass-in training function to perform model pruning. Through the pruning API, multiple pruner objects are supported in one single Pruning object to enable layer-specific configurations and a default setting is used as a complement. + @@ -130,8 +131,8 @@ The following section is an example of how to use hooks in user pass-in training from neural_compressor.pruning import Pruning, WeightPruningConfig config = WeightPruningConfig( - local_configs, # An example of local_configs is shown below. - target_sparsity=0.8, start_step=1, end_step=10, pruning_frequency=1 + pruning_configs, # An example of pruning_configs is shown below. + target_sparsity=0.8, start_step=1, end_step=10, pruning_frequency=1 # Default pruning setting. ) prune = Pruning(config) # Pruning constructor. prune.model = model # Set model object to prune. @@ -155,7 +156,7 @@ for epoch in range(num_train_epochs): ``` ```python -local_configs = [ +pruning_configs = [ { 'target_sparsity': 0.9, # Target sparsity ratio of modules. 'pruning_type': "snip_momentum", # Default pruning type. @@ -173,9 +174,9 @@ local_configs = [ }, { "op_names": ['layer3.*'], # A list of modules that would be pruned. - 'target_sparsity': 0.7, # Target sparsity ratio of modules. "pruning_type": "snip_momentum_progressive", # Pruning type for the listed ops. - } + # 'target_sparsity' + } # For layer3, the missing target_sparsity would be complemented by default setting (i.e. 0.8) ] ```