Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 63 additions & 60 deletions tensorflow_addons/optimizers/discriminative_layer_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""Discriminative Layer Training Optimizer for TensorFlow."""

from typing import Union
from typing import List, Union

import tensorflow as tf
from typeguard import typechecked
Expand All @@ -24,72 +24,70 @@
class MultiOptimizer(tf.keras.optimizers.Optimizer):
"""Multi Optimizer Wrapper for Discriminative Layer Training.

Creates a wrapper around a set of instantiated optimizer layer pairs. Generally useful for transfer learning
of deep networks.
Creates a wrapper around a set of instantiated optimizer layer pairs.
Generally useful for transfer learning of deep networks.

Each optimizer will optimize only the weights associated with its paired layer. This can be used
to implement discriminative layer training by assigning different learning rates to each optimizer
layer pair. (Optimizer, list(Layers)) pairs are also supported. Please note that the layers must be
instantiated before instantiating the optimizer.
Each optimizer will optimize only the weights associated with its paired layer.
This can be used to implement discriminative layer training by assigning
different learning rates to each optimizer layer pair.
`(tf.keras.optimizers.Optimizer, List[tf.keras.layers.Layer])` pairs are also supported.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List[tf.keras.layers.Layer]) -> List([tf.keras.layers.Layer]). Was missing a (

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's ( optimizer, List[layer] ), where () stands for Tuple.

Please note that the layers must be instantiated before instantiating the optimizer.

Args:
optimizers_and_layers: a list of tuples of an optimizer and a layer or model. Each tuple should contain
exactly 1 instantiated optimizer and 1 object that subclasses tf.keras.Model or tf.keras.Layer. Nested
layers and models will be automatically discovered. Alternatively, in place of a single layer, you can pass
a list of layers.
optimizer_specs: specialized list for serialization. Should be left as None for almost all cases. If you are
loading a serialized version of this optimizer, please use tf.keras.models.load_model after saving a
model compiled with this optimizer.
optimizers_and_layers: a list of tuples of an optimizer and a layer or model.
Each tuple should contain exactly 1 instantiated optimizer and 1 object that
subclasses `tf.keras.Model`, `tf.keras.Sequential` or `tf.keras.layers.Layer`.
Nested layers and models will be automatically discovered.
Alternatively, in place of a single layer, you can pass a list of layers.
optimizer_specs: specialized list for serialization.
Should be left as None for almost all cases.
If you are loading a serialized version of this optimizer,
please use `tf.keras.models.load_model` after saving a model compiled with this optimizer.

Usage:

```python
model = get_model()

opt1 = tf.keras.optimizers.Adam(learning_rate=1e-4)
opt2 = tf.keras.optimizers.Adam(learning_rate=1e-2)

opt_layer_pairs = [(opt1, model.layers[0]), (opt2, model.layers[1:])]

loss = tf.keras.losses.MSE
optimizer = tfa.optimizers.MultiOpt(opt_layer_pairs)

model.compile(optimizer=optimizer, loss = loss)

model.fit(x,y)
'''
>>> model = tf.keras.Sequential([
... tf.keras.Input(shape=(4,)),
... tf.keras.layers.Dense(8),
... tf.keras.layers.Dense(16),
... tf.keras.layers.Dense(32),
... ])
>>> optimizers = [
... tf.keras.optimizers.Adam(learning_rate=1e-4),
... tf.keras.optimizers.Adam(learning_rate=1e-2)
... ]
>>> optimizers_and_layers = [(optimizers[0], model.layers[0]), (optimizers[1], model.layers[1:])]
>>> optimizer = tfa.optimizers.MultiOptimizer(optimizers_and_layers)
>>> model.compile(optimizer=optimizer, loss="mse")

Reference:
- [Universal Language Model Fine-tuning for Text Classification](https://arxiv.org/abs/1801.06146)
- [Collaborative Layer-wise Discriminative Learning in Deep Neural Networks](https://arxiv.org/abs/1607.05440)

[Universal Language Model Fine-tuning for Text Classification](https://arxiv.org/abs/1801.06146)
[Collaborative Layer-wise Discriminative Learning in Deep Neural Networks](https://arxiv.org/abs/1607.05440)

Notes:

Currently, MultiOpt does not support callbacks that modify optimizers. However, you can instantiate
optimizer layer pairs with tf.keras.optimizers.schedules.LearningRateSchedule instead of a static learning
rate.
Note: Currently, `tfa.optimizers.MultiOptimizer` does not support callbacks that modify optimizers.
However, you can instantiate optimizer layer pairs with
`tf.keras.optimizers.schedules.LearningRateSchedule`
instead of a static learning rate.

This code should function on CPU, GPU, and TPU. Apply the with strategy.scope() context as you
This code should function on CPU, GPU, and TPU. Apply with `tf.distribute.Strategy().scope()` context as you
would with any other optimizer.

"""

@typechecked
def __init__(
self,
optimizers_and_layers: Union[list, None] = None,
optimizer_specs: Union[list, None] = None,
name: str = "MultiOptimzer",
name: str = "MultiOptimizer",
**kwargs
):

super(MultiOptimizer, self).__init__(name, **kwargs)

if optimizer_specs is None and optimizers_and_layers is not None:
self.optimizer_specs = [
self.create_optimizer_spec(opt, layer)
for opt, layer in optimizers_and_layers
self.create_optimizer_spec(optimizer, layers_or_model)
for optimizer, layers_or_model in optimizers_and_layers
]

elif optimizer_specs is not None and optimizers_and_layers is None:
Expand All @@ -99,14 +97,13 @@ def __init__(

else:
raise RuntimeError(
"You must specify either an list of optimizers and layers or a list of optimizer_specs"
"Must specify one of `optimizers_and_layers` or `optimizer_specs`."
)

def apply_gradients(self, grads_and_vars, name=None, **kwargs):
def apply_gradients(self, grads_and_vars, **kwargs):
"""Wrapped apply_gradient method.

Returns a list of tf ops to be executed.
Name of variable is used rather than var.ref() to enable serialization and deserialization.
Returns an operation to be executed.
"""

for spec in self.optimizer_specs:
Expand All @@ -131,29 +128,35 @@ def get_config(self):
return config

@classmethod
def create_optimizer_spec(cls, optimizer_instance, layer):

assert isinstance(
optimizer_instance, tf.keras.optimizers.Optimizer
), "Object passed is not an instance of tf.keras.optimizers.Optimizer"

assert isinstance(layer, tf.keras.layers.Layer) or isinstance(
layer, tf.keras.Model
), "Object passed is not an instance of tf.keras.layers.Layer nor tf.keras.Model"
def create_optimizer_spec(
cls,
optimizer: tf.keras.optimizers.Optimizer,
layers_or_model: Union[
tf.keras.Model,
tf.keras.Sequential,
tf.keras.layers.Layer,
List[tf.keras.layers.Layer],
],
):
"""Creates a serializable optimizer spec.

if type(layer) == list:
weights = [var.name for sublayer in layer for var in sublayer.weights]
The name of each variable is used rather than `var.ref()` to enable serialization and deserialization.
"""
if isinstance(layers_or_model, list):
weights = [
var.name for sublayer in layers_or_model for var in sublayer.weights
]
else:
weights = [var.name for var in layer.weights]
weights = [var.name for var in layers_or_model.weights]

return {
"optimizer": optimizer_instance,
"optimizer": optimizer,
"weights": weights,
}

@classmethod
def maybe_initialize_optimizer_spec(cls, optimizer_spec):
if type(optimizer_spec["optimizer"]) == dict:
if isinstance(optimizer_spec["optimizer"], dict):
optimizer_spec["optimizer"] = tf.keras.optimizers.deserialize(
optimizer_spec["optimizer"]
)
Expand Down
Loading