diff --git a/tensorflow_addons/optimizers/lazy_adam.py b/tensorflow_addons/optimizers/lazy_adam.py index 4efe4416b3..494d108732 100644 --- a/tensorflow_addons/optimizers/lazy_adam.py +++ b/tensorflow_addons/optimizers/lazy_adam.py @@ -47,6 +47,23 @@ class LazyAdam(tf.keras.optimizers.Adam): False. """ + def __init__(self, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-7, + amsgrad=False, + name='LazyAdam', + **kwargs): + super(LazyAdam, self).__init__( + learning_rate=learning_rate, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon, + amsgrad=amsgrad, + name=name, + **kwargs) + def _resource_apply_sparse(self, grad, var, indices): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) diff --git a/tensorflow_addons/optimizers/lazy_adam_test.py b/tensorflow_addons/optimizers/lazy_adam_test.py index 9a3adab71d..cea6484df5 100644 --- a/tensorflow_addons/optimizers/lazy_adam_test.py +++ b/tensorflow_addons/optimizers/lazy_adam_test.py @@ -212,7 +212,7 @@ def doTestBasic(self, use_callable_params=False): self.evaluate(var0)) self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) - self.assertEqual("var0_%d/m:0" % (i,), + self.assertEqual("LazyAdam/var0_%d/m:0" % (i,), opt.get_slot(var0, "m").name) @test_utils.run_in_graph_and_eager_modes(reset_test=True)