diff --git a/tensorflow_addons/layers/optical_flow.py b/tensorflow_addons/layers/optical_flow.py index d8ef483fa8..7fcf1c0971 100644 --- a/tensorflow_addons/layers/optical_flow.py +++ b/tensorflow_addons/layers/optical_flow.py @@ -27,15 +27,15 @@ @tf.function -def correlation_cost(input_a, - input_b, - kernel_size, - max_displacement, - stride_1, - stride_2, - pad, - data_format='channels_last', - name=None): +def _correlation_cost(input_a, + input_b, + kernel_size, + max_displacement, + stride_1, + stride_2, + pad, + data_format='channels_last', + name=None): """Correlation Cost Volume computation. "FlowNet: Learning Optical Flow with Convolutional Networks" @@ -141,6 +141,27 @@ def _correlation_cost_grad(op, grad_output): @keras_utils.register_keras_custom_object class CorrelationCost(tf.keras.layers.Layer): + """Correlation Cost Layer. + + This layer implements the correlation operation from FlowNet Learning + Optical Flow with Convolutional Networks (Fischer et al.): + https://arxiv.org/abs/1504.06 + + Args: + kernel_size: An integer specifying the height and width of the + patch used to compute the per-patch costs. + max_displacement: An integer specifying the maximum search radius + for each position. + stride_1: An integer specifying the stride length in the input. + stride_2: An integer specifying the stride length in the patch. + pad: An integer specifying the paddings in height and width. + data_format: Specifies the data format. + Possible values are: + "channels_last" float [batch, height, width, channels] + "channels_first" float [batch, channels, height, width] + Defaults to `"channels_last"`. + """ + def __init__(self, kernel_size, max_displacement, stride_1, stride_2, pad, data_format, **kwargs): self.kernel_size = kernel_size @@ -169,7 +190,7 @@ def call(self, inputs): input_a = tf.convert_to_tensor(inputs[0]) input_b = tf.convert_to_tensor(inputs[1]) - return correlation_cost( + return _correlation_cost( input_a, input_b, kernel_size=self.kernel_size,