From 1652f44929e4eac6126b482f0391320220cde4be Mon Sep 17 00:00:00 2001 From: gabrieldemarmiesse Date: Sun, 15 Mar 2020 15:25:07 +0000 Subject: [PATCH] Simpler python implementation of softshrink. --- tensorflow_addons/activations/softshrink.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tensorflow_addons/activations/softshrink.py b/tensorflow_addons/activations/softshrink.py index 861c35dcad..577e0712f9 100644 --- a/tensorflow_addons/activations/softshrink.py +++ b/tensorflow_addons/activations/softshrink.py @@ -69,10 +69,6 @@ def _softshrink_py(x, lower, upper): " not be higher than the value " "variable upper, which is {} .".format(lower, upper) ) - mask_lower = x < lower - mask_upper = upper < x - mask_middle = tf.logical_not(tf.logical_or(mask_lower, mask_upper)) - mask_lower = tf.cast(mask_lower, x.dtype) - mask_upper = tf.cast(mask_upper, x.dtype) - mask_middle = tf.cast(mask_middle, x.dtype) - return x * (1 - mask_middle) - mask_lower * lower - mask_upper * upper + values_below_lower = tf.where(x < lower, x - lower, 0) + values_above_upper = tf.where(upper < x, x - upper, 0) + return values_below_lower + values_above_upper