diff --git a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt
index cecf227cbaf..83a0d9c1521 100644
--- a/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt
+++ b/.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt
@@ -150,6 +150,7 @@ berts
bertsquad
BertTokenizer
bfloat
+blockwise
BFP
BGR
Bianchi
@@ -327,6 +328,7 @@ convolutional
Convolutional
ConvPerStage
ConvReLU
+cooldown
copt
coreml
CoreML
@@ -741,6 +743,7 @@ horovodrun
hostfile
Hounsfield
howpublished
+hyp
HqEgzS
href
html
@@ -1179,6 +1182,7 @@ ngatang
NGPUS
ngram
NHWC
+ni
NIC
nifti
niftis
@@ -1240,8 +1244,11 @@ nvidia
NVIDIA
NVIDIA's
nvme
+nw
Nx
+NxM
nyu
+oc
ok
ol
Omer
@@ -1251,6 +1258,7 @@ oneapi
oneAPI
onednn
oneDNN
+oneshot
onlinedocs
onnx
ONNX
@@ -1885,6 +1893,7 @@ UI
UID
uint
uk
+ultralytics
un
uncomment
uncompress
@@ -1895,6 +1904,7 @@ unidecode
uniq
unittest
unref
+unscale
unsqueeze
unstack
upenn
@@ -2121,6 +2131,7 @@ tensorrt
hardwares
BenchmarkConf
PruningConf
+Pruning's
DistillationConf
grey
ModelZoo
@@ -2442,3 +2453,8 @@ QuantizationAwareTrainingConfig
Startup
doesn
startup
+Ajanthan
+WeightPruningConfig
+Namhoon
+Thalaiyasingam
+Torr
diff --git a/docs/source/_static/imgs/pruning/Pruning_patterns.PNG b/docs/source/_static/imgs/pruning/Pruning_patterns.PNG
new file mode 100644
index 00000000000..d453622ed5a
Binary files /dev/null and b/docs/source/_static/imgs/pruning/Pruning_patterns.PNG differ
diff --git a/docs/source/_static/imgs/pruning/pruning.PNG b/docs/source/_static/imgs/pruning/pruning.PNG
new file mode 100644
index 00000000000..0c6c53295ab
Binary files /dev/null and b/docs/source/_static/imgs/pruning/pruning.PNG differ
diff --git a/docs/source/_static/imgs/pruning/pruning_criteria.PNG b/docs/source/_static/imgs/pruning/pruning_criteria.PNG
new file mode 100644
index 00000000000..a91fcbabb5f
Binary files /dev/null and b/docs/source/_static/imgs/pruning/pruning_criteria.PNG differ
diff --git a/docs/source/_static/imgs/pruning/pruning_patterns.png b/docs/source/_static/imgs/pruning/pruning_patterns.png
index 872c6cf8b35..d453622ed5a 100644
Binary files a/docs/source/_static/imgs/pruning/pruning_patterns.png and b/docs/source/_static/imgs/pruning/pruning_patterns.png differ
diff --git a/docs/source/_static/imgs/pruning/pruning_schedule.PNG b/docs/source/_static/imgs/pruning/pruning_schedule.PNG
new file mode 100644
index 00000000000..abd07603d5d
Binary files /dev/null and b/docs/source/_static/imgs/pruning/pruning_schedule.PNG differ
diff --git a/docs/source/_static/imgs/pruning/regularization.PNG b/docs/source/_static/imgs/pruning/regularization.PNG
new file mode 100644
index 00000000000..2feb6ae276e
Binary files /dev/null and b/docs/source/_static/imgs/pruning/regularization.PNG differ
diff --git a/docs/source/pruning.md b/docs/source/pruning.md
index 89e6567737e..fe951fc98e4 100644
--- a/docs/source/pruning.md
+++ b/docs/source/pruning.md
@@ -32,7 +32,7 @@ Neural network pruning (briefly known as pruning or sparsity) is one of the most
Pruning patterns defines the rules of pruned weights' arrangements in space.
-
+
diff --git a/docs/source/pruning_details.md b/docs/source/pruning_details.md
new file mode 100644
index 00000000000..48e1df7398f
--- /dev/null
+++ b/docs/source/pruning_details.md
@@ -0,0 +1,315 @@
+Pruning details
+
+============
+
+
+
+
+
+
+1. [Introduction](#introduction)
+
+
+
+
+
+
+>>>[Neural Network Pruning](#neural-network-pruning)
+
+
+
+
+
+
+>>>[Pruning Patterns](#pruning-patterns)
+
+
+
+
+
+
+>>>[Pruning Criteria](#pruning-criteria)
+
+
+
+
+
+
+>>>[Pruning Schedule](#pruning-schedule)
+
+
+
+
+
+
+>>>[Pruning Type](#pruning-type)
+
+
+
+
+
+
+>>>[Regularization](#regularization)
+
+
+
+
+
+
+
+
+2. [Pruning examples](#examples)
+
+
+
+
+
+
+3. [Reference](#reference)
+
+
+
+
+
+
+## Introduction
+
+
+
+
+
+
+### Neural Network Pruning
+
+Neural network pruning is a promising model compression technique that removes the least important parameters in the network and achieves compact architectures with minimal accuracy drop and maximal inference acceleration. As state-of-the-art model sizes have grown at an unprecedented speed, pruning has become increasingly crucial for reducing the computational and memory footprint that huge neural networks require.
+
+
+
+
+
+
+
+### Pruning Patterns
+
+
+
+
+
+- Unstructured Pruning
+
+
+
+
+
+Unstructured pruning means pruning the least salient connections in the model. The nonzero patterns are irregular and could be anywhere in the matrix.
+
+
+
+
+
+- Structured Pruning
+
+
+
+
+
+Structured pruning means pruning parameters in groups and deleting entire blocks, filters, or channels according to some pruning criterions. In general, structured pruning leads to lower accuracy due to restrictive structure compared to unstructured pruning but it can significantly accelerate the model execution as it fits better with hardware designs.
+
+
+
+
+
+
+
+
+
+### Pruning Criteria
+
+
+
+
+
+
+Pruning criteria determines how should the weights of a neural network be scored and pruned. The magnitude and gradient are widely used to score the weights.
+
+
+
+
+
+- Magnitude
+
+
+
+
+
+ The algorithm prunes the weight by the lowest absolute value at each layer with given sparsity target.
+
+
+
+
+
+- Gradient
+
+
+
+
+ The algorithm prunes the weight by the lowest gradient value at each layer with given sparsity target.
+
+
+
+
+- SNIP
+
+
+
+
+
+ The algorithm prunes the dense model at its initialization, by analyzing the weights' effect to the loss function when they are masked. Please refer to the original [paper](https://arxiv.org/abs/1810.02340) for details
+
+
+
+
+
+- SNIP with momentum
+
+
+
+
+
+ The algorithm improves original SNIP algorithms and introduces weights' score maps which updates in a momentum way.\
+
+ In the following formula, $n$ is the pruning step and $W$ and $G$ are model's weights and gradients respectively.
+
+ $$Score_{n} = 1.0 \times Score_{n-1} + 0.9 \times |W_{n} \times G_{n}|$$
+
+
+
+
+
+
+### Pruning Schedule
+
+
+
+
+
+Pruning schedule defines the way the model reach the target sparsity (the ratio of pruned weights).
+
+
+
+
+
+- One-shot Pruning
+
+
+
+
+
+ One-shot pruning means the model is pruned to its target sparsity with one single step. This pruning method often works at model's initialization step. It can easily cause accuracy drop, but save much training time.
+
+
+
+
+
+
+- Iterative Pruning
+
+
+
+
+
+ Iterative pruning means the model is gradually pruned to its target sparsity during a training process. The pruning process contains several pruning steps, and each step raises model's sparsity to a higher value. In the final pruning step, the model reaches target sparsity and the pruning process ends.
+
+
+
+
+
+
+
+### Pruning Type
+
+
+
+
+
+
+- Pattern_lock Pruning
+
+
+
+
+
+Pattern_lock pruning type uses masks of a fixed pattern during the pruning process.
+
+
+
+
+
+- Progressive Pruning
+
+
+
+
+
+Progressive pruning aims at smoothing the structured pruning by automatically interpolating a group of interval masks during the pruning process. In this method, a sequence of masks are generated to enable a more flexible pruning process and those masks would gradually change into ones to fit the target pruning structure.
+
+
+
+
+
+### Regularization
+
+
+
+
+
+Regularization is a technique that discourages learning a more complex model and therefore performs variable-selection.
+
+
+
+
+
+- Group Lasso
+
+
+
+
+
+ The Group-lasso algorithm is used to prune entire rows, columns or blocks of parameters that result in a smaller dense network.
+
+
+
+
+
+
+
+## Pruning Examples
+
+
+
+
+We validate the pruning technique on typical models across various domains (including CV and NLP).
+
+
+
+
+## Reference
+
+
+
+
+[1] Namhoon Lee, Thalaiyasingam Ajanthan, and Philip Torr. SNIP: SINGLE-SHOT NETWORK
+
+PRUNING BASED ON CONNECTION SENSITIVITY. In International Conference on
+
+Learning Representations, 2019.
+
+
+
+
+
+
+
+
+
+
diff --git a/neural_compressor/conf/config.py b/neural_compressor/conf/config.py
index 64a100dc113..de0105e7f84 100644
--- a/neural_compressor/conf/config.py
+++ b/neural_compressor/conf/config.py
@@ -20,7 +20,7 @@
from ..adaptor import FRAMEWORKS
from ..strategy import STRATEGIES
from ..objective import OBJECTIVES
-from ..pruners import PRUNERS
+from ..pruner.pruner_legacy import PRUNERS
from ..utils import logger
from ..version import __version__
import re
diff --git a/neural_compressor/experimental/pruning.py b/neural_compressor/experimental/pruning.py
index 9f4e5bb128a..727b437f644 100644
--- a/neural_compressor/experimental/pruning.py
+++ b/neural_compressor/experimental/pruning.py
@@ -17,7 +17,7 @@
# limitations under the License.
from .component import Component
-from ..pruners import PRUNERS
+from ..pruner.pruner_legacy import PRUNERS
from ..utils import logger
from ..utils.utility import GLOBAL_STATE, MODE
from ..utils.create_obj_from_config import create_dataloader, create_train_func, create_eval_func
diff --git a/neural_compressor/pruner/README.md b/neural_compressor/pruner/README.md
new file mode 100644
index 00000000000..f81b47bd0ed
--- /dev/null
+++ b/neural_compressor/pruner/README.md
@@ -0,0 +1,210 @@
+Pruning
+============
+
+
+
+1. [Introduction](#introduction)
+
+
+
+ - [Neural Network Pruning](#neural-network-pruning)
+
+
+
+ - [Pruning Patterns](#pruning-patterns)
+
+
+
+ - [Pruning Criteria](#pruning-criteria)
+
+
+
+ - [Pruning Schedules](#pruning-schedule)
+
+
+
+ - [Pruning types](#pruning-type)
+
+
+
+ - [Regularization](#regularization)
+
+
+
+2. [Get Started With Pruning API](#get-started-with-pruning-api)
+
+
+
+3. [Examples](#examples)
+
+
+
+
+## Introduction
+
+
+
+### Neural Network Pruning
+Neural network pruning is a promising model compression technique that removes the least important parameters/neurons in the network and achieves compact architectures with minimal accuracy drop and maximal inference acceleration. As state-of-the-art model sizes have grown at an unprecedented speed, pruning has become increasingly crucial for reducing the computational and memory footprint that huge neural networks require.
+
+
+
+
+### Pruning Patterns
+
+
+
+Pruning patterns defines the rules of pruned weights' arrangements in space. INC currently supports unstructured, N:M and NxM patterns. Please note that N:M pattern is applied to input channels while NxM pattern is applied to output ones. [Details](../../docs/source/pruning_details.md#pruning-patterns).
+
+
+
+### Pruning Criteria
+
+
+
+Pruning Criteria determines how should the weights of a neural network be scored and pruned. In the image below, pruning scores are represented by neurons' color and those with the lowest scores are pruned. The magnitude and gradient are widely used to score the weights. Currently, INC supports **magnitude**, **gradient**, **snip** and **snip_momentum** criteria. [Details](../../docs/source/pruning_details.md#pruning-criteria).
+
+
+
+### Pruning Schedules
+
+
+
+Pruning schedule defines the way the model reach the target sparsity (the ratio of pruned weights). Both **one-shot** and **iterative** pruning schedules are supported. [Details](../../docs/source/pruning_details.md#pruning-schedule).
+
+
+
+
+### Pruning Types
+
+
+
+Pruning type defines how the masks are generated and applied to a neural network. Both **pattern_lock** and **progressive** types are supported by INC. [Details](../../docs/source/pruning_details.md#pruning-type).
+
+
+
+### Regularization
+
+
+
+Regularization is a technique that discourages learning a more complex model and therefore performs variable-selection. In the image below, some weights are pushed to be as small as possible and the connections are thus sparsified. **Group-lasso** method is used in INC.
+[Details](../../docs/source/pruning_details.md#regularization).
+
+
+
+
+## Get Started with Pruning API
+
+
+
+Neural Compressor `Pruning` API is defined under `neural_compressor.pruning`, which takes a user-defined config object as input.
+Users can pass the customized training/evaluation functions to `Pruning` in various scenarios.
+
+
+
+The following section is an example of how to use hooks in user pass-in training function to perform BERT training. Our pruning API supports multiple pruner objects in a single Pruning object, which means we can apply different pruning configurations for different layers in a model. Since these pruning configurations share the same parameter names, we introduce a global-local configuration structure to initialize a Pruning object. First, we set up a dict-like local_config, which refers to some unique configurations for specific pruners. Afterwards, we pass this local_config dict and common configurations for all pruners (known as "global setting") to Pruning's initialization function. Below is code example for how to utilize our global-local configuration method to initialize a Pruning object.
+
+
+
+```python
+from neural_compressor.pruning import Pruning, WeightPruningConfig
+
+config = WeightPruningConfig(
+ local_configs, # An example of local_configs is shown below.
+ target_sparsity=0.8, start_step=1, end_step=10, pruning_frequency=1
+)
+prune = Pruning(config) # Pruning constructor.
+prune.model = model # Set model object to prune.
+prune.on_train_begin() # Execute on_train_begin hook before training.
+for epoch in range(num_train_epochs):
+ model.train()
+ prune.on_epoch_begin(epoch) # Execute on_epoch_begin hook before each epoch.
+ for step, batch in enumerate(train_dataloader):
+ prune.on_step_begin(step) # Execute on_step_begin hook before each step.
+ outputs = model(**batch)
+ loss = outputs.loss
+ loss.backward()
+ prune.on_before_optimizer_step() #Execute on_before_optimizer_step() hook before optimization.
+ optimizer.step()
+ prune.on_after_optimizer_step() #Execute on_after_optimizer_step() hook after optimization.
+ scheduler.step() # Update learning rate schedule
+ model.zero_grad()
+ prune.on_step_end() # Execute on_step_end hook after each step.
+ prune.on_epoch_end() # Execute on_epoch_end hook after each epoch.
+...
+```
+
+```python
+local_configs = [
+ {
+ 'target_sparsity': 0.9, # Target sparsity ratio of modules.
+ 'pruning_type': "snip_momentum", # Default pruning type.
+ 'pattern': "4x1", # Default pruning pattern.
+ 'op_names': ['layer1.*'], # A list of modules that would be pruned.
+ 'excluded_op_names': ['layer3.*'], # A list of modules that would not be pruned.
+ 'start_step': 0, # Step at which to begin pruning.
+ 'end_step': 10, # Step at which to end pruning.
+ 'pruning_scope': "global", # Default pruning scope.
+ 'pruning_frequency': 1, # Frequency of applying pruning.
+ 'min_sparsity_ratio_per_op': 0.0, # Minimum sparsity ratio of each module.
+ 'max_sparsity_ratio_per_op': 0.98, # Maximum sparsity ratio of each module.
+ 'sparsity_decay_type': "exp", # Function applied to control pruning rate.
+ 'pruning_op_types': ['Conv', 'Linear'], # Types of op that would be pruned.
+ },
+ {
+ "op_names": ['layer3.*'], # A list of modules that would be pruned.
+ 'target_sparsity': 0.7, # Target sparsity ratio of modules.
+ "pruning_type": "snip_momentum_progressive", # Pruning type for the listed ops.
+ }
+ ]
+```
+
+ In the case mentioned above, pruning process can be done by pre-defined hooks in Neural Compressor. Users need to place those hooks inside the training function. The pre-defined Neural Compressor hooks are listed below.
+
+
+
+```
+on_train_begin() : Execute at the beginning of training phase.
+on_epoch_begin(epoch) : Execute at the beginning of each epoch.
+on_step_begin(batch) : Execute at the beginning of each batch.
+on_step_end() : Execute at the end of each batch.
+on_epoch_end() : Execute at the end of each epoch.
+on_before_optimizer_step() : Execute before optimization step.
+on_after_optimizer_step() : Execute after optimization step.
+```
+
+
+
+
+
+
+## Examples
+
+
+
+We validate the pruning technique on typical models across various domains (including CV and NLP) and the examples are listed in [Pruning Examples](../../docs/source/pruning_details.md#examples). A complete overview of validated examples including quantization, pruning and distillation results could be found in [INC Validated examples](../../docs/source/validated_model_list.md#validated-pruning-examples).
+
+
+Please refer to pruning examples([PyTorch](../../examples/README.md#Pruning-1)) for more information.
+
+
diff --git a/neural_compressor/pruner/__init__.py b/neural_compressor/pruner/__init__.py
new file mode 100644
index 00000000000..d33331cae08
--- /dev/null
+++ b/neural_compressor/pruner/__init__.py
@@ -0,0 +1,17 @@
+"""prune init."""
+# !/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2022 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
\ No newline at end of file
diff --git a/neural_compressor/pruner/criteria.py b/neural_compressor/pruner/criteria.py
new file mode 100644
index 00000000000..0397fca4c82
--- /dev/null
+++ b/neural_compressor/pruner/criteria.py
@@ -0,0 +1,188 @@
+"""pruning criterion."""
+# !/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2022 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from neural_compressor.utils.utility import LazyImport
+torch = LazyImport('torch')
+
+
+CRITERIAS = {}
+
+
+def register_criterion(name):
+ """Register a criterion to the registry."""
+
+ def register(criterion):
+ CRITERIAS[name] = criterion
+ return criterion
+
+ return register
+
+
+def get_criterion(config, modules):
+ """Get registered criterion class."""
+ name = config["criterion_type"]
+ if name not in CRITERIAS.keys():
+ assert False, f"criteria does not support {name}, currently only support {CRITERIAS.keys()}"
+ return CRITERIAS[name](modules, config)
+
+
+class PruningCriterion:
+ """Pruning base criterion.
+
+ Args:
+ config: A config dict object that includes information about pruner and pruning criterion.
+ modules: A dict {"module_name": Tensor} that stores the pruning modules' weights.
+
+ Attributes:
+ scores: A dict {"module_name": Tensor} that stores the scores of pruning modules.
+ """
+
+ def __init__(self, modules, config):
+ """Initiliaze a pruning criterion."""
+ self.scores = {}
+ self.modules = modules
+ self.config = config
+
+ def on_step_begin(self):
+ """Calculate and store the pruning scores of pruning modules at the beginning of a step."""
+ pass
+
+ def on_after_optimizer_step(self):
+ """Calculate and store the pruning scores of pruning modules after the optimizer step."""
+ pass
+
+
+@register_criterion('magnitude')
+class MagnitudeCriterion(PruningCriterion):
+ """Pruning criterion.
+
+ The magnitude criterion_class is derived from PruningCriterion.
+ The magnitude value is used to score and determine if a weight is to be pruned.
+
+ Args:
+ config: A config dict object that includes information about pruner and pruning criterion.
+ modules: A dict {"module_name": Tensor} that stores the pruning modules' weights.
+
+ Attributes:
+ scores: A dict {"module_name": Tensor} that stores the scores of pruning modules.
+ """
+
+ def __init__(self, modules, config):
+ """Initiliaze a magnitude pruning criterion."""
+ super(MagnitudeCriterion, self).__init__(modules, config)
+
+ def on_step_begin(self):
+ """Calculate and store the pruning scores based on magtinude criterion."""
+ with torch.no_grad():
+ for key in self.modules.keys():
+ p = self.modules[key].weight.data
+ self.scores[key] = p
+
+
+@register_criterion('gradient')
+class GradientCriterion(PruningCriterion):
+ """Pruning criterion.
+
+ The gradient criterion_class is derived from PruningCriterion.
+ The absolute value of gradient is used to score and determine if a weight is to be pruned.
+
+ Args:
+ config: A config dict object that includes information about pruner and pruning criterion.
+ modules: A dict {"module_name": Tensor} that stores the pruning modules' weights.
+
+ Attributes:
+ scores: A dict {"module_name": Tensor} that stores the scores of pruning modules.
+ """
+
+ def __init__(self, modules, config):
+ """Initiliaze a gradient pruning criterion."""
+ super(GradientCriterion, self).__init__(modules, config)
+
+ def on_after_optimizer_step(self):
+ """Calculate and store the pruning scores based on gradient criterion."""
+ with torch.no_grad():
+ for key in self.modules.keys():
+ p = self.modules[key].weight
+ self.scores[key] = torch.abs(p.grad)
+
+
+@register_criterion('snip')
+class SnipCriterion(PruningCriterion):
+ """Pruning criterion.
+
+ The snip criterion_class is derived from PruningCriterion.
+ The product of magnitude and gradient is used to score and determine if a weight is to be pruned.
+ Please refer to SNIP: Single-shot Network Pruning based on Connection Sensitivity.
+ (https://arxiv.org/abs/1810.02340)
+
+ Args:
+ config: A config dict object that includes information about pruner and pruning criterion.
+ modules: A dict {"module_name": Tensor} that stores the pruning modules' weights.
+
+ Attributes:
+ scores: A dict {"module_name": Tensor} that stores the scores of pruning modules.
+ """
+
+ def __init__(self, modules, config):
+ """Initiliaze a snip pruning criterion."""
+ super(SnipCriterion, self).__init__(modules, config)
+ assert self.config.end_step > 0, "gradient based criterion does not work on step 0"
+
+ def on_after_optimizer_step(self):
+ """Calculate and store the pruning scores based on snip criterion."""
+ ##self.mask_weights()
+ with torch.no_grad():
+ for key in self.modules.keys():
+ p = self.modules[key].weight
+ self.scores[key] = torch.abs(p * p.grad)
+
+
+@register_criterion('snip_momentum')
+class SnipMomentumCriterion(PruningCriterion):
+ """Pruning criterion.
+
+ The snip_momentum criterion_class is derived from PruningCriterion.
+ A momentum mechanism is used to calculate snip score, which determines if a weight is to be pruned.
+
+ Args:
+ config: A config dict object that includes information about pruner and pruning criterion.
+ modules: A dict {"module_name": Tensor} that stores the pruning modules' weights.
+ alpha: A parameter that determines how much of the snip score is preserved from last pruning step.
+ beta: A parameter that determines how much of the snip score is updated at the current step.
+
+ Attributes:
+ scores: A dict {"module_name": Tensor} that stores the scores of pruning modules.
+ """
+
+ def __init__(self, modules, config):
+ """Initiliaze a snip_momentum pruning criterion."""
+ super(SnipMomentumCriterion, self).__init__(modules, config)
+ assert self.config.end_step > 0, "gradient based criterion does not work on step 0"
+ for key in modules.keys():
+ p = modules[key].weight
+ self.scores[key] = torch.zeros(p.shape).to(p.device)
+
+ self.alpha = 0.9
+ self.beta = 1.0
+
+ def on_after_optimizer_step(self):
+ """Calculate and store the pruning scores based on snip_momentum criterion."""
+ with torch.no_grad():
+ for key in self.modules.keys():
+ p = self.modules[key].weight
+ self.scores[key] *= self.alpha
+ self.scores[key] += self.beta * torch.abs(p * p.grad)
diff --git a/neural_compressor/pruner/logger.py b/neural_compressor/pruner/logger.py
new file mode 100644
index 00000000000..f39f1198a65
--- /dev/null
+++ b/neural_compressor/pruner/logger.py
@@ -0,0 +1,23 @@
+"""logger module."""
+# !/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2022 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+try:
+ from neural_compressor.utils import logger
+except:
+ import logging
+ logger = logging.getLogger(__name__)
diff --git a/neural_compressor/pruner/patterns.py b/neural_compressor/pruner/patterns.py
new file mode 100644
index 00000000000..8ad1d1fb6f0
--- /dev/null
+++ b/neural_compressor/pruner/patterns.py
@@ -0,0 +1,1110 @@
+"""pruning patterns."""
+# !/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2022 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from neural_compressor.utils.utility import LazyImport
+torch = LazyImport('torch')
+from .logger import logger
+from collections import namedtuple
+
+PATTERNS = {}
+
+
+def register_pattern(name):
+ """Class decorator used to register a Pattern subclass to the registry.
+
+ Decorator function used before a Pattern subclasses.
+ Make sure that this Pattern class can be registered in PATTERNS.
+
+ Args:
+ name: A string. Define the pattern type name which will be included in a pruning process.
+
+ Returns:
+ cls: The class of register.
+ """
+
+ def register(pattern):
+ """Register patterns."""
+ PATTERNS[name] = pattern
+ return pattern
+
+ return register
+
+
+def get_pattern(config, modules):
+ """Get registered pattern class.
+
+ Get a Pattern object from PATTERNS.
+
+ Args:
+ config: A config dict object. Contains the pattern information.
+ modules: torch neural network modules, which will be pruned with the pattern
+
+ Returns:
+ A Pattern object.
+
+ Raises:
+ AssertionError: Currently only support patterns which have been registered in PATTERNS.
+ """
+ name = config.pattern
+ name = name.split('_')[-1]
+ if "x" in name:
+ return PATTERNS["NxM"](config, modules)
+ if ":" in name:
+ return PATTERNS["N:M"](config, modules)
+ assert False, f"currently only support {PATTERNS.keys()}"
+
+
+SparsityInfo = namedtuple("SparsityInfo", ['zero_cnt', 'total_cnt', 'sparsity_ratio'])
+
+
+class BasePattern:
+ """Pruning Pattern.
+
+ It defines the basic pruning unit and how this unit will be pruned during pruning, e.g. 4x1, 2:4
+
+ Args:
+ config: A config dict object. Contains the pattern information.
+ modules: torch neural network modules, which will be pruned with the pattern
+
+ Attributes:
+ pattern: A config dict object. The pattern related part in args config.
+ is_global: A bool. Whether the pruning take global pruning option.
+ Global pruning means that all pruning layers are gathered to calculate pruning criterion.
+ Local pruning, on the contrast, means that pruning layers are to calculate criterion individually.
+ keep_mask_layers:A dict. the layers whose mask will not be updated
+ invalid_layers: the layers whose shape don't fit the patten
+ modules: torch neural network modules, which will be pruned with the pattern
+ config: A config dict object. Contains all the information including the pattern's.
+ max_sparsity_ratio_per_op: A float. The maximum sparsity that one layer could reach
+ min_sparsity_ratio_per_op: A float. The minimum sparsity that one layer could reach
+ target_sparsity: A float. The sparsity ratio of the modules will be reached after pruning.
+
+ """
+
+ def __init__(self, config, modules):
+ """Initialize the basic pruning unit of a pattern."""
+ self.pattern = config.pattern
+ self.is_global = config.pruning_scope == "global"
+ self.keep_mask_layers = {}
+ self.invalid_layers = []
+ self.modules = modules
+ self.config = config
+ self.max_sparsity_ratio_per_op = self.config['max_sparsity_ratio_per_op']
+ self.min_sparsity_ratio_per_op = self.config['min_sparsity_ratio_per_op']
+ self.target_sparsity_ratio = self.config['target_sparsity']
+ # Not using deterministic_algorithms for all examples
+ torch.use_deterministic_algorithms(False)
+
+ def reduce_tensor(self, data, dim):
+ """Reduce the data along the given dimension.
+
+ Args:
+ data: The input data
+ dim: The reduced axis
+
+ Returns:
+ The reduced tensor
+
+ """
+ name = self.config['criterion_reduce_type']
+ if name == "mean":
+ return torch.mean(data, dim=dim)
+ elif name == "sum":
+ return torch.sum(data, dim=dim)
+ elif name == "max":
+ return torch.max(data, dim=dim)[0]
+ else:
+ assert False, "currently only support mean, sum and max reduce type"
+
+ def get_masks(self, scores, target_sparsity_ratio, pre_masks):
+ """Generate the weight masks according to the weight score and the current target sparsity ratio.
+
+ Args:
+ scores: A dict{“layer_name”: Tensor}. Store the pruning scores of weights.
+ target_sparsity_ratio: A float. After pruning, the sparsity of the modules will reach this value.
+ pre_masks: A dict{"layer_name": Tensor}. The previous masks generated after the last pruning step.
+
+ Returns:
+ A dict with the identical size as pre_masks. Update the 0/1 values in it. 1 means keep, 0 means drop
+
+ """
+ if self.is_global:
+ return self.get_masks_global(scores, target_sparsity_ratio, pre_masks)
+ else:
+ return self.get_masks_local(scores, target_sparsity_ratio, pre_masks)
+
+ def get_masks_global(self, scores, target_sparsity_ratio, pre_masks):
+ """Generate the weight masks for global pruning, please refer to function get_masks for more information."""
+ raise NotImplementedError
+
+ def get_masks_local(self, scores, target_sparsity_ratio, pre_masks):
+ """Generate the weight masks for local pruning.
+
+ Args:
+ scores: A dict{“layer_name”: Tensor}. Store the pruning scores of weights.
+ target_sparsity_ratio: A float. After pruning, the sparsity of the modules will reach this value.
+ pre_masks: A dict{"layer_name": Tensor}. The previous masks generated after the last pruning step.
+
+ Returns:
+ A dict with the identical size as pre_masks. Update the 0/1 values in it. 1 means keep, 0 means drop
+
+ """
+ masks = {}
+ if isinstance(self, PatternNxM) and not isinstance(self.block_size, dict):
+ self.block_size = self.get_block_size_dict(pre_masks)
+ for key in scores.keys():
+ score = {key: scores[key]}
+ pre_mask = {key: pre_masks[key]}
+ mask = self.get_masks_global(score, target_sparsity_ratio, pre_mask)
+ masks[key] = mask[key]
+ return masks
+
+ def get_single_mask_per_target_ratio(self, score, exact_sparsity_ratio):
+ """Generate a mask for one layer with the exact_sparsity_ratio.
+
+ Args:
+ score: A Tensor. the pruning scores of each weight elements.
+ exact_sparsity_ratio: A float. After pruning, the layer's sparsity will reach this value.
+
+ Returns:
+ A Tensor with the identical size as score. a new mask.
+ """
+ flattern_score = torch.flatten(score)
+ k = int(exact_sparsity_ratio * flattern_score.numel())
+ threshold, _ = torch.kthvalue(flattern_score, k)
+ if not k < 1:
+ zero = torch.tensor([0.]).to(score.device)
+ one = torch.tensor([1.]).to(score.device)
+ mask = torch.where(score <= threshold, zero, one)
+ else:
+ mask = torch.ones(score.shape, device=score.device)
+ return mask
+
+ def get_block_size_dict(self, data):
+ """Get pattern size for each module.
+
+ this is mainly for per-channel pruning when each module has different pruning size
+
+ Args:
+ data: the input data
+
+ Returns:
+ To be implemented in subclasses.
+ """
+ raise NotImplementedError
+
+ def get_sparsity_ratio(self, pre_masks, return_dict=False):
+ """Calculate the zero elements' ratio in pre_masks.
+
+ please be noted that the implementations in subclass are little tricky
+ TODO: need to refactor this function
+
+ Args:
+ pre_masks: Dict{"layer_name": Tensor}. The masks generated after the last pruning step.
+ return_dict: Whether need to return more information like zero_cnt and total_cnt
+ Returns:
+ A float. The zero elements' ratio in pre_masks.
+ """
+ zero_cnt = 0
+ total_cnt = 0
+ for key in pre_masks.keys():
+ pre_mask = pre_masks[key]
+ zero_cnt += torch.sum(pre_mask == 0.0).data.item()
+ total_cnt += pre_masks[key].numel() ##FIXME
+ if return_dict:
+ return {"sparsity_ratio": float(zero_cnt) / total_cnt, "zero_cnt": zero_cnt, "total_cnt": total_cnt}
+ else:
+ return float(zero_cnt) / total_cnt
+
+ def get_pattern_lock_masks(self, modules):
+ """Obtain masks from original weight map according the pattern and weights' zero positions.
+
+ Args:
+ modules: a dict{“layer_name”: Tensor}. Store weights.
+
+ Returns:
+ A dict with the identical size as modules, containing pattern lock masks.
+ """
+ pattern_lock_masks = {}
+ for key in modules.keys():
+ weight = modules[key].weight
+ shape = weight.shape
+ mask = torch.ones(shape)
+ mask[weight == 0] = 0.0
+ pattern_lock_masks[key] = mask.to(weight.device)
+ return pattern_lock_masks
+
+ def check_layer_validity(self):
+ """Check if a layer is valid for this block_size."""
+ pass
+
+ def get_reduced_masks_from_data(self, data, key):
+ """Obtain the unpruned weights and reshape according to the block_size."""
+ raise NotImplementedError
+
+ def update_residual_cnt(self, masks, target_sparsity_ratio):
+ """Update the number of parameters yet to be pruned.
+
+ Args:
+ masks: the current pruning mask
+ target_sparsity_ratio: A float. After pruning, the sparsity of the modules will reach this value.
+
+ Returns:
+ An int. How many weights still need to be pruned to achieve the target sparsity ratio
+ """
+ self.total_params_cnt = self.get_sparsity_ratio(masks, return_dict=True)["total_cnt"]
+ to_prune_cnt = int(self.total_params_cnt * target_sparsity_ratio)
+ for key in masks.keys():
+ if self.keep_mask_layers.get(key, False):
+ zero_cnt = self.get_sparsity_ratio({key: masks[key]}, return_dict=True)["zero_cnt"]
+ to_prune_cnt -= zero_cnt
+
+ return to_prune_cnt
+
+ def get_sparsity_ratio_each_layer(self, masks):
+ """Calculate the sparsity ratio of each layer.
+
+ TODO: need to refactor this function
+
+ Args:
+ masks: The current weight masks
+
+ Returns:
+ infos: the sparsity information for each layer, sparsity_ratio, zero_point and total cnts
+ SparsityInfo: the sparsity information for the model
+ """
+ infos = {}
+ zero_cnts = 0
+ total_cnts = 0
+ for key in masks.keys():
+ if key in self.invalid_layers:
+ continue
+ reduced_mask = self.get_reduced_masks_from_data(masks[key], key)
+ zero_cnt = (int(torch.sum(reduced_mask == 0.0).data.item()))
+ total_cnt = int(reduced_mask.numel())
+ sparsity_ratio = float(zero_cnt) / total_cnt
+ val = SparsityInfo(zero_cnt, total_cnt, sparsity_ratio)
+ infos[key] = val
+ zero_cnts += zero_cnt
+ total_cnts += total_cnt
+ sparsity_ratio = float(zero_cnts) / total_cnts
+ return infos, SparsityInfo(zero_cnts, total_cnts, sparsity_ratio)
+
+ def adjust_ratio(self, masks: dict, layer_name: str, key_new_sparsity: SparsityInfo,
+ max_sparsity_ratio: float, min_sparsity_ratio: float, \
+ final_target_sparsity_ratio: float):
+ """Limits the sparsity of a layer to the set threshold interval.
+
+ Args:
+ masks: the weight masks
+ layer_name: the to be examined layer name
+ key_new_sparsity: the proposal ratio for the layer
+ max_sparsity_ratio: A float. The maximum sparsity that one layer could reach
+ min_sparsity_ratio: A float. The minimum sparsity that one layer could reach
+ final_target_sparsity_ratio: the final target sparsity ratio
+
+ Returns:
+ A bool indicating if the ratio needs to be adjusted and the adjusted sparsity ratio.
+ adjust_sparsity_ratio: the ratio adjusted
+ """
+ need_adjust = False
+ adjust_zero_cnt = key_new_sparsity.zero_cnt
+ adjust_sparsity_ratio = key_new_sparsity.sparsity_ratio
+ adjust_total_cnt = key_new_sparsity.total_cnt
+
+ if adjust_sparsity_ratio > max_sparsity_ratio:
+ need_adjust = True
+ adjust_sparsity_ratio = max_sparsity_ratio
+ adjust_zero_cnt = int(adjust_total_cnt * max_sparsity_ratio)
+
+ if adjust_sparsity_ratio < min_sparsity_ratio:
+ return need_adjust, adjust_sparsity_ratio
+
+ ##TODO no need to calculate each time
+ infos, net_info = self.get_sparsity_ratio_each_layer(masks)
+
+ any_exceed_target_ratio = False
+ for key in infos.keys():
+ if infos[key].sparsity_ratio > final_target_sparsity_ratio:
+ any_exceed_target_ratio = True
+ break
+ if adjust_sparsity_ratio > final_target_sparsity_ratio:
+ any_exceed_target_ratio = True
+ if not any_exceed_target_ratio:
+ return need_adjust, adjust_sparsity_ratio
+
+ zero_cnt_below_min_sparsity = 0
+ total_cnt_below_min_sparsity = 0
+ zero_cnt_above_min_sparsity = 0
+ for key in infos.keys():
+ info = infos[key]
+ if key == layer_name:
+ info = SparsityInfo(zero_cnt=adjust_zero_cnt, total_cnt=adjust_total_cnt,
+ sparsity_ratio=adjust_sparsity_ratio)
+ if info.sparsity_ratio < min_sparsity_ratio:
+ zero_cnt_below_min_sparsity += info.zero_cnt
+ total_cnt_below_min_sparsity += info.total_cnt
+ else:
+ zero_cnt_above_min_sparsity += info.zero_cnt
+
+ gap_cnt = int(total_cnt_below_min_sparsity * min_sparsity_ratio) - zero_cnt_below_min_sparsity
+ remaining_cnt = int(net_info.total_cnt * final_target_sparsity_ratio) \
+ - zero_cnt_above_min_sparsity - zero_cnt_below_min_sparsity
+ if remaining_cnt >= gap_cnt:
+ return need_adjust, adjust_sparsity_ratio
+ else:
+ new_zero_cnt = adjust_zero_cnt - (gap_cnt - remaining_cnt)
+ new_sparsity_ratio = float(new_zero_cnt) / adjust_total_cnt
+ ##adjust_zero_cnt = new_zero_cnt
+ adjust_sparsity_ratio = new_sparsity_ratio
+ return True, adjust_sparsity_ratio
+
+
+@register_pattern('NxM')
+class PatternNxM(BasePattern):
+ """Pruning Pattern.
+
+ A Pattern class derived from BasePattern. In this pattern, the weights in a NxM block will be pruned or kept
+ during one pruning step.
+
+ Args:
+ config: A config dict object. Contains the pattern information.
+
+ Attributes:
+ block_size: A list of two Integers. The height and width of the block.
+ Please be aware that the vertical direction of a Linear layer's weight in PyTorch refer to output channel.
+ Because PyTorch's tensor matmul has a hidden transpose operation.
+ """
+
+ def __init__(self, config, modules):
+ """Initialize the basic pruning unit of NXM pattern."""
+ super(PatternNxM, self).__init__(config, modules)
+ pattern = self.pattern.split('_')[-1]
+ self.N = pattern.split('x')[0]
+ self.M = pattern.split('x')[1]
+ if self.N == "channel": ##channel-wise pruning mode
+ self.block_size = ["channel", int(self.M)]
+ elif self.M == "channel": ##channel-wise pruning mode
+ self.block_size = [int(self.N), "channel"]
+ else:
+ self.block_size = [int(pattern.split('x')[0]), int(pattern.split('x')[1])]
+ self.total_params_cnt = -1
+
+ self.block_size = self.get_block_size_dict()
+ self.check_layer_validity()
+
+ def get_block_size_dict(self):
+ """Calulate the zero elements' ration in pre_masks.
+
+ Args:
+ data: Dict{"layer_name": Tensor}. Store weights or scores.
+
+ Returns:
+ A dict. Dict{"layer_name": [block_size_1, block_size_2]}.
+ Containing layers' corresponding pruning pattern's block shape.
+ Because in channel-wise pruning different layers can have different pruning patterns.
+ """
+ data = self.modules
+ block_sizes_dict = {}
+ if self.N == "channel" or self.M == "channel":
+ for key in data.keys():
+ if isinstance(data[key], torch.nn.Module):
+ shape = data[key].weight.shape
+ else:
+ shape = data[key].shape
+ if self.N == "channel":
+ block_sizes_dict[key] = [shape[0], 1]
+ else:
+ block_sizes_dict[key] = [1, shape[1]]
+ return block_sizes_dict
+ for key in data.keys():
+ block_sizes_dict[key] = self.block_size
+ return block_sizes_dict
+
+ def check_layer_validity(self):
+ """Check if a layer is valid for this block_size."""
+ block_sizes = self.block_size
+ datas = self.modules
+ for key in datas.keys():
+ data = datas[key].weight
+ data = self._reshape_orig_to_2dims(data)
+ shape = data.shape
+ block_size = block_sizes[key]
+ if shape[0] % block_size[0] != 0 or shape[1] % block_size[1] != 0: ## only consider input channel
+ self.invalid_layers.append(key)
+ logger.warning(f"{key} shape {data.shape} cannot be divided by {self.pattern}")
+
+ def get_reduced_masks_from_data(self, data, key):
+ """Obtain the unpruned weights and reshape according to the block_size."""
+ assert key not in self.invalid_layers
+ block_size = self.block_size[key]
+ data = self._reshape_orig_to_2dims(data)
+ shape = data.shape
+ new_shape = [shape[0] // block_size[0], block_size[0], shape[1] // block_size[1], block_size[1]]
+ data = data.reshape(new_shape)
+ data = data.sum(-1).sum(1)
+ reduced_mask = data != 0
+ return reduced_mask
+
+ def get_sparsity_ratio(self, pre_masks, return_dict=False):
+ """Please note that the zero cnt and total cnt are all block_wise for supporting channel-wise pruning.
+
+ Args:
+ pre_masks: Dict{"layer_name": Tensor}. The masks generated after the last pruning step.
+
+ Returns:
+ A float. Calculate the zero elements' ratio in pre_masks.
+ """
+ zero_cnt = 0
+ total_cnt = 0
+ for key in pre_masks.keys():
+ if key in self.invalid_layers:
+ continue
+ reduced_mask = self.get_reduced_masks_from_data(pre_masks[key], key)
+ zero_cnt += (int(torch.sum(reduced_mask == 0.0).data.item()))
+ total_cnt += int(reduced_mask.numel())
+ if total_cnt == 0:
+ sparsity_ratio = 0.0
+ else:
+ sparsity_ratio = float(zero_cnt) / total_cnt
+ if return_dict:
+ return {"sparsity_ratio": sparsity_ratio, "zero_cnt": zero_cnt, "total_cnt": total_cnt}
+ else:
+ return sparsity_ratio
+
+ def get_sparsity_ratio_progressive(self, pre_masks, return_dict=False):
+ """Calculate the sparsity ratio of each layer."""
+ zero_cnt = 0
+ total_cnt = 0
+ for key in pre_masks.keys():
+ if key in self.invalid_layers:
+ continue
+ # progressive masks are unstructured, therefore directly find zeros
+ zero_cnt += float(torch.sum(pre_masks[key] == 0).data.item())
+ total_cnt += float(pre_masks[key].numel())
+ return (zero_cnt / total_cnt)
+
+ def _reshape_orig_to_2dims(self, data):
+ """Mainly for processing layer dims not equal to 2, for example conv layer.
+
+ Args:
+ data: the input
+
+ Returns:
+ a reshaped data
+ """
+ ##TODO need to verify whether it's ok for transposed conv
+ if len(data.shape) == 4:
+ data = data.permute(0, 2, 3, 1) ##cout,k,k,cin
+ data = data.reshape(data.shape[0], -1)
+ return data
+
+ def _reshape_2dims_to_orig(self, data, orig_shape):
+ """Mainly for recover layer dims not equal to 2, for example conv layer.
+
+ Args:
+ data: input
+ orig_shape: target shape
+
+ Returns:
+ a reshaped data
+ """
+ if len(orig_shape) == 4:
+ data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[3],
+ orig_shape[1])
+ data = data.permute(0, 3, 1, 2)
+ return data
+
+ def reshape_orig_to_pattern(self, data, key):
+ """Reshape the data(s1,s2) to [s1/N,N,s2,s2/M].
+
+ Args:
+ data: the input
+ key: the layer name
+
+ Returns:
+ The reshaped input tensor.
+ """
+ block_size = self.block_size[key]
+ data = self._reshape_orig_to_2dims(data)
+ shape = data.shape
+ new_shape = [shape[0] // block_size[0], block_size[0], shape[1] // block_size[1],
+ block_size[1]]
+ data = data.reshape(new_shape)
+ return data
+
+ def reshape_reduced_to_orig(self, data, key, orig_shape):
+ """Reshape the data [s1/N,s2/M] to [s1,s2], also permute dims for conv layer.
+
+ Args:
+ data:
+ key:
+ orig_shape:
+
+ Returns:
+ Original shape data
+ """
+ block_size = self.block_size[key]
+ data = data.repeat_interleave(block_size[0], dim=0).repeat_interleave(block_size[1], dim=-1)
+ data = self._reshape_2dims_to_orig(data, orig_shape)
+ return data
+
+ def reduce_scores(self, scores):
+ """Recalculate the pruning scores after reducing the data."""
+ new_scores = {}
+ for key in scores.keys():
+ if key in self.invalid_layers:
+ continue
+ if self.keep_mask_layers.get(key, False):
+ continue
+ self.keep_mask_layers[key] = False
+ current_score = scores[key]
+ current_score = self.reshape_orig_to_pattern(current_score, key)
+ ##sum or mean is quite different for per channel pruning
+ current_score_sum = self.reduce_tensor(self.reduce_tensor(current_score, dim=-1), dim=1)
+ new_scores[key] = current_score_sum
+ return new_scores
+
+ def get_mask_per_threshold(self, score, threshold, block_size):
+ """Get the mask per threshold."""
+ zero = torch.tensor([0.]).to(score.device)
+ one = torch.tensor([1.]).to(score.device)
+ mask = torch.where(score <= threshold, zero, one)
+ mask = mask.repeat_interleave(block_size[0], dim=0).repeat_interleave(block_size[1], dim=-1)
+ return mask
+
+ def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks,
+ keep_exact_sparsity_ratio=True):
+ """Generate masks for layers.
+
+ Gather all layer's scores together and calculate a common threshold.
+ This threshold will be applied for all layers.
+
+ Args:
+ scores: A dict{“layer_name”: Tensor}. Store the pruning scores of weights.
+ cur_target_sparsity_ratio: A float. After pruning, the model's sparsity will reach this value.
+ pre_masks: A dict{"layer_name": Tensor}. The masks generated after the last pruning step.
+ max_sparsity_ratio_per_op: A float. The maximum sparsity that one layer can reach.
+ keep_pre_masks: A bool. If True, keep the masks unchanged.
+
+ Returns:
+ A dict with the identical size as pre_masks. Update the 0/1 values in it.
+ """
+ ##keep the masks if the layer exceed max sparsity ratio
+
+ masks = pre_masks
+
+ k_blockwise = self.update_residual_cnt(masks, cur_target_sparsity_ratio)
+ if k_blockwise <= 0:
+ return masks
+ new_scores = self.reduce_scores(scores)
+ global_scores = torch.cat([torch.flatten(v) for v in new_scores.values()])
+ residual_k = k_blockwise
+ not_exceed_layers = [key for key in new_scores.keys()]
+ if self.min_sparsity_ratio_per_op > 0:
+ sparsity_infos_perlayer, _ = self.get_sparsity_ratio_each_layer(masks)
+
+ while True:
+ threshold, _ = torch.kthvalue(global_scores, residual_k)
+ for key in not_exceed_layers:
+ block_size = self.block_size[key]
+ score = new_scores[key]
+ mask = self.get_mask_per_threshold(score, threshold, block_size)
+ info = self.get_sparsity_ratio({key: mask}, return_dict=True)
+ zero_cnt = info["zero_cnt"]
+ total_cnt = info["total_cnt"]
+ current_sparsity_ratio = float(zero_cnt) / total_cnt
+ key_new_sparsity = SparsityInfo(zero_cnt, total_cnt, current_sparsity_ratio)
+ need_adjust, adjust_ratio = self.adjust_ratio(masks, key, key_new_sparsity,
+ self.max_sparsity_ratio_per_op,
+ self.min_sparsity_ratio_per_op,
+ self.target_sparsity_ratio)
+ if need_adjust:
+ # uptade status
+ self.keep_mask_layers[key] = True
+ masks[key] = self.get_single_mask_per_target_ratio(new_scores[key], adjust_ratio)
+ masks[key] = masks[key].repeat_interleave(block_size[0], 0).repeat_interleave(block_size[1], -1)
+ if keep_exact_sparsity_ratio:
+ zero_cnt = self.get_sparsity_ratio({key: masks[key]}, return_dict=True)["zero_cnt"]
+ residual_k -= zero_cnt
+ else:
+ masks[key] = mask
+ if not keep_exact_sparsity_ratio:
+ break
+ new_not_exceed_layers = [key for key in new_scores.keys() if not self.keep_mask_layers.get(key, False)]
+ if not_exceed_layers == new_not_exceed_layers or len(new_not_exceed_layers) == 0:
+ break
+ not_exceed_layers = new_not_exceed_layers
+ global_scores = torch.cat([torch.flatten(new_scores[key]) for key in not_exceed_layers])
+
+ for key in masks.keys():
+ if key in self.invalid_layers:
+ continue
+ if len(scores[key].shape) == 4: ## need to permute
+ mask = masks[key]
+ orig_shape = scores[key].shape
+ mask = self._reshape_2dims_to_orig(mask, orig_shape)
+ masks[key] = mask
+ layer_ratio = torch.sum(masks[key] == 0.0).data.item() / masks[key].numel()
+ logger.info(f'layer {key} sparsity_ratio is {layer_ratio}')
+ return masks
+
+ def get_pattern_lock_masks(self, modules):
+ """Obtain masks from original weight map, by masking where weights' are zero.
+
+ Args:
+ modules: A dict{“layer_name”: Tensor}. Store weights.
+
+ Returns:
+ A dict with the identical size as modules, containing pattern lock masks.
+ """
+ pattern_lock_masks = {}
+ for key in modules.keys():
+ weight = modules[key].weight
+ ori_shape = weight.shape
+ if key in self.invalid_layers:
+ mask = torch.ones(weight.shape, device=weight.device)
+ pattern_lock_masks[mask] = mask
+ continue
+ reduced_mask = self.get_reduced_masks_from_data(weight, key)
+ mask = self.reshape_reduced_to_orig(reduced_mask, key, ori_shape)
+ pattern_lock_masks[key] = mask
+ return pattern_lock_masks
+
+ # ---------------progressive related--------------------
+ def count_new_masked_cnts(self, new_added_masks):
+ """Cound the number of elements to be masked."""
+ # count how many elements are to masked,
+ new_masked_cnts = 0
+ for key in new_added_masks.keys():
+ new_masked_cnts += torch.nonzero(1 - new_added_masks[key]).size()[0]
+ return new_masked_cnts
+
+ def update_new_added_masks(self, pre_masks, cur_masks):
+ """Obtain the new set-to-zero mask during a pruning procedure.
+
+ Pre_masks, cur_masks should have identical keys bacause they stands for one model.
+ """
+ # obtain the new set-to-zero mask during a pruning procedure.
+ # pre_masks, cur_masks should have identical keys bacause they stands for one model.
+ new_added_masks = {}
+ for key in pre_masks.keys():
+ pre_mask = pre_masks[key]
+ cur_mask = cur_masks[key]
+ zero = torch.tensor([0.]).to(pre_mask.device)
+ one = torch.tensor([1.]).to(cur_mask.device)
+ new_added_masks[key] = torch.where(pre_mask == cur_mask, one, zero)
+ return new_added_masks
+
+ def update_progressive_masks(self, pre_masks, cur_masks, scores, progressive_step, progressive_configs):
+ """Generate the progressive masks."""
+ # Generate the progressive masks
+ use_global = progressive_configs["use_global"]
+ if use_global:
+ return self.update_progressive_masks_global(pre_masks, cur_masks, scores, \
+ progressive_step, progressive_configs)
+ else:
+ return self.update_progressive_masks_local(pre_masks, cur_masks, scores, \
+ progressive_step, progressive_configs)
+
+ def update_progressive_masks_linear(self, pre_masks, cur_masks, progressive_step, progressive_configs):
+ """Generate the progressive masks along the block's larger dimension."""
+ progressive_steps = progressive_configs["progressive_steps"]
+ progressive_masks = {}
+ new_added_masks = self.update_new_added_masks(pre_masks, cur_masks)
+ for key in pre_masks.keys():
+ block_size = self.block_size[key]
+ new_added_mask = new_added_masks[key]
+ # conv
+ new_added_mask = self._reshape_orig_to_2dims(new_added_mask)
+ shape = new_added_mask.shape
+ # progressive masks are generated in the direction of block's large dim.
+ if block_size[0] >= block_size[1]:
+ # NxM (N>=M), output channel pruning
+ new_shape = [shape[0] // block_size[0], progressive_steps, block_size[0] // progressive_steps,
+ shape[1] // block_size[1], block_size[1]]
+ new_added_mask_reshape = new_added_mask.reshape(new_shape)
+ new_added_mask_reshape[:, progressive_step:, :, :, :] = 1.0
+ else:
+ # NxM (N N
+ return reduced_mask
+
+ def get_least_ninm_mask_from_data(self, score):
+ """Generate the least N scores in M."""
+ current_score = score
+ M = self.M
+ N = self.N
+ current_score = self._reshape_orig_to_2dims(current_score)
+ shape = current_score.shape
+ new_shape = [shape[0], shape[1] // M, M]
+ current_score_new = current_score.reshape(new_shape)
+
+ threshold, _ = torch.kthvalue(current_score_new, N, dim=2)
+ threshold = threshold.unsqueeze(-1)
+
+ threshold = threshold.expand(shape[0], shape[1] // M, M)
+ threshold = threshold.reshape((shape[0], shape[1]))
+
+ one = torch.tensor([1.]).to(current_score.device)
+ zero = torch.tensor([0.]).to(current_score.device)
+ mask = torch.where(current_score <= threshold, zero, one)
+ return mask
+
+ def get_sparsity_ratio(self, pre_masks, return_dict=False):
+ """Please noted that the zero cnt and total cnt are all block_wise for supporting channel-wise pruning.
+
+ The return sparsity ratio is elementwised(confused, TODO).
+
+ Args:
+ pre_masks:
+ return_dict:
+
+ Returns:
+ An elementwise sparisty ratio.
+ """
+ ##simply use elemwise sparsity
+ zero_cnt = 0
+ total_cnt = 0
+ for key in pre_masks.keys():
+ if key in self.invalid_layers:
+ # total_cnt += pre_masks[key].numel() // self.M
+ continue
+ reduced_mask = self.get_reduced_masks_from_data(pre_masks[key], key)
+ zero_cnt += int((torch.sum(reduced_mask == 0)).data.item())
+ total_cnt += int(reduced_mask.numel())
+ sparsity_ratio = float(zero_cnt) / total_cnt * self.N / self.M
+
+ if return_dict:
+ return {"sparsity_ratio": sparsity_ratio, "zero_cnt": zero_cnt,
+ "total_cnt": total_cnt}
+ else:
+ return sparsity_ratio
+
+ def _reshape_orig_to_2dims(self, data):
+ if len(data.shape) == 4: ##TODO need to verify whether it's ok for transposed conv
+ data = data.permute(0, 2, 3, 1) ##cout,k,k,cin
+ data = data.reshape(data.shape[0], -1)
+ return data
+
+ def _reshape_2dims_to_orig(self, data, orig_shape):
+ if len(orig_shape) == 4:
+ data = data.reshape(orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1])
+ data = data.permute(0, 3, 1, 2)
+ return data
+
+ def reshape_orig_to_pattern(self, data, key):
+ """Reshape the data based on the pruning pattern."""
+ data = self._reshape_orig_to_2dims(data)
+ shape = data.shape
+ new_shape = [shape[0], shape[1] // self.M, self.M]
+ data = data.reshape(new_shape)
+ return data
+
+ def reshape_reduced_to_orig(self, data, key, orig_shape):
+ """Reshape the reduced data to its original shape."""
+ data = data.repeat_interleave(self.M, dim=-1)
+ return self._reshape_2dims_to_orig(data, orig_shape)
+
+ def reduce_scores(self, scores):
+ """Calculate the pruning scores after reducing the data and obtain the least N scores in M."""
+ ##to get the least N scores in M
+ M = self.M
+ N = self.N
+ least_ninm_masks = {}
+ new_scores = {}
+ for key in scores.keys():
+ if key in self.invalid_layers:
+ continue
+ if self.keep_mask_layers.get(key, False):
+ continue
+ current_score = scores[key]
+ mask = self.get_least_ninm_mask_from_data(current_score)
+ current_score_new = self._reshape_orig_to_2dims(current_score)
+ shape = current_score_new.shape
+ current_score_new = current_score_new.reshape((shape[0], shape[1]))
+ ##to get the sum of N scores in each block with M
+ current_score_new = current_score_new * (1.0 - mask)
+ current_score_new = current_score_new.reshape(shape[0], shape[1] // M, M)
+ score_sum = self.reduce_tensor(current_score_new, dim=-1)
+ least_ninm_masks[key] = mask
+ new_scores[key] = score_sum
+ return new_scores, least_ninm_masks
+
+ def get_ele_mask_per_threshold(self, score, threshold, block_size, least_ninm_mask):
+ """Get the elementwise mask per threshold.
+
+ Args:
+ score:
+ threshold:
+ block_size:
+ least_m_in_m_masks:
+
+ Returns:
+ mask:
+ """
+ zero = torch.tensor([0.]).to(score.device)
+ one = torch.tensor([1.]).to(score.device)
+ mask = torch.where(score <= threshold, zero, one)
+ mask = mask.repeat_interleave(block_size[1], dim=-1)
+ ## both zero will be zero
+ mask = (mask + least_ninm_mask)
+ mask = torch.where(mask <= 0, zero, one)
+ return mask
+
+ def get_masks_global(self, scores, cur_target_sparsity_ratio, pre_masks,
+ keep_exact_sparsity_ratio=True):
+ """Generate masks for layers.
+
+ Gather all layer's scores together and calculate a common threshold.
+ This threshold will be applied for all layers.
+
+ Args:
+ scores: A dict{“layer_name”: Tensor}. Store the pruning scores of weights.
+ target_sparsity_ratio: A float. After pruning, the model's sparsity will reach this value.
+ pre_masks: A dict{"layer_name": Tensor}. The masks generated after the last pruning step.
+ max_sparsity_ratio_per_op: A float. The maximum sparsity that one layer can reach.
+
+ Returns:
+ A dict with the identical size as pre_masks. Update the 0/1 values in it.
+ """
+ masks = pre_masks
+
+ block_sparsity_ratio = cur_target_sparsity_ratio * self.M / self.N
+ k_blockwise = self.update_residual_cnt(pre_masks, block_sparsity_ratio)
+ if k_blockwise <= 0:
+ return masks
+ new_scores, least_ninm_masks = self.reduce_scores(scores)
+ global_scores = torch.cat([torch.flatten(v) for v in new_scores.values()]) ##block_wise
+ residual_k = k_blockwise
+ not_exceed_layers = [key for key in new_scores.keys()]
+
+ while True:
+ threshold, _ = torch.kthvalue(global_scores, residual_k)
+ for key in not_exceed_layers:
+ score = new_scores[key]
+ mask = self.get_ele_mask_per_threshold(score, threshold, (self.N, self.M), least_ninm_masks[key])
+ info = self.get_sparsity_ratio({key: mask}, return_dict=True)
+ zero_cnt = info["zero_cnt"]
+ total_cnt = info["total_cnt"]
+ current_sparsity_ratio = float(zero_cnt) / total_cnt
+ key_new_sparsity = SparsityInfo(zero_cnt, total_cnt, current_sparsity_ratio)
+ need_adjust, adjust_ratio = self.adjust_ratio(masks, key, key_new_sparsity,
+ self.max_sparsity_ratio_per_op * self.M / self.N,
+ self.min_sparsity_ratio_per_op * self.M / self.N,
+ self.target_sparsity_ratio * self.M / self.N)
+
+ if need_adjust:
+ self.keep_mask_layers[key] = True
+ masks[key] = self.get_single_mask_per_target_ratio(new_scores[key], adjust_ratio)
+ masks[key] = masks[key].repeat_interleave(self.M, dim=-1)
+ ## both zero will be zero
+ masks[key] = (masks[key] + least_ninm_masks[key])
+ zero = torch.tensor([0.]).to(score.device)
+ one = torch.tensor([1.]).to(score.device)
+ masks[key] = torch.where(masks[key] <= 0, zero, one)
+ if keep_exact_sparsity_ratio:
+ zero_cnt = self.get_sparsity_ratio({key: masks[key]}, return_dict=True)["zero_cnt"]
+ residual_k -= zero_cnt
+ else:
+ masks[key] = mask
+ if not keep_exact_sparsity_ratio:
+ break
+ new_not_exceed_layers = [key for key in new_scores.keys() if not self.keep_mask_layers.get(key, False)]
+ if not_exceed_layers == new_not_exceed_layers or len(new_not_exceed_layers) == 0:
+ break
+ not_exceed_layers = new_not_exceed_layers
+ global_scores = torch.cat([torch.flatten(new_scores[key]) for key in not_exceed_layers])
+
+ for key in masks.keys():
+ if key in self.invalid_layers:
+ continue
+ if len(scores[key].shape) == 4: ## need to permute
+ mask = masks[key]
+ orig_shape = scores[key].shape
+ mask = self._reshape_2dims_to_orig(mask, orig_shape)
+ masks[key] = mask
+ layer_ratio = torch.sum(masks[key] == 0.0).data.item() / masks[key].numel()
+ logger.info(f'layer {key} sparsity_ratio is {layer_ratio}')
+ return masks
+
+ def get_pattern_lock_masks(self, modules):
+ """Obtain masks from original weight map, by masking where weights' are zero.
+
+ Args:
+ modules: A dict{“layer_name”: Tensor}. Store weights.
+
+ Returns:
+ A dict with the identical size as modules, containing pattern lock masks.
+ """
+ pattern_lock_masks = {}
+ for key in modules.keys():
+ weight = modules[key].weight
+ orig_shape = weight.shape
+ if key in self.invalid_layers:
+ mask = torch.ones(orig_shape, device=weight.device)
+ pattern_lock_masks[key] = mask
+ continue
+ mask = self.get_least_ninm_mask_from_data(weight)
+ mask = self._reshape_2dims_to_orig(mask, orig_shape)
+ pattern_lock_masks[key] = mask
+ return pattern_lock_masks
diff --git a/neural_compressor/pruners/__init__.py b/neural_compressor/pruner/pruner_legacy/__init__.py
similarity index 100%
rename from neural_compressor/pruners/__init__.py
rename to neural_compressor/pruner/pruner_legacy/__init__.py
diff --git a/neural_compressor/pruners/gradient_sensitivity.py b/neural_compressor/pruner/pruner_legacy/gradient_sensitivity.py
similarity index 99%
rename from neural_compressor/pruners/gradient_sensitivity.py
rename to neural_compressor/pruner/pruner_legacy/gradient_sensitivity.py
index e6ae10e0ee6..46683c14e23 100644
--- a/neural_compressor/pruners/gradient_sensitivity.py
+++ b/neural_compressor/pruner/pruner_legacy/gradient_sensitivity.py
@@ -18,7 +18,7 @@
import numpy as np
from .pruner import pruner_registry, Pruner
from heapq import heappush, heappop
-from ..utils import logger
+from neural_compressor.utils import logger
import re
@pruner_registry
diff --git a/neural_compressor/pruners/group_lasso.py b/neural_compressor/pruner/pruner_legacy/group_lasso.py
similarity index 98%
rename from neural_compressor/pruners/group_lasso.py
rename to neural_compressor/pruner/pruner_legacy/group_lasso.py
index fc659bdafa1..045fa18d07d 100644
--- a/neural_compressor/pruners/group_lasso.py
+++ b/neural_compressor/pruner/pruner_legacy/group_lasso.py
@@ -20,7 +20,7 @@
import numpy as np
from .pruner import pruner_registry, Pruner
from .magnitude import BasicMagnitudePruner
-from ..utils import logger
+from neural_compressor.utils import logger
@pruner_registry
class GroupLassoPruner(BasicMagnitudePruner):
diff --git a/neural_compressor/pruners/magnitude.py b/neural_compressor/pruner/pruner_legacy/magnitude.py
similarity index 98%
rename from neural_compressor/pruners/magnitude.py
rename to neural_compressor/pruner/pruner_legacy/magnitude.py
index 752e1cf2268..9544d9474b2 100644
--- a/neural_compressor/pruners/magnitude.py
+++ b/neural_compressor/pruner/pruner_legacy/magnitude.py
@@ -17,7 +17,7 @@
import numpy as np
from .pruner import pruner_registry, Pruner
-from ..utils import logger
+from neural_compressor.utils import logger
@pruner_registry
class BasicMagnitudePruner(Pruner):
diff --git a/neural_compressor/pruners/pattern_lock.py b/neural_compressor/pruner/pruner_legacy/pattern_lock.py
similarity index 100%
rename from neural_compressor/pruners/pattern_lock.py
rename to neural_compressor/pruner/pruner_legacy/pattern_lock.py
diff --git a/neural_compressor/pruners/pruner.py b/neural_compressor/pruner/pruner_legacy/pruner.py
similarity index 98%
rename from neural_compressor/pruners/pruner.py
rename to neural_compressor/pruner/pruner_legacy/pruner.py
index 64d2e44cdda..6384235af30 100644
--- a/neural_compressor/pruners/pruner.py
+++ b/neural_compressor/pruner/pruner_legacy/pruner.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ..experimental.pruning_recipes.patterns import patterns
+from neural_compressor.experimental.pruning_recipes.patterns import patterns
PRUNERS = {}
diff --git a/neural_compressor/pruners/util/block_mask.py b/neural_compressor/pruner/pruner_legacy/util/block_mask.py
similarity index 100%
rename from neural_compressor/pruners/util/block_mask.py
rename to neural_compressor/pruner/pruner_legacy/util/block_mask.py
diff --git a/neural_compressor/pruner/pruners.py b/neural_compressor/pruner/pruners.py
new file mode 100644
index 00000000000..c9a7cf436ae
--- /dev/null
+++ b/neural_compressor/pruner/pruners.py
@@ -0,0 +1,565 @@
+"""Pruner."""
+# !/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2022 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import copy
+from neural_compressor.utils.utility import LazyImport
+torch = LazyImport('torch')
+from .patterns import get_pattern
+from .schedulers import get_scheduler
+from .criteria import get_criterion, CRITERIAS
+from .regs import get_reg
+from .logger import logger
+
+PRUNERS = {}
+
+
+def register_pruner(name):
+ """Class decorator to register a Pruner subclass to the registry.
+
+ Decorator function used before a Pattern subclass.
+ Make sure that the Pruner class decorated by this function can be registered in PRUNERS.
+
+ Args:
+ cls (class): The subclass of register.
+ name: A string. Define the pruner type.
+
+ Returns:
+ cls: The class of register.
+ """
+
+ def register(pruner):
+ PRUNERS[name] = pruner
+ return pruner
+
+ return register
+
+
+def get_pruner(config, modules):
+ """Get registered pruner class.
+
+ Get a Pruner object from PRUNERS.
+
+ Args:
+ modules: A dict {"module_name": Tensor}. Store the pruning modules' weights.
+ config: A config dict object. Contains the pruner information.
+
+ Returns:
+ A Pruner object.
+
+ Raises: AssertionError: Cuurently only support pruners which have been registered in PRUNERS.
+ """
+ ## do the ugly work here
+ if "progressive" not in config["pruning_type"]:
+ name = config["pruning_type"]
+ config["progressive"] = False
+ else:
+ # if progressive, delete "progressive" words and reset config["progressive"]
+ name = config["pruning_type"][0:-12]
+ config["progressive"] = True
+ if name in CRITERIAS:
+ if config["progressive"] == False:
+ config['criterion_type'] = name
+ name = "basic" ##return the basic pruner
+ else:
+ config['criterion_type'] = name
+ name = "progressive" ## return the progressive pruner
+
+ if name not in PRUNERS.keys():
+ assert False, f"does not support {name}, currently only support {PRUNERS.keys()}"
+ return PRUNERS[name](config, modules)
+
+
+class BasePruner:
+ """Pruning Pruner.
+
+ The class which executes pruning process.
+
+ Args:
+ modules: A dict {"module_name": Tensor}. Store the pruning modules' weights.
+ config: A config dict object. Contains the pruner information.
+
+ Attributes:
+ modules: A dict {"module_name": Tensor}. Store the pruning modules' weights.
+ config: A config dict object. Contains the pruner information.
+ masks: A dict {"module_name": Tensor}. Store the masks for modules' weights.
+ scores: A dict {"module_name": Tensor}. Store the score for modules' weights,
+ which are used to decide pruning parts with a criterion.
+ pattern: A Pattern object. Defined in ./patterns.py
+ scheduler: A scheduler object. Defined in ./scheduler.py
+ current_sparsity_ratio: A float. Current model's sparsity ratio, initialized as zero.
+ global_step: An integer. The total steps the model has run.
+ start_step: An integer. When to trigger pruning process.
+ end_step: An integer. When to end pruning process.
+ pruning_frequency: An integer. The pruning frequency, which's valid when iterative
+ pruning is enabled.
+ target_sparsity_ratio: A float. The final sparsity after pruning.
+ max_sparsity_ratio_per_op: A float. Sparsity ratio maximum for every module.
+ """
+
+ def __init__(self, config, modules):
+ """Initialize."""
+ self.modules = modules
+ self.config = config
+ self.masks = {}
+ self.global_step = 0
+ self.handled_global_step = -1
+ self.start_step = self.config['start_step']
+ self.end_step = self.config['end_step']
+ self.pruning_frequency = self.config['pruning_frequency']
+ ##this is different with original code
+ self.total_prune_cnt = (self.end_step - self.start_step + 1) \
+ // self.pruning_frequency
+ self.completed_pruned_cnt = 0
+ for key in self.modules.keys():
+ module = self.modules[key]
+ self.masks[key] = torch.ones(module.weight.shape).to(module.weight.device) ##TODO support bias or others
+
+ self.target_sparsity_ratio = self.config['target_sparsity']
+ self.current_sparsity_ratio = 0.0
+ self.init_sparsity_ratio = 0.0
+ self._init()
+
+ def _init(self):
+ """Auxiliary function for initializing."""
+ pass
+
+ def on_epoch_begin(self, epoch):
+ """Implement at the beginning of each epoch."""
+ pass
+
+ def mask_weights(self):
+ """Apply masks to corresponding modules' weights.
+
+ Weights are multipled with masks. This is the formal pruning process.
+ """
+ with torch.no_grad():
+ for key in self.modules.keys():
+ module = self.modules[key]
+ module.weight.data = module.weight.data * self.masks[key]
+
+ def mask_weights_general(self, input_masks):
+ """Apply input masks to corresponding modules' weights.
+
+ Weights are multipled with input_masks.
+
+ Args:
+ input_masks: A dict {"module_name": Tensor}. Store the masks for modules' weights.
+ """
+ with torch.no_grad():
+ for key in self.modules.keys():
+ module = self.modules[key]
+ module.weight.data = module.weight.data * input_masks[key]
+
+ def on_step_begin(self, local_step):
+ """Implement at the start of each step."""
+ if self.handled_global_step == self.global_step:
+ return
+ self.update_masks(local_step)
+ self.handled_global_step = self.global_step
+
+ def update_masks(self, local_step):
+ """Update the masks at a given local step."""
+ pass
+
+ def on_epoch_end(self):
+ """Implement at the end of each epoch."""
+ pass
+
+ def on_step_end(self):
+ """Implement at the end of each step."""
+ pass
+
+ def on_before_optimizer_step(self):
+ """Implement before optimizer.step()."""
+ pass
+
+ def on_after_optimizer_step(self):
+ """Implement after optimizer.step().
+
+ Prune the model after optimization.
+ """
+ self.mask_weights()
+ self.global_step += 1
+
+ def on_train_begin(self):
+ """Implement at the beginning of training phase."""
+ pass
+
+ def on_train_end(self):
+ """Implement at the end of training phase."""
+ pass
+
+ def on_before_eval(self):
+ """Implement at the beginning of evaluation phase."""
+ pass
+
+ def on_after_eval(self):
+ """Implement at the end of evaluation phase."""
+ pass
+
+ def check_is_pruned_step(self, step):
+ """Check if a pruning process should be performed at the current step.
+
+ Args:
+ step: an integer representing the number of current step.
+
+ Returns:
+ A Boolean.
+ """
+ if step < self.start_step or step > self.end_step:
+ return False
+ if int(step - self.start_step) % self.pruning_frequency == 0:
+ return True
+ return False
+
+
+@register_pruner("basic")
+class BasicPruner(BasePruner):
+ """Pruning Pruner.
+
+ The class which executes pruning process.
+ 1. Defines pruning functions called at step begin/end, epoch begin/end.
+ 2. Defines the pruning criterion.
+
+ Args:
+ modules: A dict {"module_name": Tensor}. Store the pruning modules' weights.
+ config: A config dict object. Contains the pruner information.
+
+ Attributes:
+ pattern: A Pattern object. Define pruning weights' arrangements within space.
+ criterion: A Criterion Object. Define which weights are to be pruned
+ scheduler: A Scheduler object. Define model's sparsity changing method as training/pruning executes.
+ reg: A Reg object. Define regulization terms.
+ """
+
+ def __init__(self, config, modules):
+ """Initialize."""
+ # self.modules = modules
+ # self.config = config
+ # self.masks = {}
+ super(BasicPruner, self).__init__(config, modules)
+
+ def _init(self):
+ """Auxiliary function for initializing."""
+ self.pattern = get_pattern(self.config, self.modules)
+ self.scheduler = get_scheduler(self.config)
+ self.criterion = get_criterion(self.config, self.modules)
+ self.reg = get_reg(self.config, self.modules, self.pattern)
+ # if switch off progressive but use per-channel pruning, give a warn
+ if "channel" in self.pattern.pattern:
+ logger.info("UserWarning: use per-channel pruning pattern without progressive pruning!")
+ logger.info("Instead, enabling progressive pruning would be a better choice.")
+ else:
+ pass
+
+ def set_global_step(self, global_step):
+ """Set global step number."""
+ self.global_step = global_step
+
+ # def on_step_begin(self, local_step):
+ # """Implement at the start of each step.
+ #
+ # Update the masks at a given local_step.
+ # """
+ # self.update_masks(local_step)
+
+ def update_masks(self, local_step):
+ """Update the masks at a given local step."""
+ if self.global_step == self.start_step:
+ if self.config['lock_init_sparsity']:
+ self.masks = self.pattern.get_pattern_lock_masks(self.modules)
+ self.init_sparsity_ratio = self.pattern.get_sparsity_ratio(self.masks)
+ self.current_sparsity_ratio = self.init_sparsity_ratio
+
+ if not self.check_is_pruned_step(self.global_step):
+ return
+
+ if self.current_sparsity_ratio > self.target_sparsity_ratio:
+ return
+
+ self.criterion.on_step_begin()
+ current_target_sparsity_ratio = self.scheduler.update_sparsity_ratio(self.target_sparsity_ratio,
+ self.completed_pruned_cnt,
+ self.total_prune_cnt, self.masks,
+ self.init_sparsity_ratio)
+ logger.info(f"current target ratio is {current_target_sparsity_ratio}")
+
+ self.completed_pruned_cnt += 1
+ if self.criterion.scores == {}:
+ return
+ self.masks = self.pattern.get_masks(self.criterion.scores, current_target_sparsity_ratio, self.masks)
+ self.mask_weights()
+
+ self.current_sparsity_ratio = self.pattern.get_sparsity_ratio(self.masks)
+ logger.info(f"current sparsity ratio is {self.current_sparsity_ratio}")
+
+ def on_before_optimizer_step(self):
+ """Implement before optimizer.step()."""
+ self.reg.on_before_optimizer_step()
+
+ def on_after_optimizer_step(self):
+ """Prune the model after optimization."""
+ ##the order of the following three lines can't not be exchanged
+ self.reg.on_after_optimizer_step()
+ self.mask_weights()
+ self.criterion.on_after_optimizer_step()
+ self.global_step += 1
+
+
+@register_pruner('pattern_lock')
+class PatternLockPruner(BasePruner):
+ """Pruning Pruner.
+
+ A Pruner class derived from BasePruner.
+ In this pruner, original model's sparsity pattern will be fixed while training.
+ This pruner is useful when you want to train a sparse model without change its original structure.
+
+ Args:
+ modules: A dict {"module_name": Tensor}. Store the pruning modules' weights.
+ config: A config dict object. Contains the pruner information.
+
+ Attributes:
+ Inherit from parent class Pruner.
+ """
+
+ def __init__(self, config, modules):
+ """Initialize."""
+ super(PatternLockPruner, self).__init__(config, modules)
+ self.pattern = get_pattern(self.config, modules)
+ assert self.config.end_step == self.config.start_step, "pattern_lock pruner only supports one shot mode"
+
+ def update_masks(self, local_step):
+ """Update the masks at a given local step."""
+ if not self.check_is_pruned_step(self.global_step):
+ return
+ self.masks = self.pattern.get_pattern_lock_masks(self.modules)
+
+ def on_after_optimizer_step(self):
+ """Implement after optimizer.step().
+
+ Prune the model after optimization.
+ """
+ self.mask_weights()
+ self.global_step += 1
+
+
+@register_pruner('progressive')
+class ProgressivePruner(BasicPruner):
+ """Pruning Pruner.
+
+ A Pruner class derived from BasePruner. In this pruner, mask interpolation will be applied.
+ Mask interpolation is a fine-grained improvement for NxM structured pruning,
+ By adding interval masks between masks of two pruning steps
+
+ Args:
+ modules: A dict {"module_name": Tensor}. Store the pruning modules' weights.
+ config: A config dict object. Contains the pruner information.
+
+ Attributes:
+ Inherit from parent class Pruner.
+ """
+
+ def __init__(self, config, modules):
+ """Initialize."""
+ super(ProgressivePruner, self).__init__(config, modules)
+
+ def _init(self):
+ """Auxiliary function for initialization."""
+ self.pattern = get_pattern(self.config, self.modules)
+ self.scheduler = get_scheduler(self.config)
+ self.criterion = get_criterion(self.config, self.modules)
+ self.reg = get_reg(self.config, self.modules, self.pattern)
+ # progressive pruning set up, including check up paramters.
+ self.use_progressive = self.config["progressive"]
+ # progressive parameters
+ # dict passed to Pattern's functions
+ self.progressive_configs = {
+ "progressive_steps": 4,
+ "progressive_type": "scores",
+ "use_global": True
+ }
+ self.progressive_steps = self.progressive_configs["progressive_steps"]
+ self.progressive_type = self.progressive_configs["progressive_type"]
+ self.use_global = self.progressive_configs["use_global"]
+ self.progressive_logger = False
+ self._init_for_progressive()
+
+ def _init_for_progressive(self):
+ """Auxiliary function for initializing progressive pruning."""
+ # detailed progressive parameters will stored at patterns.py
+ # step 1: check if pattern is NxM
+ if "x" not in self.pattern.pattern:
+ raise NotImplementedError(f"Currently progressive only " \
+ f"support NxM and per-channel pruning patterns.")
+
+ # step 2: check if current set up will "degrade" into non-progressive
+ degrading_flag = False
+ if (self.end_step - self.start_step) <= self.progressive_steps or self.progressive_steps <= 1:
+ logger.info("Current progressive setting will degrading to non-progressive pruning.")
+ self.use_progressive = False
+ return
+
+ # step 3: log hyper-parameters. and check validity.
+ if self.use_progressive:
+ logger.info(f"Progressive pruning is enabled!")
+ logger.info(f"Progressive pruning steps: {self.progressive_steps}")
+ logger.info(f"Progressive type: {self.progressive_type}")
+ logger.info(f"Progressive balance: {self.use_global}")
+ self.check_progressive_validity()
+ self.pre_masks = copy.deepcopy(self.masks)
+ self.progressive_masks = copy.deepcopy(self.masks)
+ if self.pruning_frequency < self.progressive_steps:##TODO trick
+ self.progressive_steps = self.pruning_frequency
+ # if self.progressive_steps == 3:
+ # self.progressive_steps = 2
+ self.pruning_frequency_progressive = self.progressive_steps
+ else:
+ self.pruning_frequency_progressive = self.pruning_frequency // self.progressive_steps
+ # this is a structural pruning step, it fits self.pruning_frequency
+ self.structured_update_step = 0
+
+ def check_progressive_validity(self):
+ """Check if the settings of progressive pruning are valid."""
+ # check some problematic settings
+ if self.progressive_type == "linear":
+ if self.use_global:
+ # when global progressive is applied, linear type is contradict.
+ raise NotImplementedError("Global progressive pruning do not support linear pattern")
+ # When linear, progressive_step should not meet a indivisible
+ for key in self.pattern.block_size.keys():
+ block_size = self.pattern.block_size[key]
+ progressive_direction = max(block_size)
+ if progressive_direction % self.progressive_steps != 0:
+ raise ValueError(
+ f"In layer {key}, its pruning pattern is {block_size}, " \
+ f"while progressive steps {self.progressive_steps} is indivisible.")
+ else:
+ for key in self.pattern.block_size.keys():
+ block_size = self.pattern.block_size[key]
+ total_block_size = block_size[0] * block_size[1]
+ if total_block_size < self.progressive_steps:
+ raise ValueError(
+ f"In layer {key}, its pruning pattern is {block_size}, " \
+ f"while progressive steps {self.progressive_steps} is overflowing.")
+
+ def check_is_pruned_progressive_step(self, step):
+ """Check if a progressive pruning process should be performed at the current step.
+
+ Args:
+ step: an integer representing the number of current step.
+
+ Returns:
+ A Boolean.
+ """
+ # used in progressive pruning
+ if step < self.start_step or step > self.end_step:
+ return False
+ if int(step - self.start_step) % self.pruning_frequency_progressive == 0:
+ return True
+ return False
+
+ def update_masks_progressive(self, local_step):
+ """Update the masks in progressive pruning mode at a given local step."""
+ if self.global_step == self.start_step:
+ if self.config['lock_init_sparsity']:
+ self.masks = self.pattern.get_pattern_lock_masks(self.modules)
+ self.init_sparsity_ratio = self.pattern.get_sparsity_ratio(self.masks)
+ self.current_sparsity_ratio = self.init_sparsity_ratio
+
+ # case 1: step is not in [start_step, end_step] or it is not either pruning or progressive pruning step.
+ if (self.check_is_pruned_step(self.global_step) == False) and (
+ self.check_is_pruned_progressive_step(self.global_step) == False):
+ return
+ if self.current_sparsity_ratio > self.target_sparsity_ratio:
+ return
+
+ # case 2: step which does progressive update, but it is not a pruning step in case 3
+ if self.check_is_pruned_progressive_step(self.global_step) \
+ and self.check_is_pruned_step(self.global_step) == False:
+ # do not do global pruning, only do the progressive mask update.
+ step_offset = self.global_step - self.structured_update_step
+ progressive_idx = step_offset // self.pruning_frequency_progressive
+ if progressive_idx < (self.progressive_steps - 1):
+ self.progressive_masks = self.pattern.update_progressive_masks(self.pre_masks, self.masks, \
+ self.criterion.scores, \
+ progressive_idx + 1, \
+ self.progressive_configs)
+ else:
+ # in the end, directly use new masks.
+ for n in self.masks.keys():
+ self.progressive_masks[n] = self.masks[n].clone()
+ self.mask_weights_general(self.progressive_masks)
+ if self.progressive_logger:
+ self.print_progressive_sparsity()
+ return
+
+ # case 3: a pruning step, generate new masks, progressive masks also update.
+ tmp_step = self.global_step
+ self.structured_update_step = tmp_step
+ current_target_sparsity_ratio = self.scheduler.update_sparsity_ratio(self.target_sparsity_ratio,
+ self.completed_pruned_cnt,
+ self.total_prune_cnt, self.masks)
+ logger.info(f"current target ratio is {current_target_sparsity_ratio}")
+ self.criterion.on_step_begin()
+ self.completed_pruned_cnt += 1
+ if self.criterion.scores == {}:
+ return
+ for n in self.masks.keys():
+ self.pre_masks[n] = self.masks[n].clone()
+ # update new masks
+ self.masks = self.pattern.get_masks(self.criterion.scores, current_target_sparsity_ratio, self.masks, )
+ self.progressive_masks = self.pattern.update_progressive_masks(self.pre_masks, self.masks, \
+ self.criterion.scores, 1, \
+ self.progressive_configs)
+ self.mask_weights_general(self.progressive_masks)
+ if self.progressive_logger:
+ self.print_progressive_sparsity()
+ return
+
+ def on_step_begin(self, local_step):
+ """Update the masks at a given local_step."""
+ """Implement at the start of each step."""
+ if self.handled_global_step == self.global_step:
+ return
+
+ if not self.use_progressive:
+ # As _init_for_progressive() works, when degrades to non-progressive
+ # just call BasicPruner's update_masks().
+ self.update_masks(local_step)
+ else:
+ self.update_masks_progressive(local_step)
+ self.handled_global_step = self.global_step
+
+ def on_before_optimizer_step(self):
+ """Implement before optimizer.step()."""
+ self.reg.on_before_optimizer_step()
+
+ def on_after_optimizer_step(self):
+ """Prune the model after optimization."""
+ ##the order of the following three lines can't not be exchanged
+ self.reg.on_after_optimizer_step()
+ if not self.use_progressive:
+ self.mask_weights()
+ else:
+ self.mask_weights_general(self.progressive_masks)
+ self.criterion.on_after_optimizer_step()
+ self.global_step += 1
+
+ def print_progressive_sparsity(self):
+ """Output the progressive sparsity."""
+ cur_sp = self.pattern.get_sparsity_ratio_progressive(self.progressive_masks)
+ logger.info("Step: {} -> Current progressive sparsity: {}".format(self.global_step, cur_sp))
diff --git a/neural_compressor/pruner/regs.py b/neural_compressor/pruner/regs.py
new file mode 100644
index 00000000000..8ce97e4c87e
--- /dev/null
+++ b/neural_compressor/pruner/regs.py
@@ -0,0 +1,128 @@
+"""Regularizer."""
+# !/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2022 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .patterns import BasePattern
+from neural_compressor.utils.utility import LazyImport
+torch = LazyImport('torch')
+
+REGS = {}
+
+
+def register_reg(name):
+ """Register a regularizator to the registry."""
+
+ def register(reg):
+ REGS[name] = reg
+ return reg
+
+ return register
+
+
+def get_reg_type(config):
+ """Obtain the regularizer type."""
+ for key in REGS.keys(): ##assume there is only one reg
+ if config.get(key, None) != None:
+ return key
+ return None
+
+
+def get_reg(config, modules, pattern):
+ """Get registered regularizator class."""
+ reg_type = config["reg_type"]
+ if reg_type == None:
+ return BaseReg(config, modules, pattern)
+ if reg_type not in REGS.keys():
+ assert False, f"regularizator does not support {reg_type}, currently only support {REGS.keys()}"
+ return REGS[reg_type](config, modules, pattern, config["reg_coeff"])
+
+
+class BaseReg:
+ """Regularizer.
+
+ The class which performs regularization.
+
+ Args:
+ modules: A dict {"module_name": Tensor}. Store the pruning modules' weights.
+ config: A config dict object that includes information of the regularizer.
+ pattern: A config dict object. The pattern related part in args config.
+ """
+
+ def __init__(self, config: dict, modules: dict, pattern: BasePattern):
+ """Initialize."""
+ self.modules = modules
+ self.config = config
+ self.pattern = pattern
+
+ def on_before_optimizer_step(self):
+ """Implement before optimizer.step()."""
+ pass
+
+ def on_after_optimizer_step(self):
+ """Implement after optimizer.step()."""
+ pass
+
+
+@register_reg("group_lasso")
+class GroupLasso(BaseReg):
+ """Regularizer.
+
+ A regularizer class derived from BaseReg. In this class, the Group-lasso regularization will be performed.
+ Group-lasso is a variable-selection and regularization method.
+
+ Args:
+ modules: A dict {"module_name": Tensor}. Store the pruning modules' weights.
+ config: A config dict object that includes information of the regularizer.
+ pattern: A config dict object. The pattern related part in args config.
+
+ Attributes:
+ reg_terms: A dict {"module_name": Tensor} of regularization terms.
+ alpha: A float representing the coeffient related to group lasso.
+ """
+
+ def __init__(self, config: dict, modules: dict, pattern: BasePattern, coeff):
+ """Initialize."""
+ super(GroupLasso, self).__init__(config, modules, pattern)
+ assert "x" in self.config.pattern, "group lasso only supports NXM pattern"
+ self.reg_terms = {}
+ self.alpha = float(coeff)
+ assert self.alpha >= 0, "group lasso only supports positive coeff"
+
+ def on_before_optimizer_step(self):
+ """Calculate the group-lasso score map."""
+ with torch.no_grad():
+ if self.pattern.invalid_layers == None:
+ self.pattern.check_layer_validity()
+ for key in self.modules.keys():
+ if key in self.pattern.invalid_layers:
+ continue
+ grad = self.modules[key].weight.grad
+ reg_term = self.pattern.reshape_orig_to_pattern(grad, key)
+ reg_term = self.alpha / (torch.norm(reg_term, p=2, dim=[1, 3]) + 1e-12)
+ reg_term[torch.isinf(reg_term)] = 0.0
+ self.reg_terms[key] = reg_term
+
+ def on_after_optimizer_step(self): ##decoupled with grad descent
+ """Perform group lasso regularization after optimization."""
+ with torch.no_grad():
+ for key in self.modules.keys():
+ if key in self.pattern.invalid_layers:
+ continue
+ reg_term = self.pattern.reshape_reduced_to_orig(self.reg_terms[key], key,
+ self.modules[key].weight.shape)
+ self.modules[key].weight -= reg_term
+
diff --git a/neural_compressor/pruner/schedulers.py b/neural_compressor/pruner/schedulers.py
new file mode 100644
index 00000000000..78e985da05f
--- /dev/null
+++ b/neural_compressor/pruner/schedulers.py
@@ -0,0 +1,177 @@
+"""scheduler module."""
+# !/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2022 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+SCHEDULERS = {}
+
+
+def register_scheduler(name):
+ """Class decorator used to register a Scheduler subclass to the registry.
+
+ Decorator function used before a Scheduler subclass.
+ Make sure that the Scheduler class decorated by this function can be registered in SCHEDULERS.
+
+ Args:
+ cls (class): The class of register.
+ name: A string. Define the scheduler type.
+
+ Returns:
+ cls: The class of register.
+ """
+
+ def register(scheduler):
+ SCHEDULERS[name] = scheduler
+ return scheduler
+
+ return register
+
+
+def get_scheduler(config):
+ """Get registered scheduler class.
+
+ Get a scheduler object from SCHEDULERS.
+
+ Args:
+ config: A config dict object. Contains the scheduler information.
+
+ Returns:
+ A Scheduler object.
+ """
+ name = "iterative"
+ if config.start_step == config.end_step:
+ name = "oneshot"
+ return SCHEDULERS[name](config)
+
+
+class PruningScheduler:
+ """Pruning Scheduler.
+
+ The class which defines a sparsity changing process during pruning.
+ Mainly contains two types:
+ 1. iterative scheduler. Prune the model from dense to target sparsity gradually.
+ 2. one-shot scheduler. Prune the model in a single step and reach the target sparsity.
+
+ Args:
+ config: A config dict object. Contains the scheduler information.
+
+ Attributes:
+ config: A config dict object. Contains the scheduler information.
+ """
+
+ def __init__(self, config):
+ """Initialize."""
+ self.config = config
+
+ def update_sparsity_ratio(self, target_ratio, current_prune_step, total_prune_steps, masks, init_ratio=0.0):
+ """To be implemented in subclasses."""
+ raise NotImplementedError
+
+
+@register_scheduler('oneshot')
+class OneshotScheduler(PruningScheduler):
+ """Pruning Scheduler.
+
+ A Scheduler class derived from Scheduler.
+ Prune the model to target sparsity once.
+
+ Args:
+ config: A config dict object. Contains the scheduler information.
+
+ Attributes:
+ Inherit from parent class Scheduler.
+ """
+
+ def __init__(self, config):
+ """Initialize."""
+ super(OneshotScheduler, self).__init__(config)
+
+ def update_sparsity_ratio(self, target_ratio, current_prune_step, total_prune_steps, masks, init_ratio=0.0):
+ """Update sparsity ratio.
+
+ Args:
+ target_ratio: A float representing the sparsity ratio after pruning.
+ current_prune_step: An integer representing the current pruning step.
+ total_prune_steps: An integer representing the total number of steps of the pruning process.
+ masks: A dict {"module_name": Tensor} that stores the masks for modules' weights.
+ init_ratio: A float representing the sparsity ratio before pruning.
+
+ Return:
+ A float representing the sparsity ratio that the model will reach after the next pruning step.
+ """
+ return target_ratio
+
+
+@register_scheduler('iterative')
+class IterativeScheduler(PruningScheduler):
+ """Pruning Scheduler.
+
+ A Scheduler class derived from Scheduler.
+ Prune the model from dense to target sparsity in several steps.
+
+ Args:
+ config: A config dict object. Contains the scheduler information.
+
+ Attributes:
+ Inherit from parent class Scheduler.
+ """
+
+ def __init__(self, config):
+ """Initialize."""
+ super(IterativeScheduler, self).__init__(config)
+
+ def update_sparsity_ratio(self, target_ratio, current_prune_step, total_prune_steps, masks,
+ init_sparsity_ratio=0.0):
+ """Obtain new target sparsity ratio according to the step.
+
+ Args:
+ target_ratio: A float. The target sparsity ratio.
+ current_prune_step: A integer. The current pruning step.
+ total_prune_steps: A integer. The total steps included in the pruning progress.
+ masks: A dict{"module_name": Tensor}. The masks for modules' weights.
+ init_sparsity_ratio:
+
+ Returns:
+ A float representing the target sparsity ratio the model will reach after the next pruning step.
+ """
+ aggressive_ratio = target_ratio
+ aggressive_ratio = min(self.config.max_sparsity_ratio_per_op,
+ aggressive_ratio) ##legacy issue
+
+ decay_type = self.config.sparsity_decay_type
+ if decay_type == "cos":
+ current_target_sparsity = (aggressive_ratio - init_sparsity_ratio) * (
+ 1.0 - math.cos(float(current_prune_step) / total_prune_steps * (math.pi / 2))) + init_sparsity_ratio
+ elif decay_type == "exp":
+ target_dense_change_ratio = ((1.0 - aggressive_ratio) / (1.0 - init_sparsity_ratio)) ** (
+ 1 / total_prune_steps)
+ current_target_sparsity = 1.0 - (
+ 1.0 - init_sparsity_ratio) * target_dense_change_ratio ** current_prune_step
+
+ elif decay_type == "linear":
+ current_target_sparsity = (aggressive_ratio - init_sparsity_ratio) * float(
+ current_prune_step) / total_prune_steps + init_sparsity_ratio
+
+ elif decay_type == "cube":
+ current_target_sparsity = (aggressive_ratio - init_sparsity_ratio) * (
+ (float(current_prune_step) / total_prune_steps) ** 3) + init_sparsity_ratio
+ else:
+ assert False, "{} is not supported".format(decay_type)
+
+ current_target_sparsity = min(target_ratio, current_target_sparsity)
+ return current_target_sparsity
diff --git a/neural_compressor/pruner/utils.py b/neural_compressor/pruner/utils.py
new file mode 100644
index 00000000000..5598167dee5
--- /dev/null
+++ b/neural_compressor/pruner/utils.py
@@ -0,0 +1,247 @@
+"""prune utils."""
+# !/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2022 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+import yaml
+
+try:
+ from neural_compressor.conf.dotdict import DotDict
+except:
+ from .dot_dict import DotDict ##TODO
+from .logger import logger
+
+
+class WeightPruningConfig:
+ """
+ similiar to torch optimizer's interface
+ """
+
+ def __init__(self, pruning_configs=[{}], ##empty dict will use global values
+ target_sparsity=0.9, pruning_type="snip_momentum", pattern="4x1", op_names=[],
+ excluded_op_names=[],
+ start_step=0, end_step=0, pruning_scope="global", pruning_frequency=1,
+ min_sparsity_ratio_per_op=0.0, max_sparsity_ratio_per_op=0.98,
+ sparsity_decay_type="exp", pruning_op_types=['Conv', 'Linear'],
+ **kwargs):
+ self.pruning_configs = pruning_configs
+ self._weight_compression = DotDict({
+ 'target_sparsity': target_sparsity,
+ 'pruning_type': pruning_type,
+ 'pattern': pattern,
+ 'op_names': op_names,
+ 'excluded_op_names': excluded_op_names, ##global only
+ 'start_step': start_step,
+ 'end_step': end_step,
+ 'pruning_scope': pruning_scope,
+ 'pruning_frequency': pruning_frequency,
+ 'min_sparsity_ratio_per_op': min_sparsity_ratio_per_op,
+ 'max_sparsity_ratio_per_op': max_sparsity_ratio_per_op,
+ 'sparsity_decay_type': sparsity_decay_type,
+ 'pruning_op_types': pruning_op_types,
+ ##reg_type=None, reduce_type="mean", parameters={"reg_coeff": 0.0}
+ ##'resume_from_pruned_checkpoint': resume_from_pruned_checkpoint ##resume_from_pruned_checkpoint
+ })
+ self._weight_compression.update(kwargs)
+
+ @property
+ def weight_compression(self):
+ return self._weight_compression
+
+ @weight_compression.setter
+ def weight_compression(self, weight_compression):
+ self._weight_compression = weight_compression
+
+
+def check_config(prune_config):
+ """Functions that check key-value is valid to run Pruning object.
+
+ Args:
+ prune_config: A config dict object. Contains Pruning parameters and configurations.
+
+ Returns:
+ None if everything is correct.
+
+ Raises:
+ AssertionError.
+ """
+ assert prune_config['start_step'] >= 0, "start_step should be greater than 0"
+ assert prune_config['end_step'] >= -1, "end_step should be greater than 0"
+ assert prune_config['end_step'] >= prune_config['start_step'], \
+ "end_step should be greater than start_step"
+ assert prune_config['target_sparsity'] >= 0 and prune_config['target_sparsity'] < 1.0, \
+ "begin_pruning_step should be in range [0,1)"
+ assert prune_config['pruning_frequency'] > 0, "pruning_frequency should be greater than 0"
+ assert prune_config['max_sparsity_ratio_per_op'] >= 0 and prune_config['max_sparsity_ratio_per_op'] < 1, \
+ "pruning_frequency should be greater than 0"
+ assert prune_config['pruning_scope'] == "global" or prune_config['pruning_scope'] == "local", \
+ "only support 'global' and 'local' prune domain"
+ try:
+ prune_config['resume_from_pruned_checkpoint'] = bool(prune_config['resume_from_pruned_checkpoint'])
+ except:
+ assert False, "resume_from_pruned_checkpoint should be bool value"
+ if "x" in prune_config["pattern"]:
+ pattern = prune_config["pattern"].split('_')[-1].split('x')
+ if pattern[0] == "channel" or pattern[1] == "channel":
+ pass
+ else:
+ try:
+ N = int(pattern[0])
+ M = int(pattern[1])
+ except:
+ assert False, "N or M can't convert to int"
+ assert N > 0, "N should be greater than 0"
+ assert M > 0, "M should be greater than 0"
+ if ":" in prune_config["pattern"]:
+ pattern = prune_config["pattern"].split('_')[-1].split(':')
+ try:
+ N = int(pattern[0])
+ M = int(pattern[1])
+ except:
+ assert False, "N or M can't convert to int"
+ assert N > 0, "N should be greater than 0"
+ assert M > N, "M should be greater than N"
+ max_ratio = float(N) / M
+ assert prune_config['target_sparsity'] <= max_ratio, \
+ "in N:M pattern, the max sparsity is N/M={}".format(max_ratio)
+ prune_config['max_sparsity_ratio_per_op'] = min(max_ratio, prune_config['max_sparsity_ratio_per_op'])
+ if prune_config['reg_coeff'] != None:
+ prune_config['reg_coeff'] = float(prune_config['reg_coeff'])
+ assert prune_config['reg_coeff'] >= 0, "only support positive reg_type"
+ assert prune_config["min_sparsity_ratio_per_op"] >= 0 and prune_config["min_sparsity_ratio_per_op"] <= \
+ prune_config['max_sparsity_ratio_per_op'], \
+ "min_sparsity_ratio_per_op should in[0, max_sparsity_ratio_per_op]"
+
+
+def reset_none_to_default(obj, key, default):
+ """Functions that add up undefined configurations.
+
+ If some configurations are not defined in the configuration, set it to a default value.
+
+ Args:
+ obj: A dict{key: value}
+ key: A string. Key in obj.
+ default: When the key is not in obj, Add key: default item in original obj.
+
+ """
+ if obj == None:
+ return None
+ if isinstance(obj, dict):
+ if (not key in obj.keys()) or obj[key] == None:
+ return default
+ else:
+ return obj[key]
+ else:
+ if not hasattr(obj, key) or getattr(obj, key) == None:
+ return default
+ else:
+ return getattr(obj, key)
+
+
+def update_params(info):
+ if "parameters" in info.keys():
+ params = info["parameters"]
+ for key in params:
+ info[key] = params[key]
+
+
+def process_and_check_weight_config(val: WeightPruningConfig):
+ default_global_config = {'target_sparsity': 0.9, 'pruning_type': 'snip_momentum', 'pattern': '4x1', 'op_names': [],
+ 'excluded_op_names': [],
+ 'start_step': 0, 'end_step': 0, 'pruning_scope': 'global', 'pruning_frequency': 1,
+ 'min_sparsity_ratio_per_op': 0.0, 'max_sparsity_ratio_per_op': 0.98,
+ 'sparsity_decay_type': 'exp',
+ 'pruning_op_types': ['Conv', 'Linear'],
+
+ }
+ default_local_config = {'resume_from_pruned_checkpoint': False, 'reg_type': None,
+ 'criterion_reduce_type': "mean", 'parameters': {"reg_coeff": 0.0}}
+
+ params_default_config = {"reg_coeff": 0.0}
+
+ default_config = {}
+ default_config.update(default_global_config)
+ default_config.update(default_local_config)
+ default_config.update(params_default_config)
+
+ pruning_configs = val.pruning_configs
+ pruners_info = []
+ global_info = val.weight_compression
+ if len(pruning_configs) == 0: ##only one
+ pruner_info = global_info
+ for key in default_config.keys():
+ pruner_info[key] = reset_none_to_default(pruner_info, key, default_config[key])
+ update_params(pruner_info)
+ check_config(pruner_info)
+ pruner_info = DotDict(pruner_info)
+ pruners_info.append(pruner_info)
+
+ else: ##TODO need update, in this mode, we ingore the global op names
+ for pruner_info in pruning_configs:
+ for key in default_config.keys():
+ pruner_info[key] = reset_none_to_default(pruner_info, key, global_info[key])
+ pruner_info[key] = reset_none_to_default(pruner_info, key, default_config[key])
+ update_params(pruner_info)
+ check_config(pruner_info)
+ pruner_info = DotDict(pruner_info)
+ pruners_info.append(pruner_info)
+
+ return pruners_info
+
+
+def process_config(config):
+ """Obtain a config dict object from a config file.
+
+ Args:
+ config: A string. The path to configuration file.
+
+ Returns:
+ A config dict object.
+ """
+ if isinstance(config, WeightPruningConfig):
+ return process_and_check_weight_config(config)
+ else:
+ assert False, f"not supported type {config}"
+
+
+def parse_to_prune(config, model):
+ """Keep target pruned layers."""
+ modules = {}
+ if config["op_names"] == None or config["op_names"] == []:
+ config["op_names"] = [".*"]
+ for raw in config["op_names"]:
+ try:
+ pattern = re.compile(raw)
+ except:
+ assert False, f"regular expression match does not support {raw}"
+ for name, module in filter(lambda t: pattern.search(t[0]), model.named_modules()):
+ for layer_type in config["pruning_op_types"]:
+ if layer_type in type(module).__name__:
+ modules[name] = module
+ break
+ ##remove not to prune layers
+ """Drop non-pruned layers."""
+ exclude_names = config["excluded_op_names"]
+ patterns = [re.compile(s) for s in exclude_names]
+ if len(patterns) <= 0:
+ return modules
+ new_modules = {}
+ for name in modules.keys():
+ if any([p.search(name) for p in patterns]):
+ continue
+ new_modules[name] = modules[name]
+ return new_modules
diff --git a/neural_compressor/pruning.py b/neural_compressor/pruning.py
index 3205ffff99b..0094b0fcdf9 100644
--- a/neural_compressor/pruning.py
+++ b/neural_compressor/pruning.py
@@ -1,7 +1,8 @@
-#!/usr/bin/env python
+"""Pruning."""
+# !/usr/bin/env python
# -*- coding: utf-8 -*-
#
-# Copyright (c) 2021 Intel Corporation
+# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,144 +15,188 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from neural_compressor.utils.utility import LazyImport
+LazyImport('torch.nn')
+torch = LazyImport('torch')
-from .utils import logger
-from .utils.utility import singleton
-from .experimental import Pruning as ExpPruning
-from deprecated import deprecated
+from neural_compressor.pruner.utils import process_config, parse_to_prune,\
+ check_config, update_params
+from neural_compressor.pruner.pruners import get_pruner
+from neural_compressor.utils import logger
+import re
+from neural_compressor.pruner.utils import WeightPruningConfig
-@singleton
class Pruning:
- """This is base class of pruning object.
+ """Pruning.
- Since DL use cases vary in the accuracy metrics (Top-1, MAP, ROC etc.), loss criteria
- (<1% or <0.1% etc.) and pruning objectives (performance, memory footprint etc.).
- Pruning class provides a flexible configuration interface via YAML for users to specify
- these parameters.
+ The main class that users will used in codes to do pruning.
+ Contain at least one Pruner object.
Args:
- conf_fname_or_obj (string or obj): The path to the YAML configuration file or
- Pruning_Conf class containing accuracy goal, pruning objective and related
- dataloaders etc.
-
+ config: a string. The path to a config file. For config file template, please refer to
+ https://github.com/intel/neural-compressor/tree/master/examples/pytorch/nlp/huggingface_models/text-classification/pruning/pytorch_pruner/eager/
+
+ Attributes:
+ model: The model object to prune.
+ config_file_path: A string. The path to a config file.
+ pruners: A list. A list of Pruner objects.
+ pruner_info: A config dict object. Contains pruners' information.
"""
- def __init__(self, conf_fname_or_obj):
- self.exp_pruner = ExpPruning(conf_fname_or_obj)
+ def __init__(self, config):
+ """Initialize."""
+ self.model = None
+ self.pruners = []
+ self.pruners_info = process_config(config)
- def on_epoch_begin(self, epoch):
- """ called on the begining of epochs"""
- self.exp_pruner.on_epoch_begin(epoch)
+ def update_config(self, *args, **kwargs):
+ """Add user-defined arguments to the original configurations.
+
+ The original config of pruning is read from a file.
+ However, users can still modify configurations by passing key-value arguments in this function.
+ Please note that the key-value arguments' keys are analysable in current configuration.
+ """
+ for item in self.pruners_info:
+ for key in kwargs:
+ if key in item.keys():
+ item[key] = kwargs[key]
+
+ update_params(item)
+ check_config(item)
+
+ # def _call_pruners(self, func):
+ # """Function which decorates the Pruning class's functions.
+ #
+ # It can simplify codes by calling same-name functions in Pruning's Pruner objects.
+ # For example, when it decorates on_step_begin function of Pruning,
+ # it automatically calls its Pruners' on_step_begin functions without a "for" code.
+ # However, when this trick is enabled, the pylint validation on INC cannot passed, therefore commented out.
+ # """
+ # def warpper(self, *args, **kw):
+ # func_name = f"{func.__name__}"
+ # func(self, *args, **kw)
+ # for prune in self.pruners:
+ # prun_func = getattr(prune, func_name)
+ # prun_func(*args, **kw)
+ #
+ # return warpper
+
+ def get_sparsity_ratio(self):
+ """Calculate sparsity ratio of a module/layer.
- def on_step_begin(self, batch_id):
- """ called on the begining of batches"""
- self.exp_pruner.on_step_begin(batch_id)
+ Returns:
+ Three floats.
+ elementwise_over_matmul_gemm_conv refers to zero elements' ratio in pruning layers.
+ elementwise_over_all refers to zero elements' ratio in all layers in the model.
+ blockwise_over_matmul_gemm_conv refers to all-zero blocks' ratio in pruning layers.
+ """
+ pattern_sparsity_cnt = 0
+ element_sparsity_cnt = 0
+ for pruner in self.pruners:
+ modules = pruner.modules
+ sparsity_ratio = pruner.pattern.get_sparsity_ratio(pruner.masks)
+ cnt = 0
+ for key in modules.keys():
+ cnt += modules[key].weight.numel()
+ pattern_sparsity_cnt += int(cnt * sparsity_ratio)
+ for key in pruner.masks.keys():
+ element_sparsity_cnt += torch.sum(pruner.masks[key] == 0).data.item()
+
+ linear_conv_cnt = 0
+ param_cnt = 0
+ for name, module in self.model.named_modules():
+ if type(module).__name__ in ["Linear"] or re.search(r'Conv.d', type(module).__name__) != None:
+ linear_conv_cnt += module.weight.numel()
+
+ for n, param in self.model.named_parameters():
+ param_cnt += param.numel()
+ if linear_conv_cnt == 0:
+ blockwise_over_matmul_gemm_conv = 0
+ elementwise_over_matmul_gemm_conv = 0
+ else:
+ blockwise_over_matmul_gemm_conv = float(pattern_sparsity_cnt) / linear_conv_cnt
+ elementwise_over_matmul_gemm_conv = float(element_sparsity_cnt) / linear_conv_cnt
+ if param_cnt == 0:
+ elementwise_over_all = 0
+ else:
+ elementwise_over_all = float(
+ element_sparsity_cnt) / param_cnt
+
+ return elementwise_over_matmul_gemm_conv, elementwise_over_all, blockwise_over_matmul_gemm_conv
+
+ def _generate_pruners(self):
+ """Obtain Pruner objects."""
+ assert isinstance(self.model, torch.nn.Module)
+
+ for info in self.pruners_info:
+ modules = parse_to_prune(info, self.model)
+ if modules == {}:
+ logger.warning("one pruner hooks no layers, please have a check")
+
+ self.pruners.append(get_pruner(info, modules))
+ info['modules'] = [key for key in modules.keys()]
+ info['len_of_modules'] = len(info['modules'])
+ logger.info(info)
+
+ # @_call_pruners
+ def on_train_begin(self):
+ """Implement at the beginning of training process.
+
+ Before training, ensure that pruners are generated.
+ """
+ self._generate_pruners() ##TODO is there better place to place
+ # @_call_pruners
+ def on_epoch_begin(self, epoch):
+ """Implement at the beginning of every epoch."""
+ for pruner in self.pruners:
+ pruner.on_epoch_begin(epoch)
+
+ # @_call_pruners
+ def on_step_begin(self, local_step):
+ """Implement at the beginning of every step."""
+ for pruner in self.pruners:
+ pruner.on_step_begin(local_step)
+
+ # @_call_pruners
+ def on_before_optimizer_step(self):
+ """Implement before optimizer.step()."""
+ for pruner in self.pruners:
+ pruner.on_before_optimizer_step()
+
+ # @_call_pruners
def on_step_end(self):
- """ called on the end of batches"""
- self.exp_pruner.on_step_end()
+ """Implement at the end of every step."""
+ for pruner in self.pruners:
+ pruner.on_step_end()
+ # @_call_pruners
def on_epoch_end(self):
- """ called on the end of epochs"""
- self.exp_pruner.on_epoch_end()
-
- @deprecated(version='2.0', reason="please use neural_compressor.prepare and neural_compressor.fit instead")
- def __call__(self, model, train_dataloader=None, pruning_func=None, eval_dataloader=None,
- eval_func=None):
- """The main entry point of pruning.
-
- This interface currently only works on pytorch
- and provides three usages:
- a) Fully yaml configuration: User specifies all the info through yaml,
- including dataloaders used in training and evaluation phases
- and pruning tuning settings.
-
- For this usage, only model parameter is mandatory.
-
- b) Partial yaml configuration: User specifies dataloaders used in training
- and evaluation phase by code.
- The tool provides built-in dataloaders and evaluators, user just need provide
- a dataset implemented __iter__ or __getitem__ methods and invoke dataloader()
- with dataset as input parameter to create neural_compressor dataloader before calling this
- function.
-
- After that, User specifies fp32 "model", train dataset "train_dataloader"
- and evaluation dataset "eval_dataloader".
- The trained and pruned model is evaluated with "eval_dataloader"
- with evaluation metrics specified in the configuration file. The evaluation tells
- the tuner whether the pruned model meets the accuracy criteria. If not,
- the tuner starts a new training and tuning flow.
-
- For this usage, model, q_dataloader and eval_dataloader parameters are mandatory.
-
- c) Partial yaml configuration: User specifies dataloaders used in training phase
- by code.
- This usage is quite similar with b), just user specifies a custom "eval_func"
- which encapsulates the evaluation dataset by itself.
- The trained and pruned model is evaluated with "eval_func".
- The "eval_func" tells the tuner whether the pruned model meets
- the accuracy criteria. If not, the Tuner starts a new training and tuning flow.
-
- For this usage, model, q_dataloader and eval_func parameters are mandatory.
-
- Args:
- model (object): For PyTorch model, it's torch.nn.model
- instance.
- train_dataloader (generator): Data loader for training. It is iterable
- and should yield a tuple (input, label) for
- training dataset containing label,
- or yield (input, _) for label-free training
- dataset. The input could be a object, list,
- tuple or dict, depending on user implementation,
- as well as it can be taken as model input.
- pruning_func (function, optional): Training function for pruning.
- This function takes "model" as input parameter
- and executes entire training process with self
- contained training hyper-parameters. If this
- parameter specified, eval_dataloader parameter
- plus metric defined in yaml, or eval_func
- parameter should also be specified at same time.
- eval_dataloader (generator, optional): Data loader for evaluation. It is iterable
- and should yield a tuple of (input, label).
- The input could be a object, list, tuple or
- dict, depending on user implementation,
- as well as it can be taken as model input.
- The label should be able to take as input of
- supported metrics. If this parameter is
- not None, user needs to specify pre-defined
- evaluation metrics through configuration file
- and should set "eval_func" paramter as None.
- Tuner will combine model, eval_dataloader
- and pre-defined metrics to run evaluation
- process.
- eval_func (function, optional): The evaluation function provided by user.
- This function takes model as parameter,
- and evaluation dataset and metrics should be
- encapsulated in this function implementation
- and outputs a higher-is-better accuracy scalar
- value.
-
- The pseudo code should be something like:
-
- def eval_func(model):
- input, label = dataloader()
- output = model(input)
- accuracy = metric(output, label)
- return accuracy
-
- Returns:
- pruned model: best pruned model found, otherwise return None
-
- """
- logger.warning("This API is going to be deprecated. Please import "
- "neural_compressor.experimental.Pruning, initialize an instance of `Pruning`,"
- "set its dataloader and metric attributes, then invoke its __call__ method.")
- self.exp_pruner.model = model
- self.exp_pruner.train_dataloader = train_dataloader
- self.exp_pruner.pruning_func = pruning_func
- self.exp_pruner.eval_dataloader = eval_dataloader
- self.exp_pruner.eval_func = eval_func
- return self.exp_pruner()
-
- fit = __call__
+ """Implement the end of every epoch."""
+ for pruner in self.pruners:
+ pruner.on_epoch_end()
+
+ # @_call_pruners
+ def on_train_end(self):
+ """Implement the end of training phase."""
+ for pruner in self.pruners:
+ pruner.on_train_end()
+
+ # @_call_pruners
+ def on_before_eval(self):
+ """Implement at the beginning of evaluation phase."""
+ for pruner in self.pruners:
+ pruner.on_before_eval()
+
+ # @_call_pruners
+ def on_after_eval(self):
+ """Implement at the end of evaluation phase."""
+ for pruner in self.pruners:
+ pruner.on_after_eval()
+
+ # @_call_pruners
+ def on_after_optimizer_step(self):
+ """Implement after optimizer.step()."""
+ for pruner in self.pruners:
+ pruner.on_after_optimizer_step()
diff --git a/test/pruning/test_pruning.py b/test/pruning/test_pruning.py
index 5871f6bcc34..57fdb9fd604 100644
--- a/test/pruning/test_pruning.py
+++ b/test/pruning/test_pruning.py
@@ -5,127 +5,71 @@
import torch
import torchvision
import torch.nn as nn
-
-from neural_compressor.config import Pruner, PruningConfig
from neural_compressor.data import Datasets
from neural_compressor.experimental.data.dataloaders.pytorch_dataloader import PyTorchDataLoader
-from neural_compressor.training import prepare_compression
-
-
-def build_fake_yaml():
- fake_yaml = """
- model:
- name: imagenet_prune
- framework: pytorch
-
- pruning:
- approach:
- weight_compression:
- initial_sparsity: 0.0
- target_sparsity: 0.97
- start_epoch: 0
- end_epoch: 2
- pruners:
- - !Pruner
- start_epoch: 1
- end_epoch: 2
- prune_type: basic_magnitude
- names: ['layer1.0.conv1.weight']
-
- - !Pruner
- target_sparsity: 0.6
- prune_type: basic_magnitude
- update_frequency: 2
- names: ['layer1.0.conv2.weight']
- """
- with open('fake.yaml', 'w', encoding="utf-8") as f:
- f.write(fake_yaml)
+from neural_compressor.pruning import Pruning, WeightPruningConfig
class TestPruning(unittest.TestCase):
-
model = torchvision.models.resnet18()
- @classmethod
- def setUpClass(cls):
- build_fake_yaml()
-
- @classmethod
- def tearDownClass(cls):
- os.remove('fake.yaml')
- shutil.rmtree('./saved', ignore_errors=True)
- shutil.rmtree('runs', ignore_errors=True)
-
- def test_pruning(self):
- pruner1 = Pruner(start_epoch=1, end_epoch=2, names=['layer1.0.conv1.weight'])
- pruner2 = Pruner(target_sparsity=0.6, update_frequency=2, names=['layer1.0.conv2.weight'])
- conf = PruningConfig(pruners=[pruner1, pruner2], end_epoch=2)
- datasets = Datasets('pytorch')
- dummy_dataset = datasets['dummy'](shape=(100, 3, 224, 224), low=0., high=1., label=True)
- dummy_dataloader = PyTorchDataLoader(dummy_dataset)
- compression_manager = prepare_compression(self.model, conf)
- model = compression_manager.model
+ def test_pruning_basic(self):
+ local_configs = [
+ {
+ "op_names": ['layer1.*'],
+ 'target_sparsity': 0.5,
+ "pattern": '8x2',
+ "pruning_type": "magnitude_progressive"
+ },
+ {
+ "op_names": ['layer2.*'],
+ 'target_sparsity': 0.5,
+ 'pattern': '2:4'
+ },
+ {
+ "op_names": ['layer3.*'],
+ 'target_sparsity': 0.7,
+ 'pattern': '5x1',
+ "pruning_type": "snip_progressive"
+ }
+ ]
+ config = WeightPruningConfig(
+ local_configs,
+ target_sparsity=0.8
+ )
+ prune = Pruning(config)
+ prune.update_config(start_step=1, end_step=10)
+ prune.model = self.model
- epochs = 2
- iters = 3
criterion = nn.CrossEntropyLoss()
- optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
- for nepoch in range(epochs):
- model.train()
- cnt = 0
- compression_manager.callbacks.on_epoch_begin(nepoch)
- for image, target in dummy_dataloader:
- compression_manager.callbacks.on_step_begin(cnt)
- print('.', end='')
- cnt += 1
- output = model(image)
- loss = criterion(output, target)
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- compression_manager.callbacks.on_step_end()
- if cnt >= iters:
- break
- compression_manager.callbacks.on_epoch_end()
-
- model.save("./saved")
-
- def test_pruning_external(self):
- from neural_compressor.experimental import common
- from neural_compressor import Pruning
- from neural_compressor.conf.config import PruningConf
- pruners = [Pruner(1,3,names=['layer1.0.conv1.weight']),
- Pruner(target_sparsity=0.6,update_frequency=2,names=['layer1.0.conv2.weight'])]
- conf = PruningConfig(pruners)
-
+ optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001)
datasets = Datasets('pytorch')
- dummy_dataset = datasets['dummy'](shape=(100, 3, 224, 224), low=0., high=1., label=True)
+ dummy_dataset = datasets['dummy'](shape=(10, 3, 224, 224), low=0., high=1., label=True)
dummy_dataloader = PyTorchDataLoader(dummy_dataset)
- compression_manager = prepare_compression(self.model, conf)
- model = compression_manager.model
- epochs = 2
- iters = 3
- criterion = nn.CrossEntropyLoss()
- optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
- for nepoch in range(epochs):
- model.train()
- cnt = 0
- compression_manager.callbacks.on_epoch_begin(nepoch)
+ prune.on_train_begin()
+ prune.update_config(pruning_frequency=4)
+ for epoch in range(2):
+ self.model.train()
+ prune.on_epoch_begin(epoch)
+ local_step = 0
for image, target in dummy_dataloader:
- compression_manager.callbacks.on_step_begin(cnt)
- print('.', end='')
- cnt += 1
- output = model(image)
+ prune.on_step_begin(local_step)
+ output = self.model(image)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
+ prune.on_before_optimizer_step()
optimizer.step()
- compression_manager.callbacks.on_step_end()
- if cnt >= iters:
- break
- compression_manager.callbacks.on_epoch_end()
- model.save("./saved")
+ prune.on_after_optimizer_step()
+ prune.on_step_end()
+ local_step += 1
+
+ prune.on_epoch_end()
+ prune.get_sparsity_ratio()
+ prune.on_train_end()
+ prune.on_before_eval()
+ prune.on_after_eval()
if __name__ == "__main__":
diff --git a/test/pruning/test_pruning_config.py b/test/pruning/test_pruning_config.py
new file mode 100644
index 00000000000..4430affbb49
--- /dev/null
+++ b/test/pruning/test_pruning_config.py
@@ -0,0 +1,80 @@
+import os
+import shutil
+import unittest
+
+import torch
+import torchvision
+import torch.nn as nn
+
+from neural_compressor.data import Datasets
+from neural_compressor.experimental.data.dataloaders.pytorch_dataloader import PyTorchDataLoader
+from neural_compressor.pruning import Pruning, WeightPruningConfig
+
+
+class TestPytorchPruning(unittest.TestCase):
+ model = torchvision.models.resnet18()
+
+ def test_pruning_class_config(self):
+ local_configs = [
+ {
+ "op_names": ['layer1.*', 'layer2.*'],
+ "excluded_op_names": ['downsample.*'],
+ 'target_sparsity': 0.6,
+ "pattern": 'channelx1',
+ "pruning_type": "snip_progressive",
+ "pruning_scope": "local",
+ "start_step": 0,
+ "end_step": 10
+ },
+ {
+ "op_names": ['layer3.*'],
+ "pruning_type": "pattern_lock"
+ }
+ ]
+ config = WeightPruningConfig(
+ local_configs,
+ pruning_frequency=2,
+ target_sparsity=0.8,
+ )
+ prune = Pruning(config)
+ prune.model = self.model
+
+ criterion = nn.CrossEntropyLoss()
+ optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001)
+ datasets = Datasets('pytorch')
+ dummy_dataset = datasets['dummy'](shape=(12, 3, 224, 224), low=0., high=1., label=True)
+ dummy_dataloader = PyTorchDataLoader(dummy_dataset)
+
+ prune.on_train_begin()
+ prune.update_config(pruning_frequency=4)
+ assert prune.pruners[0].config['pruning_frequency'] == 4
+ assert prune.pruners[0].config['target_sparsity'] == 0.6
+ assert prune.pruners[1].config['target_sparsity'] == 0.8
+ assert prune.pruners[0].config['pattern'] == "channelx1"
+ assert prune.pruners[1].config['pruning_type'] == 'pattern_lock'
+
+ for epoch in range(1):
+ self.model.train()
+ prune.on_epoch_begin(epoch)
+ local_step = 0
+ for image, target in dummy_dataloader:
+ prune.on_step_begin(local_step)
+ output = self.model(image)
+ loss = criterion(output, target)
+ optimizer.zero_grad()
+ loss.backward()
+ prune.on_before_optimizer_step()
+ optimizer.step()
+ prune.on_after_optimizer_step()
+ prune.on_step_end()
+ local_step += 1
+
+ prune.on_epoch_end()
+ prune.get_sparsity_ratio()
+ prune.on_train_end()
+ prune.on_before_eval()
+ prune.on_after_eval()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/pruning/test_pruning_criteria.py b/test/pruning/test_pruning_criteria.py
new file mode 100644
index 00000000000..03a54d60d7c
--- /dev/null
+++ b/test/pruning/test_pruning_criteria.py
@@ -0,0 +1,87 @@
+import os
+import shutil
+import unittest
+
+import torch
+import torchvision
+import torch.nn as nn
+
+from neural_compressor.data import Datasets
+from neural_compressor.experimental.data.dataloaders.pytorch_dataloader import PyTorchDataLoader
+from neural_compressor.pruning import Pruning, WeightPruningConfig
+
+
+class TestPruningCriteria(unittest.TestCase):
+ model = torchvision.models.resnet18()
+
+ def test_pruning_criteria(self):
+ local_configs = [
+ {
+ "op_names": ['layer1.*'],
+ 'target_sparsity': 0.4,
+ "pattern": '8x2',
+ "pruning_type": "magnitude_progressive",
+ "pruning_scope": "local",
+ "sparsity_decay_type": "cube"
+ },
+ {
+ "op_names": ['layer2.*'],
+ 'target_sparsity': 0.45,
+ 'pattern': '2:4',
+ "pruning_type": "snip",
+ 'start_step': 6,
+ 'end_step': 6
+ },
+ {
+ "op_names": ['layer3.*'],
+ 'excluded_op_names': ['downsample.*'],
+ 'target_sparsity': 0.7,
+ 'pattern': '4x1',
+ "pruning_type": "snip_momentum_progressive",
+ "pruning_frequency": 4,
+ "min_sparsity_ratio_per_op": 0.5,
+ "max_sparsity_ratio_per_op": 0.8,
+ }
+ ]
+ config = WeightPruningConfig(
+ local_configs,
+ target_sparsity=0.8,
+ sparsity_decay_type="cube"
+ )
+ prune = Pruning(config)
+ prune.update_config(start_step=1, end_step=10)
+ prune.model = self.model
+
+ criterion = nn.CrossEntropyLoss()
+ optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001)
+ datasets = Datasets('pytorch')
+ dummy_dataset = datasets['dummy'](shape=(10, 3, 224, 224), low=0., high=1., label=True)
+ dummy_dataloader = PyTorchDataLoader(dummy_dataset)
+
+ prune.on_train_begin()
+ prune.update_config(pruning_frequency=4)
+ for epoch in range(2):
+ self.model.train()
+ prune.on_epoch_begin(epoch)
+ local_step = 0
+ for image, target in dummy_dataloader:
+ prune.on_step_begin(local_step)
+ output = self.model(image)
+ loss = criterion(output, target)
+ optimizer.zero_grad()
+ loss.backward()
+ prune.on_before_optimizer_step()
+ optimizer.step()
+ prune.on_after_optimizer_step()
+ prune.on_step_end()
+ local_step += 1
+
+ prune.on_epoch_end()
+ prune.get_sparsity_ratio()
+ prune.on_train_end()
+ prune.on_before_eval()
+ prune.on_after_eval()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/pruning/test_pruning_patterns.py b/test/pruning/test_pruning_patterns.py
new file mode 100644
index 00000000000..f5f6db91f34
--- /dev/null
+++ b/test/pruning/test_pruning_patterns.py
@@ -0,0 +1,83 @@
+import os
+import shutil
+import unittest
+
+import torch
+import torchvision
+import torch.nn as nn
+
+from neural_compressor.data import Datasets
+from neural_compressor.experimental.data.dataloaders.pytorch_dataloader import PyTorchDataLoader
+from neural_compressor.pruning import Pruning, WeightPruningConfig
+
+
+class TestPruningPatterns(unittest.TestCase):
+ model = torchvision.models.resnet18()
+
+ def test_pruning_pattern(self):
+ local_configs = [
+ {
+ "op_names": ['layer1.*'],
+ 'target_sparsity': 0.5,
+ "pattern": '5:8',
+ "pruning_type": "magnitude"
+ },
+ {
+ "op_names": ['layer2.*'],
+ "pattern": '1xchannel',
+ "pruning_scope": "global"
+ },
+ {
+ "start_step": 2,
+ "end_step": 20,
+ "op_names": ['layer3.*'],
+ 'target_sparsity': 0.666666,
+ 'pattern': '4x2',
+ "pruning_type": "snip_progressive",
+ "pruning_frequency": 5
+ }
+ ]
+ config = WeightPruningConfig(
+ local_configs,
+ target_sparsity=0.8,
+ sparsity_decay_type="cos",
+ excluded_op_names=["downsample.*"],
+ pruning_scope="local",
+ min_sparsity_ratio_per_op=0.1
+ )
+ prune = Pruning(config)
+ prune.update_config(start_step=1, end_step=10)
+ prune.model = self.model
+
+ criterion = nn.CrossEntropyLoss()
+ optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001)
+ datasets = Datasets('pytorch')
+ dummy_dataset = datasets['dummy'](shape=(10, 3, 224, 224), low=0., high=1., label=True)
+ dummy_dataloader = PyTorchDataLoader(dummy_dataset)
+
+ prune.on_train_begin()
+ for epoch in range(5):
+ self.model.train()
+ prune.on_epoch_begin(epoch)
+ local_step = 0
+ for image, target in dummy_dataloader:
+ prune.on_step_begin(local_step)
+ output = self.model(image)
+ loss = criterion(output, target)
+ optimizer.zero_grad()
+ loss.backward()
+ prune.on_before_optimizer_step()
+ optimizer.step()
+ prune.on_after_optimizer_step()
+ prune.on_step_end()
+ local_step += 1
+
+ prune.on_epoch_end()
+ prune.get_sparsity_ratio()
+ prune.on_train_end()
+ prune.on_before_eval()
+ prune.on_after_eval()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/pruning/test_pruning_regs.py b/test/pruning/test_pruning_regs.py
new file mode 100644
index 00000000000..7da5f44852f
--- /dev/null
+++ b/test/pruning/test_pruning_regs.py
@@ -0,0 +1,98 @@
+import os
+import shutil
+import unittest
+
+import torch
+import torchvision
+import torch.nn as nn
+
+from neural_compressor.data import Datasets
+from neural_compressor.experimental.data.dataloaders.pytorch_dataloader import PyTorchDataLoader
+from neural_compressor.pruning import Pruning, WeightPruningConfig
+
+local_regs_config = [
+ {
+ "start_step": 0,
+ "end_step": 10,
+ "pruning_type": "magnitude",
+ "op_names": ['layer1.*'],
+ "excluded_op_names": ['layer2.*'],
+ "pruning_scope": "global",
+ "target_sparsity": 0.5,
+ "pattern": "4x1",
+ "reg_type": "group_lasso",
+ "parameters": {'reg_coeff': 0.2}
+ },
+ {
+ "start_step": 1,
+ "end_step": 1,
+ "target_sparsity": 0.5,
+ "pruning_type": "snip_momentum",
+ "pruning_frequency": 2,
+ "op_names": ['layer2.*'],
+ "pruning_scope": "local",
+ "target_sparsity": 0.75,
+ "pattern": "1x1",
+ "sparsity_decay_type": "exp",
+ "reg_type": "group_lasso",
+ "parameters": {'reg_coeff': 0.1}
+ },
+ {
+ "start_step": 2,
+ "end_step": 8,
+ "target_sparsity": 0.1,
+ "pruning_type": "gradient",
+ "pruning_frequency": 2,
+ "op_names": ['fc'],
+ "pruning_scope": "local",
+ "target_sparsity": 0.75,
+ "pattern": "1x1",
+ "sparsity_decay_type": "cube",
+ "reg_type": "group_lasso",
+ "parameters": {'reg_coeff': 0.0}
+ }
+]
+
+fake_snip_config = WeightPruningConfig(local_regs_config, target_sparsity=0.9, start_step=0, \
+ end_step=10, pruning_frequency=1, sparsity_decay_type="exp")
+
+
+class TestPruningRegs(unittest.TestCase):
+ model = torchvision.models.resnet18()
+
+ def test_pruning_regs(self):
+ prune = Pruning(fake_snip_config)
+ prune.update_config(start_step=1)
+ prune.model = self.model
+ criterion = nn.CrossEntropyLoss()
+ optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001)
+ datasets = Datasets('pytorch')
+ dummy_dataset = datasets['dummy'](shape=(10, 3, 224, 224), low=0., high=1., label=True)
+ dummy_dataloader = PyTorchDataLoader(dummy_dataset)
+ prune.on_train_begin()
+ prune.update_config(pruning_frequency=1)
+ for epoch in range(2):
+ self.model.train()
+ prune.on_epoch_begin(epoch)
+ local_step = 0
+ for image, target in dummy_dataloader:
+ prune.on_step_begin(local_step)
+ output = self.model(image)
+ loss = criterion(output, target)
+ optimizer.zero_grad()
+ loss.backward()
+ prune.on_before_optimizer_step()
+ optimizer.step()
+ prune.on_after_optimizer_step()
+ prune.on_step_end()
+ local_step += 1
+
+ prune.on_epoch_end()
+ prune.get_sparsity_ratio()
+ prune.on_train_end()
+ prune.on_before_eval()
+ prune.on_after_eval()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/pruning/test_pruning_schedulers.py b/test/pruning/test_pruning_schedulers.py
new file mode 100644
index 00000000000..272b766f661
--- /dev/null
+++ b/test/pruning/test_pruning_schedulers.py
@@ -0,0 +1,81 @@
+import os
+import shutil
+import unittest
+
+import torch
+import torchvision
+import torch.nn as nn
+
+from neural_compressor.data import Datasets
+from neural_compressor.experimental.data.dataloaders.pytorch_dataloader import PyTorchDataLoader
+from neural_compressor.pruning import Pruning, WeightPruningConfig
+
+local_schedulers_config = [
+ {
+ "start_step": 0,
+ "end_step": 2,
+ "pruning_type": "magnitude",
+ "op_names": ['layer1.*'],
+ "excluded_op_names": ['layer2.*'],
+ "pruning_scope": "global",
+ "target_sparsity": 0.5,
+ "pattern": "4x1"
+ },
+ {
+ "start_step": 1,
+ "end_step": 10,
+ "target_sparsity": 0.5,
+ "pruning_type": "snip_momentum",
+ "pruning_frequency": 2,
+ "op_names": ['layer2.*'],
+ "pruning_scope": "local",
+ "target_sparsity": 0.75,
+ "pattern": "32x1",
+ "sparsity_decay_type": "exp"
+ }
+]
+
+fake_snip_config = WeightPruningConfig(local_schedulers_config, target_sparsity=0.9, start_step=0, \
+ end_step=10, pruning_frequency=1, sparsity_decay_type="exp")
+
+
+class TestPruningCriteria(unittest.TestCase):
+ model = torchvision.models.resnet18()
+
+ def test_pruning_schedulers(self):
+
+ prune = Pruning(fake_snip_config)
+ prune.update_config(start_step=1)
+ prune.model = self.model
+ criterion = nn.CrossEntropyLoss()
+ optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001)
+ datasets = Datasets('pytorch')
+ dummy_dataset = datasets['dummy'](shape=(10, 3, 224, 224), low=0., high=1., label=True)
+ dummy_dataloader = PyTorchDataLoader(dummy_dataset)
+ prune.on_train_begin()
+ prune.update_config(pruning_frequency=1)
+ for epoch in range(2):
+ self.model.train()
+ prune.on_epoch_begin(epoch)
+ local_step = 0
+ for image, target in dummy_dataloader:
+ prune.on_step_begin(local_step)
+ output = self.model(image)
+ loss = criterion(output, target)
+ optimizer.zero_grad()
+ loss.backward()
+ prune.on_before_optimizer_step()
+ optimizer.step()
+ prune.on_after_optimizer_step()
+ prune.on_step_end()
+ local_step += 1
+
+ prune.on_epoch_end()
+ prune.get_sparsity_ratio()
+ prune.on_train_end()
+ prune.on_before_eval()
+ prune.on_after_eval()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/pruning/test_pruning_types.py b/test/pruning/test_pruning_types.py
new file mode 100644
index 00000000000..3adbc78452e
--- /dev/null
+++ b/test/pruning/test_pruning_types.py
@@ -0,0 +1,87 @@
+import os
+import shutil
+import unittest
+
+import torch
+import torchvision
+import torch.nn as nn
+
+from neural_compressor.data import Datasets
+from neural_compressor.experimental.data.dataloaders.pytorch_dataloader import PyTorchDataLoader
+from neural_compressor.pruning import Pruning, WeightPruningConfig
+
+local_types_config = [
+ {
+ "start_step": 0,
+ "end_step": 0,
+ "pruning_type": "pattern_lock",
+ "op_names": ['layer1.*'],
+ "excluded_op_names": ['layer2.*'],
+ "pruning_scope": "global"
+ },
+ {
+ "start_step": 1,
+ "end_step": 1,
+ "target_sparsity": 0.5,
+ "pruning_type": "snip_momentum_progressive",
+ "pruning_frequency": 2,
+ "op_names": ['layer2.*'],
+ "pruning_scope": "local",
+ "pattern": "4x1",
+ "sparsity_decay_type": "exp"
+ },
+ {
+ "start_step": 2,
+ "end_step": 8,
+ "target_sparsity": 0.8,
+ "pruning_type": "snip_progressive",
+ "pruning_frequency": 1,
+ "op_names": ['layer3.*'],
+ "pruning_scope": "local",
+ "pattern": "16x1",
+ "sparsity_decay_type": "cube"
+ }
+]
+
+fake_snip_config = WeightPruningConfig(local_types_config, target_sparsity=0.9, start_step=0, \
+ end_step=10, pruning_frequency=3, sparsity_decay_type="exp")
+
+
+class TestPruningTypes(unittest.TestCase):
+ model = torchvision.models.resnet18()
+
+ def test_pruning_types(self):
+ prune = Pruning(fake_snip_config)
+ prune.model = self.model
+ criterion = nn.CrossEntropyLoss()
+ optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001)
+ datasets = Datasets('pytorch')
+ dummy_dataset = datasets['dummy'](shape=(10, 3, 224, 224), low=0., high=1., label=True)
+ dummy_dataloader = PyTorchDataLoader(dummy_dataset)
+ prune.on_train_begin()
+ prune.update_config(pruning_frequency=1)
+ for epoch in range(2):
+ self.model.train()
+ prune.on_epoch_begin(epoch)
+ local_step = 0
+ for image, target in dummy_dataloader:
+ prune.on_step_begin(local_step)
+ output = self.model(image)
+ loss = criterion(output, target)
+ optimizer.zero_grad()
+ loss.backward()
+ prune.on_before_optimizer_step()
+ optimizer.step()
+ prune.on_after_optimizer_step()
+ prune.on_step_end()
+ local_step += 1
+
+ prune.on_epoch_end()
+ prune.get_sparsity_ratio()
+ prune.on_train_end()
+ prune.on_before_eval()
+ prune.on_after_eval()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/pruning/test_pytorch_pruning.py b/test/pruning/test_pytorch_pruning.py
deleted file mode 100644
index 5fb2047b7b2..00000000000
--- a/test/pruning/test_pytorch_pruning.py
+++ /dev/null
@@ -1,203 +0,0 @@
-import os
-import shutil
-import unittest
-
-import torch
-import torchvision
-import torch.nn as nn
-
-from neural_compressor.data import Datasets
-from neural_compressor.experimental.data.dataloaders.pytorch_dataloader import PyTorchDataLoader
-
-
-def build_fake_yaml_basic():
- fake_snip_yaml = """
- model:
- name: imagenet_prune
- framework: pytorch
-
- pruning:
- approach:
- weight_compression_pytorch:
- initial_sparsity: 0.0
- target_sparsity: 0.9
- start_step: 0
- end_step: 10
- excluded_names: ["classifier"]
-
- update_frequency_on_step: 1
- sparsity_decay_type: "exp"
- pruners:
- - !Pruner
- start_step: 0
- sparsity_decay_type: "cos"
- end_step: 10
- prune_type: "magnitude"
- names: ['layer1.*']
- extra_excluded_names: ['layer2.*']
- prune_domain: "global"
- pattern: "tile_pattern_4x1"
-
- - !Pruner
- start_step: 1
- end_step: 1
- target_sparsity: 0.5
- prune_type: "snip_momentum"
- update_frequency: 2
- names: ['layer2.*']
- prune_domain: local
- pattern: "tile_pattern_2:4"
-
- - !Pruner
- start_step: 2
- end_step: 8
- target_sparsity: 0.8
- prune_type: "snip"
- names: ['layer3.*']
- prune_domain: "local"
- pattern: "tile_pattern_16x1"
- sparsity_decay_type: "cube"
-
- """
- with open('fake_snip.yaml', 'w', encoding="utf-8") as f:
- f.write(fake_snip_yaml)
-
-def build_fake_yaml_channel():
- fake_channel_pruning_yaml = """
- model:
- name: imagenet_prune
- framework: pytorch
-
- pruning:
- approach:
- weight_compression_pytorch:
- initial_sparsity: 0.0
- target_sparsity: 0.9
- start_step: 0
- end_step: 10
- excluded_names: ["classifier"]
-
- update_frequency_on_step: 1
- sparsity_decay_type: "exp"
- pruners:
- - !Pruner
- start_step: 5
- end_step: 5
- prune_type: "pattern_lock"
- names: ['layer1.*']
- extra_excluded_names: ['layer2.*']
- prune_domain: "global"
- pattern: "channelx1"
-
- - !Pruner
- start_step: 1
- end_step: 1
- target_sparsity: 0.5
- prune_type: "pattern_lock"
- update_frequency: 2
- names: ['layer2.*']
- prune_domain: local
- pattern: "2:4"
-
- - !Pruner
- start_step: 2
- end_step: 8
- target_sparsity: 0.8
- prune_type: "snip"
- names: ['layer3.*']
- prune_domain: "local"
- pattern: "1xchannel"
- sparsity_decay_type: "cube"
-
- """
-
- with open('fake_channel_pruning.yaml', 'w', encoding="utf-8") as f:
- f.write(fake_channel_pruning_yaml)
-
-
-class TestPytorchPruning(unittest.TestCase):
-
- model = torchvision.models.resnet18()
-
- @classmethod
- def setUpClass(cls):
- build_fake_yaml_basic()
- build_fake_yaml_channel()
-
-
- @classmethod
- def tearDownClass(cls):
- os.remove('fake_channel_pruning.yaml')
- os.remove('fake_snip.yaml')
- shutil.rmtree('./saved', ignore_errors=True)
- shutil.rmtree('runs', ignore_errors=True)
-
- def test_pytorch_pruning_basic(self):
- from neural_compressor.experimental.pytorch_pruner.pruning import Pruning
-
- prune = Pruning("fake_snip.yaml")
- ##prune.generate_pruners()
- prune.update_items_for_all_pruners(start_step=1)
- prune.model = self.model
- criterion = nn.CrossEntropyLoss()
- optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001)
- datasets = Datasets('pytorch')
- dummy_dataset = datasets['dummy'](shape=(10, 3, 224, 224), low=0., high=1., label=True)
- dummy_dataloader = PyTorchDataLoader(dummy_dataset)
- prune.on_train_begin()
- prune.update_items_for_all_pruners(update_frequency_on_step=1)
- for epoch in range(2):
- self.model.train()
- prune.on_epoch_begin(epoch)
- local_step = 0
- for image, target in dummy_dataloader:
- prune.on_step_begin(local_step)
- output = self.model(image)
- loss = criterion(output, target)
- optimizer.zero_grad()
- loss.backward()
- prune.on_before_optimizer_step()
- optimizer.step()
- prune.on_after_optimizer_step()
- prune.on_step_end()
- local_step += 1
-
- prune.on_epoch_end()
- prune.get_sparsity_ratio()
- prune.on_train_end()
- prune.on_before_eval()
- prune.on_after_eval()
-
- def test_pytorch_pruner_channel_pruning(self):
- from neural_compressor.experimental.pytorch_pruner.pruning import Pruning
- prune = Pruning("fake_channel_pruning.yaml")
- ##prune.generate_pruners()
- prune.model = self.model
- criterion = nn.CrossEntropyLoss()
- optimizer = torch.optim.SGD(self.model.parameters(), lr=0.0001)
- datasets = Datasets('pytorch')
- dummy_dataset = datasets['dummy'](shape=(10, 3, 224, 224), low=0., high=1., label=True)
- dummy_dataloader = PyTorchDataLoader(dummy_dataset)
- prune.on_train_begin()
- for epoch in range(2):
- self.model.train()
- prune.on_epoch_begin(epoch)
- local_step = 0
- for image, target in dummy_dataloader:
- prune.on_step_begin(local_step)
- output = self.model(image)
- loss = criterion(output, target)
- optimizer.zero_grad()
- loss.backward()
- prune.on_before_optimizer_step()
- optimizer.step()
- prune.on_after_optimizer_step()
- prune.on_step_end()
- local_step += 1
-
- prune.on_epoch_end()
-
-if __name__ == "__main__":
- unittest.main()
-
-
diff --git a/test/pruning/test_gradient_sensitivity.py b/test/pruning_v1/test_gradient_sensitivity.py
similarity index 100%
rename from test/pruning/test_gradient_sensitivity.py
rename to test/pruning_v1/test_gradient_sensitivity.py
diff --git a/test/pruning/test_pattern_lock.py b/test/pruning_v1/test_pattern_lock.py
similarity index 100%
rename from test/pruning/test_pattern_lock.py
rename to test/pruning_v1/test_pattern_lock.py
diff --git a/test/pruning_v1/test_pruning.py b/test/pruning_v1/test_pruning.py
new file mode 100644
index 00000000000..5871f6bcc34
--- /dev/null
+++ b/test/pruning_v1/test_pruning.py
@@ -0,0 +1,132 @@
+import os
+import shutil
+import unittest
+
+import torch
+import torchvision
+import torch.nn as nn
+
+from neural_compressor.config import Pruner, PruningConfig
+from neural_compressor.data import Datasets
+from neural_compressor.experimental.data.dataloaders.pytorch_dataloader import PyTorchDataLoader
+from neural_compressor.training import prepare_compression
+
+
+def build_fake_yaml():
+ fake_yaml = """
+ model:
+ name: imagenet_prune
+ framework: pytorch
+
+ pruning:
+ approach:
+ weight_compression:
+ initial_sparsity: 0.0
+ target_sparsity: 0.97
+ start_epoch: 0
+ end_epoch: 2
+ pruners:
+ - !Pruner
+ start_epoch: 1
+ end_epoch: 2
+ prune_type: basic_magnitude
+ names: ['layer1.0.conv1.weight']
+
+ - !Pruner
+ target_sparsity: 0.6
+ prune_type: basic_magnitude
+ update_frequency: 2
+ names: ['layer1.0.conv2.weight']
+ """
+ with open('fake.yaml', 'w', encoding="utf-8") as f:
+ f.write(fake_yaml)
+
+
+class TestPruning(unittest.TestCase):
+
+ model = torchvision.models.resnet18()
+
+ @classmethod
+ def setUpClass(cls):
+ build_fake_yaml()
+
+ @classmethod
+ def tearDownClass(cls):
+ os.remove('fake.yaml')
+ shutil.rmtree('./saved', ignore_errors=True)
+ shutil.rmtree('runs', ignore_errors=True)
+
+ def test_pruning(self):
+ pruner1 = Pruner(start_epoch=1, end_epoch=2, names=['layer1.0.conv1.weight'])
+ pruner2 = Pruner(target_sparsity=0.6, update_frequency=2, names=['layer1.0.conv2.weight'])
+ conf = PruningConfig(pruners=[pruner1, pruner2], end_epoch=2)
+ datasets = Datasets('pytorch')
+ dummy_dataset = datasets['dummy'](shape=(100, 3, 224, 224), low=0., high=1., label=True)
+ dummy_dataloader = PyTorchDataLoader(dummy_dataset)
+ compression_manager = prepare_compression(self.model, conf)
+ model = compression_manager.model
+
+ epochs = 2
+ iters = 3
+ criterion = nn.CrossEntropyLoss()
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
+ for nepoch in range(epochs):
+ model.train()
+ cnt = 0
+ compression_manager.callbacks.on_epoch_begin(nepoch)
+ for image, target in dummy_dataloader:
+ compression_manager.callbacks.on_step_begin(cnt)
+ print('.', end='')
+ cnt += 1
+ output = model(image)
+ loss = criterion(output, target)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ compression_manager.callbacks.on_step_end()
+ if cnt >= iters:
+ break
+ compression_manager.callbacks.on_epoch_end()
+
+ model.save("./saved")
+
+ def test_pruning_external(self):
+ from neural_compressor.experimental import common
+ from neural_compressor import Pruning
+ from neural_compressor.conf.config import PruningConf
+ pruners = [Pruner(1,3,names=['layer1.0.conv1.weight']),
+ Pruner(target_sparsity=0.6,update_frequency=2,names=['layer1.0.conv2.weight'])]
+ conf = PruningConfig(pruners)
+
+ datasets = Datasets('pytorch')
+ dummy_dataset = datasets['dummy'](shape=(100, 3, 224, 224), low=0., high=1., label=True)
+ dummy_dataloader = PyTorchDataLoader(dummy_dataset)
+ compression_manager = prepare_compression(self.model, conf)
+ model = compression_manager.model
+
+ epochs = 2
+ iters = 3
+ criterion = nn.CrossEntropyLoss()
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
+ for nepoch in range(epochs):
+ model.train()
+ cnt = 0
+ compression_manager.callbacks.on_epoch_begin(nepoch)
+ for image, target in dummy_dataloader:
+ compression_manager.callbacks.on_step_begin(cnt)
+ print('.', end='')
+ cnt += 1
+ output = model(image)
+ loss = criterion(output, target)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ compression_manager.callbacks.on_step_end()
+ if cnt >= iters:
+ break
+ compression_manager.callbacks.on_epoch_end()
+ model.save("./saved")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/pruning/test_pruning_group_lasso.py b/test/pruning_v1/test_pruning_group_lasso.py
similarity index 100%
rename from test/pruning/test_pruning_group_lasso.py
rename to test/pruning_v1/test_pruning_group_lasso.py
diff --git a/test/pruning/test_pruning_pattern.py b/test/pruning_v1/test_pruning_pattern.py
similarity index 100%
rename from test/pruning/test_pruning_pattern.py
rename to test/pruning_v1/test_pruning_pattern.py
diff --git a/test/pruning/test_pruning_pure_yaml.py b/test/pruning_v1/test_pruning_pure_yaml.py
similarity index 100%
rename from test/pruning/test_pruning_pure_yaml.py
rename to test/pruning_v1/test_pruning_pure_yaml.py
diff --git a/test/pruning/test_tensorflow_distributed_pruning.py b/test/pruning_v1/test_tensorflow_distributed_pruning.py
similarity index 100%
rename from test/pruning/test_tensorflow_distributed_pruning.py
rename to test/pruning_v1/test_tensorflow_distributed_pruning.py
diff --git a/test/pruning/test_tensorflow_pruning.py b/test/pruning_v1/test_tensorflow_pruning.py
similarity index 100%
rename from test/pruning/test_tensorflow_pruning.py
rename to test/pruning_v1/test_tensorflow_pruning.py
diff --git a/test/pruning/test_tensorflow_pruning_utility.py b/test/pruning_v1/test_tensorflow_pruning_utility.py
similarity index 100%
rename from test/pruning/test_tensorflow_pruning_utility.py
rename to test/pruning_v1/test_tensorflow_pruning_utility.py