-
Notifications
You must be signed in to change notification settings - Fork 617
GeLU activation as a layer #424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
de64975
add gelu activation
AakashKumarNain 3e8ae83
add tests for gelu activation
AakashKumarNain 9041c57
add gelu to imports
AakashKumarNain 0c3fa0c
include gelu in build file
AakashKumarNain afd2c80
update tests and refactor
AakashKumarNain 05e11c4
refactor
AakashKumarNain 66c40ae
make compatible with every fp dtype and fulfill layer requirements
AakashKumarNain 4646a8f
add dummy model test
AakashKumarNain 32659c3
Merge branch 'master' of https://github.com/tensorflow/addons into la…
AakashKumarNain 821a17a
code format
AakashKumarNain c6f981d
code format and sanity check pass
AakashKumarNain d569f2b
code format
AakashKumarNain 6a13de9
auto code format
AakashKumarNain 8c89a62
Merge branch 'master' of https://github.com/tensorflow/addons into la…
AakashKumarNain 6cdba92
use fused gelu activation
AakashKumarNain 8b9ee5f
Merge branch 'master' of https://github.com/tensorflow/addons into la…
AakashKumarNain db52c48
remove redundant test cases
AakashKumarNain File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """Implements GeLU activation.""" | ||
|
|
||
| from __future__ import absolute_import | ||
| from __future__ import division | ||
| from __future__ import print_function | ||
|
|
||
| import tensorflow as tf | ||
| from tensorflow_addons.utils import keras_utils | ||
| from tensorflow_addons.activations import gelu | ||
|
|
||
|
|
||
| @keras_utils.register_keras_custom_object | ||
| class GeLU(tf.keras.layers.Layer): | ||
| """Gaussian Error Linear Unit. | ||
|
|
||
| A smoother version of ReLU generally used | ||
| in the BERT or BERT architecture based models. | ||
| Original paper: https://arxiv.org/abs/1606.08415 | ||
|
|
||
| Input shape: | ||
| Arbitrary. Use the keyword argument `input_shape` | ||
| (tuple of integers, does not include the samples axis) | ||
| when using this layer as the first layer in a model. | ||
|
|
||
| Output shape: | ||
| Same shape as the input. | ||
| """ | ||
|
|
||
| def __init__(self, approximate=True, **kwargs): | ||
| super(GeLU, self).__init__(**kwargs) | ||
| self.approximate = approximate | ||
| self.supports_masking = True | ||
|
|
||
| def call(self, inputs): | ||
| return gelu(inputs, approximate=self.approximate) | ||
|
|
||
| def get_config(self): | ||
| config = {'approximate': self.approximate} | ||
| base_config = super(GeLU, self).get_config() | ||
| return dict(list(base_config.items()) + list(config.items())) | ||
|
|
||
| def compute_output_shape(self, input_shape): | ||
| return input_shape | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """Tests for GeLU activation.""" | ||
|
|
||
| from __future__ import absolute_import | ||
| from __future__ import division | ||
| from __future__ import print_function | ||
|
|
||
| import numpy as np | ||
| import tensorflow as tf | ||
| from absl.testing import parameterized | ||
| from tensorflow_addons.layers.gelu import GeLU | ||
| from tensorflow_addons.utils import test_utils | ||
|
|
||
|
|
||
| @parameterized.parameters([np.float16, np.float32, np.float64]) | ||
| @test_utils.run_all_in_graph_and_eager_modes | ||
| class TestGeLU(tf.test.TestCase): | ||
| def test_random(self, dtype): | ||
| x = np.array([[0.5, 1.2, -0.3]]).astype(dtype) | ||
| val = np.array([[0.345714, 1.0617027, -0.11462909]]).astype(dtype) | ||
| test_utils.layer_test( | ||
| GeLU, kwargs={'dtype': dtype}, input_data=x, expected_output=val) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| tf.test.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.