1515from torchvision .models .resnet import Bottleneck
1616
1717
18- #########################################################
19- # helper functions #
20- #########################################################
21-
22-
23- def _call_method (method , rref , * args , ** kwargs ):
24- r"""
25- a helper function to call a method on the given RRef
26- """
27- return method (rref .local_value (), * args , ** kwargs )
28-
29-
30- def _remote_on_rref (method , rref , * args , ** kwargs ):
31- r"""
32- a helper function to run method on the owner of rref and return an RRef
33- of the result.
34- """
35- return rpc .remote (
36- rref .owner (),
37- _call_method ,
38- args = [method , rref ] + list (args ),
39- kwargs = kwargs
40- )
41-
42-
43- def _async_on_rref (method , rref , * args , ** kwargs ):
44- r"""
45- a helper function to run method on the owner of rref and fetch back the
46- result using RPC
47- """
48- return rpc .rpc_async (
49- rref .owner (),
50- _call_method ,
51- args = [method , rref ] + list (args ),
52- kwargs = kwargs
53- )
54-
55-
56- def _parameter_rrefs (module ):
57- r"""
58- Create one RRef for each parameter in the given local module, and return a
59- list of RRefs.
60- """
61- param_rrefs = []
62- for param in module .parameters ():
63- param_rrefs .append (RRef (param ))
64- return param_rrefs
65-
66-
6718#########################################################
6819# Define Model Parallel ResNet50 #
6920#########################################################
7021
22+ # In order to split the ResNet50 and place it on two different workers, we
23+ # implement it in two model shards. The ResNetBase class defines common
24+ # attributes and methods shared by two shards. ResNetShard1 and ResNetShard2
25+ # contain two partitions of the model layers respectively.
26+
7127
7228num_classes = 1000
7329
@@ -76,9 +32,8 @@ def conv1x1(in_planes, out_planes, stride=1):
7632 """1x1 convolution"""
7733 return nn .Conv2d (in_planes , out_planes , kernel_size = 1 , stride = stride , bias = False )
7834
79-
8035class ResNetBase (nn .Module ):
81- def __init__ (self , block , inplanes , num_classes = 1000 ,
36+ def __init__ (self , block , inplanes , num_classes = 1000 ,
8237 groups = 1 , width_per_group = 64 , norm_layer = None ):
8338 super (ResNetBase , self ).__init__ ()
8439
@@ -111,13 +66,20 @@ def _make_layer(self, planes, blocks, stride=1):
11166
11267 return nn .Sequential (* layers )
11368
69+ def parameter_rrefs (self ):
70+ r"""
71+ Create one RRef for each parameter in the given local module, and return a
72+ list of RRefs.
73+ """
74+ return [RRef (p ) for p in self .parameters ()]
75+
11476
115- class ResNetPart1 (ResNetBase ):
77+ class ResNetShard1 (ResNetBase ):
11678 """
11779 The first part of ResNet.
11880 """
11981 def __init__ (self , device , * args , ** kwargs ):
120- super (ResNetPart1 , self ).__init__ (
82+ super (ResNetShard1 , self ).__init__ (
12183 Bottleneck , 64 , num_classes = num_classes , * args , ** kwargs )
12284
12385 self .device = device
@@ -144,12 +106,12 @@ def forward(self, x_rref):
144106 return out .cpu ()
145107
146108
147- class ResNetPart2 (ResNetBase ):
109+ class ResNetShard2 (ResNetBase ):
148110 """
149111 The second part of ResNet.
150112 """
151113 def __init__ (self , device , * args , ** kwargs ):
152- super (ResNetPart2 , self ).__init__ (
114+ super (ResNetShard2 , self ).__init__ (
153115 Bottleneck , 512 , num_classes = num_classes , * args , ** kwargs )
154116
155117 self .device = device
@@ -180,15 +142,15 @@ def __init__(self, split_size, workers, *args, **kwargs):
180142 # Put the first part of the ResNet50 on workers[0]
181143 self .p1_rref = rpc .remote (
182144 workers [0 ],
183- ResNetPart1 ,
145+ ResNetShard1 ,
184146 args = ("cuda:0" ,) + args ,
185147 kwargs = kwargs
186148 )
187149
188150 # Put the second part of the ResNet50 on workers[1]
189151 self .p2_rref = rpc .remote (
190152 workers [1 ],
191- ResNetPart2 ,
153+ ResNetShard2 ,
192154 args = ("cuda:1" ,) + args ,
193155 kwargs = kwargs
194156 )
@@ -199,22 +161,19 @@ def forward(self, xs):
199161 out_futures = []
200162 for x in iter (xs .split (self .split_size , dim = 0 )):
201163 x_rref = RRef (x )
202- y_rref = _remote_on_rref ( ResNetPart1 . forward , self .p1_rref , x_rref )
203- z_fut = _async_on_rref ( ResNetPart2 . forward , self .p2_rref , y_rref )
164+ y_rref = self .p1_rref . remote (). forward ( x_rref )
165+ z_fut = self .p2_rref . rpc_async (). forward ( y_rref )
204166 out_futures .append (z_fut )
205167
206- # wait for all RPC to finish
207- outs = [fut .wait () for fut in out_futures ]
208- # cat all tensors into one tensor.
209- out = torch .cat (outs )
210- return out
211-
168+ # collect and cat all output tensors into one tensor.
169+ return torch .cat (torch .futures .wait_all (out_futures ))
170+
212171 def parameter_rrefs (self ):
213172 remote_params = []
214- remote_params .extend (_remote_on_rref ( _parameter_rrefs , self .p1_rref ).to_here ())
215- remote_params .extend (_remote_on_rref ( _parameter_rrefs , self .p2_rref ).to_here ())
173+ remote_params .extend (self .p1_rref . remote (). parameter_rrefs ( ).to_here ())
174+ remote_params .extend (self .p2_rref . remote (). parameter_rrefs ( ).to_here ())
216175 return remote_params
217-
176+
218177
219178#########################################################
220179# Run RPC Processes #
@@ -248,6 +207,9 @@ def run_master(split_size):
248207 labels = torch .zeros (batch_size , num_classes ) \
249208 .scatter_ (1 , one_hot_indices , 1 )
250209
210+ # The distributed autograd context is the deciated scope for the
211+ # distributed backward pass to store gradients, which can later be
212+ # retrieved using the context_id by the distributed optimizer.
251213 with dist_autograd .context () as context_id :
252214 outputs = model (inputs )
253215 dist_autograd .backward (context_id , [loss_fn (outputs , labels )])
@@ -261,17 +223,17 @@ def run_worker(rank, world_size, num_split):
261223
262224 if rank == 0 :
263225 rpc .init_rpc (
264- "master" ,
265- rank = rank ,
266- world_size = world_size ,
226+ "master" ,
227+ rank = rank ,
228+ world_size = world_size ,
267229 rpc_backend_options = options
268230 )
269231 run_master (num_split )
270232 else :
271233 rpc .init_rpc (
272- f"worker{ rank } " ,
273- rank = rank ,
274- world_size = world_size ,
234+ f"worker{ rank } " ,
235+ rank = rank ,
236+ world_size = world_size ,
275237 rpc_backend_options = options
276238 )
277239 pass
0 commit comments