|
| 1 | +Dynamic Parallelism in TorchScript |
| 2 | +================================== |
| 3 | + |
| 4 | +In this tutorial, we introduce the syntax for doing *dynamic inter-op parallelism* |
| 5 | +in TorchScript. This parallelism has the following properties: |
| 6 | + |
| 7 | +* dynamic - The number of parallel tasks created and their workload can depend on the control flow of the program. |
| 8 | +* inter-op - The parallelism is concerned with running TorchScript program fragments in parallel. This is distinct from *intra-op parallelism*, which is concerned with splitting up individual operators and running subsets of the operator's work in parallel. |
| 9 | +Basic Syntax |
| 10 | +------------ |
| 11 | + |
| 12 | +The two important APIs for dynamic parallelism are: |
| 13 | + |
| 14 | +* ``torch.jit.fork(fn : Callable[..., T], *args, **kwargs) -> torch.jit.Future[T]`` |
| 15 | +* ``torch.jit.wait(fut : torch.jit.Future[T]) -> T`` |
| 16 | + |
| 17 | +A good way to demonstrate how these work is by way of an example: |
| 18 | + |
| 19 | +.. code-block:: python |
| 20 | +
|
| 21 | + import torch |
| 22 | +
|
| 23 | + def foo(x): |
| 24 | + return torch.neg(x) |
| 25 | +
|
| 26 | + @torch.jit.script |
| 27 | + def example(x): |
| 28 | + # Call `foo` using parallelism: |
| 29 | + # First, we "fork" off a task. This task will run `foo` with argument `x` |
| 30 | + future = torch.jit.fork(foo, x) |
| 31 | +
|
| 32 | + # Call `foo` normally |
| 33 | + x_normal = foo(x) |
| 34 | +
|
| 35 | + # Second, we "wait" on the task. Since the task may be running in |
| 36 | + # parallel, we have to "wait" for its result to become available. |
| 37 | + # Notice that by having lines of code between the "fork()" and "wait()" |
| 38 | + # call for a given Future, we can overlap computations so that they |
| 39 | + # run in parallel. |
| 40 | + x_parallel = torch.jit.wait(future) |
| 41 | +
|
| 42 | + return x_normal, x_parallel |
| 43 | +
|
| 44 | + print(example(torch.ones(1))) # (-1., -1.) |
| 45 | +
|
| 46 | +
|
| 47 | +``fork()`` takes the callable ``fn`` and arguments to that callable ``args`` |
| 48 | +and ``kwargs`` and creates an asynchronous task for the execution of ``fn``. |
| 49 | +``fn`` can be a function, method, or Module instance. ``fork()`` returns a |
| 50 | +reference to the value of the result of this execution, called a ``Future``. |
| 51 | +Because ``fork`` returns immediately after creating the async task, ``fn`` may |
| 52 | +not have been executed by the time the line of code after the ``fork()`` call |
| 53 | +is executed. Thus, ``wait()`` is used to wait for the async task to complete |
| 54 | +and return the value. |
| 55 | + |
| 56 | +These constructs can be used to overlap the execution of statements within a |
| 57 | +function (shown in the worked example section) or be composed with other language |
| 58 | +constructs like loops: |
| 59 | + |
| 60 | +.. code-block:: python |
| 61 | +
|
| 62 | + import torch |
| 63 | + from typing import List |
| 64 | +
|
| 65 | + def foo(x): |
| 66 | + return torch.neg(x) |
| 67 | +
|
| 68 | + @torch.jit.script |
| 69 | + def example(x): |
| 70 | + futures : List[torch.jit.Future[torch.Tensor]] = [] |
| 71 | + for _ in range(100): |
| 72 | + futures.append(torch.jit.fork(foo, x)) |
| 73 | +
|
| 74 | + results = [] |
| 75 | + for future in futures: |
| 76 | + results.append(torch.jit.wait(future)) |
| 77 | +
|
| 78 | + return torch.sum(torch.stack(results)) |
| 79 | +
|
| 80 | + print(example(torch.ones([]))) |
| 81 | +
|
| 82 | +.. note:: |
| 83 | + |
| 84 | + When we initialized an empty list of Futures, we needed to add an explicit |
| 85 | + type annotation to ``futures``. In TorchScript, empty containers default |
| 86 | + to assuming they contain Tensor values, so we annotate the list constructor |
| 87 | + # as being of type ``List[torch.jit.Future[torch.Tensor]]`` |
| 88 | + |
| 89 | +This example uses ``fork()`` to launch 100 instances of the function ``foo``, |
| 90 | +waits on the 100 tasks to complete, then sums the results, returning ``-100.0``. |
| 91 | + |
| 92 | +Applied Example: Ensemble of Bidirectional LSTMs |
| 93 | +------------------------------------------------ |
| 94 | + |
| 95 | +Let's try to apply parallelism to a more realistic example and see what sort |
| 96 | +of performance we can get out of it. First, let's define the baseline model: an |
| 97 | +ensemble of bidirectional LSTM layers. |
| 98 | + |
| 99 | +.. code-block:: python |
| 100 | +
|
| 101 | + import torch, time |
| 102 | +
|
| 103 | + # In RNN parlance, the dimensions we care about are: |
| 104 | + # # of time-steps (T) |
| 105 | + # Batch size (B) |
| 106 | + # Hidden size/number of "channels" (C) |
| 107 | + T, B, C = 50, 50, 1024 |
| 108 | +
|
| 109 | + # A module that defines a single "bidirectional LSTM". This is simply two |
| 110 | + # LSTMs applied to the same sequence, but one in reverse |
| 111 | + class BidirectionalRecurrentLSTM(torch.nn.Module): |
| 112 | + def __init__(self): |
| 113 | + super().__init__() |
| 114 | + self.cell_f = torch.nn.LSTM(input_size=C, hidden_size=C) |
| 115 | + self.cell_b = torch.nn.LSTM(input_size=C, hidden_size=C) |
| 116 | +
|
| 117 | + def forward(self, x : torch.Tensor) -> torch.Tensor: |
| 118 | + # Forward layer |
| 119 | + output_f, _ = self.cell_f(x) |
| 120 | +
|
| 121 | + # Backward layer. Flip input in the time dimension (dim 0), apply the |
| 122 | + # layer, then flip the outputs in the time dimension |
| 123 | + x_rev = torch.flip(x, dims=[0]) |
| 124 | + output_b, _ = self.cell_b(torch.flip(x, dims=[0])) |
| 125 | + output_b_rev = torch.flip(output_b, dims=[0]) |
| 126 | +
|
| 127 | + return torch.cat((output_f, output_b_rev), dim=2) |
| 128 | +
|
| 129 | +
|
| 130 | + # An "ensemble" of `BidirectionalRecurrentLSTM` modules. The modules in the |
| 131 | + # ensemble are run one-by-one on the same input then their results are |
| 132 | + # stacked and summed together, returning the combined result. |
| 133 | + class LSTMEnsemble(torch.nn.Module): |
| 134 | + def __init__(self, n_models): |
| 135 | + super().__init__() |
| 136 | + self.n_models = n_models |
| 137 | + self.models = torch.nn.ModuleList([ |
| 138 | + BidirectionalRecurrentLSTM() for _ in range(self.n_models)]) |
| 139 | +
|
| 140 | + def forward(self, x : torch.Tensor) -> torch.Tensor: |
| 141 | + results = [] |
| 142 | + for model in self.models: |
| 143 | + results.append(model(x)) |
| 144 | + return torch.stack(results).sum(dim=0) |
| 145 | +
|
| 146 | + # For a head-to-head comparison to what we're going to do with fork/wait, let's |
| 147 | + # instantiate the model and compile it with TorchScript |
| 148 | + ens = torch.jit.script(LSTMEnsemble(n_models=4)) |
| 149 | +
|
| 150 | + # Normally you would pull this input out of an embedding table, but for the |
| 151 | + # purpose of this demo let's just use random data. |
| 152 | + x = torch.rand(T, B, C) |
| 153 | +
|
| 154 | + # Let's run the model once to warm up things like the memory allocator |
| 155 | + ens(x) |
| 156 | +
|
| 157 | + x = torch.rand(T, B, C) |
| 158 | +
|
| 159 | + # Let's see how fast it runs! |
| 160 | + s = time.time() |
| 161 | + ens(x) |
| 162 | + print('Inference took', time.time() - s, ' seconds') |
| 163 | +
|
| 164 | +On my machine, this network runs in ``2.05`` seconds. We can do a lot better! |
| 165 | + |
| 166 | +Parallelizing Forward and Backward Layers |
| 167 | +----------------------------------------- |
| 168 | + |
| 169 | +A very simple thing we can do is parallelize the forward and backward layers |
| 170 | +within ``BidirectionalRecurrentLSTM``. For this, the structure of the computation |
| 171 | +is static, so we don't actually even need any loops. Let's rewrite the ``forward`` |
| 172 | +method of ``BidirectionalRecurrentLSTM`` like so: |
| 173 | + |
| 174 | +.. code-block:: python |
| 175 | +
|
| 176 | + def forward(self, x : torch.Tensor) -> torch.Tensor: |
| 177 | + # Forward layer - fork() so this can run in parallel to the backward |
| 178 | + # layer |
| 179 | + future_f = torch.jit.fork(self.cell_f, x) |
| 180 | +
|
| 181 | + # Backward layer. Flip input in the time dimension (dim 0), apply the |
| 182 | + # layer, then flip the outputs in the time dimension |
| 183 | + x_rev = torch.flip(x, dims=[0]) |
| 184 | + output_b, _ = self.cell_b(torch.flip(x, dims=[0])) |
| 185 | + output_b_rev = torch.flip(output_b, dims=[0]) |
| 186 | +
|
| 187 | + # Retrieve the output from the forward layer. Note this needs to happen |
| 188 | + # *after* the stuff we want to parallelize with |
| 189 | + output_f, _ = torch.jit.wait(future_f) |
| 190 | +
|
| 191 | + return torch.cat((output_f, output_b_rev), dim=2) |
| 192 | +
|
| 193 | +In this example, ``forward()`` delegates execution of ``cell_f`` to another thread, |
| 194 | +while it continues to execute ``cell_b``. This causes the execution of both the |
| 195 | +cells to be overlapped with each other. |
| 196 | + |
| 197 | +Running the script again with this simple modification yields a runtime of |
| 198 | +``1.71`` seconds for an improvement of ``17%``! |
| 199 | + |
| 200 | +Aside: Visualizing Parallelism |
| 201 | +------------------------------ |
| 202 | + |
| 203 | +We're not done optimizing our model but it's worth introducing the tooling we |
| 204 | +have for visualizing performance. One important tool is the `PyTorch profiler <https://pytorch.org/docs/stable/autograd.html#profiler>`_. |
| 205 | + |
| 206 | +Let's use the profiler along with the Chrome trace export functionality to |
| 207 | +visualize the performance of our parallelized model: |
| 208 | + |
| 209 | +.. code-block:: python |
| 210 | + with torch.autograd.profiler.profile() as prof: |
| 211 | + ens(x) |
| 212 | + prof.export_chrome_trace('parallel.json') |
| 213 | +
|
| 214 | +This snippet of code will write out a file named ``parallel.json``. If you |
| 215 | +navigate Google Chrome to ``chrome://tracing``, click the ``Load`` button, and |
| 216 | +load in that JSON file, you should see a timeline like the following: |
| 217 | + |
| 218 | +.. image:: https://i.imgur.com/rm5hdG9.png |
| 219 | + |
| 220 | +The horizontal axis of the timeline represents time and the vertical axis |
| 221 | +represents threads of execution. As we can see, we are running two ``lstm`` |
| 222 | +instances at a time. This is the result of our hard work parallelizing the |
| 223 | +bidirectional layers! |
| 224 | + |
| 225 | +Parallelizing Models in the Ensemble |
| 226 | +------------------------------------ |
| 227 | + |
| 228 | +You may have noticed that there is a further parallelization opportunity in our |
| 229 | +code: we can also run the models contained in ``LSTMEnsemble`` in parallel with |
| 230 | +each other. The way to do that is simple enough, this is how we should change |
| 231 | +the ``forward`` method of ``LSTMEnsemble``: |
| 232 | + |
| 233 | +.. code-block:: python |
| 234 | +
|
| 235 | + def forward(self, x : torch.Tensor) -> torch.Tensor: |
| 236 | + # Launch tasks for each model |
| 237 | + futures : List[torch.jit.Future[torch.Tensor]] = [] |
| 238 | + for model in self.models: |
| 239 | + futures.append(torch.jit.fork(model, x)) |
| 240 | +
|
| 241 | + # Collect the results from the launched tasks |
| 242 | + results : List[torch.Tensor] = [] |
| 243 | + for future in futures: |
| 244 | + results.append(torch.jit.wait(future)) |
| 245 | +
|
| 246 | + return torch.stack(results).sum(dim=0) |
| 247 | +
|
| 248 | +Or, if you value brevity, we can use list comprehensions: |
| 249 | + |
| 250 | +.. code-block:: python |
| 251 | +
|
| 252 | + def forward(self, x : torch.Tensor) -> torch.Tensor: |
| 253 | + futures = [torch.jit.fork(model, x) for model in self.models] |
| 254 | + results = [torch.jit.wait(fut) for fut in futures] |
| 255 | + return torch.stack(results).sum(dim=0) |
| 256 | +
|
| 257 | +Like described in the intro, we've used loops to fork off tasks for each of the |
| 258 | +models in our ensemble. We've then used another loop to wait for all of the |
| 259 | +tasks to be completed. This provides even more overlap of computation. |
| 260 | + |
| 261 | +With this small update, the script runs in ``1.4`` seconds, for a total speedup |
| 262 | +of ``32%``! Pretty good for two lines of code. |
| 263 | + |
| 264 | +We can also use the Chrome tracer again to see where's going on: |
| 265 | + |
| 266 | +.. image:: https://i.imgur.com/kA0gyQm.png |
| 267 | + |
| 268 | +We can now see that all ``LSTM`` instances are being run fully in parallel. |
| 269 | + |
| 270 | +Conclusion |
| 271 | +---------- |
| 272 | + |
| 273 | +In this tutorial, we learned about ``fork()`` and ``wait()``, the basic APIs |
| 274 | +for doing dynamic, inter-op parallelism in TorchScript. We saw a few typical |
| 275 | +usage patterns for using these functions to parallelize the execution of |
| 276 | +functions, methods, or ``Modules`` in TorchScript code. Finally, we worked through |
| 277 | +an example of optimizing a model using this technique and explored the performance |
| 278 | +measurement and visualization tooling available in PyTorch. |
0 commit comments