Skip to content

Commit d5e2113

Browse files
committed
unify data loading from HF and from disk
ghstack-source-id: a8f5228 Pull Request resolved: #287
1 parent d442743 commit d5e2113

File tree

4 files changed

+13
-58
lines changed

4 files changed

+13
-58
lines changed
-96.6 MB
Binary file not shown.

torchtitan/datasets/c4_mini/dataset_info.json

Lines changed: 0 additions & 20 deletions
This file was deleted.

torchtitan/datasets/c4_mini/state.json

Lines changed: 0 additions & 13 deletions
This file was deleted.

torchtitan/datasets/hf_datasets.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
from torchtitan.datasets.tokenizer import Tokenizer
1313
from torchtitan.logging_utils import logger
1414

15-
from datasets import load_dataset, load_from_disk
15+
from datasets import load_dataset
1616
from datasets.distributed import split_dataset_by_node
1717

18+
# map from dataset name to a local directory, or
19+
# a dataset repository on the HF hub
1820
_supported_datasets = {
1921
"c4_mini": "torchtitan/datasets/c4_mini",
2022
"c4": "allenai/c4",
@@ -66,7 +68,7 @@ def __init__(
6668
rank: int = 0,
6769
infinite: bool = False,
6870
) -> None:
69-
# allow user to pass in a local path to use unsupported datasets
71+
# allow user to pass in a (local or HF hub) path to use unsupported datasets
7072
if dataset_name not in _supported_datasets:
7173
if dataset_path:
7274
logger.warning(
@@ -79,32 +81,18 @@ def __init__(
7981
f"Supported datasets are: {list(_supported_datasets.keys())}."
8082
)
8183

82-
# special case to auto-load c4_mini (and any future datasets) from local dir
83-
if dataset_name == "c4_mini":
84-
dataset_path = f"torchtitan/datasets/{dataset_name}"
84+
if not dataset_path:
85+
dataset_path = _supported_datasets[dataset_name]
86+
logger.info(f"Preparing {dataset_name} dataset from {dataset_path}")
8587

86-
# TODO: This is a temporary solution for small datasets.
87-
# For large datasets we need to use a more scalable approach,
88-
# and support shuffling and checkpointing.
89-
if dataset_path:
90-
logger.info(f"Loading {dataset_name} dataset locally from {dataset_path}")
91-
ds = load_from_disk(dataset_path)
88+
if dataset_name == "c4":
89+
# c4 is huge, and requires both streaming and language selection
90+
# (we default to en)
91+
ds = load_dataset(dataset_path, name="en", split="train", streaming=True)
9292
else:
93-
logger.info(f"Preparing {dataset_name} dataset from HuggingFace")
94-
# Setting `streaming=True` works for large dataset, but is slightly
95-
# slower and unstable.
96-
if dataset_name == "c4":
97-
# c4 is huge, and requires both streaming and language selection
98-
# (we default to en).
99-
ds = load_dataset(
100-
_supported_datasets[dataset_name],
101-
"en",
102-
split="train",
103-
streaming=True,
104-
)
105-
else:
106-
ds = load_dataset(_supported_datasets[dataset_name], split="train")
93+
ds = load_dataset(dataset_path, split="train")
10794

95+
# TODO: support shuffling and checkpointing
10896
self.dataset_name = dataset_name
10997
self._data = split_dataset_by_node(ds, rank, world_size)
11098
self._tokenizer = tokenizer

0 commit comments

Comments
 (0)