Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,5 @@ dist/*
data
out
wandb
*.json

torchtitan/datasets/**/*.model
Binary file not shown.
20 changes: 0 additions & 20 deletions torchtitan/datasets/c4_mini/dataset_info.json

This file was deleted.

13 changes: 0 additions & 13 deletions torchtitan/datasets/c4_mini/state.json

This file was deleted.

38 changes: 13 additions & 25 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging_utils import logger

from datasets import load_dataset, load_from_disk
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node

# map from dataset name to a local directory, or
# a dataset repository on the HF hub
_supported_datasets = {
"c4_mini": "torchtitan/datasets/c4_mini",
"c4": "allenai/c4",
Expand Down Expand Up @@ -66,7 +68,7 @@ def __init__(
rank: int = 0,
infinite: bool = False,
) -> None:
# allow user to pass in a local path to use unsupported datasets
# allow user to pass in a (local or HF hub) path to use unsupported datasets
if dataset_name not in _supported_datasets:
if dataset_path:
logger.warning(
Expand All @@ -79,32 +81,18 @@ def __init__(
f"Supported datasets are: {list(_supported_datasets.keys())}."
)

# special case to auto-load c4_mini (and any future datasets) from local dir
if dataset_name == "c4_mini":
dataset_path = f"torchtitan/datasets/{dataset_name}"
if not dataset_path:
dataset_path = _supported_datasets[dataset_name]
logger.info(f"Preparing {dataset_name} dataset from {dataset_path}")

# TODO: This is a temporary solution for small datasets.
# For large datasets we need to use a more scalable approach,
# and support shuffling and checkpointing.
if dataset_path:
logger.info(f"Loading {dataset_name} dataset locally from {dataset_path}")
ds = load_from_disk(dataset_path)
if dataset_name == "c4":
# c4 is huge, and requires both streaming and language selection
# (we default to en)
ds = load_dataset(dataset_path, name="en", split="train", streaming=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a streaming as an arg the HuggingFaceDataset, and then when creating dataloader in c4, we just pass streaming=true to the HuggingFaceDataset constructor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We certainly can, although this would create if-c4 statements to other places as well and more args in function call in general. Since we only have two datasets right now, maybe let's keep it until the change is necessary.

else:
logger.info(f"Preparing {dataset_name} dataset from HuggingFace")
# Setting `streaming=True` works for large dataset, but is slightly
# slower and unstable.
if dataset_name == "c4":
# c4 is huge, and requires both streaming and language selection
# (we default to en).
ds = load_dataset(
_supported_datasets[dataset_name],
"en",
split="train",
streaming=True,
)
else:
ds = load_dataset(_supported_datasets[dataset_name], split="train")
ds = load_dataset(dataset_path, split="train")

# TODO: support shuffling and checkpointing
self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._tokenizer = tokenizer
Expand Down