|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +(experimental) Channels Last Memory Format in PyTorch |
| 4 | +******************************************************* |
| 5 | +**Author**: `Vitaly Fedyunin <https://github.com/VitalyFedyunin>`_ |
| 6 | +
|
| 7 | +What is Channels Last |
| 8 | +--------------------- |
| 9 | +
|
| 10 | +Channels Last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels Last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). |
| 11 | +
|
| 12 | +For example, classic (contiguous) storage of NCHW tensor (in our case it is two 2x2 images with 3 color channels) look like this: |
| 13 | +
|
| 14 | +.. figure:: /_static/img/classic_memory_format.png |
| 15 | + :alt: classic_memory_format |
| 16 | +
|
| 17 | +Channels Last memory format orders data differently: |
| 18 | +
|
| 19 | +.. figure:: /_static/img/channels_last_memory_format.png |
| 20 | + :alt: channels_last_memory_format |
| 21 | +
|
| 22 | +Pytorch supports memory formats (and provides back compatibility with existing models including eager, JIT, and TorchScript) by utilizing existing strides structure. |
| 23 | +For example, 10x3x16x16 batch in Channels Last format will have strides equal to (768, 1, 48, 3). |
| 24 | +""" |
| 25 | + |
| 26 | +###################################################################### |
| 27 | +# Channels Last memory format is implemented for 4D NCWH Tensors only. |
| 28 | +# |
| 29 | + |
| 30 | +import torch |
| 31 | +N, C, H, W = 10, 3, 32, 32 |
| 32 | + |
| 33 | +###################################################################### |
| 34 | +# Memory Format API |
| 35 | +# ----------------------- |
| 36 | +# |
| 37 | +# Here is how to convert tensors between contiguous and channels |
| 38 | +# last memory formats. |
| 39 | + |
| 40 | +###################################################################### |
| 41 | +# Classic PyTorch contiguous tensor |
| 42 | +x = torch.empty(N, C, H, W) |
| 43 | +print(x.stride()) # Ouputs: (3072, 1024, 32, 1) |
| 44 | + |
| 45 | +###################################################################### |
| 46 | +# Conversion operator |
| 47 | +x = x.contiguous(memory_format=torch.channels_last) |
| 48 | +print(x.shape) # Outputs: (10, 3, 32, 32) as dimensions order preserved |
| 49 | +print(x.stride()) # Outputs: (3072, 1, 96, 3) |
| 50 | + |
| 51 | +###################################################################### |
| 52 | +# Back to contiguous |
| 53 | +x = x.contiguous(memory_format=torch.contiguous_format) |
| 54 | +print(x.stride()) # Outputs: (3072, 1024, 32, 1) |
| 55 | + |
| 56 | +###################################################################### |
| 57 | +# Alternative option |
| 58 | +x = x.to(memory_format=torch.channels_last) |
| 59 | +print(x.stride()) # Ouputs: (3072, 1, 96, 3) |
| 60 | + |
| 61 | +###################################################################### |
| 62 | +# Format checks |
| 63 | +print(x.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True |
| 64 | + |
| 65 | +###################################################################### |
| 66 | +# Create as Channels Last |
| 67 | +x = torch.empty(N, C, H, W, memory_format=torch.channels_last) |
| 68 | +print(x.stride()) # Ouputs: (3072, 1, 96, 3) |
| 69 | + |
| 70 | +###################################################################### |
| 71 | +# ``clone`` preserves memory format |
| 72 | +y = x.clone() |
| 73 | +print(y.stride()) # Ouputs: (3072, 1, 96, 3) |
| 74 | + |
| 75 | +###################################################################### |
| 76 | +# ``to``, ``cuda``, ``float`` ... preserves memory format |
| 77 | +if torch.cuda.is_available(): |
| 78 | + y = x.cuda() |
| 79 | + print(y.stride()) # Ouputs: (3072, 1, 96, 3) |
| 80 | + |
| 81 | +###################################################################### |
| 82 | +# ``empty_like``, ``*_like`` operators preserves memory format |
| 83 | +y = torch.empty_like(x) |
| 84 | +print(y.stride()) # Ouputs: (3072, 1, 96, 3) |
| 85 | + |
| 86 | +###################################################################### |
| 87 | +# Pointwise operators preserves memory format |
| 88 | +z = x + y |
| 89 | +print(z.stride()) # Ouputs: (3072, 1, 96, 3) |
| 90 | + |
| 91 | +###################################################################### |
| 92 | +# Conv, Batchnorm modules support Channels Last |
| 93 | +# (only works for CudNN >= 7.6) |
| 94 | +if torch.backends.cudnn.version() >= 7603: |
| 95 | + input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True) |
| 96 | + model = torch.nn.Conv2d(8, 4, 3).cuda().float() |
| 97 | + |
| 98 | + input = input.contiguous(memory_format=torch.channels_last) |
| 99 | + model = model.to(memory_format=torch.channels_last) # Module parameters need to be Channels Last |
| 100 | + |
| 101 | + out = model(input) |
| 102 | + print(out.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True |
| 103 | + |
| 104 | +###################################################################### |
| 105 | +# Performance Gains |
| 106 | +# ------------------------------------------------------------------------------------------- |
| 107 | +# The most significant performance gains are observed on NVidia's hardware with |
| 108 | +# Tensor Cores support. We were able to archive over 22% perf gains while running ' |
| 109 | +# AMP (Automated Mixed Precision) training scripts supplied by NVidia https://github.com/NVIDIA/apex. |
| 110 | +# |
| 111 | +# ``python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 ./data`` |
| 112 | + |
| 113 | +# opt_level = O2 |
| 114 | +# keep_batchnorm_fp32 = None <class 'NoneType'> |
| 115 | +# loss_scale = None <class 'NoneType'> |
| 116 | +# CUDNN VERSION: 7603 |
| 117 | +# => creating model 'resnet50' |
| 118 | +# Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights. |
| 119 | +# Defaults for this optimization level are: |
| 120 | +# enabled : True |
| 121 | +# opt_level : O2 |
| 122 | +# cast_model_type : torch.float16 |
| 123 | +# patch_torch_functions : False |
| 124 | +# keep_batchnorm_fp32 : True |
| 125 | +# master_weights : True |
| 126 | +# loss_scale : dynamic |
| 127 | +# Processing user overrides (additional kwargs that are not None)... |
| 128 | +# After processing overrides, optimization options are: |
| 129 | +# enabled : True |
| 130 | +# opt_level : O2 |
| 131 | +# cast_model_type : torch.float16 |
| 132 | +# patch_torch_functions : False |
| 133 | +# keep_batchnorm_fp32 : True |
| 134 | +# master_weights : True |
| 135 | +# loss_scale : dynamic |
| 136 | +# Epoch: [0][10/125] Time 0.866 (0.866) Speed 230.949 (230.949) Loss 0.6735125184 (0.6735) Prec@1 61.000 (61.000) Prec@5 100.000 (100.000) |
| 137 | +# Epoch: [0][20/125] Time 0.259 (0.562) Speed 773.481 (355.693) Loss 0.6968704462 (0.6852) Prec@1 55.000 (58.000) Prec@5 100.000 (100.000) |
| 138 | +# Epoch: [0][30/125] Time 0.258 (0.461) Speed 775.089 (433.965) Loss 0.7877287269 (0.7194) Prec@1 51.500 (55.833) Prec@5 100.000 (100.000) |
| 139 | +# Epoch: [0][40/125] Time 0.259 (0.410) Speed 771.710 (487.281) Loss 0.8285319805 (0.7467) Prec@1 48.500 (54.000) Prec@5 100.000 (100.000) |
| 140 | +# Epoch: [0][50/125] Time 0.260 (0.380) Speed 770.090 (525.908) Loss 0.7370464802 (0.7447) Prec@1 56.500 (54.500) Prec@5 100.000 (100.000) |
| 141 | +# Epoch: [0][60/125] Time 0.258 (0.360) Speed 775.623 (555.728) Loss 0.7592862844 (0.7472) Prec@1 51.000 (53.917) Prec@5 100.000 (100.000) |
| 142 | +# Epoch: [0][70/125] Time 0.258 (0.345) Speed 774.746 (579.115) Loss 1.9698858261 (0.9218) Prec@1 49.500 (53.286) Prec@5 100.000 (100.000) |
| 143 | +# Epoch: [0][80/125] Time 0.260 (0.335) Speed 770.324 (597.659) Loss 2.2505953312 (1.0879) Prec@1 50.500 (52.938) Prec@5 100.000 (100.000) |
| 144 | + |
| 145 | +###################################################################### |
| 146 | +# Passing ``--channels-last true`` allows running a model in Channels Last format with observed 22% perf gain. |
| 147 | +# |
| 148 | +# ``python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data`` |
| 149 | + |
| 150 | +# opt_level = O2 |
| 151 | +# keep_batchnorm_fp32 = None <class 'NoneType'> |
| 152 | +# loss_scale = None <class 'NoneType'> |
| 153 | +# |
| 154 | +# CUDNN VERSION: 7603 |
| 155 | +# |
| 156 | +# => creating model 'resnet50' |
| 157 | +# Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights. |
| 158 | +# |
| 159 | +# Defaults for this optimization level are: |
| 160 | +# enabled : True |
| 161 | +# opt_level : O2 |
| 162 | +# cast_model_type : torch.float16 |
| 163 | +# patch_torch_functions : False |
| 164 | +# keep_batchnorm_fp32 : True |
| 165 | +# master_weights : True |
| 166 | +# loss_scale : dynamic |
| 167 | +# Processing user overrides (additional kwargs that are not None)... |
| 168 | +# After processing overrides, optimization options are: |
| 169 | +# enabled : True |
| 170 | +# opt_level : O2 |
| 171 | +# cast_model_type : torch.float16 |
| 172 | +# patch_torch_functions : False |
| 173 | +# keep_batchnorm_fp32 : True |
| 174 | +# master_weights : True |
| 175 | +# loss_scale : dynamic |
| 176 | +# |
| 177 | +# Epoch: [0][10/125] Time 0.767 (0.767) Speed 260.785 (260.785) Loss 0.7579724789 (0.7580) Prec@1 53.500 (53.500) Prec@5 100.000 (100.000) |
| 178 | +# Epoch: [0][20/125] Time 0.198 (0.482) Speed 1012.135 (414.716) Loss 0.7007197738 (0.7293) Prec@1 49.000 (51.250) Prec@5 100.000 (100.000) |
| 179 | +# Epoch: [0][30/125] Time 0.198 (0.387) Speed 1010.977 (516.198) Loss 0.7113101482 (0.7233) Prec@1 55.500 (52.667) Prec@5 100.000 (100.000) |
| 180 | +# Epoch: [0][40/125] Time 0.197 (0.340) Speed 1013.023 (588.333) Loss 0.8943189979 (0.7661) Prec@1 54.000 (53.000) Prec@5 100.000 (100.000) |
| 181 | +# Epoch: [0][50/125] Time 0.198 (0.312) Speed 1010.541 (641.977) Loss 1.7113249302 (0.9551) Prec@1 51.000 (52.600) Prec@5 100.000 (100.000) |
| 182 | +# Epoch: [0][60/125] Time 0.198 (0.293) Speed 1011.163 (683.574) Loss 5.8537774086 (1.7716) Prec@1 50.500 (52.250) Prec@5 100.000 (100.000) |
| 183 | +# Epoch: [0][70/125] Time 0.198 (0.279) Speed 1011.453 (716.767) Loss 5.7595844269 (2.3413) Prec@1 46.500 (51.429) Prec@5 100.000 (100.000) |
| 184 | +# Epoch: [0][80/125] Time 0.198 (0.269) Speed 1011.827 (743.883) Loss 2.8196096420 (2.4011) Prec@1 47.500 (50.938) Prec@5 100.000 (100.000) |
| 185 | + |
| 186 | +###################################################################### |
| 187 | +# The following list of models has the full support of Channels Last and showing 8%-35% perf gains on Volta devices: |
| 188 | +# ``alexnet``, ``mnasnet0_5``, ``mnasnet0_75``, ``mnasnet1_0``, ``mnasnet1_3``, ``mobilenet_v2``, ``resnet101``, ``resnet152``, ``resnet18``, ``resnet34``, ``resnet50``, ``resnext50_32x4d``, ``shufflenet_v2_x0_5``, ``shufflenet_v2_x1_0``, ``shufflenet_v2_x1_5``, ``shufflenet_v2_x2_0``, ``squeezenet1_0``, ``squeezenet1_1``, ``vgg11``, ``vgg11_bn``, ``vgg13``, ``vgg13_bn``, ``vgg16``, ``vgg16_bn``, ``vgg19``, ``vgg19_bn``, ``wide_resnet101_2``, ``wide_resnet50_2`` |
| 189 | +# |
| 190 | + |
| 191 | +###################################################################### |
| 192 | +# Converting existing models |
| 193 | +# -------------------------- |
| 194 | +# |
| 195 | +# Channels Last support not limited by existing models, as any model can be converted to Channels Last and propagate format through the graph as soon as input formatted correctly. |
| 196 | +# |
| 197 | + |
| 198 | +# Need to be done once, after model initialization (or load) |
| 199 | +model = model.to(memory_format=torch.channels_last) # Replace with your model |
| 200 | + |
| 201 | +# Need to be done for every input |
| 202 | +input = input.to(memory_format=torch.channels_last) # Replace with your input |
| 203 | +output = model(input) |
| 204 | + |
| 205 | +####################################################################### |
| 206 | +# However, not all operators fully converted to support Channels Last (usually returning |
| 207 | +# contiguous output instead). That means you need to verify the list of used operators |
| 208 | +# against supported operators list https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support, |
| 209 | +# or introduce memory format checks into eager execution mode and run your model. |
| 210 | +# |
| 211 | +# After running the code below, operators will raise an exception if the output of the |
| 212 | +# operator doesn't match the memory format of the input. |
| 213 | +# |
| 214 | +# |
| 215 | +def contains_cl(args): |
| 216 | + for t in args: |
| 217 | + if isinstance(t, torch.Tensor): |
| 218 | + if t.is_contiguous(memory_format=torch.channels_last) and not t.is_contiguous(): |
| 219 | + return True |
| 220 | + elif isinstance(t, list) or isinstance(t, tuple): |
| 221 | + if contains_cl(list(t)): |
| 222 | + return True |
| 223 | + return False |
| 224 | + |
| 225 | + |
| 226 | +def print_inputs(args, indent=''): |
| 227 | + for t in args: |
| 228 | + if isinstance(t, torch.Tensor): |
| 229 | + print(indent, t.stride(), t.shape, t.device, t.dtype) |
| 230 | + elif isinstance(t, list) or isinstance(t, tuple): |
| 231 | + print(indent, type(t)) |
| 232 | + print_inputs(list(t), indent=indent + ' ') |
| 233 | + else: |
| 234 | + print(indent, t) |
| 235 | + |
| 236 | + |
| 237 | +def check_wrapper(fn): |
| 238 | + name = fn.__name__ |
| 239 | + |
| 240 | + def check_cl(*args, **kwargs): |
| 241 | + was_cl = contains_cl(args) |
| 242 | + try: |
| 243 | + result = fn(*args, **kwargs) |
| 244 | + except Exception as e: |
| 245 | + print("`{}` inputs are:".format(name)) |
| 246 | + print_inputs(args) |
| 247 | + print('-------------------') |
| 248 | + raise e |
| 249 | + failed = False |
| 250 | + if was_cl: |
| 251 | + if isinstance(result, torch.Tensor): |
| 252 | + if result.dim() == 4 and not result.is_contiguous(memory_format=torch.channels_last): |
| 253 | + print("`{}` got channels_last input, but output is not channels_last:".format(name), |
| 254 | + result.shape, result.stride(), result.device, result.dtype) |
| 255 | + failed = True |
| 256 | + if failed and True: |
| 257 | + print("`{}` inputs are:".format(name)) |
| 258 | + print_inputs(args) |
| 259 | + raise Exception( |
| 260 | + 'Operator `{}` lost channels_last property'.format(name)) |
| 261 | + return result |
| 262 | + return check_cl |
| 263 | + |
| 264 | + |
| 265 | +def attribute(m): |
| 266 | + for i in dir(m): |
| 267 | + e = getattr(m, i) |
| 268 | + exclude_functions = ['is_cuda', 'has_names', 'numel', |
| 269 | + 'stride', 'Tensor', 'is_contiguous', '__class__'] |
| 270 | + if i not in exclude_functions and not i.startswith('_') and '__call__' in dir(e): |
| 271 | + try: |
| 272 | + setattr(m, i, check_wrapper(e)) |
| 273 | + except Exception as e: |
| 274 | + print(i) |
| 275 | + print(e) |
| 276 | + |
| 277 | + |
| 278 | +attribute(torch.Tensor) |
| 279 | +attribute(torch.nn.functional) |
| 280 | +attribute(torch) |
| 281 | + |
| 282 | + |
| 283 | +###################################################################### |
| 284 | +# If you found an operator that doesn't support Channels Last tensors |
| 285 | +# and you want to contribute, feel free to use following developers |
| 286 | +# guide https://github.com/pytorch/pytorch/wiki/Writing-memory-format-aware-operators. |
| 287 | +# |
| 288 | + |
| 289 | +###################################################################### |
| 290 | +# Work to do |
| 291 | +# ---------- |
| 292 | +# There are still many things to do, such as: |
| 293 | +# |
| 294 | +# - Resolving ambiguity of N1HW and NC11 Tensors; |
| 295 | +# - Testing of Distributed Training support; |
| 296 | +# - Improving operators coverage. |
| 297 | +# |
| 298 | +# If you have feedback and/or suggestions for improvement, please let us |
| 299 | +# know by creating `an issue <https://github.com/pytorch/pytorch/issues>`_. |
0 commit comments