Skip to content

Commit 2354919

Browse files
Fix setting V1 format version for Non-REST catalogs
1 parent a576fc9 commit 2354919

File tree

5 files changed

+220
-4
lines changed

5 files changed

+220
-4
lines changed

pyiceberg/table/metadata.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,10 @@ def set_v2_compatible_defaults(cls, data: Dict[str, Any]) -> Dict[str, Any]:
260260
The TableMetadata with the defaults applied.
261261
"""
262262
# When the schema doesn't have an ID
263-
if data.get("schema") and "schema_id" not in data["schema"]:
264-
data["schema"]["schema_id"] = DEFAULT_SCHEMA_ID
263+
schema = data.get("schema")
264+
if isinstance(schema, dict):
265+
if "schema_id" not in schema:
266+
schema["schema_id"] = DEFAULT_SCHEMA_ID
265267

266268
return data
267269

@@ -313,6 +315,34 @@ def construct_partition_specs(cls, data: Dict[str, Any]) -> Dict[str, Any]:
313315

314316
return data
315317

318+
@model_validator(mode="before")
319+
def construct_v1_spec_from_v2_fields(cls, data: Dict[str, Any]) -> Dict[str, Any]:
320+
specs_field = "partition_specs"
321+
default_spec_id_field = "default_spec_id"
322+
if specs_field in data and default_spec_id_field in data:
323+
found_spec = next((spec for spec in data[specs_field] if spec.spec_id == data[default_spec_id_field]), None)
324+
if found_spec is not None:
325+
spec_dict = found_spec.model_dump()
326+
spec_dict['fields'] = list(spec_dict['fields'])
327+
data["partition_spec"] = [spec_dict]
328+
return data
329+
330+
return data
331+
332+
@model_validator(mode="before")
333+
def construct_v1_schema_from_v2_fields(cls, data: Dict[str, Any]) -> Dict[str, Any]:
334+
schemas_field = "schemas"
335+
current_schema_id_field = "current_schema_id"
336+
if schemas_field in data and current_schema_id_field in data:
337+
found_schema = next(
338+
(schema for schema in data[schemas_field] if schema.schema_id == data[current_schema_id_field]), None
339+
)
340+
if found_schema is not None:
341+
data["schema"] = found_schema
342+
return data
343+
344+
return data
345+
316346
@model_validator(mode="before")
317347
def set_sort_orders(cls, data: Dict[str, Any]) -> Dict[str, Any]:
318348
"""Set the sort_orders if not provided.
@@ -335,7 +365,7 @@ def to_v2(self) -> TableMetadataV2:
335365
metadata["format-version"] = 2
336366
return TableMetadataV2.model_validate(metadata)
337367

338-
format_version: Literal[1] = Field(alias="format-version")
368+
format_version: Literal[1] = Field(alias="format-version", default=1)
339369
"""An integer version number for the format. Currently, this can be 1 or 2
340370
based on the spec. Implementations must throw an exception if a table’s
341371
version is higher than the supported version."""
@@ -394,6 +424,7 @@ def construct_refs(cls, table_metadata: TableMetadata) -> TableMetadata:
394424

395425

396426
TableMetadata = Annotated[Union[TableMetadataV1, TableMetadataV2], Field(discriminator="format_version")]
427+
DEFAULT_FORMAT_VERSION = "2"
397428

398429

