@@ -388,20 +388,16 @@ def wraps_optimizer(cls):
388
388
HvdOptimizer
389
389
'''
390
390
class HvdOptimizer (cls , optimizer .Optimizer ):
391
- def __init__ (self , * args , ** kwargs ):
392
- kwargs ["learning_rate" ] = kwargs .get ("learning_rate" , 0.001 ) * \
393
- HvdContext .get ().world_size
394
- super (HvdOptimizer , self ).__init__ (* args , ** kwargs )
391
+ def __init__ (self , learning_rate = 0.001 , * args , ** kwargs ):
392
+ learning_rate = learning_rate * HvdContext .get ().world_size
393
+ super (HvdOptimizer , self ).__init__ (learning_rate , * args , ** kwargs )
395
394
396
- def compute_gradients (self , loss , ** kwargs ):
397
- loss = hvd .allreduce (loss , op = hvd .Sum )
398
- return super ().compute_gradients (loss , ** kwargs )
399
-
400
395
if isinstance (cls , HvdOptimizer ):
401
396
return cls
402
397
else :
403
398
def horovod_optimizer (* args , ** kwargs ):
404
- return HvdOptimizer (* args , ** kwargs )
399
+ from horovod .tensorflow import DistributedOptimizer
400
+ return DistributedOptimizer (HvdOptimizer (* args , ** kwargs ))
405
401
return horovod_optimizer
406
402
407
403
@@ -478,16 +474,6 @@ def HorovodMonitoredTrainingSession(*args, **kwargs): # pylint: disable=invalid
478
474
kwargs ['config' ] = wraps_session_config (kwargs .pop ('config' , None ))
479
475
kwargs ['is_chief' ] = True
480
476
args = list (args )
481
- if args :
482
- master = args [0 ]
483
- if not master :
484
- master = ''
485
- args [0 ] = master
486
- else :
487
- master = kwargs .pop ('master' , None )
488
- if not master :
489
- master = ''
490
- kwargs ['master' ] = master
491
477
492
478
prev_monitored_session = _monitored_session .MonitoredSession
493
479
sess = fn (* args , ** kwargs )
@@ -1449,4 +1435,4 @@ def export(export_dir_base,
1449
1435
as_text = as_text ,
1450
1436
clear_devices = clear_devices ,
1451
1437
strip_default_attrs = strip_default_attrs ,
1452
- modes = [mode ])
1438
+ modes = [mode ])
0 commit comments