Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
9b5c7b8
Create noisy_dense.py
LeonShams Aug 17, 2020
baf3574
Create noisy_dense_test.py
LeonShams Aug 17, 2020
dfa4dcc
Update __init__.py
LeonShams Aug 17, 2020
d823bed
Fix minor typo
LeonShams Aug 17, 2020
08bb63e
Update noisy_dense_test.py
LeonShams Aug 17, 2020
68d6ccb
Update comments
LeonShams Aug 17, 2020
5a564c0
Update comments
LeonShams Aug 17, 2020
1785723
Update noisy_dense.py
LeonShams Aug 17, 2020
bdc7dfe
fix typo
LeonShams Aug 18, 2020
882bbde
Update noisy_dense.py
LeonShams Aug 18, 2020
5c09b50
Update noisy_dense_test.py
LeonShams Aug 18, 2020
5e16eef
Fix compliance issues
LeonShams Aug 23, 2020
4b14b8a
Fix compliance issues
LeonShams Aug 23, 2020
9e1f82f
Update comments
LeonShams Aug 23, 2020
57ebf5d
Fix typo
LeonShams Aug 23, 2020
009873b
Update CODEOWNERS
LeonShams Aug 27, 2020
840ab1c
Update CODEOWNERS
LeonShams Aug 27, 2020
fa54c00
add use bias to config
LeonShams Aug 29, 2020
d87069a
Update noisy_dense.py
LeonShams Aug 30, 2020
82e979f
Update CODEOWNERS
LeonShams Aug 30, 2020
48d3c57
Revert "Update CODEOWNERS"
LeonShams Aug 30, 2020
114842f
Update noisy_dense.py
LeonShams Aug 30, 2020
4e2ed8e
Update noisy_dense.py
LeonShams Aug 30, 2020
4fb4bff
Update noisy_dense.py
LeonShams Aug 30, 2020
b874f0b
Update noisy_dense.py
LeonShams Aug 30, 2020
7852e62
Revert "Update CODEOWNERS"
LeonShams Aug 30, 2020
133bb20
Revert "Revert "Update CODEOWNERS""
LeonShams Aug 30, 2020
f57f895
Update noisy_dense.py
LeonShams Aug 30, 2020
0a95587
Code reformatted with updated black
LeonShams Aug 30, 2020
8582796
Update noisy_dense.py
LeonShams Sep 2, 2020
0598762
Update noisy_dense.py
LeonShams Sep 2, 2020
442f7e3
Update noisy_dense.py
LeonShams Sep 5, 2020
3e0fcdb
Added support for manual noise reset
LeonShams Sep 5, 2020
ee96d5f
support for noise removal
LeonShams Sep 5, 2020
6c33f22
tests for noise removal
LeonShams Sep 5, 2020
1ac4699
use typecheck and remove unicode,
LeonShams Sep 11, 2020
14462f4
fix typo and code cleanup
LeonShams Sep 11, 2020
d4ad136
control noise removal through call
LeonShams Sep 11, 2020
da02bb1
Inherit from Dense instead of Layer
LeonShams Sep 12, 2020
5730a92
Added missing comment
LeonShams Sep 12, 2020
999c8b7
Documentation and test improvement
LeonShams Sep 13, 2020
5d73c90
fix typo
LeonShams Sep 13, 2020
4cdf577
minor formatting changes
LeonShams Sep 13, 2020
ce39ee0
Merge branch 'master' of https://github.com/LeonShams/addons
LeonShams Sep 13, 2020
3cd0217
minor formatting fix
LeonShams Sep 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
/tensorflow_addons/layers/tests/esn_test.py @pedrolarben
/tensorflow_addons/layers/snake.py @failure-to-thrive
/tensorflow_addons/layers/tests/snake_test.py @failure-to-thrive
/tensorflow_addons/layers/noisy_dense.py @leonshams
/tensorflow_addons/layers/tests/noisy_dense_test.py @leonshams

/tensorflow_addons/losses/contrastive.py @windqaq
/tensorflow_addons/losses/tests/contrastive_test.py @windqaq
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@
from tensorflow_addons.layers.tlu import TLU
from tensorflow_addons.layers.wrappers import WeightNormalization
from tensorflow_addons.layers.esn import ESN
from tensorflow_addons.layers.noisy_dense import NoisyDense
264 changes: 264 additions & 0 deletions tensorflow_addons/layers/noisy_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import tensorflow as tf
from tensorflow.keras import (
activations,
initializers,
regularizers,
constraints,
)
from tensorflow.keras import backend as K
from tensorflow.keras.layers import InputSpec
from typeguard import typechecked

from tensorflow_addons.utils import types


def _scale_noise(x):
return tf.sign(x) * tf.sqrt(tf.abs(x))


