Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions recipes_source/recipes/tuning_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def fused_gelu(x):
# Use oneDNN Graph with TorchScript for inference
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# oneDNN Graph can significantly boost inference performance. It fuses some compute-intensive operations such as convolution, matmul with their neighbor operations.
# Currently, it's supported as an experimental feature for Float32 data-type.
# In PyTorch 2.0, it is supported as a beta feature for Float32 & BFloat16 data-types.
# oneDNN Graph receives the model’s graph and identifies candidates for operator-fusion with respect to the shape of the example input.
# A model should be JIT-traced using an example input.
# Speed-up would then be observed after a couple of warm-up iterations for inputs with the same shape as the example input.
Expand All @@ -250,7 +250,7 @@ def fused_gelu(x):
torch.jit.enable_onednn_fusion(True)

###############################################################################
# Using the oneDNN Graph API requires just one extra line of code.
# Using the oneDNN Graph API requires just one extra line of code for inference with Float32.
# If you are using oneDNN Graph, please avoid calling ``torch.jit.optimize_for_inference``.

# sample input should be of the same shape as expected inputs
Expand All @@ -273,6 +273,24 @@ def fused_gelu(x):
# speedup would be observed after warmup runs
traced_model(*sample_input)

###############################################################################
# While the JIT fuser for oneDNN Graph also supports inference with BFloat16 datatype,
# performance benefit with oneDNN Graph is only exhibited by machines with AVX512_BF16 ISA.
# The following code snippets serves as an example of using BFloat16 datatype for inference with oneDNN Graph:

# AMP for JIT mode is enabled by default, and is divergent with its eager mode counterpart
torch._C._jit_set_autocast_mode(False)

with torch.no_grad(), torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16):
model = torch.jit.trace(model, (example_input))
model = torch.jit.freeze(model)
# a couple of warmup runs
model(example_input)
model(example_input)
# speedup would be observed in subsequent runs.
model(example_input)


###############################################################################
# Train a model on CPU with PyTorch DistributedDataParallel(DDP) functionality
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down