@@ -431,50 +431,52 @@ def caltech256(info, root, config):
431431
432432@register_mock
433433def imagenet (info , root , config ):
434- wnids = tuple (info .extra .wnid_to_category .keys ())
435- if config .split == "train" :
436- images_root = root / "ILSVRC2012_img_train"
434+ from scipy .io import savemat
437435
436+ categories = info .categories
437+ wnids = [info .extra .category_to_wnid [category ] for category in categories ]
438+ if config .split == "train" :
438439 num_samples = len (wnids )
440+ archive_name = "ILSVRC2012_img_train.tar"
439441
442+ files = []
440443 for wnid in wnids :
441- files = create_image_folder (
442- root = images_root ,
444+ create_image_folder (
445+ root = root ,
443446 name = wnid ,
444447 file_name_fn = lambda image_idx : f"{ wnid } _{ image_idx :04d} .JPEG" ,
445448 num_examples = 1 ,
446449 )
447- make_tar (images_root , f"{ wnid } .tar" , files [ 0 ]. parent )
450+ files . append ( make_tar (root , f"{ wnid } .tar" ) )
448451 elif config .split == "val" :
449452 num_samples = 3
450- files = create_image_folder (
451- root = root ,
452- name = "ILSVRC2012_img_val" ,
453- file_name_fn = lambda image_idx : f"ILSVRC2012_val_{ image_idx + 1 :08d} .JPEG" ,
454- num_examples = num_samples ,
455- )
456- images_root = files [0 ].parent
457- else : # config.split == "test"
458- images_root = root / "ILSVRC2012_img_test_v10102019"
453+ archive_name = "ILSVRC2012_img_val.tar"
454+ files = [create_image_file (root , f"ILSVRC2012_val_{ idx + 1 :08d} .JPEG" ) for idx in range (num_samples )]
459455
460- num_samples = 3
456+ devkit_root = root / "ILSVRC2012_devkit_t12"
457+ data_root = devkit_root / "data"
458+ data_root .mkdir (parents = True )
461459
462- create_image_folder (
463- root = images_root ,
464- name = "test" ,
465- file_name_fn = lambda image_idx : f"ILSVRC2012_test_{ image_idx + 1 :08d} .JPEG" ,
466- num_examples = num_samples ,
467- )
468- make_tar (root , f"{ images_root .name } .tar" , images_root )
460+ with open (data_root / "ILSVRC2012_validation_ground_truth.txt" , "w" ) as file :
461+ for label in torch .randint (0 , len (wnids ), (num_samples ,)).tolist ():
462+ file .write (f"{ label } \n " )
463+
464+ num_children = 0
465+ synsets = [
466+ (idx , wnid , category , "" , num_children , [], 0 , 0 )
467+ for idx , (category , wnid ) in enumerate (zip (categories , wnids ), 1 )
468+ ]
469+ num_children = 1
470+ synsets .extend ((0 , "" , "" , "" , num_children , [], 0 , 0 ) for _ in range (5 ))
471+ savemat (data_root / "meta.mat" , dict (synsets = synsets ))
472+
473+ make_tar (root , devkit_root .with_suffix (".tar.gz" ).name , compression = "gz" )
474+ else : # config.split == "test"
475+ num_samples = 5
476+ archive_name = "ILSVRC2012_img_test_v10102019.tar"
477+ files = [create_image_file (root , f"ILSVRC2012_test_{ idx + 1 :08d} .JPEG" ) for idx in range (num_samples )]
469478
470- devkit_root = root / "ILSVRC2012_devkit_t12"
471- devkit_root .mkdir ()
472- data_root = devkit_root / "data"
473- data_root .mkdir ()
474- with open (data_root / "ILSVRC2012_validation_ground_truth.txt" , "w" ) as file :
475- for label in torch .randint (0 , len (wnids ), (num_samples ,)).tolist ():
476- file .write (f"{ label } \n " )
477- make_tar (root , f"{ devkit_root } .tar.gz" , devkit_root , compression = "gz" )
479+ make_tar (root , archive_name , * files )
478480
479481 return num_samples
480482
@@ -666,14 +668,15 @@ def sbd(info, root, config):
666668@register_mock
667669def semeion (info , root , config ):
668670 num_samples = 3
671+ num_categories = len (info .categories )
669672
670673 images = torch .rand (num_samples , 256 )
671- labels = one_hot (torch .randint (len ( info . categories ) , size = (num_samples ,)))
674+ labels = one_hot (torch .randint (num_categories , size = (num_samples ,)), num_classes = num_categories )
672675 with open (root / "semeion.data" , "w" ) as fh :
673676 for image , one_hot_label in zip (images , labels ):
674677 image_columns = " " .join ([f"{ pixel .item ():.4f} " for pixel in image ])
675678 labels_columns = " " .join ([str (label .item ()) for label in one_hot_label ])
676- fh .write (f"{ image_columns } { labels_columns } \n " )
679+ fh .write (f"{ image_columns } { labels_columns } \n " )
677680
678681 return num_samples
679682
@@ -728,32 +731,33 @@ def _make_detection_anns_folder(cls, root, name, *, file_name_fn, num_examples):
728731 def _make_detection_ann_file (cls , root , name ):
729732 def add_child (parent , name , text = None ):
730733 child = ET .SubElement (parent , name )
731- child .text = text
734+ child .text = str ( text )
732735 return child
733736
734737 def add_name (obj , name = "dog" ):
735738 add_child (obj , "name" , name )
736- return name
737739
738- def add_bndbox (obj , bndbox = None ):
739- if bndbox is None :
740- bndbox = {"xmin" : "1" , "xmax" : "2" , "ymin" : "3" , "ymax" : "4" }
740+ def add_size (obj ):
741+ obj = add_child (obj , "size" )
742+ size = {"width" : 0 , "height" : 0 , "depth" : 3 }
743+ for name , text in size .items ():
744+ add_child (obj , name , text )
741745
746+ def add_bndbox (obj ):
742747 obj = add_child (obj , "bndbox" )
748+ bndbox = {"xmin" : 1 , "xmax" : 2 , "ymin" : 3 , "ymax" : 4 }
743749 for name , text in bndbox .items ():
744750 add_child (obj , name , text )
745751
746- return bndbox
747-
748752 annotation = ET .Element ("annotation" )
753+ add_size (annotation )
749754 obj = add_child (annotation , "object" )
750- data = dict (name = add_name (obj ), bndbox = add_bndbox (obj ))
755+ add_name (obj )
756+ add_bndbox (obj )
751757
752758 with open (root / name , "wb" ) as fh :
753759 fh .write (ET .tostring (annotation ))
754760
755- return data
756-
757761 @classmethod
758762 def generate (cls , root , * , year , trainval ):
759763 archive_folder = root
0 commit comments