Skip to content

Commit fd9dc88

Browse files
sungwyFokko
andauthored
Update NameMapping on update_schema() (#441)
* update name-mapping * Update __init__.py Co-authored-by: Fokko Driesprong <[email protected]> * Update pyiceberg/table/name_mapping.py Co-authored-by: Fokko Driesprong <[email protected]> * validation mode after * type --------- Co-authored-by: Fokko Driesprong <[email protected]>
1 parent b32d3a5 commit fd9dc88

File tree

4 files changed

+188
-19
lines changed

4 files changed

+188
-19
lines changed

pyiceberg/table/__init__.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,7 @@
8484
TableMetadata,
8585
TableMetadataUtil,
8686
)
87-
from pyiceberg.table.name_mapping import (
88-
NameMapping,
89-
create_mapping_from_schema,
90-
parse_mapping_from_json,
91-
)
87+
from pyiceberg.table.name_mapping import NameMapping, parse_mapping_from_json, update_mapping
9288
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
9389
from pyiceberg.table.snapshots import (
9490
Operation,
@@ -994,12 +990,12 @@ def update_snapshot(self) -> UpdateSnapshot:
994990
"""
995991
return UpdateSnapshot(self)
996992

997-
def name_mapping(self) -> NameMapping:
993+
def name_mapping(self) -> Optional[NameMapping]:
998994
"""Return the table's field-id NameMapping."""
999995
if name_mapping_json := self.properties.get(TableProperties.DEFAULT_NAME_MAPPING):
1000996
return parse_mapping_from_json(name_mapping_json)
1001997
else:
1002-
return create_mapping_from_schema(self.schema())
998+
return None
1003999

10041000
def append(self, df: pa.Table) -> None:
10051001
"""
@@ -1950,6 +1946,12 @@ def commit(self) -> None:
19501946
else:
19511947
updates = (SetCurrentSchemaUpdate(schema_id=existing_schema_id),) # type: ignore
19521948

1949+
if name_mapping := self._table.name_mapping():
1950+
updated_name_mapping = update_mapping(name_mapping, self._updates, self._adds)
1951+
updates += ( # type: ignore
1952+
SetPropertiesUpdate(updates={TableProperties.DEFAULT_NAME_MAPPING: updated_name_mapping.model_dump_json()}),
1953+
)
1954+
19531955
if self._transaction is not None:
19541956
self._transaction._append_updates(*updates) # pylint: disable=W0212
19551957
self._transaction._append_requirements(*requirements) # pylint: disable=W0212

pyiceberg/table/name_mapping.py

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from abc import ABC, abstractmethod
2727
from collections import ChainMap
2828
from functools import cached_property, singledispatch
29-
from typing import Any, Dict, Generic, List, TypeVar, Union
29+
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
3030

3131
from pydantic import Field, conlist, field_validator, model_serializer
3232

@@ -45,6 +45,18 @@ class MappedField(IcebergBaseModel):
4545
def convert_null_to_empty_List(cls, v: Any) -> Any:
4646
return v or []
4747

48+
@field_validator('names', mode='after')
49+
@classmethod
50+
def check_at_least_one(cls, v: List[str]) -> Any:
51+
"""
52+
Conlist constraint does not seem to be validating the class on instantiation.
53+
54+
Adding a custom validator to enforce min_length=1 constraint.
55+
"""
56+
if len(v) < 1:
57+
raise ValueError("At least one mapped name must be provided for the field")
58+
return v
59+
4860
@model_serializer
4961
def ser_model(self) -> Dict[str, Any]:
5062
"""Set custom serializer to leave out the field when it is empty."""
@@ -93,24 +105,25 @@ def __str__(self) -> str:
93105
return "[\n " + "\n ".join([str(e) for e in self.root]) + "\n]"
94106

95107

108+
S = TypeVar('S')
96109
T = TypeVar("T")
97110

98111

99-
class NameMappingVisitor(Generic[T], ABC):
112+
class NameMappingVisitor(Generic[S, T], ABC):
100113
@abstractmethod
101-
def mapping(self, nm: NameMapping, field_results: T) -> T:
114+
def mapping(self, nm: NameMapping, field_results: S) -> S:
102115
"""Visit a NameMapping."""
103116

104117
@abstractmethod
105-
def fields(self, struct: List[MappedField], field_results: List[T]) -> T:
118+
def fields(self, struct: List[MappedField], field_results: List[T]) -> S:
106119
"""Visit a List[MappedField]."""
107120

108121
@abstractmethod
109-
def field(self, field: MappedField, field_result: T) -> T:
122+
def field(self, field: MappedField, field_result: S) -> T:
110123
"""Visit a MappedField."""
111124

112125

113-
class _IndexByName(NameMappingVisitor[Dict[str, MappedField]]):
126+
class _IndexByName(NameMappingVisitor[Dict[str, MappedField], Dict[str, MappedField]]):
114127
def mapping(self, nm: NameMapping, field_results: Dict[str, MappedField]) -> Dict[str, MappedField]:
115128
return field_results
116129

@@ -129,18 +142,18 @@ def field(self, field: MappedField, field_result: Dict[str, MappedField]) -> Dic
129142

130143

131144
@singledispatch
132-
def visit_name_mapping(obj: Union[NameMapping, List[MappedField], MappedField], visitor: NameMappingVisitor[T]) -> T:
145+
def visit_name_mapping(obj: Union[NameMapping, List[MappedField], MappedField], visitor: NameMappingVisitor[S, T]) -> S:
133146
"""Traverse the name mapping in post-order traversal."""
134147
raise NotImplementedError(f"Cannot visit non-type: {obj}")
135148

136149

137150
@visit_name_mapping.register(NameMapping)
138-
def _(obj: NameMapping, visitor: NameMappingVisitor[T]) -> T:
151+
def _(obj: NameMapping, visitor: NameMappingVisitor[S, T]) -> S:
139152
return visitor.mapping(obj, visit_name_mapping(obj.root, visitor))
140153

141154

142155
@visit_name_mapping.register(list)
143-
def _(fields: List[MappedField], visitor: NameMappingVisitor[T]) -> T:
156+
def _(fields: List[MappedField], visitor: NameMappingVisitor[S, T]) -> S:
144157
results = [visitor.field(field, visit_name_mapping(field.fields, visitor)) for field in fields]
145158
return visitor.fields(fields, results)
146159

@@ -175,5 +188,71 @@ def primitive(self, primitive: PrimitiveType) -> List[MappedField]:
175188
return []
176189

177190

191+
class _UpdateMapping(NameMappingVisitor[List[MappedField], MappedField]):
192+
_updates: Dict[int, NestedField]
193+
_adds: Dict[int, List[NestedField]]
194+
195+
def __init__(self, updates: Dict[int, NestedField], adds: Dict[int, List[NestedField]]):
196+
self._updates = updates
197+
self._adds = adds
198+
199+
@staticmethod
200+
def _remove_reassigned_names(field: MappedField, assignments: Dict[str, int]) -> Optional[MappedField]:
201+
removed_names = set()
202+
for name in field.names:
203+
if (assigned_id := assignments.get(name)) and assigned_id != field.field_id:
204+
removed_names.add(name)
205+
206+
remaining_names = [f for f in field.names if f not in removed_names]
207+
if remaining_names:
208+
return MappedField(field_id=field.field_id, names=remaining_names, fields=field.fields)
209+
else:
210+
return None
211+
212+
def _add_new_fields(self, mapped_fields: List[MappedField], parent_id: int) -> List[MappedField]:
213+
if fields_to_add := self._adds.get(parent_id):
214+
fields: List[MappedField] = []
215+
new_fields: List[MappedField] = []
216+
217+
for add in fields_to_add:
218+
new_fields.append(
219+
MappedField(field_id=add.field_id, names=[add.name], fields=visit(add.field_type, _CreateMapping()))
220+
)
221+
222+
reassignments = {f.name: f.field_id for f in fields_to_add}
223+
fields = [
224+
updated_field
225+
for field in mapped_fields
226+
if (updated_field := self._remove_reassigned_names(field, reassignments)) is not None
227+
] + new_fields
228+
return fields
229+
else:
230+
return mapped_fields
231+
232+
def mapping(self, nm: NameMapping, field_results: List[MappedField]) -> List[MappedField]:
233+
return self._add_new_fields(field_results, -1)
234+
235+
def fields(self, struct: List[MappedField], field_results: List[MappedField]) -> List[MappedField]:
236+
reassignments: Dict[str, int] = {
237+
update.name: update.field_id for f in field_results if (update := self._updates.get(f.field_id))
238+
}
239+
return [
240+
updated_field
241+
for field in field_results
242+
if (updated_field := self._remove_reassigned_names(field, reassignments)) is not None
243+
]
244+
245+
def field(self, field: MappedField, field_result: List[MappedField]) -> MappedField:
246+
field_names = field.names
247+
if (update := self._updates.get(field.field_id)) is not None and update.name not in field_names:
248+
field_names.append(update.name)
249+
250+
return MappedField(field_id=field.field_id, names=field_names, fields=self._add_new_fields(field_result, field.field_id))
251+
252+
178253
def create_mapping_from_schema(schema: Schema) -> NameMapping:
179254
return NameMapping(visit(schema, _CreateMapping()))
255+
256+
257+
def update_mapping(mapping: NameMapping, updates: Dict[int, NestedField], adds: Dict[int, List[NestedField]]) -> NameMapping:
258+
return NameMapping(visit_name_mapping(mapping, _UpdateMapping(updates, adds)))

tests/integration/test_rest_schema.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from pyiceberg.exceptions import CommitFailedException, NoSuchTableError, ValidationError
2323
from pyiceberg.partitioning import PartitionField, PartitionSpec
2424
from pyiceberg.schema import Schema, prune_columns
25-
from pyiceberg.table import Table, UpdateSchema
25+
from pyiceberg.table import Table, TableProperties, UpdateSchema
26+
from pyiceberg.table.name_mapping import MappedField, NameMapping, create_mapping_from_schema
2627
from pyiceberg.table.sorting import SortField, SortOrder
2728
from pyiceberg.transforms import IdentityTransform
2829
from pyiceberg.types import (
@@ -73,7 +74,11 @@ def _create_table_with_schema(catalog: Catalog, schema: Schema) -> Table:
7374
catalog.drop_table(tbl_name)
7475
except NoSuchTableError:
7576
pass
76-
return catalog.create_table(identifier=tbl_name, schema=schema)
77+
return catalog.create_table(
78+
identifier=tbl_name,
79+
schema=schema,
80+
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
81+
)
7782

7883

7984
@pytest.mark.integration
@@ -674,6 +679,13 @@ def test_rename_simple(simple_table: Table) -> None:
674679
identifier_field_ids=[2],
675680
)
676681

682+
# Check that the name mapping gets updated
683+
assert simple_table.name_mapping() == NameMapping([
684+
MappedField(field_id=1, names=['foo', 'vo']),
685+
MappedField(field_id=2, names=['bar']),
686+
MappedField(field_id=3, names=['baz']),
687+
])
688+
677689

678690
@pytest.mark.integration
679691
def test_rename_simple_nested(catalog: Catalog) -> None:
@@ -701,6 +713,11 @@ def test_rename_simple_nested(catalog: Catalog) -> None:
701713
),
702714
)
703715

716+
# Check that the name mapping gets updated
717+
assert tbl.name_mapping() == NameMapping([
718+
MappedField(field_id=1, names=['foo'], fields=[MappedField(field_id=2, names=['bar', 'vo'])]),
719+
])
720+
704721

705722
@pytest.mark.integration
706723
def test_rename_simple_nested_with_dots(catalog: Catalog) -> None:

tests/table/test_name_mapping.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@
1717
import pytest
1818

1919
from pyiceberg.schema import Schema
20-
from pyiceberg.table.name_mapping import MappedField, NameMapping, create_mapping_from_schema, parse_mapping_from_json
20+
from pyiceberg.table.name_mapping import (
21+
MappedField,
22+
NameMapping,
23+
create_mapping_from_schema,
24+
parse_mapping_from_json,
25+
update_mapping,
26+
)
27+
from pyiceberg.types import NestedField, StringType
2128

2229

2330
@pytest.fixture(scope="session")
@@ -238,3 +245,67 @@ def test_mapping_lookup_by_name(table_name_mapping_nested: NameMapping) -> None:
238245

239246
with pytest.raises(ValueError, match="Could not find field with name: boom"):
240247
table_name_mapping_nested.find("boom")
248+
249+
250+
def test_invalid_mapped_field() -> None:
251+
with pytest.raises(ValueError):
252+
MappedField(field_id=1, names=[])
253+
254+
255+
def test_update_mapping_no_updates_or_adds(table_name_mapping_nested: NameMapping) -> None:
256+
assert update_mapping(table_name_mapping_nested, {}, {}) == table_name_mapping_nested
257+
258+
259+
def test_update_mapping(table_name_mapping_nested: NameMapping) -> None:
260+
updates = {1: NestedField(1, "foo_update", StringType(), True)}
261+
adds = {
262+
-1: [NestedField(18, "add_18", StringType(), True)],
263+
15: [NestedField(19, "name", StringType(), True), NestedField(20, "add_20", StringType(), True)],
264+
}
265+
266+
expected = NameMapping([
267+
MappedField(field_id=1, names=['foo', 'foo_update']),
268+
MappedField(field_id=2, names=['bar']),
269+
MappedField(field_id=3, names=['baz']),
270+
MappedField(field_id=4, names=['qux'], fields=[MappedField(field_id=5, names=['element'])]),
271+
MappedField(
272+
field_id=6,
273+
names=['quux'],
274+
fields=[
275+
MappedField(field_id=7, names=['key']),
276+
MappedField(
277+
field_id=8,
278+
names=['value'],
279+
fields=[
280+
MappedField(field_id=9, names=['key']),
281+
MappedField(field_id=10, names=['value']),
282+
],
283+
),
284+
],
285+
),
286+
MappedField(
287+
field_id=11,
288+
names=['location'],
289+
fields=[
290+
MappedField(
291+
field_id=12,
292+
names=['element'],
293+
fields=[
294+
MappedField(field_id=13, names=['latitude']),
295+
MappedField(field_id=14, names=['longitude']),
296+
],
297+
)
298+
],
299+
),
300+
MappedField(
301+
field_id=15,
302+
names=['person'],
303+
fields=[
304+
MappedField(field_id=17, names=['age']),
305+
MappedField(field_id=19, names=['name']),
306+
MappedField(field_id=20, names=['add_20']),
307+
],
308+
),
309+
MappedField(field_id=18, names=['add_18']),
310+
])
311+
assert update_mapping(table_name_mapping_nested, updates, adds) == expected

0 commit comments

Comments
 (0)