Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,7 @@
TableMetadata,
TableMetadataUtil,
)
from pyiceberg.table.name_mapping import (
NameMapping,
create_mapping_from_schema,
parse_mapping_from_json,
)
from pyiceberg.table.name_mapping import NameMapping, parse_mapping_from_json, update_mapping
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
from pyiceberg.table.snapshots import (
Operation,
Expand Down Expand Up @@ -994,12 +990,12 @@ def update_snapshot(self) -> UpdateSnapshot:
"""
return UpdateSnapshot(self)

def name_mapping(self) -> NameMapping:
def name_mapping(self) -> Optional[NameMapping]:
"""Return the table's field-id NameMapping."""
if name_mapping_json := self.properties.get(TableProperties.DEFAULT_NAME_MAPPING):
return parse_mapping_from_json(name_mapping_json)
else:
return create_mapping_from_schema(self.schema())
return None

def append(self, df: pa.Table) -> None:
"""
Expand Down Expand Up @@ -1950,6 +1946,12 @@ def commit(self) -> None:
else:
updates = (SetCurrentSchemaUpdate(schema_id=existing_schema_id),) # type: ignore

if name_mapping := self._table.name_mapping():
updated_name_mapping = update_mapping(name_mapping, self._updates, self._adds)
updates += ( # type: ignore
SetPropertiesUpdate(updates={TableProperties.DEFAULT_NAME_MAPPING: updated_name_mapping.model_dump_json()}),
)

if self._transaction is not None:
self._transaction._append_updates(*updates) # pylint: disable=W0212
self._transaction._append_requirements(*requirements) # pylint: disable=W0212
Expand Down
97 changes: 88 additions & 9 deletions pyiceberg/table/name_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from abc import ABC, abstractmethod
from collections import ChainMap
from functools import cached_property, singledispatch
from typing import Any, Dict, Generic, List, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union

from pydantic import Field, conlist, field_validator, model_serializer

Expand All @@ -45,6 +45,18 @@ class MappedField(IcebergBaseModel):
def convert_null_to_empty_List(cls, v: Any) -> Any:
return v or []

@field_validator('names', mode='after')
@classmethod
def check_at_least_one(cls, v: List[str]) -> Any:
"""
Conlist constraint does not seem to be validating the class on instantiation.

Adding a custom validator to enforce min_length=1 constraint.
"""
if len(v) < 1:
raise ValueError("At least one mapped name must be provided for the field")
return v

@model_serializer
def ser_model(self) -> Dict[str, Any]:
"""Set custom serializer to leave out the field when it is empty."""
Expand Down Expand Up @@ -93,24 +105,25 @@ def __str__(self) -> str:
return "[\n " + "\n ".join([str(e) for e in self.root]) + "\n]"


S = TypeVar('S')
T = TypeVar("T")


class NameMappingVisitor(Generic[T], ABC):
class NameMappingVisitor(Generic[S, T], ABC):
@abstractmethod
def mapping(self, nm: NameMapping, field_results: T) -> T:
def mapping(self, nm: NameMapping, field_results: S) -> S:
"""Visit a NameMapping."""

@abstractmethod
def fields(self, struct: List[MappedField], field_results: List[T]) -> T:
def fields(self, struct: List[MappedField], field_results: List[T]) -> S:
"""Visit a List[MappedField]."""

@abstractmethod
def field(self, field: MappedField, field_result: T) -> T:
def field(self, field: MappedField, field_result: S) -> T:
"""Visit a MappedField."""


class _IndexByName(NameMappingVisitor[Dict[str, MappedField]]):
class _IndexByName(NameMappingVisitor[Dict[str, MappedField], Dict[str, MappedField]]):
def mapping(self, nm: NameMapping, field_results: Dict[str, MappedField]) -> Dict[str, MappedField]:
return field_results

Expand All @@ -129,18 +142,18 @@ def field(self, field: MappedField, field_result: Dict[str, MappedField]) -> Dic


@singledispatch
def visit_name_mapping(obj: Union[NameMapping, List[MappedField], MappedField], visitor: NameMappingVisitor[T]) -> T:
def visit_name_mapping(obj: Union[NameMapping, List[MappedField], MappedField], visitor: NameMappingVisitor[S, T]) -> S:
"""Traverse the name mapping in post-order traversal."""
raise NotImplementedError(f"Cannot visit non-type: {obj}")


@visit_name_mapping.register(NameMapping)
def _(obj: NameMapping, visitor: NameMappingVisitor[T]) -> T:
def _(obj: NameMapping, visitor: NameMappingVisitor[S, T]) -> S:
return visitor.mapping(obj, visit_name_mapping(obj.root, visitor))


@visit_name_mapping.register(list)
def _(fields: List[MappedField], visitor: NameMappingVisitor[T]) -> T:
def _(fields: List[MappedField], visitor: NameMappingVisitor[S, T]) -> S:
results = [visitor.field(field, visit_name_mapping(field.fields, visitor)) for field in fields]
return visitor.fields(fields, results)

Expand Down Expand Up @@ -175,5 +188,71 @@ def primitive(self, primitive: PrimitiveType) -> List[MappedField]:
return []


class _UpdateMapping(NameMappingVisitor[List[MappedField], MappedField]):
_updates: Dict[int, NestedField]
_adds: Dict[int, List[NestedField]]

def __init__(self, updates: Dict[int, NestedField], adds: Dict[int, List[NestedField]]):
self._updates = updates
self._adds = adds

@staticmethod
def _remove_reassigned_names(field: MappedField, assignments: Dict[str, int]) -> Optional[MappedField]:
removed_names = set()
for name in field.names:
if (assigned_id := assignments.get(name)) and assigned_id != field.field_id:
removed_names.add(name)

remaining_names = [f for f in field.names if f not in removed_names]
if remaining_names:
return MappedField(field_id=field.field_id, names=remaining_names, fields=field.fields)
else:
return None

def _add_new_fields(self, mapped_fields: List[MappedField], parent_id: int) -> List[MappedField]:
if fields_to_add := self._adds.get(parent_id):
fields: List[MappedField] = []
new_fields: List[MappedField] = []

for add in fields_to_add:
new_fields.append(
MappedField(field_id=add.field_id, names=[add.name], fields=visit(add.field_type, _CreateMapping()))
)

reassignments = {f.name: f.field_id for f in fields_to_add}
fields = [
updated_field
for field in mapped_fields
if (updated_field := self._remove_reassigned_names(field, reassignments)) is not None
] + new_fields
return fields
else:
return mapped_fields

def mapping(self, nm: NameMapping, field_results: List[MappedField]) -> List[MappedField]:
return self._add_new_fields(field_results, -1)

def fields(self, struct: List[MappedField], field_results: List[MappedField]) -> List[MappedField]:
reassignments: Dict[str, int] = {
update.name: update.field_id for f in field_results if (update := self._updates.get(f.field_id))
}
return [
updated_field
for field in field_results
if (updated_field := self._remove_reassigned_names(field, reassignments)) is not None
]

def field(self, field: MappedField, field_result: List[MappedField]) -> MappedField:
field_names = field.names
if (update := self._updates.get(field.field_id)) is not None and update.name not in field_names:
field_names.append(update.name)

return MappedField(field_id=field.field_id, names=field_names, fields=self._add_new_fields(field_result, field.field_id))


def create_mapping_from_schema(schema: Schema) -> NameMapping:
return NameMapping(visit(schema, _CreateMapping()))


def update_mapping(mapping: NameMapping, updates: Dict[int, NestedField], adds: Dict[int, List[NestedField]]) -> NameMapping:
return NameMapping(visit_name_mapping(mapping, _UpdateMapping(updates, adds)))
21 changes: 19 additions & 2 deletions tests/integration/test_rest_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from pyiceberg.exceptions import CommitFailedException, NoSuchTableError, ValidationError
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema, prune_columns
from pyiceberg.table import Table, UpdateSchema
from pyiceberg.table import Table, TableProperties, UpdateSchema
from pyiceberg.table.name_mapping import MappedField, NameMapping, create_mapping_from_schema
from pyiceberg.table.sorting import SortField, SortOrder
from pyiceberg.transforms import IdentityTransform
from pyiceberg.types import (
Expand Down Expand Up @@ -73,7 +74,11 @@ def _create_table_with_schema(catalog: Catalog, schema: Schema) -> Table:
catalog.drop_table(tbl_name)
except NoSuchTableError:
pass
return catalog.create_table(identifier=tbl_name, schema=schema)
return catalog.create_table(
identifier=tbl_name,
schema=schema,
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
)


@pytest.mark.integration
Expand Down Expand Up @@ -674,6 +679,13 @@ def test_rename_simple(simple_table: Table) -> None:
identifier_field_ids=[2],
)

# Check that the name mapping gets updated
assert simple_table.name_mapping() == NameMapping([
MappedField(field_id=1, names=['foo', 'vo']),
MappedField(field_id=2, names=['bar']),
MappedField(field_id=3, names=['baz']),
])


@pytest.mark.integration
def test_rename_simple_nested(catalog: Catalog) -> None:
Expand Down Expand Up @@ -701,6 +713,11 @@ def test_rename_simple_nested(catalog: Catalog) -> None:
),
)

# Check that the name mapping gets updated
assert tbl.name_mapping() == NameMapping([
MappedField(field_id=1, names=['foo'], fields=[MappedField(field_id=2, names=['bar', 'vo'])]),
])


@pytest.mark.integration
def test_rename_simple_nested_with_dots(catalog: Catalog) -> None:
Expand Down
73 changes: 72 additions & 1 deletion tests/table/test_name_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
import pytest

from pyiceberg.schema import Schema
from pyiceberg.table.name_mapping import MappedField, NameMapping, create_mapping_from_schema, parse_mapping_from_json
from pyiceberg.table.name_mapping import (
MappedField,
NameMapping,
create_mapping_from_schema,
parse_mapping_from_json,
update_mapping,
)
from pyiceberg.types import NestedField, StringType


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -238,3 +245,67 @@ def test_mapping_lookup_by_name(table_name_mapping_nested: NameMapping) -> None:

with pytest.raises(ValueError, match="Could not find field with name: boom"):
table_name_mapping_nested.find("boom")


def test_invalid_mapped_field() -> None:
with pytest.raises(ValueError):
MappedField(field_id=1, names=[])


def test_update_mapping_no_updates_or_adds(table_name_mapping_nested: NameMapping) -> None:
assert update_mapping(table_name_mapping_nested, {}, {}) == table_name_mapping_nested


def test_update_mapping(table_name_mapping_nested: NameMapping) -> None:
updates = {1: NestedField(1, "foo_update", StringType(), True)}
adds = {
-1: [NestedField(18, "add_18", StringType(), True)],
15: [NestedField(19, "name", StringType(), True), NestedField(20, "add_20", StringType(), True)],
}

expected = NameMapping([
MappedField(field_id=1, names=['foo', 'foo_update']),
MappedField(field_id=2, names=['bar']),
MappedField(field_id=3, names=['baz']),
MappedField(field_id=4, names=['qux'], fields=[MappedField(field_id=5, names=['element'])]),
MappedField(
field_id=6,
names=['quux'],
fields=[
MappedField(field_id=7, names=['key']),
MappedField(
field_id=8,
names=['value'],
fields=[
MappedField(field_id=9, names=['key']),
MappedField(field_id=10, names=['value']),
],
),
],
),
MappedField(
field_id=11,
names=['location'],
fields=[
MappedField(
field_id=12,
names=['element'],
fields=[
MappedField(field_id=13, names=['latitude']),
MappedField(field_id=14, names=['longitude']),
],
)
],
),
MappedField(
field_id=15,
names=['person'],
fields=[
MappedField(field_id=17, names=['age']),
MappedField(field_id=19, names=['name']),
MappedField(field_id=20, names=['add_20']),
],
),
MappedField(field_id=18, names=['add_18']),
])
assert update_mapping(table_name_mapping_nested, updates, adds) == expected