|
27 | 27 |
|
28 | 28 |
|
29 | 29 | @tf.function |
30 | | -def correlation_cost(input_a, |
31 | | - input_b, |
32 | | - kernel_size, |
33 | | - max_displacement, |
34 | | - stride_1, |
35 | | - stride_2, |
36 | | - pad, |
37 | | - data_format='channels_last', |
38 | | - name=None): |
| 30 | +def _correlation_cost(input_a, |
| 31 | + input_b, |
| 32 | + kernel_size, |
| 33 | + max_displacement, |
| 34 | + stride_1, |
| 35 | + stride_2, |
| 36 | + pad, |
| 37 | + data_format='channels_last', |
| 38 | + name=None): |
39 | 39 | """Correlation Cost Volume computation. |
40 | 40 |
|
41 | 41 | "FlowNet: Learning Optical Flow with Convolutional Networks" |
@@ -141,6 +141,27 @@ def _correlation_cost_grad(op, grad_output): |
141 | 141 |
|
142 | 142 | @keras_utils.register_keras_custom_object |
143 | 143 | class CorrelationCost(tf.keras.layers.Layer): |
| 144 | + """Correlation Cost Layer. |
| 145 | +
|
| 146 | + This layer implements the correlation operation from FlowNet Learning |
| 147 | + Optical Flow with Convolutional Networks (Fischer et al.): |
| 148 | + https://arxiv.org/abs/1504.06 |
| 149 | +
|
| 150 | + Args: |
| 151 | + kernel_size: An integer specifying the height and width of the |
| 152 | + patch used to compute the per-patch costs. |
| 153 | + max_displacement: An integer specifying the maximum search radius |
| 154 | + for each position. |
| 155 | + stride_1: An integer specifying the stride length in the input. |
| 156 | + stride_2: An integer specifying the stride length in the patch. |
| 157 | + pad: An integer specifying the paddings in height and width. |
| 158 | + data_format: Specifies the data format. |
| 159 | + Possible values are: |
| 160 | + "channels_last" float [batch, height, width, channels] |
| 161 | + "channels_first" float [batch, channels, height, width] |
| 162 | + Defaults to `"channels_last"`. |
| 163 | + """ |
| 164 | + |
144 | 165 | def __init__(self, kernel_size, max_displacement, stride_1, stride_2, pad, |
145 | 166 | data_format, **kwargs): |
146 | 167 | self.kernel_size = kernel_size |
@@ -169,7 +190,7 @@ def call(self, inputs): |
169 | 190 | input_a = tf.convert_to_tensor(inputs[0]) |
170 | 191 | input_b = tf.convert_to_tensor(inputs[1]) |
171 | 192 |
|
172 | | - return correlation_cost( |
| 193 | + return _correlation_cost( |
173 | 194 | input_a, |
174 | 195 | input_b, |
175 | 196 | kernel_size=self.kernel_size, |
|
0 commit comments