-
Notifications
You must be signed in to change notification settings - Fork 617
Added stochastic depth layer #2154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
WindQAQ
merged 13 commits into
tensorflow:master
from
MHStadler:feature/stochastic_depth
Sep 16, 2020
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
bb828f6
Added stochastic depth layer
MHStadler fb94ee8
Fixed code style and added missing __init__ entry
MHStadler 0a20a5e
Fixed tests and style
MHStadler 4449b09
Fixed code style
MHStadler 6e7999a
Updated CODEOWNERS
MHStadler 5c77ff4
Added codeowners for tests
MHStadler 6321ce4
Changes after code review
MHStadler 6ea93cd
Test and formatting fixes
MHStadler d2bbed7
Fixed doc string
MHStadler bb3f955
Added mixed precision test
MHStadler adced7a
Further code review changes
MHStadler de99e1c
Code review changes
MHStadler b2189b4
Merge branch 'master' into feature/stochastic_depth
MHStadler File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| import tensorflow as tf | ||
| from typeguard import typechecked | ||
|
|
||
|
|
||
| @tf.keras.utils.register_keras_serializable(package="Addons") | ||
| class StochasticDepth(tf.keras.layers.Layer): | ||
| """Stochastic Depth layer. | ||
|
|
||
| Implements Stochastic Depth as described in | ||
| [Deep Networks with Stochastic Depth](https://arxiv.org/abs/1603.09382), to randomly drop residual branches | ||
| in residual architectures. | ||
|
|
||
| Usage: | ||
| Residual architectures with fixed depth, use residual branches that are merged back into the main network | ||
| by adding the residual branch back to the input: | ||
|
|
||
| >>> input = np.ones((1, 3, 3, 1), dtype = np.float32) | ||
| >>> residual = tf.keras.layers.Conv2D(1, 1)(input) | ||
| >>> output = tf.keras.layers.Add()([input, residual]) | ||
| >>> output.shape | ||
| TensorShape([1, 3, 3, 1]) | ||
|
|
||
| StochasticDepth acts as a drop-in replacement for the addition: | ||
|
|
||
| >>> input = np.ones((1, 3, 3, 1), dtype = np.float32) | ||
| >>> residual = tf.keras.layers.Conv2D(1, 1)(input) | ||
| >>> output = tfa.layers.StochasticDepth()([input, residual]) | ||
MHStadler marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| >>> output.shape | ||
| TensorShape([1, 3, 3, 1]) | ||
|
|
||
| At train time, StochasticDepth returns: | ||
|
|
||
| $$ | ||
| x[0] + b_l * x[1], | ||
| $$ | ||
|
|
||
| where $b_l$ is a random Bernoulli variable with probability $P(b_l = 1) = p_l$ | ||
|
|
||
| At test time, StochasticDepth rescales the activations of the residual branch based on the survival probability ($p_l$): | ||
|
|
||
| $$ | ||
| x[0] + p_l * x[1] | ||
| $$ | ||
|
|
||
| Arguments: | ||
| survival_probability: float, the probability of the residual branch being kept. | ||
MHStadler marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Call Arguments: | ||
| inputs: List of `[shortcut, residual]` where `shortcut`, and `residual` are tensors of equal shape. | ||
|
|
||
| Output shape: | ||
| Equal to the shape of inputs `shortcut`, and `residual` | ||
| """ | ||
|
|
||
| @typechecked | ||
| def __init__(self, survival_probability: float = 0.5, **kwargs): | ||
| super().__init__(**kwargs) | ||
|
|
||
| self.survival_probability = survival_probability | ||
|
|
||
| def call(self, x, training=None): | ||
| if not isinstance(x, list) or len(x) != 2: | ||
| raise ValueError("input must be a list of length 2.") | ||
|
|
||
| shortcut, residual = x | ||
|
|
||
| # Random bernoulli variable indicating whether the branch should be kept or not or not | ||
| b_l = tf.keras.backend.random_bernoulli([], p=self.survival_probability) | ||
|
|
||
| def _call_train(): | ||
| return shortcut + b_l * residual | ||
|
|
||
| def _call_test(): | ||
| return shortcut + self.survival_probability * residual | ||
|
|
||
| return tf.keras.backend.in_train_phase( | ||
| _call_train, _call_test, training=training | ||
| ) | ||
|
|
||
| def compute_output_shape(self, input_shape): | ||
| return input_shape[0] | ||
|
|
||
| def get_config(self): | ||
| base_config = super().get_config() | ||
|
|
||
| config = {"survival_probability": self.survival_probability} | ||
|
|
||
| return {**base_config, **config} | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| import pytest | ||
| import numpy as np | ||
| import tensorflow as tf | ||
|
|
||
| from tensorflow_addons.layers.stochastic_depth import StochasticDepth | ||
| from tensorflow_addons.utils import test_utils | ||
MHStadler marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| _KEEP_SEED = 1111 | ||
| _DROP_SEED = 2222 | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("seed", [_KEEP_SEED, _DROP_SEED]) | ||
| @pytest.mark.parametrize("training", [True, False]) | ||
| def stochastic_depth_test(seed, training): | ||
| np.random.seed(seed) | ||
| tf.random.set_seed(seed) | ||
|
|
||
| survival_probability = 0.5 | ||
|
|
||
| shortcut = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32) | ||
| residual = np.asarray([[0.2, 0.4, 0.5]]).astype(np.float32) | ||
|
|
||
| if training: | ||
| if seed == _KEEP_SEED: | ||
| # shortcut + residual | ||
| expected_output = np.asarray([[0.4, 0.5, 0.9]]).astype(np.float32) | ||
| elif seed == _DROP_SEED: | ||
| # shortcut | ||
| expected_output = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32) | ||
| else: | ||
| # shortcut + p_l * residual | ||
| expected_output = np.asarray([[0.3, 0.3, 0.65]]).astype(np.float32) | ||
|
|
||
| test_utils.layer_test( | ||
| StochasticDepth, | ||
| kwargs={"survival_probability": survival_probability}, | ||
| input_data=[shortcut, residual], | ||
| expected_output=expected_output, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures("run_with_mixed_precision_policy") | ||
| def test_with_mixed_precision_policy(): | ||
| policy = tf.keras.mixed_precision.experimental.global_policy() | ||
|
|
||
| shortcut = np.asarray([[0.2, 0.1, 0.4]]) | ||
| residual = np.asarray([[0.2, 0.4, 0.5]]) | ||
|
|
||
| output = StochasticDepth()([shortcut, residual]) | ||
|
|
||
| assert output.dtype == policy.compute_dtype | ||
|
|
||
|
|
||
| def test_serialization(): | ||
| stoch_depth = StochasticDepth(survival_probability=0.5) | ||
| serialized_stoch_depth = tf.keras.layers.serialize(stoch_depth) | ||
| new_layer = tf.keras.layers.deserialize(serialized_stoch_depth) | ||
| assert stoch_depth.get_config() == new_layer.get_config() | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.