Skip to content

Commit 91f7f97

Browse files
committed
Test data for conditional BatchNorm from TensorFlow
1 parent 87e8e51 commit 91f7f97

File tree

4 files changed

+27
-1
lines changed

4 files changed

+27
-1
lines changed

testdata/dnn/tensorflow/generate_tf_models.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def prepare_for_dnn(sess, graph_def, in_node, out_node, out_graph, dtype, optimi
3030
graph_def = TransformGraph(graph_def, [in_node], [out_node], transforms)
3131
# Serialize
3232
with tf.gfile.FastGFile(out_graph, 'wb') as f:
33-
f.write(graph_def.SerializeToString())
33+
f.write(graph_def.SerializeToString())
3434

3535
tf.reset_default_graph()
3636
tf.Graph().as_default()
@@ -677,6 +677,32 @@ def pad_depth(x, desired_channels):
677677
softmax = tf.contrib.slim.softmax(conv)
678678
save(inp, softmax, 'slim_softmax')
679679
################################################################################
680+
# issue https://github.com/opencv/opencv/issues/14224
681+
inp_node = 'img_inputs'
682+
out_node = 'MobileFaceNet/MobileFaceNet/Conv2d_0/add'
683+
with tf.Session(graph=tf.Graph()) as localSession:
684+
localSession.graph.as_default()
685+
686+
with tf.gfile.FastGFile('frozen_model.pb') as f:
687+
graph_def = tf.GraphDef()
688+
graph_def.ParseFromString(f.read())
689+
for node in graph_def.node:
690+
if node.name == inp_node:
691+
del node.attr['shape']
692+
693+
tf.import_graph_def(graph_def, name='')
694+
695+
inputData = gen_data(tf.placeholder(tf.float32, [1, 4, 5, 3], inp_node))
696+
outputData = localSession.run(localSession.graph.get_tensor_by_name(out_node + ':0'),
697+
feed_dict={inp_node + ':0': inputData})
698+
writeBlob(inputData, 'slim_batch_norm_in')
699+
writeBlob(outputData, 'slim_batch_norm_out')
700+
701+
graph_def = TransformGraph(graph_def, [inp_node], [out_node], ['fold_constants', 'strip_unused_nodes'])
702+
with tf.gfile.FastGFile('slim_batch_norm_net.pb', 'wb') as f:
703+
f.write(graph_def.SerializeToString())
704+
705+
################################################################################
680706

681707
# Uncomment to print the final graph.
682708
# with tf.gfile.FastGFile('fused_batch_norm_net.pb') as f:
320 Bytes
Binary file not shown.
20.2 KB
Binary file not shown.
1.58 KB
Binary file not shown.

0 commit comments

Comments
 (0)