2121
2222import argparse
2323import os
24+ import sys
2425
2526from absl import app
2627from absl .flags import argparse_flags
3031
3132import tensorflow_compression as tfc # pylint:disable=unused-import
3233
34+ # Default URL to fetch metagraphs from.
35+ URL_PREFIX = "https://storage.googleapis.com/tensorflow_compression/metagraphs"
36+ # Default location to store cached metagraphs.
37+ METAGRAPH_CACHE = "/tmp/tfc_metagraphs"
38+
3339
3440def read_png (filename ):
3541 """Creates graph to load a PNG image file."""
@@ -50,22 +56,28 @@ def write_png(filename, image):
5056 return tf .io .write_file (filename , string )
5157
5258
53- def load_metagraph ( model , url_prefix , metagraph_cache ):
54- """Loads and caches a trained model metagraph ."""
55- filename = os .path .join (metagraph_cache , model + ".metagraph" )
59+ def load_cached ( filename ):
60+ """Downloads and caches files from web storage ."""
61+ pathname = os .path .join (METAGRAPH_CACHE , filename )
5662 try :
57- with tf .io .gfile .GFile (filename , "rb" ) as f :
63+ with tf .io .gfile .GFile (pathname , "rb" ) as f :
5864 string = f .read ()
5965 except tf .errors .NotFoundError :
60- url = url_prefix + "/" + model + ".metagraph"
66+ url = URL_PREFIX + "/" + filename
6167 try :
6268 request = urllib .request .urlopen (url )
6369 string = request .read ()
6470 finally :
6571 request .close ()
66- tf .io .gfile .makedirs (os .path .dirname (filename ))
67- with tf .io .gfile .GFile (filename , "wb" ) as f :
72+ tf .io .gfile .makedirs (os .path .dirname (pathname ))
73+ with tf .io .gfile .GFile (pathname , "wb" ) as f :
6874 f .write (string )
75+ return string
76+
77+
78+ def import_metagraph (model ):
79+ """Imports a trained model metagraph into the current graph."""
80+ string = load_cached (model + ".metagraph" )
6981 metagraph = tf .MetaGraphDef ()
7082 metagraph .ParseFromString (string )
7183 tf .train .import_meta_graph (metagraph )
@@ -86,14 +98,11 @@ def instantiate_signature(signature_def):
8698 return inputs , outputs
8799
88100
89- def compress (model , input_file , output_file , url_prefix , metagraph_cache ):
90- """Compresses a PNG file to a TFCI file."""
91- if not output_file :
92- output_file = input_file + ".tfci"
93-
101+ def compress_image (model , input_image ):
102+ """Compresses an image array into a bitstring."""
94103 with tf .Graph ().as_default ():
95104 # Load model metagraph.
96- signature_defs = load_metagraph (model , url_prefix , metagraph_cache )
105+ signature_defs = import_metagraph (model )
97106 inputs , outputs = instantiate_signature (signature_defs ["sender" ])
98107
99108 # Just one input tensor.
@@ -103,12 +112,12 @@ def compress(model, input_file, output_file, url_prefix, metagraph_cache):
103112
104113 # Run encoder.
105114 with tf .Session () as sess :
106- feed_dict = {inputs : sess . run ( read_png ( input_file )) }
115+ feed_dict = {inputs : input_image }
107116 arrays = sess .run (outputs , feed_dict = feed_dict )
108117
109118 # Pack data into tf.Example.
110119 example = tf .train .Example ()
111- example .features .feature ["MD" ].bytes_list .value [:] = [model ]
120+ example .features .feature ["MD" ].bytes_list .value [:] = [model . encode ( "ascii" ) ]
112121 for i , (array , tensor ) in enumerate (zip (arrays , outputs )):
113122 feature = example .features .feature [chr (i + 1 )]
114123 if array .ndim != 1 :
@@ -121,12 +130,60 @@ def compress(model, input_file, output_file, url_prefix, metagraph_cache):
121130 raise RuntimeError (
122131 "Unexpected tensor dtype: '{}'." .format (tensor .dtype ))
123132
124- # Write serialized tf.Example to disk.
125- with tf .io .gfile .GFile (output_file , "wb" ) as f :
126- f .write (example .SerializeToString ())
133+ return example .SerializeToString ()
127134
128135
129- def decompress (input_file , output_file , url_prefix , metagraph_cache ):
136+ def compress (model , input_file , output_file , target_bpp = None , bpp_strict = False ):
137+ """Compresses a PNG file to a TFCI file."""
138+ if not output_file :
139+ output_file = input_file + ".tfci"
140+
141+ # Load image.
142+ with tf .Graph ().as_default ():
143+ with tf .Session () as sess :
144+ input_image = sess .run (read_png (input_file ))
145+ num_pixels = input_image .shape [- 2 ] * input_image .shape [- 3 ]
146+
147+ if not target_bpp :
148+ # Just compress with a specific model.
149+ bitstring = compress_image (model , input_image )
150+ else :
151+ # Get model list.
152+ models = load_cached (model + ".models" )
153+ models = models .decode ("ascii" ).split ()
154+
155+ # Do a binary search over all RD points.
156+ lower = - 1
157+ upper = len (models )
158+ bpp = None
159+ best_bitstring = None
160+ best_bpp = None
161+ while bpp != target_bpp and upper - lower > 1 :
162+ i = (upper + lower ) // 2
163+ bitstring = compress_image (models [i ], input_image )
164+ bpp = 8 * len (bitstring ) / num_pixels
165+ is_admissible = bpp <= target_bpp or not bpp_strict
166+ is_better = (best_bpp is None or
167+ abs (bpp - target_bpp ) < abs (best_bpp - target_bpp ))
168+ if is_admissible and is_better :
169+ best_bitstring = bitstring
170+ best_bpp = bpp
171+ if bpp < target_bpp :
172+ lower = i
173+ if bpp > target_bpp :
174+ upper = i
175+ if best_bpp is None :
176+ assert bpp_strict
177+ raise RuntimeError (
178+ "Could not compress image to less than {} bpp." .format (target_bpp ))
179+ bitstring = best_bitstring
180+
181+ # Write bitstring to disk.
182+ with tf .io .gfile .GFile (output_file , "wb" ) as f :
183+ f .write (bitstring )
184+
185+
186+ def decompress (input_file , output_file ):
130187 """Decompresses a TFCI file and writes a PNG file."""
131188 if not output_file :
132189 output_file = input_file + ".png"
@@ -136,10 +193,10 @@ def decompress(input_file, output_file, url_prefix, metagraph_cache):
136193 with tf .io .gfile .GFile (input_file , "rb" ) as f :
137194 example = tf .train .Example ()
138195 example .ParseFromString (f .read ())
139- model = example .features .feature ["MD" ].bytes_list .value [0 ]
196+ model = example .features .feature ["MD" ].bytes_list .value [0 ]. decode ( "ascii" )
140197
141198 # Load model metagraph.
142- signature_defs = load_metagraph (model , url_prefix , metagraph_cache )
199+ signature_defs = import_metagraph (model )
143200 inputs , outputs = instantiate_signature (signature_defs ["receiver" ])
144201
145202 # Multiple input tensors, ordered alphabetically, without names.
@@ -166,52 +223,59 @@ def decompress(input_file, output_file, url_prefix, metagraph_cache):
166223 sess .run (outputs , feed_dict = feed_dict )
167224
168225
169- def list_models (url_prefix ):
170- url = url_prefix + "/models.txt"
226+ def list_models ():
227+ url = URL_PREFIX + "/models.txt"
171228 try :
172229 request = urllib .request .urlopen (url )
173- print (request .read ())
230+ print (request .read (). decode ( "utf-8" ) )
174231 finally :
175232 request .close ()
176233
177234
178235def parse_args (argv ):
236+ """Parses command line arguments."""
179237 parser = argparse_flags .ArgumentParser (
180238 formatter_class = argparse .ArgumentDefaultsHelpFormatter )
181239
182240 # High-level options.
183241 parser .add_argument (
184242 "--url_prefix" ,
185- default = "https://storage.googleapis.com/tensorflow_compression/"
186- "metagraphs" ,
243+ default = URL_PREFIX ,
187244 help = "URL prefix for downloading model metagraphs." )
188245 parser .add_argument (
189246 "--metagraph_cache" ,
190- default = "/tmp/tfc_metagraphs" ,
247+ default = METAGRAPH_CACHE ,
191248 help = "Directory where to cache model metagraphs." )
192249 subparsers = parser .add_subparsers (
193- title = "commands" , help = "Invoke '<command> -h' for more information." )
250+ title = "commands" , dest = "command" ,
251+ help = "Invoke '<command> -h' for more information." )
194252
195253 # 'compress' subcommand.
196254 compress_cmd = subparsers .add_parser (
197255 "compress" ,
198256 description = "Reads a PNG file, compresses it using the given model, and "
199257 "writes a TFCI file." )
200- compress_cmd .set_defaults (
201- f = compress ,
202- a = ["model" , "input_file" , "output_file" , "url_prefix" , "metagraph_cache" ])
203258 compress_cmd .add_argument (
204259 "model" ,
205- help = "Unique model identifier. See 'models' command for options." )
260+ help = "Unique model identifier. See 'models' command for options. If "
261+ "'target_bpp' is provided, don't specify the index at the end of "
262+ "the model identifier." )
263+ compress_cmd .add_argument (
264+ "--target_bpp" , type = float ,
265+ help = "Target bits per pixel. If provided, a binary search is used to try "
266+ "to match the given bpp as close as possible. In this case, don't "
267+ "specify the index at the end of the model identifier. It will be "
268+ "automatically determined." )
269+ compress_cmd .add_argument (
270+ "--bpp_strict" , action = "store_true" ,
271+ help = "Try never to exceed 'target_bpp'. Ignored if 'target_bpp' is not "
272+ "set." )
206273
207274 # 'decompress' subcommand.
208275 decompress_cmd = subparsers .add_parser (
209276 "decompress" ,
210277 description = "Reads a TFCI file, reconstructs the image using the model "
211278 "it was compressed with, and writes back a PNG file." )
212- decompress_cmd .set_defaults (
213- f = decompress ,
214- a = ["input_file" , "output_file" , "url_prefix" , "metagraph_cache" ])
215279
216280 # Arguments for both 'compress' and 'decompress'.
217281 for cmd , ext in ((compress_cmd , ".tfci" ), (decompress_cmd , ".png" )):
@@ -224,18 +288,34 @@ def parse_args(argv):
224288 "the input filename." .format (ext ))
225289
226290 # 'models' subcommand.
227- models_cmd = subparsers .add_parser (
291+ subparsers .add_parser (
228292 "models" ,
229293 description = "Lists available trained models. Requires an internet "
230294 "connection." )
231- models_cmd .set_defaults (f = list_models , a = ["url_prefix" ])
232295
233296 # Parse arguments.
234- return parser .parse_args (argv [1 :])
297+ args = parser .parse_args (argv [1 :])
298+ if args .command is None :
299+ parser .print_usage ()
300+ sys .exit (2 )
301+ return args
302+
303+
304+ def main (args ):
305+ # Command line can override these defaults.
306+ global URL_PREFIX , METAGRAPH_CACHE
307+ URL_PREFIX = args .url_prefix
308+ METAGRAPH_CACHE = args .metagraph_cache
309+
310+ # Invoke subcommand.
311+ if args .command == "compress" :
312+ compress (args .model , args .input_file , args .output_file ,
313+ args .target_bpp , args .bpp_strict )
314+ if args .command == "decompress" :
315+ decompress (args .input_file , args .output_file )
316+ if args .command == "models" :
317+ list_models ()
235318
236319
237320if __name__ == "__main__" :
238- # Parse arguments and run function determined by subcommand.
239- app .run (
240- lambda args : args .f (** {k : getattr (args , k ) for k in args .a }),
241- flags_parser = parse_args )
321+ app .run (main , flags_parser = parse_args )
0 commit comments