@tf.keras.utils.register_keras_serializable(package="Addons")
class NoisyDense(tf.keras.layers.Dense):
r"""Noisy dense layer that injects random noise to the weights of dense layer.

Noisy dense layers are fully connected layers whose weights and biases are
augmented by factorised Gaussian noise. The factorised Gaussian noise is
controlled through gradient descent by a second weights layer.

A `NoisyDense` layer implements the operation:
$$
\mathrm{NoisyDense}(x) =
\mathrm{activation}(\mathrm{dot}(x, \mu + (\sigma \cdot \epsilon))
+ \mathrm{bias})
$$
where $\mu$ is the standard weights layer, $\epsilon$ is the factorised
Gaussian noise, and $\sigma$ is a second weights layer which controls
$\epsilon$.

Note: bias only added if `use_bias` is `True`.

Example:

>>> # Create a `Sequential` model and add a NoisyDense
>>> # layer as the first layer.
>>> model = tf.keras.models.Sequential()
>>> model.add(tf.keras.Input(shape=(16,)))
>>> model.add(NoisyDense(32, activation='relu'))
>>> # Now the model will take as input arrays of shape (None, 16)
>>> # and output arrays of shape (None, 32).
>>> # Note that after the first layer, you don't need to specify
>>> # the size of the input anymore:
>>> model.add(NoisyDense(32))
>>> model.output_shape
(None, 32)

Arguments:
units: Positive integer, dimensionality of the output space.
sigma: A float between 0-1 used as a standard deviation figure and is
applied to the gaussian noise layer (`sigma_kernel` and `sigma_bias`).
activation: Activation function to use.
If you don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix.
bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to
the output of the layer (its "activation").
kernel_constraint: Constraint function applied to
the `kernel` weights matrix.
bias_constraint: Constraint function applied to the bias vector.

Input shape:
N-D tensor with shape: `(batch_size, ..., input_dim)`.
The most common situation would be
a 2D input with shape `(batch_size, input_dim)`.

Output shape:
N-D tensor with shape: `(batch_size, ..., units)`.
For instance, for a 2D input with shape `(batch_size, input_dim)`,
the output would have shape `(batch_size, units)`.

References:
- [Noisy Networks for Explanation](https://arxiv.org/pdf/1706.10295.pdf)
"""

@typechecked
def __init__(
self,
units: int,
sigma: float = 0.5,
activation: types.Activation = None,
use_bias: bool = True,
kernel_regularizer: types.Regularizer = None,
bias_regularizer: types.Regularizer = None,
activity_regularizer: types.Regularizer = None,
kernel_constraint: types.Constraint = None,
bias_constraint: types.Constraint = None,
**kwargs
):
super().__init__(
units=units,
activation=activation,
use_bias=use_bias,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs,
)
delattr(self, "kernel_initializer")
delattr(self, "bias_initializer")
self.sigma = sigma

def build(self, input_shape):
# Make sure dtype is correct
dtype = tf.dtypes.as_dtype(self.dtype or K.floatx())
if not (dtype.is_floating or dtype.is_complex):
raise TypeError(
"Unable to build `Dense` layer with non-floating point "
"dtype %s" % (dtype,)
)

input_shape = tf.TensorShape(input_shape)
self.last_dim = tf.compat.dimension_value(input_shape[-1])
sqrt_dim = self.last_dim ** (1 / 2)
if self.last_dim is None:
raise ValueError(
"The last dimension of the inputs to `Dense` "
"should be defined. Found `None`."
)
self.input_spec = InputSpec(min_ndim=2, axes={-1: self.last_dim})

sigma_init = initializers.Constant(value=self.sigma / sqrt_dim)
mu_init = initializers.RandomUniform(minval=-1 / sqrt_dim, maxval=1 / sqrt_dim)

# Learnable parameters
self.sigma_kernel = self.add_weight(
"sigma_kernel",
shape=[self.last_dim, self.units],
initializer=sigma_init,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
dtype=self.dtype,
trainable=True,
)

self.mu_kernel = self.add_weight(
"mu_kernel",
shape=[self.last_dim, self.units],
initializer=mu_init,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
dtype=self.dtype,
trainable=True,
)

if self.use_bias:
self.sigma_bias = self.add_weight(
"sigma_bias",
shape=[
self.units,
],
initializer=sigma_init,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
dtype=self.dtype,
trainable=True,
)

self.mu_bias = self.add_weight(
"mu_bias",
shape=[
self.units,
],
initializer=mu_init,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
dtype=self.dtype,
trainable=True,
)
else:
self.sigma_bias = None
self.mu_bias = None
self._reset_noise()
self.built = True

@property
def kernel(self):
return self.mu_kernel + (self.sigma_kernel * self.eps_kernel)

@property
def bias(self):
if self.use_bias:
return self.mu_bias + (self.sigma_bias * self.eps_bias)

def _reset_noise(self):
"""Create the factorised Gaussian noise."""

dtype = self._compute_dtype_object

# Generate random noise
eps_i = tf.random.normal([self.last_dim, self.units], dtype=dtype)
eps_j = tf.random.normal(
[
self.units,
],
dtype=dtype,
)

# Scale the random noise
self.eps_kernel = _scale_noise(eps_i) * _scale_noise(eps_j)
self.eps_bias = _scale_noise(eps_j)

def _remove_noise(self):
"""Remove the factorised Gaussian noise."""

dtype = self._compute_dtype_object
self.eps_kernel = tf.zeros([self.last_dim, self.units], dtype=dtype)
self.eps_bias = tf.zeros([self.units], dtype=dtype)

def call(self, inputs, reset_noise=True, remove_noise=False):
# Generate fixed parameters added as the noise
if remove_noise:
self._remove_noise()
elif reset_noise:
self._reset_noise()

# TODO(WindQAQ): Replace this with `dense()` once public.
return super().call(inputs)

def get_config(self):
# TODO(WindQAQ): Get rid of this hacky way.
config = super(tf.keras.layers.Dense, self).get_config()
config.update(
{
"units": self.units,
"sigma": self.sigma,
"activation": activations.serialize(self.activation),
"use_bias": self.use_bias,
"kernel_regularizer": regularizers.serialize(self.kernel_regularizer),
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
"activity_regularizer": regularizers.serialize(
self.activity_regularizer
),
"kernel_constraint": constraints.serialize(self.kernel_constraint),
"bias_constraint": constraints.serialize(self.bias_constraint),
}
)
return config
Loading