diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000000..d1609f8e820 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,24 @@ +cmake_minimum_required(VERSION 2.8) +project(torchvision) +set(CMAKE_CXX_STANDARD 11) + +find_package(Torch REQUIRED) + +file(GLOB_RECURSE HEADERS torchvision/csrc/vision.h) +file(GLOB_RECURSE MODELS_HEADERS torchvision/csrc/models/*.h) +file(GLOB_RECURSE MODELS_SOURCES torchvision/csrc/models/*.h torchvision/csrc/models/*.cpp) + +add_library (${PROJECT_NAME} SHARED ${MODELS_SOURCES}) +target_link_libraries(${PROJECT_NAME} "${TORCH_LIBRARIES}") + +add_executable(convertmodels torchvision/csrc/convert_models/convert_models.cpp) +target_link_libraries(convertmodels "${PROJECT_NAME}") +target_link_libraries(convertmodels "${TORCH_LIBRARIES}") + +#add_executable(testmodels test/test_models.cpp) +#target_link_libraries(testmodels "${PROJECT_NAME}") +#target_link_libraries(testmodels "${TORCH_LIBRARIES}") + +install(TARGETS ${PROJECT_NAME} DESTINATION ${CMAKE_INSTALL_PREFIX}/lib) +install(FILES ${HEADERS} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${PROJECT_NAME}) +install(FILES ${MODELS_HEADERS} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${PROJECT_NAME}/models) diff --git a/setup.py b/setup.py index 30b0d87e538..5e87a7a099a 100644 --- a/setup.py +++ b/setup.py @@ -89,6 +89,15 @@ def get_extensions(): sources = main_file + source_cpu extension = CppExtension + test_dir = os.path.join(this_dir, 'test') + models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models') + test_file = glob.glob(os.path.join(test_dir, '*.cpp')) + source_models = glob.glob(os.path.join(models_dir, '*.cpp')) + + test_file = [os.path.join(test_dir, s) for s in test_file] + source_models = [os.path.join(models_dir, s) for s in source_models] + tests = test_file + source_models + define_macros = [] extra_compile_args = {} @@ -109,6 +118,7 @@ def get_extensions(): sources = [os.path.join(extensions_dir, s) for s in sources] include_dirs = [extensions_dir] + tests_include_dirs = [test_dir, models_dir] ext_modules = [ extension( @@ -117,6 +127,13 @@ def get_extensions(): include_dirs=include_dirs, define_macros=define_macros, extra_compile_args=extra_compile_args, + ), + extension( + 'torchvision._C_tests', + tests, + include_dirs=tests_include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, ) ] diff --git a/test/test_cpp_models.py b/test/test_cpp_models.py new file mode 100644 index 00000000000..3331db112d1 --- /dev/null +++ b/test/test_cpp_models.py @@ -0,0 +1,122 @@ +import torch +import os +import unittest +from torchvision import models, transforms, _C_tests + +from PIL import Image +import torchvision.transforms.functional as F + + +def process_model(model, tensor, func, name): + model.eval() + traced_script_module = torch.jit.trace(model, tensor) + traced_script_module.save("model.pt") + + py_output = model.forward(tensor) + cpp_output = func("model.pt", tensor) + + assert torch.allclose(py_output, cpp_output), 'Output mismatch of ' + name + ' models' + + +def read_image1(): + image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') + image = Image.open(image_path) + image = image.resize((224, 224)) + x = F.to_tensor(image) + return x.view(1, 3, 224, 224) + + +def read_image2(): + image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') + image = Image.open(image_path) + image = image.resize((299, 299)) + x = F.to_tensor(image) + x = x.view(1, 3, 299, 299) + return torch.cat([x, x], 0) + + +class Tester(unittest.TestCase): + pretrained = False + image = read_image1() + + def test_alexnet(self): + process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, 'Alexnet') + + def test_vgg11(self): + process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, 'VGG11') + + def test_vgg13(self): + process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, 'VGG13') + + def test_vgg16(self): + process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, 'VGG16') + + def test_vgg19(self): + process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, 'VGG19') + + def test_vgg11_bn(self): + process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, 'VGG11BN') + + def test_vgg13_bn(self): + process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, 'VGG13BN') + + def test_vgg16_bn(self): + process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, 'VGG16BN') + + def test_vgg19_bn(self): + process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, 'VGG19BN') + + def test_resnet18(self): + process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, 'Resnet18') + + def test_resnet34(self): + process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, 'Resnet34') + + def test_resnet50(self): + process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, 'Resnet50') + + def test_resnet101(self): + process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, 'Resnet101') + + def test_resnet152(self): + process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, 'Resnet152') + + def test_resnext50_32x4d(self): + process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, 'ResNext50_32x4d') + + def test_resnext101_32x8d(self): + process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, 'ResNext101_32x8d') + + def test_squeezenet1_0(self): + process_model(models.squeezenet1_0(self.pretrained), self.image, + _C_tests.forward_squeezenet1_0, 'Squeezenet1.0') + + def test_squeezenet1_1(self): + process_model(models.squeezenet1_1(self.pretrained), self.image, + _C_tests.forward_squeezenet1_1, 'Squeezenet1.1') + + def test_densenet121(self): + process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, 'Densenet121') + + def test_densenet169(self): + process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, 'Densenet169') + + def test_densenet201(self): + process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, 'Densenet201') + + def test_densenet161(self): + process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, 'Densenet161') + + def test_mobilenet_v2(self): + process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, 'MobileNet') + + def test_googlenet(self): + process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, 'GoogLeNet') + + def test_inception_v3(self): + self.image = read_image2() + process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, 'Inceptionv3') + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_models.cpp b/test/test_models.cpp new file mode 100644 index 00000000000..17e76e8c9d1 --- /dev/null +++ b/test/test_models.cpp @@ -0,0 +1,173 @@ +#include +#include +#include + +#include "../torchvision/csrc/models/models.h" + +using namespace vision::models; + +template +torch::Tensor forward_model(const std::string& input_path, torch::Tensor x) { + Model network; + torch::load(network, input_path); + network->eval(); + return network->forward(x); +} + +torch::Tensor forward_alexnet(const std::string& input_path, torch::Tensor x) { + return forward_model(input_path, x); +} + +torch::Tensor forward_vgg11(const std::string& input_path, torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_vgg13(const std::string& input_path, torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_vgg16(const std::string& input_path, torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_vgg19(const std::string& input_path, torch::Tensor x) { + return forward_model(input_path, x); +} + +torch::Tensor forward_vgg11bn(const std::string& input_path, torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_vgg13bn(const std::string& input_path, torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_vgg16bn(const std::string& input_path, torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_vgg19bn(const std::string& input_path, torch::Tensor x) { + return forward_model(input_path, x); +} + +torch::Tensor forward_resnet18(const std::string& input_path, torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_resnet34(const std::string& input_path, torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_resnet50(const std::string& input_path, torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_resnet101( + const std::string& input_path, + torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_resnet152( + const std::string& input_path, + torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_resnext50_32x4d( + const std::string& input_path, + torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_resnext101_32x8d( + const std::string& input_path, + torch::Tensor x) { + return forward_model(input_path, x); +} + +torch::Tensor forward_squeezenet1_0( + const std::string& input_path, + torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_squeezenet1_1( + const std::string& input_path, + torch::Tensor x) { + return forward_model(input_path, x); +} + +torch::Tensor forward_densenet121( + const std::string& input_path, + torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_densenet169( + const std::string& input_path, + torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_densenet201( + const std::string& input_path, + torch::Tensor x) { + return forward_model(input_path, x); +} +torch::Tensor forward_densenet161( + const std::string& input_path, + torch::Tensor x) { + return forward_model(input_path, x); +} + +torch::Tensor forward_mobilenetv2( + const std::string& input_path, + torch::Tensor x) { + return forward_model(input_path, x); +} + +torch::Tensor forward_googlenet( + const std::string& input_path, + torch::Tensor x) { + GoogLeNet network; + torch::load(network, input_path); + network->eval(); + return network->forward(x).output; +} +torch::Tensor forward_inceptionv3( + const std::string& input_path, + torch::Tensor x) { + InceptionV3 network; + torch::load(network, input_path); + network->eval(); + return network->forward(x).output; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward_alexnet", &forward_alexnet, "forward_alexnet"); + + m.def("forward_vgg11", &forward_vgg11, "forward_vgg11"); + m.def("forward_vgg13", &forward_vgg13, "forward_vgg13"); + m.def("forward_vgg16", &forward_vgg16, "forward_vgg16"); + m.def("forward_vgg19", &forward_vgg19, "forward_vgg19"); + + m.def("forward_vgg11bn", &forward_vgg11bn, "forward_vgg11bn"); + m.def("forward_vgg13bn", &forward_vgg13bn, "forward_vgg13bn"); + m.def("forward_vgg16bn", &forward_vgg16bn, "forward_vgg16bn"); + m.def("forward_vgg19bn", &forward_vgg19bn, "forward_vgg19bn"); + + m.def("forward_resnet18", &forward_resnet18, "forward_resnet18"); + m.def("forward_resnet34", &forward_resnet34, "forward_resnet34"); + m.def("forward_resnet50", &forward_resnet50, "forward_resnet50"); + m.def("forward_resnet101", &forward_resnet101, "forward_resnet101"); + m.def("forward_resnet152", &forward_resnet152, "forward_resnet152"); + m.def( + "forward_resnext50_32x4d", + &forward_resnext50_32x4d, + "forward_resnext50_32x4d"); + m.def( + "forward_resnext101_32x8d", + &forward_resnext101_32x8d, + "forward_resnext101_32x8d"); + + m.def( + "forward_squeezenet1_0", &forward_squeezenet1_0, "forward_squeezenet1_0"); + m.def( + "forward_squeezenet1_1", &forward_squeezenet1_1, "forward_squeezenet1_1"); + + m.def("forward_densenet121", &forward_densenet121, "forward_densenet121"); + m.def("forward_densenet169", &forward_densenet169, "forward_densenet169"); + m.def("forward_densenet201", &forward_densenet201, "forward_densenet201"); + m.def("forward_densenet161", &forward_densenet161, "forward_densenet161"); + + m.def("forward_mobilenetv2", &forward_mobilenetv2, "forward_mobilenetv2"); + + m.def("forward_googlenet", &forward_googlenet, "forward_googlenet"); + m.def("forward_inceptionv3", &forward_inceptionv3, "forward_inceptionv3"); +} diff --git a/torchvision/csrc/convert_models/convert_models.cpp b/torchvision/csrc/convert_models/convert_models.cpp new file mode 100644 index 00000000000..988bc83e948 --- /dev/null +++ b/torchvision/csrc/convert_models/convert_models.cpp @@ -0,0 +1,76 @@ +#include +#include +#include + +#include "../models/models.h" + +using namespace vision::models; + +template +void convert_and_save_model( + const std::string& input_path, + const std::string& output_path) { + Model network; + torch::load(network, input_path); + torch::save(network, output_path); + + auto index = input_path.find("_python"); + auto name = input_path.substr(0, index); + std::cout << "finished loading and saving " << name << std::endl; +} + +int main(int argc, const char* argv[]) { + convert_and_save_model("alexnet_python.pt", "alexnet_cpp.pt"); + + convert_and_save_model("vgg11_python.pt", "vgg11_cpp.pt"); + convert_and_save_model("vgg13_python.pt", "vgg13_cpp.pt"); + convert_and_save_model("vgg16_python.pt", "vgg16_cpp.pt"); + convert_and_save_model("vgg19_python.pt", "vgg19_cpp.pt"); + + convert_and_save_model("vgg11bn_python.pt", "vgg11bn_cpp.pt"); + convert_and_save_model("vgg13bn_python.pt", "vgg13bn_cpp.pt"); + convert_and_save_model("vgg16bn_python.pt", "vgg16bn_cpp.pt"); + convert_and_save_model("vgg19bn_python.pt", "vgg19bn_cpp.pt"); + + convert_and_save_model("resnet18_python.pt", "resnet18_cpp.pt"); + convert_and_save_model("resnet34_python.pt", "resnet34_cpp.pt"); + convert_and_save_model("resnet50_python.pt", "resnet50_cpp.pt"); + convert_and_save_model("resnet101_python.pt", "resnet101_cpp.pt"); + convert_and_save_model("resnet152_python.pt", "resnet152_cpp.pt"); + convert_and_save_model( + "resnext50_32x4d_python.pt", "resnext50_32x4d_cpp.pt"); + convert_and_save_model( + "resnext101_32x8d_python.pt", "resnext101_32x8d_cpp.pt"); + + convert_and_save_model( + "squeezenet1_0_python.pt", "squeezenet1_0_cpp.pt"); + convert_and_save_model( + "squeezenet1_1_python.pt", "squeezenet1_1_cpp.pt"); + + convert_and_save_model( + "densenet121_python.pt", "densenet121_cpp.pt"); + convert_and_save_model( + "densenet169_python.pt", "densenet169_cpp.pt"); + convert_and_save_model( + "densenet201_python.pt", "densenet201_cpp.pt"); + convert_and_save_model( + "densenet161_python.pt", "densenet161_cpp.pt"); + + convert_and_save_model( + "mobilenetv2_python.pt", "mobilenetv2_cpp.pt"); + + convert_and_save_model( + "shufflenetv2_x0_5_python.pt", "shufflenetv2_x0_5_cpp.pt"); + convert_and_save_model( + "shufflenetv2_x1_0_python.pt", "shufflenetv2_x1_0_cpp.pt"); + convert_and_save_model( + "shufflenetv2_x1_5_python.pt", "shufflenetv2_x1_5_cpp.pt"); + convert_and_save_model( + "shufflenetv2_x2_0_python.pt", "shufflenetv2_x2_0_cpp.pt"); + + convert_and_save_model("googlenet_python.pt", "googlenet_cpp.pt"); + convert_and_save_model( + "inceptionv3_python.pt", "inceptionv3_cpp.pt"); + + return 0; +} diff --git a/torchvision/csrc/models/alexnet.cpp b/torchvision/csrc/models/alexnet.cpp new file mode 100644 index 00000000000..e29674b706a --- /dev/null +++ b/torchvision/csrc/models/alexnet.cpp @@ -0,0 +1,47 @@ +#include "alexnet.h" + +#include "modelsimpl.h" + +namespace vision { +namespace models { +AlexNetImpl::AlexNetImpl(int64_t num_classes) { + features = torch::nn::Sequential( + torch::nn::Conv2d( + torch::nn::Conv2dOptions(3, 64, 11).stride(4).padding(2)), + torch::nn::Functional(modelsimpl::relu_), + torch::nn::Functional(modelsimpl::max_pool2d, 3, 2), + torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 192, 5).padding(2)), + torch::nn::Functional(modelsimpl::relu_), + torch::nn::Functional(modelsimpl::max_pool2d, 3, 2), + torch::nn::Conv2d(torch::nn::Conv2dOptions(192, 384, 3).padding(1)), + torch::nn::Functional(modelsimpl::relu_), + torch::nn::Conv2d(torch::nn::Conv2dOptions(384, 256, 3).padding(1)), + torch::nn::Functional(modelsimpl::relu_), + torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, 3).padding(1)), + torch::nn::Functional(modelsimpl::relu_), + torch::nn::Functional(modelsimpl::max_pool2d, 3, 2)); + + classifier = torch::nn::Sequential( + torch::nn::Dropout(), + torch::nn::Linear(256 * 6 * 6, 4096), + torch::nn::Functional(torch::relu), + torch::nn::Dropout(), + torch::nn::Linear(4096, 4096), + torch::nn::Functional(torch::relu), + torch::nn::Linear(4096, num_classes)); + + register_module("features", features); + register_module("classifier", classifier); +} + +torch::Tensor AlexNetImpl::forward(torch::Tensor x) { + x = features->forward(x); + x = torch::adaptive_avg_pool2d(x, {6, 6}); + x = x.view({x.size(0), -1}); + x = classifier->forward(x); + + return x; +} + +} // namespace models +} // namespace vision diff --git a/torchvision/csrc/models/alexnet.h b/torchvision/csrc/models/alexnet.h new file mode 100644 index 00000000000..66694e0c982 --- /dev/null +++ b/torchvision/csrc/models/alexnet.h @@ -0,0 +1,23 @@ +#ifndef ALEXNET_H +#define ALEXNET_H + +#include + +namespace vision { +namespace models { +// AlexNet model architecture from the +// "One weird trick..." paper. +struct AlexNetImpl : torch::nn::Module { + torch::nn::Sequential features{nullptr}, classifier{nullptr}; + + AlexNetImpl(int64_t num_classes = 1000); + + torch::Tensor forward(torch::Tensor x); +}; + +TORCH_MODULE(AlexNet); + +} // namespace models +} // namespace vision + +#endif // ALEXNET_H diff --git a/torchvision/csrc/models/densenet.cpp b/torchvision/csrc/models/densenet.cpp new file mode 100644 index 00000000000..4c50b509b10 --- /dev/null +++ b/torchvision/csrc/models/densenet.cpp @@ -0,0 +1,219 @@ +#include "densenet.h" + +#include "modelsimpl.h" + +namespace vision { +namespace models { +using Options = torch::nn::Conv2dOptions; + +struct _DenseLayerImpl : torch::nn::SequentialImpl { + double drop_rate; + + _DenseLayerImpl( + int64_t num_input_features, + int64_t growth_rate, + int64_t bn_size, + double drop_rate) + : drop_rate(drop_rate) { + push_back("norm1", torch::nn::BatchNorm(num_input_features)); + push_back("relu1", torch::nn::Functional(modelsimpl::relu_)); + push_back( + "conv1", + torch::nn::Conv2d(Options(num_input_features, bn_size * growth_rate, 1) + .stride(1) + .with_bias(false))); + push_back("norm2", torch::nn::BatchNorm(bn_size * growth_rate)); + push_back("relu2", torch::nn::Functional(modelsimpl::relu_)); + push_back( + "conv2", + torch::nn::Conv2d(Options(bn_size * growth_rate, growth_rate, 3) + .stride(1) + .padding(1) + .with_bias(false))); + } + + torch::Tensor forward(torch::Tensor x) { + auto new_features = torch::nn::SequentialImpl::forward(x); + if (drop_rate > 0) + new_features = + torch::dropout(new_features, drop_rate, this->is_training()); + return torch::cat({x, new_features}, 1); + } +}; + +TORCH_MODULE(_DenseLayer); + +struct _DenseBlockImpl : torch::nn::SequentialImpl { + _DenseBlockImpl( + int64_t num_layers, + int64_t num_input_features, + int64_t bn_size, + int64_t growth_rate, + double drop_rate) { + for (int64_t i = 0; i < num_layers; ++i) { + auto layer = _DenseLayer( + num_input_features + i * growth_rate, + growth_rate, + bn_size, + drop_rate); + push_back("denselayer" + std::to_string(i + 1), layer); + } + } + + torch::Tensor forward(torch::Tensor x) { + return torch::nn::SequentialImpl::forward(x); + } +}; + +TORCH_MODULE(_DenseBlock); + +struct _TransitionImpl : torch::nn::SequentialImpl { + _TransitionImpl(int64_t num_input_features, int64_t num_output_features) { + push_back("norm", torch::nn::BatchNorm(num_input_features)); + push_back("relu ", torch::nn::Functional(modelsimpl::relu_)); + push_back( + "conv", + torch::nn::Conv2d(Options(num_input_features, num_output_features, 1) + .stride(1) + .with_bias(false))); + push_back( + "pool", torch::nn::Functional(torch::avg_pool2d, 2, 2, 0, false, true)); + } + + torch::Tensor forward(torch::Tensor x) { + return torch::nn::SequentialImpl::forward(x); + } +}; + +TORCH_MODULE(_Transition); + +DenseNetImpl::DenseNetImpl( + int64_t num_classes, + int64_t growth_rate, + std::vector block_config, + int64_t num_init_features, + int64_t bn_size, + double drop_rate) { + // First convolution + features = torch::nn::Sequential(); + features->push_back( + "conv0", + torch::nn::Conv2d(Options(3, num_init_features, 7) + .stride(2) + .padding(3) + .with_bias(false))); + + features->push_back("norm0", torch::nn::BatchNorm(num_init_features)); + features->push_back("relu0", torch::nn::Functional(modelsimpl::relu_)); + features->push_back( + "pool0", torch::nn::Functional(torch::max_pool2d, 3, 2, 1, 1, false)); + + // Each denseblock + auto num_features = num_init_features; + for (size_t i = 0; i < block_config.size(); ++i) { + auto num_layers = block_config[i]; + _DenseBlock block( + num_layers, num_features, bn_size, growth_rate, drop_rate); + + features->push_back("denseblock" + std::to_string(i + 1), block); + num_features = num_features + num_layers * growth_rate; + + if (i != block_config.size() - 1) { + auto trans = _Transition(num_features, num_features / 2); + features->push_back("transition" + std::to_string(i + 1), trans); + num_features = num_features / 2; + } + } + + // Final batch norm + features->push_back("norm5", torch::nn::BatchNorm(num_features)); + // Linear layer + classifier = torch::nn::Linear(num_features, num_classes); + + register_module("features", features); + register_module("classifier", classifier); + + // Official init from torch repo. + for (auto& module : modules(/*include_self=*/false)) { + if (auto M = dynamic_cast(module.get())) + torch::nn::init::kaiming_normal_(M->weight); + else if (auto M = dynamic_cast(module.get())) { + torch::nn::init::constant_(M->weight, 1); + torch::nn::init::constant_(M->bias, 0); + } else if (auto M = dynamic_cast(module.get())) + torch::nn::init::constant_(M->bias, 0); + } +} + +torch::Tensor DenseNetImpl::forward(torch::Tensor x) { + auto features = this->features->forward(x); + auto out = torch::relu_(features); + out = torch::adaptive_avg_pool2d(out, {1, 1}); + + out = out.view({features.size(0), -1}); + out = this->classifier->forward(out); + return out; +} + +DenseNet121Impl::DenseNet121Impl( + int64_t num_classes, + int64_t growth_rate, + std::vector block_config, + int64_t num_init_features, + int64_t bn_size, + double drop_rate) + : DenseNetImpl( + num_classes, + growth_rate, + block_config, + num_init_features, + bn_size, + drop_rate) {} + +DenseNet169Impl::DenseNet169Impl( + int64_t num_classes, + int64_t growth_rate, + std::vector block_config, + int64_t num_init_features, + int64_t bn_size, + double drop_rate) + : DenseNetImpl( + num_classes, + growth_rate, + block_config, + num_init_features, + bn_size, + drop_rate) {} + +DenseNet201Impl::DenseNet201Impl( + int64_t num_classes, + int64_t growth_rate, + std::vector block_config, + int64_t num_init_features, + int64_t bn_size, + double drop_rate) + : DenseNetImpl( + num_classes, + growth_rate, + block_config, + num_init_features, + bn_size, + drop_rate) {} + +DenseNet161Impl::DenseNet161Impl( + int64_t num_classes, + int64_t growth_rate, + std::vector block_config, + int64_t num_init_features, + int64_t bn_size, + double drop_rate) + : DenseNetImpl( + num_classes, + growth_rate, + block_config, + num_init_features, + bn_size, + drop_rate) {} + +} // namespace models +} // namespace vision diff --git a/torchvision/csrc/models/densenet.h b/torchvision/csrc/models/densenet.h new file mode 100644 index 00000000000..b27efc10a07 --- /dev/null +++ b/torchvision/csrc/models/densenet.h @@ -0,0 +1,85 @@ +#ifndef DENSENET_H +#define DENSENET_H + +#include + +namespace vision { +namespace models { +// Densenet-BC model class, based on +// "Densely Connected Convolutional Networks" +// + +// Args: +// num_classes (int) - number of classification classes +// growth_rate (int) - how many filters to add each layer (`k` in paper) +// block_config (list of 4 ints) - how many layers in each pooling block +// num_init_features (int) - the number of filters to learn in the first +// convolution layer +// bn_size (int) - multiplicative factor for number of bottle neck layers +// (i.e. bn_size * k features in the bottleneck layer) +// drop_rate (float) - dropout rate after each dense layer +struct DenseNetImpl : torch::nn::Module { + torch::nn::Sequential features{nullptr}; + torch::nn::Linear classifier{nullptr}; + + DenseNetImpl( + int64_t num_classes = 1000, + int64_t growth_rate = 32, + std::vector block_config = {6, 12, 24, 16}, + int64_t num_init_features = 64, + int64_t bn_size = 4, + double drop_rate = 0); + + torch::Tensor forward(torch::Tensor x); +}; + +struct DenseNet121Impl : DenseNetImpl { + DenseNet121Impl( + int64_t num_classes = 1000, + int64_t growth_rate = 32, + std::vector block_config = {6, 12, 24, 16}, + int64_t num_init_features = 64, + int64_t bn_size = 4, + double drop_rate = 0); +}; + +struct DenseNet169Impl : DenseNetImpl { + DenseNet169Impl( + int64_t num_classes = 1000, + int64_t growth_rate = 32, + std::vector block_config = {6, 12, 32, 32}, + int64_t num_init_features = 64, + int64_t bn_size = 4, + double drop_rate = 0); +}; + +struct DenseNet201Impl : DenseNetImpl { + DenseNet201Impl( + int64_t num_classes = 1000, + int64_t growth_rate = 32, + std::vector block_config = {6, 12, 48, 32}, + int64_t num_init_features = 64, + int64_t bn_size = 4, + double drop_rate = 0); +}; + +struct DenseNet161Impl : DenseNetImpl { + DenseNet161Impl( + int64_t num_classes = 1000, + int64_t growth_rate = 48, + std::vector block_config = {6, 12, 36, 24}, + int64_t num_init_features = 96, + int64_t bn_size = 4, + double drop_rate = 0); +}; + +TORCH_MODULE(DenseNet); +TORCH_MODULE(DenseNet121); +TORCH_MODULE(DenseNet169); +TORCH_MODULE(DenseNet201); +TORCH_MODULE(DenseNet161); + +} // namespace models +} // namespace vision + +#endif // DENSENET_H diff --git a/torchvision/csrc/models/googlenet.cpp b/torchvision/csrc/models/googlenet.cpp new file mode 100644 index 00000000000..c053b65b4d6 --- /dev/null +++ b/torchvision/csrc/models/googlenet.cpp @@ -0,0 +1,232 @@ +#include "googlenet.h" + +#include "modelsimpl.h" + +namespace vision { +namespace models { + +using Options = torch::nn::Conv2dOptions; + +namespace _googlenetimpl { +BasicConv2dImpl::BasicConv2dImpl(torch::nn::Conv2dOptions options) { + options.with_bias(false); + conv = torch::nn::Conv2d(options); + bn = torch::nn::BatchNorm( + torch::nn::BatchNormOptions(options.output_channels()).eps(0.001)); + + register_module("conv", conv); + register_module("bn", bn); +} + +torch::Tensor BasicConv2dImpl::forward(torch::Tensor x) { + x = conv->forward(x); + x = bn->forward(x); + return x.relu_(); +} + +InceptionImpl::InceptionImpl( + int64_t in_channels, + int64_t ch1x1, + int64_t ch3x3red, + int64_t ch3x3, + int64_t ch5x5red, + int64_t ch5x5, + int64_t pool_proj) { + branch1 = BasicConv2d(Options(in_channels, ch1x1, 1)); + + branch2->push_back(BasicConv2d(Options(in_channels, ch3x3red, 1))); + branch2->push_back(BasicConv2d(Options(ch3x3red, ch3x3, 3).padding(1))); + + branch3->push_back(BasicConv2d(Options(in_channels, ch5x5red, 1))); + branch3->push_back(BasicConv2d(Options(ch5x5red, ch5x5, 3).padding(1))); + + branch4->push_back( + torch::nn::Functional(torch::max_pool2d, 3, 1, 1, 1, true)); + branch4->push_back(BasicConv2d(Options(in_channels, pool_proj, 1))); + + register_module("branch1", branch1); + register_module("branch2", branch2); + register_module("branch3", branch3); + register_module("branch4", branch4); +} + +torch::Tensor InceptionImpl::forward(torch::Tensor x) { + auto b1 = branch1->forward(x); + auto b2 = branch2->forward(x); + auto b3 = branch3->forward(x); + auto b4 = branch4->forward(x); + + return torch::cat({b1, b2, b3, b4}, 1); +} + +InceptionAuxImpl::InceptionAuxImpl(int64_t in_channels, int64_t num_classes) { + conv = BasicConv2d(Options(in_channels, 128, 1)); + fc1 = torch::nn::Linear(2048, 1024); + fc2 = torch::nn::Linear(1024, num_classes); + + register_module("conv", conv); + register_module("fc1", fc1); + register_module("fc2", fc2); +} + +torch::Tensor InceptionAuxImpl::forward(at::Tensor x) { + // aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 + x = torch::adaptive_avg_pool2d(x, {4, 4}); + // aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 + x = conv->forward(x); + // N x 128 x 4 x 4 + x = x.view({x.size(0), -1}); + // N x 2048 + x = fc1->forward(x).relu_(); + // N x 2048 + x = torch::dropout(x, 0.7, is_training()); + // N x 2048 + x = fc2->forward(x); + // N x 1024 + + return x; +} + +} // namespace _googlenetimpl + +GoogLeNetImpl::GoogLeNetImpl( + int64_t num_classes, + bool aux_logits, + bool transform_input, + bool init_weights) { + this->aux_logits = aux_logits; + this->transform_input = transform_input; + + conv1 = _googlenetimpl::BasicConv2d(Options(3, 64, 7).stride(2).padding(3)); + conv2 = _googlenetimpl::BasicConv2d(Options(64, 64, 1)); + conv3 = _googlenetimpl::BasicConv2d(Options(64, 192, 3).padding(1)); + + inception3a = _googlenetimpl::Inception(192, 64, 96, 128, 16, 32, 32); + inception3b = _googlenetimpl::Inception(256, 128, 128, 192, 32, 96, 64); + + inception4a = _googlenetimpl::Inception(480, 192, 96, 208, 16, 48, 64); + inception4b = _googlenetimpl::Inception(512, 160, 112, 224, 24, 64, 64); + inception4c = _googlenetimpl::Inception(512, 128, 128, 256, 24, 64, 64); + inception4d = _googlenetimpl::Inception(512, 112, 144, 288, 32, 64, 64); + inception4e = _googlenetimpl::Inception(528, 256, 160, 320, 32, 128, 128); + + inception5a = _googlenetimpl::Inception(832, 256, 160, 320, 32, 128, 128); + inception5b = _googlenetimpl::Inception(832, 384, 192, 384, 48, 128, 128); + + if (aux_logits) { + aux1 = _googlenetimpl::InceptionAux(512, num_classes); + aux2 = _googlenetimpl::InceptionAux(528, num_classes); + + register_module("aux1", aux1); + register_module("aux2", aux2); + } + + dropout = torch::nn::Dropout(0.2); + fc = torch::nn::Linear(1024, num_classes); + + register_module("conv1", conv1); + register_module("conv2", conv2); + register_module("conv3", conv3); + + register_module("inception3a", inception3a); + register_module("inception3b", inception3b); + + register_module("inception4a", inception4a); + register_module("inception4b", inception4b); + register_module("inception4c", inception4c); + register_module("inception4d", inception4d); + register_module("inception4e", inception4e); + + register_module("inception5a", inception5a); + register_module("inception5b", inception5b); + + register_module("dropout", dropout); + register_module("fc", fc); + + if (init_weights) + _initialize_weights(); +} + +void GoogLeNetImpl::_initialize_weights() { + for (auto& module : modules(/*include_self=*/false)) { + if (auto M = dynamic_cast(module.get())) + torch::nn::init::normal_(M->weight); // Note: used instead of truncated + // normal initialization + else if (auto M = dynamic_cast(module.get())) + torch::nn::init::normal_(M->weight); // Note: used instead of truncated + // normal initialization + else if (auto M = dynamic_cast(module.get())) { + torch::nn::init::ones_(M->weight); + torch::nn::init::zeros_(M->bias); + } + } +} + +GoogLeNetOutput GoogLeNetImpl::forward(torch::Tensor x) { + if (transform_input) { + auto x_ch0 = torch::unsqueeze(x.select(1, 0), 1) * (0.229 / 0.5) + + (0.485 - 0.5) / 0.5; + auto x_ch1 = torch::unsqueeze(x.select(1, 1), 1) * (0.224 / 0.5) + + (0.456 - 0.5) / 0.5; + auto x_ch2 = torch::unsqueeze(x.select(1, 2), 1) * (0.225 / 0.5) + + (0.406 - 0.5) / 0.5; + + x = torch::cat({x_ch0, x_ch1, x_ch2}, 1); + } + + // N x 3 x 224 x 224 + x = conv1->forward(x); + // N x 64 x 112 x 112 + x = torch::max_pool2d(x, 3, 2, 0, 1, true); + // N x 64 x 56 x 56 + x = conv2->forward(x); + // N x 64 x 56 x 56 + x = conv3->forward(x); + // N x 192 x 56 x 56 + x = torch::max_pool2d(x, 3, 2, 0, 1, true); + + // N x 192 x 28 x 28 + x = inception3a->forward(x); + // N x 256 x 28 x 28 + x = inception3b->forward(x); + // N x 480 x 28 x 28 + x = torch::max_pool2d(x, 3, 2, 0, 1, true); + // N x 480 x 14 x 14 + x = inception4a->forward(x); + // N x 512 x 14 x 14 + torch::Tensor aux1; + if (is_training() && aux_logits) + aux1 = this->aux1->forward(x); + + x = inception4b->forward(x); + // N x 512 x 14 x 14 + x = inception4c->forward(x); + // N x 512 x 14 x 14 + x = inception4d->forward(x); + // N x 528 x 14 x 14 + torch::Tensor aux2; + if (is_training() && aux_logits) + aux2 = this->aux2->forward(x); + + x = inception4e(x); + // N x 832 x 14 x 14 + x = torch::max_pool2d(x, 2, 2, 0, 1, true); + // N x 832 x 7 x 7 + x = inception5a(x); + // N x 832 x 7 x 7 + x = inception5b(x); + // N x 1024 x 7 x 7 + + x = torch::adaptive_avg_pool2d(x, {1, 1}); + // N x 1024 x 1 x 1 + x = x.view({x.size(0), -1}); + // N x 1024 + x = dropout->forward(x); + x = fc->forward(x); + // N x 1000(num_classes) + + return {x, aux1, aux2}; +} + +} // namespace models +} // namespace vision diff --git a/torchvision/csrc/models/googlenet.h b/torchvision/csrc/models/googlenet.h new file mode 100644 index 00000000000..5324e8e758d --- /dev/null +++ b/torchvision/csrc/models/googlenet.h @@ -0,0 +1,89 @@ +#ifndef GOOGLENET_H +#define GOOGLENET_H + +#include + +namespace vision { +namespace models { + +namespace _googlenetimpl { +struct BasicConv2dImpl : torch::nn::Module { + torch::nn::Conv2d conv{nullptr}; + torch::nn::BatchNorm bn{nullptr}; + + BasicConv2dImpl(torch::nn::Conv2dOptions options); + + torch::Tensor forward(torch::Tensor x); +}; + +TORCH_MODULE(BasicConv2d); + +struct InceptionImpl : torch::nn::Module { + BasicConv2d branch1{nullptr}; + torch::nn::Sequential branch2, branch3, branch4; + + InceptionImpl( + int64_t in_channels, + int64_t ch1x1, + int64_t ch3x3red, + int64_t ch3x3, + int64_t ch5x5red, + int64_t ch5x5, + int64_t pool_proj); + + torch::Tensor forward(torch::Tensor x); +}; + +TORCH_MODULE(Inception); + +struct InceptionAuxImpl : torch::nn::Module { + BasicConv2d conv{nullptr}; + torch::nn::Linear fc1{nullptr}, fc2{nullptr}; + + InceptionAuxImpl(int64_t in_channels, int64_t num_classes); + + torch::Tensor forward(torch::Tensor x); +}; + +TORCH_MODULE(InceptionAux); + +} // namespace _googlenetimpl + +struct GoogLeNetOutput { + torch::Tensor output; + torch::Tensor aux1; + torch::Tensor aux2; +}; + +struct GoogLeNetImpl : torch::nn::Module { + bool aux_logits, transform_input; + + _googlenetimpl::BasicConv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr}; + + _googlenetimpl::Inception inception3a{nullptr}, inception3b{nullptr}, + inception4a{nullptr}, inception4b{nullptr}, inception4c{nullptr}, + inception4d{nullptr}, inception4e{nullptr}, inception5a{nullptr}, + inception5b{nullptr}; + + _googlenetimpl::InceptionAux aux1{nullptr}, aux2{nullptr}; + + torch::nn::Dropout dropout{nullptr}; + torch::nn::Linear fc{nullptr}; + + GoogLeNetImpl( + int64_t num_classes = 1000, + bool aux_logits = true, + bool transform_input = false, + bool init_weights = true); + + void _initialize_weights(); + + GoogLeNetOutput forward(torch::Tensor x); +}; + +TORCH_MODULE(GoogLeNet); + +} // namespace models +} // namespace vision + +#endif // GOOGLENET_H diff --git a/torchvision/csrc/models/inception.cpp b/torchvision/csrc/models/inception.cpp new file mode 100644 index 00000000000..ebb35089d33 --- /dev/null +++ b/torchvision/csrc/models/inception.cpp @@ -0,0 +1,373 @@ +#include "inception.h" + +namespace vision { +namespace models { + +using Options = torch::nn::Conv2dOptions; + +namespace _inceptionimpl { +BasicConv2dImpl::BasicConv2dImpl( + torch::nn::Conv2dOptions options, + double std_dev) { + options.with_bias(false); + conv = torch::nn::Conv2d(options); + bn = torch::nn::BatchNorm( + torch::nn::BatchNormOptions(options.output_channels()).eps(0.001)); + + register_module("conv", conv); + register_module("bn", bn); + + torch::nn::init::normal_( + conv->weight, + 0, + std_dev); // Note: used instead of truncated normal initialization + + torch::nn::init::constant_(bn->weight, 1); + torch::nn::init::constant_(bn->bias, 0); +} + +torch::Tensor BasicConv2dImpl::forward(torch::Tensor x) { + x = conv->forward(x); + x = bn->forward(x); + return torch::relu_(x); +} + +InceptionAImpl::InceptionAImpl(int64_t in_channels, int64_t pool_features) + : branch1x1(Options(in_channels, 64, 1)), + branch5x5_1(Options(in_channels, 48, 1)), + branch5x5_2(Options(48, 64, 5).padding(2)), + branch3x3dbl_1(Options(in_channels, 64, 1)), + branch3x3dbl_2(Options(64, 96, 3).padding(1)), + branch3x3dbl_3(Options(96, 96, 3).padding(1)), + branch_pool(Options(in_channels, pool_features, 1)) { + register_module("branch1x1", branch1x1); + register_module("branch5x5_1", branch5x5_1); + register_module("branch5x5_2", branch5x5_2); + register_module("branch3x3dbl_1", branch3x3dbl_1); + register_module("branch3x3dbl_2", branch3x3dbl_2); + register_module("branch3x3dbl_3", branch3x3dbl_3); + register_module("branch_pool", branch_pool); +} + +torch::Tensor InceptionAImpl::forward(torch::Tensor x) { + auto branch1x1 = this->branch1x1->forward(x); + + auto branch5x5 = this->branch5x5_1->forward(x); + branch5x5 = this->branch5x5_2->forward(branch5x5); + + auto branch3x3dbl = this->branch3x3dbl_1->forward(x); + branch3x3dbl = this->branch3x3dbl_2->forward(branch3x3dbl); + branch3x3dbl = this->branch3x3dbl_3->forward(branch3x3dbl); + + auto branch_pool = torch::avg_pool2d(x, 3, 1, 1); + branch_pool = this->branch_pool->forward(branch_pool); + + return torch::cat({branch1x1, branch5x5, branch3x3dbl, branch_pool}, 1); +} + +InceptionBImpl::InceptionBImpl(int64_t in_channels) + : branch3x3(Options(in_channels, 384, 3).stride(2)), + branch3x3dbl_1(Options(in_channels, 64, 1)), + branch3x3dbl_2(Options(64, 96, 3).padding(1)), + branch3x3dbl_3(Options(96, 96, 3).stride(2)) { + register_module("branch3x3", branch3x3); + register_module("branch3x3dbl_1", branch3x3dbl_1); + register_module("branch3x3dbl_2", branch3x3dbl_2); + register_module("branch3x3dbl_3", branch3x3dbl_3); +} + +torch::Tensor InceptionBImpl::forward(torch::Tensor x) { + auto branch3x3 = this->branch3x3->forward(x); + + auto branch3x3dbl = this->branch3x3dbl_1->forward(x); + branch3x3dbl = this->branch3x3dbl_2->forward(branch3x3dbl); + branch3x3dbl = this->branch3x3dbl_3->forward(branch3x3dbl); + + auto branch_pool = torch::max_pool2d(x, 3, 2); + return torch::cat({branch3x3, branch3x3dbl, branch_pool}, 1); +} + +InceptionCImpl::InceptionCImpl(int64_t in_channels, int64_t channels_7x7) { + branch1x1 = BasicConv2d(Options(in_channels, 192, 1)); + + auto c7 = channels_7x7; + branch7x7_1 = BasicConv2d(Options(in_channels, c7, 1)); + branch7x7_2 = BasicConv2d(Options(c7, c7, {1, 7}).padding({0, 3})); + branch7x7_3 = BasicConv2d(Options(c7, 192, {7, 1}).padding({3, 0})); + + branch7x7dbl_1 = BasicConv2d(Options(in_channels, c7, 1)); + branch7x7dbl_2 = BasicConv2d(Options(c7, c7, {7, 1}).padding({3, 0})); + branch7x7dbl_3 = BasicConv2d(Options(c7, c7, {1, 7}).padding({0, 3})); + branch7x7dbl_4 = BasicConv2d(Options(c7, c7, {7, 1}).padding({3, 0})); + branch7x7dbl_5 = BasicConv2d(Options(c7, 192, {1, 7}).padding({0, 3})); + + branch_pool = BasicConv2d(Options(in_channels, 192, 1)); + + register_module("branch1x1", branch1x1); + register_module("branch7x7_1", branch7x7_1); + register_module("branch7x7_2", branch7x7_2); + register_module("branch7x7_3", branch7x7_3); + register_module("branch7x7dbl_1", branch7x7dbl_1); + register_module("branch7x7dbl_2", branch7x7dbl_2); + register_module("branch7x7dbl_3", branch7x7dbl_3); + register_module("branch7x7dbl_4", branch7x7dbl_4); + register_module("branch7x7dbl_5", branch7x7dbl_5); + register_module("branch_pool", branch_pool); +} + +torch::Tensor InceptionCImpl::forward(torch::Tensor x) { + auto branch1x1 = this->branch1x1->forward(x); + + auto branch7x7 = this->branch7x7_1->forward(x); + branch7x7 = this->branch7x7_2->forward(branch7x7); + branch7x7 = this->branch7x7_3->forward(branch7x7); + + auto branch7x7dbl = this->branch7x7dbl_1->forward(x); + branch7x7dbl = this->branch7x7dbl_2->forward(branch7x7dbl); + branch7x7dbl = this->branch7x7dbl_3->forward(branch7x7dbl); + branch7x7dbl = this->branch7x7dbl_4->forward(branch7x7dbl); + branch7x7dbl = this->branch7x7dbl_5->forward(branch7x7dbl); + + auto branch_pool = torch::avg_pool2d(x, 3, 1, 1); + branch_pool = this->branch_pool->forward(branch_pool); + + return torch::cat({branch1x1, branch7x7, branch7x7dbl, branch_pool}, 1); +} + +InceptionDImpl::InceptionDImpl(int64_t in_channels) + : branch3x3_1(Options(in_channels, 192, 1)), + branch3x3_2(Options(192, 320, 3).stride(2)), + branch7x7x3_1(Options(in_channels, 192, 1)), + branch7x7x3_2(Options(192, 192, {1, 7}).padding({0, 3})), + branch7x7x3_3(Options(192, 192, {7, 1}).padding({3, 0})), + branch7x7x3_4(Options(192, 192, 3).stride(2)) + +{ + register_module("branch3x3_1", branch3x3_1); + register_module("branch3x3_2", branch3x3_2); + register_module("branch7x7x3_1", branch7x7x3_1); + register_module("branch7x7x3_2", branch7x7x3_2); + register_module("branch7x7x3_3", branch7x7x3_3); + register_module("branch7x7x3_4", branch7x7x3_4); +} + +torch::Tensor InceptionDImpl::forward(torch::Tensor x) { + auto branch3x3 = this->branch3x3_1->forward(x); + branch3x3 = this->branch3x3_2->forward(branch3x3); + + auto branch7x7x3 = this->branch7x7x3_1->forward(x); + branch7x7x3 = this->branch7x7x3_2->forward(branch7x7x3); + branch7x7x3 = this->branch7x7x3_3->forward(branch7x7x3); + branch7x7x3 = this->branch7x7x3_4->forward(branch7x7x3); + + auto branch_pool = torch::max_pool2d(x, 3, 2); + return torch::cat({branch3x3, branch7x7x3, branch_pool}, 1); +} + +InceptionEImpl::InceptionEImpl(int64_t in_channels) + : branch1x1(Options(in_channels, 320, 1)), + branch3x3_1(Options(in_channels, 384, 1)), + branch3x3_2a(Options(384, 384, {1, 3}).padding({0, 1})), + branch3x3_2b(Options(384, 384, {3, 1}).padding({1, 0})), + branch3x3dbl_1(Options(in_channels, 448, 1)), + branch3x3dbl_2(Options(448, 384, 3).padding(1)), + branch3x3dbl_3a(Options(384, 384, {1, 3}).padding({0, 1})), + branch3x3dbl_3b(Options(384, 384, {3, 1}).padding({1, 0})), + branch_pool(Options(in_channels, 192, 1)) { + register_module("branch1x1", branch1x1); + register_module("branch3x3_1", branch3x3_1); + register_module("branch3x3_2a", branch3x3_2a); + register_module("branch3x3_2b", branch3x3_2b); + register_module("branch3x3dbl_1", branch3x3dbl_1); + register_module("branch3x3dbl_2", branch3x3dbl_2); + register_module("branch3x3dbl_3a", branch3x3dbl_3a); + register_module("branch3x3dbl_3b", branch3x3dbl_3b); + register_module("branch_pool", branch_pool); +} + +torch::Tensor InceptionEImpl::forward(torch::Tensor x) { + auto branch1x1 = this->branch1x1->forward(x); + + auto branch3x3 = this->branch3x3_1->forward(x); + branch3x3 = torch::cat( + { + this->branch3x3_2a->forward(branch3x3), + this->branch3x3_2b->forward(branch3x3), + }, + 1); + + auto branch3x3dbl = this->branch3x3dbl_1->forward(x); + branch3x3dbl = this->branch3x3dbl_2->forward(branch3x3dbl); + branch3x3dbl = torch::cat( + {this->branch3x3dbl_3a->forward(branch3x3dbl), + this->branch3x3dbl_3b->forward(branch3x3dbl)}, + 1); + + auto branch_pool = torch::avg_pool2d(x, 3, 1, 1); + branch_pool = this->branch_pool->forward(branch_pool); + + return torch::cat({branch1x1, branch3x3, branch3x3dbl, branch_pool}, 1); +} + +InceptionAuxImpl::InceptionAuxImpl(int64_t in_channels, int64_t num_classes) + : conv0(BasicConv2d(Options(in_channels, 128, 1))), + conv1(BasicConv2d(Options(128, 768, 5), 0.01)), + fc(768, num_classes) { + torch::nn::init::normal_( + fc->weight, + 0, + 0.001); // Note: used instead of truncated normal initialization + + register_module("conv0", conv0); + register_module("conv1", conv1); + register_module("fc", fc); +} + +torch::Tensor InceptionAuxImpl::forward(torch::Tensor x) { + // N x 768 x 17 x 17 + x = torch::avg_pool2d(x, 5, 3); + // N x 768 x 5 x 5 + x = conv0->forward(x); + // N x 128 x 5 x 5 + x = conv1->forward(x); + // N x 768 x 1 x 1 + x = torch::adaptive_avg_pool2d(x, {1, 1}); + // N x 768 x 1 x 1 + x = x.view({x.size(0), -1}); + // N x 768 + x = fc->forward(x); + // N x 1000 (num_classes) + return x; +} + +} // namespace _inceptionimpl + +InceptionV3Impl::InceptionV3Impl( + int64_t num_classes, + bool aux_logits, + bool transform_input) + : aux_logits(aux_logits), transform_input(transform_input) { + Conv2d_1a_3x3 = _inceptionimpl::BasicConv2d(Options(3, 32, 3).stride(2)); + Conv2d_2a_3x3 = _inceptionimpl::BasicConv2d(Options(32, 32, 3)); + Conv2d_2b_3x3 = _inceptionimpl::BasicConv2d(Options(32, 64, 3).padding(1)); + Conv2d_3b_1x1 = _inceptionimpl::BasicConv2d(Options(64, 80, 1)); + Conv2d_4a_3x3 = _inceptionimpl::BasicConv2d(Options(80, 192, 3)); + + Mixed_5b = _inceptionimpl::InceptionA(192, 32); + Mixed_5c = _inceptionimpl::InceptionA(256, 64); + Mixed_5d = _inceptionimpl::InceptionA(288, 64); + + Mixed_6a = _inceptionimpl::InceptionB(288); + Mixed_6b = _inceptionimpl::InceptionC(768, 128); + Mixed_6c = _inceptionimpl::InceptionC(768, 160); + Mixed_6d = _inceptionimpl::InceptionC(768, 160); + Mixed_6e = _inceptionimpl::InceptionC(768, 192); + + if (aux_logits) + AuxLogits = _inceptionimpl::InceptionAux(768, num_classes); + + Mixed_7a = _inceptionimpl::InceptionD(768); + Mixed_7b = _inceptionimpl::InceptionE(1280); + Mixed_7c = _inceptionimpl::InceptionE(2048); + + fc = torch::nn::Linear(2048, num_classes); + torch::nn::init::normal_( + fc->weight, + 0, + 0.1); // Note: used instead of truncated normal initialization + + register_module("Conv2d_1a_3x3", Conv2d_1a_3x3); + register_module("Conv2d_2a_3x3", Conv2d_2a_3x3); + register_module("Conv2d_2b_3x3", Conv2d_2b_3x3); + register_module("Conv2d_3b_1x1", Conv2d_3b_1x1); + register_module("Conv2d_4a_3x3", Conv2d_4a_3x3); + register_module("Mixed_5b", Mixed_5b); + register_module("Mixed_5c", Mixed_5c); + register_module("Mixed_5d", Mixed_5d); + register_module("Mixed_6a", Mixed_6a); + register_module("Mixed_6b", Mixed_6b); + register_module("Mixed_6c", Mixed_6c); + register_module("Mixed_6d", Mixed_6d); + register_module("Mixed_6e", Mixed_6e); + + if (!AuxLogits.is_empty()) + register_module("AuxLogits", AuxLogits); + + register_module("Mixed_7a", Mixed_7a); + register_module("Mixed_7b", Mixed_7b); + register_module("Mixed_7c", Mixed_7c); + register_module("fc", fc); +} + +InceptionV3Output InceptionV3Impl::forward(torch::Tensor x) { + if (transform_input) { + auto x_ch0 = torch::unsqueeze(x.select(1, 0), 1) * (0.229 / 0.5) + + (0.485 - 0.5) / 0.5; + auto x_ch1 = torch::unsqueeze(x.select(1, 1), 1) * (0.224 / 0.5) + + (0.456 - 0.5) / 0.5; + auto x_ch2 = torch::unsqueeze(x.select(1, 2), 1) * (0.225 / 0.5) + + (0.406 - 0.5) / 0.5; + + x = torch::cat({x_ch0, x_ch1, x_ch2}, 1); + } + + // N x 3 x 299 x 299 + x = Conv2d_1a_3x3->forward(x); + // N x 32 x 149 x 149 + x = Conv2d_2a_3x3->forward(x); + // N x 32 x 147 x 147 + x = Conv2d_2b_3x3->forward(x); + // N x 64 x 147 x 147 + x = torch::max_pool2d(x, 3, 2); + // N x 64 x 73 x 73 + x = Conv2d_3b_1x1->forward(x); + // N x 80 x 73 x 73 + x = Conv2d_4a_3x3->forward(x); + // N x 192 x 71 x 71 + x = torch::max_pool2d(x, 3, 2); + // N x 192 x 35 x 35 + x = Mixed_5b->forward(x); + // N x 256 x 35 x 35 + x = Mixed_5c->forward(x); + // N x 288 x 35 x 35 + x = Mixed_5d->forward(x); + // N x 288 x 35 x 35 + x = Mixed_6a->forward(x); + // N x 768 x 17 x 17 + x = Mixed_6b->forward(x); + // N x 768 x 17 x 17 + x = Mixed_6c->forward(x); + // N x 768 x 17 x 17 + x = Mixed_6d->forward(x); + // N x 768 x 17 x 17 + x = Mixed_6e->forward(x); + // N x 768 x 17 x 17 + + torch::Tensor aux; + if (is_training() && aux_logits) + aux = AuxLogits->forward(x); + + // N x 768 x 17 x 17 + x = Mixed_7a->forward(x); + // N x 1280 x 8 x 8 + x = Mixed_7b->forward(x); + // N x 2048 x 8 x 8 + x = Mixed_7c->forward(x); + // N x 2048 x 8 x 8 + x = torch::adaptive_avg_pool2d(x, {1, 1}); + // N x 2048 x 1 x 1 + x = torch::dropout(x, 0.5, is_training()); + // N x 2048 x 1 x 1 + x = x.view({x.size(0), -1}); + // N x 2048 + x = fc->forward(x); + // N x 1000 (num_classes) + + if (is_training() && aux_logits) + return {x, aux}; + return {x, {}}; +} + +// namespace _inceptionimpl +} // namespace models +} // namespace vision diff --git a/torchvision/csrc/models/inception.h b/torchvision/csrc/models/inception.h new file mode 100644 index 00000000000..213152f98dc --- /dev/null +++ b/torchvision/csrc/models/inception.h @@ -0,0 +1,125 @@ +#ifndef INCEPTION_H +#define INCEPTION_H + +#include + +namespace vision { +namespace models { +namespace _inceptionimpl { +struct BasicConv2dImpl : torch::nn::Module { + torch::nn::Conv2d conv{nullptr}; + torch::nn::BatchNorm bn{nullptr}; + + BasicConv2dImpl(torch::nn::Conv2dOptions options, double std_dev = 0.1); + + torch::Tensor forward(torch::Tensor x); +}; + +TORCH_MODULE(BasicConv2d); + +struct InceptionAImpl : torch::nn::Module { + BasicConv2d branch1x1, branch5x5_1, branch5x5_2, branch3x3dbl_1, + branch3x3dbl_2, branch3x3dbl_3, branch_pool; + + InceptionAImpl(int64_t in_channels, int64_t pool_features); + + torch::Tensor forward(torch::Tensor x); +}; + +struct InceptionBImpl : torch::nn::Module { + BasicConv2d branch3x3, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3; + + InceptionBImpl(int64_t in_channels); + + torch::Tensor forward(torch::Tensor x); +}; + +struct InceptionCImpl : torch::nn::Module { + BasicConv2d branch1x1{nullptr}, branch7x7_1{nullptr}, branch7x7_2{nullptr}, + branch7x7_3{nullptr}, branch7x7dbl_1{nullptr}, branch7x7dbl_2{nullptr}, + branch7x7dbl_3{nullptr}, branch7x7dbl_4{nullptr}, branch7x7dbl_5{nullptr}, + branch_pool{nullptr}; + + InceptionCImpl(int64_t in_channels, int64_t channels_7x7); + + torch::Tensor forward(torch::Tensor x); +}; + +struct InceptionDImpl : torch::nn::Module { + BasicConv2d branch3x3_1, branch3x3_2, branch7x7x3_1, branch7x7x3_2, + branch7x7x3_3, branch7x7x3_4; + + InceptionDImpl(int64_t in_channels); + + torch::Tensor forward(torch::Tensor x); +}; + +struct InceptionEImpl : torch::nn::Module { + BasicConv2d branch1x1, branch3x3_1, branch3x3_2a, branch3x3_2b, + branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3a, branch3x3dbl_3b, + branch_pool; + + InceptionEImpl(int64_t in_channels); + + torch::Tensor forward(torch::Tensor x); +}; + +struct InceptionAuxImpl : torch::nn::Module { + BasicConv2d conv0; + BasicConv2d conv1; + torch::nn::Linear fc; + + InceptionAuxImpl(int64_t in_channels, int64_t num_classes); + + torch::Tensor forward(torch::Tensor x); +}; + +TORCH_MODULE(InceptionA); +TORCH_MODULE(InceptionB); +TORCH_MODULE(InceptionC); +TORCH_MODULE(InceptionD); +TORCH_MODULE(InceptionE); +TORCH_MODULE(InceptionAux); + +} // namespace _inceptionimpl + +struct InceptionV3Output { + torch::Tensor output; + torch::Tensor aux; +}; + +// Inception v3 model architecture from +//"Rethinking the Inception Architecture for Computer Vision" +// +struct InceptionV3Impl : torch::nn::Module { + bool aux_logits, transform_input; + + _inceptionimpl::BasicConv2d Conv2d_1a_3x3{nullptr}, Conv2d_2a_3x3{nullptr}, + Conv2d_2b_3x3{nullptr}, Conv2d_3b_1x1{nullptr}, Conv2d_4a_3x3{nullptr}; + + _inceptionimpl::InceptionA Mixed_5b{nullptr}, Mixed_5c{nullptr}, + Mixed_5d{nullptr}; + _inceptionimpl::InceptionB Mixed_6a{nullptr}; + _inceptionimpl::InceptionC Mixed_6b{nullptr}, Mixed_6c{nullptr}, + Mixed_6d{nullptr}, Mixed_6e{nullptr}; + _inceptionimpl::InceptionD Mixed_7a{nullptr}; + _inceptionimpl::InceptionE Mixed_7b{nullptr}, Mixed_7c{nullptr}; + + torch::nn::Linear fc{nullptr}; + + _inceptionimpl::InceptionAux AuxLogits{nullptr}; + + InceptionV3Impl( + int64_t num_classes = 1000, + bool aux_logits = true, + bool transform_input = false); + + InceptionV3Output forward(torch::Tensor x); +}; + +TORCH_MODULE(InceptionV3); + +} // namespace models +} // namespace vision + +#endif // INCEPTION_H diff --git a/torchvision/csrc/models/mobilenet.cpp b/torchvision/csrc/models/mobilenet.cpp new file mode 100644 index 00000000000..e38b3d75594 --- /dev/null +++ b/torchvision/csrc/models/mobilenet.cpp @@ -0,0 +1,135 @@ +#include "mobilenet.h" + +#include "modelsimpl.h" + +namespace vision { +namespace models { +using Options = torch::nn::Conv2dOptions; + +struct ConvBNReLUImpl : torch::nn::SequentialImpl { + ConvBNReLUImpl( + int64_t in_planes, + int64_t out_planes, + int64_t kernel_size = 3, + int64_t stride = 1, + int64_t groups = 1) { + auto padding = (kernel_size - 1) / 2; + + push_back(torch::nn::Conv2d(Options(in_planes, out_planes, kernel_size) + .stride(stride) + .padding(padding) + .groups(groups) + .with_bias(false))); + push_back(torch::nn::BatchNorm(out_planes)); + push_back(torch::nn::Functional(modelsimpl::relu6_)); + } + + torch::Tensor forward(torch::Tensor x) { + return torch::nn::SequentialImpl::forward(x); + } +}; + +TORCH_MODULE(ConvBNReLU); + +struct MobileNetInvertedResidualImpl : torch::nn::Module { + int64_t stride; + bool use_res_connect; + torch::nn::Sequential conv; + + MobileNetInvertedResidualImpl( + int64_t input, + int64_t output, + int64_t stride, + double expand_ratio) + : stride(stride), use_res_connect(stride == 1 && input == output) { + auto double_compare = [](double a, double b) { + return double(std::abs(a - b)) < std::numeric_limits::epsilon(); + }; + + assert(stride == 1 || stride == 2); + auto hidden_dim = int64_t(std::round(input * expand_ratio)); + + if (!double_compare(expand_ratio, 1)) + conv->push_back(ConvBNReLU(input, hidden_dim, 1)); + + conv->push_back(ConvBNReLU(hidden_dim, hidden_dim, 3, stride, hidden_dim)); + conv->push_back(torch::nn::Conv2d( + Options(hidden_dim, output, 1).stride(1).padding(0).with_bias(false))); + conv->push_back(torch::nn::BatchNorm(output)); + + register_module("conv", conv); + } + + torch::Tensor forward(torch::Tensor x) { + if (use_res_connect) + return x + conv->forward(x); + return conv->forward(x); + } +}; + +TORCH_MODULE(MobileNetInvertedResidual); + +MobileNetV2Impl::MobileNetV2Impl(int64_t num_classes, double width_mult) { + using Block = MobileNetInvertedResidual; + int64_t input_channel = 32; + int64_t last_channel = 1280; + + std::vector> inverted_residual_settings = { + // t, c, n, s + {1, 16, 1, 1}, + {6, 24, 2, 2}, + {6, 32, 3, 2}, + {6, 64, 4, 2}, + {6, 96, 3, 1}, + {6, 160, 3, 2}, + {6, 320, 1, 1}, + }; + + input_channel = int64_t(input_channel * width_mult); + this->last_channel = int64_t(last_channel * std::max(1.0, width_mult)); + features->push_back(ConvBNReLU(3, input_channel, 3, 2)); + + for (auto setting : inverted_residual_settings) { + auto output_channel = int64_t(setting[1] * width_mult); + + for (int64_t i = 0; i < setting[2]; ++i) { + auto stride = i == 0 ? setting[3] : 1; + features->push_back( + Block(input_channel, output_channel, stride, setting[0])); + input_channel = output_channel; + } + } + + features->push_back(ConvBNReLU(input_channel, this->last_channel, 1)); + + classifier->push_back(torch::nn::Dropout(0.2)); + classifier->push_back(torch::nn::Linear(this->last_channel, num_classes)); + + register_module("features", features); + register_module("classifier", classifier); + + for (auto& module : modules(/*include_self=*/false)) { + if (auto M = dynamic_cast(module.get())) { + torch::nn::init::kaiming_normal_( + M->weight, 0, torch::nn::init::FanMode::FanOut); + if (M->options.with_bias()) + torch::nn::init::zeros_(M->bias); + } else if (auto M = dynamic_cast(module.get())) { + torch::nn::init::ones_(M->weight); + torch::nn::init::zeros_(M->bias); + } else if (auto M = dynamic_cast(module.get())) { + torch::nn::init::normal_(M->weight, 0, 0.01); + torch::nn::init::zeros_(M->bias); + } + } +} + +torch::Tensor MobileNetV2Impl::forward(at::Tensor x) { + x = features->forward(x); + x = x.mean({2, 3}); + x = classifier->forward(x); + return x; +} + +} // namespace models +} // namespace vision diff --git a/torchvision/csrc/models/mobilenet.h b/torchvision/csrc/models/mobilenet.h new file mode 100644 index 00000000000..6cb296d820a --- /dev/null +++ b/torchvision/csrc/models/mobilenet.h @@ -0,0 +1,21 @@ +#ifndef MOBILENET_H +#define MOBILENET_H + +#include + +namespace vision { +namespace models { +struct MobileNetV2Impl : torch::nn::Module { + int64_t last_channel; + torch::nn::Sequential features, classifier; + + MobileNetV2Impl(int64_t num_classes = 1000, double width_mult = 1.0); + + torch::Tensor forward(torch::Tensor x); +}; + +TORCH_MODULE(MobileNetV2); +} // namespace models +} // namespace vision + +#endif // MOBILENET_H diff --git a/torchvision/csrc/models/models.h b/torchvision/csrc/models/models.h new file mode 100644 index 00000000000..55127668be5 --- /dev/null +++ b/torchvision/csrc/models/models.h @@ -0,0 +1,14 @@ +#ifndef MODELS_H +#define MODELS_H + +#include "alexnet.h" +#include "densenet.h" +#include "googlenet.h" +#include "inception.h" +#include "mobilenet.h" +#include "resnet.h" +#include "shufflenetv2.h" +#include "squeezenet.h" +#include "vgg.h" + +#endif // MODELS_H diff --git a/torchvision/csrc/models/modelsimpl.h b/torchvision/csrc/models/modelsimpl.h new file mode 100644 index 00000000000..5d0c9f3467d --- /dev/null +++ b/torchvision/csrc/models/modelsimpl.h @@ -0,0 +1,42 @@ +#ifndef MODELSIMPL_H +#define MODELSIMPL_H + +#include + +namespace vision { +namespace models { +namespace modelsimpl { + +// TODO here torch::relu_ and torch::adaptive_avg_pool2d wrapped in +// torch::nn::Fuctional don't work. so keeping these for now + +inline torch::Tensor& relu_(torch::Tensor x) { + return torch::relu_(x); +} + +inline torch::Tensor relu6_(torch::Tensor x) { + return torch::clamp_(x, 0, 6); +} + +inline torch::Tensor adaptive_avg_pool2d( + torch::Tensor x, + torch::ExpandingArray<2> output_size) { + return torch::adaptive_avg_pool2d(x, output_size); +} + +inline torch::Tensor max_pool2d( + torch::Tensor x, + torch::ExpandingArray<2> kernel_size, + torch::ExpandingArray<2> stride) { + return torch::max_pool2d(x, kernel_size, stride); +} + +inline bool double_compare(double a, double b) { + return double(std::abs(a - b)) < std::numeric_limits::epsilon(); +}; + +} // namespace modelsimpl +} // namespace models +} // namespace vision + +#endif // MODELSIMPL_H diff --git a/torchvision/csrc/models/resnet.cpp b/torchvision/csrc/models/resnet.cpp new file mode 100644 index 00000000000..da9a2e25f25 --- /dev/null +++ b/torchvision/csrc/models/resnet.cpp @@ -0,0 +1,149 @@ +#include "resnet.h" + +namespace vision { +namespace models { +namespace _resnetimpl { +torch::nn::Conv2d conv3x3( + int64_t in, + int64_t out, + int64_t stride, + int64_t groups) { + torch::nn::Conv2dOptions O(in, out, 3); + O.padding(1).stride(stride).groups(groups).with_bias(false); + return torch::nn::Conv2d(O); +} + +torch::nn::Conv2d conv1x1(int64_t in, int64_t out, int64_t stride) { + torch::nn::Conv2dOptions O(in, out, 1); + O.stride(stride).with_bias(false); + return torch::nn::Conv2d(O); +} + +int BasicBlock::expansion = 1; +int Bottleneck::expansion = 4; + +BasicBlock::BasicBlock( + int64_t inplanes, + int64_t planes, + int64_t stride, + torch::nn::Sequential downsample, + int64_t groups, + int64_t base_width) + : stride(stride), downsample(downsample) { + if (groups != 1 or base_width != 64) { + std::cerr << "BasicBlock only supports groups=1 and base_width=64" + << std::endl; + assert(false); + } + + // Both conv1 and downsample layers downsample the input when stride != 1 + conv1 = conv3x3(inplanes, planes, stride); + conv2 = conv3x3(planes, planes); + + bn1 = torch::nn::BatchNorm(planes); + bn2 = torch::nn::BatchNorm(planes); + + register_module("conv1", conv1); + register_module("conv2", conv2); + + register_module("bn1", bn1); + register_module("bn2", bn2); + + if (!downsample.is_empty()) + register_module("downsample", this->downsample); +} + +Bottleneck::Bottleneck( + int64_t inplanes, + int64_t planes, + int64_t stride, + torch::nn::Sequential downsample, + int64_t groups, + int64_t base_width) + : stride(stride), downsample(downsample) { + auto width = int64_t(planes * (base_width / 64.)) * groups; + + // Both conv2 and downsample layers downsample the input when stride != 1 + conv1 = conv1x1(inplanes, width); + conv2 = conv3x3(width, width, stride, groups); + conv3 = conv1x1(width, planes * expansion); + + bn1 = torch::nn::BatchNorm(width); + bn2 = torch::nn::BatchNorm(width); + bn3 = torch::nn::BatchNorm(planes * expansion); + + register_module("conv1", conv1); + register_module("conv2", conv2); + register_module("conv3", conv3); + + register_module("bn1", bn1); + register_module("bn2", bn2); + register_module("bn3", bn3); + + if (!downsample.is_empty()) + register_module("downsample", this->downsample); +} + +torch::Tensor Bottleneck::forward(torch::Tensor X) { + auto identity = X; + + auto out = conv1->forward(X); + out = bn1->forward(out).relu_(); + + out = conv2->forward(out); + out = bn2->forward(out).relu_(); + + out = conv3->forward(out); + out = bn3->forward(out); + + if (!downsample.is_empty()) + identity = downsample->forward(X); + + out += identity; + return out.relu_(); +} + +torch::Tensor BasicBlock::forward(torch::Tensor x) { + auto identity = x; + + auto out = conv1->forward(x); + out = bn1->forward(out).relu_(); + + out = conv2->forward(out); + out = bn2->forward(out); + + if (!downsample.is_empty()) + identity = downsample->forward(x); + + out += identity; + return out.relu_(); +} +} // namespace _resnetimpl + +ResNet18Impl::ResNet18Impl(int64_t num_classes, bool zero_init_residual) + : ResNetImpl({2, 2, 2, 2}, num_classes, zero_init_residual) {} + +ResNet34Impl::ResNet34Impl(int64_t num_classes, bool zero_init_residual) + : ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual) {} + +ResNet50Impl::ResNet50Impl(int64_t num_classes, bool zero_init_residual) + : ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual) {} + +ResNet101Impl::ResNet101Impl(int64_t num_classes, bool zero_init_residual) + : ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual) {} + +ResNet152Impl::ResNet152Impl(int64_t num_classes, bool zero_init_residual) + : ResNetImpl({3, 8, 36, 3}, num_classes, zero_init_residual) {} + +ResNext50_32x4dImpl::ResNext50_32x4dImpl( + int64_t num_classes, + bool zero_init_residual) + : ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual, 32, 4) {} + +ResNext101_32x8dImpl::ResNext101_32x8dImpl( + int64_t num_classes, + bool zero_init_residual) + : ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual, 32, 8) {} + +} // namespace models +} // namespace vision diff --git a/torchvision/csrc/models/resnet.h b/torchvision/csrc/models/resnet.h new file mode 100644 index 00000000000..a01c2da7dc5 --- /dev/null +++ b/torchvision/csrc/models/resnet.h @@ -0,0 +1,235 @@ +#ifndef RESNET_H +#define RESNET_H + +#include + +namespace vision { +namespace models { +template +struct ResNetImpl; + +namespace _resnetimpl { +// 3x3 convolution with padding +torch::nn::Conv2d conv3x3( + int64_t in, + int64_t out, + int64_t stride = 1, + int64_t groups = 1); + +// 1x1 convolution +torch::nn::Conv2d conv1x1(int64_t in, int64_t out, int64_t stride = 1); + +struct BasicBlock : torch::nn::Module { + template + friend struct vision::models::ResNetImpl; + + int64_t stride; + torch::nn::Sequential downsample; + + torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}; + torch::nn::BatchNorm bn1{nullptr}, bn2{nullptr}; + + static int expansion; + + BasicBlock( + int64_t inplanes, + int64_t planes, + int64_t stride = 1, + torch::nn::Sequential downsample = nullptr, + int64_t groups = 1, + int64_t base_width = 64); + + torch::Tensor forward(torch::Tensor x); +}; + +struct Bottleneck : torch::nn::Module { + template + friend struct vision::models::ResNetImpl; + + int64_t stride; + torch::nn::Sequential downsample; + + torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr}; + torch::nn::BatchNorm bn1{nullptr}, bn2{nullptr}, bn3{nullptr}; + + static int expansion; + + Bottleneck( + int64_t inplanes, + int64_t planes, + int64_t stride = 1, + torch::nn::Sequential downsample = nullptr, + int64_t groups = 1, + int64_t base_width = 64); + + torch::Tensor forward(torch::Tensor X); +}; +} // namespace _resnetimpl + +template +struct ResNetImpl : torch::nn::Module { + int64_t groups, base_width, inplanes; + torch::nn::Conv2d conv1; + torch::nn::BatchNorm bn1; + torch::nn::Linear fc; + torch::nn::Sequential layer1, layer2, layer3, layer4; + + torch::nn::Sequential _make_layer( + int64_t planes, + int64_t blocks, + int64_t stride = 1); + + ResNetImpl( + const std::vector& layers, + int64_t num_classes = 1000, + bool zero_init_residual = false, + int64_t groups = 1, + int64_t width_per_group = 64); + + torch::Tensor forward(torch::Tensor X); +}; + +template +torch::nn::Sequential ResNetImpl::_make_layer( + int64_t planes, + int64_t blocks, + int64_t stride) { + torch::nn::Sequential downsample = nullptr; + if (stride != 1 || inplanes != planes * Block::expansion) { + downsample = torch::nn::Sequential( + _resnetimpl::conv1x1(inplanes, planes * Block::expansion, stride), + torch::nn::BatchNorm(planes * Block::expansion)); + } + + torch::nn::Sequential layers; + layers->push_back( + Block(inplanes, planes, stride, downsample, groups, base_width)); + + inplanes = planes * Block::expansion; + + for (int i = 1; i < blocks; ++i) + layers->push_back(Block(inplanes, planes, 1, nullptr, groups, base_width)); + + return layers; +} + +template +ResNetImpl::ResNetImpl( + const std::vector& layers, + int64_t num_classes, + bool zero_init_residual, + int64_t groups, + int64_t width_per_group) + : groups(groups), + base_width(width_per_group), + inplanes(64), + conv1(torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).with_bias( + false)), + bn1(64), + layer1(_make_layer(64, layers[0])), + layer2(_make_layer(128, layers[1], 2)), + layer3(_make_layer(256, layers[2], 2)), + layer4(_make_layer(512, layers[3], 2)), + fc(512 * Block::expansion, num_classes) { + register_module("conv1", conv1); + register_module("bn1", bn1); + register_module("fc", fc); + + register_module("layer1", layer1); + register_module("layer2", layer2); + register_module("layer3", layer3); + register_module("layer4", layer4); + + for (auto& module : modules(/*include_self=*/false)) { + if (auto M = dynamic_cast(module.get())) + torch::nn::init::kaiming_normal_( + M->weight, + /*a=*/0, + torch::nn::init::FanMode::FanOut, + torch::nn::init::Nonlinearity::ReLU); + else if (auto M = dynamic_cast(module.get())) { + torch::nn::init::constant_(M->weight, 1); + torch::nn::init::constant_(M->bias, 0); + } + } + + // Zero-initialize the last BN in each residual branch, so that the residual + // branch starts with zeros, and each residual block behaves like an + // identity. This improves the model by 0.2~0.3% according to + // https://arxiv.org/abs/1706.02677 + if (zero_init_residual) + for (auto& module : modules(/*include_self=*/false)) { + if (auto* M = dynamic_cast<_resnetimpl::Bottleneck*>(module.get())) + torch::nn::init::constant_(M->bn3->weight, 0); + else if (auto* M = dynamic_cast<_resnetimpl::BasicBlock*>(module.get())) + torch::nn::init::constant_(M->bn2->weight, 0); + } +} + +template +torch::Tensor ResNetImpl::forward(torch::Tensor x) { + x = conv1->forward(x); + x = bn1->forward(x).relu_(); + x = torch::max_pool2d(x, 3, 2, 1); + + x = layer1->forward(x); + x = layer2->forward(x); + x = layer3->forward(x); + x = layer4->forward(x); + + x = torch::adaptive_avg_pool2d(x, {1, 1}); + x = x.reshape({x.size(0), -1}); + x = fc->forward(x); + + return x; +} + +struct ResNet18Impl : ResNetImpl<_resnetimpl::BasicBlock> { + ResNet18Impl(int64_t num_classes = 1000, bool zero_init_residual = false); +}; + +struct ResNet34Impl : ResNetImpl<_resnetimpl::BasicBlock> { + ResNet34Impl(int64_t num_classes = 1000, bool zero_init_residual = false); +}; + +struct ResNet50Impl : ResNetImpl<_resnetimpl::Bottleneck> { + ResNet50Impl(int64_t num_classes = 1000, bool zero_init_residual = false); +}; + +struct ResNet101Impl : ResNetImpl<_resnetimpl::Bottleneck> { + ResNet101Impl(int64_t num_classes = 1000, bool zero_init_residual = false); +}; + +struct ResNet152Impl : ResNetImpl<_resnetimpl::Bottleneck> { + ResNet152Impl(int64_t num_classes = 1000, bool zero_init_residual = false); +}; + +struct ResNext50_32x4dImpl : ResNetImpl<_resnetimpl::Bottleneck> { + ResNext50_32x4dImpl( + int64_t num_classes = 1000, + bool zero_init_residual = false); +}; + +struct ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> { + ResNext101_32x8dImpl( + int64_t num_classes = 1000, + bool zero_init_residual = false); +}; + +template +struct ResNet : torch::nn::ModuleHolder> { + using torch::nn::ModuleHolder>::ModuleHolder; +}; + +TORCH_MODULE(ResNet18); +TORCH_MODULE(ResNet34); +TORCH_MODULE(ResNet50); +TORCH_MODULE(ResNet101); +TORCH_MODULE(ResNet152); +TORCH_MODULE(ResNext50_32x4d); +TORCH_MODULE(ResNext101_32x8d); + +} // namespace models +} // namespace vision + +#endif // RESNET_H diff --git a/torchvision/csrc/models/shufflenetv2.cpp b/torchvision/csrc/models/shufflenetv2.cpp new file mode 100644 index 00000000000..79fad1a4a13 --- /dev/null +++ b/torchvision/csrc/models/shufflenetv2.cpp @@ -0,0 +1,185 @@ +#include "shufflenetv2.h" + +#include "modelsimpl.h" + +namespace vision { +namespace models { + +using Options = torch::nn::Conv2dOptions; + +torch::Tensor channel_shuffle(torch::Tensor x, int64_t groups) { + auto shape = x.sizes(); + auto batchsize = shape[0]; + auto num_channels = shape[1]; + auto height = shape[2]; + auto width = shape[3]; + + auto channels_per_group = num_channels / groups; + + x = x.view({batchsize, groups, channels_per_group, height, width}); + x = torch::transpose(x, 1, 2).contiguous(); + x = x.view({batchsize, -1, height, width}); + + return x; +} + +torch::nn::Conv2d conv11(int64_t input, int64_t output) { + Options opts(input, output, 1); + opts = opts.stride(1).padding(0).with_bias(false); + return torch::nn::Conv2d(opts); +} + +torch::nn::Conv2d conv33(int64_t input, int64_t output, int64_t stride) { + Options opts(input, output, 3); + opts = opts.stride(stride).padding(1).with_bias(false).groups(input); + return torch::nn::Conv2d(opts); +} + +struct ShuffleNetV2InvertedResidualImpl : torch::nn::Module { + int64_t stride; + torch::nn::Sequential branch1{nullptr}, branch2{nullptr}; + + ShuffleNetV2InvertedResidualImpl(int64_t inp, int64_t oup, int64_t stride) + : stride(stride) { + if (stride < 1 || stride > 3) { + std::cerr << "illegal stride value'" << std::endl; + assert(false); + } + + auto branch_features = oup / 2; + assert(stride != 1 || inp == branch_features << 1); + + if (stride > 1) { + branch1 = torch::nn::Sequential( + conv33(inp, inp, stride), + torch::nn::BatchNorm(inp), + conv11(inp, branch_features), + torch::nn::BatchNorm(branch_features), + torch::nn::Functional(modelsimpl::relu_)); + } + + branch2 = torch::nn::Sequential( + conv11(stride > 1 ? inp : branch_features, branch_features), + torch::nn::BatchNorm(branch_features), + torch::nn::Functional(modelsimpl::relu_), + conv33(branch_features, branch_features, stride), + torch::nn::BatchNorm(branch_features), + conv11(branch_features, branch_features), + torch::nn::BatchNorm(branch_features), + torch::nn::Functional(modelsimpl::relu_)); + + if (!branch1.is_empty()) + register_module("branch1", branch1); + + register_module("branch2", branch2); + } + + torch::Tensor forward(torch::Tensor x) { + torch::Tensor out; + + if (stride == 1) { + auto chunks = x.chunk(2, 1); + out = torch::cat({chunks[0], branch2->forward(chunks[1])}, 1); + } else + out = torch::cat({branch1->forward(x), branch2->forward(x)}, 1); + + out = channel_shuffle(out, 2); + return out; + } +}; + +TORCH_MODULE(ShuffleNetV2InvertedResidual); + +ShuffleNetV2Impl::ShuffleNetV2Impl( + const std::vector& stage_repeats, + const std::vector& stage_out_channels, + int64_t num_classes) { + if (stage_repeats.size() != 3) { + std::cerr << "expected stage_repeats as vector of 3 positive ints" + << std::endl; + assert(false); + } + + if (stage_out_channels.size() != 5) { + std::cerr << "expected stage_out_channels as vector of 5 positive ints" + << std::endl; + assert(false); + } + + _stage_out_channels = stage_out_channels; + int64_t input_channels = 3; + auto output_channels = _stage_out_channels[0]; + + conv1 = torch::nn::Sequential( + torch::nn::Conv2d(Options(input_channels, output_channels, 3) + .stride(2) + .padding(1) + .with_bias(false)), + torch::nn::BatchNorm(output_channels), + torch::nn::Functional(modelsimpl::relu_)); + + input_channels = output_channels; + std::vector stages = {stage2, stage3, stage4}; + + for (size_t i = 0; i < stages.size(); ++i) { + auto& seq = stages[i]; + auto repeats = stage_repeats[i]; + auto output_channels = _stage_out_channels[i + 1]; + + seq->push_back( + ShuffleNetV2InvertedResidual(input_channels, output_channels, 2)); + + for (size_t j = 0; j < size_t(repeats - 1); ++j) + seq->push_back( + ShuffleNetV2InvertedResidual(output_channels, output_channels, 1)); + + input_channels = output_channels; + } + + output_channels = _stage_out_channels.back(); + conv5 = torch::nn::Sequential( + torch::nn::Conv2d(Options(input_channels, output_channels, 1) + .stride(1) + .padding(0) + .with_bias(false)), + torch::nn::BatchNorm(output_channels), + torch::nn::Functional(modelsimpl::relu_)); + + fc = torch::nn::Linear(output_channels, num_classes); + + register_module("conv1", conv1); + register_module("stage2", stage2); + register_module("stage3", stage3); + register_module("stage4", stage4); + register_module("conv2", conv5); + register_module("fc", fc); +} + +torch::Tensor ShuffleNetV2Impl::forward(torch::Tensor x) { + x = conv1->forward(x); + x = torch::max_pool2d(x, 3, 2, 1); + + x = stage2->forward(x); + x = stage3->forward(x); + x = stage4->forward(x); + x = conv5->forward(x); + + x = x.mean({2, 3}); + x = fc->forward(x); + return x; +} + +ShuffleNetV2_x0_5Impl::ShuffleNetV2_x0_5Impl(int64_t num_classes) + : ShuffleNetV2Impl({4, 8, 4}, {24, 48, 96, 192, 1024}, num_classes) {} + +ShuffleNetV2_x1_0Impl::ShuffleNetV2_x1_0Impl(int64_t num_classes) + : ShuffleNetV2Impl({4, 8, 4}, {24, 116, 232, 464, 1024}, num_classes) {} + +ShuffleNetV2_x1_5Impl::ShuffleNetV2_x1_5Impl(int64_t num_classes) + : ShuffleNetV2Impl({4, 8, 4}, {24, 176, 352, 704, 1024}, num_classes) {} + +ShuffleNetV2_x2_0Impl::ShuffleNetV2_x2_0Impl(int64_t num_classes) + : ShuffleNetV2Impl({4, 8, 4}, {24, 244, 488, 976, 2048}, num_classes) {} + +} // namespace models +} // namespace vision diff --git a/torchvision/csrc/models/shufflenetv2.h b/torchvision/csrc/models/shufflenetv2.h new file mode 100644 index 00000000000..ddb1bccbb14 --- /dev/null +++ b/torchvision/csrc/models/shufflenetv2.h @@ -0,0 +1,47 @@ +#ifndef SHUFFLENETV2_H +#define SHUFFLENETV2_H + +#include + +namespace vision { +namespace models { + +struct ShuffleNetV2Impl : torch::nn::Module { + std::vector _stage_out_channels; + torch::nn::Sequential conv1{nullptr}, stage2, stage3, stage4, conv5{nullptr}; + torch::nn::Linear fc{nullptr}; + + ShuffleNetV2Impl( + const std::vector& stage_repeats, + const std::vector& stage_out_channels, + int64_t num_classes = 1000); + + torch::Tensor forward(torch::Tensor x); +}; + +struct ShuffleNetV2_x0_5Impl : ShuffleNetV2Impl { + ShuffleNetV2_x0_5Impl(int64_t num_classes = 1000); +}; + +struct ShuffleNetV2_x1_0Impl : ShuffleNetV2Impl { + ShuffleNetV2_x1_0Impl(int64_t num_classes = 1000); +}; + +struct ShuffleNetV2_x1_5Impl : ShuffleNetV2Impl { + ShuffleNetV2_x1_5Impl(int64_t num_classes = 1000); +}; + +struct ShuffleNetV2_x2_0Impl : ShuffleNetV2Impl { + ShuffleNetV2_x2_0Impl(int64_t num_classes = 1000); +}; + +TORCH_MODULE(ShuffleNetV2); +TORCH_MODULE(ShuffleNetV2_x0_5); +TORCH_MODULE(ShuffleNetV2_x1_0); +TORCH_MODULE(ShuffleNetV2_x1_5); +TORCH_MODULE(ShuffleNetV2_x2_0); + +} // namespace models +} // namespace vision + +#endif // SHUFFLENETV2_H diff --git a/torchvision/csrc/models/squeezenet.cpp b/torchvision/csrc/models/squeezenet.cpp new file mode 100644 index 00000000000..300415c2396 --- /dev/null +++ b/torchvision/csrc/models/squeezenet.cpp @@ -0,0 +1,111 @@ +#include "squeezenet.h" + +#include +#include "modelsimpl.h" + +namespace vision { +namespace models { +struct Fire : torch::nn::Module { + torch::nn::Conv2d squeeze, expand1x1, expand3x3; + + Fire( + int64_t inplanes, + int64_t squeeze_planes, + int64_t expand1x1_planes, + int64_t expand3x3_planes) + : squeeze(torch::nn::Conv2dOptions(inplanes, squeeze_planes, 1)), + expand1x1( + torch::nn::Conv2dOptions(squeeze_planes, expand1x1_planes, 1)), + expand3x3(torch::nn::Conv2dOptions(squeeze_planes, expand3x3_planes, 3) + .padding(1)) { + register_module("squeeze", squeeze); + register_module("expand1x1", expand1x1); + register_module("expand3x3", expand3x3); + } + + torch::Tensor forward(torch::Tensor x) { + x = torch::relu(squeeze->forward(x)); + return torch::cat( + {torch::relu(expand1x1->forward(x)), + torch::relu(expand3x3->forward(x))}, + 1); + } +}; + +SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes) + : num_classes(num_classes) { + if (modelsimpl::double_compare(version, 1.0)) { + features = torch::nn::Sequential( + torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 96, 7).stride(2)), + torch::nn::Functional(modelsimpl::relu_), + torch::nn::Functional(torch::max_pool2d, 3, 2, 0, 1, true), + Fire(96, 16, 64, 64), + Fire(128, 16, 64, 64), + Fire(128, 32, 128, 128), + torch::nn::Functional(torch::max_pool2d, 3, 2, 0, 1, true), + Fire(256, 32, 128, 128), + Fire(256, 48, 192, 192), + Fire(384, 48, 192, 192), + Fire(384, 64, 256, 256), + torch::nn::Functional(torch::max_pool2d, 3, 2, 0, 1, true), + Fire(512, 64, 256, 256)); + } else if (modelsimpl::double_compare(version, 1.1)) { + features = torch::nn::Sequential( + torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 64, 3).stride(2)), + torch::nn::Functional(modelsimpl::relu_), + torch::nn::Functional(torch::max_pool2d, 3, 2, 0, 1, true), + Fire(64, 16, 64, 64), + Fire(128, 16, 64, 64), + torch::nn::Functional(torch::max_pool2d, 3, 2, 0, 1, true), + Fire(128, 32, 128, 128), + Fire(256, 32, 128, 128), + torch::nn::Functional(torch::max_pool2d, 3, 2, 0, 1, true), + Fire(256, 48, 192, 192), + Fire(384, 48, 192, 192), + Fire(384, 64, 256, 256), + Fire(512, 64, 256, 256)); + } else { + std::cerr << "Wrong version number is passed th SqueeseNet constructor!" + << std::endl; + assert(false); + } + + // Final convolution is initialized differently from the rest + auto final_conv = + torch::nn::Conv2d(torch::nn::Conv2dOptions(512, num_classes, 1)); + + classifier = torch::nn::Sequential( + torch::nn::Dropout(0.5), + final_conv, + torch::nn::Functional(modelsimpl::relu_), + torch::nn::Functional(modelsimpl::adaptive_avg_pool2d, 1)); + + register_module("features", features); + register_module("classifier", classifier); + + for (auto& module : modules(/*include_self=*/false)) + if (auto M = dynamic_cast(module.get())) { + if (M == final_conv.get()) + torch::nn::init::normal_(M->weight, 0.0, 0.01); + else + torch::nn::init::kaiming_uniform_(M->weight); + + if (M->options.with_bias()) + torch::nn::init::constant_(M->bias, 0); + } +} + +torch::Tensor SqueezeNetImpl::forward(torch::Tensor x) { + x = features->forward(x); + x = classifier->forward(x); + return x.view({x.size(0), -1}); +} + +SqueezeNet1_0Impl::SqueezeNet1_0Impl(int64_t num_classes) + : SqueezeNetImpl(1.0, num_classes) {} + +SqueezeNet1_1Impl::SqueezeNet1_1Impl(int64_t num_classes) + : SqueezeNetImpl(1.1, num_classes) {} + +} // namespace models +} // namespace vision diff --git a/torchvision/csrc/models/squeezenet.h b/torchvision/csrc/models/squeezenet.h new file mode 100644 index 00000000000..ee3350b4bf3 --- /dev/null +++ b/torchvision/csrc/models/squeezenet.h @@ -0,0 +1,39 @@ +#ifndef SQUEEZENET_H +#define SQUEEZENET_H + +#include + +namespace vision { +namespace models { +struct SqueezeNetImpl : torch::nn::Module { + int64_t num_classes; + torch::nn::Sequential features{nullptr}, classifier{nullptr}; + + SqueezeNetImpl(double version = 1.0, int64_t num_classes = 1000); + + torch::Tensor forward(torch::Tensor x); +}; + +// SqueezeNet model architecture from the "SqueezeNet: AlexNet-level +// accuracy with 50x fewer parameters and <0.5MB model size" +// paper. +struct SqueezeNet1_0Impl : SqueezeNetImpl { + SqueezeNet1_0Impl(int64_t num_classes = 1000); +}; + +// SqueezeNet 1.1 model from the official SqueezeNet repo +// . +// SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters +// than SqueezeNet 1.0, without sacrificing accuracy. +struct SqueezeNet1_1Impl : SqueezeNetImpl { + SqueezeNet1_1Impl(int64_t num_classes = 1000); +}; + +TORCH_MODULE(SqueezeNet); +TORCH_MODULE(SqueezeNet1_0); +TORCH_MODULE(SqueezeNet1_1); + +} // namespace models +} // namespace vision + +#endif // SQUEEZENET_H diff --git a/torchvision/csrc/models/vgg.cpp b/torchvision/csrc/models/vgg.cpp new file mode 100644 index 00000000000..f0991b7b085 --- /dev/null +++ b/torchvision/csrc/models/vgg.cpp @@ -0,0 +1,114 @@ +#include "vgg.h" + +#include +#include "modelsimpl.h" + +namespace vision { +namespace models { +torch::nn::Sequential makeLayers( + const std::vector& cfg, + bool batch_norm = false) { + torch::nn::Sequential seq; + auto channels = 3; + + for (const auto& V : cfg) { + if (V <= -1) + seq->push_back(torch::nn::Functional(modelsimpl::max_pool2d, 2, 2)); + else { + seq->push_back(torch::nn::Conv2d( + torch::nn::Conv2dOptions(channels, V, 3).padding(1))); + + if (batch_norm) + seq->push_back(torch::nn::BatchNorm(V)); + seq->push_back(torch::nn::Functional(modelsimpl::relu_)); + + channels = V; + } + } + + return seq; +} + +void VGGImpl::_initialize_weights() { + for (auto& module : modules(/*include_self=*/false)) { + if (auto M = dynamic_cast(module.get())) { + torch::nn::init::kaiming_normal_( + M->weight, + /*a=*/0, + torch::nn::init::FanMode::FanOut, + torch::nn::init::Nonlinearity::ReLU); + torch::nn::init::constant_(M->bias, 0); + } else if (auto M = dynamic_cast(module.get())) { + torch::nn::init::constant_(M->weight, 1); + torch::nn::init::constant_(M->bias, 0); + } else if (auto M = dynamic_cast(module.get())) { + torch::nn::init::normal_(M->weight, 0, 0.01); + torch::nn::init::constant_(M->bias, 0); + } + } +} + +VGGImpl::VGGImpl( + torch::nn::Sequential features, + int64_t num_classes, + bool initialize_weights) { + classifier = torch::nn::Sequential( + torch::nn::Linear(512 * 7 * 7, 4096), + torch::nn::Functional(modelsimpl::relu_), + torch::nn::Dropout(), + torch::nn::Linear(4096, 4096), + torch::nn::Functional(modelsimpl::relu_), + torch::nn::Dropout(), + torch::nn::Linear(4096, num_classes)); + + this->features = features; + + register_module("features", this->features); + register_module("classifier", classifier); + + if (initialize_weights) + _initialize_weights(); +} + +torch::Tensor VGGImpl::forward(torch::Tensor x) { + x = features->forward(x); + x = torch::adaptive_avg_pool2d(x, {7, 7}); + x = x.view({x.size(0), -1}); + x = classifier->forward(x); + return x; +} + +// clang-format off +static std::unordered_map> cfg = { + {'A', {64, -1, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}}, + {'B', {64, 64, -1, 128, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}}, + {'D', {64, 64, -1, 128, 128, -1, 256, 256, 256, -1, 512, 512, 512, -1, 512, 512, 512, -1}}, + {'E', {64, 64, -1, 128, 128, -1, 256, 256, 256, 256, -1, 512, 512, 512, 512, -1, 512, 512, 512, 512, -1}}}; +// clang-format on + +VGG11Impl::VGG11Impl(int64_t num_classes, bool initialize_weights) + : VGGImpl(makeLayers(cfg['A']), num_classes, initialize_weights) {} + +VGG13Impl::VGG13Impl(int64_t num_classes, bool initialize_weights) + : VGGImpl(makeLayers(cfg['B']), num_classes, initialize_weights) {} + +VGG16Impl::VGG16Impl(int64_t num_classes, bool initialize_weights) + : VGGImpl(makeLayers(cfg['D']), num_classes, initialize_weights) {} + +VGG19Impl::VGG19Impl(int64_t num_classes, bool initialize_weights) + : VGGImpl(makeLayers(cfg['E']), num_classes, initialize_weights) {} + +VGG11BNImpl::VGG11BNImpl(int64_t num_classes, bool initialize_weights) + : VGGImpl(makeLayers(cfg['A'], true), num_classes, initialize_weights) {} + +VGG13BNImpl::VGG13BNImpl(int64_t num_classes, bool initialize_weights) + : VGGImpl(makeLayers(cfg['B'], true), num_classes, initialize_weights) {} + +VGG16BNImpl::VGG16BNImpl(int64_t num_classes, bool initialize_weights) + : VGGImpl(makeLayers(cfg['D'], true), num_classes, initialize_weights) {} + +VGG19BNImpl::VGG19BNImpl(int64_t num_classes, bool initialize_weights) + : VGGImpl(makeLayers(cfg['E'], true), num_classes, initialize_weights) {} + +} // namespace models +} // namespace vision diff --git a/torchvision/csrc/models/vgg.h b/torchvision/csrc/models/vgg.h new file mode 100644 index 00000000000..04f77007817 --- /dev/null +++ b/torchvision/csrc/models/vgg.h @@ -0,0 +1,76 @@ +#ifndef VGG_H +#define VGG_H + +#include + +namespace vision { +namespace models { +struct VGGImpl : torch::nn::Module { + torch::nn::Sequential features{nullptr}, classifier{nullptr}; + + void _initialize_weights(); + + VGGImpl( + torch::nn::Sequential features, + int64_t num_classes = 1000, + bool initialize_weights = true); + + torch::Tensor forward(torch::Tensor x); +}; + +// VGG 11-layer model (configuration "A") +struct VGG11Impl : VGGImpl { + VGG11Impl(int64_t num_classes = 1000, bool initialize_weights = true); +}; + +// VGG 13-layer model (configuration "B") +struct VGG13Impl : VGGImpl { + VGG13Impl(int64_t num_classes = 1000, bool initialize_weights = true); +}; + +// VGG 16-layer model (configuration "D") +struct VGG16Impl : VGGImpl { + VGG16Impl(int64_t num_classes = 1000, bool initialize_weights = true); +}; + +// VGG 19-layer model (configuration "E") +struct VGG19Impl : VGGImpl { + VGG19Impl(int64_t num_classes = 1000, bool initialize_weights = true); +}; + +// VGG 11-layer model (configuration "A") with batch normalization +struct VGG11BNImpl : VGGImpl { + VGG11BNImpl(int64_t num_classes = 1000, bool initialize_weights = true); +}; + +// VGG 13-layer model (configuration "B") with batch normalization +struct VGG13BNImpl : VGGImpl { + VGG13BNImpl(int64_t num_classes = 1000, bool initialize_weights = true); +}; + +// VGG 16-layer model (configuration "D") with batch normalization +struct VGG16BNImpl : VGGImpl { + VGG16BNImpl(int64_t num_classes = 1000, bool initialize_weights = true); +}; + +// VGG 19-layer model (configuration 'E') with batch normalization +struct VGG19BNImpl : VGGImpl { + VGG19BNImpl(int64_t num_classes = 1000, bool initialize_weights = true); +}; + +TORCH_MODULE(VGG); + +TORCH_MODULE(VGG11); +TORCH_MODULE(VGG13); +TORCH_MODULE(VGG16); +TORCH_MODULE(VGG19); + +TORCH_MODULE(VGG11BN); +TORCH_MODULE(VGG13BN); +TORCH_MODULE(VGG16BN); +TORCH_MODULE(VGG19BN); + +} // namespace models +} // namespace vision + +#endif // VGG_H diff --git a/torchvision/csrc/vision.h b/torchvision/csrc/vision.h new file mode 100644 index 00000000000..120eb034f38 --- /dev/null +++ b/torchvision/csrc/vision.h @@ -0,0 +1,6 @@ +#ifndef VISION_H +#define VISION_H + +#include + +#endif // VISION_H