@@ -68,6 +68,20 @@ def get_config(self):
6868 return {** base_config , ** config }
6969
7070
71+ def normalize_padding (value ):
72+ """A copy of tensorflow.python.keras.util."""
73+ if isinstance (value , (list , tuple )):
74+ return value
75+ padding = value .lower ()
76+ if padding not in {"valid" , "same" , "causal" }:
77+ raise ValueError (
78+ "The `padding` argument must be a list/tuple or one of "
79+ '"valid", "same" (or "causal", only for `Conv1D). '
80+ "Received: " + str (padding )
81+ )
82+ return padding
83+
84+
7185def normalize_data_format (value ):
7286 if value is None :
7387 value = tf .keras .backend .image_data_format ()
@@ -143,6 +157,34 @@ def normalize_tuple(value, n, name):
143157 return value_tuple
144158
145159
160+ def conv_output_length (input_length , filter_size , padding , stride , dilation = 1 ):
161+ """Determines output length of a convolution given input length.
162+
163+ A copy of tensorflow.python.keras.util.
164+
165+ Arguments:
166+ input_length: integer.
167+ filter_size: integer.
168+ padding: one of "same", "valid", "full", "causal"
169+ stride: integer.
170+ dilation: dilation rate, integer.
171+
172+ Returns:
173+ The output length (integer).
174+ """
175+ if input_length is None :
176+ return None
177+ assert padding in {"same" , "valid" , "full" , "causal" }
178+ dilated_filter_size = filter_size + (filter_size - 1 ) * (dilation - 1 )
179+ if padding in ["same" , "causal" ]:
180+ output_length = input_length
181+ elif padding == "valid" :
182+ output_length = input_length - dilated_filter_size + 1
183+ elif padding == "full" :
184+ output_length = input_length + dilated_filter_size - 1
185+ return (output_length + stride - 1 ) // stride
186+
187+
146188def _hasattr (obj , attr_name ):
147189 # If possible, avoid retrieving the attribute as the object might run some
148190 # lazy computation in it.
0 commit comments