From e6209803439ec1e080b8ee526140bc79a994a652 Mon Sep 17 00:00:00 2001 From: bhack Date: Fri, 25 Sep 2020 18:12:04 +0000 Subject: [PATCH] Fix assert for list of layers --- .../discriminative_layer_training.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tensorflow_addons/optimizers/discriminative_layer_training.py b/tensorflow_addons/optimizers/discriminative_layer_training.py index 494c29d365..56fb0b9fdc 100644 --- a/tensorflow_addons/optimizers/discriminative_layer_training.py +++ b/tensorflow_addons/optimizers/discriminative_layer_training.py @@ -88,8 +88,8 @@ def __init__( if optimizer_specs is None and optimizers_and_layers is not None: self.optimizer_specs = [ - self.create_optimizer_spec(opt, layer) - for opt, layer in optimizers_and_layers + self.create_optimizer_spec(opt, layers) + for opt, layers in optimizers_and_layers ] elif optimizer_specs is not None and optimizers_and_layers is None: @@ -131,20 +131,23 @@ def get_config(self): return config @classmethod - def create_optimizer_spec(cls, optimizer_instance, layer): + def create_optimizer_spec(cls, optimizer_instance, layers): assert isinstance( optimizer_instance, tf.keras.optimizers.Optimizer ), "Object passed is not an instance of tf.keras.optimizers.Optimizer" - assert isinstance(layer, tf.keras.layers.Layer) or isinstance( - layer, tf.keras.Model + assert isinstance(layers, (tf.keras.layers.Layer, tf.keras.Model)) or ( + isinstance(layers, list) + and all(isinstance(layer, tf.keras.layers.Layer) for layer in layers) ), "Object passed is not an instance of tf.keras.layers.Layer nor tf.keras.Model" - if type(layer) == list: - weights = [var.name for sublayer in layer for var in sublayer.weights] + if type(layers) == list: + weights = [ + var.name for sublayer in layers for var in sublayer.trainable_weights + ] else: - weights = [var.name for var in layer.weights] + weights = [var.name for var in layers.trainable_weights] return { "optimizer": optimizer_instance,