-
Notifications
You must be signed in to change notification settings - Fork 617
Added support for noisy dense layers. #2099
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
9b5c7b8
Create noisy_dense.py
LeonShams baf3574
Create noisy_dense_test.py
LeonShams dfa4dcc
Update __init__.py
LeonShams d823bed
Fix minor typo
LeonShams 08bb63e
Update noisy_dense_test.py
LeonShams 68d6ccb
Update comments
LeonShams 5a564c0
Update comments
LeonShams 1785723
Update noisy_dense.py
LeonShams bdc7dfe
fix typo
LeonShams 882bbde
Update noisy_dense.py
LeonShams 5c09b50
Update noisy_dense_test.py
LeonShams 5e16eef
Fix compliance issues
LeonShams 4b14b8a
Fix compliance issues
LeonShams 9e1f82f
Update comments
LeonShams 57ebf5d
Fix typo
LeonShams 009873b
Update CODEOWNERS
LeonShams 840ab1c
Update CODEOWNERS
LeonShams fa54c00
add use bias to config
LeonShams d87069a
Update noisy_dense.py
LeonShams 82e979f
Update CODEOWNERS
LeonShams 48d3c57
Revert "Update CODEOWNERS"
LeonShams 114842f
Update noisy_dense.py
LeonShams 4e2ed8e
Update noisy_dense.py
LeonShams 4fb4bff
Update noisy_dense.py
LeonShams b874f0b
Update noisy_dense.py
LeonShams 7852e62
Revert "Update CODEOWNERS"
LeonShams 133bb20
Revert "Revert "Update CODEOWNERS""
LeonShams f57f895
Update noisy_dense.py
LeonShams 0a95587
Code reformatted with updated black
LeonShams 8582796
Update noisy_dense.py
LeonShams 0598762
Update noisy_dense.py
LeonShams 442f7e3
Update noisy_dense.py
LeonShams 3e0fcdb
Added support for manual noise reset
LeonShams ee96d5f
support for noise removal
LeonShams 6c33f22
tests for noise removal
LeonShams 1ac4699
use typecheck and remove unicode,
LeonShams 14462f4
fix typo and code cleanup
LeonShams d4ad136
control noise removal through call
LeonShams da02bb1
Inherit from Dense instead of Layer
LeonShams 5730a92
Added missing comment
LeonShams 999c8b7
Documentation and test improvement
LeonShams 5d73c90
fix typo
LeonShams 4cdf577
minor formatting changes
LeonShams ce39ee0
Merge branch 'master' of https://github.com/LeonShams/addons
LeonShams 3cd0217
minor formatting fix
LeonShams File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,264 @@ | ||
| # Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| import tensorflow as tf | ||
| from tensorflow.keras import ( | ||
| activations, | ||
| initializers, | ||
| regularizers, | ||
| constraints, | ||
| ) | ||
| from tensorflow.keras import backend as K | ||
| from tensorflow.keras.layers import InputSpec | ||
| from typeguard import typechecked | ||
|
|
||
| from tensorflow_addons.utils import types | ||
|
|
||
|
|
||
| def _scale_noise(x): | ||
| return tf.sign(x) * tf.sqrt(tf.abs(x)) | ||
|
|
||
|
|
||
| @tf.keras.utils.register_keras_serializable(package="Addons") | ||
| class NoisyDense(tf.keras.layers.Dense): | ||
| r"""Noisy dense layer that injects random noise to the weights of dense layer. | ||
|
|
||
| Noisy dense layers are fully connected layers whose weights and biases are | ||
| augmented by factorised Gaussian noise. The factorised Gaussian noise is | ||
| controlled through gradient descent by a second weights layer. | ||
|
|
||
| A `NoisyDense` layer implements the operation: | ||
| $$ | ||
| \mathrm{NoisyDense}(x) = | ||
| \mathrm{activation}(\mathrm{dot}(x, \mu + (\sigma \cdot \epsilon)) | ||
| + \mathrm{bias}) | ||
| $$ | ||
| where $\mu$ is the standard weights layer, $\epsilon$ is the factorised | ||
| Gaussian noise, and $\sigma$ is a second weights layer which controls | ||
| $\epsilon$. | ||
|
|
||
| Note: bias only added if `use_bias` is `True`. | ||
|
|
||
| Example: | ||
|
|
||
| >>> # Create a `Sequential` model and add a NoisyDense | ||
| >>> # layer as the first layer. | ||
| >>> model = tf.keras.models.Sequential() | ||
| >>> model.add(tf.keras.Input(shape=(16,))) | ||
| >>> model.add(NoisyDense(32, activation='relu')) | ||
| >>> # Now the model will take as input arrays of shape (None, 16) | ||
| >>> # and output arrays of shape (None, 32). | ||
| >>> # Note that after the first layer, you don't need to specify | ||
| >>> # the size of the input anymore: | ||
| >>> model.add(NoisyDense(32)) | ||
| >>> model.output_shape | ||
| (None, 32) | ||
|
|
||
| Arguments: | ||
| units: Positive integer, dimensionality of the output space. | ||
| sigma: A float between 0-1 used as a standard deviation figure and is | ||
| applied to the gaussian noise layer (`sigma_kernel` and `sigma_bias`). | ||
| activation: Activation function to use. | ||
| If you don't specify anything, no activation is applied | ||
| (ie. "linear" activation: `a(x) = x`). | ||
| use_bias: Boolean, whether the layer uses a bias vector. | ||
| kernel_regularizer: Regularizer function applied to | ||
| the `kernel` weights matrix. | ||
| bias_regularizer: Regularizer function applied to the bias vector. | ||
| activity_regularizer: Regularizer function applied to | ||
| the output of the layer (its "activation"). | ||
| kernel_constraint: Constraint function applied to | ||
| the `kernel` weights matrix. | ||
| bias_constraint: Constraint function applied to the bias vector. | ||
|
|
||
| Input shape: | ||
| N-D tensor with shape: `(batch_size, ..., input_dim)`. | ||
| The most common situation would be | ||
| a 2D input with shape `(batch_size, input_dim)`. | ||
|
|
||
| Output shape: | ||
| N-D tensor with shape: `(batch_size, ..., units)`. | ||
| For instance, for a 2D input with shape `(batch_size, input_dim)`, | ||
| the output would have shape `(batch_size, units)`. | ||
|
|
||
| References: | ||
| - [Noisy Networks for Explanation](https://arxiv.org/pdf/1706.10295.pdf) | ||
| """ | ||
|
|
||
| @typechecked | ||
| def __init__( | ||
| self, | ||
| units: int, | ||
| sigma: float = 0.5, | ||
| activation: types.Activation = None, | ||
| use_bias: bool = True, | ||
| kernel_regularizer: types.Regularizer = None, | ||
| bias_regularizer: types.Regularizer = None, | ||
| activity_regularizer: types.Regularizer = None, | ||
| kernel_constraint: types.Constraint = None, | ||
| bias_constraint: types.Constraint = None, | ||
| **kwargs | ||
| ): | ||
| super().__init__( | ||
| units=units, | ||
| activation=activation, | ||
| use_bias=use_bias, | ||
| kernel_regularizer=kernel_regularizer, | ||
| bias_regularizer=bias_regularizer, | ||
| activity_regularizer=activity_regularizer, | ||
| kernel_constraint=kernel_constraint, | ||
| bias_constraint=bias_constraint, | ||
| **kwargs, | ||
| ) | ||
| delattr(self, "kernel_initializer") | ||
| delattr(self, "bias_initializer") | ||
| self.sigma = sigma | ||
|
|
||
| def build(self, input_shape): | ||
| # Make sure dtype is correct | ||
| dtype = tf.dtypes.as_dtype(self.dtype or K.floatx()) | ||
| if not (dtype.is_floating or dtype.is_complex): | ||
| raise TypeError( | ||
| "Unable to build `Dense` layer with non-floating point " | ||
| "dtype %s" % (dtype,) | ||
| ) | ||
|
|
||
| input_shape = tf.TensorShape(input_shape) | ||
| self.last_dim = tf.compat.dimension_value(input_shape[-1]) | ||
| sqrt_dim = self.last_dim ** (1 / 2) | ||
bhack marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if self.last_dim is None: | ||
| raise ValueError( | ||
| "The last dimension of the inputs to `Dense` " | ||
| "should be defined. Found `None`." | ||
| ) | ||
| self.input_spec = InputSpec(min_ndim=2, axes={-1: self.last_dim}) | ||
|
|
||
| sigma_init = initializers.Constant(value=self.sigma / sqrt_dim) | ||
| mu_init = initializers.RandomUniform(minval=-1 / sqrt_dim, maxval=1 / sqrt_dim) | ||
|
|
||
| # Learnable parameters | ||
| self.sigma_kernel = self.add_weight( | ||
| "sigma_kernel", | ||
| shape=[self.last_dim, self.units], | ||
| initializer=sigma_init, | ||
| regularizer=self.kernel_regularizer, | ||
| constraint=self.kernel_constraint, | ||
| dtype=self.dtype, | ||
| trainable=True, | ||
| ) | ||
|
|
||
| self.mu_kernel = self.add_weight( | ||
| "mu_kernel", | ||
| shape=[self.last_dim, self.units], | ||
| initializer=mu_init, | ||
| regularizer=self.kernel_regularizer, | ||
| constraint=self.kernel_constraint, | ||
| dtype=self.dtype, | ||
| trainable=True, | ||
| ) | ||
|
|
||
| if self.use_bias: | ||
| self.sigma_bias = self.add_weight( | ||
| "sigma_bias", | ||
| shape=[ | ||
| self.units, | ||
| ], | ||
| initializer=sigma_init, | ||
| regularizer=self.bias_regularizer, | ||
| constraint=self.bias_constraint, | ||
| dtype=self.dtype, | ||
| trainable=True, | ||
| ) | ||
|
|
||
| self.mu_bias = self.add_weight( | ||
| "mu_bias", | ||
| shape=[ | ||
| self.units, | ||
| ], | ||
| initializer=mu_init, | ||
| regularizer=self.bias_regularizer, | ||
| constraint=self.bias_constraint, | ||
| dtype=self.dtype, | ||
| trainable=True, | ||
| ) | ||
| else: | ||
| self.sigma_bias = None | ||
| self.mu_bias = None | ||
| self._reset_noise() | ||
| self.built = True | ||
|
|
||
| @property | ||
| def kernel(self): | ||
| return self.mu_kernel + (self.sigma_kernel * self.eps_kernel) | ||
|
|
||
| @property | ||
| def bias(self): | ||
| if self.use_bias: | ||
| return self.mu_bias + (self.sigma_bias * self.eps_bias) | ||
|
|
||
| def _reset_noise(self): | ||
LeonShams marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Create the factorised Gaussian noise.""" | ||
|
|
||
| dtype = self._compute_dtype_object | ||
|
|
||
| # Generate random noise | ||
| eps_i = tf.random.normal([self.last_dim, self.units], dtype=dtype) | ||
| eps_j = tf.random.normal( | ||
| [ | ||
| self.units, | ||
| ], | ||
| dtype=dtype, | ||
| ) | ||
|
|
||
| # Scale the random noise | ||
| self.eps_kernel = _scale_noise(eps_i) * _scale_noise(eps_j) | ||
| self.eps_bias = _scale_noise(eps_j) | ||
|
|
||
| def _remove_noise(self): | ||
| """Remove the factorised Gaussian noise.""" | ||
|
|
||
| dtype = self._compute_dtype_object | ||
| self.eps_kernel = tf.zeros([self.last_dim, self.units], dtype=dtype) | ||
| self.eps_bias = tf.zeros([self.units], dtype=dtype) | ||
|
|
||
| def call(self, inputs, reset_noise=True, remove_noise=False): | ||
LeonShams marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Generate fixed parameters added as the noise | ||
| if remove_noise: | ||
| self._remove_noise() | ||
| elif reset_noise: | ||
| self._reset_noise() | ||
|
|
||
| # TODO(WindQAQ): Replace this with `dense()` once public. | ||
| return super().call(inputs) | ||
|
|
||
| def get_config(self): | ||
| # TODO(WindQAQ): Get rid of this hacky way. | ||
| config = super(tf.keras.layers.Dense, self).get_config() | ||
| config.update( | ||
bhack marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| { | ||
| "units": self.units, | ||
| "sigma": self.sigma, | ||
| "activation": activations.serialize(self.activation), | ||
| "use_bias": self.use_bias, | ||
| "kernel_regularizer": regularizers.serialize(self.kernel_regularizer), | ||
| "bias_regularizer": regularizers.serialize(self.bias_regularizer), | ||
| "activity_regularizer": regularizers.serialize( | ||
| self.activity_regularizer | ||
| ), | ||
| "kernel_constraint": constraints.serialize(self.kernel_constraint), | ||
| "bias_constraint": constraints.serialize(self.bias_constraint), | ||
| } | ||
| ) | ||
| return config | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.