@@ -18,6 +18,7 @@ def googlenet(pretrained=False, **kwargs):
1818        pretrained (bool): If True, returns a model pre-trained on ImageNet 
1919    """ 
2020    if  pretrained :
21+         kwargs ['init_weights' ] =  False 
2122        model  =  GoogLeNet (** kwargs )
2223        model .load_state_dict (model_zoo .load_url (model_urls ['googlenet' ]))
2324        return  model 
@@ -32,6 +33,7 @@ def googlenet_bn(pretrained=False, **kwargs):
3233        pretrained (bool): If True, returns a model pre-trained on ImageNet 
3334    """ 
3435    if  pretrained :
36+         kwargs ['init_weights' ] =  False 
3537        model  =  GoogLeNet (batch_norm = True , ** kwargs )
3638        model .load_state_dict (model_zoo .load_url (model_urls ['googlenet_bn' ]))
3739        return  model 
@@ -41,7 +43,7 @@ def googlenet_bn(pretrained=False, **kwargs):
4143
4244class  GoogLeNet (nn .Module ):
4345
44-     def  __init__ (self , num_classes = 1000 , aux_logits = True , batch_norm = False ):
46+     def  __init__ (self , num_classes = 1000 , aux_logits = True , batch_norm = False ,  init_weights = True ):
4547        super (GoogLeNet , self ).__init__ ()
4648        self .aux_logits  =  aux_logits 
4749
@@ -73,11 +75,25 @@ def __init__(self, num_classes=1000, aux_logits=True, batch_norm=False):
7375        self .dropout  =  nn .Dropout (0.4 )
7476        self .fc  =  nn .Linear (1024 , num_classes )
7577
78+         if  init_weights :
79+             self ._initialize_weights ()
80+ 
81+     def  _initialize_weights (self ):
7682        for  m  in  self .modules ():
7783            if  isinstance (m , nn .Conv2d ) or  isinstance (m , nn .Linear ):
7884                nn .init .xavier_uniform_ (m .weight )
7985                if  m .bias  is  not None :
8086                    nn .init .constant_ (m .bias , 0.2 )
87+             elif  isinstance (m , nn .BatchNorm2d ):
88+                 nn .init .constant_ (m .weight , 1 )
89+                 nn .init .constant_ (m .bias , 0 )
90+ 
91+         # zero init classifier 
92+         for  m  in  self .modules ():
93+             if  isinstance (m , InceptionAux ):
94+                 nn .init .zeros_ (m .fc2 .bias )
95+             elif  m  ==  self .fc :
96+                 nn .init .zeros_ (m .bias )
8197
8298    def  forward (self , x ):
8399        x  =  self .conv1 (x )
0 commit comments