Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion beginner_source/basics/buildmodel_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def forward(self, x):
# Model Layers
# -------------------------
#
# Lets break down the layers in the FashionMNIST model. To illustrate it, we
# Let's break down the layers in the FashionMNIST model. To illustrate it, we
# will take a sample minibatch of 3 images of size 28x28 and see what happens to it as
# we pass it through the network.

Expand Down
2 changes: 1 addition & 1 deletion beginner_source/basics/data_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# -------------------
#
# Here is an example of how to load the `Fashion-MNIST <https://research.zalando.com/project/fashion_mnist/fashion_mnist/>`_ dataset from TorchVision.
# Fashion-MNIST is a dataset of Zalando’s article images consisting of of 60,000 training examples and 10,000 test examples.
# Fashion-MNIST is a dataset of Zalando’s article images consisting of 60,000 training examples and 10,000 test examples.
# Each example comprises a 28×28 grayscale image and an associated label from one of 10 classes.
#
# We load the `FashionMNIST Dataset <https://pytorch.org/vision/stable/datasets.html#fashion-mnist>`_ with the following parameters:
Expand Down
2 changes: 1 addition & 1 deletion beginner_source/dcgan_faces_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
# with the discriminator. Let :math:`x` be data representing an image.
# :math:`D(x)` is the discriminator network which outputs the (scalar)
# probability that :math:`x` came from training data rather than the
# generator. Here, since we are dealing with images the input to
# generator. Here, since we are dealing with images, the input to
# :math:`D(x)` is an image of CHW size 3x64x64. Intuitively, :math:`D(x)`
# should be HIGH when :math:`x` comes from training data and LOW when
# :math:`x` comes from the generator. :math:`D(x)` can also be thought of
Expand Down
2 changes: 1 addition & 1 deletion beginner_source/text_sentiment_ngrams_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
#
# Before sending to the model, ``collate_fn`` function works on a batch of samples generated from ``DataLoader``. The input to ``collate_fn`` is a batch of data with the batch size in ``DataLoader``, and ``collate_fn`` processes them according to the data processing pipelines declared previously. Pay attention here and make sure that ``collate_fn`` is declared as a top level def. This ensures that the function is available in each worker.
#
# In this example, the text entries in the original data batch input are packed into a list and concatenated as a single tensor for the input of ``nn.EmbeddingBag``. The offset is a tensor of delimiters to represent the beginning index of the individual sequence in the text tensor. Label is a tensor saving the labels of indidividual text entries.
# In this example, the text entries in the original data batch input are packed into a list and concatenated as a single tensor for the input of ``nn.EmbeddingBag``. The offset is a tensor of delimiters to represent the beginning index of the individual sequence in the text tensor. Label is a tensor saving the labels of individual text entries.


from torch.utils.data import DataLoader
Expand Down
63 changes: 12 additions & 51 deletions intermediate_source/pipeline_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,27 +148,24 @@ def forward(self, x):
# efficient batch processing.
#

import io
import torch
from torchtext.utils import download_from_url, extract_archive
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'
test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url))
train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer,
iter(io.open(train_filepath,
encoding="utf8"))))
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

def data_process(raw_text_iter):
data = [torch.tensor([vocab[token] for token in tokenizer(item)],
dtype=torch.long) for item in raw_text_iter]
data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

train_data = data_process(iter(io.open(train_filepath, encoding="utf8")))
val_data = data_process(iter(io.open(valid_filepath, encoding="utf8")))
test_data = data_process(iter(io.open(test_filepath, encoding="utf8")))
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

device = torch.device("cuda")

Expand Down Expand Up @@ -244,7 +241,7 @@ def get_batch(source, i):
# allows the Pipe to work with only two partitions and avoid any
# cross-partition overheads.

ntokens = len(vocab.stoi) # the size of vocabulary
ntokens = len(vocab) # the size of vocabulary
emsize = 4096 # embedding dimension
nhid = 4096 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 12 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
Expand Down Expand Up @@ -330,7 +327,7 @@ def train():
model.train() # Turn on the train mode
total_loss = 0.
start_time = time.time()
ntokens = len(vocab.stoi)
ntokens = len(vocab)

