2222from tensorflow_addons .utils import types
2323
2424
25- @tf .keras .utils .register_keras_serializable (package = ' Addons' )
25+ @tf .keras .utils .register_keras_serializable (package = " Addons" )
2626class GroupNormalization (tf .keras .layers .Layer ):
2727 """Group normalization layer.
2828
@@ -71,19 +71,21 @@ class GroupNormalization(tf.keras.layers.Layer):
7171 """
7272
7373 @typechecked
74- def __init__ (self ,
75- groups : int = 2 ,
76- axis : int = - 1 ,
77- epsilon : int = 1e-3 ,
78- center : bool = True ,
79- scale : bool = True ,
80- beta_initializer : types .Initializer = 'zeros' ,
81- gamma_initializer : types .Initializer = 'ones' ,
82- beta_regularizer : types .Regularizer = None ,
83- gamma_regularizer : types .Regularizer = None ,
84- beta_constraint : types .Constraint = None ,
85- gamma_constraint : types .Constraint = None ,
86- ** kwargs ):
74+ def __init__ (
75+ self ,
76+ groups : int = 2 ,
77+ axis : int = - 1 ,
78+ epsilon : float = 1e-3 ,
79+ center : bool = True ,
80+ scale : bool = True ,
81+ beta_initializer : types .Initializer = "zeros" ,
82+ gamma_initializer : types .Initializer = "ones" ,
83+ beta_regularizer : types .Regularizer = None ,
84+ gamma_regularizer : types .Regularizer = None ,
85+ beta_constraint : types .Constraint = None ,
86+ gamma_constraint : types .Constraint = None ,
87+ ** kwargs
88+ ):
8789 super ().__init__ (** kwargs )
8890 self .supports_masking = True
8991 self .groups = groups
@@ -117,39 +119,32 @@ def call(self, inputs):
117119 tensor_input_shape = tf .shape (inputs )
118120
119121 reshaped_inputs , group_shape = self ._reshape_into_groups (
120- inputs , input_shape , tensor_input_shape )
122+ inputs , input_shape , tensor_input_shape
123+ )
121124
122- normalized_inputs = self ._apply_normalization (reshaped_inputs ,
123- input_shape )
125+ normalized_inputs = self ._apply_normalization (reshaped_inputs , input_shape )
124126
125127 outputs = tf .reshape (normalized_inputs , tensor_input_shape )
126128
127129 return outputs
128130
129131 def get_config (self ):
130132 config = {
131- 'groups' :
132- self .groups ,
133- 'axis' :
134- self .axis ,
135- 'epsilon' :
136- self .epsilon ,
137- 'center' :
138- self .center ,
139- 'scale' :
140- self .scale ,
141- 'beta_initializer' :
142- tf .keras .initializers .serialize (self .beta_initializer ),
143- 'gamma_initializer' :
144- tf .keras .initializers .serialize (self .gamma_initializer ),
145- 'beta_regularizer' :
146- tf .keras .regularizers .serialize (self .beta_regularizer ),
147- 'gamma_regularizer' :
148- tf .keras .regularizers .serialize (self .gamma_regularizer ),
149- 'beta_constraint' :
150- tf .keras .constraints .serialize (self .beta_constraint ),
151- 'gamma_constraint' :
152- tf .keras .constraints .serialize (self .gamma_constraint )
133+ "groups" : self .groups ,
134+ "axis" : self .axis ,
135+ "epsilon" : self .epsilon ,
136+ "center" : self .center ,
137+ "scale" : self .scale ,
138+ "beta_initializer" : tf .keras .initializers .serialize (self .beta_initializer ),
139+ "gamma_initializer" : tf .keras .initializers .serialize (
140+ self .gamma_initializer
141+ ),
142+ "beta_regularizer" : tf .keras .regularizers .serialize (self .beta_regularizer ),
143+ "gamma_regularizer" : tf .keras .regularizers .serialize (
144+ self .gamma_regularizer
145+ ),
146+ "beta_constraint" : tf .keras .constraints .serialize (self .beta_constraint ),
147+ "gamma_constraint" : tf .keras .constraints .serialize (self .gamma_constraint ),
153148 }
154149 base_config = super ().get_config ()
155150 return {** base_config , ** config }
@@ -174,7 +169,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
174169 group_reduction_axes .pop (axis )
175170
176171 mean , variance = tf .nn .moments (
177- reshaped_inputs , group_reduction_axes , keepdims = True )
172+ reshaped_inputs , group_reduction_axes , keepdims = True
173+ )
178174
179175 gamma , beta = self ._get_reshaped_weights (input_shape )
180176 normalized_inputs = tf .nn .batch_normalization (
@@ -183,7 +179,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
183179 variance = variance ,
184180 scale = gamma ,
185181 offset = beta ,
186- variance_epsilon = self .epsilon )
182+ variance_epsilon = self .epsilon ,
183+ )
187184 return normalized_inputs
188185
189186 def _get_reshaped_weights (self , input_shape ):
@@ -200,10 +197,11 @@ def _get_reshaped_weights(self, input_shape):
200197 def _check_if_input_shape_is_none (self , input_shape ):
201198 dim = input_shape [self .axis ]
202199 if dim is None :
203- raise ValueError ('Axis ' + str (self .axis ) + ' of '
204- 'input tensor should have a defined dimension '
205- 'but the layer received an input with shape ' +
206- str (input_shape ) + '.' )
200+ raise ValueError (
201+ "Axis " + str (self .axis ) + " of "
202+ "input tensor should have a defined dimension "
203+ "but the layer received an input with shape " + str (input_shape ) + "."
204+ )
207205
208206 def _set_number_of_groups_for_instance_norm (self , input_shape ):
209207 dim = input_shape [self .axis ]
@@ -216,26 +214,30 @@ def _check_size_of_dimensions(self, input_shape):
216214 dim = input_shape [self .axis ]
217215 if dim < self .groups :
218216 raise ValueError (
219- 'Number of groups (' + str (self .groups ) + ') cannot be '
220- 'more than the number of channels (' + str (dim ) + ').' )
217+ "Number of groups (" + str (self .groups ) + ") cannot be "
218+ "more than the number of channels (" + str (dim ) + ")."
219+ )
221220
222221 if dim % self .groups != 0 :
223222 raise ValueError (
224- 'Number of groups (' + str (self .groups ) + ') must be a '
225- 'multiple of the number of channels (' + str (dim ) + ').' )
223+ "Number of groups (" + str (self .groups ) + ") must be a "
224+ "multiple of the number of channels (" + str (dim ) + ")."
225+ )
226226
227227 def _check_axis (self ):
228228
229229 if self .axis == 0 :
230230 raise ValueError (
231231 "You are trying to normalize your batch axis. Do you want to "
232- "use tf.layer.batch_normalization instead" )
232+ "use tf.layer.batch_normalization instead"
233+ )
233234
234235 def _create_input_spec (self , input_shape ):
235236
236237 dim = input_shape [self .axis ]
237238 self .input_spec = tf .keras .layers .InputSpec (
238- ndim = len (input_shape ), axes = {self .axis : dim })
239+ ndim = len (input_shape ), axes = {self .axis : dim }
240+ )
239241
240242 def _add_gamma_weight (self , input_shape ):
241243
@@ -245,10 +247,11 @@ def _add_gamma_weight(self, input_shape):
245247 if self .scale :
246248 self .gamma = self .add_weight (
247249 shape = shape ,
248- name = ' gamma' ,
250+ name = " gamma" ,
249251 initializer = self .gamma_initializer ,
250252 regularizer = self .gamma_regularizer ,
251- constraint = self .gamma_constraint )
253+ constraint = self .gamma_constraint ,
254+ )
252255 else :
253256 self .gamma = None
254257
@@ -260,10 +263,11 @@ def _add_beta_weight(self, input_shape):
260263 if self .center :
261264 self .beta = self .add_weight (
262265 shape = shape ,
263- name = ' beta' ,
266+ name = " beta" ,
264267 initializer = self .beta_initializer ,
265268 regularizer = self .beta_regularizer ,
266- constraint = self .beta_constraint )
269+ constraint = self .beta_constraint ,
270+ )
267271 else :
268272 self .beta = None
269273
@@ -274,7 +278,7 @@ def _create_broadcast_shape(self, input_shape):
274278 return broadcast_shape
275279
276280
277- @tf .keras .utils .register_keras_serializable (package = ' Addons' )
281+ @tf .keras .utils .register_keras_serializable (package = " Addons" )
278282class InstanceNormalization (GroupNormalization ):
279283 """Instance normalization layer.
280284
0 commit comments