Skip to content

Conversation

@kohei-us
Copy link

@kohei-us kohei-us commented Mar 7, 2019

This is related to opencv/opencv#13989.

@dkurt
Copy link
Member

dkurt commented Mar 7, 2019

@kohei-us, Please show us how these test files are generated. Add a modified https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/onnx/generate_onnx_models.py

@kohei-us
Copy link
Author

kohei-us commented Mar 8, 2019

@kurt I used Keras to create a model and then convert it to onnx. I noticed that the script uses PyTorch to generate models. Would a use of Keras be okay in that script, or is PyTorch required?

@dkurt
Copy link
Member

dkurt commented Mar 8, 2019

@kohei-us, If it's possible to do all the conversions in python I'd like to suggest to add Keras model in this script. Please use try except scope to check if TensorFlow exists.

@kohei-us
Copy link
Author

kohei-us commented Mar 8, 2019

I've now stumbled upon one major issue.

Originally, I created a keras model and converted to onnx using an older version of onnxmltools, and that converter added the output_shape attribute to ConvTranspose, which is what this test case is designed to test.

However, since then the onnxmltools has changed quite a bit, and when using their latest code (it's now been split into a separate project called keras-onnx), the converted onnx model no longer contains the output_shape attribute... :-(

I have also tried using PyTorch to create an onnx model with ConvTranspose, but that one doesn't have output_shape either.

Hmm...

@dkurt
Copy link
Member

dkurt commented Mar 8, 2019

@kohei-us, May I ask you to add a line

std::cout << layerParams << std::endl;

at
https://github.com/opencv/opencv/blob/c3cf35ab63c04fb1d7b2f6760128f42c20cac0e1/modules/dnn/src/onnx/onnx_importer.cpp#L585-L594?

So we can check all the parameters except output_shape. Please show the output.

@kohei-us
Copy link
Author

@dkurt Sure thing. Here is the output for the new test data I've added:

dilation_h : 1
dilation_w : 1
group : 1
kernel_h : 4
kernel_w : 4
output_shape : -1, 1, 83, 83
pad_mode : SAME
stride_h : 3
stride_w : 3

@dkurt
Copy link
Member

dkurt commented Mar 11, 2019

@kohei-us, May I also ask you to specify which size has input blob for this layer?

@kohei-us
Copy link
Author

@dkurt Sure. It's (1, 1, 28, 28).

@dkurt
Copy link
Member

dkurt commented Mar 11, 2019

@kohei-us, This way deconvolution layer uses the following formula to compute output shape:

        else if (padMode == "SAME")
        {
            outH = stride.height * (inpH - 1) + 1 + adjustPad.height;
            outW = stride.width * (inpW - 1) + 1 + adjustPad.width;
        }

So

3 * (28 - 1) + 1 + 0 = 82

However user specified output shape 83 that means we need to add extra padding called adjust padding:

3 * (28 - 1) + 1 + 1 = 83

May I ask you to check how OpenCV manage it for TensorFlow models: https://github.com/opencv/opencv/blob/master/modules/dnn/src/tensorflow/tf_importer.cpp:

            const int outH = outShape.at<int>(1);
            const int outW = outShape.at<int>(2);
            if (layerParams.get<String>("pad_mode") == "SAME")
            {
                layerParams.set("adj_w", (outW - 1) % strideX);
                layerParams.set("adj_h", (outH - 1) % strideY);
            }
            else if (layerParams.get<String>("pad_mode") == "VALID")
            {
                layerParams.set("adj_w", (outW - kernelW) % strideX);
                layerParams.set("adj_h", (outH - kernelH) % strideY);
}

So you can just read output shape and compute adjust padding:

                layerParams.set("adj_w", (83 - 1) % 3);  // == 1
                layerParams.set("adj_h", (83 - 1) % 3);

@kohei-us kohei-us force-pushed the onnx-conv-transpose-output-shape branch from 50286a5 to 46f1cbd Compare March 26, 2019 14:50
@kohei-us kohei-us changed the base branch from master to 3.4 March 26, 2019 14:50
@opencv-pushbot opencv-pushbot merged commit 46f1cbd into opencv:3.4 Mar 26, 2019
@alalek alalek mentioned this pull request Mar 26, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants