Skip to content

Commit 2efdc8e

Browse files
author
Jessica Lin
authored
Merge pull request #958 from VitalyFedyunin/channels_last
Add Channels Last tutorial
2 parents 5cc89d8 + 9bc982a commit 2efdc8e

File tree

5 files changed

+305
-0
lines changed

5 files changed

+305
-0
lines changed
3.54 KB
Loading
3.5 KB
Loading

_static/img/memory_format_logo.png

2.87 KB
Loading

index.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,11 @@ Frontend APIs
175175
:tooltip: Named Tensor
176176
:description: :doc:`intermediate/named_tensor_tutorial`
177177

178+
.. customgalleryitem::
179+
:figure: /_static/img/memory_format_logo.png
180+
:tooltip: Memory Format
181+
:description: :doc:`intermediate/memory_format_tutorial`
182+
178183
.. customgalleryitem::
179184
:tooltip: Using the PyTorch C++ Frontend
180185
:figure: /_static/img/cpp-pytorch.png
@@ -348,6 +353,7 @@ Parallel and Distributed Training
348353
:caption: Frontend APIs
349354

350355
intermediate/named_tensor_tutorial
356+
intermediate/memory_format_tutorial
351357
advanced/cpp_frontend
352358
advanced/cpp_extension
353359
advanced/torch_script_custom_ops
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
(experimental) Channels Last Memory Format in PyTorch
4+
*******************************************************
5+
**Author**: `Vitaly Fedyunin <https://github.com/VitalyFedyunin>`_
6+
7+
What is Channels Last
8+
---------------------
9+
10+
Channels Last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels Last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel).
11+
12+
For example, classic (contiguous) storage of NCHW tensor (in our case it is two 2x2 images with 3 color channels) look like this:
13+
14+
.. figure:: /_static/img/classic_memory_format.png
15+
:alt: classic_memory_format
16+
17+
Channels Last memory format orders data differently:
18+
19+
.. figure:: /_static/img/channels_last_memory_format.png
20+
:alt: channels_last_memory_format
21+
22+
Pytorch supports memory formats (and provides back compatibility with existing models including eager, JIT, and TorchScript) by utilizing existing strides structure.
23+
For example, 10x3x16x16 batch in Channels Last format will have strides equal to (768, 1, 48, 3).
24+
"""
25+
26+
######################################################################
27+
# Channels Last memory format is implemented for 4D NCWH Tensors only.
28+
#
29+
30+
import torch
31+
N, C, H, W = 10, 3, 32, 32
32+
33+
######################################################################
34+
# Memory Format API
35+
# -----------------------
36+
#
37+
# Here is how to convert tensors between contiguous and channels
38+
# last memory formats.
39+
40+
######################################################################
41+
# Classic PyTorch contiguous tensor
42+
x = torch.empty(N, C, H, W)
43+
print(x.stride()) # Ouputs: (3072, 1024, 32, 1)
44+
45+
######################################################################
46+
# Conversion operator
47+
x = x.contiguous(memory_format=torch.channels_last)
48+
print(x.shape) # Outputs: (10, 3, 32, 32) as dimensions order preserved
49+
print(x.stride()) # Outputs: (3072, 1, 96, 3)
50+
51+
######################################################################
52+
# Back to contiguous
53+
x = x.contiguous(memory_format=torch.contiguous_format)
54+
print(x.stride()) # Outputs: (3072, 1024, 32, 1)
55+
56+
######################################################################
57+
# Alternative option
58+
x = x.to(memory_format=torch.channels_last)
59+
print(x.stride()) # Ouputs: (3072, 1, 96, 3)
60+
61+
######################################################################
62+
# Format checks
63+
print(x.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True
64+
65+
######################################################################
66+
# Create as Channels Last
67+
x = torch.empty(N, C, H, W, memory_format=torch.channels_last)
68+
print(x.stride()) # Ouputs: (3072, 1, 96, 3)
69+
70+
######################################################################
71+
# ``clone`` preserves memory format
72+
y = x.clone()
73+
print(y.stride()) # Ouputs: (3072, 1, 96, 3)
74+
75+
######################################################################
76+
# ``to``, ``cuda``, ``float`` ... preserves memory format
77+
if torch.cuda.is_available():
78+
y = x.cuda()
79+
print(y.stride()) # Ouputs: (3072, 1, 96, 3)
80+
81+
######################################################################
82+
# ``empty_like``, ``*_like`` operators preserves memory format
83+
y = torch.empty_like(x)
84+
print(y.stride()) # Ouputs: (3072, 1, 96, 3)
85+
86+
######################################################################
87+
# Pointwise operators preserves memory format
88+
z = x + y
89+
print(z.stride()) # Ouputs: (3072, 1, 96, 3)
90+
91+
######################################################################
92+
# Conv, Batchnorm modules support Channels Last
93+
# (only works for CudNN >= 7.6)
94+
if torch.backends.cudnn.version() >= 7603:
95+
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True)
96+
model = torch.nn.Conv2d(8, 4, 3).cuda().float()
97+
98+
input = input.contiguous(memory_format=torch.channels_last)
99+
model = model.to(memory_format=torch.channels_last) # Module parameters need to be Channels Last
100+
101+
out = model(input)
102+
print(out.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True
103+
104+
######################################################################
105+
# Performance Gains
106+
# -------------------------------------------------------------------------------------------
107+
# The most significant performance gains are observed on NVidia's hardware with
108+
# Tensor Cores support. We were able to archive over 22% perf gains while running '
109+
# AMP (Automated Mixed Precision) training scripts supplied by NVidia https://github.com/NVIDIA/apex.
110+
#
111+
# ``python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 ./data``
112+
113+
# opt_level = O2
114+
# keep_batchnorm_fp32 = None <class 'NoneType'>
115+
# loss_scale = None <class 'NoneType'>
116+
# CUDNN VERSION: 7603
117+
# => creating model 'resnet50'
118+
# Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
119+
# Defaults for this optimization level are:
120+
# enabled : True
121+
# opt_level : O2
122+
# cast_model_type : torch.float16
123+
# patch_torch_functions : False
124+
# keep_batchnorm_fp32 : True
125+
# master_weights : True
126+
# loss_scale : dynamic
127+
# Processing user overrides (additional kwargs that are not None)...
128+
# After processing overrides, optimization options are:
129+
# enabled : True
130+
# opt_level : O2
131+
# cast_model_type : torch.float16
132+
# patch_torch_functions : False
133+
# keep_batchnorm_fp32 : True
134+
# master_weights : True
135+
# loss_scale : dynamic
136+
# Epoch: [0][10/125] Time 0.866 (0.866) Speed 230.949 (230.949) Loss 0.6735125184 (0.6735) Prec@1 61.000 (61.000) Prec@5 100.000 (100.000)
137+
# Epoch: [0][20/125] Time 0.259 (0.562) Speed 773.481 (355.693) Loss 0.6968704462 (0.6852) Prec@1 55.000 (58.000) Prec@5 100.000 (100.000)
138+
# Epoch: [0][30/125] Time 0.258 (0.461) Speed 775.089 (433.965) Loss 0.7877287269 (0.7194) Prec@1 51.500 (55.833) Prec@5 100.000 (100.000)
139+
# Epoch: [0][40/125] Time 0.259 (0.410) Speed 771.710 (487.281) Loss 0.8285319805 (0.7467) Prec@1 48.500 (54.000) Prec@5 100.000 (100.000)
140+
# Epoch: [0][50/125] Time 0.260 (0.380) Speed 770.090 (525.908) Loss 0.7370464802 (0.7447) Prec@1 56.500 (54.500) Prec@5 100.000 (100.000)
141+
# Epoch: [0][60/125] Time 0.258 (0.360) Speed 775.623 (555.728) Loss 0.7592862844 (0.7472) Prec@1 51.000 (53.917) Prec@5 100.000 (100.000)
142+
# Epoch: [0][70/125] Time 0.258 (0.345) Speed 774.746 (579.115) Loss 1.9698858261 (0.9218) Prec@1 49.500 (53.286) Prec@5 100.000 (100.000)
143+
# Epoch: [0][80/125] Time 0.260 (0.335) Speed 770.324 (597.659) Loss 2.2505953312 (1.0879) Prec@1 50.500 (52.938) Prec@5 100.000 (100.000)
144+
145+
######################################################################
146+
# Passing ``--channels-last true`` allows running a model in Channels Last format with observed 22% perf gain.
147+
#
148+
# ``python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data``
149+
150+
# opt_level = O2
151+
# keep_batchnorm_fp32 = None <class 'NoneType'>
152+
# loss_scale = None <class 'NoneType'>
153+
#
154+
# CUDNN VERSION: 7603
155+
#
156+
# => creating model 'resnet50'
157+
# Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
158+
#
159+
# Defaults for this optimization level are:
160+
# enabled : True
161+
# opt_level : O2
162+
# cast_model_type : torch.float16
163+
# patch_torch_functions : False
164+
# keep_batchnorm_fp32 : True
165+
# master_weights : True
166+
# loss_scale : dynamic
167+
# Processing user overrides (additional kwargs that are not None)...
168+
# After processing overrides, optimization options are:
169+
# enabled : True
170+
# opt_level : O2
171+
# cast_model_type : torch.float16
172+
# patch_torch_functions : False
173+
# keep_batchnorm_fp32 : True
174+
# master_weights : True
175+
# loss_scale : dynamic
176+
#
177+
# Epoch: [0][10/125] Time 0.767 (0.767) Speed 260.785 (260.785) Loss 0.7579724789 (0.7580) Prec@1 53.500 (53.500) Prec@5 100.000 (100.000)
178+
# Epoch: [0][20/125] Time 0.198 (0.482) Speed 1012.135 (414.716) Loss 0.7007197738 (0.7293) Prec@1 49.000 (51.250) Prec@5 100.000 (100.000)
179+
# Epoch: [0][30/125] Time 0.198 (0.387) Speed 1010.977 (516.198) Loss 0.7113101482 (0.7233) Prec@1 55.500 (52.667) Prec@5 100.000 (100.000)
180+
# Epoch: [0][40/125] Time 0.197 (0.340) Speed 1013.023 (588.333) Loss 0.8943189979 (0.7661) Prec@1 54.000 (53.000) Prec@5 100.000 (100.000)
181+
# Epoch: [0][50/125] Time 0.198 (0.312) Speed 1010.541 (641.977) Loss 1.7113249302 (0.9551) Prec@1 51.000 (52.600) Prec@5 100.000 (100.000)
182+
# Epoch: [0][60/125] Time 0.198 (0.293) Speed 1011.163 (683.574) Loss 5.8537774086 (1.7716) Prec@1 50.500 (52.250) Prec@5 100.000 (100.000)
183+
# Epoch: [0][70/125] Time 0.198 (0.279) Speed 1011.453 (716.767) Loss 5.7595844269 (2.3413) Prec@1 46.500 (51.429) Prec@5 100.000 (100.000)
184+
# Epoch: [0][80/125] Time 0.198 (0.269) Speed 1011.827 (743.883) Loss 2.8196096420 (2.4011) Prec@1 47.500 (50.938) Prec@5 100.000 (100.000)
185+
186+
######################################################################
187+
# The following list of models has the full support of Channels Last and showing 8%-35% perf gains on Volta devices:
188+
# ``alexnet``, ``mnasnet0_5``, ``mnasnet0_75``, ``mnasnet1_0``, ``mnasnet1_3``, ``mobilenet_v2``, ``resnet101``, ``resnet152``, ``resnet18``, ``resnet34``, ``resnet50``, ``resnext50_32x4d``, ``shufflenet_v2_x0_5``, ``shufflenet_v2_x1_0``, ``shufflenet_v2_x1_5``, ``shufflenet_v2_x2_0``, ``squeezenet1_0``, ``squeezenet1_1``, ``vgg11``, ``vgg11_bn``, ``vgg13``, ``vgg13_bn``, ``vgg16``, ``vgg16_bn``, ``vgg19``, ``vgg19_bn``, ``wide_resnet101_2``, ``wide_resnet50_2``
189+
#
190+
191+
######################################################################
192+
# Converting existing models
193+
# --------------------------
194+
#
195+
# Channels Last support not limited by existing models, as any model can be converted to Channels Last and propagate format through the graph as soon as input formatted correctly.
196+
#
197+
198+
# Need to be done once, after model initialization (or load)
199+
model = model.to(memory_format=torch.channels_last) # Replace with your model
200+
201+
# Need to be done for every input
202+
input = input.to(memory_format=torch.channels_last) # Replace with your input
203+
output = model(input)
204+
205+
#######################################################################
206+
# However, not all operators fully converted to support Channels Last (usually returning
207+
# contiguous output instead). That means you need to verify the list of used operators
208+
# against supported operators list https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support,
209+
# or introduce memory format checks into eager execution mode and run your model.
210+
#
211+
# After running the code below, operators will raise an exception if the output of the
212+
# operator doesn't match the memory format of the input.
213+
#
214+
#
215+
def contains_cl(args):
216+
for t in args:
217+
if isinstance(t, torch.Tensor):
218+
if t.is_contiguous(memory_format=torch.channels_last) and not t.is_contiguous():
219+
return True
220+
elif isinstance(t, list) or isinstance(t, tuple):
221+
if contains_cl(list(t)):
222+
return True
223+
return False
224+
225+
226+
def print_inputs(args, indent=''):
227+
for t in args:
228+
if isinstance(t, torch.Tensor):
229+
print(indent, t.stride(), t.shape, t.device, t.dtype)
230+
elif isinstance(t, list) or isinstance(t, tuple):
231+
print(indent, type(t))
232+
print_inputs(list(t), indent=indent + ' ')
233+
else:
234+
print(indent, t)
235+
236+
237+
def check_wrapper(fn):
238+
name = fn.__name__
239+
240+
def check_cl(*args, **kwargs):
241+
was_cl = contains_cl(args)
242+
try:
243+
result = fn(*args, **kwargs)
244+
except Exception as e:
245+
print("`{}` inputs are:".format(name))
246+
print_inputs(args)
247+
print('-------------------')
248+
raise e
249+
failed = False
250+
if was_cl:
251+
if isinstance(result, torch.Tensor):
252+
if result.dim() == 4 and not result.is_contiguous(memory_format=torch.channels_last):
253+
print("`{}` got channels_last input, but output is not channels_last:".format(name),
254+
result.shape, result.stride(), result.device, result.dtype)
255+
failed = True
256+
if failed and True:
257+
print("`{}` inputs are:".format(name))
258+
print_inputs(args)
259+
raise Exception(
260+
'Operator `{}` lost channels_last property'.format(name))
261+
return result
262+
return check_cl
263+
264+
265+
def attribute(m):
266+
for i in dir(m):
267+
e = getattr(m, i)
268+
exclude_functions = ['is_cuda', 'has_names', 'numel',
269+
'stride', 'Tensor', 'is_contiguous', '__class__']
270+
if i not in exclude_functions and not i.startswith('_') and '__call__' in dir(e):
271+
try:
272+
setattr(m, i, check_wrapper(e))
273+
except Exception as e:
274+
print(i)
275+
print(e)
276+
277+
278+
attribute(torch.Tensor)
279+
attribute(torch.nn.functional)
280+
attribute(torch)
281+
282+
283+
######################################################################
284+
# If you found an operator that doesn't support Channels Last tensors
285+
# and you want to contribute, feel free to use following developers
286+
# guide https://github.com/pytorch/pytorch/wiki/Writing-memory-format-aware-operators.
287+
#
288+
289+
######################################################################
290+
# Work to do
291+
# ----------
292+
# There are still many things to do, such as:
293+
#
294+
# - Resolving ambiguity of N1HW and NC11 Tensors;
295+
# - Testing of Distributed Training support;
296+
# - Improving operators coverage.
297+
#
298+
# If you have feedback and/or suggestions for improvement, please let us
299+
# know by creating `an issue <https://github.com/pytorch/pytorch/issues>`_.

0 commit comments

Comments
 (0)