Skip to content

Commit cde83b9

Browse files
committed
Add image
1 parent 1d63510 commit cde83b9

File tree

3 files changed

+55
-6
lines changed

3 files changed

+55
-6
lines changed

_static/img/rpc_trace_img.png

307 KB
Loading

recipes_source/distributed_rpc_profiling.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ two RPC workers on the same host, named "worker0" and "worker1" respectively. Th
4444
be spawned as subprocesses, and we set some environment variables required for proper
4545
initialization (see torch.distributed documentation for more details).
4646

47-
.. code:: python
47+
.. code:: python3
4848
import torch
4949
import torch.distributed.rpc as rpc
5050
import torch.autograd.profiler as profiler
@@ -95,7 +95,7 @@ Now that we have a skeleton setup of our RPC framework, we can move on to
9595
sending RPCs back and forth and using the profiler to obtain a view of what's
9696
happening under the hood. Let's add to the above "worker" function:
9797

98-
..code:: python
98+
..code:: python3
9999
def worker(rank, world_size):
100100
# Above code omitted...
101101
if rank == 0:
@@ -152,7 +152,7 @@ call are prefixed with ::rpc_async#aten::mul(worker0 -> worker1).
152152
We can also use the profiler gain insight into user-defined functions that are executed over RPC.
153153
For example, let's add the following to the above "worker" function:
154154

155-
..code:: python
155+
..code:: python3
156156
# Define somewhere outside of worker() func.
157157
def udf_with_ops():
158158
import time
@@ -161,7 +161,7 @@ For example, let's add the following to the above "worker" function:
161161
torch.add(t1, t2)
162162
torch.mul(t1, t2)
163163

164-
..code::python
164+
..code::python3
165165
def worker(rank, world_size):
166166
# Above code omitted
167167
with profiler.profile() as p:
@@ -197,7 +197,7 @@ remote operators that have been executed on worker 1 as part of executing this R
197197
Lastly, we can visualize remote execution using the tracing functionality provided by the profiler.
198198
Let's add the following code to the above "worker" function:
199199

200-
..code:: python
200+
..code:: python3
201201
def worker(rank, world_size):
202202
# Above code omitted
203203
# Will generated trace for above profiling output
@@ -215,7 +215,7 @@ in this case, given in the trace column for "node_id: 1".
215215

216216
Putting it all together, we have the following code for this recipe:
217217

218-
..code:: python
218+
..code:: python3
219219
import torch
220220
import torch.distributed.rpc as rpc
221221
import torch.autograd.profiler as profiler

recipes_source/prof_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
import torch.distributed.rpc as rpc
3+
import torch.autograd.profiler as profiler
4+
import torch.multiprocessing as mp
5+
import os
6+
import logging
7+
import sys
8+
9+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
10+
logger = logging.getLogger()
11+
12+
def random_tensor():
13+
return torch.rand((3, 3), requires_grad=True)
14+
15+
16+
def worker(rank, world_size):
17+
os.environ["MASTER_ADDR"] = "localhost"
18+
os.environ["MASTER_PORT"] = "29500"
19+
worker_name = f"worker{rank}"
20+
21+
# Initialize RPC framework.
22+
rpc.init_rpc(
23+
name=worker_name,
24+
rank=rank,
25+
world_size=world_size
26+
)
27+
logger.debug(f"{worker_name} successfully initialized RPC.")
28+
29+
pass # to be continued below
30+
if rank == 0:
31+
dst_worker_rank = (rank + 1) % world_size
32+
dst_worker_name = f"worker{dst_worker_rank}"
33+
t1, t2 = random_tensor(), random_tensor()
34+
# Send and wait RPC completion under profiling scope.
35+
with profiler.profile() as p:
36+
fut1 = rpc.rpc_async(dst_worker_name, torch.add, args=(t1, t2))
37+
fut2 = rpc.rpc_async(dst_worker_name, torch.mul, args=(t1, t2))
38+
# RPCs must be awaited within profiling scope.
39+
fut1.wait()
40+
fut2.wait()
41+
42+
print(p.key_averages().table())
43+
44+
45+
46+
if __name__ == '__main__':
47+
# Run 2 RPC workers.
48+
world_size = 2
49+
mp.spawn(worker, args=(world_size,), nprocs=world_size)

0 commit comments

Comments
 (0)