From 1837b25e9a37898027e41c5f0d24167083de427b Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Thu, 23 Jan 2025 15:16:08 -0800 Subject: [PATCH 01/13] add huggingfacesink data source --- pyspark_huggingface/huggingface_sink.py | 150 ++++++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 pyspark_huggingface/huggingface_sink.py diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py new file mode 100644 index 0000000..668488b --- /dev/null +++ b/pyspark_huggingface/huggingface_sink.py @@ -0,0 +1,150 @@ +import ast +from dataclasses import dataclass +from typing import TYPE_CHECKING, Iterator, List, Optional + +from pyspark.sql.datasource import ( + DataSource, + DataSourceArrowWriter, + WriterCommitMessage, +) +from pyspark.sql.types import StructType + +if TYPE_CHECKING: + from huggingface_hub import CommitOperationAdd + from pyarrow import RecordBatch + + +class HuggingFaceSink(DataSource): + def __init__(self, options): + super().__init__(options) + + if "path" not in options or not options["path"]: + raise Exception("You must specify a dataset name.") + + kwargs = dict(self.options) + self.path = kwargs.pop("path") + self.token = kwargs.pop("token") + self.endpoint = kwargs.pop("endpoint", None) + for arg in kwargs: + if kwargs[arg].lower() == "true": + kwargs[arg] = True + elif kwargs[arg].lower() == "false": + kwargs[arg] = False + else: + try: + kwargs[arg] = ast.literal_eval(kwargs[arg]) + except ValueError: + pass + self.kwargs = kwargs + + @classmethod + def name(cls): + return "huggingfacesink" + + def writer(self, schema: StructType, overwrite: bool) -> DataSourceArrowWriter: + return HuggingFaceDatasetsWriter( + path=self.path, + schema=schema, + token=self.token, + endpoint=self.endpoint, + **self.kwargs, + ) + + +@dataclass +class HuggingFaceCommitMessage(WriterCommitMessage): + addition: Optional["CommitOperationAdd"] + + +class HuggingFaceDatasetsWriter(DataSourceArrowWriter): + def __init__( + self, + path: str, + schema: StructType, + token: str, + endpoint: Optional[str] = None, + row_group_size: Optional[int] = None, + max_operations_per_commit=50, + **kwargs, + ): + self.path = path + self.schema = schema + self.token = token + self.endpoint = endpoint + self.row_group_size = row_group_size + self.max_operations_per_commit = max_operations_per_commit + self.kwargs = kwargs + + def get_filesystem(self): + from huggingface_hub import HfFileSystem + + return HfFileSystem(token=self.token, endpoint=self.endpoint) + + def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: + import io + + from huggingface_hub import CommitOperationAdd + from pyarrow import parquet as pq + from pyspark import TaskContext + from pyspark.sql.pandas.types import to_arrow_schema + + context = TaskContext.get() + assert context, "No active Spark task context" + partition_id = context.partitionId() + + schema = to_arrow_schema(self.schema) + parquet = io.BytesIO() + is_empty = True + with pq.ParquetWriter(parquet, schema, **self.kwargs) as writer: + for batch in iterator: + writer.write_batch(batch, row_group_size=self.row_group_size) + is_empty = False + + if is_empty: + return HuggingFaceCommitMessage(addition=None) + + name = f"part-{partition_id}.parquet" # Name of the file in the repo + parquet.seek(0) + addition = CommitOperationAdd(path_in_repo=name, path_or_fileobj=parquet) + + fs = self.get_filesystem() + resolved_path = fs.resolve_path(self.path) + fs._api.preupload_lfs_files( + repo_id=resolved_path.repo_id, + additions=[addition], + repo_type=resolved_path.repo_type, + revision=resolved_path.revision, + ) + + print(f"Written {name} with content") + return HuggingFaceCommitMessage(addition=addition) + + def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ignore[override] + import math + + fs = self.get_filesystem() + resolved_path = fs.resolve_path(self.path) + additions = [message.addition for message in messages if message.addition] + num_commits = math.ceil(len(additions) / self.max_operations_per_commit) + for i in range(num_commits): + begin = i * self.max_operations_per_commit + end = (i + 1) * self.max_operations_per_commit + operations = additions[begin:end] + commit_message = "Upload using PySpark" + ( + f" (part {i:05d}-of-{num_commits:05d})" if num_commits > 1 else "" + ) + print(operations, commit_message) + fs._api.create_commit( + repo_id=resolved_path.repo_id, + repo_type=resolved_path.repo_type, + revision=resolved_path.revision, + operations=operations, + commit_message=commit_message, + ) + for addition in operations: + print(f"Committed {addition.path_in_repo}") + + def abort(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ignore[override] + additions = [message.addition for message in messages if message.addition] + for addition in additions: + print(f"Aborted {addition.path_in_repo}") From d0764a8a7ab2f3d54531fe7fc0243c7ee374141a Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Thu, 23 Jan 2025 17:15:47 -0800 Subject: [PATCH 02/13] support overwrite and append modes --- pyspark_huggingface/huggingface_sink.py | 59 ++++++++++++++++++++----- 1 file changed, 48 insertions(+), 11 deletions(-) diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index 668488b..df9e101 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -1,4 +1,5 @@ import ast +import logging from dataclasses import dataclass from typing import TYPE_CHECKING, Iterator, List, Optional @@ -13,6 +14,8 @@ from huggingface_hub import CommitOperationAdd from pyarrow import RecordBatch +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) class HuggingFaceSink(DataSource): def __init__(self, options): @@ -45,6 +48,7 @@ def writer(self, schema: StructType, overwrite: bool) -> DataSourceArrowWriter: return HuggingFaceDatasetsWriter( path=self.path, schema=schema, + overwrite=overwrite, token=self.token, endpoint=self.endpoint, **self.kwargs, @@ -57,29 +61,58 @@ class HuggingFaceCommitMessage(WriterCommitMessage): class HuggingFaceDatasetsWriter(DataSourceArrowWriter): + def __init__( self, + *, path: str, schema: StructType, + overwrite: bool, token: str, endpoint: Optional[str] = None, row_group_size: Optional[int] = None, - max_operations_per_commit=50, + max_operations_per_commit=25000, **kwargs, ): + import uuid + self.path = path self.schema = schema + self.overwrite = overwrite self.token = token self.endpoint = endpoint self.row_group_size = row_group_size self.max_operations_per_commit = max_operations_per_commit self.kwargs = kwargs + # Use a unique prefix to avoid conflicts with existing files + self.prefix = f"{uuid.uuid4()}-" + def get_filesystem(self): from huggingface_hub import HfFileSystem return HfFileSystem(token=self.token, endpoint=self.endpoint) + def get_delete_operations(self, resolved_path): + from huggingface_hub import CommitOperationDelete, list_repo_tree + from huggingface_hub.hf_api import RepoFile + + # list all files in the directory + objects = list_repo_tree( + token=self.token, + path_in_repo=resolved_path.path_in_repo, + repo_id=resolved_path.repo_id, + repo_type=resolved_path.repo_type, + revision=resolved_path.revision, + expand=False, + recursive=False, + ) + + # delete all parquet files + for obj in objects: + if isinstance(obj, RepoFile) and obj.path.endswith(".parquet"): + yield CommitOperationDelete(path_in_repo=obj.path) + def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: import io @@ -100,10 +133,12 @@ def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: writer.write_batch(batch, row_group_size=self.row_group_size) is_empty = False - if is_empty: + if is_empty: # Skip empty partitions return HuggingFaceCommitMessage(addition=None) - name = f"part-{partition_id}.parquet" # Name of the file in the repo + name = ( + f"{self.prefix}part-{partition_id}.parquet" # Name of the file in the repo + ) parquet.seek(0) addition = CommitOperationAdd(path_in_repo=name, path_or_fileobj=parquet) @@ -116,7 +151,7 @@ def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: revision=resolved_path.revision, ) - print(f"Written {name} with content") + logger.info(f"Written {name} with content") return HuggingFaceCommitMessage(addition=addition) def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ignore[override] @@ -124,16 +159,18 @@ def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ign fs = self.get_filesystem() resolved_path = fs.resolve_path(self.path) - additions = [message.addition for message in messages if message.addition] - num_commits = math.ceil(len(additions) / self.max_operations_per_commit) + operations = [message.addition for message in messages if message.addition] + if self.overwrite: + operations += list(self.get_delete_operations(resolved_path)) + num_commits = math.ceil(len(operations) / self.max_operations_per_commit) + for i in range(num_commits): begin = i * self.max_operations_per_commit end = (i + 1) * self.max_operations_per_commit - operations = additions[begin:end] + operations = operations[begin:end] commit_message = "Upload using PySpark" + ( f" (part {i:05d}-of-{num_commits:05d})" if num_commits > 1 else "" ) - print(operations, commit_message) fs._api.create_commit( repo_id=resolved_path.repo_id, repo_type=resolved_path.repo_type, @@ -141,10 +178,10 @@ def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ign operations=operations, commit_message=commit_message, ) - for addition in operations: - print(f"Committed {addition.path_in_repo}") + for operation in operations: + logger.info(f"Committed {operation}") def abort(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ignore[override] additions = [message.addition for message in messages if message.addition] for addition in additions: - print(f"Aborted {addition.path_in_repo}") + logger.info(f"Aborted {addition}") From 703581ba67c195dd92b7cc189f167464391e33e8 Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Fri, 24 Jan 2025 14:11:22 -0800 Subject: [PATCH 03/13] limit size of each single parquet file --- pyspark_huggingface/huggingface_sink.py | 101 ++++++++++++++---------- 1 file changed, 61 insertions(+), 40 deletions(-) diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index df9e101..6fc7574 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -15,7 +15,6 @@ from pyarrow import RecordBatch logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) class HuggingFaceSink(DataSource): def __init__(self, options): @@ -57,7 +56,7 @@ def writer(self, schema: StructType, overwrite: bool) -> DataSourceArrowWriter: @dataclass class HuggingFaceCommitMessage(WriterCommitMessage): - addition: Optional["CommitOperationAdd"] + additions: List["CommitOperationAdd"] class HuggingFaceDatasetsWriter(DataSourceArrowWriter): @@ -71,7 +70,8 @@ def __init__( token: str, endpoint: Optional[str] = None, row_group_size: Optional[int] = None, - max_operations_per_commit=25000, + max_bytes_per_file=500_000_000, + max_operations_per_commit=25_000, **kwargs, ): import uuid @@ -82,22 +82,24 @@ def __init__( self.token = token self.endpoint = endpoint self.row_group_size = row_group_size + self.max_bytes_per_file = max_bytes_per_file self.max_operations_per_commit = max_operations_per_commit self.kwargs = kwargs - # Use a unique prefix to avoid conflicts with existing files - self.prefix = f"{uuid.uuid4()}-" + # Use a unique filename prefix to avoid conflicts with existing files + self.uuid = uuid.uuid4() def get_filesystem(self): from huggingface_hub import HfFileSystem return HfFileSystem(token=self.token, endpoint=self.endpoint) + # Get the commit operations to delete all existing Parquet files def get_delete_operations(self, resolved_path): from huggingface_hub import CommitOperationDelete, list_repo_tree from huggingface_hub.hf_api import RepoFile - # list all files in the directory + # List all files in the directory objects = list_repo_tree( token=self.token, path_in_repo=resolved_path.path_in_repo, @@ -108,10 +110,10 @@ def get_delete_operations(self, resolved_path): recursive=False, ) - # delete all parquet files + # Delete all existing parquet files for obj in objects: if isinstance(obj, RepoFile) and obj.path.endswith(".parquet"): - yield CommitOperationDelete(path_in_repo=obj.path) + yield CommitOperationDelete(path_in_repo=obj.path, is_folder=False) def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: import io @@ -121,49 +123,70 @@ def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: from pyspark import TaskContext from pyspark.sql.pandas.types import to_arrow_schema + # Get the current partition ID context = TaskContext.get() assert context, "No active Spark task context" partition_id = context.partitionId() - schema = to_arrow_schema(self.schema) - parquet = io.BytesIO() - is_empty = True - with pq.ParquetWriter(parquet, schema, **self.kwargs) as writer: - for batch in iterator: - writer.write_batch(batch, row_group_size=self.row_group_size) - is_empty = False - - if is_empty: # Skip empty partitions - return HuggingFaceCommitMessage(addition=None) - - name = ( - f"{self.prefix}part-{partition_id}.parquet" # Name of the file in the repo - ) - parquet.seek(0) - addition = CommitOperationAdd(path_in_repo=name, path_or_fileobj=parquet) - fs = self.get_filesystem() resolved_path = fs.resolve_path(self.path) - fs._api.preupload_lfs_files( - repo_id=resolved_path.repo_id, - additions=[addition], - repo_type=resolved_path.repo_type, - revision=resolved_path.revision, - ) - logger.info(f"Written {name} with content") - return HuggingFaceCommitMessage(addition=addition) + schema = to_arrow_schema(self.schema) + num_files = 0 + additions = [] + + with io.BytesIO() as parquet: + + def flush(): + nonlocal num_files + name = f"{self.uuid}-part-{partition_id}-{num_files}.parquet" # Name of the file in the repo + num_files += 1 + parquet.seek(0) + + addition = CommitOperationAdd( + path_in_repo=name, path_or_fileobj=parquet + ) + fs._api.preupload_lfs_files( + repo_id=resolved_path.repo_id, + additions=[addition], + repo_type=resolved_path.repo_type, + revision=resolved_path.revision, + ) + additions.append(addition) + + # Reuse the buffer for the next file + parquet.seek(0) + parquet.truncate() + + # Write the Parquet files, limiting the size of each file + while True: + with pq.ParquetWriter(parquet, schema, **self.kwargs) as writer: + num_batches = 0 + for batch in iterator: + writer.write_batch(batch, row_group_size=self.row_group_size) + num_batches += 1 + if parquet.tell() > self.max_bytes_per_file: + flush() + break + else: + if num_batches > 0: + flush() + break + + return HuggingFaceCommitMessage(additions=additions) def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ignore[override] import math fs = self.get_filesystem() resolved_path = fs.resolve_path(self.path) - operations = [message.addition for message in messages if message.addition] - if self.overwrite: - operations += list(self.get_delete_operations(resolved_path)) - num_commits = math.ceil(len(operations) / self.max_operations_per_commit) + operations = [ + addition for message in messages for addition in message.additions + ] + if self.overwrite: # Delete existing files if overwrite is enabled + operations.extend(self.get_delete_operations(resolved_path)) + num_commits = math.ceil(len(operations) / self.max_operations_per_commit) for i in range(num_commits): begin = i * self.max_operations_per_commit end = (i + 1) * self.max_operations_per_commit @@ -178,10 +201,8 @@ def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ign operations=operations, commit_message=commit_message, ) - for operation in operations: - logger.info(f"Committed {operation}") def abort(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ignore[override] - additions = [message.addition for message in messages if message.addition] + additions = [addition for message in messages for addition in message.additions] for addition in additions: logger.info(f"Aborted {addition}") From 95d7a9d4f498bc3ab1f421ece80e0f733cd9ace1 Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Fri, 24 Jan 2025 16:02:03 -0800 Subject: [PATCH 04/13] fix writing to new split --- pyspark_huggingface/huggingface_sink.py | 73 +++++++++++++++++++------ 1 file changed, 57 insertions(+), 16 deletions(-) diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index 6fc7574..8c63d97 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -17,6 +17,41 @@ logger = logging.getLogger(__name__) class HuggingFaceSink(DataSource): + """ + A DataSource for writing Spark DataFrames to HuggingFace Datasets. + + This data source allows writing Spark DataFrames to the HuggingFace Hub as Parquet files. + The path must be a valid `hf://` URL. + + Name: `huggingfacesink` + + Data Source Options: + - token (str, required): HuggingFace API token for authentication. + - endpoint (str): Custom HuggingFace API endpoint URL. + - max_bytes_per_file (int): Maximum size of each Parquet file. + - row_group_size (int): Row group size when writing Parquet files. + - max_operations_per_commit (int): Maximum number of files to add/delete per commit. + + Modes: + - `overwrite`: Overwrite an existing dataset by deleting existing Parquet files. + - `append`: Append data to an existing dataset. + + Examples + -------- + + Write a DataFrame to the HuggingFace Hub. + + >>> df.write.format("huggingfacesink").mode("overwrite").options(token="...").save("hf://datasets/user/dataset") + + Append data to an existing dataset on the HuggingFace Hub. + + >>> df.write.format("huggingfacesink").mode("append").options(token="...").save("hf://datasets/user/dataset") + + Write data to configuration `en` and split `train` of a dataset. + + >>> df.write.format("huggingfacesink").mode("overwrite").options(token="...").save("hf://datasets/user/dataset/en/train") + """ + def __init__(self, options): super().__init__(options) @@ -96,27 +131,32 @@ def get_filesystem(self): # Get the commit operations to delete all existing Parquet files def get_delete_operations(self, resolved_path): - from huggingface_hub import CommitOperationDelete, list_repo_tree + from huggingface_hub import CommitOperationDelete from huggingface_hub.hf_api import RepoFile + from huggingface_hub.errors import EntryNotFoundError - # List all files in the directory - objects = list_repo_tree( - token=self.token, - path_in_repo=resolved_path.path_in_repo, - repo_id=resolved_path.repo_id, - repo_type=resolved_path.repo_type, - revision=resolved_path.revision, - expand=False, - recursive=False, - ) + fs = self.get_filesystem() - # Delete all existing parquet files - for obj in objects: - if isinstance(obj, RepoFile) and obj.path.endswith(".parquet"): - yield CommitOperationDelete(path_in_repo=obj.path, is_folder=False) + # List all files in the directory + try: + # Delete all existing parquet files + objects = fs._api.list_repo_tree( + path_in_repo=resolved_path.path_in_repo, + repo_id=resolved_path.repo_id, + repo_type=resolved_path.repo_type, + revision=resolved_path.revision, + expand=False, + recursive=False, + ) + for obj in objects: + if isinstance(obj, RepoFile) and obj.path.endswith(".parquet"): + yield CommitOperationDelete(path_in_repo=obj.path, is_folder=False) + except EntryNotFoundError as e: + logger.info(f"Writing to a new path: {e}") def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: import io + import os from huggingface_hub import CommitOperationAdd from pyarrow import parquet as pq @@ -140,11 +180,12 @@ def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: def flush(): nonlocal num_files name = f"{self.uuid}-part-{partition_id}-{num_files}.parquet" # Name of the file in the repo + path_in_repo = os.path.join(resolved_path.path_in_repo, name) num_files += 1 parquet.seek(0) addition = CommitOperationAdd( - path_in_repo=name, path_or_fileobj=parquet + path_in_repo=path_in_repo, path_or_fileobj=parquet ) fs._api.preupload_lfs_files( repo_id=resolved_path.repo_id, From eb374a401dc908781c157f55db0a0bd2e035466d Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Mon, 27 Jan 2025 12:53:30 -0800 Subject: [PATCH 05/13] improve comments --- pyspark_huggingface/huggingface_sink.py | 70 ++++++++++++++++--------- 1 file changed, 46 insertions(+), 24 deletions(-) diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index 8c63d97..d5e37ac 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -11,7 +11,11 @@ from pyspark.sql.types import StructType if TYPE_CHECKING: - from huggingface_hub import CommitOperationAdd + from huggingface_hub import ( + CommitOperationAdd, + CommitOperationDelete, + HfFileSystemResolvedPath, + ) from pyarrow import RecordBatch logger = logging.getLogger(__name__) @@ -21,10 +25,13 @@ class HuggingFaceSink(DataSource): A DataSource for writing Spark DataFrames to HuggingFace Datasets. This data source allows writing Spark DataFrames to the HuggingFace Hub as Parquet files. - The path must be a valid `hf://` URL. Name: `huggingfacesink` + Path: + - The path must be a valid HuggingFace dataset path, e.g. `{user}/{repo}` or `{user}/{repo}/{split}`. + - A revision can be specified using the `@` symbol, e.g. `{user}/{repo}@{revision}`. + Data Source Options: - token (str, required): HuggingFace API token for authentication. - endpoint (str): Custom HuggingFace API endpoint URL. @@ -41,15 +48,15 @@ class HuggingFaceSink(DataSource): Write a DataFrame to the HuggingFace Hub. - >>> df.write.format("huggingfacesink").mode("overwrite").options(token="...").save("hf://datasets/user/dataset") + >>> df.write.format("huggingfacesink").mode("overwrite").options(token="...").save("user/dataset") - Append data to an existing dataset on the HuggingFace Hub. + Append to an existing directory on the HuggingFace Hub. - >>> df.write.format("huggingfacesink").mode("append").options(token="...").save("hf://datasets/user/dataset") + >>> df.write.format("huggingfacesink").mode("append").options(token="...").save("user/dataset") - Write data to configuration `en` and split `train` of a dataset. + Write to the `test` split of a dataset. - >>> df.write.format("huggingfacesink").mode("overwrite").options(token="...").save("hf://datasets/user/dataset/en/train") + >>> df.write.format("huggingfacesink").mode("overwrite").options(token="...").save("user/dataset/test") """ def __init__(self, options): @@ -95,7 +102,6 @@ class HuggingFaceCommitMessage(WriterCommitMessage): class HuggingFaceDatasetsWriter(DataSourceArrowWriter): - def __init__( self, *, @@ -111,7 +117,7 @@ def __init__( ): import uuid - self.path = path + self.path = f"datasets/{path}" self.schema = schema self.overwrite = overwrite self.token = token @@ -129,17 +135,20 @@ def get_filesystem(self): return HfFileSystem(token=self.token, endpoint=self.endpoint) - # Get the commit operations to delete all existing Parquet files - def get_delete_operations(self, resolved_path): + def get_delete_operations( + self, resolved_path: "HfFileSystemResolvedPath" + ) -> Iterator["CommitOperationDelete"]: + """ + Get the commit operations to delete all existing Parquet files. + This is used when `overwrite=True` to clear the target directory. + """ from huggingface_hub import CommitOperationDelete from huggingface_hub.hf_api import RepoFile from huggingface_hub.errors import EntryNotFoundError fs = self.get_filesystem() - # List all files in the directory try: - # Delete all existing parquet files objects = fs._api.list_repo_tree( path_in_repo=resolved_path.path_in_repo, repo_id=resolved_path.repo_id, @@ -163,7 +172,7 @@ def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: from pyspark import TaskContext from pyspark.sql.pandas.types import to_arrow_schema - # Get the current partition ID + # Get the current partition ID. Use this to generate unique filenames for each partition. context = TaskContext.get() assert context, "No active Spark task context" partition_id = context.partitionId() @@ -175,9 +184,14 @@ def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: num_files = 0 additions = [] + # TODO: Evaluate the performance of using a temp file instead of an in-memory buffer. with io.BytesIO() as parquet: - def flush(): + def flush(writer: pq.ParquetWriter): + """ + Upload the current Parquet file and reset the buffer. + """ + writer.close() # Close the writer to flush the buffer nonlocal num_files name = f"{self.uuid}-part-{partition_id}-{num_files}.parquet" # Name of the file in the repo path_in_repo = os.path.join(resolved_path.path_in_repo, name) @@ -199,20 +213,23 @@ def flush(): parquet.seek(0) parquet.truncate() - # Write the Parquet files, limiting the size of each file + """ + Write the Parquet files, flushing the buffer when the file size exceeds the limit. + Limiting the size is necessary because we are writing them in memory. + """ while True: with pq.ParquetWriter(parquet, schema, **self.kwargs) as writer: num_batches = 0 - for batch in iterator: + for batch in iterator: # Start iterating from where we left off writer.write_batch(batch, row_group_size=self.row_group_size) num_batches += 1 if parquet.tell() > self.max_bytes_per_file: - flush() - break - else: + flush(writer) + break # Start a new file + else: # Finished writing all batches if num_batches > 0: - flush() - break + flush(writer) + break # Exit while loop return HuggingFaceCommitMessage(additions=additions) @@ -227,11 +244,15 @@ def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ign if self.overwrite: # Delete existing files if overwrite is enabled operations.extend(self.get_delete_operations(resolved_path)) + """ + Split the commit into multiple parts if necessary. + The HuggingFace API has a limit of 25,000 operations per commit. + """ num_commits = math.ceil(len(operations) / self.max_operations_per_commit) for i in range(num_commits): begin = i * self.max_operations_per_commit end = (i + 1) * self.max_operations_per_commit - operations = operations[begin:end] + part = operations[begin:end] commit_message = "Upload using PySpark" + ( f" (part {i:05d}-of-{num_commits:05d})" if num_commits > 1 else "" ) @@ -239,11 +260,12 @@ def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ign repo_id=resolved_path.repo_id, repo_type=resolved_path.repo_type, revision=resolved_path.revision, - operations=operations, + operations=part, commit_message=commit_message, ) def abort(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ignore[override] + # We don't need to do anything here, as the files are not included in the repo until commit additions = [addition for message in messages for addition in message.additions] for addition in additions: logger.info(f"Aborted {addition}") From 42592255b03dd188b412739adf278135404d297b Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Mon, 27 Jan 2025 13:32:11 -0800 Subject: [PATCH 06/13] organize imports --- pyspark_huggingface/huggingface_sink.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index d5e37ac..4fba52f 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -143,8 +143,8 @@ def get_delete_operations( This is used when `overwrite=True` to clear the target directory. """ from huggingface_hub import CommitOperationDelete - from huggingface_hub.hf_api import RepoFile from huggingface_hub.errors import EntryNotFoundError + from huggingface_hub.hf_api import RepoFile fs = self.get_filesystem() From 6c68ae347dd50fae04d901b34c6b012748bb2a7b Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Mon, 27 Jan 2025 15:18:03 -0800 Subject: [PATCH 07/13] add tests --- pyspark_huggingface/huggingface_sink.py | 3 +- tests/test_huggingface_writer.py | 125 ++++++++++++++++++++++++ 2 files changed, 126 insertions(+), 2 deletions(-) create mode 100644 tests/test_huggingface_writer.py diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index 4fba52f..931866b 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -174,8 +174,7 @@ def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: # Get the current partition ID. Use this to generate unique filenames for each partition. context = TaskContext.get() - assert context, "No active Spark task context" - partition_id = context.partitionId() + partition_id = context.partitionId() if context else 0 fs = self.get_filesystem() resolved_path = fs.resolve_path(self.path) diff --git a/tests/test_huggingface_writer.py b/tests/test_huggingface_writer.py new file mode 100644 index 0000000..ba87299 --- /dev/null +++ b/tests/test_huggingface_writer.py @@ -0,0 +1,125 @@ +import os + +import pytest +from pyspark.sql import DataFrame, SparkSession +from pyspark.testing import assertDataFrameEqual +from pytest_mock import MockerFixture + + +@pytest.fixture(scope="session") +def spark(): + from pyspark_huggingface.huggingface_sink import HuggingFaceSink + + spark = SparkSession.builder.getOrCreate() + spark.dataSource.register(HuggingFaceSink) + yield spark + + +def token(): + return os.environ["HF_TOKEN"] + + +def reader(spark): + return spark.read.format("huggingface").option("token", token()) + + +def writer(df: DataFrame): + return df.write.format("huggingfacesink").option("token", token()) + + +@pytest.fixture(scope="session") +def random_df(spark: SparkSession): + from pyspark.sql.functions import rand + + return lambda n: spark.range(n).select((rand()).alias("value")) + + +@pytest.fixture(scope="session") +def api(): + import huggingface_hub + + return huggingface_hub.HfApi(token=token()) + + +@pytest.fixture(scope="session") +def username(api): + return api.whoami()["name"] + + +@pytest.fixture +def repo(api, username): + import uuid + + repo_id = f"{username}/test-{uuid.uuid4()}" + api.create_repo(repo_id, private=True, repo_type="dataset") + yield repo_id + api.delete_repo(repo_id, repo_type="dataset") + + +def test_basic(spark, repo, random_df): + df = random_df(10) + writer(df).mode("append").save(repo) + actual = reader(spark).load(repo) + assertDataFrameEqual(df, actual) + + +def test_append(spark, repo, random_df): + df1 = random_df(10) + df2 = random_df(10) + writer(df1).mode("append").save(repo) + writer(df2).mode("append").save(repo) + actual = reader(spark).load(repo) + expected = df1.union(df2) + assertDataFrameEqual(actual, expected) + + +def test_overwrite(spark, repo, random_df): + df1 = random_df(10) + df2 = random_df(10) + writer(df1).mode("append").save(repo) + writer(df2).mode("overwrite").save(repo) + actual = reader(spark).load(repo) + assertDataFrameEqual(actual, df2) + + +def test_dir(repo, random_df, api): + df = random_df(10) + writer(df).mode("append").save(repo + "/dir") + assert any( + file.path.endswith(".parquet") + for file in api.list_repo_tree(repo, "dir", repo_type="dataset") + ) + + +def test_revision(repo, random_df, api): + df = random_df(10) + api.create_branch(repo, branch="test", repo_type="dataset") + writer(df).mode("append").save(repo + "@test") + assert any( + file.path.endswith(".parquet") + for file in api.list_repo_tree(repo, repo_type="dataset", revision="test") + ) + + +def test_max_bytes_per_file(spark, mocker: MockerFixture): + from pyspark_huggingface.huggingface_sink import HuggingFaceDatasetsWriter + + repo = "user/test" + fs = mocker.patch("huggingface_hub.HfFileSystem").return_value = mocker.MagicMock() + # mock fs._api.preupload_lfs_files + resolved_path = fs.resolve_path.return_value = mocker.MagicMock() + resolved_path.path_in_repo = repo + resolved_path.repo_id = repo + resolved_path.repo_type = "dataset" + resolved_path.revision = "main" + + df = spark.range(10) + writer = HuggingFaceDatasetsWriter( + path=repo, + overwrite=False, + schema=df.schema, + token="token", + max_bytes_per_file=1, + ) + writer.write(iter(df.toArrow().to_batches(max_chunksize=1))) + assert fs._api.preupload_lfs_files.call_count == 10 From 7b4bba9ee4e4a52acdc6bfb5a8b293014c84aa05 Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Mon, 27 Jan 2025 15:18:49 -0800 Subject: [PATCH 08/13] update dependencies --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 79b722e..aba295e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,8 @@ datasets = "^3.2" [tool.poetry.group.dev.dependencies] pytest = "^8.0.0" +pytest-dotenv = "^0.5.2" +pytest-mock = "^3.14.0" [build-system] requires = ["poetry-core"] From c536a6980b75df17bf1f9d994ef38c90adaf4379 Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Mon, 27 Jan 2025 15:21:45 -0800 Subject: [PATCH 09/13] clean up --- tests/test_huggingface_writer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_huggingface_writer.py b/tests/test_huggingface_writer.py index ba87299..f08514b 100644 --- a/tests/test_huggingface_writer.py +++ b/tests/test_huggingface_writer.py @@ -6,6 +6,8 @@ from pytest_mock import MockerFixture +# ============== Fixtures & Helpers ============== + @pytest.fixture(scope="session") def spark(): from pyspark_huggingface.huggingface_sink import HuggingFaceSink @@ -56,6 +58,8 @@ def repo(api, username): api.delete_repo(repo_id, repo_type="dataset") +# ============== Tests ============== + def test_basic(spark, repo, random_df): df = random_df(10) writer(df).mode("append").save(repo) @@ -106,7 +110,6 @@ def test_max_bytes_per_file(spark, mocker: MockerFixture): repo = "user/test" fs = mocker.patch("huggingface_hub.HfFileSystem").return_value = mocker.MagicMock() - # mock fs._api.preupload_lfs_files resolved_path = fs.resolve_path.return_value = mocker.MagicMock() resolved_path.path_in_repo = repo resolved_path.repo_id = repo From 77f0667962c9a26991286f263723e3ff965f545f Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Wed, 29 Jan 2025 10:14:12 -0800 Subject: [PATCH 10/13] pass split as a separate option --- pyspark_huggingface/huggingface_sink.py | 117 ++++++++++++++---------- tests/test_huggingface_writer.py | 38 ++++---- 2 files changed, 87 insertions(+), 68 deletions(-) diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index 931866b..baafa5f 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -1,5 +1,6 @@ import ast import logging +import os from dataclasses import dataclass from typing import TYPE_CHECKING, Iterator, List, Optional @@ -11,11 +12,7 @@ from pyspark.sql.types import StructType if TYPE_CHECKING: - from huggingface_hub import ( - CommitOperationAdd, - CommitOperationDelete, - HfFileSystemResolvedPath, - ) + from huggingface_hub import CommitOperationAdd, CommitOperationDelete from pyarrow import RecordBatch logger = logging.getLogger(__name__) @@ -28,12 +25,12 @@ class HuggingFaceSink(DataSource): Name: `huggingfacesink` - Path: - - The path must be a valid HuggingFace dataset path, e.g. `{user}/{repo}` or `{user}/{repo}/{split}`. - - A revision can be specified using the `@` symbol, e.g. `{user}/{repo}@{revision}`. - Data Source Options: - token (str, required): HuggingFace API token for authentication. + - path (str, required): HuggingFace repository ID, e.g. `{username}/{dataset}`. + - path_in_repo (str): Path within the repository to write the data. Defaults to the root. + - split (str): Split name to write the data to. Defaults to `train`. Only `train`, `test`, and `validation` are supported. + - revision (str): Branch, tag, or commit to write to. Defaults to the main branch. - endpoint (str): Custom HuggingFace API endpoint URL. - max_bytes_per_file (int): Maximum size of each Parquet file. - row_group_size (int): Row group size when writing Parquet files. @@ -50,13 +47,13 @@ class HuggingFaceSink(DataSource): >>> df.write.format("huggingfacesink").mode("overwrite").options(token="...").save("user/dataset") - Append to an existing directory on the HuggingFace Hub. + Append to an existing dataset on the HuggingFace Hub. >>> df.write.format("huggingfacesink").mode("append").options(token="...").save("user/dataset") Write to the `test` split of a dataset. - >>> df.write.format("huggingfacesink").mode("overwrite").options(token="...").save("user/dataset/test") + >>> df.write.format("huggingfacesink").mode("overwrite").options(token="...", split="test").save("user/dataset") """ def __init__(self, options): @@ -66,8 +63,11 @@ def __init__(self, options): raise Exception("You must specify a dataset name.") kwargs = dict(self.options) - self.path = kwargs.pop("path") self.token = kwargs.pop("token") + self.repo_id = kwargs.pop("path") + self.path_in_repo = kwargs.pop("path_in_repo", None) + self.split = kwargs.pop("split", None) + self.revision = kwargs.pop("revision", None) self.endpoint = kwargs.pop("endpoint", None) for arg in kwargs: if kwargs[arg].lower() == "true": @@ -87,7 +87,10 @@ def name(cls): def writer(self, schema: StructType, overwrite: bool) -> DataSourceArrowWriter: return HuggingFaceDatasetsWriter( - path=self.path, + repo_id=self.repo_id, + path_in_repo=self.path_in_repo, + split=self.split, + revision=self.revision, schema=schema, overwrite=overwrite, token=self.token, @@ -102,22 +105,30 @@ class HuggingFaceCommitMessage(WriterCommitMessage): class HuggingFaceDatasetsWriter(DataSourceArrowWriter): + repo_type = "dataset" + def __init__( self, *, - path: str, + repo_id: str, + path_in_repo: Optional[str] = None, + split: Optional[str] = None, + revision: Optional[str] = None, schema: StructType, overwrite: bool, token: str, endpoint: Optional[str] = None, row_group_size: Optional[int] = None, max_bytes_per_file=500_000_000, - max_operations_per_commit=25_000, + max_operations_per_commit=100, **kwargs, ): import uuid - self.path = f"datasets/{path}" + self.repo_id = repo_id + self.path_in_repo = (path_in_repo or "").strip("/") + self.split = split or "train" + self.revision = revision self.schema = schema self.overwrite = overwrite self.token = token @@ -130,42 +141,53 @@ def __init__( # Use a unique filename prefix to avoid conflicts with existing files self.uuid = uuid.uuid4() - def get_filesystem(self): - from huggingface_hub import HfFileSystem + self.validate() + + def validate(self): + if self.split not in ["train", "test", "validation"]: + raise NotImplementedError( + f"Only 'train', 'test', and 'validation' splits are supported. Got '{self.split}'." + ) - return HfFileSystem(token=self.token, endpoint=self.endpoint) + def get_api(self): + from huggingface_hub import HfApi - def get_delete_operations( - self, resolved_path: "HfFileSystemResolvedPath" - ) -> Iterator["CommitOperationDelete"]: + return HfApi(token=self.token, endpoint=self.endpoint) + + @property + def prefix(self) -> str: + return os.path.join(self.path_in_repo, self.split) + + def get_delete_operations(self) -> Iterator["CommitOperationDelete"]: """ Get the commit operations to delete all existing Parquet files. This is used when `overwrite=True` to clear the target directory. """ from huggingface_hub import CommitOperationDelete from huggingface_hub.errors import EntryNotFoundError - from huggingface_hub.hf_api import RepoFile + from huggingface_hub.hf_api import RepoFolder - fs = self.get_filesystem() + api = self.get_api() try: - objects = fs._api.list_repo_tree( - path_in_repo=resolved_path.path_in_repo, - repo_id=resolved_path.repo_id, - repo_type=resolved_path.repo_type, - revision=resolved_path.revision, + objects = api.list_repo_tree( + path_in_repo=self.path_in_repo, + repo_id=self.repo_id, + repo_type=self.repo_type, + revision=self.revision, expand=False, recursive=False, ) for obj in objects: - if isinstance(obj, RepoFile) and obj.path.endswith(".parquet"): - yield CommitOperationDelete(path_in_repo=obj.path, is_folder=False) + if obj.path.startswith(self.prefix): + yield CommitOperationDelete( + path_in_repo=obj.path, is_folder=isinstance(obj, RepoFolder) + ) except EntryNotFoundError as e: logger.info(f"Writing to a new path: {e}") def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: import io - import os from huggingface_hub import CommitOperationAdd from pyarrow import parquet as pq @@ -176,8 +198,7 @@ def write(self, iterator: Iterator["RecordBatch"]) -> HuggingFaceCommitMessage: context = TaskContext.get() partition_id = context.partitionId() if context else 0 - fs = self.get_filesystem() - resolved_path = fs.resolve_path(self.path) + api = self.get_api() schema = to_arrow_schema(self.schema) num_files = 0 @@ -192,19 +213,20 @@ def flush(writer: pq.ParquetWriter): """ writer.close() # Close the writer to flush the buffer nonlocal num_files - name = f"{self.uuid}-part-{partition_id}-{num_files}.parquet" # Name of the file in the repo - path_in_repo = os.path.join(resolved_path.path_in_repo, name) + name = ( + f"{self.prefix}-{self.uuid}-part-{partition_id}-{num_files}.parquet" + ) num_files += 1 parquet.seek(0) addition = CommitOperationAdd( - path_in_repo=path_in_repo, path_or_fileobj=parquet + path_in_repo=name, path_or_fileobj=parquet ) - fs._api.preupload_lfs_files( - repo_id=resolved_path.repo_id, + api.preupload_lfs_files( + repo_id=self.repo_id, additions=[addition], - repo_type=resolved_path.repo_type, - revision=resolved_path.revision, + repo_type=self.repo_type, + revision=self.revision, ) additions.append(addition) @@ -235,13 +257,12 @@ def flush(writer: pq.ParquetWriter): def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ignore[override] import math - fs = self.get_filesystem() - resolved_path = fs.resolve_path(self.path) + api = self.get_api() operations = [ addition for message in messages for addition in message.additions ] if self.overwrite: # Delete existing files if overwrite is enabled - operations.extend(self.get_delete_operations(resolved_path)) + operations.extend(self.get_delete_operations()) """ Split the commit into multiple parts if necessary. @@ -255,10 +276,10 @@ def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ign commit_message = "Upload using PySpark" + ( f" (part {i:05d}-of-{num_commits:05d})" if num_commits > 1 else "" ) - fs._api.create_commit( - repo_id=resolved_path.repo_id, - repo_type=resolved_path.repo_type, - revision=resolved_path.revision, + api.create_commit( + repo_id=self.repo_id, + repo_type=self.repo_type, + revision=self.revision, operations=part, commit_message=commit_message, ) diff --git a/tests/test_huggingface_writer.py b/tests/test_huggingface_writer.py index f08514b..734621b 100644 --- a/tests/test_huggingface_writer.py +++ b/tests/test_huggingface_writer.py @@ -1,4 +1,5 @@ import os +import uuid import pytest from pyspark.sql import DataFrame, SparkSession @@ -50,10 +51,8 @@ def username(api): @pytest.fixture def repo(api, username): - import uuid - repo_id = f"{username}/test-{uuid.uuid4()}" - api.create_repo(repo_id, private=True, repo_type="dataset") + api.create_repo(repo_id, private=False, repo_type="dataset") yield repo_id api.delete_repo(repo_id, repo_type="dataset") @@ -86,22 +85,26 @@ def test_overwrite(spark, repo, random_df): assertDataFrameEqual(actual, df2) -def test_dir(repo, random_df, api): - df = random_df(10) - writer(df).mode("append").save(repo + "/dir") - assert any( - file.path.endswith(".parquet") - for file in api.list_repo_tree(repo, "dir", repo_type="dataset") - ) +def test_split(spark, repo, random_df): + df1 = random_df(10) + df2 = random_df(10) + writer(df1).mode("append").save(repo) + writer(df2).mode("append").options(split="test").save(repo) + actual1 = reader(spark).options(split="train").load(repo) + actual2 = reader(spark).options(split="test").load(repo) + assertDataFrameEqual(actual1, df1) + assertDataFrameEqual(actual2, df2) def test_revision(repo, random_df, api): df = random_df(10) api.create_branch(repo, branch="test", repo_type="dataset") - writer(df).mode("append").save(repo + "@test") + writer(df).mode("append").options(revision="test").save(repo) assert any( file.path.endswith(".parquet") - for file in api.list_repo_tree(repo, repo_type="dataset", revision="test") + for file in api.list_repo_tree( + repo, repo_type="dataset", revision="test", recursive=True + ) ) @@ -109,20 +112,15 @@ def test_max_bytes_per_file(spark, mocker: MockerFixture): from pyspark_huggingface.huggingface_sink import HuggingFaceDatasetsWriter repo = "user/test" - fs = mocker.patch("huggingface_hub.HfFileSystem").return_value = mocker.MagicMock() - resolved_path = fs.resolve_path.return_value = mocker.MagicMock() - resolved_path.path_in_repo = repo - resolved_path.repo_id = repo - resolved_path.repo_type = "dataset" - resolved_path.revision = "main" + api = mocker.patch("huggingface_hub.HfApi").return_value = mocker.MagicMock() df = spark.range(10) writer = HuggingFaceDatasetsWriter( - path=repo, + repo_id=repo, overwrite=False, schema=df.schema, token="token", max_bytes_per_file=1, ) writer.write(iter(df.toArrow().to_batches(max_chunksize=1))) - assert fs._api.preupload_lfs_files.call_count == 10 + assert api.preupload_lfs_files.call_count == 10 From fa3bdd98badb45eee88451621c214b04f0cb3cd1 Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Wed, 29 Jan 2025 10:23:29 -0800 Subject: [PATCH 11/13] add comment explaining split name support --- pyspark_huggingface/huggingface_sink.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index baafa5f..ed37d7d 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -145,6 +145,17 @@ def __init__( def validate(self): if self.split not in ["train", "test", "validation"]: + """ + TODO: Add support for custom splits. + + For custom split names to be recognized, the files must have path with format: + `data/{split}-{iiiii}-of-{nnnnn}.parquet` + where `iiiii` is the part number and `nnnnn` is the total number of parts, both padded to 5 digits. + Example: `data/custom-00000-of-00002.parquet` + + Therefore the current usage of UUID to avoid naming conflicts won't work for custom split names. + To fix this we can rename the files in the commit phase to satisfy the naming convention. + """ raise NotImplementedError( f"Only 'train', 'test', and 'validation' splits are supported. Got '{self.split}'." ) From 3cb2a0b63c263afe10fc5d91a2d27ff0be3be724 Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Wed, 29 Jan 2025 15:51:56 -0800 Subject: [PATCH 12/13] fix comment --- pyspark_huggingface/huggingface_sink.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index ed37d7d..b131975 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -277,7 +277,7 @@ def commit(self, messages: List[HuggingFaceCommitMessage]) -> None: # type: ign """ Split the commit into multiple parts if necessary. - The HuggingFace API has a limit of 25,000 operations per commit. + The HuggingFace API may time out if there are too many operations in a single commit. """ num_commits = math.ceil(len(operations) / self.max_operations_per_commit) for i in range(num_commits): From 104af7c7eb8ec2484dee6b9bc4b174fcdbaf6941 Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Wed, 29 Jan 2025 16:00:57 -0800 Subject: [PATCH 13/13] fix path separator --- pyspark_huggingface/huggingface_sink.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyspark_huggingface/huggingface_sink.py b/pyspark_huggingface/huggingface_sink.py index b131975..7e6689c 100644 --- a/pyspark_huggingface/huggingface_sink.py +++ b/pyspark_huggingface/huggingface_sink.py @@ -1,6 +1,5 @@ import ast import logging -import os from dataclasses import dataclass from typing import TYPE_CHECKING, Iterator, List, Optional @@ -167,7 +166,7 @@ def get_api(self): @property def prefix(self) -> str: - return os.path.join(self.path_in_repo, self.split) + return f"{self.path_in_repo}/{self.split}".strip("/") def get_delete_operations(self) -> Iterator["CommitOperationDelete"]: """