# Train only for 50 batches to keep script execution time low.
nbatches = min(50 * bptt, train_data.size(0) - 1)
Expand Down Expand Up @@ -366,7 +363,7 @@ def train():
def evaluate(eval_model, data_source):
eval_model.eval() # Turn on the evaluation mode
total_loss = 0.
ntokens = len(vocab.stoi)
ntokens = len(vocab)
# Evaluate only for 50 batches to keep script execution time low.
nbatches = min(50 * bptt, data_source.size(0) - 1)
with torch.no_grad():
Expand Down Expand Up @@ -418,39 +415,3 @@ def evaluate(eval_model, data_source):
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)


######################################################################
# Output
# ------
#


######################################################################
#.. code-block:: py
#
# Total parameters in model: 1,847,087,215
# | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 2387.45 | loss 42.16 | ppl 2036775646369743616.00
# | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 2150.93 | loss 48.24 | ppl 891334049215401558016.00
# | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 2155.23 | loss 34.66 | ppl 1125676483188404.62
# | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 2158.42 | loss 38.87 | ppl 76287208340888368.00
# -----------------------------------------------------------------------------------------
# | end of epoch 1 | time: 119.65s | valid loss 2.95 | valid ppl 19.15
# -----------------------------------------------------------------------------------------
# | epoch 2 | 10/ 50 batches | lr 4.51 | ms/batch 2376.16 | loss 34.92 | ppl 1458001430957104.00
# | epoch 2 | 20/ 50 batches | lr 4.51 | ms/batch 2160.96 | loss 34.75 | ppl 1232463826541886.50
# | epoch 2 | 30/ 50 batches | lr 4.51 | ms/batch 2160.66 | loss 28.10 | ppl 1599598251136.51
# | epoch 2 | 40/ 50 batches | lr 4.51 | ms/batch 2160.07 | loss 20.25 | ppl 621174306.77
# -----------------------------------------------------------------------------------------
# | end of epoch 2 | time: 119.76s | valid loss 0.87 | valid ppl 2.38
# -----------------------------------------------------------------------------------------
# | epoch 3 | 10/ 50 batches | lr 4.29 | ms/batch 2376.49 | loss 13.20 | ppl 537727.23
# | epoch 3 | 20/ 50 batches | lr 4.29 | ms/batch 2160.12 | loss 10.98 | ppl 58548.58
# | epoch 3 | 30/ 50 batches | lr 4.29 | ms/batch 2160.05 | loss 12.01 | ppl 164152.79
# | epoch 3 | 40/ 50 batches | lr 4.29 | ms/batch 2160.03 | loss 10.63 | ppl 41348.00
# -----------------------------------------------------------------------------------------
# | end of epoch 3 | time: 119.76s | valid loss 0.78 | valid ppl 2.17
# -----------------------------------------------------------------------------------------
# =========================================================================================
# | End of training | test loss 0.69 | test ppl 1.99
# =========================================================================================
2 changes: 1 addition & 1 deletion intermediate_source/seq2seq_translation_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@
# the networks later. To keep track of all this we will use a helper class
# called ``Lang`` which has word → index (``word2index``) and index → word
# (``index2word``) dictionaries, as well as a count of each word
# ``word2count`` to use to later replace rare words.
# ``word2count`` which will be used to replace rare words later.
#

SOS_token = 0
Expand Down
63 changes: 63 additions & 0 deletions recipes_source/recipes/tuning_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,69 @@ def fused_gelu(x):
# `torch.autograd.gradgradcheck <https://pytorch.org/docs/stable/autograd.html#torch.autograd.gradgradcheck>`_
#

###############################################################################
# CPU specific optimizations
# --------------------------

###############################################################################
# Utilize Non-Uniform Memory Access (NUMA) Controls
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# NUMA or non-uniform memory access is a memory layout design used in data center machines meant to take advantage of locality of memory in multi-socket machines with multiple memory controllers and blocks. Generally speaking, all deep learning workloads, training or inference, get better performance without accessing hardware resources across NUMA nodes. Thus, inference can be run with multiple instances, each instance runs on one socket, to raise throughput. For training tasks on single node, distributed training is recommended to make each training process run on one socket.
#
# In general cases the following command executes a PyTorch script on cores on the Nth node only, and avoids cross-socket memory access to reduce memory access overhead.

# numactl --cpunodebind=N --membind=N python <pytorch_script>

###############################################################################
# More detailed descriptions can be found `here <https://software.intel.com/content/www/us/en/develop/articles/how-to-get-better-performance-on-pytorchcaffe2-with-intel-acceleration.html>`_.

