diff --git a/modules/README.md b/modules/README.md index 26387d2e34c..3f8dc5e42a3 100644 --- a/modules/README.md +++ b/modules/README.md @@ -22,7 +22,9 @@ $ cmake -D OPENCV_EXTRA_MODULES_PATH=/modules -D BUILD_opencv_/modules -Dopencv_dnn_superres=ON +``` + +## Models + +There are four models which are trained. (Not yet implemented!!) + +#### EDSR + +- Size of the model: +- This model was trained for <> iterations with a batch size of <> +- Link to model: +- Advantage: +- Disadvantage: +- Speed: + +#### ESPCN + +Trained models can be downloaded from [here](https://github.com/fannymonori/TF-ESPCN/tree/master/export). + +- Size of the model: ~100kb +- This model was trained for 100 iterations with a batch size of 32 +- Link to implementation code: https://github.com/fannymonori/TF-ESPCN +- x2, x3, x4 trained models available +- Advantage: It is tiny, and fast, and still perform well. +- Disadvantage: Perform worse visually than newer, more robust models. +- Speed: + +#### FSRCNN + +- Size of the model: +- This model was trained for <> iterations with a batch size of <> +- Link to model: +- Advantage: +- Disadvantage: +- Speed: + +#### LapSRN + +Trained models can be downloaded from [here](https://github.com/fannymonori/TF-LapSRN/tree/master/export). + +- Size of the model: between 1-5Mb +- This model was trained for ~50 iterations with a batch size of 32 +- Link to implementation code: https://github.com/fannymonori/TF-LAPSRN +- x2, x4, x8 trained models available +- Advantage: The model can do multi-scale super-resolution with one forward pass. It can now support 2x, 4x, 8x, and [2x, 4x] and [2x, 4x, 8x] super-resolution. +- Disadvantage: It is a slower model. +- Speed \ No newline at end of file diff --git a/modules/dnn_superres/include/opencv2/dnn_superres.hpp b/modules/dnn_superres/include/opencv2/dnn_superres.hpp new file mode 100644 index 00000000000..f12c20ddd5c --- /dev/null +++ b/modules/dnn_superres/include/opencv2/dnn_superres.hpp @@ -0,0 +1,158 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef _OPENCV_DNN_SUPERRES_DNNSUPERRESIMPL_HPP_ +#define _OPENCV_DNN_SUPERRES_DNNSUPERRESIMPL_HPP_ + +#include +#include +#include "opencv2/dnn.hpp" + +/** @defgroup dnn_superres DNN used for super resolution + +This module contains functionality for upscaling an image via convolutional neural networks. +The following four models are implemented: + +- EDSR +- ESPCN +- FSRCNN +- LapSRN + +There is also functionality for simply upscaling by bilinear or bicubic interpolation. + +*/ + +namespace cv +{ +namespace dnn +{ +namespace dnn_superres +{ + //! @addtogroup dnn_superres + //! @{ + + /** @brief A class to upscale images via convolutional neural networks. + The following four models are implemented: + + - edsr + - espcn + - fsrcnn + - lapsrn + */ + class CV_EXPORTS DnnSuperResImpl + { + private: + + /** @brief Net which holds the desired neural network + */ + Net net; + + std::string alg; //algorithm + + int sc; //scale factor + + /// @private + static int layer_loaded; + + void registerLayers(); + + void preprocess(const Mat inpImg, Mat &outpImg); + + void reconstruct_YCrCb(const Mat inpImg, const Mat origImg, Mat &outpImg, int scale); + + void reconstruct_YCrCb(const Mat inpImg, const Mat origImg, Mat &outpImg); + + void preprocess_YCrCb(const Mat inpImg, Mat &outpImg); + + public: + + /** @brief Empty constructor + */ + DnnSuperResImpl(); + + /** @brief Constructor which immediately sets the desired model + @param algo String containing one of the desired models: + - __edsr__ + - __espcn__ + - __fsrcnn__ + - __lapsrn__ + @param scale Integer specifying the upscale factor + */ + DnnSuperResImpl(std::string algo, int scale); + + /** @brief Read the model from the given path + @param path Path to the model file. + */ + void readModel(std::string path); + + /** @brief Read the model from the given path + @param weights Path to the model weights file. + @param definition Path to the model definition file. + */ + void readModel(std::string weights, std::string definition); + + /** @brief Set desired model + @param algo String containing one of the desired models: + - __edsr__ + - __espcn__ + - __fsrcnn__ + - __lapsrn__ + @param scale Integer specifying the upscale factor + */ + void setModel(std::string algo, int scale); + + /** @brief Upsample via neural network + @param img Image to upscale + @param img_new Destination upscaled image + */ + void upsample(Mat img, Mat &img_new); + + /** @brief Upsample via neural network of multiple outputs + @param img Image to upscale + @param imgs_new Destination upscaled images + @param scale_factors Scaling factors of the output nodes + @param node_names Names of the output nodes in the neural network + */ + void upsample_multioutput(Mat img, std::vector &imgs_new, std::vector scale_factors, std::vector node_names); + + /** @brief Returns the scale factor of the model: + @return Current scale factor. + */ + int getScale(); + + /** @brief Returns the scale factor of the model: + @return Current algorithm. + */ + std::string getAlgorithm(); + + private: + /** @brief Class for importing DepthToSpace layer from the ESPCN model + */ + class DepthToSpace CV_FINAL : public cv::dnn::Layer + { + public: + + /// @private + DepthToSpace(const cv::dnn::LayerParams ¶ms); + + /// @private + static cv::Ptr create(cv::dnn::LayerParams& params); + + /// @private + virtual bool getMemoryShapes(const std::vector > &inputs, + const int, + std::vector > &outputs, + std::vector > &) const CV_OVERRIDE; + + /// @private + virtual void forward(cv::InputArrayOfArrays inputs_arr, + cv::OutputArrayOfArrays outputs_arr, + cv::OutputArrayOfArrays) CV_OVERRIDE; + }; + }; + //! @} +} +} +} +#endif \ No newline at end of file diff --git a/modules/dnn_superres/samples/butterfly.png b/modules/dnn_superres/samples/butterfly.png new file mode 100644 index 00000000000..90c31f1759c Binary files /dev/null and b/modules/dnn_superres/samples/butterfly.png differ diff --git a/modules/dnn_superres/samples/dnn_superres.cpp b/modules/dnn_superres/samples/dnn_superres.cpp new file mode 100644 index 00000000000..3aadfb8f6dc --- /dev/null +++ b/modules/dnn_superres/samples/dnn_superres.cpp @@ -0,0 +1,76 @@ +#include + +#include + +using namespace std; +using namespace cv; +using namespace dnn; +using namespace dnn_superres; + +int main(int argc, char *argv[]) +{ + // Check for valid command line arguments, print usage + // if insufficient arguments were given. + if (argc < 4) { + cout << "usage: Arg 1: image | Path to image" << endl; + cout << "\t Arg 2: algorithm | bilinear, bicubic, edsr, espcn, fsrcnn or lapsrn" << endl; + cout << "\t Arg 3: scale | 2, 3 or 4 \n"; + cout << "\t Arg 4: path to model file \n"; + return -1; + } + + string img_path = string(argv[1]); + string algorithm = string(argv[2]); + int scale = atoi(argv[3]); + string path = ""; + + if( argc > 4) + path = string(argv[4]); + + // Load the image + Mat img = cv::imread(img_path); + Mat original_img(img); + if (img.empty()) + { + std::cerr << "Couldn't load image: " << img << "\n"; + return -2; + } + + //Make dnn super resolution instance + DnnSuperResImpl sr; + + Mat img_new; + + if(algorithm == "bilinear"){ + resize(img, img_new, Size(), scale, scale, 2); + } + else if(algorithm == "bicubic") + { + resize(img, img_new, Size(), scale, scale, 3); + } + else if(algorithm == "espcn" || algorithm == "lapsrn") + { + sr.readModel(path); + sr.setModel(algorithm, scale); + sr.upsample(img, img_new); + } + else{ //one of the neural networks + sr.setModel(algorithm, scale); + sr.upsample(img, img_new); + } + + if (img_new.empty()) + { + std::cerr << "Upsampling failed. \n"; + return -3; + } + cout << "Upsampling succeeded. \n"; + + // Display image + cv::namedWindow("Initial Image", WINDOW_AUTOSIZE); + cv::imshow("Initial Image", img_new); + //cv::imwrite("./saved.jpg", img_new); + cv::waitKey(0); + + return 0; +} \ No newline at end of file diff --git a/modules/dnn_superres/samples/dnn_superres_multioutput.cpp b/modules/dnn_superres/samples/dnn_superres_multioutput.cpp new file mode 100644 index 00000000000..829dc68347b --- /dev/null +++ b/modules/dnn_superres/samples/dnn_superres_multioutput.cpp @@ -0,0 +1,68 @@ +#include +#include +#include + +using namespace std; +using namespace cv; +using namespace dnn; +using namespace dnn_superres; + +int main(int argc, char *argv[]) +{ + // Check for valid command line arguments, print usage + // if insufficient arguments were given. + if (argc < 4) { + cout << "usage: Arg 1: image | Path to image" << endl; + cout << "\t Arg 2: scales in a format of 2,4,8\n"; + cout << "\t Arg 3: output node names in a format of nchw_output_0,nchw_output_1\n"; + cout << "\t Arg 4: path to model file \n"; + return -1; + } + + string img_path = string(argv[1]); + string scales_str = string(argv[2]); + string output_names_str = string(argv[3]); + std::string path = string(argv[4]); + + std::stringstream ss(scales_str); + std::vector scales; + std::string token; + char delim = ','; + while (std::getline(ss, token, delim)) { + scales.push_back(atoi(token.c_str())); + } + + ss = std::stringstream(output_names_str); + std::vector node_names; + while (std::getline(ss, token, delim)) { + node_names.push_back(token); + } + + // Load the image + Mat img = cv::imread(img_path); + Mat original_img(img); + if (img.empty()) + { + std::cerr << "Couldn't load image: " << img << "\n"; + return -2; + } + + //Make dnn super resolution instance + DnnSuperResImpl sr; + int scale = *max_element(scales.begin(), scales.end()); + std::vector outputs; + sr.readModel(path); + sr.setModel("lapsrn", scale); + + sr.upsample_multioutput(img, outputs, scales, node_names); + + for(unsigned int i = 0; i < outputs.size(); i++) + { + cv::namedWindow("Upsampled image", WINDOW_AUTOSIZE); + cv::imshow("Upsampled image", outputs[i]); + //cv::imwrite("./saved.jpg", img_new); + cv::waitKey(0); + } + + return 0; +} \ No newline at end of file diff --git a/modules/dnn_superres/src/dnn_superres.cpp b/modules/dnn_superres/src/dnn_superres.cpp new file mode 100644 index 00000000000..ac70b488af3 --- /dev/null +++ b/modules/dnn_superres/src/dnn_superres.cpp @@ -0,0 +1,319 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "precomp.hpp" + +#include "opencv2/dnn_superres.hpp" + +namespace cv +{ + namespace dnn + { + namespace dnn_superres + { + + int DnnSuperResImpl::layer_loaded = 0; + + DnnSuperResImpl::DnnSuperResImpl() + { + if( !this->layer_loaded ) + { + layer_loaded = true; + registerLayers(); + } + } + + DnnSuperResImpl::DnnSuperResImpl(std::string algo, int scale) : alg(algo), sc(scale) + { + if( !this->layer_loaded ) + { + layer_loaded = true; + registerLayers(); + } + } + + void DnnSuperResImpl::registerLayers() + { + //Register custom layer that implements pixel shuffling + std::string name = "DepthToSpace"; + dnn::LayerParams layerParams = dnn::LayerParams(); + cv::dnn::LayerFactory::registerLayer("DepthToSpace", DepthToSpace::create); + } + + void DnnSuperResImpl::readModel(std::string path) + { + if ( path.size() ) + { + this->net = dnn::readNetFromTensorflow(path); + std::cout << "Successfully loaded model. \n"; + } + else + { + std::cout << "Could not load model. \n"; + } + } + + void DnnSuperResImpl::readModel(std::string weights, std::string definition) + { + if ( weights.size() && definition.size() ) + { + this->net = dnn::readNetFromTensorflow(weights, definition); + std::cout << "Successfully loaded model. \n"; + } + else + { + std::cout << "Could not load model. \n"; + } + } + + void DnnSuperResImpl::setModel(std::string algo, int scale) + { + this->sc = scale; + this->alg = algo; + } + + void DnnSuperResImpl::upsample(Mat img, Mat &img_new) + { + if( !net.empty() ) + { + if ( this->alg == "espcn" || this->alg == "lapsrn" ) + { + //Preprocess the image: convert to YCrCb float image and normalize + Mat preproc_img; + preprocess_YCrCb(img, preproc_img); + + //Split the image: only the Y channel is used for inference + Mat ycbcr_channels[3]; + split(preproc_img, ycbcr_channels); + + Mat Y = ycbcr_channels[0]; + + //Create blob from image so it has size 1,1,Width,Height + cv::Mat blob; + dnn::blobFromImage(Y, blob, 1.0); + + //Get the HR output + this->net.setInput(blob); + Mat blob_output = this->net.forward(); + + //Convert from blob + std::vector model_outs; + dnn::imagesFromBlob(blob_output, model_outs); + Mat out_img = model_outs[0]; + + //Reconstruct: upscale the Cr and Cb space and merge the three layer + reconstruct_YCrCb(out_img, preproc_img, img_new, this->sc); + } + else + { + //get blob + //Mat blob = blobFromImage(img, 1.0); + //std::cout << "Made a blob. \n"; + + //get prediction + //net.setInput(blob); + //img_new = net.forward(); + //std::cout << "Made a Prediction. \n"; + } + } + else + { + std::cout << "Model not specified. Please set model via setModel(). \n"; + } + } + + void DnnSuperResImpl::upsample_multioutput(Mat img, std::vector &imgs_new, std::vector scale_factors, std::vector node_names) + { + CV_Assert(scale_factors.size() == node_names.size()); + CV_Assert(!scale_factors.empty()); + CV_Assert(!node_names.empty()); + + if ( this->alg != "lapsrn" ) + { + std::cout << "Only LapSRN support multiscale upsampling for now!" << std::endl; + return; + } + + if( !net.empty() ) + { + if ( this->alg == "lapsrn" ) + { + Mat orig = img; + + //Preprocess the image: convert to YCrCb float image and normalize + Mat preproc_img; + preprocess_YCrCb(orig, preproc_img); + + //Split the image: only the Y channel is used for inference + Mat ycbcr_channels[3]; + split(preproc_img, ycbcr_channels); + + Mat Y = ycbcr_channels[0]; + + //Create blob from image so it has size 1,1,Width,Height + cv::Mat blob; + dnn::blobFromImage(Y, blob, 1.0); + + //Get the HR outputs + std::vector outputs_blobs; + this->net.setInput(blob); + this->net.forward(outputs_blobs, node_names); + + for(unsigned int i = 0; i < scale_factors.size(); i++) + { + std::vector model_outs; + dnn::imagesFromBlob(outputs_blobs[i], model_outs); + Mat out_img = model_outs[0]; + Mat reconstructed; + + reconstruct_YCrCb(out_img, preproc_img, reconstructed, scale_factors[i]); + + imgs_new.push_back(reconstructed); + } + } + } + else + { + std::cout << "Model not specified. Please set model via setModel(). \n"; + } + } + + int DnnSuperResImpl::getScale() + { + return this->sc; + } + + std::string DnnSuperResImpl::getAlgorithm() + { + return this->alg; + } + + void DnnSuperResImpl::preprocess_YCrCb(const Mat inpImg, Mat &outImg) + { + if ( inpImg.type() == CV_8UC1 ) + { + Mat ycrcb; + inpImg.convertTo(outImg, CV_32F, 1.0 / 255.0); + } + else if ( inpImg.type() == CV_32FC1 ) + { + Mat ycrcb; + inpImg.convertTo(outImg, CV_32F, 1.0 / 255.0); + } + else if ( inpImg.type() == CV_32FC3 ) + { + Mat img_float; + inpImg.convertTo(img_float, CV_32F, 1.0 / 255.0); + cvtColor(img_float, outImg, COLOR_BGR2YCrCb); + } + else if ( inpImg.type() == CV_8UC3 ) + { + Mat ycrcb; + cvtColor(inpImg, ycrcb, COLOR_BGR2YCrCb); + ycrcb.convertTo(outImg, CV_32F, 1.0 / 255.0); + } + else + { + std::cout << "Not supported image type!" << std::endl; + } + } + + void DnnSuperResImpl::reconstruct_YCrCb(const Mat inpImg, const Mat origImg, Mat &outImg, int scale) + { + if ( origImg.type() == CV_32FC3 ) + { + Mat orig_channels[3]; + split(origImg, orig_channels); + + Mat Cr, Cb; + cv::resize(orig_channels[1], Cr, cv::Size(), scale, scale); + cv::resize(orig_channels[2], Cb, cv::Size(), scale, scale); + + std::vector channels; + channels.push_back(inpImg); + channels.push_back(Cr); + channels.push_back(Cb); + + Mat merged_img; + merge(channels, merged_img); + + Mat merged_8u_img; + merged_img.convertTo(merged_8u_img, CV_8U, 255.0); + + cvtColor(merged_8u_img, outImg, COLOR_YCrCb2BGR); + } + else if ( origImg.type() == CV_32FC1 ) + { + inpImg.convertTo(outImg, CV_8U, 255.0); + } + else + { + std::cout << "Not supported image type!" << std::endl; + } + } + + DnnSuperResImpl::DepthToSpace::DepthToSpace(const cv::dnn::LayerParams ¶ms) : Layer(params) + { + + } + + cv::Ptr DnnSuperResImpl::DepthToSpace::create(cv::dnn::LayerParams ¶ms) + { + return cv::Ptr(new DepthToSpace(params)); + } + + bool DnnSuperResImpl::DepthToSpace::getMemoryShapes(const std::vector > &inputs, + const int, + std::vector > &outputs, + std::vector > &) const + { + std::vector outShape(4); + outShape[0] = inputs[0][0]; + outShape[1] = 1; + outShape[2] = static_cast(sqrt(inputs[0][1])) * inputs[0][2]; + outShape[3] = static_cast(sqrt(inputs[0][1])) * inputs[0][3]; + + outputs.assign(4, outShape); + + return false; + } + + void DnnSuperResImpl::DepthToSpace::forward(cv::InputArrayOfArrays inputs_arr, + cv::OutputArrayOfArrays outputs_arr, + cv::OutputArrayOfArrays) + { + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + cv::Mat &inp = inputs[0]; + cv::Mat &out = outputs[0]; + const float *inpData = (float *) inp.data; + float *outData = (float *) out.data; + + const int height = out.size[2]; + const int width = out.size[3]; + + const int inpHeight = inp.size[2]; + const int inpWidth = inp.size[3]; + + int scale = int(sqrt(inp.size[1])); + + int count = 0; + for (int y = 0; y < height; y++) + { + for (int x = 0; x < width; x++) + { + int x_coord = static_cast(floor((y / scale))); + int y_coord = static_cast(floor((x / scale))); + int c_coord = scale * (y % scale) + (x % scale); + + int index = (((c_coord * inpHeight) + x_coord) * inpWidth) + y_coord; + outData[count] = inpData[index]; + count = count + 1; + } + } + } + } + } +} diff --git a/modules/dnn_superres/src/precomp.hpp b/modules/dnn_superres/src/precomp.hpp new file mode 100644 index 00000000000..f82eb2eaf1b --- /dev/null +++ b/modules/dnn_superres/src/precomp.hpp @@ -0,0 +1,17 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. +#ifndef __OPENCV_DNN_SUPERRES_PRECOMP_HPP__ +#define __OPENCV_DNN_SUPERRES_PRECOMP_HPP__ + +#include +#include +#include +#include +#include +#include +#include + +#include "opencv2/core.hpp" + +#endif // __OPENCV_DNN_SUPERRES_PRECOMP_HPP__ diff --git a/modules/dnn_superres/test/test_dnn_superres.cpp b/modules/dnn_superres/test/test_dnn_superres.cpp new file mode 100644 index 00000000000..57dc64321aa --- /dev/null +++ b/modules/dnn_superres/test/test_dnn_superres.cpp @@ -0,0 +1,215 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "test_precomp.hpp" + +namespace opencv_test +{ + namespace + { + + const std::string DNN_SUPERRES_DIR = "dnn_superres"; + const std::string IMAGE_FILENAME = "butterfly.png"; + + + /****************************************************************************************\ + * Test single output models * + \****************************************************************************************/ + + class CV_DnnSuperResSingleOutputTest : public cvtest::BaseTest + { + public: + CV_DnnSuperResSingleOutputTest(); + + protected: + Ptr dnn_sr; + + virtual void run(int); + + void runOneModel(std::string algorithm, int scale, std::string model_filename); + }; + + void CV_DnnSuperResSingleOutputTest::runOneModel(std::string algorithm, int scale, std::string model_filename) + { + std::string path = std::string(ts->get_data_path()) + DNN_SUPERRES_DIR + "/" + IMAGE_FILENAME; + + Mat img = imread(path); + if (img.empty()) + { + ts->printf(cvtest::TS::LOG, "Test image not found!\n"); + ts->set_failed_test_info(cvtest::TS::FAIL_INVALID_TEST_DATA); + return; + } + + std::string pb_path = std::string(ts->get_data_path()) + DNN_SUPERRES_DIR + "/" + model_filename; + + this->dnn_sr->readModel(pb_path); + + this->dnn_sr->setModel(algorithm, scale); + + if (this->dnn_sr->getScale() != scale) + { + ts->printf(cvtest::TS::LOG, + "Scale factor could not be set for scale algorithm %s and scale factor %d!\n", + algorithm.c_str(), scale); + ts->set_failed_test_info(cvtest::TS::FAIL_BAD_ACCURACY); + return; + } + + if (this->dnn_sr->getAlgorithm() != algorithm) + { + ts->printf(cvtest::TS::LOG, "Algorithm could not be set for scale algorithm %s and scale factor %d!\n", + algorithm.c_str(), scale); + ts->set_failed_test_info(cvtest::TS::FAIL_BAD_ACCURACY); + return; + } + + Mat img_new; + this->dnn_sr->upsample(img, img_new); + + if (img_new.empty()) + { + ts->printf(cvtest::TS::LOG, + "Could not perform upsampling for scale algorithm %s and scale factor %d!\n", + algorithm.c_str(), scale); + ts->set_failed_test_info(cvtest::TS::FAIL_BAD_ACCURACY); + return; + } + + int new_cols = img.cols * scale; + int new_rows = img.rows * scale; + if (img_new.cols != new_cols || img_new.rows != new_rows) + { + ts->printf(cvtest::TS::LOG, "Dimensions are not correct for scale algorithm %s and scale factor %d!\n", + algorithm.c_str(), scale); + ts->set_failed_test_info(cvtest::TS::FAIL_BAD_ACCURACY); + return; + } + } + + CV_DnnSuperResSingleOutputTest::CV_DnnSuperResSingleOutputTest() + { + dnn_sr = makePtr(); + } + + void CV_DnnSuperResSingleOutputTest::run(int) + { + //x2 + runOneModel("espcn", 2, "ESPCN_x2.pb"); + + //x3 + runOneModel("espcn", 3, "ESPCN_x3.pb"); + + //x4 + runOneModel("espcn", 4, "ESPCN_x4.pb"); + } + + TEST(CV_DnnSuperResSingleOutputTest, accuracy) + { + CV_DnnSuperResSingleOutputTest test; + test.safe_run(); + } + + /****************************************************************************************\ + * Test multi output models * + \****************************************************************************************/ + + class CV_DnnSuperResMultiOutputTest : public cvtest::BaseTest + { + public: + CV_DnnSuperResMultiOutputTest(); + + protected: + Ptr dnn_sr; + + virtual void run(int); + + void runOneModel(std::string algorithm, int scale, std::string model_filename, + std::vector scales, std::vector node_names); + }; + + void CV_DnnSuperResMultiOutputTest::runOneModel(std::string algorithm, int scale, std::string model_filename, + std::vector scales, std::vector node_names) + { + std::string path = std::string(ts->get_data_path()) + DNN_SUPERRES_DIR + "/" + IMAGE_FILENAME; + + Mat img = imread(path); + if ( img.empty() ) + { + ts->printf(cvtest::TS::LOG, "Test image not found!\n"); + ts->set_failed_test_info(cvtest::TS::FAIL_INVALID_TEST_DATA); + return; + } + + std::string pb_path = std::string(ts->get_data_path()) + DNN_SUPERRES_DIR + "/" + model_filename; + + this->dnn_sr->readModel(pb_path); + + this->dnn_sr->setModel(algorithm, scale); + + if ( this->dnn_sr->getScale() != scale ) + { + ts->printf(cvtest::TS::LOG, + "Scale factor could not be set for scale algorithm %s and scale factor %d!\n", + algorithm.c_str(), scale); + ts->set_failed_test_info(cvtest::TS::FAIL_BAD_ACCURACY); + return; + } + + if ( this->dnn_sr->getAlgorithm() != algorithm ) + { + ts->printf(cvtest::TS::LOG, "Algorithm could not be set for scale algorithm %s and scale factor %d!\n", + algorithm.c_str(), scale); + ts->set_failed_test_info(cvtest::TS::FAIL_BAD_ACCURACY); + return; + } + + std::vector outputs; + this->dnn_sr->upsample_multioutput(img, outputs, scales, node_names); + + for(unsigned int i = 0; i < outputs.size(); i++) + { + if( outputs[i].empty() ) + { + ts->printf(cvtest::TS::LOG, + "Could not perform upsampling for scale algorithm %s and scale factor %d!\n", + algorithm.c_str(), scale); + ts->set_failed_test_info(cvtest::TS::FAIL_BAD_ACCURACY); + return; + } + + int new_cols = img.cols * scales[i]; + int new_rows = img.rows * scales[i]; + + if ( outputs[i].cols != new_cols || outputs[i].rows != new_rows ) + { + ts->printf(cvtest::TS::LOG, "Dimensions are not correct for scale algorithm %s and scale factor %d!\n", + algorithm.c_str(), scale); + ts->set_failed_test_info(cvtest::TS::FAIL_BAD_ACCURACY); + return; + } + } + } + + CV_DnnSuperResMultiOutputTest::CV_DnnSuperResMultiOutputTest() + { + dnn_sr = makePtr(); + } + + void CV_DnnSuperResMultiOutputTest::run(int) + { + //LAPSRN + //x4 + std::vector names_4x {"NCHW_output_2x", "NCHW_output_4x"}; + std::vector scales_4x {2, 4}; + runOneModel("lapsrn", 4, "LapSRN_x4.pb", scales_4x, names_4x); + } + + TEST(CV_DnnSuperResMultiOutputTest, accuracy) + { + CV_DnnSuperResMultiOutputTest test; + test.safe_run(); + } + +}} \ No newline at end of file diff --git a/modules/dnn_superres/test/test_main.cpp b/modules/dnn_superres/test/test_main.cpp new file mode 100644 index 00000000000..0e51ddfd050 --- /dev/null +++ b/modules/dnn_superres/test/test_main.cpp @@ -0,0 +1,6 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. +#include "test_precomp.hpp" + +CV_TEST_MAIN("cv") diff --git a/modules/dnn_superres/test/test_precomp.hpp b/modules/dnn_superres/test/test_precomp.hpp new file mode 100644 index 00000000000..2afb739eaed --- /dev/null +++ b/modules/dnn_superres/test/test_precomp.hpp @@ -0,0 +1,15 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef __OPENCV_TEST_PRECOMP_HPP__ +#define __OPENCV_TEST_PRECOMP_HPP__ + +#include "opencv2/ts.hpp" +#include "opencv2/dnn_superres.hpp" + +namespace opencv_test { + using namespace cv::dnn::dnn_superres; +} + +#endif diff --git a/modules/dnn_superres/tutorials/dnn_superres_tutorial.markdown b/modules/dnn_superres/tutorials/dnn_superres_tutorial.markdown new file mode 100644 index 00000000000..b2f4e70aac8 --- /dev/null +++ b/modules/dnn_superres/tutorials/dnn_superres_tutorial.markdown @@ -0,0 +1,41 @@ +Super Resolution using CNNs {#tutorial_dnn_superres} +=========================== + +# Building + +Run the following command to build this module: + +```make +cmake -DOPENCV_EXTRA_MODULES_PATH=/modules -Dopencv_dnn_superres=ON +``` + +# Super resolution sample code + +See the "dnn_superres" in the samples for an idea of how to run it. For example: + +``` +dnn_superres/samples/dnn_superres.cpp ./butterfly.png edsr 2 +``` + +# Single output + +Run the sample code to do single output super-resolution with the implemented models.\ +ESPCN model can now support 2x, 3x, and 4x super resolution. + +``` +./bin/example_dnn_superres_dnn_superres path/to/image.png espcn 2 \ +/path/to/opencv_contrib/modules/dnn_superres/models/ESPCN_x2.pb +``` + +# Multiple output + +LapSRN supports multiple outputs with one forward pass. It can now support 2x, 4x, 8x, and (2x, 4x) and (2x, 4x, 8x) super-resolution.\ +The uploaded trained model files have the following output node names: +- 2x model: NCHW_output +- 4x model: NCHW_output_2x, NCHW_output_4x +- 8x model: NCHW_output_2x, NCHW_output_4x, NCHW_output_8x + +``` +./bin/example_dnn_superres_dnn_superres_multioutput path/to/image.png 2,4 NCHW_output_2x,NCHW_output_4x \ +path/to/opencv_contrib/modules/dnn_superres/models/LapSRN_x4.pb +``` \ No newline at end of file