Skip to content

Commit 1a221eb

Browse files
seanpmorganshun-lin
authored andcommitted
[WIP] Use public Keras object registration (#669)
* Remove test skips since upstream fix * Use public keras object registration * Merge master and update registration * Fix tf import * Update READMEs * F scores keras registration * Lint
1 parent 257a39b commit 1a221eb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+60
-109
lines changed

tensorflow_addons/activations/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
In order to conform with the current API standard, all activations
3030
must:
3131
* Be a `tf.function`.
32-
* [Register as a keras global object](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/utils/keras_utils.py)
33-
so it can be serialized properly.
32+
* Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')`
3433
* Add the addon to the `py_library` in this sub-package's BUILD file.
3534

3635
#### Testing Requirements

tensorflow_addons/activations/gelu.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
from __future__ import print_function
1919

2020
import tensorflow as tf
21-
from tensorflow_addons.utils import keras_utils
2221
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
2322

2423
_activation_ops_so = tf.load_op_library(
2524
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
2625

2726

28-
@keras_utils.register_keras_custom_object
27+
@tf.keras.utils.register_keras_serializable(package='Addons')
2928
@tf.function
3029
def gelu(x, approximate=True):
3130
"""Gaussian Error Linear Unit.

tensorflow_addons/activations/hardshrink.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
from __future__ import print_function
1919

2020
import tensorflow as tf
21-
from tensorflow_addons.utils import keras_utils
2221
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
2322

2423
_activation_ops_so = tf.load_op_library(
2524
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
2625

2726

28-
@keras_utils.register_keras_custom_object
27+
@tf.keras.utils.register_keras_serializable(package='Addons')
2928
@tf.function
3029
def hardshrink(x, lower=-0.5, upper=0.5):
3130
"""Hard shrink function.

tensorflow_addons/activations/lisht.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
from __future__ import print_function
1919

2020
import tensorflow as tf
21-
from tensorflow_addons.utils import keras_utils
2221
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
2322

2423
_activation_ops_so = tf.load_op_library(
2524
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
2625

2726

28-
@keras_utils.register_keras_custom_object
27+
@tf.keras.utils.register_keras_serializable(package='Addons')
2928
@tf.function
3029
def lisht(x):
3130
"""LiSHT: Non-Parameteric Linearly Scaled Hyperbolic Tangent Activation Function.

tensorflow_addons/activations/mish.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
from __future__ import print_function
1919

2020
import tensorflow as tf
21-
from tensorflow_addons.utils import keras_utils
2221
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
2322

2423
_activation_ops_so = tf.load_op_library(
2524
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
2625

2726

28-
@keras_utils.register_keras_custom_object
27+
@tf.keras.utils.register_keras_serializable(package='Addons')
2928
@tf.function
3029
def mish(x):
3130
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function.

tensorflow_addons/activations/rrelu.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
from __future__ import print_function
1919

2020
import tensorflow as tf
21-
from tensorflow_addons.utils import keras_utils
2221
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
2322

2423
_activation_ops_so = tf.load_op_library(
2524
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
2625

2726

28-
@keras_utils.register_keras_custom_object
27+
@tf.keras.utils.register_keras_serializable(package='Addons')
2928
@tf.function
3029
def rrelu(x, lower=0.125, upper=0.3333333333333333, training=None, seed=None):
3130
"""rrelu function.

tensorflow_addons/activations/softshrink.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
from __future__ import print_function
1919

2020
import tensorflow as tf
21-
from tensorflow_addons.utils import keras_utils
2221
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
2322

2423
_activation_ops_so = tf.load_op_library(
2524
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
2625

2726

28-
@keras_utils.register_keras_custom_object
27+
@tf.keras.utils.register_keras_serializable(package='Addons')
2928
@tf.function
3029
def softshrink(x, lower=-0.5, upper=0.5):
3130
"""Soft shrink function.

tensorflow_addons/activations/sparsemax.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@
1919

2020
import tensorflow as tf
2121

22-
from tensorflow_addons.utils import keras_utils
2322

24-
25-
@keras_utils.register_keras_custom_object
23+
@tf.keras.utils.register_keras_serializable(package='Addons')
2624
@tf.function
2725
def sparsemax(logits, axis=-1):
2826
"""Sparsemax activation function [1].

tensorflow_addons/activations/tanhshrink.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
from __future__ import print_function
1919

2020
import tensorflow as tf
21-
from tensorflow_addons.utils import keras_utils
2221
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
2322

2423
_activation_ops_so = tf.load_op_library(
2524
get_path_to_datafile("custom_ops/activations/_activation_ops.so"))
2625

2726

28-
@keras_utils.register_keras_custom_object
27+
@tf.keras.utils.register_keras_serializable(package='Addons')
2928
@tf.function
3029
def tanhshrink(x):
3130
"""Applies the element-wise function: x - tanh(x)

tensorflow_addons/callbacks/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
In order to conform with the current API standard, all callbacks
1717
must:
1818
* Inherit from `tf.keras.callbacks.Callback`.
19-
* [Register as a keras global object](https://github.com/tensorflow/addons/blob/master/tensorflow_addons/utils/keras_utils.py)
20-
so it can be serialized properly.
19+
* Register as a keras global object so it can be serialized properly: `@tf.keras.utils.register_keras_serializable(package='Addons')`
2120
* Add the addon to the `py_library` in this sub-package's BUILD file.
2221

2322
#### Testing Requirements

0 commit comments

Comments
 (0)