Skip to content

Commit dd1fbd0

Browse files
authored
Merge pull request #1180 from CNOCycle:tflite/ops
* Simpify generating permutation testes for tflite * Simpify converting keras into TF for tflite tests * Add global_pool_2d tests for tflite models
1 parent 723bdf2 commit dd1fbd0

7 files changed

+35
-55
lines changed

testdata/dnn/tflite/generate.py

Lines changed: 35 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -93,77 +93,57 @@ def split(x):
9393
inp = np.random.standard_normal((1, 3)).astype(np.float32)
9494
save_tflite_model(split, inp, 'split')
9595

96+
def keras_to_tf(model, input_shape):
97+
tf_func = tf.function(
98+
model.call,
99+
input_signature=[tf.TensorSpec(input_shape, tf.float32)],
100+
)
101+
inp = np.random.standard_normal((input_shape)).astype(np.float32)
102+
103+
return tf_func, inp
96104

97105
fully_connected = tf.keras.models.Sequential([
98106
tf.keras.layers.Dense(3),
99107
tf.keras.layers.ReLU(),
100108
tf.keras.layers.Softmax(),
101109
])
102110

103-
fully_connected = tf.function(
104-
fully_connected.call,
105-
input_signature=[tf.TensorSpec((1,2), tf.float32)],
106-
)
107-
108-
inp = np.random.standard_normal((1, 2)).astype(np.float32)
111+
fully_connected, inp = keras_to_tf(fully_connected, (1, 2))
109112
save_tflite_model(fully_connected, inp, 'fully_connected')
110113

111114
permutation_3d = tf.keras.models.Sequential([
112-
tf.keras.layers.Permute((2,1))
115+
tf.keras.layers.Permute((2, 1))
113116
])
114117

115-
permutation_3d = tf.function(
116-
permutation_3d.call,
117-
input_signature=[tf.TensorSpec((1,2,3), tf.float32)],
118-
)
119-
inp = np.random.standard_normal((1, 2, 3)).astype(np.float32)
118+
permutation_3d, inp = keras_to_tf(permutation_3d, (1, 2, 3))
120119
save_tflite_model(permutation_3d, inp, 'permutation_3d')
121120

122-
# Temporarily disabled as TFLiteConverter produces a incorrect graph in this case
123-
#permutation_4d_0123 = tf.keras.models.Sequential([
124-
# tf.keras.layers.Permute((1,2,3)),
125-
# tf.keras.layers.Conv2D(3,1)
126-
#])
127-
#
128-
#permutation_4d_0123 = tf.function(
129-
# permutation_4d_0123.call,
130-
# input_signature=[tf.TensorSpec((1,2,3,4), tf.float32)],
131-
#)
132-
#inp = np.random.standard_normal((1, 2, 3, 4)).astype(np.float32)
133-
#save_tflite_model(permutation_4d_0123, inp, 'permutation_4d_0123')
134-
135-
permutation_4d_0132 = tf.keras.models.Sequential([
136-
tf.keras.layers.Permute((1,3,2)),
137-
tf.keras.layers.Conv2D(3,1)
138-
])
139-
140-
permutation_4d_0132 = tf.function(
141-
permutation_4d_0132.call,
142-
input_signature=[tf.TensorSpec((1,2,3,4), tf.float32)],
143-
)
144-
inp = np.random.standard_normal((1, 2, 3, 4)).astype(np.float32)
145-
save_tflite_model(permutation_4d_0132, inp, 'permutation_4d_0132')
146-
147-
permutation_4d_0213 = tf.keras.models.Sequential([
148-
tf.keras.layers.Permute((2,1,3)),
149-
tf.keras.layers.Conv2D(3,1)
121+
# (1, 2, 3) is temporarily disabled as TFLiteConverter produces a incorrect graph in this case
122+
permutation_4d_list = [(1, 3, 2), (2, 1, 3), (2, 3, 1)]
123+
for perm_axis in permutation_4d_list:
124+
permutation_4d_model = tf.keras.models.Sequential([
125+
tf.keras.layers.Permute(perm_axis),
126+
tf.keras.layers.Conv2D(3, 1)
127+
])
128+
129+
permutation_4d_model, inp = keras_to_tf(permutation_4d_model, (1, 2, 3, 4))
130+
model_name = f"permutation_4d_0{''.join(map(str, perm_axis))}"
131+
save_tflite_model(permutation_4d_model, inp, model_name)
132+
133+
global_average_pooling_2d = tf.keras.models.Sequential([
134+
tf.keras.layers.GlobalAveragePooling2D(keepdims=True),
135+
tf.keras.layers.ZeroPadding2D(1),
136+
tf.keras.layers.GlobalAveragePooling2D(keepdims=False)
150137
])
151138

152-
permutation_4d_0213 = tf.function(
153-
permutation_4d_0213.call,
154-
input_signature=[tf.TensorSpec((1,2,3,4), tf.float32)],
155-
)
156-
inp = np.random.standard_normal((1, 2, 3, 4)).astype(np.float32)
157-
save_tflite_model(permutation_4d_0213, inp, 'permutation_4d_0213')
139+
global_average_pooling_2d, inp = keras_to_tf(global_average_pooling_2d, (1, 7, 7, 5))
140+
save_tflite_model(global_average_pooling_2d, inp, 'global_average_pooling_2d')
158141

159-
permutation_4d_0231 = tf.keras.models.Sequential([
160-
tf.keras.layers.Permute((2,3,1)),
161-
tf.keras.layers.Conv2D(3,1)
142+
global_max_pool = tf.keras.models.Sequential([
143+
tf.keras.layers.GlobalMaxPool2D(keepdims=True),
144+
tf.keras.layers.ZeroPadding2D(1),
145+
tf.keras.layers.GlobalMaxPool2D(keepdims=True)
162146
])
163147

164-
permutation_4d_0231 = tf.function(
165-
permutation_4d_0231.call,
166-
input_signature=[tf.TensorSpec((1,2,3,4), tf.float32)],
167-
)
168-
inp = np.random.standard_normal((1, 2, 3, 4)).astype(np.float32)
169-
save_tflite_model(permutation_4d_0231, inp, 'permutation_4d_0231')
148+
global_max_pool, inp = keras_to_tf(global_max_pool, (1, 7, 7, 5))
149+
save_tflite_model(global_max_pool, inp, 'global_max_pooling_2d')
1.24 KB
Binary file not shown.
1.08 KB
Binary file not shown.
148 Bytes
Binary file not shown.
1.25 KB
Binary file not shown.
1.08 KB
Binary file not shown.
148 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)