Skip to content
9 changes: 9 additions & 0 deletions pyiceberg/avro/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ class AvroOutputFile(Generic[D]):
encoder: BinaryEncoder
sync_bytes: bytes
writer: Writer
closed: bool

def __init__(
self,
Expand All @@ -247,6 +248,7 @@ def __init__(
else resolve_writer(record_schema=record_schema, file_schema=self.file_schema)
)
self.metadata = metadata
self.closed = False

def __enter__(self) -> AvroOutputFile[D]:
"""
Expand All @@ -267,6 +269,7 @@ def __exit__(
) -> None:
"""Perform cleanup when exiting the scope of a 'with' statement."""
self.output_stream.close()
self.closed = True

def _write_header(self) -> None:
json_schema = json.dumps(AvroSchemaConversion().iceberg_to_avro(self.file_schema, schema_name=self.schema_name))
Expand All @@ -285,3 +288,9 @@ def write_block(self, objects: List[D]) -> None:
self.encoder.write_int(len(block_content))
self.encoder.write(block_content)
self.encoder.write(self.sync_bytes)

def __len__(self) -> int:
"""Return the total number number of bytes written."""
if self.closed:
return len(self.output_file)
return self.output_stream.tell()
4 changes: 4 additions & 0 deletions pyiceberg/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def __exit__(
) -> None:
"""Perform cleanup when exiting the scope of a 'with' statement."""

@abstractmethod
def tell(self) -> int:
"""Return the total number number of bytes written to the stream."""


class InputFile(ABC):
"""A base class for InputFile implementations.
Expand Down
75 changes: 75 additions & 0 deletions pyiceberg/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,81 @@ def existing(self, entry: ManifestEntry) -> ManifestWriter:
return self


def __len__(self) -> int:
"""Return the total number number of bytes written."""
return len(self._writer)


class RollingManifestWriter:
closed: bool
_supplier: Generator[ManifestWriter, None, None]
_manifest_files: list[ManifestFile]
_target_file_size_in_bytes: int
_target_number_of_rows: int
_current_writer: Optional[ManifestWriter]
_current_file_rows: int

def __init__(
self, supplier: Generator[ManifestWriter, None, None], target_file_size_in_bytes, target_number_of_rows
) -> None:
self._closed = False
self._manifest_files = []
self._supplier = supplier
self._target_file_size_in_bytes = target_file_size_in_bytes
self._target_number_of_rows = target_number_of_rows
self._current_writer = None
self._current_file_rows = 0

def __enter__(self) -> RollingManifestWriter:
self._get_current_writer().__enter__()
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self._close_current_writer()
self._closed = True

def _get_current_writer(self) -> ManifestWriter:
if self._should_roll_to_new_file():
self._close_current_writer()
if not self._current_writer:
self._current_writer = next(self._supplier)
self._current_writer.__enter__()
return self._current_writer
return self._current_writer

def _should_roll_to_new_file(self) -> bool:
if not self._current_writer:
return False
return (
self._current_file_rows >= self._target_number_of_rows or len(self._current_writer) >= self._target_file_size_in_bytes
)

def _close_current_writer(self):
if self._current_writer:
self._current_writer.__exit__(None, None, None)
current_file = self._current_writer.to_manifest_file()
self._manifest_files.append(current_file)
self._current_writer = None
self._current_file_rows = 0

def to_manifest_files(self) -> list[ManifestFile]:
if not self._closed:
raise RuntimeError("Cannot create manifest files from unclosed writer")
return self._manifest_files

def add_entry(self, entry: ManifestEntry) -> RollingManifestWriter:
if self._closed:
raise RuntimeError("Cannot add entry to closed manifest writer")
self._get_current_writer().add_entry(entry)
self._current_file_rows += entry.data_file.record_count
return self


class ManifestWriterV1(ManifestWriter):
def __init__(self, spec: PartitionSpec, schema: Schema, output_file: OutputFile, snapshot_id: int):
super().__init__(
Expand Down
75 changes: 74 additions & 1 deletion tests/utils/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=redefined-outer-name,arguments-renamed,fixme
from tempfile import TemporaryDirectory
from typing import Dict
from typing import Dict, Generator

import fastavro
import pytest
Expand All @@ -30,7 +30,9 @@
ManifestContent,
ManifestEntryStatus,
ManifestFile,
ManifestWriter,
PartitionFieldSummary,
RollingManifestWriter,
read_manifest_list,
write_manifest,
write_manifest_list,
Expand Down Expand Up @@ -493,6 +495,75 @@ def test_write_manifest(
assert data_file.sort_order_id == 0


@pytest.mark.parametrize("format_version", [1, 2])
@pytest.mark.parametrize(
"target_number_of_rows,target_file_size_in_bytes,expected_number_of_files",
[
(19514, 388873, 1), # should not roll over
(19513, 388873, 2), # should roll over due to target_rows
(4000, 388872, 2), # should roll over due target_bytes
(4000, 388872, 2), # should roll over due to target_rows and target_bytes
],
)
def test_rolling_manifest_writer(
generated_manifest_file_file_v1: str,
generated_manifest_file_file_v2: str,
format_version: TableVersion,
target_number_of_rows: int,
target_file_size_in_bytes: int,
expected_number_of_files: int,
) -> None:
io = load_file_io()
snapshot = Snapshot(
snapshot_id=25,
parent_snapshot_id=19,
timestamp_ms=1602638573590,
manifest_list=generated_manifest_file_file_v1 if format_version == 1 else generated_manifest_file_file_v2,
summary=Summary(Operation.APPEND),
schema_id=3,
)
demo_manifest_file = snapshot.manifests(io)[0]
manifest_entries = demo_manifest_file.fetch_manifest_entry(io)
test_schema = Schema(
NestedField(1, "VendorID", IntegerType(), False), NestedField(2, "tpep_pickup_datetime", IntegerType(), False)
)
test_spec = PartitionSpec(
PartitionField(source_id=1, field_id=1, transform=IdentityTransform(), name="VendorID"),
PartitionField(source_id=2, field_id=2, transform=IdentityTransform(), name="tpep_pickup_datetime"),
spec_id=demo_manifest_file.partition_spec_id,
)

with TemporaryDirectory() as tmpdir:

def supplier() -> Generator[ManifestWriter, None, None]:
i = 0
while True:
tmp_avro_file = tmpdir + f"/test_write_manifest-{i}.avro"
output = io.new_output(tmp_avro_file)
yield write_manifest(
format_version=format_version,
spec=test_spec,
schema=test_schema,
output_file=output,
snapshot_id=8744736658442914487,
)
i += 1

with RollingManifestWriter(
supplier=supplier(),
target_file_size_in_bytes=target_file_size_in_bytes,
target_number_of_rows=target_number_of_rows,
) as writer:
for entry in manifest_entries:
writer.add_entry(entry)

manifest_files = writer.to_manifest_files()
assert len(manifest_files) == expected_number_of_files
with pytest.raises(RuntimeError):
# It is already closed
writer.add_entry(manifest_entries[0])


@pytest.mark.parametrize("format_version", [1, 2])
def test_write_manifest_list(
generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: TableVersion
Expand Down Expand Up @@ -560,3 +631,5 @@ def test_write_manifest_list(
assert entry.file_sequence_number == 0 if format_version == 1 else 3
assert entry.snapshot_id == 8744736658442914487
assert entry.status == ManifestEntryStatus.ADDED