diff --git a/tensorflow_addons/image/compose_ops.py b/tensorflow_addons/image/compose_ops.py index 44c477760d..827b360cf8 100644 --- a/tensorflow_addons/image/compose_ops.py +++ b/tensorflow_addons/image/compose_ops.py @@ -16,10 +16,10 @@ import tensorflow as tf -from tensorflow_addons.utils.types import TensorLike +from tensorflow_addons.utils.types import TensorLike, Number -def blend(image1: TensorLike, image2: TensorLike, factor: float) -> TensorLike: +def blend(image1: TensorLike, image2: TensorLike, factor: Number) -> tf.Tensor: """Blend image1 and image2 using 'factor'. Factor can be above 0.0. A value of 0.0 means only image1 is used. @@ -30,17 +30,19 @@ def blend(image1: TensorLike, image2: TensorLike, factor: float) -> TensorLike: between 0 and 255. Args: - image1: An image Tensor of type uint8. - image2: An image Tensor of type uint8. - factor: A floating point value above 0.0. + image1: An image Tensor of shape (num_rows, num_columns, + num_channels) (HWC), or (num_rows, num_columns) (HW), + or (num_channels, num_rows, num_columns). + image2: An image Tensor of shape (num_rows, num_columns, + num_channels) (HWC), or (num_rows, num_columns) (HW), + or (num_channels, num_rows, num_columns). + factor: A floating point value or Tensor of type tf.float32 above 0.0. Returns: - A blended image Tensor of type uint8. + A blended image Tensor of tf.float32. """ with tf.name_scope("blend"): - if image1.dtype != tf.uint8 or image2.dtype != tf.uint8: - raise ValueError("Images must have dtype tf.uint8") if factor == 0.0: return tf.convert_to_tensor(image1) @@ -60,10 +62,10 @@ def blend(image1: TensorLike, image2: TensorLike, factor: float) -> TensorLike: if factor > 0.0 and factor < 1.0: # Interpolation means we always stay within 0 and 255. temp = tf.round(temp) - return tf.cast(temp, tf.dtypes.uint8) + return temp # Extrapolate: # # We need to clip and then cast. temp = tf.round(tf.clip_by_value(temp, 0.0, 255.0)) - return tf.cast(temp, tf.dtypes.uint8) + return temp diff --git a/tensorflow_addons/image/compose_ops_test.py b/tensorflow_addons/image/compose_ops_test.py index c93852bd43..2968295336 100644 --- a/tensorflow_addons/image/compose_ops_test.py +++ b/tensorflow_addons/image/compose_ops_test.py @@ -23,6 +23,11 @@ _DTYPES = { tf.dtypes.uint8, + tf.dtypes.int32, + tf.dtypes.int64, + tf.dtypes.float16, + tf.dtypes.float32, + tf.dtypes.float64, } @@ -34,11 +39,12 @@ def blend_np(image1, image2, factor): temp = image1 + scaled if factor >= 0.0 and factor <= 1.0: temp = np.round(temp) - return temp.astype("uint8") + return temp temp = np.round(np.clip(temp, 0.0, 255.0)) - return temp.astype("uint8") + return temp +@pytest.mark.usefixtures("maybe_run_functions_eagerly") @pytest.mark.parametrize("dtype", _DTYPES) def test_blend(dtype): image1 = tf.constant( @@ -64,13 +70,17 @@ def test_blend(dtype): ], ) - image1 = np.random.randint(0, 255, (4, 4, 3), np.uint8) - image2 = np.random.randint(0, 255, (4, 4, 3), np.uint8) + np.random.seed(0) + image1 = np.random.randint(0, 255, (3, 5, 5), np.uint8) + image2 = np.random.randint(0, 255, (3, 5, 5), np.uint8) + tf.random.set_seed(0) + factor = tf.random.uniform(shape=[], maxval=1, dtype=tf.dtypes.float32, seed=0) blended = compose_ops.blend( - tf.convert_to_tensor(image1), tf.convert_to_tensor(image2), 0.35 + tf.convert_to_tensor(image1), tf.convert_to_tensor(image2), factor ).numpy() - expected = blend_np(image1, image2, 0.35) + expected = blend_np(image1, image2, factor.numpy()) np.testing.assert_equal(blended, expected) + assert blended.dtype == expected.dtype if __name__ == "__main__":