@@ -316,6 +316,11 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
316316 """
317317
318318 if obj .__name__ .endswith (("_Weights" , "_QuantizedWeights" )):
319+
320+ if len (obj ) == 0 :
321+ lines [:] = ["There are no available pre-trained weights." ]
322+ return
323+
319324 lines [:] = [
320325 "The model builder above accepts the following values as the ``weights`` parameter." ,
321326 f"``{ obj .__name__ } .DEFAULT`` is equivalent to ``{ obj .DEFAULT } ``." ,
@@ -329,37 +334,44 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
329334 lines .append ("" )
330335
331336 for field in obj :
337+ meta = copy (field .meta )
338+
332339 lines += [f"**{ str (field )} **:" , "" ]
340+ lines += [meta .pop ("_docs" )]
341+
333342 if field == obj .DEFAULT :
334- lines += [f"This weight is also available as ``{ obj .__name__ } .DEFAULT``." , "" ]
343+ lines += [f"Also available as ``{ obj .__name__ } .DEFAULT``." ]
344+ lines += ["" ]
335345
336346 table = []
347+ metrics = meta .pop ("_metrics" )
348+ for dataset , dataset_metrics in metrics .items ():
349+ for metric_name , metric_value in dataset_metrics .items ():
350+ table .append ((f"{ metric_name } (on { dataset } )" , str (metric_value )))
337351
338- # the `meta` dict contains another embedded `metrics` dict. To
339- # simplify the table generation below, we create the
340- # `meta_with_metrics` dict, where the metrics dict has been "flattened"
341- meta = copy (field .meta )
342- metrics = meta .pop ("metrics" , {})
343- meta_with_metrics = dict (meta , ** metrics )
344-
345- meta_with_metrics .pop ("categories" , None ) # We don't want to document these, they can be too long
346-
347- custom_docs = meta_with_metrics .pop ("_docs" , None ) # Custom per-Weights docs
348- if custom_docs is not None :
349- lines += [custom_docs , "" ]
350-
351- for k , v in meta_with_metrics .items ():
352- if k == "recipe" :
352+ for k , v in meta .items ():
353+ if k in {"recipe" , "license" }:
353354 v = f"`link <{ v } >`__"
355+ elif k == "min_size" :
356+ v = f"height={ v [0 ]} , width={ v [1 ]} "
357+ elif k in {"categories" , "keypoint_names" } and isinstance (v , list ):
358+ max_visible = 3
359+ v_sample = ", " .join (v [:max_visible ])
360+ v = f"{ v_sample } , ... ({ len (v )- max_visible } omitted)" if len (v ) > max_visible else v_sample
354361 table .append ((str (k ), str (v )))
355362 table = tabulate (table , tablefmt = "rst" )
356363 lines += [".. rst-class:: table-weights" ] # Custom CSS class, see custom_torchvision.css
357364 lines += [".. table::" , "" ]
358365 lines += textwrap .indent (table , " " * 4 ).split ("\n " )
359366 lines .append ("" )
367+ lines .append (
368+ f"The inference transforms are available at ``{ str (field )} .transforms`` and "
369+ f"perform the following preprocessing operations: { field .transforms ().describe ()} "
370+ )
371+ lines .append ("" )
360372
361373
362- def generate_weights_table (module , table_name , metrics , include_patterns = None , exclude_patterns = None ):
374+ def generate_weights_table (module , table_name , metrics , dataset , include_patterns = None , exclude_patterns = None ):
363375 weights_endswith = "_QuantizedWeights" if module .__name__ .split ("." )[- 1 ] == "quantization" else "_Weights"
364376 weight_enums = [getattr (module , name ) for name in dir (module ) if name .endswith (weights_endswith )]
365377 weights = [w for weight_enum in weight_enums for w in weight_enum ]
@@ -376,7 +388,7 @@ def generate_weights_table(module, table_name, metrics, include_patterns=None, e
376388 content = [
377389 (
378390 f":class:`{ w } <{ type (w ).__name__ } >`" ,
379- * (w .meta ["metrics" ][metric ] for metric in metrics_keys ),
391+ * (w .meta ["_metrics" ][ dataset ][metric ] for metric in metrics_keys ),
380392 f"{ w .meta ['num_params' ]/ 1e6 :.1f} M" ,
381393 f"`link <{ w .meta ['recipe' ]} >`__" ,
382394 )
@@ -393,29 +405,45 @@ def generate_weights_table(module, table_name, metrics, include_patterns=None, e
393405 table_file .write (f"{ textwrap .indent (table , ' ' * 4 )} \n \n " )
394406
395407
396- generate_weights_table (module = M , table_name = "classification" , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )])
397408generate_weights_table (
398- module = M . quantization , table_name = "classification_quant " , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )]
409+ module = M , table_name = "classification " , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )], dataset = "ImageNet-1K"
399410)
400411generate_weights_table (
401- module = M .detection , table_name = "detection" , metrics = [("box_map" , "Box MAP" )], exclude_patterns = ["Mask" , "Keypoint" ]
412+ module = M .quantization ,
413+ table_name = "classification_quant" ,
414+ metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )],
415+ dataset = "ImageNet-1K" ,
416+ )
417+ generate_weights_table (
418+ module = M .detection ,
419+ table_name = "detection" ,
420+ metrics = [("box_map" , "Box MAP" )],
421+ exclude_patterns = ["Mask" , "Keypoint" ],
422+ dataset = "COCO-val2017" ,
402423)
403424generate_weights_table (
404425 module = M .detection ,
405426 table_name = "instance_segmentation" ,
406427 metrics = [("box_map" , "Box MAP" ), ("mask_map" , "Mask MAP" )],
428+ dataset = "COCO-val2017" ,
407429 include_patterns = ["Mask" ],
408430)
409431generate_weights_table (
410432 module = M .detection ,
411433 table_name = "detection_keypoint" ,
412434 metrics = [("box_map" , "Box MAP" ), ("kp_map" , "Keypoint MAP" )],
435+ dataset = "COCO-val2017" ,
413436 include_patterns = ["Keypoint" ],
414437)
415438generate_weights_table (
416- module = M .segmentation , table_name = "segmentation" , metrics = [("miou" , "Mean IoU" ), ("pixel_acc" , "pixelwise Acc" )]
439+ module = M .segmentation ,
440+ table_name = "segmentation" ,
441+ metrics = [("miou" , "Mean IoU" ), ("pixel_acc" , "pixelwise Acc" )],
442+ dataset = "COCO-val2017-VOC-labels" ,
443+ )
444+ generate_weights_table (
445+ module = M .video , table_name = "video" , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )], dataset = "Kinetics-400"
417446)
418- generate_weights_table (module = M .video , table_name = "video" , metrics = [("acc@1" , "Acc@1" ), ("acc@5" , "Acc@5" )])
419447
420448
421449def setup (app ):
0 commit comments