Skip to content

Commit 3449e2d

Browse files
edenlightningtchatonBordacarmoccaSeanNaren
authored
Docs for Pruning, Quantization, and SWA (#6041)
Co-authored-by: chaton <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Sean Naren <[email protected]>
1 parent f48a933 commit 3449e2d

File tree

6 files changed

+181
-27
lines changed

6 files changed

+181
-27
lines changed
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
.. testsetup:: *
2+
3+
import os
4+
from pytorch_lightning.trainer.trainer import Trainer
5+
from pytorch_lightning.core.lightning import LightningModule
6+
7+
.. _pruning_quantization:
8+
9+
########################
10+
Pruning and Quantization
11+
########################
12+
13+
Pruning and Quantization are techniques to compress model size for deployment, allowing inference speed up and energy saving without significant accuracy losses.
14+
15+
*******
16+
Pruning
17+
*******
18+
19+
.. warning::
20+
21+
Pruning is in beta and subject to change.
22+
23+
Pruning is a technique which focuses on eliminating some of the model weights to reduce the model size and decrease inference requirements.
24+
25+
Pruning has been shown to achieve significant efficiency improvements while minimizing the drop in model performance (prediction quality). Model pruning is recommended for cloud endpoints, deploying models on edge devices, or mobile inference (among others).
26+
27+
To enable pruning during training in Lightning, simply pass in the :class:`~pytorch_lightning.callbacks.ModelPruning` callback to the Lightning Trainer. PyTorch's native pruning implementation is used under the hood.
28+
29+
This callback supports multiple pruning functions: pass any `torch.nn.utils.prune <https://pytorch.org/docs/stable/nn.html#utilities>`_ function as a string to select which weights to prune (`random_unstructured <https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.random_unstructured.html#torch.nn.utils.prune.random_unstructured>`_, `RandomStructured <https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.RandomStructured.html#torch.nn.utils.prune.RandomStructured>`_, etc) or implement your own by subclassing `BasePruningMethod <https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#extending-torch-nn-utils-prune-with-custom-pruning-functions>`_.
30+
31+
.. code-block:: python
32+
33+
from pytorch_lightning.callbacks import ModelPruning
34+
35+
# set the amount to be the fraction of parameters to prune
36+
trainer = Trainer(callbacks=[ModelPruning("l1_unstructured", amount=0.5)])
37+
38+
You can also perform iterative pruning, apply the `lottery ticket hypothesis <https://arxiv.org/pdf/1803.03635.pdf>`__, and more!
39+
40+
.. code-block:: python
41+
42+
def compute_amount(epoch):
43+
# the sum of all returned values need to be smaller than 1
44+
if epoch == 10:
45+
return 0.5
46+
47+
elif epoch == 50:
48+
return 0.25
49+
50+
elif 75 < epoch < 99 :
51+
return 0.01
52+
53+
# the amount can be also be a callable
54+
trainer = Trainer(callbacks=[ModelPruning("l1_unstructured", amount=compute_amount)])
55+
56+
57+
************
58+
Quantization
59+
************
60+
61+
.. warning ::
62+
Quantization is in beta and subject to change.
63+
64+
Model quantization is another performance optimization technique that allows speeding up inference and decreasing memory requirements by performing computations and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating-point precision. This is particularly beneficial during model deployment.
65+
66+
Quantization Aware Training (QAT) mimics the effects of quantization during training: The computations are carried-out in floating-point precision but the subsequent quantization effect is taken into account. The weights and activations are quantized into lower precision only for inference, when training is completed.
67+
68+
Quantization is useful when it is required to serve large models on machines with limited memory, or when there's a need to switch between models and reducing the I/O time is important. For example, switching between monolingual speech recognition models across multiple languages.
69+
70+
Lightning includes :class:`~pytorch_lightning.callbacks.QuantizationAwareTraining` callback (using PyTorch's native quantization, read more `here <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`__), which allows creating fully quantized models (compatible with torchscript).
71+
72+
.. code-block:: python
73+
74+
from pytorch_lightning.callbacks import QuantizationAwareTraining
75+
76+
class RegressionModel(LightningModule):
77+
78+
def __init__(self):
79+
super().__init__()
80+
self.layer_0 = nn.Linear(16, 64)
81+
self.layer_0a = torch.nn.ReLU()
82+
self.layer_1 = nn.Linear(64, 64)
83+
self.layer_1a = torch.nn.ReLU()
84+
self.layer_end = nn.Linear(64, 1)
85+
86+
def forward(self, x):
87+
x = self.layer_0(x)
88+
x = self.layer_0a(x)
89+
x = self.layer_1(x)
90+
x = self.layer_1a(x)
91+
x = self.layer_end(x)
92+
return x
93+
94+
trainer = Trainer(callbacks=[QuantizationAwareTraining()])
95+
qmodel = RegressionModel()
96+
trainer.fit(qmodel, ...)
97+
98+
batch = iter(my_dataloader()).next()
99+
qmodel(qmodel.quant(batch[0]))
100+
101+
tsmodel = qmodel.to_torchscript()
102+
tsmodel(tsmodel.quant(batch[0]))
103+
104+
You can further customize the callback:
105+
106+
.. code-block:: python
107+
108+
109+
qcb = QuantizationAwareTraining(
110+
# specification of quant estimation quality
111+
observer_type='histogram',
112+
# specify which layers shall be merged together to increase efficiency
113+
modules_to_fuse=[(f'layer_{i}', f'layer_{i}a') for i in range(2)]
114+
# make your model compatible with all original input/outputs, in such case the model is wrapped in a shell with entry/exit layers.
115+
input_compatible=True
116+
)
117+
118+
batch = iter(my_dataloader()).next()
119+
qmodel(batch[0])

docs/source/advanced/training_tricks.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,24 @@ norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_
4141

4242
----------
4343

44+
Stochastic Weight Averaging
45+
---------------------------
46+
Stochastic Weight Averaging (SWA) can make your models generalize better at virtually no additional cost.
47+
This can be used with both non-trained and trained models. The SWA procedure smooths the loss landscape thus making
48+
it harder to end up in a local minimum during optimization.
49+
50+
For a more detailed explanation of SWA and how it works,
51+
read `this <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ post by the PyTorch team.
52+
53+
.. seealso:: :class:`~pytorch_lightning.callbacks.StochasticWeightAveraging` (Callback)
54+
55+
.. testcode::
56+
57+
# Enable Stochastic Weight Averaging
58+
trainer = Trainer(stochastic_weight_avg=True)
59+
60+
----------
61+
4462
Auto scaling of batch size
4563
--------------------------
4664
Auto scaling of batch size may be enabled to find the largest batch size that fits into

docs/source/extensions/callbacks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ Lightning has a few built-in callbacks.
106106
ModelPruning
107107
ProgressBar
108108
ProgressBarBase
109+
QuantizationAwareTraining
109110
StochasticWeightAveraging
110111

111112
----------

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ PyTorch Lightning Documentation
111111
common/single_gpu
112112
advanced/sequences
113113
advanced/training_tricks
114+
advanced/pruning_quantization
114115
advanced/transfer_learning
115116
advanced/tpu
116117
advanced/cluster

pytorch_lightning/callbacks/quantization.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,6 @@ def _recursive_hasattr(obj: Any, attribs: str, state: bool = True) -> bool:
8383

8484

8585
class QuantizationAwareTraining(Callback):
86-
"""
87-
Quantization allows speeding up inference and decreasing memory requirements by performing computations
88-
and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision.
89-
We use native PyTorch API so for more information see
90-
`Quantization <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>_`
91-
92-
.. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.
93-
"""
94-
9586
OBSERVER_TYPES = ('histogram', 'average')
9687

9788
def __init__(
@@ -103,31 +94,49 @@ def __init__(
10394
input_compatible: bool = True,
10495
) -> None:
10596
"""
97+
Quantization allows speeding up inference and decreasing memory requirements
98+
by performing computations and storing tensors at lower bitwidths
99+
(such as INT8 or FLOAT16) than floating point precision.
100+
We use native PyTorch API so for more information
101+
see `Quantization <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_.
102+
103+
.. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.
104+
105+
106106
Args:
107-
qconfig: define quantization configuration see: `torch.quantization.QConfig
108-
<https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig>_`
109-
or use pre-defined: 'fbgemm' for server inference and 'qnnpack' for mobile inference
107+
108+
qconfig: quantization configuration:
109+
110+
- 'fbgemm' for server inference.
111+
- 'qnnpack' for mobile inference.
112+
- a custom `torch.quantization.QConfig <https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig>`_.
113+
110114
observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default)
111-
and ``HistogramObserver`` as "histogram" which is more computationally expensive
112-
collect_quantization: count or custom function to collect quantization statistics
115+
and ``HistogramObserver`` as "histogram" which is more computationally expensive.
116+
117+
collect_quantization: count or custom function to collect quantization statistics:
118+
119+
- ``None`` (deafult). The quantization observer is called in each module forward
120+
(useful for collecting extended statistic when useing image/data augmentation).
121+
- ``int``. Use to set a fixed number of calls, starting from the beginning.
122+
- ``Callable``. Custom function with single trainer argument.
123+
See this example to trigger only the last epoch:
113124
114-
- with default ``None`` the quantization observer is called each module forward,
115-
typical use-case can be collecting extended statistic when user uses image/data augmentation
116-
- custom call count to set a fixed number of calls, starting from the beginning
117-
- custom ``Callable`` function with single trainer argument,
118-
see example when you limit call only for last epoch::
125+
.. code-block:: python
119126
120-
def custom_trigger_last(trainer):
121-
return trainer.current_epoch == (trainer.max_epochs - 1)
127+
def custom_trigger_last(trainer):
128+
return trainer.current_epoch == (trainer.max_epochs - 1)
122129
123-
QuantizationAwareTraining(collect_quantization=custom_trigger_last)
130+
QuantizationAwareTraining(collect_quantization=custom_trigger_last)
131+
132+
modules_to_fuse: allows you fuse a few layers together as shown in
133+
`diagram <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_
134+
to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286.
124135
125-
modules_to_fuse: allows you fuse a few layers together as shown in `diagram
126-
<https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>_`
127-
to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286
128136
input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model,
129-
but break compatibility to torchscript
130-
"""
137+
but break compatibility to torchscript.
138+
139+
""" # noqa: E501
131140
_valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines
132141
if not isinstance(qconfig, QConfig) and not _valid_qconf_str:
133142
raise MisconfigurationException(

pytorch_lightning/callbacks/swa.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ def __init__(
6363
6464
.. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers.
6565
66+
SWA can easily be activated directly from the Trainer as follow:
67+
68+
.. code-block:: python
69+
70+
Trainer(stochastic_weight_avg=True)
71+
6672
Arguments:
6773
6874
swa_epoch_start: If provided as int, the procedure will start from

0 commit comments

Comments
 (0)