Skip to content

Commit a64e05b

Browse files
committed
Use keras_utils instead of TensorFlow private API
1 parent 536fc16 commit a64e05b

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

tensorflow_addons/layers/deformable_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typeguard import typechecked
2020
from tensorflow_addons.utils import types
2121
from tensorflow_addons.utils.resource_loader import LazySO
22-
from tensorflow.python.keras.utils import conv_utils
22+
import tensorflow_addons.utils.keras_utils as conv_utils
2323

2424
_deformable_conv2d_ops_so = LazySO("custom_ops/layers/_deformable_conv2d_ops.so")
2525

tensorflow_addons/utils/keras_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,20 @@ def get_config(self):
6868
return {**base_config, **config}
6969

7070

71+
def normalize_padding(value):
72+
"""A copy of tensorflow.python.keras.util."""
73+
if isinstance(value, (list, tuple)):
74+
return value
75+
padding = value.lower()
76+
if padding not in {"valid", "same", "causal"}:
77+
raise ValueError(
78+
"The `padding` argument must be a list/tuple or one of "
79+
'"valid", "same" (or "causal", only for `Conv1D). '
80+
"Received: " + str(padding)
81+
)
82+
return padding
83+
84+
7185
def normalize_data_format(value):
7286
if value is None:
7387
value = tf.keras.backend.image_data_format()
@@ -143,6 +157,34 @@ def normalize_tuple(value, n, name):
143157
return value_tuple
144158

145159

160+
def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
161+
"""Determines output length of a convolution given input length.
162+
163+
A copy of tensorflow.python.keras.util.
164+
165+
Arguments:
166+
input_length: integer.
167+
filter_size: integer.
168+
padding: one of "same", "valid", "full", "causal"
169+
stride: integer.
170+
dilation: dilation rate, integer.
171+
172+
Returns:
173+
The output length (integer).
174+
"""
175+
if input_length is None:
176+
return None
177+
assert padding in {"same", "valid", "full", "causal"}
178+
dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
179+
if padding in ["same", "causal"]:
180+
output_length = input_length
181+
elif padding == "valid":
182+
output_length = input_length - dilated_filter_size + 1
183+
elif padding == "full":
184+
output_length = input_length + dilated_filter_size - 1
185+
return (output_length + stride - 1) // stride
186+
187+
146188
def _hasattr(obj, attr_name):
147189
# If possible, avoid retrieving the attribute as the object might run some
148190
# lazy computation in it.

0 commit comments

Comments
 (0)