Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions tensorflow_addons/optimizers/discriminative_layer_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def __init__(

if optimizer_specs is None and optimizers_and_layers is not None:
self.optimizer_specs = [
self.create_optimizer_spec(opt, layer)
for opt, layer in optimizers_and_layers
self.create_optimizer_spec(opt, layers)
for opt, layers in optimizers_and_layers
]

elif optimizer_specs is not None and optimizers_and_layers is None:
Expand Down Expand Up @@ -131,20 +131,23 @@ def get_config(self):
return config

@classmethod
def create_optimizer_spec(cls, optimizer_instance, layer):
def create_optimizer_spec(cls, optimizer_instance, layers):

assert isinstance(
optimizer_instance, tf.keras.optimizers.Optimizer
), "Object passed is not an instance of tf.keras.optimizers.Optimizer"

assert isinstance(layer, tf.keras.layers.Layer) or isinstance(
layer, tf.keras.Model
assert isinstance(layers, (tf.keras.layers.Layer, tf.keras.Model)) or (
isinstance(layers, list)
and all(isinstance(layer, tf.keras.layers.Layer) for layer in layers)
), "Object passed is not an instance of tf.keras.layers.Layer nor tf.keras.Model"

if type(layer) == list:
weights = [var.name for sublayer in layer for var in sublayer.weights]
if type(layers) == list:
weights = [
var.name for sublayer in layers for var in sublayer.trainable_weights
]
else:
weights = [var.name for var in layer.weights]
weights = [var.name for var in layers.trainable_weights]

return {
"optimizer": optimizer_instance,
Expand Down