diff --git a/tensorflow_addons/optimizers/discriminative_layer_training.py b/tensorflow_addons/optimizers/discriminative_layer_training.py index 494c29d365..56fb0b9fdc 100644 --- a/tensorflow_addons/optimizers/discriminative_layer_training.py +++ b/tensorflow_addons/optimizers/discriminative_layer_training.py @@ -88,8 +88,8 @@ def __init__( if optimizer_specs is None and optimizers_and_layers is not None: self.optimizer_specs = [ - self.create_optimizer_spec(opt, layer) - for opt, layer in optimizers_and_layers + self.create_optimizer_spec(opt, layers) + for opt, layers in optimizers_and_layers ] elif optimizer_specs is not None and optimizers_and_layers is None: @@ -131,20 +131,23 @@ def get_config(self): return config @classmethod - def create_optimizer_spec(cls, optimizer_instance, layer): + def create_optimizer_spec(cls, optimizer_instance, layers): assert isinstance( optimizer_instance, tf.keras.optimizers.Optimizer ), "Object passed is not an instance of tf.keras.optimizers.Optimizer" - assert isinstance(layer, tf.keras.layers.Layer) or isinstance( - layer, tf.keras.Model + assert isinstance(layers, (tf.keras.layers.Layer, tf.keras.Model)) or ( + isinstance(layers, list) + and all(isinstance(layer, tf.keras.layers.Layer) for layer in layers) ), "Object passed is not an instance of tf.keras.layers.Layer nor tf.keras.Model" - if type(layer) == list: - weights = [var.name for sublayer in layer for var in sublayer.weights] + if type(layers) == list: + weights = [ + var.name for sublayer in layers for var in sublayer.trainable_weights + ] else: - weights = [var.name for var in layer.weights] + weights = [var.name for var in layers.trainable_weights] return { "optimizer": optimizer_instance,