-
Notifications
You must be signed in to change notification settings - Fork 617
Description
Describe the feature and the current behavior/state.
Stochastic depth is a regularization technique to train very deep residual networks (the authors for example train ResNets of up to 1202 layers). Particularly, it allows training of, on average, more shallow networks, that retain their full depth at inference time.
Previous attempts (#626) seem to have been unnecessarily convoluted. This implemenation would be a simple layer, attached to the end of a residual branch (as suggested here tensorflow/tensorflow#8817 (comment), which is also the way it is described in the Shake-Drop paper: https://arxiv.org/abs/1802.02375)
Noteably this means that, if momentum is being used, the dropped layers will still receive small updates due to their historic gradients (similar to Dropout).
Relevant information
- Are you willing to contribute it (yes/no): Yes
- Are you willing to maintain it going forward? (yes/no): Yes
- Is there a relevant academic paper? (if so, where): Yes (https://arxiv.org/abs/1603.09382)
- Is there already an implementation in another framework? (if so, where): The authors original pytorch implementation is available here: https://github.com/yueatsprograms/Stochastic_Depth
- Was it part of tf.contrib? (if so, where): No
Which API type would this fall under (layer, metric, optimizer, etc.)
layer
Who will benefit with this feature?
Anybody who wants to use stochastic depth to train deeper ResNets, or who wants to recreate the EfficientNet architecture (https://arxiv.org/abs/1905.11946). Anyone who wants to add Shake-Drop (https://arxiv.org/abs/1802.02375) to their network, can use this as base.
Any other info.
It is important to note that Stochastic Depth is not the same as Dropout with noise_shape=(1, 1, 1), as suggested in the tensorflow EfficientNet implementation (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/applications/efficientnet.py#L511).
Using dropout, the branch is simply kept as is during inference time, however for stochastic depth, the branch is re-scaled based on its survival probability, before being merged with the main network.