diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 879f140855..aea78775cc 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -105,6 +105,8 @@ /tensorflow_addons/layers/tests/esn_test.py @pedrolarben /tensorflow_addons/layers/snake.py @failure-to-thrive /tensorflow_addons/layers/tests/snake_test.py @failure-to-thrive +/tensorflow_addons/layers/stochastic_depth.py @mhstadler @windqaq +/tensorflow_addons/layers/tests/stochastic_depth_test.py @mhstadler @windqaq /tensorflow_addons/losses/contrastive.py @windqaq /tensorflow_addons/losses/tests/contrastive_test.py @windqaq diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 8545334690..7511d7858a 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -38,9 +38,5 @@ from tensorflow_addons.layers.tlu import TLU from tensorflow_addons.layers.wrappers import WeightNormalization from tensorflow_addons.layers.esn import ESN -<<<<<<< HEAD -======= from tensorflow_addons.layers.stochastic_depth import StochasticDepth -from tensorflow_addons.layers.noisy_dense import NoisyDense from tensorflow_addons.layers.crf import CRF ->>>>>>> 2cd311a... updating documentation for CRF (#2168) diff --git a/tensorflow_addons/layers/stochastic_depth.py b/tensorflow_addons/layers/stochastic_depth.py new file mode 100644 index 0000000000..3cf0df4f8c --- /dev/null +++ b/tensorflow_addons/layers/stochastic_depth.py @@ -0,0 +1,88 @@ +import tensorflow as tf +from typeguard import typechecked + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class StochasticDepth(tf.keras.layers.Layer): + """Stochastic Depth layer. + + Implements Stochastic Depth as described in + [Deep Networks with Stochastic Depth](https://arxiv.org/abs/1603.09382), to randomly drop residual branches + in residual architectures. + + Usage: + Residual architectures with fixed depth, use residual branches that are merged back into the main network + by adding the residual branch back to the input: + + >>> input = np.ones((1, 3, 3, 1), dtype = np.float32) + >>> residual = tf.keras.layers.Conv2D(1, 1)(input) + >>> output = tf.keras.layers.Add()([input, residual]) + >>> output.shape + TensorShape([1, 3, 3, 1]) + + StochasticDepth acts as a drop-in replacement for the addition: + + >>> input = np.ones((1, 3, 3, 1), dtype = np.float32) + >>> residual = tf.keras.layers.Conv2D(1, 1)(input) + >>> output = tfa.layers.StochasticDepth()([input, residual]) + >>> output.shape + TensorShape([1, 3, 3, 1]) + + At train time, StochasticDepth returns: + + $$ + x[0] + b_l * x[1], + $$ + + where $b_l$ is a random Bernoulli variable with probability $P(b_l = 1) = p_l$ + + At test time, StochasticDepth rescales the activations of the residual branch based on the survival probability ($p_l$): + + $$ + x[0] + p_l * x[1] + $$ + + Arguments: + survival_probability: float, the probability of the residual branch being kept. + + Call Arguments: + inputs: List of `[shortcut, residual]` where `shortcut`, and `residual` are tensors of equal shape. + + Output shape: + Equal to the shape of inputs `shortcut`, and `residual` + """ + + @typechecked + def __init__(self, survival_probability: float = 0.5, **kwargs): + super().__init__(**kwargs) + + self.survival_probability = survival_probability + + def call(self, x, training=None): + if not isinstance(x, list) or len(x) != 2: + raise ValueError("input must be a list of length 2.") + + shortcut, residual = x + + # Random bernoulli variable indicating whether the branch should be kept or not or not + b_l = tf.keras.backend.random_bernoulli([], p=self.survival_probability) + + def _call_train(): + return shortcut + b_l * residual + + def _call_test(): + return shortcut + self.survival_probability * residual + + return tf.keras.backend.in_train_phase( + _call_train, _call_test, training=training + ) + + def compute_output_shape(self, input_shape): + return input_shape[0] + + def get_config(self): + base_config = super().get_config() + + config = {"survival_probability": self.survival_probability} + + return {**base_config, **config} diff --git a/tensorflow_addons/layers/tests/stochastic_depth_test.py b/tensorflow_addons/layers/tests/stochastic_depth_test.py new file mode 100644 index 0000000000..1122016f57 --- /dev/null +++ b/tensorflow_addons/layers/tests/stochastic_depth_test.py @@ -0,0 +1,58 @@ +import pytest +import numpy as np +import tensorflow as tf + +from tensorflow_addons.layers.stochastic_depth import StochasticDepth +from tensorflow_addons.utils import test_utils + +_KEEP_SEED = 1111 +_DROP_SEED = 2222 + + +@pytest.mark.parametrize("seed", [_KEEP_SEED, _DROP_SEED]) +@pytest.mark.parametrize("training", [True, False]) +def stochastic_depth_test(seed, training): + np.random.seed(seed) + tf.random.set_seed(seed) + + survival_probability = 0.5 + + shortcut = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32) + residual = np.asarray([[0.2, 0.4, 0.5]]).astype(np.float32) + + if training: + if seed == _KEEP_SEED: + # shortcut + residual + expected_output = np.asarray([[0.4, 0.5, 0.9]]).astype(np.float32) + elif seed == _DROP_SEED: + # shortcut + expected_output = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32) + else: + # shortcut + p_l * residual + expected_output = np.asarray([[0.3, 0.3, 0.65]]).astype(np.float32) + + test_utils.layer_test( + StochasticDepth, + kwargs={"survival_probability": survival_probability}, + input_data=[shortcut, residual], + expected_output=expected_output, + ) + + +@pytest.mark.usefixtures("run_with_mixed_precision_policy") +def test_with_mixed_precision_policy(): + policy = tf.keras.mixed_precision.experimental.global_policy() + + shortcut = np.asarray([[0.2, 0.1, 0.4]]) + residual = np.asarray([[0.2, 0.4, 0.5]]) + + output = StochasticDepth()([shortcut, residual]) + + assert output.dtype == policy.compute_dtype + + +def test_serialization(): + stoch_depth = StochasticDepth(survival_probability=0.5) + serialized_stoch_depth = tf.keras.layers.serialize(stoch_depth) + new_layer = tf.keras.layers.deserialize(serialized_stoch_depth) + assert stoch_depth.get_config() == new_layer.get_config()