399430
def new_table_metadata(
@@ -411,6 +442,24 @@ def new_table_metadata(
411442
if table_uuid is None:
412443
table_uuid = uuid.uuid4()
413444

445+
# Remove format-version so it does not get persisted
446+
format_version = int(properties.pop("format-version", DEFAULT_FORMAT_VERSION))
447+
448+
if format_version == 1:
449+
return TableMetadataV1(
450+
location=location,
451+
schema=fresh_schema,
452+
last_column_id=fresh_schema.highest_field_id,
453+
current_schema_id=fresh_schema.schema_id,
454+
partition_specs=[fresh_partition_spec],
455+
default_spec_id=fresh_partition_spec.spec_id,
456+
sort_orders=[fresh_sort_order],
457+
default_sort_order_id=fresh_sort_order.order_id,
458+
properties=properties,
459+
last_partition_id=fresh_partition_spec.last_assigned_field_id,
460+
table_uuid=table_uuid,
461+
)
462+
414463
return TableMetadataV2(
415464
location=location,
416465
schemas=[fresh_schema],

tests/catalog/test_glue.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,38 @@ def test_create_table_with_database_location(
7272
assert storage_descriptor["Location"] == f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"
7373

7474

75+
@mock_aws
76+
def test_create_v1_table(
77+
_bucket_initialize: None,
78+
_glue: boto3.client,
79+
moto_endpoint_url: str,
80+
table_schema_nested: Schema,
81+
database_name: str,
82+
table_name: str,
83+
) -> None:
84+
catalog_name = "glue"
85+
test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url})
86+
test_catalog.create_namespace(namespace=database_name, properties={"location": f"s3://{BUCKET_NAME}/{database_name}.db"})
87+
table = test_catalog.create_table((database_name, table_name), table_schema_nested, properties={"format-version": "1"})
88+
assert table.format_version == 1
89+
90+
table_info = _glue.get_table(
91+
DatabaseName=database_name,
92+
Name=table_name,
93+
)
94+
95+
storage_descriptor = table_info["Table"]["StorageDescriptor"]
96+
columns = storage_descriptor["Columns"]
97+
assert len(columns) == len(table_schema_nested.fields)
98+
assert columns[0] == {
99+
"Name": "foo",
100+
"Type": "string",
101+
"Parameters": {"iceberg.field.id": "1", "iceberg.field.optional": "true", "iceberg.field.current": "true"},
102+
}
103+
104+
assert storage_descriptor["Location"] == f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"
105+
106+
75107
@mock_aws
76108
def test_create_table_with_default_warehouse(
77109
_bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str

tests/catalog/test_hive.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
)
4444
from pyiceberg.partitioning import PartitionField, PartitionSpec
4545
from pyiceberg.schema import Schema
46-
from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV2
46+
from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV1, TableMetadataV2
4747
from pyiceberg.table.refs import SnapshotRef, SnapshotRefType
4848
from pyiceberg.table.snapshots import (
4949
MetadataLogEntry,
@@ -295,6 +295,57 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase,
295295
assert metadata.model_dump() == expected.model_dump()
296296

297297

298+
@patch("time.time", MagicMock(return_value=12345))
299+
def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabase, hive_table: HiveTable) -> None:
300+
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
301+
302+
catalog._client = MagicMock()
303+
catalog._client.__enter__().create_table.return_value = None
304+
catalog._client.__enter__().get_table.return_value = hive_table
305+
catalog._client.__enter__().get_database.return_value = hive_database
306+
catalog.create_table(
307+
("default", "table"), schema=table_schema_simple, properties={"owner": "javaberg", "format-version": "1"}
308+
)
309+
310+
# Test creating V1 table
311+
called_v1_table: HiveTable = catalog._client.__enter__().create_table.call_args[0][0]
312+
metadata_location = called_v1_table.parameters["metadata_location"]
313+
with open(metadata_location, encoding=UTF8) as f:
314+
payload = f.read()
315+
316+
actual_v1_metadata = TableMetadataUtil.parse_raw(payload)
317+
expected_v1_metadata = TableMetadataV1(
318+
location=actual_v1_metadata.location,
319+
table_uuid=actual_v1_metadata.table_uuid,
320+
last_updated_ms=actual_v1_metadata.last_updated_ms,
321+
last_column_id=3,
322+
schemas=[
323+
Schema(
324+
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
325+
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
326+
NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
327+
schema_id=0,
328+
identifier_field_ids=[2],
329+
)
330+
],
331+
current_schema_id=0,
332+
last_partition_id=1000,
333+
properties={"owner": "javaberg", "write.parquet.compression-codec": "zstd"},
334+
partition_specs=[PartitionSpec()],
335+
default_spec_id=0,
336+
current_snapshot_id=None,
337+
snapshots=[],
338+
snapshot_log=[],
339+
metadata_log=[],
340+
sort_orders=[SortOrder(order_id=0)],
341+
default_sort_order_id=0,
342+
refs={},
343+
format_version=1,
344+
)
345+
346+
assert actual_v1_metadata.model_dump() == expected_v1_metadata.model_dump()
347+
348+
298349
def test_load_table(hive_table: HiveTable) -> None:
299350
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
300351

tests/catalog/test_sql.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from pyiceberg.io import FSSPEC_FILE_IO, PY_IO_IMPL
4040
from pyiceberg.io.pyarrow import schema_to_pyarrow
41+
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC
4142
from pyiceberg.schema import Schema
4243
from pyiceberg.table.snapshots import Operation
4344
from pyiceberg.table.sorting import (
@@ -158,6 +159,24 @@ def test_create_table_default_sort_order(catalog: SqlCatalog, table_schema_neste
158159
catalog.drop_table(random_identifier)
159160

160161

162+
@pytest.mark.parametrize(
163+
'catalog',
164+
[
165+
lazy_fixture('catalog_memory'),
166+
lazy_fixture('catalog_sqlite'),
167+
],
168+
)
169+
def test_create_v1_table(catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier) -> None:
170+
database_name, _table_name = random_identifier
171+
catalog.create_namespace(database_name)
172+
table = catalog.create_table(random_identifier, table_schema_nested, properties={"format-version": "1"})
173+
assert table.sort_order().order_id == 0, "Order ID must match"
174+
assert table.sort_order().is_unsorted is True, "Order must be unsorted"
175+
assert table.format_version == 1
176+
assert table.spec() == UNPARTITIONED_PARTITION_SPEC
177+
catalog.drop_table(random_identifier)
178+
179+
161180
@pytest.mark.parametrize(
162181
'catalog',
163182
[

tests/table/test_metadata.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,71 @@ def test_migrate_v1_partition_specs(example_table_metadata_v1: Dict[str, Any]) -
199199
]
200200

201201

202+
def test_new_table_metadata_with_explicit_v1_format() -> None:
203+
schema = Schema(
204+
NestedField(field_id=10, name="foo", field_type=StringType(), required=False),
205+
NestedField(field_id=22, name="bar", field_type=IntegerType(), required=True),
206+
NestedField(field_id=33, name="baz", field_type=BooleanType(), required=False),
207+
schema_id=10,
208+
identifier_field_ids=[22],
209+
)
210+
211+
partition_spec = PartitionSpec(
212+
PartitionField(source_id=22, field_id=1022, transform=IdentityTransform(), name="bar"), spec_id=10
213+
)
214+
215+
sort_order = SortOrder(
216+
SortField(source_id=10, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST),
217+
order_id=10,
218+
)
219+
220+
actual = new_table_metadata(
221+
schema=schema,
222+
partition_spec=partition_spec,
223+
sort_order=sort_order,
224+
location="s3://some_v1_location/",
225+
properties={'format-version': "1"},
226+
)
227+
228+
expected = TableMetadataV1(
229+
location="s3://some_v1_location/",
230+
table_uuid=actual.table_uuid,
231+
last_updated_ms=actual.last_updated_ms,
232+
last_column_id=3,
233+
schemas=[
234+
Schema(
235+
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
236+
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
237+
NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
238+
schema_id=0,
239+
identifier_field_ids=[2],
240+
)
241+
],
242+
current_schema_id=0,
243+
partition_specs=[PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="bar"))],
244+
default_spec_id=0,
245+
last_partition_id=1000,
246+
properties={},
247+
current_snapshot_id=None,
248+
snapshots=[],
249+
snapshot_log=[],
250+
metadata_log=[],
251+
sort_orders=[
252+
SortOrder(
253+
SortField(
254+
source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST
255+
),
256+
order_id=1,
257+
)
258+
],
259+
default_sort_order_id=1,
260+
refs={},
261+
format_version=1,
262+
)
263+
264+
assert actual.model_dump() == expected.model_dump()
265+
266+
202267
def test_invalid_format_version(example_table_metadata_v1: Dict[str, Any]) -> None:
203268
"""Test the exception when trying to load an unknown version"""
204269

0 commit comments

Comments
 (0)