@@ -93,77 +93,57 @@ def split(x):
9393inp = np .random .standard_normal ((1 , 3 )).astype (np .float32 )
9494save_tflite_model (split , inp , 'split' )
9595
96+ def keras_to_tf (model , input_shape ):
97+ tf_func = tf .function (
98+ model .call ,
99+ input_signature = [tf .TensorSpec (input_shape , tf .float32 )],
100+ )
101+ inp = np .random .standard_normal ((input_shape )).astype (np .float32 )
102+
103+ return tf_func , inp
96104
97105fully_connected = tf .keras .models .Sequential ([
98106 tf .keras .layers .Dense (3 ),
99107 tf .keras .layers .ReLU (),
100108 tf .keras .layers .Softmax (),
101109])
102110
103- fully_connected = tf .function (
104- fully_connected .call ,
105- input_signature = [tf .TensorSpec ((1 ,2 ), tf .float32 )],
106- )
107-
108- inp = np .random .standard_normal ((1 , 2 )).astype (np .float32 )
111+ fully_connected , inp = keras_to_tf (fully_connected , (1 , 2 ))
109112save_tflite_model (fully_connected , inp , 'fully_connected' )
110113
111114permutation_3d = tf .keras .models .Sequential ([
112- tf .keras .layers .Permute ((2 ,1 ))
115+ tf .keras .layers .Permute ((2 , 1 ))
113116])
114117
115- permutation_3d = tf .function (
116- permutation_3d .call ,
117- input_signature = [tf .TensorSpec ((1 ,2 ,3 ), tf .float32 )],
118- )
119- inp = np .random .standard_normal ((1 , 2 , 3 )).astype (np .float32 )
118+ permutation_3d , inp = keras_to_tf (permutation_3d , (1 , 2 , 3 ))
120119save_tflite_model (permutation_3d , inp , 'permutation_3d' )
121120
122- # Temporarily disabled as TFLiteConverter produces a incorrect graph in this case
123- #permutation_4d_0123 = tf.keras.models.Sequential([
124- # tf.keras.layers.Permute((1,2,3)),
125- # tf.keras.layers.Conv2D(3,1)
126- #])
127- #
128- #permutation_4d_0123 = tf.function(
129- # permutation_4d_0123.call,
130- # input_signature=[tf.TensorSpec((1,2,3,4), tf.float32)],
131- #)
132- #inp = np.random.standard_normal((1, 2, 3, 4)).astype(np.float32)
133- #save_tflite_model(permutation_4d_0123, inp, 'permutation_4d_0123')
134-
135- permutation_4d_0132 = tf .keras .models .Sequential ([
136- tf .keras .layers .Permute ((1 ,3 ,2 )),
137- tf .keras .layers .Conv2D (3 ,1 )
138- ])
139-
140- permutation_4d_0132 = tf .function (
141- permutation_4d_0132 .call ,
142- input_signature = [tf .TensorSpec ((1 ,2 ,3 ,4 ), tf .float32 )],
143- )
144- inp = np .random .standard_normal ((1 , 2 , 3 , 4 )).astype (np .float32 )
145- save_tflite_model (permutation_4d_0132 , inp , 'permutation_4d_0132' )
146-
147- permutation_4d_0213 = tf .keras .models .Sequential ([
148- tf .keras .layers .Permute ((2 ,1 ,3 )),
149- tf .keras .layers .Conv2D (3 ,1 )
121+ # (1, 2, 3) is temporarily disabled as TFLiteConverter produces a incorrect graph in this case
122+ permutation_4d_list = [(1 , 3 , 2 ), (2 , 1 , 3 ), (2 , 3 , 1 )]
123+ for perm_axis in permutation_4d_list :
124+ permutation_4d_model = tf .keras .models .Sequential ([
125+ tf .keras .layers .Permute (perm_axis ),
126+ tf .keras .layers .Conv2D (3 , 1 )
127+ ])
128+
129+ permutation_4d_model , inp = keras_to_tf (permutation_4d_model , (1 , 2 , 3 , 4 ))
130+ model_name = f"permutation_4d_0{ '' .join (map (str , perm_axis ))} "
131+ save_tflite_model (permutation_4d_model , inp , model_name )
132+
133+ global_average_pooling_2d = tf .keras .models .Sequential ([
134+ tf .keras .layers .GlobalAveragePooling2D (keepdims = True ),
135+ tf .keras .layers .ZeroPadding2D (1 ),
136+ tf .keras .layers .GlobalAveragePooling2D (keepdims = False )
150137])
151138
152- permutation_4d_0213 = tf .function (
153- permutation_4d_0213 .call ,
154- input_signature = [tf .TensorSpec ((1 ,2 ,3 ,4 ), tf .float32 )],
155- )
156- inp = np .random .standard_normal ((1 , 2 , 3 , 4 )).astype (np .float32 )
157- save_tflite_model (permutation_4d_0213 , inp , 'permutation_4d_0213' )
139+ global_average_pooling_2d , inp = keras_to_tf (global_average_pooling_2d , (1 , 7 , 7 , 5 ))
140+ save_tflite_model (global_average_pooling_2d , inp , 'global_average_pooling_2d' )
158141
159- permutation_4d_0231 = tf .keras .models .Sequential ([
160- tf .keras .layers .Permute ((2 ,3 ,1 )),
161- tf .keras .layers .Conv2D (3 ,1 )
142+ global_max_pool = tf .keras .models .Sequential ([
143+ tf .keras .layers .GlobalMaxPool2D (keepdims = True ),
144+ tf .keras .layers .ZeroPadding2D (1 ),
145+ tf .keras .layers .GlobalMaxPool2D (keepdims = True )
162146])
163147
164- permutation_4d_0231 = tf .function (
165- permutation_4d_0231 .call ,
166- input_signature = [tf .TensorSpec ((1 ,2 ,3 ,4 ), tf .float32 )],
167- )
168- inp = np .random .standard_normal ((1 , 2 , 3 , 4 )).astype (np .float32 )
169- save_tflite_model (permutation_4d_0231 , inp , 'permutation_4d_0231' )
148+ global_max_pool , inp = keras_to_tf (global_max_pool , (1 , 7 , 7 , 5 ))
149+ save_tflite_model (global_max_pool , inp , 'global_max_pooling_2d' )
0 commit comments