|  | 
| 26 | 26 | from PIL import Image | 
| 27 | 27 | import time | 
| 28 | 28 | import ilit | 
|  | 29 | +from ilit.adaptor.tf_utils.util import _parse_ckpt_bn_input | 
| 29 | 30 | 
 | 
| 30 | 31 | flags = tf.flags | 
| 31 | 32 | flags.DEFINE_string('style_images_paths', None, 'Paths to the style images' | 
| @@ -77,39 +78,6 @@ def image_style_transfer(sess, content_img_path, style_img_path): | 
| 77 | 78 |     # saves stylized image. | 
| 78 | 79 |     save_image(stylized_image_res, os.path.join(FLAGS.output_dir, 'stylized_image.jpg')) | 
| 79 | 80 | 
 | 
| 80 |  | -def _parse_ckpt_bn_input(graph_def): | 
| 81 |  | -    """parse ckpt batch norm inputs to match correct moving mean and variance | 
| 82 |  | -    Args: | 
| 83 |  | -        graph_def (graph_def): original graph_def | 
| 84 |  | -    Returns: | 
| 85 |  | -        graph_def: well linked graph_def  | 
| 86 |  | -    """ | 
| 87 |  | -    for node in graph_def.node: | 
| 88 |  | -        if node.op == 'FusedBatchNorm': | 
| 89 |  | -            moving_mean_op_name = node.input[3] | 
| 90 |  | -            moving_var_op_name = node.input[4] | 
| 91 |  | -            moving_mean_op = _get_nodes_from_name(moving_mean_op_name, graph_def)[0] | 
| 92 |  | -            moving_var_op = _get_nodes_from_name(moving_var_op_name, graph_def)[0] | 
| 93 |  | - | 
| 94 |  | -            if moving_mean_op.op == 'Const': | 
| 95 |  | -                name_part = moving_mean_op_name.rsplit('/', 1)[0] | 
| 96 |  | -                real_moving_mean_op_name = name_part + '/moving_mean' | 
| 97 |  | -                if len(_get_nodes_from_name(real_moving_mean_op_name, graph_def)) > 0: | 
| 98 |  | -                    # replace the real moving mean op name | 
| 99 |  | -                    node.input[3] = real_moving_mean_op_name | 
| 100 |  | - | 
| 101 |  | -            if moving_var_op.op == 'Const': | 
| 102 |  | -                name_part = moving_var_op_name.rsplit('/', 1)[0] | 
| 103 |  | -                real_moving_var_op_name = name_part + '/moving_variance' | 
| 104 |  | -                if len(_get_nodes_from_name(real_moving_var_op_name, graph_def)) > 0: | 
| 105 |  | -                    # replace the real moving mean op name | 
| 106 |  | -                    node.input[4] = real_moving_var_op_name | 
| 107 |  | - | 
| 108 |  | -    return graph_def | 
| 109 |  | - | 
| 110 |  | -def _get_nodes_from_name(node_name, graph_def): | 
| 111 |  | -    return [node for node in graph_def.node if node.name==node_name] | 
| 112 |  | -  | 
| 113 | 81 | def main(args=None): | 
| 114 | 82 |   tf.logging.set_verbosity(tf.logging.INFO) | 
| 115 | 83 |   if not tf.gfile.Exists(FLAGS.output_dir): | 
|  | 
0 commit comments