2828# map from dataset name to a local directory, or
2929# a dataset repository on the HF hub
3030_supported_datasets = {
31- "c4_mini " : "torchtitan/datasets/c4_mini " ,
31+ "c4_test " : "test/assets/c4_test " ,
3232 "c4" : "allenai/c4" ,
3333}
3434
@@ -48,8 +48,8 @@ class HuggingFaceDataset(IterableDataset, Stateful):
4848 rank (int): rank of the current data parallel process
4949 infinite (bool): whether to loop infinitely over the dataset
5050
51- We currently support the c4 dataset and a subset of it:
52- c4_mini (45K training entries)
51+ We currently support the c4 dataset, and a subset of it for testing purposes :
52+ c4_test (2K training entries)
5353 c4 (177M training entries - this dataset is streamed due to the size)
5454
5555 >> c4 (EN) <<:
@@ -83,12 +83,12 @@ def __init__(
8383 if dataset_path :
8484 logger .warning (
8585 f"Dataset { dataset_name } is not tested or verfied. "
86- f"Recommended datasets are: { list (_supported_datasets .keys ())} . "
86+ f"Recommended datasets are: { list (_supported_datasets .keys ())} "
8787 )
8888 else :
8989 raise ValueError (
9090 f"Dataset { dataset_name } is not supported. "
91- f"Supported datasets are: { list (_supported_datasets .keys ())} . "
91+ f"Supported datasets are: { list (_supported_datasets .keys ())} "
9292 )
9393
9494 if not dataset_path :
@@ -132,15 +132,12 @@ def __iter__(self):
132132 yield input , label
133133
134134 if not self .infinite :
135- logger .warning (f"Dataset { self .dataset_name } has run out of data. " )
135+ logger .warning (f"Dataset { self .dataset_name } has run out of data" )
136136 break
137137 else :
138138 # Reset offset for the next iteration
139139 self ._sample_idx = 0
140- logger .warning (
141- f"Dataset { self .dataset_name } is being re-looped. "
142- "Loss related metrics might be misleading."
143- )
140+ logger .warning (f"Dataset { self .dataset_name } is being re-looped" )
144141
145142 def _get_data_iter (self ):
146143 if self ._sample_idx == 0 :
@@ -188,7 +185,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
188185
189186 if self ._rank_id not in state_dict :
190187 logger .warning (
191- f"DataLoader state is empty for dp rank { self ._dp_rank } , expected key { self ._rank_id } . "
188+ f"DataLoader state is empty for dp rank { self ._dp_rank } , expected key { self ._rank_id } "
192189 )
193190 return
194191 super ().load_state_dict (pickle .loads (state_dict [self ._rank_id ]))
0 commit comments