Skip to content
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
/tensorflow_addons/layers/tests/esn_test.py @pedrolarben
/tensorflow_addons/layers/snake.py @failure-to-thrive
/tensorflow_addons/layers/tests/snake_test.py @failure-to-thrive
/tensorflow_addons/layers/stochastic_depth.py @mhstadler @windqaq
/tensorflow_addons/layers/tests/stochastic_depth_test.py @mhstadler @windqaq
/tensorflow_addons/layers/noisy_dense.py @leonshams
/tensorflow_addons/layers/tests/noisy_dense_test.py @leonshams

Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@
from tensorflow_addons.layers.tlu import TLU
from tensorflow_addons.layers.wrappers import WeightNormalization
from tensorflow_addons.layers.esn import ESN
from tensorflow_addons.layers.stochastic_depth import StochasticDepth
from tensorflow_addons.layers.noisy_dense import NoisyDense
88 changes: 88 additions & 0 deletions tensorflow_addons/layers/stochastic_depth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import tensorflow as tf
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
class StochasticDepth(tf.keras.layers.Layer):
"""Stochastic Depth layer.

Implements Stochastic Depth as described in
[Deep Networks with Stochastic Depth](https://arxiv.org/abs/1603.09382), to randomly drop residual branches
in residual architectures.

Usage:
Residual architectures with fixed depth, use residual branches that are merged back into the main network
by adding the residual branch back to the input:

>>> input = np.ones((1, 3, 3, 1), dtype = np.float32)
>>> residual = tf.keras.layers.Conv2D(1, 1)(input)
>>> output = tf.keras.layers.Add()([input, residual])
>>> output.shape
TensorShape([1, 3, 3, 1])

StochasticDepth acts as a drop-in replacement for the addition:

>>> input = np.ones((1, 3, 3, 1), dtype = np.float32)
>>> residual = tf.keras.layers.Conv2D(1, 1)(input)
>>> output = tfa.layers.StochasticDepth()([input, residual])
>>> output.shape
TensorShape([1, 3, 3, 1])

At train time, StochasticDepth returns:

$$
x[0] + b_l * x[1],
$$

where $b_l$ is a random Bernoulli variable with probability $P(b_l = 1) = p_l$

At test time, StochasticDepth rescales the activations of the residual branch based on the survival probability ($p_l$):

$$
x[0] + p_l * x[1]
$$

Arguments:
survival_probability: float, the probability of the residual branch being kept.

Call Arguments:
inputs: List of `[shortcut, residual]` where `shortcut`, and `residual` are tensors of equal shape.

Output shape:
Equal to the shape of inputs `shortcut`, and `residual`
"""

@typechecked
def __init__(self, survival_probability: float = 0.5, **kwargs):
super().__init__(**kwargs)

self.survival_probability = survival_probability

def call(self, x, training=None):
if not isinstance(x, list) or len(x) != 2:
raise ValueError("input must be a list of length 2.")

shortcut, residual = x

# Random bernoulli variable indicating whether the branch should be kept or not or not
b_l = tf.keras.backend.random_bernoulli([], p=self.survival_probability)

def _call_train():
return shortcut + b_l * residual

def _call_test():
return shortcut + self.survival_probability * residual

return tf.keras.backend.in_train_phase(
_call_train, _call_test, training=training
)

def compute_output_shape(self, input_shape):
return input_shape[0]

def get_config(self):
base_config = super().get_config()

config = {"survival_probability": self.survival_probability}

return {**base_config, **config}
58 changes: 58 additions & 0 deletions tensorflow_addons/layers/tests/stochastic_depth_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
import numpy as np
import tensorflow as tf

from tensorflow_addons.layers.stochastic_depth import StochasticDepth
from tensorflow_addons.utils import test_utils

_KEEP_SEED = 1111
_DROP_SEED = 2222


@pytest.mark.parametrize("seed", [_KEEP_SEED, _DROP_SEED])
@pytest.mark.parametrize("training", [True, False])
def stochastic_depth_test(seed, training):
np.random.seed(seed)
tf.random.set_seed(seed)

survival_probability = 0.5

shortcut = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32)
residual = np.asarray([[0.2, 0.4, 0.5]]).astype(np.float32)

if training:
if seed == _KEEP_SEED:
# shortcut + residual
expected_output = np.asarray([[0.4, 0.5, 0.9]]).astype(np.float32)
elif seed == _DROP_SEED:
# shortcut
expected_output = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32)
else:
# shortcut + p_l * residual
expected_output = np.asarray([[0.3, 0.3, 0.65]]).astype(np.float32)

test_utils.layer_test(
StochasticDepth,
kwargs={"survival_probability": survival_probability},
input_data=[shortcut, residual],
expected_output=expected_output,
)


@pytest.mark.usefixtures("run_with_mixed_precision_policy")
def test_with_mixed_precision_policy():
policy = tf.keras.mixed_precision.experimental.global_policy()

shortcut = np.asarray([[0.2, 0.1, 0.4]])
residual = np.asarray([[0.2, 0.4, 0.5]])

output = StochasticDepth()([shortcut, residual])

assert output.dtype == policy.compute_dtype


def test_serialization():
stoch_depth = StochasticDepth(survival_probability=0.5)
serialized_stoch_depth = tf.keras.layers.serialize(stoch_depth)
new_layer = tf.keras.layers.deserialize(serialized_stoch_depth)
assert stoch_depth.get_config() == new_layer.get_config()