Skip to content

Commit ddf7e38

Browse files
WindQAQfacaiy
authored andcommitted
make sparsemax deserializable (#441)
1 parent 594e183 commit ddf7e38

File tree

4 files changed

+22
-2
lines changed

4 files changed

+22
-2
lines changed

tensorflow_addons/activations/sparsemax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from tensorflow_addons.utils import keras_utils
2323

2424

25-
@tf.function
2625
@keras_utils.register_keras_custom_object
26+
@tf.function
2727
def sparsemax(logits, axis=-1, name=None):
2828
"""Sparsemax activation function [1].
2929

tensorflow_addons/activations/sparsemax_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,20 @@ def test_gradient_against_estimate(self, dtype=None):
274274
lambda logits: sparsemax(logits), [z], delta=1e-6)
275275
self.assertAllCloseAccordingToType(jacob_sym, jacob_num)
276276

277+
def test_serialization(self, dtype=None):
278+
ref_fn = sparsemax
279+
config = tf.keras.activations.serialize(ref_fn)
280+
fn = tf.keras.activations.deserialize(config)
281+
self.assertEqual(fn, ref_fn)
282+
283+
def test_serialization_with_layers(self, dtype=None):
284+
layer = tf.keras.layers.Dense(3, activation=sparsemax)
285+
config = tf.keras.layers.serialize(layer)
286+
deserialized_layer = tf.keras.layers.deserialize(config)
287+
self.assertEqual(deserialized_layer.__class__.__name__,
288+
layer.__class__.__name__)
289+
self.assertEqual(deserialized_layer.activation.__name__, "sparsemax")
290+
277291

278292
if __name__ == '__main__':
279293
tf.test.main()

tensorflow_addons/losses/sparsemax_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from tensorflow_addons.utils import keras_utils
2424

2525

26-
@tf.function
2726
@keras_utils.register_keras_custom_object
27+
@tf.function
2828
def sparsemax_loss(logits, sparsemax, labels, name=None):
2929
"""Sparsemax loss function [1].
3030

tensorflow_addons/losses/sparsemax_loss_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,12 @@ def test_gradient_against_estimate(self, dtype=None):
226226
lambda logits: sparsemax_loss(logits, sparsemax(logits), q), [z])
227227
self.assertAllCloseAccordingToType(jacob_sym, jacob_num)
228228

229+
def test_serialization(self, dtype=None):
230+
ref_fn = sparsemax_loss
231+
config = tf.keras.losses.serialize(ref_fn)
232+
fn = tf.keras.losses.deserialize(config)
233+
self.assertEqual(ref_fn, fn)
234+
229235

230236
if __name__ == '__main__':
231237
tf.test.main()

0 commit comments

Comments
 (0)