diff --git a/pyiceberg/table/name_mapping.py b/pyiceberg/table/name_mapping.py new file mode 100644 index 0000000000..b07fbc0735 --- /dev/null +++ b/pyiceberg/table/name_mapping.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Contains everything around the name mapping. + +More information can be found on here: +https://iceberg.apache.org/spec/#name-mapping-serialization +""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections import ChainMap +from functools import cached_property, singledispatch +from typing import Any, Dict, Generic, List, TypeVar, Union + +from pydantic import Field, conlist, field_validator, model_serializer + +from pyiceberg.schema import Schema, SchemaVisitor, visit +from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel +from pyiceberg.types import ListType, MapType, NestedField, PrimitiveType, StructType + + +class MappedField(IcebergBaseModel): + field_id: int = Field(alias="field-id") + names: List[str] = conlist(str, min_length=1) + fields: List[MappedField] = Field(default_factory=list) + + @field_validator('fields', mode='before') + @classmethod + def convert_null_to_empty_List(cls, v: Any) -> Any: + return v or [] + + @model_serializer + def ser_model(self) -> Dict[str, Any]: + """Set custom serializer to leave out the field when it is empty.""" + fields = {'fields': self.fields} if len(self.fields) > 0 else {} + return { + 'field-id': self.field_id, + 'names': self.names, + **fields, + } + + def __len__(self) -> int: + """Return the number of fields.""" + return len(self.fields) + + def __str__(self) -> str: + """Convert the mapped-field into a nicely formatted string.""" + # Otherwise the UTs fail because the order of the set can change + fields_str = ", ".join([str(e) for e in self.fields]) or "" + fields_str = " " + fields_str if fields_str else "" + return "([" + ", ".join(self.names) + "] -> " + (str(self.field_id) or "?") + fields_str + ")" + + +class NameMapping(IcebergRootModel[List[MappedField]]): + root: List[MappedField] + + @cached_property + def _field_by_name(self) -> Dict[str, MappedField]: + return visit_name_mapping(self, _IndexByName()) + + def find(self, *names: str) -> MappedField: + name = '.'.join(names) + try: + return self._field_by_name[name] + except KeyError as e: + raise ValueError(f"Could not find field with name: {name}") from e + + def __len__(self) -> int: + """Return the number of mappings.""" + return len(self.root) + + def __str__(self) -> str: + """Convert the name-mapping into a nicely formatted string.""" + if len(self.root) == 0: + return "[]" + else: + return "[\n " + "\n ".join([str(e) for e in self.root]) + "\n]" + + +T = TypeVar("T") + + +class NameMappingVisitor(Generic[T], ABC): + @abstractmethod + def mapping(self, nm: NameMapping, field_results: T) -> T: + """Visit a NameMapping.""" + + @abstractmethod + def fields(self, struct: List[MappedField], field_results: List[T]) -> T: + """Visit a List[MappedField].""" + + @abstractmethod + def field(self, field: MappedField, field_result: T) -> T: + """Visit a MappedField.""" + + +class _IndexByName(NameMappingVisitor[Dict[str, MappedField]]): + def mapping(self, nm: NameMapping, field_results: Dict[str, MappedField]) -> Dict[str, MappedField]: + return field_results + + def fields(self, struct: List[MappedField], field_results: List[Dict[str, MappedField]]) -> Dict[str, MappedField]: + return dict(ChainMap(*field_results)) + + def field(self, field: MappedField, field_result: Dict[str, MappedField]) -> Dict[str, MappedField]: + result: Dict[str, MappedField] = { + f"{field_name}.{key}": result_field for key, result_field in field_result.items() for field_name in field.names + } + + for name in field.names: + result[name] = field + + return result + + +@singledispatch +def visit_name_mapping(obj: Union[NameMapping, List[MappedField], MappedField], visitor: NameMappingVisitor[T]) -> T: + """Traverse the name mapping in post-order traversal.""" + raise NotImplementedError(f"Cannot visit non-type: {obj}") + + +@visit_name_mapping.register(NameMapping) +def _(obj: NameMapping, visitor: NameMappingVisitor[T]) -> T: + return visitor.mapping(obj, visit_name_mapping(obj.root, visitor)) + + +@visit_name_mapping.register(list) +def _(fields: List[MappedField], visitor: NameMappingVisitor[T]) -> T: + results = [visitor.field(field, visit_name_mapping(field.fields, visitor)) for field in fields] + return visitor.fields(fields, results) + + +def parse_mapping_from_json(mapping: str) -> NameMapping: + return NameMapping.model_validate_json(mapping) + + +class _CreateMapping(SchemaVisitor[List[MappedField]]): + def schema(self, schema: Schema, struct_result: List[MappedField]) -> List[MappedField]: + return struct_result + + def struct(self, struct: StructType, field_results: List[List[MappedField]]) -> List[MappedField]: + return [ + MappedField(field_id=field.field_id, names=[field.name], fields=result) + for field, result in zip(struct.fields, field_results) + ] + + def field(self, field: NestedField, field_result: List[MappedField]) -> List[MappedField]: + return field_result + + def list(self, list_type: ListType, element_result: List[MappedField]) -> List[MappedField]: + return [MappedField(field_id=list_type.element_id, names=["element"], fields=element_result)] + + def map(self, map_type: MapType, key_result: List[MappedField], value_result: List[MappedField]) -> List[MappedField]: + return [ + MappedField(field_id=map_type.key_id, names=["key"], fields=key_result), + MappedField(field_id=map_type.value_id, names=["value"], fields=value_result), + ] + + def primitive(self, primitive: PrimitiveType) -> List[MappedField]: + return [] + + +def create_mapping_from_schema(schema: Schema) -> NameMapping: + return NameMapping(visit(schema, _CreateMapping())) diff --git a/tests/table/test_name_mapping.py b/tests/table/test_name_mapping.py new file mode 100644 index 0000000000..37111a5e3e --- /dev/null +++ b/tests/table/test_name_mapping.py @@ -0,0 +1,246 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from pyiceberg.schema import Schema +from pyiceberg.table.name_mapping import MappedField, NameMapping, create_mapping_from_schema, parse_mapping_from_json + + +@pytest.fixture(scope="session") +def table_name_mapping_nested() -> NameMapping: + return NameMapping( + [ + MappedField(field_id=1, names=['foo']), + MappedField(field_id=2, names=['bar']), + MappedField(field_id=3, names=['baz']), + MappedField(field_id=4, names=['qux'], fields=[MappedField(field_id=5, names=['element'])]), + MappedField( + field_id=6, + names=['quux'], + fields=[ + MappedField(field_id=7, names=['key']), + MappedField( + field_id=8, + names=['value'], + fields=[ + MappedField(field_id=9, names=['key']), + MappedField(field_id=10, names=['value']), + ], + ), + ], + ), + MappedField( + field_id=11, + names=['location'], + fields=[ + MappedField( + field_id=12, + names=['element'], + fields=[ + MappedField(field_id=13, names=['latitude']), + MappedField(field_id=14, names=['longitude']), + ], + ) + ], + ), + MappedField( + field_id=15, + names=['person'], + fields=[ + MappedField(field_id=16, names=['name']), + MappedField(field_id=17, names=['age']), + ], + ), + ] + ) + + +def test_json_mapped_field_deserialization() -> None: + mapped_field = """{ + "field-id": 1, + "names": ["id", "record_id"] + } + """ + assert MappedField(field_id=1, names=['id', 'record_id']) == MappedField.model_validate_json(mapped_field) + + mapped_field_with_null_fields = """{ + "field-id": 1, + "names": ["id", "record_id"], + "fields": null + } + """ + assert MappedField(field_id=1, names=['id', 'record_id']) == MappedField.model_validate_json(mapped_field_with_null_fields) + + +def test_json_name_mapping_deserialization() -> None: + name_mapping = """ +[ + { + "field-id": 1, + "names": [ + "id", + "record_id" + ] + }, + { + "field-id": 2, + "names": [ + "data" + ] + }, + { + "field-id": 3, + "names": [ + "location" + ], + "fields": [ + { + "field-id": 4, + "names": [ + "latitude", + "lat" + ] + }, + { + "field-id": 5, + "names": [ + "longitude", + "long" + ] + } + ] + } +] + """ + + assert parse_mapping_from_json(name_mapping) == NameMapping( + [ + MappedField(field_id=1, names=['id', 'record_id']), + MappedField(field_id=2, names=['data']), + MappedField( + names=['location'], + field_id=3, + fields=[ + MappedField(field_id=4, names=['latitude', 'lat']), + MappedField(field_id=5, names=['longitude', 'long']), + ], + ), + ] + ) + + +def test_json_serialization(table_name_mapping_nested: NameMapping) -> None: + assert ( + table_name_mapping_nested.model_dump_json() + == """[{"field-id":1,"names":["foo"]},{"field-id":2,"names":["bar"]},{"field-id":3,"names":["baz"]},{"field-id":4,"names":["qux"],"fields":[{"field-id":5,"names":["element"]}]},{"field-id":6,"names":["quux"],"fields":[{"field-id":7,"names":["key"]},{"field-id":8,"names":["value"],"fields":[{"field-id":9,"names":["key"]},{"field-id":10,"names":["value"]}]}]},{"field-id":11,"names":["location"],"fields":[{"field-id":12,"names":["element"],"fields":[{"field-id":13,"names":["latitude"]},{"field-id":14,"names":["longitude"]}]}]},{"field-id":15,"names":["person"],"fields":[{"field-id":16,"names":["name"]},{"field-id":17,"names":["age"]}]}]""" + ) + + +def test_name_mapping_to_string() -> None: + nm = NameMapping( + [ + MappedField(field_id=1, names=['id', 'record_id']), + MappedField(field_id=2, names=['data']), + MappedField( + names=['location'], + field_id=3, + fields=[ + MappedField(field_id=4, names=['lat', 'latitude']), + MappedField(field_id=5, names=['long', 'longitude']), + ], + ), + ] + ) + + assert ( + str(nm) + == """[ + ([id, record_id] -> 1) + ([data] -> 2) + ([location] -> 3 ([lat, latitude] -> 4), ([long, longitude] -> 5)) +]""" + ) + + +def test_mapping_from_schema(table_schema_nested: Schema, table_name_mapping_nested: NameMapping) -> None: + nm = create_mapping_from_schema(table_schema_nested) + assert nm == table_name_mapping_nested + + +def test_mapping_by_name(table_name_mapping_nested: NameMapping) -> None: + assert table_name_mapping_nested._field_by_name == { + 'person.age': MappedField(field_id=17, names=['age']), + 'person.name': MappedField(field_id=16, names=['name']), + 'person': MappedField( + field_id=15, + names=['person'], + fields=[MappedField(field_id=16, names=['name']), MappedField(field_id=17, names=['age'])], + ), + 'location.element.longitude': MappedField(field_id=14, names=['longitude']), + 'location.element.latitude': MappedField(field_id=13, names=['latitude']), + 'location.element': MappedField( + field_id=12, + names=['element'], + fields=[MappedField(field_id=13, names=['latitude']), MappedField(field_id=14, names=['longitude'])], + ), + 'location': MappedField( + field_id=11, + names=['location'], + fields=[ + MappedField( + field_id=12, + names=['element'], + fields=[MappedField(field_id=13, names=['latitude']), MappedField(field_id=14, names=['longitude'])], + ) + ], + ), + 'quux.value.value': MappedField(field_id=10, names=['value']), + 'quux.value.key': MappedField(field_id=9, names=['key']), + 'quux.value': MappedField( + field_id=8, + names=['value'], + fields=[MappedField(field_id=9, names=['key']), MappedField(field_id=10, names=['value'])], + ), + 'quux.key': MappedField(field_id=7, names=['key']), + 'quux': MappedField( + field_id=6, + names=['quux'], + fields=[ + MappedField(field_id=7, names=['key']), + MappedField( + field_id=8, + names=['value'], + fields=[MappedField(field_id=9, names=['key']), MappedField(field_id=10, names=['value'])], + ), + ], + ), + 'qux.element': MappedField(field_id=5, names=['element']), + 'qux': MappedField(field_id=4, names=['qux'], fields=[MappedField(field_id=5, names=['element'])]), + 'baz': MappedField(field_id=3, names=['baz']), + 'bar': MappedField(field_id=2, names=['bar']), + 'foo': MappedField(field_id=1, names=['foo']), + } + + +def test_mapping_lookup_by_name(table_name_mapping_nested: NameMapping) -> None: + assert table_name_mapping_nested.find("foo") == MappedField(field_id=1, names=['foo']) + assert table_name_mapping_nested.find("location.element.latitude") == MappedField(field_id=13, names=['latitude']) + assert table_name_mapping_nested.find("location", "element", "latitude") == MappedField(field_id=13, names=['latitude']) + assert table_name_mapping_nested.find(*["location", "element", "latitude"]) == MappedField(field_id=13, names=['latitude']) + + with pytest.raises(ValueError, match="Could not find field with name: boom"): + table_name_mapping_nested.find("boom")