Skip to content

Commit cb52e05

Browse files
PyExtremefacaiy
authored andcommitted
Add docstring correlation cost (#514)
* Add Docstring CorrelationCost * Add CorrelationCost Documentation * Small reformat
1 parent cffee80 commit cb52e05

File tree

1 file changed

+31
-10
lines changed

1 file changed

+31
-10
lines changed

tensorflow_addons/layers/optical_flow.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@
2727

2828

2929
@tf.function
30-
def correlation_cost(input_a,
31-
input_b,
32-
kernel_size,
33-
max_displacement,
34-
stride_1,
35-
stride_2,
36-
pad,
37-
data_format='channels_last',
38-
name=None):
30+
def _correlation_cost(input_a,
31+
input_b,
32+
kernel_size,
33+
max_displacement,
34+
stride_1,
35+
stride_2,
36+
pad,
37+
data_format='channels_last',
38+
name=None):
3939
"""Correlation Cost Volume computation.
4040
4141
"FlowNet: Learning Optical Flow with Convolutional Networks"
@@ -141,6 +141,27 @@ def _correlation_cost_grad(op, grad_output):
141141

142142
@keras_utils.register_keras_custom_object
143143
class CorrelationCost(tf.keras.layers.Layer):
144+
"""Correlation Cost Layer.
145+
146+
This layer implements the correlation operation from FlowNet Learning
147+
Optical Flow with Convolutional Networks (Fischer et al.):
148+
https://arxiv.org/abs/1504.06
149+
150+
Args:
151+
kernel_size: An integer specifying the height and width of the
152+
patch used to compute the per-patch costs.
153+
max_displacement: An integer specifying the maximum search radius
154+
for each position.
155+
stride_1: An integer specifying the stride length in the input.
156+
stride_2: An integer specifying the stride length in the patch.
157+
pad: An integer specifying the paddings in height and width.
158+
data_format: Specifies the data format.
159+
Possible values are:
160+
"channels_last" float [batch, height, width, channels]
161+
"channels_first" float [batch, channels, height, width]
162+
Defaults to `"channels_last"`.
163+
"""
164+
144165
def __init__(self, kernel_size, max_displacement, stride_1, stride_2, pad,
145166
data_format, **kwargs):
146167
self.kernel_size = kernel_size
@@ -169,7 +190,7 @@ def call(self, inputs):
169190
input_a = tf.convert_to_tensor(inputs[0])
170191
input_b = tf.convert_to_tensor(inputs[1])
171192

172-
return correlation_cost(
193+
return _correlation_cost(
173194
input_a,
174195
input_b,
175196
kernel_size=self.kernel_size,

0 commit comments

Comments
 (0)