1212from torchtitan .datasets .tokenizer import Tokenizer
1313from torchtitan .logging_utils import logger
1414
15- from datasets import load_dataset , load_from_disk
15+ from datasets import load_dataset
1616from datasets .distributed import split_dataset_by_node
1717
18+ # map from dataset name to a local directory, or
19+ # a dataset repository on the HF hub
1820_supported_datasets = {
1921 "c4_mini" : "torchtitan/datasets/c4_mini" ,
2022 "c4" : "allenai/c4" ,
@@ -66,7 +68,7 @@ def __init__(
6668 rank : int = 0 ,
6769 infinite : bool = False ,
6870 ) -> None :
69- # allow user to pass in a local path to use unsupported datasets
71+ # allow user to pass in a ( local or HF hub) path to use unsupported datasets
7072 if dataset_name not in _supported_datasets :
7173 if dataset_path :
7274 logger .warning (
@@ -79,32 +81,18 @@ def __init__(
7981 f"Supported datasets are: { list (_supported_datasets .keys ())} ."
8082 )
8183
82- # special case to auto-load c4_mini (and any future datasets) from local dir
83- if dataset_name == "c4_mini" :
84- dataset_path = f"torchtitan/datasets/ { dataset_name } "
84+ if not dataset_path :
85+ dataset_path = _supported_datasets [ dataset_name ]
86+ logger . info ( f"Preparing { dataset_name } dataset from { dataset_path } " )
8587
86- # TODO: This is a temporary solution for small datasets.
87- # For large datasets we need to use a more scalable approach,
88- # and support shuffling and checkpointing.
89- if dataset_path :
90- logger .info (f"Loading { dataset_name } dataset locally from { dataset_path } " )
91- ds = load_from_disk (dataset_path )
88+ if dataset_name == "c4" :
89+ # c4 is huge, and requires both streaming and language selection
90+ # (we default to en)
91+ ds = load_dataset (dataset_path , name = "en" , split = "train" , streaming = True )
9292 else :
93- logger .info (f"Preparing { dataset_name } dataset from HuggingFace" )
94- # Setting `streaming=True` works for large dataset, but is slightly
95- # slower and unstable.
96- if dataset_name == "c4" :
97- # c4 is huge, and requires both streaming and language selection
98- # (we default to en).
99- ds = load_dataset (
100- _supported_datasets [dataset_name ],
101- "en" ,
102- split = "train" ,
103- streaming = True ,
104- )
105- else :
106- ds = load_dataset (_supported_datasets [dataset_name ], split = "train" )
93+ ds = load_dataset (dataset_path , split = "train" )
10794
95+ # TODO: support shuffling and checkpointing
10896 self .dataset_name = dataset_name
10997 self ._data = split_dataset_by_node (ds , rank , world_size )
11098 self ._tokenizer = tokenizer
0 commit comments