99
1010Hacked together by / Copyright 2020 Ross Wightman
1111"""
12+ import logging
1213import os
13- import tarfile
1414import pickle
15- import logging
16- import numpy as np
15+ import tarfile
1716from glob import glob
18- from typing import List , Dict
17+ from typing import List , Tuple , Dict , Set , Optional , Union
18+
19+ import numpy as np
1920
2021from timm .utils .misc import natural_key
2122
22- from .parser import Parser
2323from .class_map import load_class_map
24- from .constants import IMG_EXTENSIONS
25-
24+ from .img_extensions import get_img_extensions
25+ from . parser import Parser
2626
2727_logger = logging .getLogger (__name__ )
2828CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
@@ -39,7 +39,7 @@ def reset(self):
3939 self .tf = None
4040
4141
42- def _extract_tarinfo (tf : tarfile .TarFile , parent_info : Dict , extensions = IMG_EXTENSIONS ):
42+ def _extract_tarinfo (tf : tarfile .TarFile , parent_info : Dict , extensions : Set [ str ] ):
4343 sample_count = 0
4444 for i , ti in enumerate (tf ):
4545 if not ti .isfile ():
@@ -60,7 +60,14 @@ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTE
6060 return sample_count
6161
6262
63- def extract_tarinfos (root , class_name_to_idx = None , cache_tarinfo = None , extensions = IMG_EXTENSIONS , sort = True ):
63+ def extract_tarinfos (
64+ root ,
65+ class_name_to_idx : Optional [Dict ] = None ,
66+ cache_tarinfo : Optional [bool ] = None ,
67+ extensions : Optional [Union [List , Tuple , Set ]] = None ,
68+ sort : bool = True
69+ ):
70+ extensions = get_img_extensions (as_set = True ) if not extensions else set (extensions )
6471 root_is_tar = False
6572 if os .path .isfile (root ):
6673 assert os .path .splitext (root )[- 1 ].lower () == '.tar'
@@ -176,8 +183,8 @@ def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
176183 self .samples , self .targets , self .class_name_to_idx , tarfiles = extract_tarinfos (
177184 self .root ,
178185 class_name_to_idx = class_name_to_idx ,
179- cache_tarinfo = cache_tarinfo ,
180- extensions = IMG_EXTENSIONS )
186+ cache_tarinfo = cache_tarinfo
187+ )
181188 self .class_idx_to_name = {v : k for k , v in self .class_name_to_idx .items ()}
182189 if len (tarfiles ) == 1 and tarfiles [0 ][0 ] is None :
183190 self .root_is_tar = True
0 commit comments