From 2cad108ae28e891315cec8a0706d4cc5710ff27c Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Sun, 7 Jun 2020 11:05:32 +0800 Subject: [PATCH] use raw ops --- tensorflow_addons/image/transform_ops.py | 47 ++---------------------- 1 file changed, 3 insertions(+), 44 deletions(-) diff --git a/tensorflow_addons/image/transform_ops.py b/tensorflow_addons/image/transform_ops.py index 75ade66760..844ef4dcba 100644 --- a/tensorflow_addons/image/transform_ops.py +++ b/tensorflow_addons/image/transform_ops.py @@ -16,13 +16,11 @@ import tensorflow as tf from tensorflow_addons.image import utils as img_utils -from tensorflow_addons.utils.resource_loader import LazySO from tensorflow_addons.utils.types import TensorLike from tensorflow_addons.image.utils import wrap, unwrap from typing import Optional -_image_so = LazySO("custom_ops/image/_image_ops.so") _IMAGE_DTYPES = { tf.dtypes.uint8, @@ -34,7 +32,6 @@ } -@tf.function def transform( images: TensorLike, transforms: TensorLike, @@ -108,10 +105,10 @@ def transform( % len(transforms.get_shape()) ) - output = _image_so.ops.addons_image_projective_transform_v2( - images, - output_shape=output_shape, + output = tf.raw_ops.ImageProjectiveTransformV2( + images=images, transforms=transforms, + output_shape=output_shape, interpolation=interpolation.upper(), ) return img_utils.from_4D_image(output, original_ndims) @@ -271,44 +268,6 @@ def angles_to_projective_transforms( ) -@tf.RegisterGradient("Addons>ImageProjectiveTransformV2") -def _image_projective_transform_grad(op, grad): - """Computes the gradient for ImageProjectiveTransform.""" - images = op.inputs[0] - transforms = op.inputs[1] - interpolation = op.get_attr("interpolation") - - image_or_images = tf.convert_to_tensor(images, name="images") - transform_or_transforms = tf.convert_to_tensor( - transforms, name="transforms", dtype=tf.dtypes.float32 - ) - - if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: - raise ValueError("Invalid dtype %s." % image_or_images.dtype) - if len(transform_or_transforms.get_shape()) == 1: - transforms = transform_or_transforms[None] - elif len(transform_or_transforms.get_shape()) == 2: - transforms = transform_or_transforms - else: - transforms = transform_or_transforms - raise ValueError( - "transforms should have rank 1 or 2, but got rank %d" - % len(transforms.get_shape()) - ) - - # Invert transformations - transforms = flat_transforms_to_matrices(transforms=transforms) - inverse = tf.linalg.inv(transforms) - transforms = matrices_to_flat_transforms(inverse) - output = _image_so.ops.addons_image_projective_transform_v2( - images=grad, - transforms=transforms, - output_shape=tf.shape(image_or_images)[1:3], - interpolation=interpolation, - ) - return [output, None, None] - - def rotate( images: TensorLike, angles: TensorLike,