Skip to content

Commit dbcd5aa

Browse files
authored
Added stochastic depth layer (#2154)
* Added stochastic depth layer * Fixed code style and added missing __init__ entry * Fixed tests and style * Fixed code style * Updated CODEOWNERS * Added codeowners for tests * Changes after code review * Test and formatting fixes * Fixed doc string * Added mixed precision test * Further code review changes * Code review changes
1 parent 1c3c072 commit dbcd5aa

File tree

4 files changed

+149
-0
lines changed

4 files changed

+149
-0
lines changed

.github/CODEOWNERS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@
105105
/tensorflow_addons/layers/tests/esn_test.py @pedrolarben
106106
/tensorflow_addons/layers/snake.py @failure-to-thrive
107107
/tensorflow_addons/layers/tests/snake_test.py @failure-to-thrive
108+
/tensorflow_addons/layers/stochastic_depth.py @mhstadler @windqaq
109+
/tensorflow_addons/layers/tests/stochastic_depth_test.py @mhstadler @windqaq
108110
/tensorflow_addons/layers/noisy_dense.py @leonshams
109111
/tensorflow_addons/layers/tests/noisy_dense_test.py @leonshams
110112

tensorflow_addons/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,5 @@
3838
from tensorflow_addons.layers.tlu import TLU
3939
from tensorflow_addons.layers.wrappers import WeightNormalization
4040
from tensorflow_addons.layers.esn import ESN
41+
from tensorflow_addons.layers.stochastic_depth import StochasticDepth
4142
from tensorflow_addons.layers.noisy_dense import NoisyDense
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import tensorflow as tf
2+
from typeguard import typechecked
3+
4+
5+
@tf.keras.utils.register_keras_serializable(package="Addons")
6+
class StochasticDepth(tf.keras.layers.Layer):
7+
"""Stochastic Depth layer.
8+
9+
Implements Stochastic Depth as described in
10+
[Deep Networks with Stochastic Depth](https://arxiv.org/abs/1603.09382), to randomly drop residual branches
11+
in residual architectures.
12+
13+
Usage:
14+
Residual architectures with fixed depth, use residual branches that are merged back into the main network
15+
by adding the residual branch back to the input:
16+
17+
>>> input = np.ones((1, 3, 3, 1), dtype = np.float32)
18+
>>> residual = tf.keras.layers.Conv2D(1, 1)(input)
19+
>>> output = tf.keras.layers.Add()([input, residual])
20+
>>> output.shape
21+
TensorShape([1, 3, 3, 1])
22+
23+
StochasticDepth acts as a drop-in replacement for the addition:
24+
25+
>>> input = np.ones((1, 3, 3, 1), dtype = np.float32)
26+
>>> residual = tf.keras.layers.Conv2D(1, 1)(input)
27+
>>> output = tfa.layers.StochasticDepth()([input, residual])
28+
>>> output.shape
29+
TensorShape([1, 3, 3, 1])
30+
31+
At train time, StochasticDepth returns:
32+
33+
$$
34+
x[0] + b_l * x[1],
35+
$$
36+
37+
where $b_l$ is a random Bernoulli variable with probability $P(b_l = 1) = p_l$
38+
39+
At test time, StochasticDepth rescales the activations of the residual branch based on the survival probability ($p_l$):
40+
41+
$$
42+
x[0] + p_l * x[1]
43+
$$
44+
45+
Arguments:
46+
survival_probability: float, the probability of the residual branch being kept.
47+
48+
Call Arguments:
49+
inputs: List of `[shortcut, residual]` where `shortcut`, and `residual` are tensors of equal shape.
50+
51+
Output shape:
52+
Equal to the shape of inputs `shortcut`, and `residual`
53+
"""
54+
55+
@typechecked
56+
def __init__(self, survival_probability: float = 0.5, **kwargs):
57+
super().__init__(**kwargs)
58+
59+
self.survival_probability = survival_probability
60+
61+
def call(self, x, training=None):
62+
if not isinstance(x, list) or len(x) != 2:
63+
raise ValueError("input must be a list of length 2.")
64+
65+
shortcut, residual = x
66+
67+
# Random bernoulli variable indicating whether the branch should be kept or not or not
68+
b_l = tf.keras.backend.random_bernoulli([], p=self.survival_probability)
69+
70+
def _call_train():
71+
return shortcut + b_l * residual
72+
73+
def _call_test():
74+
return shortcut + self.survival_probability * residual
75+
76+
return tf.keras.backend.in_train_phase(
77+
_call_train, _call_test, training=training
78+
)
79+
80+
def compute_output_shape(self, input_shape):
81+
return input_shape[0]
82+
83+
def get_config(self):
84+
base_config = super().get_config()
85+
86+
config = {"survival_probability": self.survival_probability}
87+
88+
return {**base_config, **config}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import pytest
2+
import numpy as np
3+
import tensorflow as tf
4+
5+
from tensorflow_addons.layers.stochastic_depth import StochasticDepth
6+
from tensorflow_addons.utils import test_utils
7+
8+
_KEEP_SEED = 1111
9+
_DROP_SEED = 2222
10+
11+
12+
@pytest.mark.parametrize("seed", [_KEEP_SEED, _DROP_SEED])
13+
@pytest.mark.parametrize("training", [True, False])
14+
def stochastic_depth_test(seed, training):
15+
np.random.seed(seed)
16+
tf.random.set_seed(seed)
17+
18+
survival_probability = 0.5
19+
20+
shortcut = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32)
21+
residual = np.asarray([[0.2, 0.4, 0.5]]).astype(np.float32)
22+
23+
if training:
24+
if seed == _KEEP_SEED:
25+
# shortcut + residual
26+
expected_output = np.asarray([[0.4, 0.5, 0.9]]).astype(np.float32)
27+
elif seed == _DROP_SEED:
28+
# shortcut
29+
expected_output = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32)
30+
else:
31+
# shortcut + p_l * residual
32+
expected_output = np.asarray([[0.3, 0.3, 0.65]]).astype(np.float32)
33+
34+
test_utils.layer_test(
35+
StochasticDepth,
36+
kwargs={"survival_probability": survival_probability},
37+
input_data=[shortcut, residual],
38+
expected_output=expected_output,
39+
)
40+
41+
42+
@pytest.mark.usefixtures("run_with_mixed_precision_policy")
43+
def test_with_mixed_precision_policy():
44+
policy = tf.keras.mixed_precision.experimental.global_policy()
45+
46+
shortcut = np.asarray([[0.2, 0.1, 0.4]])
47+
residual = np.asarray([[0.2, 0.4, 0.5]])
48+
49+
output = StochasticDepth()([shortcut, residual])
50+
51+
assert output.dtype == policy.compute_dtype
52+
53+
54+
def test_serialization():
55+
stoch_depth = StochasticDepth(survival_probability=0.5)
56+
serialized_stoch_depth = tf.keras.layers.serialize(stoch_depth)
57+
new_layer = tf.keras.layers.deserialize(serialized_stoch_depth)
58+
assert stoch_depth.get_config() == new_layer.get_config()

0 commit comments

Comments
 (0)