@@ -394,29 +394,28 @@ using point-to-point collectives.
394394
395395 """ Implementation of a ring-reduce with addition. """
396396 def allreduce (send , recv ):
397- rank = dist.get_rank()
398- size = dist.get_world_size()
399- send_buff = th.zeros(send.size())
400- recv_buff = th.zeros(send.size())
401- accum = th.zeros(send.size())
402- accum[:] = send[:]
403-
404- left = ((rank - 1 ) + size) % size
405- right = (rank + 1 ) % size
406-
407- for i in range (size - 1 ):
408- if i % 2 == 0 :
409- # Send send_buff
410- send_req = dist.isend(send_buff, right)
411- dist.recv(recv_buff, left)
412- accum[:] += recv[:]
413- else :
414- # Send recv_buff
415- send_req = dist.isend(recv_buff, right)
416- dist.recv(send_buff, left)
417- accum[:] += send[:]
418- send_req.wait()
419- recv[:] = accum[:]
397+ rank = dist.get_rank()
398+ size = dist.get_world_size()
399+ send_buff = send.clone()
400+ recv_buff = send.clone()
401+ accum = send.clone()
402+
403+ left = ((rank - 1 ) + size) % size
404+ right = (rank + 1 ) % size
405+
406+ for i in range (size - 1 ):
407+ if i % 2 == 0 :
408+ # Send send_buff
409+ send_req = dist.isend(send_buff, right)
410+ dist.recv(recv_buff, left)
411+ accum[:] += recv_buff[:]
412+ else :
413+ # Send recv_buff
414+ send_req = dist.isend(recv_buff, right)
415+ dist.recv(send_buff, left)
416+ accum[:] += send_buff[:]
417+ send_req.wait()
418+ recv[:] = accum[:]
420419
421420 In the above script, the ``allreduce(send, recv) `` function has a
422421slightly different signature than the ones in PyTorch. It takes a
0 commit comments