|
7 | 7 | What is Channels Last |
8 | 8 | --------------------- |
9 | 9 |
|
10 | | -Channels Last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels Last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). |
| 10 | +Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). |
11 | 11 |
|
12 | 12 | For example, classic (contiguous) storage of NCHW tensor (in our case it is two 2x2 images with 3 color channels) look like this: |
13 | 13 |
|
14 | 14 | .. figure:: /_static/img/classic_memory_format.png |
15 | 15 | :alt: classic_memory_format |
16 | 16 |
|
17 | | -Channels Last memory format orders data differently: |
| 17 | +Channels last memory format orders data differently: |
18 | 18 |
|
19 | 19 | .. figure:: /_static/img/channels_last_memory_format.png |
20 | 20 | :alt: channels_last_memory_format |
21 | 21 |
|
22 | 22 | Pytorch supports memory formats (and provides back compatibility with existing models including eager, JIT, and TorchScript) by utilizing existing strides structure. |
23 | | -For example, 10x3x16x16 batch in Channels Last format will have strides equal to (768, 1, 48, 3). |
| 23 | +For example, 10x3x16x16 batch in Channels last format will have strides equal to (768, 1, 48, 3). |
24 | 24 | """ |
25 | 25 |
|
26 | 26 | ###################################################################### |
27 | | -# Channels Last memory format is implemented for 4D NCWH Tensors only. |
| 27 | +# Channels last memory format is implemented for 4D NCWH Tensors only. |
28 | 28 | # |
29 | 29 |
|
30 | | -import torch |
31 | | -N, C, H, W = 10, 3, 32, 32 |
32 | | - |
33 | 30 | ###################################################################### |
34 | 31 | # Memory Format API |
35 | 32 | # ----------------------- |
|
39 | 36 |
|
40 | 37 | ###################################################################### |
41 | 38 | # Classic PyTorch contiguous tensor |
| 39 | +import torch |
| 40 | +N, C, H, W = 10, 3, 32, 32 |
42 | 41 | x = torch.empty(N, C, H, W) |
43 | 42 | print(x.stride()) # Ouputs: (3072, 1024, 32, 1) |
44 | 43 |
|
45 | 44 | ###################################################################### |
46 | 45 | # Conversion operator |
47 | | -x = x.contiguous(memory_format=torch.channels_last) |
| 46 | +x = x.to(memory_format=torch.channels_last) |
48 | 47 | print(x.shape) # Outputs: (10, 3, 32, 32) as dimensions order preserved |
49 | 48 | print(x.stride()) # Outputs: (3072, 1, 96, 3) |
50 | 49 |
|
51 | 50 | ###################################################################### |
52 | 51 | # Back to contiguous |
53 | | -x = x.contiguous(memory_format=torch.contiguous_format) |
| 52 | +x = x.to(memory_format=torch.contiguous_format) |
54 | 53 | print(x.stride()) # Outputs: (3072, 1024, 32, 1) |
55 | 54 |
|
56 | 55 | ###################################################################### |
57 | 56 | # Alternative option |
58 | | -x = x.to(memory_format=torch.channels_last) |
| 57 | +x = x.contiguous(memory_format=torch.channels_last) |
59 | 58 | print(x.stride()) # Ouputs: (3072, 1, 96, 3) |
60 | 59 |
|
61 | 60 | ###################################################################### |
62 | 61 | # Format checks |
63 | 62 | print(x.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True |
64 | 63 |
|
65 | 64 | ###################################################################### |
66 | | -# Create as Channels Last |
| 65 | +# There are minor difference between the two APIs ``to`` and |
| 66 | +# ``contiguous``. We suggest to stick with ``to`` when explicitly |
| 67 | +# converting memory format of tensor. |
| 68 | +# |
| 69 | +# For general cases the two APIs behave the same. However in special |
| 70 | +# cases for a 4D tensor with size ``NCHW`` when either: ``C==1`` or |
| 71 | +# ``H==1 && W==1``, only ``to`` would generate a proper stride to |
| 72 | +# represent channels last memory format. |
| 73 | +# |
| 74 | +# This is because in either of the two cases above, the memory format |
| 75 | +# of a tensor is ambiguous, i.e. a contiguous tensor with size |
| 76 | +# ``N1HW`` is both ``contiguous`` and channels last in memory storage. |
| 77 | +# Therefore, they are already considered as ``is_contiguous`` |
| 78 | +# for the given memory format and hence ``contiguous`` call becomes a |
| 79 | +# no-op and would not update the stride. On the contrary, ``to`` |
| 80 | +# would restride tensor with a meaningful stride on dimensions whose |
| 81 | +# sizes are 1 in order to properly represent the intended memory |
| 82 | +# format |
| 83 | +special_x = torch.empty(4, 1, 4, 4) |
| 84 | +print(special_x.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True |
| 85 | +print(special_x.is_contiguous(memory_format=torch.contiguous_format)) # Ouputs: True |
| 86 | + |
| 87 | +###################################################################### |
| 88 | +# Same thing applies to explicit permutation API ``permute``. In |
| 89 | +# special case where ambiguity could occur, ``permute`` does not |
| 90 | +# guarantee to produce a stride that properly carry the intended |
| 91 | +# memory format. We suggest to use ``to`` with explicit memory format |
| 92 | +# to avoid unintended behavior. |
| 93 | +# |
| 94 | +# And a side note that in the extreme case, where three non-batch |
| 95 | +# dimensions are all equal to ``1`` (``C==1 && H==1 && W==1``), |
| 96 | +# current implementation cannot mark a tensor as channels last memory |
| 97 | +# format. |
| 98 | + |
| 99 | +###################################################################### |
| 100 | +# Create as channels last |
67 | 101 | x = torch.empty(N, C, H, W, memory_format=torch.channels_last) |
68 | 102 | print(x.stride()) # Ouputs: (3072, 1, 96, 3) |
69 | 103 |
|
|
89 | 123 | print(z.stride()) # Ouputs: (3072, 1, 96, 3) |
90 | 124 |
|
91 | 125 | ###################################################################### |
92 | | -# Conv, Batchnorm modules support Channels Last |
93 | | -# (only works for CudNN >= 7.6) |
| 126 | +# Conv, Batchnorm modules using cudnn backends support channels last |
| 127 | +# (only works for CudNN >= 7.6). Convolution modules, unlike binary |
| 128 | +# p-wise operator, have channels last as the dominating memory format. |
| 129 | +# IFF all inputs are in contiguous memory format, the operator |
| 130 | +# produces output in contiguous memory format. Otherwise, output wil |
| 131 | +# be in channels last memroy format. |
| 132 | + |
94 | 133 | if torch.backends.cudnn.version() >= 7603: |
95 | | - input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True) |
96 | | - model = torch.nn.Conv2d(8, 4, 3).cuda().float() |
| 134 | + model = torch.nn.Conv2d(8, 4, 3).cuda().half() |
| 135 | + model = model.to(memory_format=torch.channels_last) # Module parameters need to be channels last |
97 | 136 |
|
98 | | - input = input.contiguous(memory_format=torch.channels_last) |
99 | | - model = model.to(memory_format=torch.channels_last) # Module parameters need to be Channels Last |
| 137 | + input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, requires_grad=True) |
| 138 | + input = input.to(device="cuda", memory_format=torch.channels_last, dtype=torch.float16) |
100 | 139 |
|
101 | 140 | out = model(input) |
102 | 141 | print(out.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True |
103 | 142 |
|
| 143 | +###################################################################### |
| 144 | +# When input tensor reaches a operator without channels last support, |
| 145 | +# a permutation should automatically apply in the kernel to restore |
| 146 | +# contiguous on input tensor. This introduces overhead and stops the |
| 147 | +# channels last memory format propagation. Nevertheless, it guarantees |
| 148 | +# correct output. |
| 149 | + |
104 | 150 | ###################################################################### |
105 | 151 | # Performance Gains |
106 | | -# ------------------------------------------------------------------------------------------- |
107 | | -# The most significant performance gains are observed on Nvidia's hardware with |
108 | | -# Tensor Cores support. We were able to archive over 22% perf gains while running ' |
109 | | -# AMP (Automated Mixed Precision) training scripts supplied by Nvidia https://github.com/NVIDIA/apex. |
| 152 | +# -------------------------------------------------------------------- |
| 153 | +# The most significant performance gains are observed on NVidia's |
| 154 | +# hardware with Tensor Cores support running on reduced precision |
| 155 | +# (``torch.float16``). |
| 156 | +# We were able to archive over 22% perf gains with channels last |
| 157 | +# comparing to contiguous format, both while utilizing |
| 158 | +# 'AMP (Automated Mixed Precision)' training scripts. |
| 159 | +# Our scripts uses AMP supplied by NVidia |
| 160 | +# https://github.com/NVIDIA/apex. |
110 | 161 | # |
111 | 162 | # ``python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 ./data`` |
112 | 163 |
|
|
143 | 194 | # Epoch: [0][80/125] Time 0.260 (0.335) Speed 770.324 (597.659) Loss 2.2505953312 (1.0879) Prec@1 50.500 (52.938) Prec@5 100.000 (100.000) |
144 | 195 |
|
145 | 196 | ###################################################################### |
146 | | -# Passing ``--channels-last true`` allows running a model in Channels Last format with observed 22% perf gain. |
147 | | -# |
| 197 | +# Passing ``--channels-last true`` allows running a model in Channels last format with observed 22% perf gain. |
| 198 | +# |
148 | 199 | # ``python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data`` |
149 | 200 |
|
150 | 201 | # opt_level = O2 |
|
184 | 235 | # Epoch: [0][80/125] Time 0.198 (0.269) Speed 1011.827 (743.883) Loss 2.8196096420 (2.4011) Prec@1 47.500 (50.938) Prec@5 100.000 (100.000) |
185 | 236 |
|
186 | 237 | ###################################################################### |
187 | | -# The following list of models has the full support of Channels Last and showing 8%-35% perf gains on Volta devices: |
| 238 | +# The following list of models has the full support of Channels last and showing 8%-35% perf gains on Volta devices: |
188 | 239 | # ``alexnet``, ``mnasnet0_5``, ``mnasnet0_75``, ``mnasnet1_0``, ``mnasnet1_3``, ``mobilenet_v2``, ``resnet101``, ``resnet152``, ``resnet18``, ``resnet34``, ``resnet50``, ``resnext50_32x4d``, ``shufflenet_v2_x0_5``, ``shufflenet_v2_x1_0``, ``shufflenet_v2_x1_5``, ``shufflenet_v2_x2_0``, ``squeezenet1_0``, ``squeezenet1_1``, ``vgg11``, ``vgg11_bn``, ``vgg13``, ``vgg13_bn``, ``vgg16``, ``vgg16_bn``, ``vgg19``, ``vgg19_bn``, ``wide_resnet101_2``, ``wide_resnet50_2`` |
189 | 240 | # |
190 | 241 |
|
191 | 242 | ###################################################################### |
192 | 243 | # Converting existing models |
193 | 244 | # -------------------------- |
194 | 245 | # |
195 | | -# Channels Last support not limited by existing models, as any model can be converted to Channels Last and propagate format through the graph as soon as input formatted correctly. |
| 246 | +# Channels last support is not limited by existing models, as any |
| 247 | +# model can be converted to channels last and propagate format through |
| 248 | +# the graph as soon as input (or certain weight) is formatted |
| 249 | +# correctly. |
196 | 250 | # |
197 | 251 |
|
198 | 252 | # Need to be done once, after model initialization (or load) |
|
203 | 257 | output = model(input) |
204 | 258 |
|
205 | 259 | ####################################################################### |
206 | | -# However, not all operators fully converted to support Channels Last (usually returning |
207 | | -# contiguous output instead). That means you need to verify the list of used operators |
208 | | -# against supported operators list https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support, |
| 260 | +# However, not all operators fully converted to support channels last |
| 261 | +# (usually returning contiguous output instead). In the example posted |
| 262 | +# above, layers that does not support channels last will stop the |
| 263 | +# memory format propagation. In spite of that, as we have converted the |
| 264 | +# model to channels last format, that means each convolution layer, |
| 265 | +# which has its 4 dimensional weight in channels last memory format, |
| 266 | +# will restore channels last memory format and benefit from faster |
| 267 | +# kernels. |
| 268 | +# |
| 269 | +# But operatos that does not support channels last does introduce |
| 270 | +# overhead by permutation. Optionally, you can investigate and identify |
| 271 | +# operatos in your model that does not support channels last, if you |
| 272 | +# want to improve the performance of converted model. |
| 273 | +# |
| 274 | +# That means you need to verify the list of used operators |
| 275 | +# against supported operators list https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support, |
209 | 276 | # or introduce memory format checks into eager execution mode and run your model. |
210 | 277 | # |
211 | 278 | # After running the code below, operators will raise an exception if the output of the |
@@ -284,8 +351,8 @@ def attribute(m): |
284 | 351 |
|
285 | 352 |
|
286 | 353 | ###################################################################### |
287 | | -# If you found an operator that doesn't support Channels Last tensors |
288 | | -# and you want to contribute, feel free to use following developers |
| 354 | +# If you found an operator that doesn't support channels last tensors |
| 355 | +# and you want to contribute, feel free to use following developers |
289 | 356 | # guide https://github.com/pytorch/pytorch/wiki/Writing-memory-format-aware-operators. |
290 | 357 | # |
291 | 358 |
|
|
0 commit comments