1+ import os
2+ import threading
3+ import time
4+ from functools import wraps
5+
6+ import torch
7+ import torch .nn as nn
8+ import torch .distributed .autograd as dist_autograd
9+ import torch .distributed .rpc as rpc
10+ import torch .multiprocessing as mp
11+ import torch .optim as optim
12+ from torch .distributed .optim import DistributedOptimizer
13+ from torch .distributed .rpc import RRef
14+
15+ from torchvision .models .resnet import Bottleneck
16+
17+
18+ #########################################################
19+ # helper functions #
20+ #########################################################
21+
22+
23+ def _call_method (method , rref , * args , ** kwargs ):
24+ r"""
25+ a helper function to call a method on the given RRef
26+ """
27+ return method (rref .local_value (), * args , ** kwargs )
28+
29+
30+ def _remote_on_rref (method , rref , * args , ** kwargs ):
31+ r"""
32+ a helper function to run method on the owner of rref and return an RRef
33+ of the result.
34+ """
35+ return rpc .remote (
36+ rref .owner (),
37+ _call_method ,
38+ args = [method , rref ] + list (args ),
39+ kwargs = kwargs
40+ )
41+
42+
43+ def _async_on_rref (method , rref , * args , ** kwargs ):
44+ r"""
45+ a helper function to run method on the owner of rref and fetch back the
46+ result using RPC
47+ """
48+ return rpc .rpc_async (
49+ rref .owner (),
50+ _call_method ,
51+ args = [method , rref ] + list (args ),
52+ kwargs = kwargs
53+ )
54+
55+
56+ def _parameter_rrefs (module ):
57+ r"""
58+ Create one RRef for each parameter in the given local module, and return a
59+ list of RRefs.
60+ """
61+ param_rrefs = []
62+ for param in module .parameters ():
63+ param_rrefs .append (RRef (param ))
64+ return param_rrefs
65+
66+
67+ #########################################################
68+ # Define Model Parallel ResNet50 #
69+ #########################################################
70+
71+
72+ num_classes = 1000
73+
74+
75+ def conv1x1 (in_planes , out_planes , stride = 1 ):
76+ """1x1 convolution"""
77+ return nn .Conv2d (in_planes , out_planes , kernel_size = 1 , stride = stride , bias = False )
78+
79+
80+ class ResNetBase (nn .Module ):
81+ def __init__ (self , block , inplanes , num_classes = 1000 ,
82+ groups = 1 , width_per_group = 64 , norm_layer = None ):
83+ super (ResNetBase , self ).__init__ ()
84+
85+ self ._lock = threading .Lock ()
86+ self ._block = block
87+ self ._norm_layer = nn .BatchNorm2d
88+ self .inplanes = inplanes
89+ self .dilation = 1
90+ self .groups = groups
91+ self .base_width = width_per_group
92+
93+ def _make_layer (self , planes , blocks , stride = 1 ):
94+ norm_layer = self ._norm_layer
95+ downsample = None
96+ previous_dilation = self .dilation
97+ if stride != 1 or self .inplanes != planes * self ._block .expansion :
98+ downsample = nn .Sequential (
99+ conv1x1 (self .inplanes , planes * self ._block .expansion , stride ),
100+ norm_layer (planes * self ._block .expansion ),
101+ )
102+
103+ layers = []
104+ layers .append (self ._block (self .inplanes , planes , stride , downsample , self .groups ,
105+ self .base_width , previous_dilation , norm_layer ))
106+ self .inplanes = planes * self ._block .expansion
107+ for _ in range (1 , blocks ):
108+ layers .append (self ._block (self .inplanes , planes , groups = self .groups ,
109+ base_width = self .base_width , dilation = self .dilation ,
110+ norm_layer = norm_layer ))
111+
112+ return nn .Sequential (* layers )
113+
114+
115+ class ResNetPart1 (ResNetBase ):
116+ """
117+ The first part of ResNet.
118+ """
119+ def __init__ (self , device , * args , ** kwargs ):
120+ super (ResNetPart1 , self ).__init__ (
121+ Bottleneck , 64 , num_classes = num_classes , * args , ** kwargs )
122+
123+ self .device = device
124+ self .seq = nn .Sequential (
125+ nn .Conv2d (3 , self .inplanes , kernel_size = 7 , stride = 2 , padding = 3 , bias = False ),
126+ self ._norm_layer (self .inplanes ),
127+ nn .ReLU (inplace = True ),
128+ nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ),
129+ self ._make_layer (64 , 3 ),
130+ self ._make_layer (128 , 4 , stride = 2 )
131+ ).to (self .device )
132+
133+ for m in self .modules ():
134+ if isinstance (m , nn .Conv2d ):
135+ nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' , nonlinearity = 'relu' )
136+ elif isinstance (m , nn .BatchNorm2d ):
137+ nn .init .constant_ (m .weight , 1 )
138+ nn .init .constant_ (m .bias , 0 )
139+
140+ def forward (self , x_rref ):
141+ x = x_rref .to_here ().to (self .device )
142+ with self ._lock :
143+ out = self .seq (x )
144+ return out .cpu ()
145+
146+
147+ class ResNetPart2 (ResNetBase ):
148+ """
149+ The second part of ResNet.
150+ """
151+ def __init__ (self , device , * args , ** kwargs ):
152+ super (ResNetPart2 , self ).__init__ (
153+ Bottleneck , 512 , num_classes = num_classes , * args , ** kwargs )
154+
155+ self .device = device
156+ self .seq = nn .Sequential (
157+ self ._make_layer (256 , 6 , stride = 2 ),
158+ self ._make_layer (512 , 3 , stride = 2 ),
159+ nn .AdaptiveAvgPool2d ((1 , 1 )),
160+ ).to (self .device )
161+
162+ self .fc = nn .Linear (512 * self ._block .expansion , num_classes ).to (self .device )
163+
164+ def forward (self , x_rref ):
165+ x = x_rref .to_here ().to (self .device )
166+ with self ._lock :
167+ out = self .fc (torch .flatten (self .seq (x ), 1 ))
168+ return out .cpu ()
169+
170+
171+ class DistResNet50 (nn .Module ):
172+ """
173+ Assemble two parts as an nn.Module and define pipelining logic
174+ """
175+ def __init__ (self , split_size , workers , * args , ** kwargs ):
176+ super (DistResNet50 , self ).__init__ ()
177+
178+ self .split_size = split_size
179+
180+ # Put the first part of the ResNet50 on workers[0]
181+ self .p1_rref = rpc .remote (
182+ workers [0 ],
183+ ResNetPart1 ,
184+ args = ("cuda:0" ,) + args ,
185+ kwargs = kwargs
186+ )
187+
188+ # Put the second part of the ResNet50 on workers[1]
189+ self .p2_rref = rpc .remote (
190+ workers [1 ],
191+ ResNetPart2 ,
192+ args = ("cuda:1" ,) + args ,
193+ kwargs = kwargs
194+ )
195+
196+ def forward (self , xs ):
197+ # Split the input batch xs into micro-batches, and collect async RPC
198+ # futures into a list
199+ out_futures = []
200+ for x in iter (xs .split (self .split_size , dim = 0 )):
201+ x_rref = RRef (x )
202+ y_rref = _remote_on_rref (ResNetPart1 .forward , self .p1_rref , x_rref )
203+ z_fut = _async_on_rref (ResNetPart2 .forward , self .p2_rref , y_rref )
204+ out_futures .append (z_fut )
205+
206+ # wait for all RPC to finish
207+ outs = [fut .wait () for fut in out_futures ]
208+ # cat all tensors into one tensor.
209+ out = torch .cat (outs )
210+ return out
211+
212+ def parameter_rrefs (self ):
213+ remote_params = []
214+ remote_params .extend (_remote_on_rref (_parameter_rrefs , self .p1_rref ).to_here ())
215+ remote_params .extend (_remote_on_rref (_parameter_rrefs , self .p2_rref ).to_here ())
216+ return remote_params
217+
218+
219+ #########################################################
220+ # Run RPC Processes #
221+ #########################################################
222+
223+ num_batches = 3
224+ batch_size = 120
225+ image_w = 128
226+ image_h = 128
227+
228+
229+ def run_master (num_split ):
230+ # put the two model parts on worker1 and worker2 respectively
231+ model = DistResNet50 (split_size , ["worker1" , "worker2" ])
232+ loss_fn = nn .MSELoss ()
233+ opt = DistributedOptimizer (
234+ optim .SGD ,
235+ model .parameter_rrefs (),
236+ lr = 0.05 ,
237+ )
238+
239+ one_hot_indices = torch .LongTensor (batch_size ) \
240+ .random_ (0 , num_classes ) \
241+ .view (batch_size , 1 )
242+
243+ for i in range (num_batches ):
244+ print (f"Processing batch { i } " )
245+ # generate random inputs and labels
246+ inputs = torch .randn (batch_size , 3 , image_w , image_h )
247+ labels = torch .zeros (batch_size , num_classes ) \
248+ .scatter_ (1 , one_hot_indices , 1 )
249+
250+ with dist_autograd .context () as context_id :
251+ outputs = model (inputs )
252+ dist_autograd .backward (context_id , [loss_fn (outputs , labels )])
253+ opt .step (context_id )
254+
255+
256+ def run_worker (rank , world_size , num_split ):
257+ os .environ ['MASTER_ADDR' ] = 'localhost'
258+ os .environ ['MASTER_PORT' ] = '29500'
259+ options = rpc .ProcessGroupRpcBackendOptions (num_send_recv_threads = 256 )
260+
261+ if rank == 0 :
262+ rpc .init_rpc (
263+ "master" ,
264+ rank = rank ,
265+ world_size = world_size ,
266+ rpc_backend_options = options
267+ )
268+ run_master (num_split )
269+ else :
270+ rpc .init_rpc (
271+ f"worker{ rank } " ,
272+ rank = rank ,
273+ world_size = world_size ,
274+ rpc_backend_options = options
275+ )
276+ pass
277+
278+ # block until all rpcs finish
279+ rpc .shutdown ()
280+
281+
282+ if __name__ == "__main__" :
283+ world_size = 3
284+ for num_split in [1 , 2 , 4 , 8 ]:
285+ tik = time .time ()
286+ mp .spawn (run_worker , args = (world_size , num_split ), nprocs = world_size , join = True )
287+ tok = time .time ()
288+ print (f"number of splits = { num_split } , execution time = { tok - tik } " )
0 commit comments