@@ -432,50 +432,52 @@ def caltech256(info, root, config):
432432
433433@register_mock
434434def imagenet (info , root , config ):
435- wnids = tuple (info .extra .wnid_to_category .keys ())
436- if config .split == "train" :
437- images_root = root / "ILSVRC2012_img_train"
435+ from scipy .io import savemat
438436
437+ categories = info .categories
438+ wnids = [info .extra .category_to_wnid [category ] for category in categories ]
439+ if config .split == "train" :
439440 num_samples = len (wnids )
441+ archive_name = "ILSVRC2012_img_train.tar"
440442
443+ files = []
441444 for wnid in wnids :
442- files = create_image_folder (
443- root = images_root ,
445+ create_image_folder (
446+ root = root ,
444447 name = wnid ,
445448 file_name_fn = lambda image_idx : f"{ wnid } _{ image_idx :04d} .JPEG" ,
446449 num_examples = 1 ,
447450 )
448- make_tar (images_root , f"{ wnid } .tar" , files [ 0 ]. parent )
451+ files . append ( make_tar (root , f"{ wnid } .tar" ) )
449452 elif config .split == "val" :
450453 num_samples = 3
451- files = create_image_folder (
452- root = root ,
453- name = "ILSVRC2012_img_val" ,
454- file_name_fn = lambda image_idx : f"ILSVRC2012_val_{ image_idx + 1 :08d} .JPEG" ,
455- num_examples = num_samples ,
456- )
457- images_root = files [0 ].parent
458- else : # config.split == "test"
459- images_root = root / "ILSVRC2012_img_test_v10102019"
454+ archive_name = "ILSVRC2012_img_val.tar"
455+ files = [create_image_file (root , f"ILSVRC2012_val_{ idx + 1 :08d} .JPEG" ) for idx in range (num_samples )]
460456
461- num_samples = 3
457+ devkit_root = root / "ILSVRC2012_devkit_t12"
458+ data_root = devkit_root / "data"
459+ data_root .mkdir (parents = True )
462460
463- create_image_folder (
464- root = images_root ,
465- name = "test" ,
466- file_name_fn = lambda image_idx : f"ILSVRC2012_test_{ image_idx + 1 :08d} .JPEG" ,
467- num_examples = num_samples ,
468- )
469- make_tar (root , f"{ images_root .name } .tar" , images_root )
461+ with open (data_root / "ILSVRC2012_validation_ground_truth.txt" , "w" ) as file :
462+ for label in torch .randint (0 , len (wnids ), (num_samples ,)).tolist ():
463+ file .write (f"{ label } \n " )
464+
465+ num_children = 0
466+ synsets = [
467+ (idx , wnid , category , "" , num_children , [], 0 , 0 )
468+ for idx , (category , wnid ) in enumerate (zip (categories , wnids ), 1 )
469+ ]
470+ num_children = 1
471+ synsets .extend ((0 , "" , "" , "" , num_children , [], 0 , 0 ) for _ in range (5 ))
472+ savemat (data_root / "meta.mat" , dict (synsets = synsets ))
473+
474+ make_tar (root , devkit_root .with_suffix (".tar.gz" ).name , compression = "gz" )
475+ else : # config.split == "test"
476+ num_samples = 5
477+ archive_name = "ILSVRC2012_img_test_v10102019.tar"
478+ files = [create_image_file (root , f"ILSVRC2012_test_{ idx + 1 :08d} .JPEG" ) for idx in range (num_samples )]
470479
471- devkit_root = root / "ILSVRC2012_devkit_t12"
472- devkit_root .mkdir ()
473- data_root = devkit_root / "data"
474- data_root .mkdir ()
475- with open (data_root / "ILSVRC2012_validation_ground_truth.txt" , "w" ) as file :
476- for label in torch .randint (0 , len (wnids ), (num_samples ,)).tolist ():
477- file .write (f"{ label } \n " )
478- make_tar (root , f"{ devkit_root } .tar.gz" , devkit_root , compression = "gz" )
480+ make_tar (root , archive_name , * files )
479481
480482 return num_samples
481483
@@ -667,14 +669,15 @@ def sbd(info, root, config):
667669@register_mock
668670def semeion (info , root , config ):
669671 num_samples = 3
672+ num_categories = len (info .categories )
670673
671674 images = torch .rand (num_samples , 256 )
672- labels = one_hot (torch .randint (len ( info . categories ) , size = (num_samples ,)))
675+ labels = one_hot (torch .randint (num_categories , size = (num_samples ,)), num_classes = num_categories )
673676 with open (root / "semeion.data" , "w" ) as fh :
674677 for image , one_hot_label in zip (images , labels ):
675678 image_columns = " " .join ([f"{ pixel .item ():.4f} " for pixel in image ])
676679 labels_columns = " " .join ([str (label .item ()) for label in one_hot_label ])
677- fh .write (f"{ image_columns } { labels_columns } \n " )
680+ fh .write (f"{ image_columns } { labels_columns } \n " )
678681
679682 return num_samples
680683
@@ -729,32 +732,33 @@ def _make_detection_anns_folder(cls, root, name, *, file_name_fn, num_examples):
729732 def _make_detection_ann_file (cls , root , name ):
730733 def add_child (parent , name , text = None ):
731734 child = ET .SubElement (parent , name )
732- child .text = text
735+ child .text = str ( text )
733736 return child
734737
735738 def add_name (obj , name = "dog" ):
736739 add_child (obj , "name" , name )
737- return name
738740
739- def add_bndbox (obj , bndbox = None ):
740- if bndbox is None :
741- bndbox = {"xmin" : "1" , "xmax" : "2" , "ymin" : "3" , "ymax" : "4" }
741+ def add_size (obj ):
742+ obj = add_child (obj , "size" )
743+ size = {"width" : 0 , "height" : 0 , "depth" : 3 }
744+ for name , text in size .items ():
745+ add_child (obj , name , text )
742746
747+ def add_bndbox (obj ):
743748 obj = add_child (obj , "bndbox" )
749+ bndbox = {"xmin" : 1 , "xmax" : 2 , "ymin" : 3 , "ymax" : 4 }
744750 for name , text in bndbox .items ():
745751 add_child (obj , name , text )
746752
747- return bndbox
748-
749753 annotation = ET .Element ("annotation" )
754+ add_size (annotation )
750755 obj = add_child (annotation , "object" )
751- data = dict (name = add_name (obj ), bndbox = add_bndbox (obj ))
756+ add_name (obj )
757+ add_bndbox (obj )
752758
753759 with open (root / name , "wb" ) as fh :
754760 fh .write (ET .tostring (annotation ))
755761
756- return data
757-
758762 @classmethod
759763 def generate (cls , root , * , year , trainval ):
760764 archive_folder = root
0 commit comments