Skip to content

Commit f4862ee

Browse files
sanchitintelSvetlana Karslioglu
andauthored
Update oneDNN Graph JIT Fuser recipe with BF16 dtype (#2229)
* Update oneDNN Graph JIT Fuser recipe with BF16 dtype Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent f2fbe6b commit f4862ee

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

recipes_source/recipes/tuning_guide.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def fused_gelu(x):
240240
# Use oneDNN Graph with TorchScript for inference
241241
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
242242
# oneDNN Graph can significantly boost inference performance. It fuses some compute-intensive operations such as convolution, matmul with their neighbor operations.
243-
# Currently, it's supported as an experimental feature for Float32 data-type.
243+
# In PyTorch 2.0, it is supported as a beta feature for Float32 & BFloat16 data-types.
244244
# oneDNN Graph receives the model’s graph and identifies candidates for operator-fusion with respect to the shape of the example input.
245245
# A model should be JIT-traced using an example input.
246246
# Speed-up would then be observed after a couple of warm-up iterations for inputs with the same shape as the example input.
@@ -250,7 +250,7 @@ def fused_gelu(x):
250250
torch.jit.enable_onednn_fusion(True)
251251

252252
###############################################################################
253-
# Using the oneDNN Graph API requires just one extra line of code.
253+
# Using the oneDNN Graph API requires just one extra line of code for inference with Float32.
254254
# If you are using oneDNN Graph, please avoid calling ``torch.jit.optimize_for_inference``.
255255

256256
# sample input should be of the same shape as expected inputs
@@ -273,6 +273,24 @@ def fused_gelu(x):
273273
# speedup would be observed after warmup runs
274274
traced_model(*sample_input)
275275

276+
###############################################################################
277+
# While the JIT fuser for oneDNN Graph also supports inference with BFloat16 datatype,
278+
# performance benefit with oneDNN Graph is only exhibited by machines with AVX512_BF16 ISA.
279+
# The following code snippets serves as an example of using BFloat16 datatype for inference with oneDNN Graph:
280+
281+
# AMP for JIT mode is enabled by default, and is divergent with its eager mode counterpart
282+
torch._C._jit_set_autocast_mode(False)
283+
284+
with torch.no_grad(), torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16):
285+
model = torch.jit.trace(model, (example_input))
286+
model = torch.jit.freeze(model)
287+
# a couple of warmup runs
288+
model(example_input)
289+
model(example_input)
290+
# speedup would be observed in subsequent runs.
291+
model(example_input)
292+
293+
276294
###############################################################################
277295
# Train a model on CPU with PyTorch DistributedDataParallel(DDP) functionality
278296
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

0 commit comments

Comments
 (0)