Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ architectures:
- `SqueezeNet`_
- `DenseNet`_
- `Inception`_ v3
- `GoogLeNet`_

You can construct a model with random weights by calling its constructor:

Expand All @@ -22,6 +23,7 @@ You can construct a model with random weights by calling its constructor:
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()

We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``:
Expand All @@ -35,6 +37,7 @@ These can be constructed by passing ``pretrained=True``:
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)

Instancing a pre-trained model will download its weights to a cache directory.
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
Expand Down Expand Up @@ -93,6 +96,7 @@ Inception v3 22.55 6.44
.. _SqueezeNet: https://arxiv.org/abs/1602.07360
.. _DenseNet: https://arxiv.org/abs/1608.06993
.. _Inception: https://arxiv.org/abs/1512.00567
.. _GoogLeNet: https://arxiv.org/abs/1409.4842

.. currentmodule:: torchvision.models

Expand Down Expand Up @@ -142,3 +146,8 @@ Inception v3

.. autofunction:: inception_v3

GoogLeNet
------------

.. autofunction:: googlenet

1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .squeezenet import *
from .inception import *
from .densenet import *
from .googlenet import *
166 changes: 166 additions & 0 deletions torchvision/models/googlenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo

__all__ = ['GoogLeNet', 'googlenet']

model_urls = {
'googlenet': ''
}


def googlenet(pretrained=False, **kwargs):
r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if pretrained:
model = GoogLeNet(**kwargs)
model.load_state_dict(model_zoo.load_url(model_urls['googlenet']))
return model

return GoogLeNet(**kwargs)


class GoogLeNet(nn.Module):

def __init__(self, num_classes=1000, aux_logits=True):
super(GoogLeNet, self).__init__()
self.aux_logits = aux_logits

self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.maxpool1 = nn.MaxPool2d(3, stride=2, padding=1)
self.lrn1 = nn.LocalResponseNorm(5, alpha=0.0001)
self.conv2 = BasicConv2d(64, 64, kernel_size=1)
self.conv3 = BasicConv2d(64, 192, kernel_size=3, stride=1, padding=1)
self.lrn2 = nn.LocalResponseNorm(5, alpha=0.0001)
self.maxpool2 = nn.MaxPool2d(3, stride=2, padding=1)

self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
self.maxpool3 = nn.MaxPool2d(3, stride=2)

self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
self.maxpool4 = nn.MaxPool2d(3, stride=2, padding=1)

self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
if aux_logits:
self.aux1 = InceptionAux(512, num_classes)
self.aux2 = InceptionAux(528, num_classes)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.dropout = nn.Dropout(0.4)
self.fc = nn.Linear(1024, num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
import scipy.stats as stats
X = stats.truncnorm(-2, 2, scale=0.01)
values = torch.Tensor(X.rvs(m.weight.numel()))
values = values.view(m.weight.size())
m.weight.data.copy_(values)

def forward(self, x):
x = self.conv1(x)
x = self.maxpool1(x)
x = self.lrn1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.lrn2(x)
x = self.maxpool2(x)

x = self.inception3a(x)
x = self.inception3b(x)
x = self.maxpool3(x)
x = self.inception4a(x)
if self.training and self.aux_logits:
aux1 = self.aux1(x)

x = self.inception4b(x)
x = self.inception4c(x)
x = self.inception4d(x)
if self.training and self.aux_logits:
aux2 = self.aux2(x)

x = self.inception4e(x)
x = self.maxpool4(x)
x = self.inception5a(x)
x = self.inception5b(x)

x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.dropout(x)
x = self.fc(x)
if self.training and self.aux_logits:
return aux1, aux2, x
return x


class Inception(nn.Module):

def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
super(Inception, self).__init__()

self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

self.branch2 = nn.Sequential(
BasicConv2d(in_channels, ch3x3red, kernel_size=1, stride=1),
BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
)

self.branch3 = nn.Sequential(
BasicConv2d(in_channels, ch5x5red, kernel_size=1),
BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)
)

self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
BasicConv2d(in_channels, pool_proj, kernel_size=1)
)

def forward(self, x):
branch1 = self.branch1(x)
branch2 = self.branch2(x)
branch3 = self.branch3(x)
branch4 = self.branch4(x)

outputs = [branch1, branch2, branch3, branch4]
return torch.cat(outputs, 1)


class InceptionAux(nn.Module):

def __init__(self, in_channels, num_classes):
super(InceptionAux, self).__init__()
self.conv = BasicConv2d(in_channels, 128, kernel_size=1)

self.fc1 = nn.Linear(128 * 3 * 3, 1024)
self.fc2 = nn.Linear(1024, num_classes)

def forward(self, x):
x = F.avg_pool2d(x, kernel_size=5, stride=3)

x = self.conv(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = F.dropout(x, 0.7, training=self.training)
x = self.fc2(x)

return x


class BasicConv2d(nn.Module):

def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)

def forward(self, x):
x = self.conv(x)
return F.relu(x, inplace=True)