Skip to content

Commit 8c5f085

Browse files
Fix setting V1 format version for Non-REST catalogs
1 parent a576fc9 commit 8c5f085

File tree

3 files changed

+173
-4
lines changed

3 files changed

+173
-4
lines changed

pyiceberg/table/metadata.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,10 @@ def set_v2_compatible_defaults(cls, data: Dict[str, Any]) -> Dict[str, Any]:
260260
The TableMetadata with the defaults applied.
261261
"""
262262
# When the schema doesn't have an ID
263-
if data.get("schema") and "schema_id" not in data["schema"]:
264-
data["schema"]["schema_id"] = DEFAULT_SCHEMA_ID
263+
schema = data.get("schema")
264+
if isinstance(schema, dict):
265+
if "schema_id" not in schema:
266+
schema["schema_id"] = DEFAULT_SCHEMA_ID
265267

266268
return data
267269

@@ -313,6 +315,34 @@ def construct_partition_specs(cls, data: Dict[str, Any]) -> Dict[str, Any]:
313315

314316
return data
315317

318+
@model_validator(mode="before")
319+
def construct_v1_spec_from_v2_fields(cls, data: Dict[str, Any]) -> Dict[str, Any]:
320+
specs_field = "partition_specs"
321+
default_spec_id_field = "default_spec_id"
322+
if specs_field in data and default_spec_id_field in data:
323+
specs = data[specs_field]
324+
spec_id = data[default_spec_id_field]
325+
for spec in specs:
326+
if spec.spec_id == spec_id:
327+
data["partition_spec"] = [spec.model_dump()]
328+
return data
329+
330+
return data
331+
332+
@model_validator(mode="before")
333+
def construct_v1_schema_from_v2_fields(cls, data: Dict[str, Any]) -> Dict[str, Any]:
334+
schemas_field = "schemas"
335+
current_schema_id_field = "current_schema_id"
336+
if schemas_field in data and current_schema_id_field in data:
337+
schemas = data[schemas_field]
338+
current_schema_id = data[current_schema_id_field]
339+
for schema in schemas:
340+
if schema.schema_id == current_schema_id:
341+
data["schema"] = schema
342+
return data
343+
344+
return data
345+
316346
@model_validator(mode="before")
317347
def set_sort_orders(cls, data: Dict[str, Any]) -> Dict[str, Any]:
318348
"""Set the sort_orders if not provided.
@@ -335,7 +365,7 @@ def to_v2(self) -> TableMetadataV2:
335365
metadata["format-version"] = 2
336366
return TableMetadataV2.model_validate(metadata)
337367

338-
format_version: Literal[1] = Field(alias="format-version")
368+
format_version: Literal[1] = Field(alias="format-version", default=1)
339369
"""An integer version number for the format. Currently, this can be 1 or 2
340370
based on the spec. Implementations must throw an exception if a table’s
341371
version is higher than the supported version."""
@@ -394,6 +424,7 @@ def construct_refs(cls, table_metadata: TableMetadata) -> TableMetadata:
394424

395425

396426
TableMetadata = Annotated[Union[TableMetadataV1, TableMetadataV2], Field(discriminator="format_version")]
427+
DEFAULT_FORMAT_VERSION = "2"
397428

398429

