diff --git a/tensorflow_addons/activations/softshrink.py b/tensorflow_addons/activations/softshrink.py index 861c35dcad..577e0712f9 100644 --- a/tensorflow_addons/activations/softshrink.py +++ b/tensorflow_addons/activations/softshrink.py @@ -69,10 +69,6 @@ def _softshrink_py(x, lower, upper): " not be higher than the value " "variable upper, which is {} .".format(lower, upper) ) - mask_lower = x < lower - mask_upper = upper < x - mask_middle = tf.logical_not(tf.logical_or(mask_lower, mask_upper)) - mask_lower = tf.cast(mask_lower, x.dtype) - mask_upper = tf.cast(mask_upper, x.dtype) - mask_middle = tf.cast(mask_middle, x.dtype) - return x * (1 - mask_middle) - mask_lower * lower - mask_upper * upper + values_below_lower = tf.where(x < lower, x - lower, 0) + values_above_upper = tf.where(upper < x, x - upper, 0) + return values_below_lower + values_above_upper