Skip to content

Commit 07c57f7

Browse files
committed
[Distributed] Directly use hvd DistributedOptimizer.
Signed-off-by: 泊霆 <[email protected]>
1 parent 6dae552 commit 07c57f7

File tree

1 file changed

+6
-20
lines changed

1 file changed

+6
-20
lines changed

tensorflow/python/distribute/hvd_strategy.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -388,20 +388,16 @@ def wraps_optimizer(cls):
388388
HvdOptimizer
389389
'''
390390
class HvdOptimizer(cls, optimizer.Optimizer):
391-
def __init__(self, *args, **kwargs):
392-
kwargs["learning_rate"] = kwargs.get("learning_rate", 0.001) *\
393-
HvdContext.get().world_size
394-
super(HvdOptimizer, self).__init__(*args, **kwargs)
391+
def __init__(self, learning_rate=0.001, *args, **kwargs):
392+
learning_rate = learning_rate * HvdContext.get().world_size
393+
super(HvdOptimizer, self).__init__(learning_rate, *args, **kwargs)
395394

396-
def compute_gradients(self, loss, **kwargs):
397-
loss = hvd.allreduce(loss, op=hvd.Sum)
398-
return super().compute_gradients(loss, **kwargs)
399-
400395
if isinstance(cls, HvdOptimizer):
401396
return cls
402397
else:
403398
def horovod_optimizer(*args, **kwargs):
404-
return HvdOptimizer(*args, **kwargs)
399+
from horovod.tensorflow import DistributedOptimizer
400+
return DistributedOptimizer(HvdOptimizer(*args, **kwargs))
405401
return horovod_optimizer
406402

407403

@@ -478,16 +474,6 @@ def HorovodMonitoredTrainingSession(*args, **kwargs): # pylint: disable=invalid
478474
kwargs['config'] = wraps_session_config(kwargs.pop('config', None))
479475
kwargs['is_chief'] = True
480476
args = list(args)
481-
if args:
482-
master = args[0]
483-
if not master:
484-
master = ''
485-
args[0] = master
486-
else:
487-
master = kwargs.pop('master', None)
488-
if not master:
489-
master = ''
490-
kwargs['master'] = master
491477

492478
prev_monitored_session = _monitored_session.MonitoredSession
493479
sess = fn(*args, **kwargs)
@@ -1449,4 +1435,4 @@ def export(export_dir_base,
14491435
as_text=as_text,
14501436
clear_devices=clear_devices,
14511437
strip_default_attrs=strip_default_attrs,
1452-
modes=[mode])
1438+
modes=[mode])

0 commit comments

Comments
 (0)