From bb828f6bf8375adb056eb850ecae5feaef5ee341 Mon Sep 17 00:00:00 2001 From: Michael Stadler Date: Tue, 8 Sep 2020 12:12:05 +0100 Subject: [PATCH 01/12] Added stochastic depth layer --- tensorflow_addons/layers/stochastic_depth.py | 86 +++++++++++++++++++ .../layers/tests/stochastic_depth_test.py | 43 ++++++++++ 2 files changed, 129 insertions(+) create mode 100644 tensorflow_addons/layers/stochastic_depth.py create mode 100644 tensorflow_addons/layers/tests/stochastic_depth_test.py diff --git a/tensorflow_addons/layers/stochastic_depth.py b/tensorflow_addons/layers/stochastic_depth.py new file mode 100644 index 0000000000..db3040bb65 --- /dev/null +++ b/tensorflow_addons/layers/stochastic_depth.py @@ -0,0 +1,86 @@ +import tensorflow as tf + +@tf.keras.utils.register_keras_serializable(package="Addons") +class StochasticDepth(tf.keras.layers.Layer): + r"""Stochastic Depth layer. + + Implements Stochastic Depth as described in + [Deep Networks with Stochastic Depth](https://arxiv.org/abs/1603.09382), to randomly drop residual branches + in residual architectures. + + Usage: + Residual architectures with fixed depth, use residual branches that are merged back into the main network + by adding the residual branch back to the input: + + ```python + residual = tf.keras.layers.Conv2D(...)(input) + + return tf.keras.layers.Add()([input, residual]) + ``` + + StochasticDepth acts as a drop-in replacement for the addition: + + ```python + residual = tf.keras.layers.Conv2D(...)(input) + + return tfa.layers.StochasticDepth()([input, residual]) + ``` + + At train time, StochasticDepth returns: + + ```python + x[0] + b_l * x[1] + ``` + + , where b_l is a random Bernoulli variable with probability p(b_l == 1) == p_l + + At test time, StochasticDepth rescales the activations of the residual branch based on the survival probability: + + ```python + x[0] + p_l * x[1] + ``` + + Arguments: + p_l: float, the probability of the residual branch being kept. + + Call Arguments: + inputs: List of `[shortcut, residual]` where + * `shortcut`, and `residual` are tensors of equal shape. + + Output shape: + Equal to the shape of inputs `shortcut`, and `residual` + """ + + @typechecked + def __init__(self, p_l: float = 0.5, **kwargs): + super().__init__(**kwargs) + + self.p_l = p_l + + def call(self, x, training = None): + assert isinstance(x, list) + + shortcut, residual = x + + # Random bernoulli variable with probability p_l, indiciathing wheter the branch should be kept or not or not + b_l = tf.keras.backend.random_binomial([], p = self.p_l) + + def _call_train(): + return shortcut + b_l * residual + + def _call_test(): + return shortcut + self.p_l * residual + + return tf.keras.backend.in_train_phase(_call_train, _call_test, training = training) + + def compute_output_shape(self, input_shape): + assert isinstance(input_shape, list) + + return input_shape[0] + + def get_config(self): + base_config = super().get_config() + + config = {"p_l": self.p_l} + + return {**base_config, **config} \ No newline at end of file diff --git a/tensorflow_addons/layers/tests/stochastic_depth_test.py b/tensorflow_addons/layers/tests/stochastic_depth_test.py new file mode 100644 index 0000000000..c0347b77ff --- /dev/null +++ b/tensorflow_addons/layers/tests/stochastic_depth_test.py @@ -0,0 +1,43 @@ +import pytest + +import numpy as np +import tensorflow as tf +from tensorflow_addons.layers.stochastic_depth import StochasticDepth +from tensorflow_addons.utils import test_utils + +KEEP_SEED = 1111 +DROP_SEED = 2222 + +@pytest.mark.parametrize("seed", [KEEP_SEED, DROP_SEED]) +@pytest.mark.parametrize("training", [True, False]) +def stochastic_depth_test(seed, training): + np.random.seed(seed) + tf.random.set_seed(seed) + + p_l = 0.5 + + shortcut = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32) + residual = np.asarray([[0.2, 0.4, 0.5]]).astype(np.float32) + + if training: + if seed == KEEP_SEED: + # shortcut + residual + expected_output = np.asarray([[0.4, 0.5, 0.9]]).astype(np.float32) + elif seed == DROP_SEED: + # shortcut + expected_output = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32) + else: + # shortcut + p_l * residual + expected_output = np.asarray([[0.3, 0.3, 0.65]]).astype(np.float32) + + test_utils.layer_test( + StochasticDepth, kwargs = {"p_l": p_l}, input_data = [shortcut, residual], expected_output = expected_output + ) + +def test_serialization(): + stoch_depth = StochasticDepth( + p_l = 0.5 + ) + serialized_stoch_depth = tf.keras.layers.serialize(stoch_depth) + new_layer = tf.keras.layers.deserialize(serialized_stoch_depth) + assert serialized_stoch_depth.get_config() == new_layer.get_config() \ No newline at end of file From fb94ee824ffc7c44580bb5b96e911e40883d4ed5 Mon Sep 17 00:00:00 2001 From: Michael Date: Tue, 8 Sep 2020 12:53:35 +0100 Subject: [PATCH 02/12] Fixed code style and added missing __init__ entry --- tensorflow_addons/layers/__init__.py | 1 + tensorflow_addons/layers/stochastic_depth.py | 29 ++++++++++--------- .../layers/tests/stochastic_depth_test.py | 13 +++++---- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 594d025a9f..4d43976474 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -38,3 +38,4 @@ from tensorflow_addons.layers.tlu import TLU from tensorflow_addons.layers.wrappers import WeightNormalization from tensorflow_addons.layers.esn import ESN +from tensorflow_addons.layers.stochastic_depth import StochasticDepth diff --git a/tensorflow_addons/layers/stochastic_depth.py b/tensorflow_addons/layers/stochastic_depth.py index db3040bb65..53db1c12d0 100644 --- a/tensorflow_addons/layers/stochastic_depth.py +++ b/tensorflow_addons/layers/stochastic_depth.py @@ -1,5 +1,6 @@ import tensorflow as tf + @tf.keras.utils.register_keras_serializable(package="Addons") class StochasticDepth(tf.keras.layers.Layer): r"""Stochastic Depth layer. @@ -9,8 +10,8 @@ class StochasticDepth(tf.keras.layers.Layer): in residual architectures. Usage: - Residual architectures with fixed depth, use residual branches that are merged back into the main network - by adding the residual branch back to the input: + Residual architectures with fixed depth, use residual branches that are merged back into the main network + by adding the residual branch back to the input: ```python residual = tf.keras.layers.Conv2D(...)(input) @@ -31,7 +32,7 @@ class StochasticDepth(tf.keras.layers.Layer): ```python x[0] + b_l * x[1] ``` - + , where b_l is a random Bernoulli variable with probability p(b_l == 1) == p_l At test time, StochasticDepth rescales the activations of the residual branch based on the survival probability: @@ -39,7 +40,7 @@ class StochasticDepth(tf.keras.layers.Layer): ```python x[0] + p_l * x[1] ``` - + Arguments: p_l: float, the probability of the residual branch being kept. @@ -57,21 +58,23 @@ def __init__(self, p_l: float = 0.5, **kwargs): self.p_l = p_l - def call(self, x, training = None): + def call(self, x, training=None): assert isinstance(x, list) - + shortcut, residual = x # Random bernoulli variable with probability p_l, indiciathing wheter the branch should be kept or not or not - b_l = tf.keras.backend.random_binomial([], p = self.p_l) + b_l = tf.keras.backend.random_binomial([], p=self.p_l) def _call_train(): return shortcut + b_l * residual - + def _call_test(): return shortcut + self.p_l * residual - - return tf.keras.backend.in_train_phase(_call_train, _call_test, training = training) + + return tf.keras.backend.in_train_phase( + _call_train, _call_test, training=training + ) def compute_output_shape(self, input_shape): assert isinstance(input_shape, list) @@ -80,7 +83,7 @@ def compute_output_shape(self, input_shape): def get_config(self): base_config = super().get_config() - + config = {"p_l": self.p_l} - - return {**base_config, **config} \ No newline at end of file + + return {**base_config, **config} diff --git a/tensorflow_addons/layers/tests/stochastic_depth_test.py b/tensorflow_addons/layers/tests/stochastic_depth_test.py index c0347b77ff..3095410a16 100644 --- a/tensorflow_addons/layers/tests/stochastic_depth_test.py +++ b/tensorflow_addons/layers/tests/stochastic_depth_test.py @@ -8,6 +8,7 @@ KEEP_SEED = 1111 DROP_SEED = 2222 + @pytest.mark.parametrize("seed", [KEEP_SEED, DROP_SEED]) @pytest.mark.parametrize("training", [True, False]) def stochastic_depth_test(seed, training): @@ -31,13 +32,15 @@ def stochastic_depth_test(seed, training): expected_output = np.asarray([[0.3, 0.3, 0.65]]).astype(np.float32) test_utils.layer_test( - StochasticDepth, kwargs = {"p_l": p_l}, input_data = [shortcut, residual], expected_output = expected_output + StochasticDepth, + kwargs={"p_l": p_l}, + input_data=[shortcut, residual], + expected_output=expected_output, ) + def test_serialization(): - stoch_depth = StochasticDepth( - p_l = 0.5 - ) + stoch_depth = StochasticDepth(p_l=0.5) serialized_stoch_depth = tf.keras.layers.serialize(stoch_depth) new_layer = tf.keras.layers.deserialize(serialized_stoch_depth) - assert serialized_stoch_depth.get_config() == new_layer.get_config() \ No newline at end of file + assert serialized_stoch_depth.get_config() == new_layer.get_config() From 0a20a5e993facf74726babb6d65f39d32b46e708 Mon Sep 17 00:00:00 2001 From: Michael Date: Tue, 8 Sep 2020 13:53:08 +0100 Subject: [PATCH 03/12] Fixed tests and style --- tensorflow_addons/layers/stochastic_depth.py | 2 +- tensorflow_addons/layers/tests/stochastic_depth_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/layers/stochastic_depth.py b/tensorflow_addons/layers/stochastic_depth.py index 53db1c12d0..8f60b2c070 100644 --- a/tensorflow_addons/layers/stochastic_depth.py +++ b/tensorflow_addons/layers/stochastic_depth.py @@ -1,5 +1,5 @@ import tensorflow as tf - +from typeguard import typechecked @tf.keras.utils.register_keras_serializable(package="Addons") class StochasticDepth(tf.keras.layers.Layer): diff --git a/tensorflow_addons/layers/tests/stochastic_depth_test.py b/tensorflow_addons/layers/tests/stochastic_depth_test.py index 3095410a16..d0bb3582fe 100644 --- a/tensorflow_addons/layers/tests/stochastic_depth_test.py +++ b/tensorflow_addons/layers/tests/stochastic_depth_test.py @@ -43,4 +43,4 @@ def test_serialization(): stoch_depth = StochasticDepth(p_l=0.5) serialized_stoch_depth = tf.keras.layers.serialize(stoch_depth) new_layer = tf.keras.layers.deserialize(serialized_stoch_depth) - assert serialized_stoch_depth.get_config() == new_layer.get_config() + assert stoch_depth.get_config() == new_layer.get_config() From 4449b09310f7c70fb5069a1678faca8a38c83baf Mon Sep 17 00:00:00 2001 From: Michael Date: Tue, 8 Sep 2020 19:55:48 +0100 Subject: [PATCH 04/12] Fixed code style --- tensorflow_addons/layers/stochastic_depth.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow_addons/layers/stochastic_depth.py b/tensorflow_addons/layers/stochastic_depth.py index 8f60b2c070..b0f571ddc5 100644 --- a/tensorflow_addons/layers/stochastic_depth.py +++ b/tensorflow_addons/layers/stochastic_depth.py @@ -1,6 +1,7 @@ import tensorflow as tf from typeguard import typechecked + @tf.keras.utils.register_keras_serializable(package="Addons") class StochasticDepth(tf.keras.layers.Layer): r"""Stochastic Depth layer. From 6e7999aa6db423c2433207f7d9a193b2227ed1a1 Mon Sep 17 00:00:00 2001 From: Michael Date: Tue, 8 Sep 2020 20:02:07 +0100 Subject: [PATCH 05/12] Updated CODEOWNERS --- .github/CODEOWNERS | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 879f140855..50d9086b3a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -105,6 +105,7 @@ /tensorflow_addons/layers/tests/esn_test.py @pedrolarben /tensorflow_addons/layers/snake.py @failure-to-thrive /tensorflow_addons/layers/tests/snake_test.py @failure-to-thrive +/tensorflow_addons/layers/stochastic_depth.py @mhstadler @windqaq /tensorflow_addons/losses/contrastive.py @windqaq /tensorflow_addons/losses/tests/contrastive_test.py @windqaq From 5c77ff4c0bd4b036bff20a5265127ec31cab903b Mon Sep 17 00:00:00 2001 From: Michael Date: Tue, 8 Sep 2020 20:29:14 +0100 Subject: [PATCH 06/12] Added codeowners for tests --- .github/CODEOWNERS | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 50d9086b3a..aea78775cc 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -106,6 +106,7 @@ /tensorflow_addons/layers/snake.py @failure-to-thrive /tensorflow_addons/layers/tests/snake_test.py @failure-to-thrive /tensorflow_addons/layers/stochastic_depth.py @mhstadler @windqaq +/tensorflow_addons/layers/tests/stochastic_depth_test.py @mhstadler @windqaq /tensorflow_addons/losses/contrastive.py @windqaq /tensorflow_addons/losses/tests/contrastive_test.py @windqaq From 6321ce4e279efdf38b02c65ce122aed4ec6c8463 Mon Sep 17 00:00:00 2001 From: Michael Stadler Date: Wed, 9 Sep 2020 20:28:27 +0100 Subject: [PATCH 07/12] Changes after code review --- tensorflow_addons/layers/stochastic_depth.py | 52 ++++++++++--------- .../layers/tests/stochastic_depth_test.py | 16 +++--- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/tensorflow_addons/layers/stochastic_depth.py b/tensorflow_addons/layers/stochastic_depth.py index b0f571ddc5..74a5b1e201 100644 --- a/tensorflow_addons/layers/stochastic_depth.py +++ b/tensorflow_addons/layers/stochastic_depth.py @@ -14,36 +14,32 @@ class StochasticDepth(tf.keras.layers.Layer): Residual architectures with fixed depth, use residual branches that are merged back into the main network by adding the residual branch back to the input: - ```python - residual = tf.keras.layers.Conv2D(...)(input) - - return tf.keras.layers.Add()([input, residual]) - ``` + >>> input = np.ones((3, 3, 1)) + >>> residual = tf.keras.layers.Conv2D(1, 1)(input) + >>> tfa.layers.StochasticDepth()([input, residual]) StochasticDepth acts as a drop-in replacement for the addition: - ```python - residual = tf.keras.layers.Conv2D(...)(input) - - return tfa.layers.StochasticDepth()([input, residual]) - ``` + >>> input = np.ones((3, 3, 1)) + >>> residual = tf.keras.layers.Conv2D(1, 1)(input) + >>> tfa.layers.StochasticDepth()([input, residual]) At train time, StochasticDepth returns: - ```python + $$ x[0] + b_l * x[1] - ``` + $$ - , where b_l is a random Bernoulli variable with probability p(b_l == 1) == p_l + , where $b_l$ is a random Bernoulli variable with probability $p(b_l == 1) == p_l$ - At test time, StochasticDepth rescales the activations of the residual branch based on the survival probability: + At test time, StochasticDepth rescales the activations of the residual branch based on the survival probability ($p_l$): - ```python + $$ x[0] + p_l * x[1] - ``` + $$ Arguments: - p_l: float, the probability of the residual branch being kept. + survival_probability: float, the probability of the residual branch being kept. Call Arguments: inputs: List of `[shortcut, residual]` where @@ -54,37 +50,43 @@ class StochasticDepth(tf.keras.layers.Layer): """ @typechecked - def __init__(self, p_l: float = 0.5, **kwargs): + def __init__(self, survival_probability: float = 0.5, **kwargs): super().__init__(**kwargs) - self.p_l = p_l + self.survival_probability = survival_probability def call(self, x, training=None): - assert isinstance(x, list) + assert isinstance(x, list): + raise ValueError("Input must be a list") + assert len(x) == 2: + raise ValueError("Input must have exactly two entries") shortcut, residual = x - # Random bernoulli variable with probability p_l, indiciathing wheter the branch should be kept or not or not - b_l = tf.keras.backend.random_binomial([], p=self.p_l) + # Random bernoulli variable indicating whether the branch should be kept or not or not + b_l = tf.keras.backend.random_binomial([], p=self.survival_probability) def _call_train(): return shortcut + b_l * residual def _call_test(): - return shortcut + self.p_l * residual + return shortcut + self.survival_probability * residual return tf.keras.backend.in_train_phase( _call_train, _call_test, training=training ) def compute_output_shape(self, input_shape): - assert isinstance(input_shape, list) + assert isinstance(x, list): + raise ValueError("Input must be a list") + assert len(x) == 2: + raise ValueError("Input must have exactly two entries") return input_shape[0] def get_config(self): base_config = super().get_config() - config = {"p_l": self.p_l} + config = {"survival_probability": self.survival_probability} return {**base_config, **config} diff --git a/tensorflow_addons/layers/tests/stochastic_depth_test.py b/tensorflow_addons/layers/tests/stochastic_depth_test.py index d0bb3582fe..a834a0ef26 100644 --- a/tensorflow_addons/layers/tests/stochastic_depth_test.py +++ b/tensorflow_addons/layers/tests/stochastic_depth_test.py @@ -5,26 +5,26 @@ from tensorflow_addons.layers.stochastic_depth import StochasticDepth from tensorflow_addons.utils import test_utils -KEEP_SEED = 1111 -DROP_SEED = 2222 +_KEEP_SEED = 1111 +_DROP_SEED = 2222 -@pytest.mark.parametrize("seed", [KEEP_SEED, DROP_SEED]) +@pytest.mark.parametrize("seed", [_KEEP_SEED, _DROP_SEED]) @pytest.mark.parametrize("training", [True, False]) def stochastic_depth_test(seed, training): np.random.seed(seed) tf.random.set_seed(seed) - p_l = 0.5 + survival_probability = 0.5 shortcut = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32) residual = np.asarray([[0.2, 0.4, 0.5]]).astype(np.float32) if training: - if seed == KEEP_SEED: + if seed == _KEEP_SEED: # shortcut + residual expected_output = np.asarray([[0.4, 0.5, 0.9]]).astype(np.float32) - elif seed == DROP_SEED: + elif seed == _DROP_SEED: # shortcut expected_output = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32) else: @@ -33,14 +33,14 @@ def stochastic_depth_test(seed, training): test_utils.layer_test( StochasticDepth, - kwargs={"p_l": p_l}, + kwargs={"survival_probability": survival_probability}, input_data=[shortcut, residual], expected_output=expected_output, ) def test_serialization(): - stoch_depth = StochasticDepth(p_l=0.5) + stoch_depth = StochasticDepth(survival_probability=0.5) serialized_stoch_depth = tf.keras.layers.serialize(stoch_depth) new_layer = tf.keras.layers.deserialize(serialized_stoch_depth) assert stoch_depth.get_config() == new_layer.get_config() From 6ea93cd43897d5514744e0b2243db1beedf5005a Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 9 Sep 2020 21:26:15 +0100 Subject: [PATCH 08/12] Test and formatting fixes --- tensorflow_addons/layers/stochastic_depth.py | 22 ++++++++------------ 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/tensorflow_addons/layers/stochastic_depth.py b/tensorflow_addons/layers/stochastic_depth.py index 74a5b1e201..66d6f46d34 100644 --- a/tensorflow_addons/layers/stochastic_depth.py +++ b/tensorflow_addons/layers/stochastic_depth.py @@ -4,7 +4,7 @@ @tf.keras.utils.register_keras_serializable(package="Addons") class StochasticDepth(tf.keras.layers.Layer): - r"""Stochastic Depth layer. + """Stochastic Depth layer. Implements Stochastic Depth as described in [Deep Networks with Stochastic Depth](https://arxiv.org/abs/1603.09382), to randomly drop residual branches @@ -14,15 +14,15 @@ class StochasticDepth(tf.keras.layers.Layer): Residual architectures with fixed depth, use residual branches that are merged back into the main network by adding the residual branch back to the input: - >>> input = np.ones((3, 3, 1)) + >>> input = np.ones((1, 3, 3, 1), dtype = np.float32) >>> residual = tf.keras.layers.Conv2D(1, 1)(input) - >>> tfa.layers.StochasticDepth()([input, residual]) + >>> output = tfa.layers.StochasticDepth()([input, residual]) StochasticDepth acts as a drop-in replacement for the addition: - >>> input = np.ones((3, 3, 1)) + >>> input = np.ones((1, 3, 3, 1), dtype = np.float32) >>> residual = tf.keras.layers.Conv2D(1, 1)(input) - >>> tfa.layers.StochasticDepth()([input, residual]) + >>> output = tfa.layers.StochasticDepth()([input, residual]) At train time, StochasticDepth returns: @@ -56,10 +56,8 @@ def __init__(self, survival_probability: float = 0.5, **kwargs): self.survival_probability = survival_probability def call(self, x, training=None): - assert isinstance(x, list): - raise ValueError("Input must be a list") - assert len(x) == 2: - raise ValueError("Input must have exactly two entries") + assert isinstance(x, list), "Input must be a list" + assert len(x) == 2, "Input must have exactly two entries" shortcut, residual = x @@ -77,10 +75,8 @@ def _call_test(): ) def compute_output_shape(self, input_shape): - assert isinstance(x, list): - raise ValueError("Input must be a list") - assert len(x) == 2: - raise ValueError("Input must have exactly two entries") + assert isinstance(input_shape, list), "Input must be a list" + assert len(input_shape) == 2, "Input must have exactly two entries" return input_shape[0] From d2bbed7f583ee265688751a2f3fb2d262577656b Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 9 Sep 2020 21:47:55 +0100 Subject: [PATCH 09/12] Fixed doc string --- tensorflow_addons/layers/stochastic_depth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/layers/stochastic_depth.py b/tensorflow_addons/layers/stochastic_depth.py index 66d6f46d34..3233e2588e 100644 --- a/tensorflow_addons/layers/stochastic_depth.py +++ b/tensorflow_addons/layers/stochastic_depth.py @@ -16,7 +16,7 @@ class StochasticDepth(tf.keras.layers.Layer): >>> input = np.ones((1, 3, 3, 1), dtype = np.float32) >>> residual = tf.keras.layers.Conv2D(1, 1)(input) - >>> output = tfa.layers.StochasticDepth()([input, residual]) + >>> output = tfa.layers.Add()([input, residual]) StochasticDepth acts as a drop-in replacement for the addition: From bb3f95585120cead2c0b788085204fd13b9f3f7e Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 9 Sep 2020 22:10:11 +0100 Subject: [PATCH 10/12] Added mixed precision test --- tensorflow_addons/layers/stochastic_depth.py | 2 +- .../layers/tests/stochastic_depth_test.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tensorflow_addons/layers/stochastic_depth.py b/tensorflow_addons/layers/stochastic_depth.py index 3233e2588e..1d810f4659 100644 --- a/tensorflow_addons/layers/stochastic_depth.py +++ b/tensorflow_addons/layers/stochastic_depth.py @@ -16,7 +16,7 @@ class StochasticDepth(tf.keras.layers.Layer): >>> input = np.ones((1, 3, 3, 1), dtype = np.float32) >>> residual = tf.keras.layers.Conv2D(1, 1)(input) - >>> output = tfa.layers.Add()([input, residual]) + >>> output = tf.keras.layers.Add()([input, residual]) StochasticDepth acts as a drop-in replacement for the addition: diff --git a/tensorflow_addons/layers/tests/stochastic_depth_test.py b/tensorflow_addons/layers/tests/stochastic_depth_test.py index a834a0ef26..14464109a7 100644 --- a/tensorflow_addons/layers/tests/stochastic_depth_test.py +++ b/tensorflow_addons/layers/tests/stochastic_depth_test.py @@ -39,6 +39,18 @@ def stochastic_depth_test(seed, training): ) +@pytest.mark.usefixtures("run_with_mixed_precision_policy") +def test_with_mixed_precision_policy(): + policy = tf.keras.mixed_precision.experimental.global_policy() + + shortcut = np.asarray([[0.2, 0.1, 0.4]]) + residual = np.asarray([[0.2, 0.4, 0.5]]) + + output = StochasticDepth()([shortcut, residual]) + + assert output.dtype == policy.compute_dtype + + def test_serialization(): stoch_depth = StochasticDepth(survival_probability=0.5) serialized_stoch_depth = tf.keras.layers.serialize(stoch_depth) From adced7ac28734bff69d5f12f3d979ade63b61bce Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 9 Sep 2020 23:17:09 +0100 Subject: [PATCH 11/12] Further code review changes --- tensorflow_addons/layers/stochastic_depth.py | 17 ++++++++++------- .../layers/tests/stochastic_depth_test.py | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tensorflow_addons/layers/stochastic_depth.py b/tensorflow_addons/layers/stochastic_depth.py index 1d810f4659..597d4762b6 100644 --- a/tensorflow_addons/layers/stochastic_depth.py +++ b/tensorflow_addons/layers/stochastic_depth.py @@ -17,12 +17,16 @@ class StochasticDepth(tf.keras.layers.Layer): >>> input = np.ones((1, 3, 3, 1), dtype = np.float32) >>> residual = tf.keras.layers.Conv2D(1, 1)(input) >>> output = tf.keras.layers.Add()([input, residual]) + >>> output.shape + TensorShape([1, 3, 3, 1]) StochasticDepth acts as a drop-in replacement for the addition: >>> input = np.ones((1, 3, 3, 1), dtype = np.float32) >>> residual = tf.keras.layers.Conv2D(1, 1)(input) >>> output = tfa.layers.StochasticDepth()([input, residual]) + >>> output.shape + TensorShape([1, 3, 3, 1]) At train time, StochasticDepth returns: @@ -30,7 +34,7 @@ class StochasticDepth(tf.keras.layers.Layer): x[0] + b_l * x[1] $$ - , where $b_l$ is a random Bernoulli variable with probability $p(b_l == 1) == p_l$ + , where $b_l$ is a random Bernoulli variable with probability $P(b_l = 1) = p_l$ At test time, StochasticDepth rescales the activations of the residual branch based on the survival probability ($p_l$): @@ -42,8 +46,7 @@ class StochasticDepth(tf.keras.layers.Layer): survival_probability: float, the probability of the residual branch being kept. Call Arguments: - inputs: List of `[shortcut, residual]` where - * `shortcut`, and `residual` are tensors of equal shape. + inputs: List of `[shortcut, residual]` where `shortcut`, and `residual` are tensors of equal shape. Output shape: Equal to the shape of inputs `shortcut`, and `residual` @@ -56,8 +59,8 @@ def __init__(self, survival_probability: float = 0.5, **kwargs): self.survival_probability = survival_probability def call(self, x, training=None): - assert isinstance(x, list), "Input must be a list" - assert len(x) == 2, "Input must have exactly two entries" + if not isinstance(x, list) or len(x) != 2: + raise ValueError("input must be a list of length 2.") shortcut, residual = x @@ -75,8 +78,8 @@ def _call_test(): ) def compute_output_shape(self, input_shape): - assert isinstance(input_shape, list), "Input must be a list" - assert len(input_shape) == 2, "Input must have exactly two entries" + if not isinstance(input_shape, list) or len(input_shape) != 2: + raise ValueError("input_shape must be a list of length 2.") return input_shape[0] diff --git a/tensorflow_addons/layers/tests/stochastic_depth_test.py b/tensorflow_addons/layers/tests/stochastic_depth_test.py index 14464109a7..1122016f57 100644 --- a/tensorflow_addons/layers/tests/stochastic_depth_test.py +++ b/tensorflow_addons/layers/tests/stochastic_depth_test.py @@ -1,7 +1,7 @@ import pytest - import numpy as np import tensorflow as tf + from tensorflow_addons.layers.stochastic_depth import StochasticDepth from tensorflow_addons.utils import test_utils From de99e1cb690995e794c6d07673f84b1c00fdf63e Mon Sep 17 00:00:00 2001 From: Michael Date: Tue, 15 Sep 2020 23:50:06 +0100 Subject: [PATCH 12/12] Code review changes --- tensorflow_addons/layers/stochastic_depth.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tensorflow_addons/layers/stochastic_depth.py b/tensorflow_addons/layers/stochastic_depth.py index 597d4762b6..3cf0df4f8c 100644 --- a/tensorflow_addons/layers/stochastic_depth.py +++ b/tensorflow_addons/layers/stochastic_depth.py @@ -31,10 +31,10 @@ class StochasticDepth(tf.keras.layers.Layer): At train time, StochasticDepth returns: $$ - x[0] + b_l * x[1] + x[0] + b_l * x[1], $$ - , where $b_l$ is a random Bernoulli variable with probability $P(b_l = 1) = p_l$ + where $b_l$ is a random Bernoulli variable with probability $P(b_l = 1) = p_l$ At test time, StochasticDepth rescales the activations of the residual branch based on the survival probability ($p_l$): @@ -65,7 +65,7 @@ def call(self, x, training=None): shortcut, residual = x # Random bernoulli variable indicating whether the branch should be kept or not or not - b_l = tf.keras.backend.random_binomial([], p=self.survival_probability) + b_l = tf.keras.backend.random_bernoulli([], p=self.survival_probability) def _call_train(): return shortcut + b_l * residual @@ -78,9 +78,6 @@ def _call_test(): ) def compute_output_shape(self, input_shape): - if not isinstance(input_shape, list) or len(input_shape) != 2: - raise ValueError("input_shape must be a list of length 2.") - return input_shape[0] def get_config(self):