@@ -180,10 +180,6 @@ def _create_and_fit_Sequential_model(self, layer, shape):
180180 model .fit (x = input_batch , y = output_batch , epochs = 1 , batch_size = 1 )
181181 return model
182182
183- def test_axis_error (self ):
184- with self .assertRaises (ValueError ):
185- GroupNormalization (axis = 0 )
186-
187183 def test_groupnorm_flat (self ):
188184 # Check basic usage of groupnorm_flat
189185 # Testing for 1 == LayerNorm, 16 == GroupNorm, -1 == InstanceNorm
@@ -219,22 +215,29 @@ def test_initializer(self):
219215 negativ = weights [weights < 0.0 ]
220216 self .assertTrue (len (negativ ) == 0 )
221217
222- def test_groupnorm_conv (self ):
223- # Check if Axis is working for CONV nets
224- # Testing for 1 == LayerNorm, 5 == GroupNorm, -1 == InstanceNorm
225- np .random .seed (0x2020 )
226- groups = [- 1 , 5 , 1 ]
227- for i in groups :
228- model = tf .keras .models .Sequential ()
229- model .add (GroupNormalization (axis = 1 , groups = i , input_shape = (20 , 20 , 3 )))
230- model .add (tf .keras .layers .Conv2D (5 , (1 , 1 ), padding = "same" ))
231- model .add (tf .keras .layers .Flatten ())
232- model .add (tf .keras .layers .Dense (1 , activation = "softmax" ))
233- model .compile (optimizer = tf .keras .optimizers .RMSprop (0.01 ), loss = "mse" )
234- x = np .random .randint (1000 , size = (10 , 20 , 20 , 3 ))
235- y = np .random .randint (1000 , size = (10 , 1 ))
236- model .fit (x = x , y = y , epochs = 1 )
237- self .assertTrue (hasattr (model .layers [0 ], "gamma" ))
218+
219+ def test_axis_error ():
220+ with pytest .raises (ValueError ):
221+ GroupNormalization (axis = 0 )
222+
223+
224+ @pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
225+ def test_groupnorm_conv ():
226+ # Check if Axis is working for CONV nets
227+ # Testing for 1 == LayerNorm, 5 == GroupNorm, -1 == InstanceNorm
228+ np .random .seed (0x2020 )
229+ groups = [- 1 , 5 , 1 ]
230+ for i in groups :
231+ model = tf .keras .models .Sequential ()
232+ model .add (GroupNormalization (axis = 1 , groups = i , input_shape = (20 , 20 , 3 )))
233+ model .add (tf .keras .layers .Conv2D (5 , (1 , 1 ), padding = "same" ))
234+ model .add (tf .keras .layers .Flatten ())
235+ model .add (tf .keras .layers .Dense (1 , activation = "softmax" ))
236+ model .compile (optimizer = tf .keras .optimizers .RMSprop (0.01 ), loss = "mse" )
237+ x = np .random .randint (1000 , size = (10 , 20 , 20 , 3 ))
238+ y = np .random .randint (1000 , size = (10 , 1 ))
239+ model .fit (x = x , y = y , epochs = 1 )
240+ assert hasattr (model .layers [0 ], "gamma" )
238241
239242
240243@pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
0 commit comments