-
Notifications
You must be signed in to change notification settings - Fork 6
Add more features to huggingface reader #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| def partitions(self) -> Sequence[InputPartition]: | ||
| from datasets import load_dataset | ||
| if not self.streaming: | ||
| return [Shard(index=0)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lhoestq I am not able to get num_shards for a non-streaming dataset. Do you know if this is supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the dataset is available locally it is loaded as a Dataset as a memory mapped Arrow Table that is the concatenation of all the shards. So in practice you don't really care about the shards themselves since you can take whatever slice of the Table you want. The number of shards can be set to the maximum level of parallelism of the Spark setup, or we can decide to have as many shards as cached Arrow files, or as many shards as Arrow Record Batches for example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here load_dataset(..., streaming=False) downloads the full dataset and prepares it as Arrow files locally, so it must be called only once. I understand in the current implementation it would be called once since the number of partitions is set to 1 so it works but it doesn't leverage Spark distributed.
There is this experimental feature that was added a while ago to let load_dataset run in parallel using Spark via joblibspark : huggingface/datasets#5924
with parallel_backend('spark') as backend:
ds = load_dataset(..., streaming=False, num_proc=<number of spark jobs to spawn>) # returns directly if the dataset is cachedIt's also possible to get the Dataset from the downloaded and prepared Arrow dataset like this
builder = load_dataset_builder(...)
with parallel_backend('spark') as backend:
builder.download_and_prepare(..., num_proc=...) # returns directly if the dataset is cached
ds = builder.as_dataset(split)
# EDIT: it should be possible to get an IterableDataset as well but I need to double checkif this doesn't fit the current implementation well we can keep it for later and call the internals of the builder manually in proper Spark code if needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. I can try it out using Spark cluster mode instead of the local mode to see if streaming works better.
lhoestq
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good ! we can see if we want to parallelize the non-streaming case later anyway :)
This PR:
partitionsmethod in DataSourceReader that uses thenum_shardsparameter from the IterableDataset to read from the a streaming dataset across multiple workers.load_dataset(..., streaming=False)load_dataset(path, name=config_name, ...)readmethod to use an iterator of arrow batches. Note this can only be tested against the Spark master branch build (not with spark4.0.0.dev2 release).Example