Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
/tensorflow_addons/layers/tests/esn_test.py @pedrolarben
/tensorflow_addons/layers/snake.py @failure-to-thrive
/tensorflow_addons/layers/tests/snake_test.py @failure-to-thrive
/tensorflow_addons/layers/stochastic_depth.py @mhstadler @windqaq
/tensorflow_addons/layers/tests/stochastic_depth_test.py @mhstadler @windqaq

/tensorflow_addons/losses/contrastive.py @windqaq
/tensorflow_addons/losses/tests/contrastive_test.py @windqaq
Expand Down
4 changes: 0 additions & 4 deletions tensorflow_addons/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,5 @@
from tensorflow_addons.layers.tlu import TLU
from tensorflow_addons.layers.wrappers import WeightNormalization
from tensorflow_addons.layers.esn import ESN
<<<<<<< HEAD
=======
from tensorflow_addons.layers.stochastic_depth import StochasticDepth
from tensorflow_addons.layers.noisy_dense import NoisyDense
from tensorflow_addons.layers.crf import CRF
>>>>>>> 2cd311a... updating documentation for CRF (#2168)
88 changes: 88 additions & 0 deletions tensorflow_addons/layers/stochastic_depth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import tensorflow as tf
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
class StochasticDepth(tf.keras.layers.Layer):
"""Stochastic Depth layer.

Implements Stochastic Depth as described in
[Deep Networks with Stochastic Depth](https://arxiv.org/abs/1603.09382), to randomly drop residual branches
in residual architectures.

Usage:
Residual architectures with fixed depth, use residual branches that are merged back into the main network
by adding the residual branch back to the input:

>>> input = np.ones((1, 3, 3, 1), dtype = np.float32)
>>> residual = tf.keras.layers.Conv2D(1, 1)(input)
>>> output = tf.keras.layers.Add()([input, residual])
>>> output.shape
TensorShape([1, 3, 3, 1])

StochasticDepth acts as a drop-in replacement for the addition:

>>> input = np.ones((1, 3, 3, 1), dtype = np.float32)
>>> residual = tf.keras.layers.Conv2D(1, 1)(input)
>>> output = tfa.layers.StochasticDepth()([input, residual])
>>> output.shape
TensorShape([1, 3, 3, 1])

At train time, StochasticDepth returns:

$$
x[0] + b_l * x[1],
$$

where $b_l$ is a random Bernoulli variable with probability $P(b_l = 1) = p_l$

At test time, StochasticDepth rescales the activations of the residual branch based on the survival probability ($p_l$):

$$
x[0] + p_l * x[1]
$$

Arguments:
survival_probability: float, the probability of the residual branch being kept.

Call Arguments:
inputs: List of `[shortcut, residual]` where `shortcut`, and `residual` are tensors of equal shape.

Output shape:
Equal to the shape of inputs `shortcut`, and `residual`
"""

@typechecked
def __init__(self, survival_probability: float = 0.5, **kwargs):
super().__init__(**kwargs)

self.survival_probability = survival_probability

def call(self, x, training=None):
if not isinstance(x, list) or len(x) != 2:
raise ValueError("input must be a list of length 2.")

shortcut, residual = x

# Random bernoulli variable indicating whether the branch should be kept or not or not
b_l = tf.keras.backend.random_bernoulli([], p=self.survival_probability)

def _call_train():
return shortcut + b_l * residual

def _call_test():
return shortcut + self.survival_probability * residual

return tf.keras.backend.in_train_phase(
_call_train, _call_test, training=training
)

def compute_output_shape(self, input_shape):
return input_shape[0]

def get_config(self):
base_config = super().get_config()

config = {"survival_probability": self.survival_probability}

return {**base_config, **config}
58 changes: 58 additions & 0 deletions tensorflow_addons/layers/tests/stochastic_depth_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
import numpy as np
import tensorflow as tf

from tensorflow_addons.layers.stochastic_depth import StochasticDepth
from tensorflow_addons.utils import test_utils

_KEEP_SEED = 1111
_DROP_SEED = 2222


@pytest.mark.parametrize("seed", [_KEEP_SEED, _DROP_SEED])
@pytest.mark.parametrize("training", [True, False])
def stochastic_depth_test(seed, training):
np.random.seed(seed)
tf.random.set_seed(seed)

survival_probability = 0.5

shortcut = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32)
residual = np.asarray([[0.2, 0.4, 0.5]]).astype(np.float32)

if training:
if seed == _KEEP_SEED:
# shortcut + residual
expected_output = np.asarray([[0.4, 0.5, 0.9]]).astype(np.float32)
elif seed == _DROP_SEED:
# shortcut
expected_output = np.asarray([[0.2, 0.1, 0.4]]).astype(np.float32)
else:
# shortcut + p_l * residual
expected_output = np.asarray([[0.3, 0.3, 0.65]]).astype(np.float32)

test_utils.layer_test(
StochasticDepth,
kwargs={"survival_probability": survival_probability},
input_data=[shortcut, residual],
expected_output=expected_output,
)


@pytest.mark.usefixtures("run_with_mixed_precision_policy")
def test_with_mixed_precision_policy():
policy = tf.keras.mixed_precision.experimental.global_policy()

shortcut = np.asarray([[0.2, 0.1, 0.4]])
residual = np.asarray([[0.2, 0.4, 0.5]])

output = StochasticDepth()([shortcut, residual])

assert output.dtype == policy.compute_dtype


def test_serialization():
stoch_depth = StochasticDepth(survival_probability=0.5)
serialized_stoch_depth = tf.keras.layers.serialize(stoch_depth)
new_layer = tf.keras.layers.deserialize(serialized_stoch_depth)
assert stoch_depth.get_config() == new_layer.get_config()