Skip to content

Commit 86ee4d0

Browse files
Tabrizianholly1238
andauthored
change reduce_op to ReduceOp (#1048)
Co-authored-by: holly1238 <[email protected]>
1 parent 4f2ce0e commit 86ee4d0

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

intermediate_source/dist_tuto.rst

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -210,19 +210,19 @@ to obtain the sum of all tensors at all processes, we can use the
210210
""" Simple point-to-point communication. """
211211
group = dist.new_group([0, 1])
212212
tensor = torch.ones(1)
213-
dist.all_reduce(tensor, op=dist.reduce_op.SUM, group=group)
213+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
214214
print('Rank ', rank, ' has data ', tensor[0])
215215
216216
Since we want the sum of all tensors in the group, we use
217-
``dist.reduce_op.SUM`` as the reduce operator. Generally speaking, any
217+
``dist.ReduceOp.SUM`` as the reduce operator. Generally speaking, any
218218
commutative mathematical operation can be used as an operator.
219219
Out-of-the-box, PyTorch comes with 4 such operators, all working at the
220220
element-wise level:
221221

222-
- ``dist.reduce_op.SUM``,
223-
- ``dist.reduce_op.PRODUCT``,
224-
- ``dist.reduce_op.MAX``,
225-
- ``dist.reduce_op.MIN``.
222+
- ``dist.ReduceOp.SUM``,
223+
- ``dist.ReduceOp.PRODUCT``,
224+
- ``dist.ReduceOp.MAX``,
225+
- ``dist.ReduceOp.MIN``.
226226

227227
In addition to ``dist.all_reduce(tensor, op, group)``, there are a total
228228
of 6 collectives currently implemented in PyTorch.
@@ -376,7 +376,7 @@ world.
376376
def average_gradients(model):
377377
size = float(dist.get_world_size())
378378
for param in model.parameters():
379-
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
379+
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
380380
param.grad.data /= size
381381
382382
*Et voilà*! We successfully implemented distributed synchronous SGD and

0 commit comments

Comments
 (0)