399430
def new_table_metadata(
@@ -411,6 +442,24 @@ def new_table_metadata(
411442
if table_uuid is None:
412443
table_uuid = uuid.uuid4()
413444

445+
# Remove format-version so it does not get persisted
446+
format_version = properties.pop("format-version", DEFAULT_FORMAT_VERSION)
447+
448+
if int(format_version) == 1:
449+
return TableMetadataV1(
450+
location=location,
451+
schema=fresh_schema,
452+
last_column_id=fresh_schema.highest_field_id,
453+
current_schema_id=fresh_schema.schema_id,
454+
partition_specs=[fresh_partition_spec],
455+
default_spec_id=fresh_partition_spec.spec_id,
456+
sort_orders=[fresh_sort_order],
457+
default_sort_order_id=fresh_sort_order.order_id,
458+
properties=properties,
459+
last_partition_id=fresh_partition_spec.last_assigned_field_id,
460+
table_uuid=table_uuid,
461+
)
462+
414463
return TableMetadataV2(
415464
location=location,
416465
schemas=[fresh_schema],

tests/catalog/test_hive.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
)
4444
from pyiceberg.partitioning import PartitionField, PartitionSpec
4545
from pyiceberg.schema import Schema
46-
from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV2
46+
from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV1, TableMetadataV2
4747
from pyiceberg.table.refs import SnapshotRef, SnapshotRefType
4848
from pyiceberg.table.snapshots import (
4949
MetadataLogEntry,
@@ -294,6 +294,60 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase,
294294

295295
assert metadata.model_dump() == expected.model_dump()
296296

297+
catalog.create_table(
298+
("default", "table_v1"), schema=table_schema_simple, properties={"owner": "javaberg", "format-version": "1"}
299+
)
300+
301+
302+
@patch("time.time", MagicMock(return_value=12345))
303+
def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabase, hive_table: HiveTable) -> None:
304+
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
305+
306+
catalog._client = MagicMock()
307+
catalog._client.__enter__().create_table.return_value = None
308+
catalog._client.__enter__().get_table.return_value = hive_table
309+
catalog._client.__enter__().get_database.return_value = hive_database
310+
catalog.create_table(
311+
("default", "table"), schema=table_schema_simple, properties={"owner": "javaberg", "format-version": "1"}
312+
)
313+
314+
# Test creating V1 table
315+
called_v1_table: HiveTable = catalog._client.__enter__().create_table.call_args[0][0]
316+
metadata_location = called_v1_table.parameters["metadata_location"]
317+
with open(metadata_location, encoding=UTF8) as f:
318+
payload = f.read()
319+
320+
actual_v1_metadata = TableMetadataUtil.parse_raw(payload)
321+
expected_v1_metadata = TableMetadataV1(
322+
location=actual_v1_metadata.location,
323+
table_uuid=actual_v1_metadata.table_uuid,
324+
last_updated_ms=actual_v1_metadata.last_updated_ms,
325+
last_column_id=3,
326+
schemas=[
327+
Schema(
328+
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
329+
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
330+
NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
331+
schema_id=0,
332+
identifier_field_ids=[2],
333+
)
334+
],
335+
current_schema_id=0,
336+
last_partition_id=1000,
337+
properties={"owner": "javaberg", "write.parquet.compression-codec": "zstd"},
338+
partition_spec=[{'fields': [], 'spec-id': 0}],
339+
current_snapshot_id=None,
340+
snapshots=[],
341+
snapshot_log=[],
342+
metadata_log=[],
343+
sort_orders=[SortOrder(order_id=0)],
344+
default_sort_order_id=0,
345+
refs={},
346+
format_version=1,
347+
)
348+
349+
assert actual_v1_metadata.model_dump() == expected_v1_metadata.model_dump()
350+
297351

298352
def test_load_table(hive_table: HiveTable) -> None:
299353
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)

tests/table/test_metadata.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,72 @@ def test_migrate_v1_partition_specs(example_table_metadata_v1: Dict[str, Any]) -
199199
]
200200

201201

202+
def test_new_table_metadata_with_explicit_v1_format() -> None:
203+
schema = Schema(
204+
NestedField(field_id=10, name="foo", field_type=StringType(), required=False),
205+
NestedField(field_id=22, name="bar", field_type=IntegerType(), required=True),
206+
NestedField(field_id=33, name="baz", field_type=BooleanType(), required=False),
207+
schema_id=10,
208+
identifier_field_ids=[22],
209+
)
210+
211+
partition_spec = PartitionSpec(
212+
PartitionField(source_id=22, field_id=1022, transform=IdentityTransform(), name="bar"), spec_id=10
213+
)
214+
215+
sort_order = SortOrder(
216+
SortField(source_id=10, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST),
217+
order_id=10,
218+
)
219+
220+
actual = new_table_metadata(
221+
schema=schema,
222+
partition_spec=partition_spec,
223+
sort_order=sort_order,
224+
location="s3://some_v1_location/",
225+
properties={'format-version': "1"},
226+
)
227+
228+
expected = TableMetadataV1(
229+
location="s3://some_v1_location/",
230+
table_uuid=actual.table_uuid,
231+
last_updated_ms=actual.last_updated_ms,
232+
last_column_id=3,
233+
schemas=[
234+
Schema(
235+
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
236+
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
237+
NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
238+
schema_id=0,
239+
identifier_field_ids=[2],
240+
)
241+
],
242+
current_schema_id=0,
243+
partition_specs=[PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="bar"))],
244+
default_spec_id=0,
245+
last_partition_id=1000,
246+
properties={},
247+
current_snapshot_id=None,
248+
snapshots=[],
249+
snapshot_log=[],
250+
metadata_log=[],
251+
sort_orders=[
252+
SortOrder(
253+
SortField(
254+
source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST
255+
),
256+
order_id=1,
257+
)
258+
],
259+
default_sort_order_id=1,
260+
refs={},
261+
format_version=1,
262+
last_sequence_number=0,
263+
)
264+
265+
assert actual.model_dump() == expected.model_dump()
266+
267+
202268
def test_invalid_format_version(example_table_metadata_v1: Dict[str, Any]) -> None:
203269
"""Test the exception when trying to load an unknown version"""
204270

0 commit comments

Comments
 (0)