From de64975fe293b058b8dcdac3c5e651510c39465e Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Fri, 16 Aug 2019 22:08:49 +0530 Subject: [PATCH 01/14] add gelu activation --- tensorflow_addons/layers/gelu.py | 60 ++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 tensorflow_addons/layers/gelu.py diff --git a/tensorflow_addons/layers/gelu.py b/tensorflow_addons/layers/gelu.py new file mode 100644 index 0000000000..e2bfeb2235 --- /dev/null +++ b/tensorflow_addons/layers/gelu.py @@ -0,0 +1,60 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements GeLU activation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import tensorflow as tf +from tensorflow.keras import backend as K +from tensorflow_addons.utils import keras_utils + + +@keras_utils.register_keras_custom_object +@keras_utils.register_keras_custom_object +class GeLU(tf.keras.layers.Layer): + """Gaussian Error Linear Unit. + + A smoother version of ReLU generally used + in the BERT or BERT architecture based models. + Original paper: https://arxiv.org/abs/1606.08415 + + Input shape: + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + + Output shape: + Same shape as the input. + """ + + def __init__(self, **kwargs): + super(GeLU, self).__init__(**kwargs) + self.supports_masking = True + + def call(self, inputs): + pi = K.cast(math.pi, inputs.dtype) + return 0.5 * inputs * (1 + tf.tanh(tf.sqrt(2.0 / pi) * \ + (inputs + 0.044715 * tf.pow(inputs, 3)))) + + def get_config(self): + config = {} + base_config = super(GeLU, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def compute_output_shape(self, input_shape): + return input_shape \ No newline at end of file From 3e8ae8375c203b5b078406ee3aa5ea61959b3437 Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Fri, 16 Aug 2019 22:09:06 +0530 Subject: [PATCH 02/14] add tests for gelu activation --- tensorflow_addons/layers/gelu_test.py | 58 +++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 tensorflow_addons/layers/gelu_test.py diff --git a/tensorflow_addons/layers/gelu_test.py b/tensorflow_addons/layers/gelu_test.py new file mode 100644 index 0000000000..5edb43270c --- /dev/null +++ b/tensorflow_addons/layers/gelu_test.py @@ -0,0 +1,58 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for GeLU activation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow.keras import backend as K +from tensorflow_addons.layers.gelu import GeLU +from tensorflow_addons.utils import test_utils +from absl.testing import parameterized + + +@parameterized.parameters([np.float16, np.float32, np.float64]) +@test_utils.run_all_in_graph_and_eager_modes +class GELUTest(tf.test.TestCase): + def random_test(self, dtype): + x = tf.constant([2.5, 0.02, -0.001], shape=(3,1)) + val = np.array([ 2.4849157e+00, + 1.0159566e-02, + -4.9960107e-04], + dtype=dtype).reshape(3,1) + + test_utils.layer_test( + GeLU, + kwargs={'dtype': dtype}, + input_data=x, + expected_output=val) + + def random_test_with_numpy(self, dtype): + x = np.array([[0.5, 1.2, -0.3]]).astype(dtype) + val = np.array([[0.345714, 1.0617027, -0.11462909]]).astype(dtype) + + test_utils.layer_test( + GeLU, + kwargs={'dtype': dtype}, + input_data=x, + expected_output=val) + +if __name__ == '__main__': + tf.test.main() + + From 9041c5720441840185d7261726da006f55f057a1 Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Fri, 16 Aug 2019 22:09:31 +0530 Subject: [PATCH 03/14] add gelu to imports --- tensorflow_addons/layers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 382f2aa80e..9da564e967 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +from tensorflow_addons.layers.gelu import GeLU from tensorflow_addons.layers.maxout import Maxout from tensorflow_addons.layers.normalizations import GroupNormalization from tensorflow_addons.layers.normalizations import InstanceNormalization From 0c3fa0c2079ffb989487cd030397ae7ff692fbc0 Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Fri, 16 Aug 2019 22:09:56 +0530 Subject: [PATCH 04/14] include gelu in build file --- tensorflow_addons/layers/BUILD | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD index 59aeb562b5..01f475a2fb 100644 --- a/tensorflow_addons/layers/BUILD +++ b/tensorflow_addons/layers/BUILD @@ -6,6 +6,7 @@ py_library( name = "layers", srcs = [ "__init__.py", + "gelu.py", "maxout.py", "normalizations.py", "optical_flow.py", @@ -23,6 +24,19 @@ py_library( ], ) +py_test( + name = "gelu_test", + size = "small", + srcs = [ + "gelu_test.py", + ], + main = "gelu_test.py", + srcs_version = "PY2AND3", + deps = [ + ":layers", + ], +) + py_test( name = "layers_wrappers_test", size = "small", From afd2c80613a8ec42bf6118abdef310b07a0092a4 Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Sat, 17 Aug 2019 15:21:19 +0530 Subject: [PATCH 05/14] update tests and refactor --- tensorflow_addons/layers/gelu_test.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tensorflow_addons/layers/gelu_test.py b/tensorflow_addons/layers/gelu_test.py index 5edb43270c..fb993e04ed 100644 --- a/tensorflow_addons/layers/gelu_test.py +++ b/tensorflow_addons/layers/gelu_test.py @@ -53,6 +53,4 @@ def random_test_with_numpy(self, dtype): expected_output=val) if __name__ == '__main__': - tf.test.main() - - + tf.test.main() \ No newline at end of file From 05e11c404bc6445728abcb8dce728ee792f1703d Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Sat, 17 Aug 2019 15:21:39 +0530 Subject: [PATCH 06/14] refactor --- tensorflow_addons/layers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 9da564e967..d527e16362 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -25,4 +25,4 @@ from tensorflow_addons.layers.optical_flow import CorrelationCost from tensorflow_addons.layers.poincare import PoincareNormalize from tensorflow_addons.layers.sparsemax import Sparsemax -from tensorflow_addons.layers.wrappers import WeightNormalization +from tensorflow_addons.layers.wrappers import WeightNormalization \ No newline at end of file From 66c40ae41089fe4dd931149662cebec5cb9deacd Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Sat, 17 Aug 2019 18:20:25 +0530 Subject: [PATCH 07/14] make compatible with every fp dtype and fulfill layer requirements --- tensorflow_addons/layers/gelu.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tensorflow_addons/layers/gelu.py b/tensorflow_addons/layers/gelu.py index e2bfeb2235..5b984c2b63 100644 --- a/tensorflow_addons/layers/gelu.py +++ b/tensorflow_addons/layers/gelu.py @@ -22,9 +22,8 @@ import tensorflow as tf from tensorflow.keras import backend as K from tensorflow_addons.utils import keras_utils +from tensorflow_addons.activations import gelu - -@keras_utils.register_keras_custom_object @keras_utils.register_keras_custom_object class GeLU(tf.keras.layers.Layer): """Gaussian Error Linear Unit. @@ -47,9 +46,7 @@ def __init__(self, **kwargs): self.supports_masking = True def call(self, inputs): - pi = K.cast(math.pi, inputs.dtype) - return 0.5 * inputs * (1 + tf.tanh(tf.sqrt(2.0 / pi) * \ - (inputs + 0.044715 * tf.pow(inputs, 3)))) + return gelu(inputs, dtype=inputs.dtype) def get_config(self): config = {} From 4646a8faeae11287efd4331e94bdaef670040915 Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Sat, 17 Aug 2019 18:21:31 +0530 Subject: [PATCH 08/14] add dummy model test --- tensorflow_addons/layers/gelu_test.py | 32 ++++++++++++++------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/tensorflow_addons/layers/gelu_test.py b/tensorflow_addons/layers/gelu_test.py index fb993e04ed..dfc932bd3c 100644 --- a/tensorflow_addons/layers/gelu_test.py +++ b/tensorflow_addons/layers/gelu_test.py @@ -23,26 +23,14 @@ from tensorflow.keras import backend as K from tensorflow_addons.layers.gelu import GeLU from tensorflow_addons.utils import test_utils +from tensorflow_addons.utils.test_utils import keras_parameterized from absl.testing import parameterized @parameterized.parameters([np.float16, np.float32, np.float64]) @test_utils.run_all_in_graph_and_eager_modes -class GELUTest(tf.test.TestCase): - def random_test(self, dtype): - x = tf.constant([2.5, 0.02, -0.001], shape=(3,1)) - val = np.array([ 2.4849157e+00, - 1.0159566e-02, - -4.9960107e-04], - dtype=dtype).reshape(3,1) - - test_utils.layer_test( - GeLU, - kwargs={'dtype': dtype}, - input_data=x, - expected_output=val) - - def random_test_with_numpy(self, dtype): +class TestGeLU(tf.test.TestCase): + def test_random(self, dtype): x = np.array([[0.5, 1.2, -0.3]]).astype(dtype) val = np.array([[0.345714, 1.0617027, -0.11462909]]).astype(dtype) @@ -52,5 +40,19 @@ def random_test_with_numpy(self, dtype): input_data=x, expected_output=val) + +@keras_parameterized.run_all_keras_modes +@keras_parameterized.run_with_all_model_types +class TestGeLU_v2(keras_parameterized.TestCase): + def test_layer_random(self): + layer = tf.keras.layers.Dense(1, activation=GeLU()) + model = keras_parameterized.testing_utils.get_model_from_layers([layer], + input_shape=(10,)) + model.compile( + 'sgd', + 'mse', + run_eagerly=keras_parameterized.testing_utils.should_run_eagerly()) + model.fit(np.ones((10, 10)), np.ones((10, 1)), batch_size=2) + if __name__ == '__main__': tf.test.main() \ No newline at end of file From 821a17ac4e927eaa7ea9e5706784168a23f451e0 Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Tue, 20 Aug 2019 18:39:17 +0530 Subject: [PATCH 09/14] code format --- tensorflow_addons/layers/gelu.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tensorflow_addons/layers/gelu.py b/tensorflow_addons/layers/gelu.py index 5b984c2b63..76f9779cfb 100644 --- a/tensorflow_addons/layers/gelu.py +++ b/tensorflow_addons/layers/gelu.py @@ -18,9 +18,7 @@ from __future__ import division from __future__ import print_function -import math import tensorflow as tf -from tensorflow.keras import backend as K from tensorflow_addons.utils import keras_utils from tensorflow_addons.activations import gelu @@ -28,19 +26,19 @@ class GeLU(tf.keras.layers.Layer): """Gaussian Error Linear Unit. - A smoother version of ReLU generally used + A smoother version of ReLU generally used in the BERT or BERT architecture based models. Original paper: https://arxiv.org/abs/1606.08415 - + Input shape: Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model. - + Output shape: Same shape as the input. """ - + def __init__(self, **kwargs): super(GeLU, self).__init__(**kwargs) self.supports_masking = True @@ -54,4 +52,4 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) def compute_output_shape(self, input_shape): - return input_shape \ No newline at end of file + return input_shape From c6f981df6e27f912a93666e84590ba27c5e9b7e9 Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Tue, 20 Aug 2019 18:39:43 +0530 Subject: [PATCH 10/14] code format and sanity check pass --- tensorflow_addons/layers/gelu_test.py | 28 +++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tensorflow_addons/layers/gelu_test.py b/tensorflow_addons/layers/gelu_test.py index dfc932bd3c..e95388823c 100644 --- a/tensorflow_addons/layers/gelu_test.py +++ b/tensorflow_addons/layers/gelu_test.py @@ -20,11 +20,10 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import backend as K +from absl.testing import parameterized from tensorflow_addons.layers.gelu import GeLU from tensorflow_addons.utils import test_utils from tensorflow_addons.utils.test_utils import keras_parameterized -from absl.testing import parameterized @parameterized.parameters([np.float16, np.float32, np.float64]) @@ -33,7 +32,6 @@ class TestGeLU(tf.test.TestCase): def test_random(self, dtype): x = np.array([[0.5, 1.2, -0.3]]).astype(dtype) val = np.array([[0.345714, 1.0617027, -0.11462909]]).astype(dtype) - test_utils.layer_test( GeLU, kwargs={'dtype': dtype}, @@ -43,16 +41,18 @@ def test_random(self, dtype): @keras_parameterized.run_all_keras_modes @keras_parameterized.run_with_all_model_types -class TestGeLU_v2(keras_parameterized.TestCase): - def test_layer_random(self): - layer = tf.keras.layers.Dense(1, activation=GeLU()) - model = keras_parameterized.testing_utils.get_model_from_layers([layer], - input_shape=(10,)) - model.compile( - 'sgd', - 'mse', - run_eagerly=keras_parameterized.testing_utils.should_run_eagerly()) - model.fit(np.ones((10, 10)), np.ones((10, 1)), batch_size=2) +class TestGeLUModel(keras_parameterized.TestCase): + """Test GeLU with random keras model""" + def test_layer_random(self): + layer = tf.keras.layers.Dense(1, activation=GeLU()) + model = keras_parameterized.testing_utils.get_model_from_layers( + [layer], + input_shape=(10,)) + model.compile( + 'sgd', + 'mse', + run_eagerly=keras_parameterized.testing_utils.should_run_eagerly()) + model.fit(np.ones((10, 10)), np.ones((10, 1)), batch_size=2) if __name__ == '__main__': - tf.test.main() \ No newline at end of file + tf.test.main() From d569f2b28e98143db9ffe3562a55b92fc80b7cd8 Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Tue, 20 Aug 2019 20:36:28 +0530 Subject: [PATCH 11/14] code format --- tensorflow_addons/layers/gelu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow_addons/layers/gelu.py b/tensorflow_addons/layers/gelu.py index 76f9779cfb..314fe06ee8 100644 --- a/tensorflow_addons/layers/gelu.py +++ b/tensorflow_addons/layers/gelu.py @@ -22,6 +22,7 @@ from tensorflow_addons.utils import keras_utils from tensorflow_addons.activations import gelu + @keras_utils.register_keras_custom_object class GeLU(tf.keras.layers.Layer): """Gaussian Error Linear Unit. From 6a13de9f6e2a9c4e5a935a89dc1edee39517d88f Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Tue, 20 Aug 2019 20:36:48 +0530 Subject: [PATCH 12/14] auto code format --- tensorflow_addons/layers/gelu_test.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tensorflow_addons/layers/gelu_test.py b/tensorflow_addons/layers/gelu_test.py index e95388823c..1c1a8c4071 100644 --- a/tensorflow_addons/layers/gelu_test.py +++ b/tensorflow_addons/layers/gelu_test.py @@ -33,26 +33,24 @@ def test_random(self, dtype): x = np.array([[0.5, 1.2, -0.3]]).astype(dtype) val = np.array([[0.345714, 1.0617027, -0.11462909]]).astype(dtype) test_utils.layer_test( - GeLU, - kwargs={'dtype': dtype}, - input_data=x, - expected_output=val) + GeLU, kwargs={'dtype': dtype}, input_data=x, expected_output=val) @keras_parameterized.run_all_keras_modes @keras_parameterized.run_with_all_model_types class TestGeLUModel(keras_parameterized.TestCase): - """Test GeLU with random keras model""" + """Test GeLU with random keras model.""" + def test_layer_random(self): layer = tf.keras.layers.Dense(1, activation=GeLU()) model = keras_parameterized.testing_utils.get_model_from_layers( - [layer], - input_shape=(10,)) + [layer], input_shape=(10,)) model.compile( 'sgd', 'mse', run_eagerly=keras_parameterized.testing_utils.should_run_eagerly()) model.fit(np.ones((10, 10)), np.ones((10, 1)), batch_size=2) + if __name__ == '__main__': tf.test.main() From 6cdba92bf4e67197b779060c66345f9570d28e96 Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Tue, 27 Aug 2019 23:18:47 +0530 Subject: [PATCH 13/14] use fused gelu activation --- tensorflow_addons/layers/gelu.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow_addons/layers/gelu.py b/tensorflow_addons/layers/gelu.py index 314fe06ee8..159e00f729 100644 --- a/tensorflow_addons/layers/gelu.py +++ b/tensorflow_addons/layers/gelu.py @@ -40,15 +40,16 @@ class GeLU(tf.keras.layers.Layer): Same shape as the input. """ - def __init__(self, **kwargs): + def __init__(self, approximate=True, **kwargs): super(GeLU, self).__init__(**kwargs) + self.approximate = approximate self.supports_masking = True def call(self, inputs): - return gelu(inputs, dtype=inputs.dtype) + return gelu(inputs, approximate=self.approximate) def get_config(self): - config = {} + config = {'approximate': self.approximate} base_config = super(GeLU, self).get_config() return dict(list(base_config.items()) + list(config.items())) From db52c48a09652caa8fa3caa49ecc89a00ad9c53a Mon Sep 17 00:00:00 2001 From: Aakash Kumar Nain Date: Thu, 29 Aug 2019 18:28:37 +0530 Subject: [PATCH 14/14] remove redundant test cases --- tensorflow_addons/layers/gelu_test.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tensorflow_addons/layers/gelu_test.py b/tensorflow_addons/layers/gelu_test.py index 1c1a8c4071..99331fb44e 100644 --- a/tensorflow_addons/layers/gelu_test.py +++ b/tensorflow_addons/layers/gelu_test.py @@ -23,7 +23,6 @@ from absl.testing import parameterized from tensorflow_addons.layers.gelu import GeLU from tensorflow_addons.utils import test_utils -from tensorflow_addons.utils.test_utils import keras_parameterized @parameterized.parameters([np.float16, np.float32, np.float64]) @@ -36,21 +35,5 @@ def test_random(self, dtype): GeLU, kwargs={'dtype': dtype}, input_data=x, expected_output=val) -@keras_parameterized.run_all_keras_modes -@keras_parameterized.run_with_all_model_types -class TestGeLUModel(keras_parameterized.TestCase): - """Test GeLU with random keras model.""" - - def test_layer_random(self): - layer = tf.keras.layers.Dense(1, activation=GeLU()) - model = keras_parameterized.testing_utils.get_model_from_layers( - [layer], input_shape=(10,)) - model.compile( - 'sgd', - 'mse', - run_eagerly=keras_parameterized.testing_utils.should_run_eagerly()) - model.fit(np.ones((10, 10)), np.ones((10, 1)), batch_size=2) - - if __name__ == '__main__': tf.test.main()