###############################################################################
# Utilize OpenMP
# ~~~~~~~~~~~~~~
# OpenMP is utilized to bring better performance for parallel computation tasks.
# OMP_NUM_THREADS is the easiest switch that can be used to accelerate computations. It determines number of threads used for OpenMP computations.
# CPU affinity setting controls how workloads are distributed over multiple cores. It affects communication overhead, cache line invalidation overhead, or page thrashing, thus proper setting of CPU affinity brings performance benefits. GOMP_CPU_AFFINITY or KMP_AFFINITY determines how to bind OpenMP* threads to physical processing units. Detailed information can be found `here <https://software.intel.com/content/www/us/en/develop/articles/how-to-get-better-performance-on-pytorchcaffe2-with-intel-acceleration.html>`_.

###############################################################################
# With the following command, PyTorch run the task on N OpenMP threads.

# export OMP_NUM_THREADS=N

###############################################################################
# Typically, the following environment variables are used to set for CPU affinity with GNU OpenMP implementation. OMP_PROC_BIND specifies whether threads may be moved between processors. Setting it to CLOSE keeps OpenMP threads close to the primary thread in contiguous place partitions. OMP_SCHEDULE determines how OpenMP threads are scheduled. GOMP_CPU_AFFINITY binds threads to specific CPUs.

# export OMP_SCHEDULE=STATIC
# export OMP_PROC_BIND=CLOSE
# export GOMP_CPU_AFFINITY="N-M"

###############################################################################
# Intel OpenMP Runtime Library (libiomp)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# By default, PyTorch uses GNU OpenMP (GNU libgomp) for parallel computation. On Intel platforms, Intel OpenMP Runtime Library (libiomp) provides OpenMP API specification support. It sometimes brings more performance benefits compared to libgomp. Utilizing environment variable LD_PRELOAD can switch OpenMP library to libiomp:

# export LD_PRELOAD=<path>/libiomp5.so:$LD_PRELOAD

###############################################################################
# Similar to CPU affinity settings in GNU OpenMP, environment variables are provided in libiomp to control CPU affinity settings.
# KMP_AFFINITY binds OpenMP threads to physical processing units. KMP_BLOCKTIME sets the time, in milliseconds, that a thread should wait, after completing the execution of a parallel region, before sleeping. In most cases, setting KMP_BLOCKTIME to 1 or 0 yields good performances.
# The following commands show a common settings with Intel OpenMP Runtime Library.

# export KMP_AFFINITY=granularity=fine,compact,1,0
# export KMP_BLOCKTIME=1

###############################################################################
# Switch Memory allocator
# ~~~~~~~~~~~~~~~~~~~~~~~
# For deep learning workloads, Jemalloc or TCMalloc can get better performance by reusing memory as much as possible than default malloc funtion. `Jemalloc <https://github.com/jemalloc/jemalloc>`_ is a general purpose malloc implementation that emphasizes fragmentation avoidance and scalable concurrency support. `TCMalloc <https://google.github.io/tcmalloc/overview.html>`_ also features a couple of optimizations to speed up program executions. One of them is holding memory in caches to speed up access of commonly-used objects. Holding such caches even after deallocation also helps avoid costly system calls if such memory is later re-allocated.
# Use environment variable LD_PRELOAD to take advantage of one of them.

# export LD_PRELOAD=<jemalloc.so/tcmalloc.so>:$LD_PRELOAD

###############################################################################
# Train a model on CPU with PyTorch DistributedDataParallel(DDP) functionality
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# For small scale models or memory-bound models, such as DLRM, training on CPU is also a good choice. On a machine with multiple sockets, distributed training brings a high-efficient hardware resource usage to accelerate the training process. `Torch-ccl <https://github.com/intel/torch-ccl>`_, optimized with Intel(R) oneCCL (collective commnications library) for efficient distributed deep learning training implementing such collectives like allreduce, allgather, alltoall, implements PyTorch C10D ProcessGroup API and can be dynamically loaded as external ProcessGroup. Upon optimizations implemented in PyTorch DDP moduel, torhc-ccl accelerates communication operations. Beside the optimizations made to communication kernels, torch-ccl also features simultaneous computation-communication functionality.

###############################################################################
# GPU specific optimizations
# --------------------------
Expand Down