1+ import os
2+ import threading
3+ import time
4+ from functools import wraps
5+
6+ import torch
7+ import torch .nn as nn
8+ import torch .distributed .autograd as dist_autograd
9+ import torch .distributed .rpc as rpc
10+ import torch .multiprocessing as mp
11+ import torch .optim as optim
12+ from torch .distributed .optim import DistributedOptimizer
13+ from torch .distributed .rpc import RRef
14+
15+ from torchvision .models .resnet import Bottleneck
16+
17+
18+ #########################################################
19+ # helper functions #
20+ #########################################################
21+
22+
23+ def _call_method (method , rref , * args , ** kwargs ):
24+ r"""
25+ a helper function to call a method on the given RRef
26+ """
27+ return method (rref .local_value (), * args , ** kwargs )
28+
29+
30+ def _remote_on_rref (method , rref , * args , ** kwargs ):
31+ r"""
32+ a helper function to run method on the owner of rref and fetch back the
33+ result using RPC
34+ """
35+ return rpc .remote (
36+ rref .owner (),
37+ _call_method ,
38+ args = [method , rref ] + list (args ),
39+ kwargs = kwargs
40+ )
41+
42+ def _async_on_rref (method , rref , * args , ** kwargs ):
43+ r"""
44+ a helper function to run method on the owner of rref and fetch back the
45+ result using RPC
46+ """
47+ return rpc .rpc_async (
48+ rref .owner (),
49+ _call_method ,
50+ args = [method , rref ] + list (args ),
51+ kwargs = kwargs
52+ )
53+
54+
55+ def _parameter_rrefs (module ):
56+ r"""
57+ Create one RRef for each parameter in the given local module, and return a
58+ list of RRefs.
59+ """
60+ param_rrefs = []
61+ for param in module .parameters ():
62+ param_rrefs .append (RRef (param ))
63+ return param_rrefs
64+
65+
66+ #########################################################
67+ # Define Model Parallel ResNet50 #
68+ #########################################################
69+
70+
71+ num_classes = 1000
72+
73+
74+ def conv1x1 (in_planes , out_planes , stride = 1 ):
75+ """1x1 convolution"""
76+ return nn .Conv2d (in_planes , out_planes , kernel_size = 1 , stride = stride , bias = False )
77+
78+
79+ class ResNetBase (nn .Module ):
80+ def __init__ (self , block , inplanes , num_classes = 1000 ,
81+ groups = 1 , width_per_group = 64 , norm_layer = None ):
82+ super (ResNetBase , self ).__init__ ()
83+
84+ self ._lock = threading .Lock ()
85+ self ._block = block
86+ self ._norm_layer = nn .BatchNorm2d
87+ self .inplanes = inplanes
88+ self .dilation = 1
89+ self .groups = groups
90+ self .base_width = width_per_group
91+
92+ def _make_layer (self , planes , blocks , stride = 1 , dilate = False ):
93+ norm_layer = self ._norm_layer
94+ downsample = None
95+ previous_dilation = self .dilation
96+ if dilate :
97+ self .dilation *= stride
98+ stride = 1
99+ if stride != 1 or self .inplanes != planes * self ._block .expansion :
100+ downsample = nn .Sequential (
101+ conv1x1 (self .inplanes , planes * self ._block .expansion , stride ),
102+ norm_layer (planes * self ._block .expansion ),
103+ )
104+
105+ layers = []
106+ layers .append (self ._block (self .inplanes , planes , stride , downsample , self .groups ,
107+ self .base_width , previous_dilation , norm_layer ))
108+ self .inplanes = planes * self ._block .expansion
109+ for _ in range (1 , blocks ):
110+ layers .append (self ._block (self .inplanes , planes , groups = self .groups ,
111+ base_width = self .base_width , dilation = self .dilation ,
112+ norm_layer = norm_layer ))
113+
114+ return nn .Sequential (* layers )
115+
116+
117+ class ResNetPart1 (ResNetBase ):
118+ """
119+ The first part of ResNet.
120+ """
121+ def __init__ (self , device , * args , ** kwargs ):
122+ super (ResNetPart1 , self ).__init__ (
123+ Bottleneck , 64 , num_classes = num_classes , * args , ** kwargs )
124+
125+ self .device = device
126+ self .seq = nn .Sequential (
127+ nn .Conv2d (3 , self .inplanes , kernel_size = 7 , stride = 2 , padding = 3 , bias = False ),
128+ self ._norm_layer (self .inplanes ),
129+ nn .ReLU (inplace = True ),
130+ nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ),
131+ self ._make_layer (64 , 3 ),
132+ self ._make_layer (128 , 4 , stride = 2 )
133+ ).to (self .device )
134+
135+ for m in self .modules ():
136+ if isinstance (m , nn .Conv2d ):
137+ nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' , nonlinearity = 'relu' )
138+ elif isinstance (m , nn .BatchNorm2d ):
139+ nn .init .constant_ (m .weight , 1 )
140+ nn .init .constant_ (m .bias , 0 )
141+
142+ def forward (self , x_rref ):
143+ x = x_rref .to_here ().to (self .device )
144+ with self ._lock :
145+ out = self .seq (x )
146+ return out .cpu ()
147+
148+ class ResNetPart2 (ResNetBase ):
149+ """
150+ The second part of ResNet.
151+ """
152+ def __init__ (self , device , * args , ** kwargs ):
153+ super (ResNetPart2 , self ).__init__ (
154+ Bottleneck , 512 , num_classes = num_classes , * args , ** kwargs )
155+
156+ self .device = device
157+ self .seq = nn .Sequential (
158+ self ._make_layer (256 , 6 , stride = 2 ),
159+ self ._make_layer (512 , 3 , stride = 2 ),
160+ nn .AdaptiveAvgPool2d ((1 , 1 )),
161+ ).to (self .device )
162+
163+ self .fc = nn .Linear (512 * self ._block .expansion , num_classes ).to (self .device )
164+
165+ def forward (self , x_rref ):
166+ x = x_rref .to_here ().to (self .device )
167+ with self ._lock :
168+ out = self .fc (torch .flatten (self .seq (x ), 1 ))
169+ return out .cpu ()
170+
171+
172+ class DistResNet50 (nn .Module ):
173+ """
174+ Assemble two parts as an nn.Module and define pipelining logic
175+ """
176+ def __init__ (self , split_size , workers , * args , ** kwargs ):
177+ super (DistResNet50 , self ).__init__ ()
178+
179+ self .split_size = split_size
180+
181+ # Put the first part of the ResNet50 on workers[0]
182+ self .p1_rref = rpc .remote (
183+ workers [0 ],
184+ ResNetPart1 ,
185+ args = ("cuda:0" ,) + args ,
186+ kwargs = kwargs
187+ )
188+
189+ # Put the second part of the ResNet50 on workers[1]
190+ self .p2_rref = rpc .remote (
191+ workers [1 ],
192+ ResNetPart2 ,
193+ args = ("cuda:1" ,) + args ,
194+ kwargs = kwargs
195+ )
196+
197+ def forward (self , xs ):
198+ # Split the input batch xs into micro-batches, and collect async RPC
199+ # futures into a list
200+ out_futures = []
201+ for x in iter (xs .split (self .split_size , dim = 0 )):
202+ x_rref = RRef (x )
203+ y_rref = _remote_on_rref (ResNetPart1 .forward , self .p1_rref , x_rref )
204+ z_fut = _async_on_rref (ResNetPart2 .forward , self .p2_rref , y_rref )
205+ out_futures .append (z_fut )
206+
207+ # wait for all RPC to finish
208+ outs = [fut .wait () for fut in out_futures ]
209+ # cat all tensors into one tensor.
210+ out = torch .cat (outs )
211+ return out
212+
213+ def parameter_rrefs (self ):
214+ remote_params = []
215+ remote_params .extend (_remote_on_rref (_parameter_rrefs , self .p1_rref ).to_here ())
216+ remote_params .extend (_remote_on_rref (_parameter_rrefs , self .p2_rref ).to_here ())
217+ return remote_params
218+
219+
220+ #########################################################
221+ # Run RPC Processes #
222+ #########################################################
223+
224+ num_batches = 3
225+ batch_size = 120
226+ image_w = 128
227+ image_h = 128
228+
229+
230+ def run_master (split_size ):
231+ model = DistResNet50 (split_size , ["worker1" , "worker2" ])
232+ loss_fn = nn .MSELoss ()
233+ opt = DistributedOptimizer (
234+ optim .SGD ,
235+ model .parameter_rrefs (),
236+ lr = 0.05 ,
237+ )
238+
239+ one_hot_indices = torch .LongTensor (batch_size ) \
240+ .random_ (0 , num_classes ) \
241+ .view (batch_size , 1 )
242+
243+ for i in range (num_batches ):
244+ print (f"Processing batch { i } " )
245+ # generate random inputs and labels
246+ inputs = torch .randn (batch_size , 3 , image_w , image_h )
247+ labels = torch .zeros (batch_size , num_classes ) \
248+ .scatter_ (1 , one_hot_indices , 1 )
249+
250+ with dist_autograd .context () as context_id :
251+ outputs = model (inputs )
252+ dist_autograd .backward (context_id , [loss_fn (outputs , labels )])
253+ opt .step (context_id )
254+
255+
256+ def run_worker (rank , world_size , split_size ):
257+ os .environ ['MASTER_ADDR' ] = 'localhost'
258+ os .environ ['MASTER_PORT' ] = '29500'
259+ options = rpc .ProcessGroupRpcBackendOptions (num_send_recv_threads = 256 )
260+
261+ if rank == 0 :
262+ rpc .init_rpc (
263+ "master" ,
264+ rank = rank ,
265+ world_size = world_size ,
266+ rpc_backend_options = options
267+ )
268+ run_master (split_size )
269+ else :
270+ rpc .init_rpc (
271+ f"worker{ rank } " ,
272+ rank = rank ,
273+ world_size = world_size ,
274+ rpc_backend_options = options
275+ )
276+ pass
277+
278+ # block until all rpcs finish
279+ rpc .shutdown ()
280+
281+
282+ if __name__ == "__main__" :
283+ world_size = 3
284+ for split_size in [1 , 2 , 4 , 8 ]:
285+ tik = time .time ()
286+ mp .spawn (run_worker , args = (world_size , split_size ), nprocs = world_size , join = True )
287+ tok = time .time ()
288+ print (f"split size = { split_size } , execution time = { tok - tik } " )
0 commit comments