From 2833b4d719586bb561805d2641e3338a92e9eb03 Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Wed, 29 Jan 2025 15:53:59 -0800 Subject: [PATCH 1/4] rename files to follow split naming convention --- pyspark_huggingface/huggingface_sink.py | 120 +++++++++++++++--------- tests/test_huggingface_writer.py | 50 +++++----- 2 files changed, 105 insertions(+), 65 deletions(-) diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index 7e6689c..9fb6127 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -1,7 +1,7 @@ import ast import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Iterator, List, Optional +from typing import TYPE_CHECKING, Iterator, List, Optional, Union from pyspark.sql.datasource import ( DataSource, @@ -11,7 +11,13 @@ from pyspark.sql.types import StructType if TYPE_CHECKING: - from huggingface_hub import CommitOperationAdd, CommitOperationDelete + from huggingface_hub import ( + CommitOperation, + CommitOperationAdd, + CommitOperationDelete, + HfApi, + ) + from huggingface_hub.hf_api import RepoFile, RepoFolder from pyarrow import RecordBatch logger = logging.getLogger(__name__) @@ -27,8 +33,8 @@ class HuggingFaceSink(DataSource): Data Source Options: - token (str, required): HuggingFace API token for authentication. - path (str, required): HuggingFace repository ID, e.g. `{username}/{dataset}`. - - path_in_repo (str): Path within the repository to write the data. Defaults to the root. - - split (str): Split name to write the data to. Defaults to `train`. Only `train`, `test`, and `validation` are supported. + - path_in_repo (str): Path within the repository to write the data. Defaults to "data". + - split (str): Split name to write the data to. Defaults to `train`. - revision (str): Branch, tag, or commit to write to. Defaults to the main branch. - endpoint (str): Custom HuggingFace API endpoint URL. - max_bytes_per_file (int): Maximum size of each Parquet file. @@ -125,7 +131,9 @@ def __init__( import uuid self.repo_id = repo_id - self.path_in_repo = (path_in_repo or "").strip("/") + self.path_in_repo = ( + path_in_repo.strip("/") if path_in_repo is not None else "data" + ) self.split = split or "train" self.revision = revision self.schema = schema @@ -140,26 +148,7 @@ def __init__( # Use a unique filename prefix to avoid conflicts with existing files self.uuid = uuid.uuid4() - self.validate() - - def validate(self): - if self.split not in ["train", "test", "validation"]: - """ - TODO: Add support for custom splits. - - For custom split names to be recognized, the files must have path with format: - `data/{split}-{iiiii}-of-{nnnnn}.parquet` - where `iiiii` is the part number and `nnnnn` is the total number of parts, both padded to 5 digits. - Example: `data/custom-00000-of-00002.parquet` - - Therefore the current usage of UUID to avoid naming conflicts won't work for custom split names. - To fix this we can rename the files in the commit phase to satisfy the naming convention. - """ - raise NotImplementedError( - f"Only 'train', 'test', and 'validation' splits are supported. Got '{self.split}'." - ) - - def get_api(self): + def _get_api(self): from huggingface_hub import HfApi return HfApi(token=self.token, endpoint=self.endpoint) @@ -168,16 +157,11 @@ def get_api(self): def prefix(self) -> str: return f"{self.path_in_repo}/{self.split}".strip("/") - def get_delete_operations(self) -> Iterator["CommitOperationDelete"]: + def _list_split(self, api: "HfApi") -> Iterator[Union["RepoFile", "RepoFolder"]]: """ - Get the commit operations to delete all existing Parquet files. - This is used when `overwrite=True` to clear the target directory. + Get all existing files of the current split. """ - from huggingface_hub import CommitOperationDelete from huggingface_hub.errors import EntryNotFoundError - from huggingface_hub.hf_api import RepoFolder - - api = self.get_api() try: objects = api.list_repo_tree( @@ -190,11 +174,48 @@ def get_delete_operations(self) -> Iterator["CommitOperationDelete"]: ) for obj in objects: if obj.path.startswith(self.prefix): - yield CommitOperationDelete( - path_in_repo=obj.path, is_folder=isinstance(obj, RepoFolder) + yield obj + except EntryNotFoundError: + pass + + def _get_rename_operations( + self, api: "HfApi", additions: List["CommitOperationAdd"] + ) -> Iterator["CommitOperation"]: + """ + Note: mutates additions to update the path_in_repo of each addition. + """ + from huggingface_hub import CommitOperationCopy, CommitOperationDelete + from huggingface_hub.hf_api import RepoFile, RepoFolder + + count_new = len(additions) + count_existing = 0 + + def format_path(i): + return f"{self.prefix}-{i:05d}-of-{count_new + count_existing:05d}.parquet" + + # Rename existing files to have the correct total number of parts + if not self.overwrite: + existing = list( + obj for obj in self._list_split(api) if isinstance(obj, RepoFile) + ) + count_existing = len(existing) + for i, obj in enumerate(existing): + new_path = format_path(i) + if obj.path != new_path: + yield CommitOperationCopy( + src_path_in_repo=obj.path, path_in_repo=new_path ) - except EntryNotFoundError as e: - logger.info(f"Writing to a new path: {e}") + yield CommitOperationDelete(path_in_repo=obj.path) + # Otherwise, delete existing files + else: + for obj in self._list_split(api): + yield CommitOperationDelete( + path_in_repo=obj.path, is_folder=isinstance(obj, RepoFolder) + ) + + # Rename additions + for i, addition in enumerate(additions): + addition.path_in_repo = format_path(i + count_existing) def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: import io @@ -208,7 +229,7 @@ def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: context = TaskContext.get() partition_id = context.partitionId() if context else 0 - api = self.get_api() + api = self._get_api() schema = to_arrow_schema(self.schema) num_files = 0 @@ -265,25 +286,40 @@ def flush(writer: pq.ParquetWriter): return HuggingFaceCommitMessage(additions=additions) def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ignore[override] - import math + api = self._get_api() - api = self.get_api() operations = [ addition for message in messages for addition in message.additions ] - if self.overwrite: # Delete existing files if overwrite is enabled - operations.extend(self.get_delete_operations()) + prepare_operations = list(self._get_rename_operations(api, operations)) + self._create_commits( + api, + operations=prepare_operations, + message="Prepare for upload using PySpark", + ) + + self._create_commits( + api, + operations=operations, + message="Upload using PySpark", + ) + + def _create_commits( + self, api: "HfApi", operations: List["CommitOperation"], message: str + ) -> None: """ Split the commit into multiple parts if necessary. The HuggingFace API may time out if there are too many operations in a single commit. """ + import math + num_commits = math.ceil(len(operations) / self.max_operations_per_commit) for i in range(num_commits): begin = i * self.max_operations_per_commit end = (i + 1) * self.max_operations_per_commit part = operations[begin:end] - commit_message = "Upload using PySpark" + ( + commit_message = message + ( f" (part {i:05d}-of-{num_commits:05d})" if num_commits > 1 else "" ) api.create_commit( diff --git a/tests/test_huggingface_writer.py b/tests/test_huggingface_writer.py index 734621b..e9ea6ae 100644 --- a/tests/test_huggingface_writer.py +++ b/tests/test_huggingface_writer.py @@ -6,7 +6,6 @@ from pyspark.testing import assertDataFrameEqual from pytest_mock import MockerFixture - # ============== Fixtures & Helpers ============== @pytest.fixture(scope="session") @@ -22,8 +21,10 @@ def token(): return os.environ["HF_TOKEN"] -def reader(spark): - return spark.read.format("huggingface").option("token", token()) +def load(repo, split): + from datasets import load_dataset + + return load_dataset(repo, token=token(), split=split).to_pandas() def writer(df: DataFrame): @@ -34,7 +35,7 @@ def writer(df: DataFrame): def random_df(spark: SparkSession): from pyspark.sql.functions import rand - return lambda n: spark.range(n).select((rand()).alias("value")) + return lambda n: spark.range(n, numPartitions=2).select((rand()).alias("value")) @pytest.fixture(scope="session") @@ -59,41 +60,44 @@ def repo(api, username): # ============== Tests ============== -def test_basic(spark, repo, random_df): + +def test_basic(repo, random_df): df = random_df(10) writer(df).mode("append").save(repo) - actual = reader(spark).load(repo) - assertDataFrameEqual(df, actual) + actual = load(repo, "train") + assertDataFrameEqual(actual, df.toPandas()) -def test_append(spark, repo, random_df): +@pytest.mark.parametrize("split", ["train", "custom"]) +def test_append(repo, random_df, split): df1 = random_df(10) df2 = random_df(10) - writer(df1).mode("append").save(repo) - writer(df2).mode("append").save(repo) - actual = reader(spark).load(repo) + writer(df1).options(split=split).mode("append").save(repo) + writer(df2).options(split=split).mode("append").save(repo) + actual = load(repo, split) expected = df1.union(df2) - assertDataFrameEqual(actual, expected) + assertDataFrameEqual(actual, expected.toPandas()) -def test_overwrite(spark, repo, random_df): +@pytest.mark.parametrize("split", ["train", "custom"]) +def test_overwrite(repo, random_df, split): df1 = random_df(10) df2 = random_df(10) - writer(df1).mode("append").save(repo) - writer(df2).mode("overwrite").save(repo) - actual = reader(spark).load(repo) - assertDataFrameEqual(actual, df2) + writer(df1).options(split=split).mode("append").save(repo) + writer(df2).options(split=split).mode("overwrite").save(repo) + actual = load(repo, split) + assertDataFrameEqual(actual, df2.toPandas()) -def test_split(spark, repo, random_df): +def test_split(repo, random_df): df1 = random_df(10) df2 = random_df(10) writer(df1).mode("append").save(repo) - writer(df2).mode("append").options(split="test").save(repo) - actual1 = reader(spark).options(split="train").load(repo) - actual2 = reader(spark).options(split="test").load(repo) - assertDataFrameEqual(actual1, df1) - assertDataFrameEqual(actual2, df2) + writer(df2).mode("append").options(split="custom").save(repo) + actual1 = load(repo, "train") + actual2 = load(repo, "custom") + assertDataFrameEqual(actual1, df1.toPandas()) + assertDataFrameEqual(actual2, df2.toPandas()) def test_revision(repo, random_df, api): From c1134f7d88793dfd540604db90cf91aefccd16ae Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Wed, 29 Jan 2025 16:09:15 -0800 Subject: [PATCH 2/4] remove unused import --- pyspark_huggingface/huggingface_sink.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index 9fb6127..d51903f 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -14,7 +14,6 @@ from huggingface_hub import ( CommitOperation, CommitOperationAdd, - CommitOperationDelete, HfApi, ) from huggingface_hub.hf_api import RepoFile, RepoFolder From 2e25a9e008f5deaa9c226e0fa482dc46d11f4067 Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Wed, 29 Jan 2025 16:45:51 -0800 Subject: [PATCH 3/4] add explanation --- pyspark_huggingface/huggingface_sink.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index d51903f..07e6529 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -177,11 +177,17 @@ def _list_split(self, api: "HfApi") -> Iterator[Union["RepoFile", "RepoFolder"]] except EntryNotFoundError: pass - def _get_rename_operations( + def _prepare_operations( self, api: "HfApi", additions: List["CommitOperationAdd"] ) -> Iterator["CommitOperation"]: """ - Note: mutates additions to update the path_in_repo of each addition. + Prepare operations for upload. + - Rename files to be recognizable by HuggingFace: `{split}-{current:05d}-of-{total:05d}.parquet` + - Delete existing files if `overwrite=True` + + See: https://huggingface.co/docs/hub/en/datasets-file-names-and-splits + + Note: additions are mutated to update the path_in_repo to the new filename. """ from huggingface_hub import CommitOperationCopy, CommitOperationDelete from huggingface_hub.hf_api import RepoFile, RepoFolder @@ -290,14 +296,17 @@ def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ign operations = [ addition for message in messages for addition in message.additions ] - prepare_operations = list(self._get_rename_operations(api, operations)) + prepare_operations = list(self._prepare_operations(api, operations)) + # First rename existing files or delete files to be overwritten self._create_commits( api, operations=prepare_operations, message="Prepare for upload using PySpark", ) + # Then upload the new files + # This is a separate commit to avoid conflicts when e.g. a renamed file's old name is the same as a new file self._create_commits( api, operations=operations, From 6bc7bb09e837b499e92e141e1a3def97291c5794 Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Thu, 30 Jan 2025 11:55:51 -0800 Subject: [PATCH 4/4] 1 phase upload when overwriting --- pyspark_huggingface/huggingface_sink.py | 112 ++++++++++++------------ 1 file changed, 54 insertions(+), 58 deletions(-) diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index 07e6529..8733fee 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -177,51 +177,6 @@ def _list_split(self, api: "HfApi") -> Iterator[Union["RepoFile", "RepoFolder"]] except EntryNotFoundError: pass - def _prepare_operations( - self, api: "HfApi", additions: List["CommitOperationAdd"] - ) -> Iterator["CommitOperation"]: - """ - Prepare operations for upload. - - Rename files to be recognizable by HuggingFace: `{split}-{current:05d}-of-{total:05d}.parquet` - - Delete existing files if `overwrite=True` - - See: https://huggingface.co/docs/hub/en/datasets-file-names-and-splits - - Note: additions are mutated to update the path_in_repo to the new filename. - """ - from huggingface_hub import CommitOperationCopy, CommitOperationDelete - from huggingface_hub.hf_api import RepoFile, RepoFolder - - count_new = len(additions) - count_existing = 0 - - def format_path(i): - return f"{self.prefix}-{i:05d}-of-{count_new + count_existing:05d}.parquet" - - # Rename existing files to have the correct total number of parts - if not self.overwrite: - existing = list( - obj for obj in self._list_split(api) if isinstance(obj, RepoFile) - ) - count_existing = len(existing) - for i, obj in enumerate(existing): - new_path = format_path(i) - if obj.path != new_path: - yield CommitOperationCopy( - src_path_in_repo=obj.path, path_in_repo=new_path - ) - yield CommitOperationDelete(path_in_repo=obj.path) - # Otherwise, delete existing files - else: - for obj in self._list_split(api): - yield CommitOperationDelete( - path_in_repo=obj.path, is_folder=isinstance(obj, RepoFolder) - ) - - # Rename additions - for i, addition in enumerate(additions): - addition.path_in_repo = format_path(i + count_existing) - def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: import io @@ -291,25 +246,66 @@ def flush(writer: pq.ParquetWriter): return HuggingFaceCommitMessage(additions=additions) def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ignore[override] + """ + Commit the pre-uploaded Parquet files to the HuggingFace Hub, renaming them to match the expected format: + `{split}-{current:05d}-of-{total:05d}.parquet`. + Also delete or rename existing files of the split, depending on the mode. + """ + + from huggingface_hub import CommitOperationCopy, CommitOperationDelete + from huggingface_hub.hf_api import RepoFile, RepoFolder + api = self._get_api() - operations = [ - addition for message in messages for addition in message.additions - ] - prepare_operations = list(self._prepare_operations(api, operations)) + additions = [addition for message in messages for addition in message.additions] + operations = {} + count_new = len(additions) + count_existing = 0 - # First rename existing files or delete files to be overwritten - self._create_commits( - api, - operations=prepare_operations, - message="Prepare for upload using PySpark", - ) + def format_path(i): + return f"{self.prefix}-{i:05d}-of-{count_new + count_existing:05d}.parquet" + + def rename(old_path, new_path): + if old_path != new_path: + yield CommitOperationCopy( + src_path_in_repo=old_path, path_in_repo=new_path + ) + yield CommitOperationDelete(path_in_repo=old_path) + + # In overwrite mode, delete existing files + if self.overwrite: + for obj in self._list_split(api): + # Delete old file + operations[obj.path] = CommitOperationDelete( + path_in_repo=obj.path, is_folder=isinstance(obj, RepoFolder) + ) + # In append mode, rename existing files to have the correct total number of parts + else: + rename_operations = [] + existing = list( + obj for obj in self._list_split(api) if isinstance(obj, RepoFile) + ) + count_existing = len(existing) + for i, obj in enumerate(existing): + new_path = format_path(i) + rename_operations.extend(rename(obj.path, new_path)) + # Rename files in a separate commit to prevent them from being overwritten by new files of the same name + self._create_commits( + api, + operations=rename_operations, + message="Rename existing files before uploading new files using PySpark", + ) + + # Rename additions, putting them after existing files if any + for i, addition in enumerate(additions): + addition.path_in_repo = format_path(i + count_existing) + # Overwrite the deletion operation if the file already exists + operations[addition.path_in_repo] = addition - # Then upload the new files - # This is a separate commit to avoid conflicts when e.g. a renamed file's old name is the same as a new file + # Upload the new files self._create_commits( api, - operations=operations, + operations=list(operations.values()), message="Upload using PySpark", )