Skip to content

Commit dcc3d9f

Browse files
authored
Add name-mapping (#212)
1 parent 9490922 commit dcc3d9f

File tree

2 files changed

+424
-0
lines changed

2 files changed

+424
-0
lines changed

pyiceberg/table/name_mapping.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""
18+
Contains everything around the name mapping.
19+
20+
More information can be found on here:
21+
https://iceberg.apache.org/spec/#name-mapping-serialization
22+
"""
23+
from __future__ import annotations
24+
25+
from abc import ABC, abstractmethod
26+
from collections import ChainMap
27+
from functools import cached_property, singledispatch
28+
from typing import Any, Dict, Generic, List, TypeVar, Union
29+
30+
from pydantic import Field, conlist, field_validator, model_serializer
31+
32+
from pyiceberg.schema import Schema, SchemaVisitor, visit
33+
from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel
34+
from pyiceberg.types import ListType, MapType, NestedField, PrimitiveType, StructType
35+
36+
37+
class MappedField(IcebergBaseModel):
38+
field_id: int = Field(alias="field-id")
39+
names: List[str] = conlist(str, min_length=1)
40+
fields: List[MappedField] = Field(default_factory=list)
41+
42+
@field_validator('fields', mode='before')
43+
@classmethod
44+
def convert_null_to_empty_List(cls, v: Any) -> Any:
45+
return v or []
46+
47+
@model_serializer
48+
def ser_model(self) -> Dict[str, Any]:
49+
"""Set custom serializer to leave out the field when it is empty."""
50+
fields = {'fields': self.fields} if len(self.fields) > 0 else {}
51+
return {
52+
'field-id': self.field_id,
53+
'names': self.names,
54+
**fields,
55+
}
56+
57+
def __len__(self) -> int:
58+
"""Return the number of fields."""
59+
return len(self.fields)
60+
61+
def __str__(self) -> str:
62+
"""Convert the mapped-field into a nicely formatted string."""
63+
# Otherwise the UTs fail because the order of the set can change
64+
fields_str = ", ".join([str(e) for e in self.fields]) or ""
65+
fields_str = " " + fields_str if fields_str else ""
66+
return "([" + ", ".join(self.names) + "] -> " + (str(self.field_id) or "?") + fields_str + ")"
67+
68+
69+
class NameMapping(IcebergRootModel[List[MappedField]]):
70+
root: List[MappedField]
71+
72+
@cached_property
73+
def _field_by_name(self) -> Dict[str, MappedField]:
74+
return visit_name_mapping(self, _IndexByName())
75+
76+
def find(self, *names: str) -> MappedField:
77+
name = '.'.join(names)
78+
try:
79+
return self._field_by_name[name]
80+
except KeyError as e:
81+
raise ValueError(f"Could not find field with name: {name}") from e
82+
83+
def __len__(self) -> int:
84+
"""Return the number of mappings."""
85+
return len(self.root)
86+
87+
def __str__(self) -> str:
88+
"""Convert the name-mapping into a nicely formatted string."""
89+
if len(self.root) == 0:
90+
return "[]"
91+
else:
92+
return "[\n " + "\n ".join([str(e) for e in self.root]) + "\n]"
93+
94+
95+
T = TypeVar("T")
96+
97+
98+
class NameMappingVisitor(Generic[T], ABC):
99+
@abstractmethod
100+
def mapping(self, nm: NameMapping, field_results: T) -> T:
101+
"""Visit a NameMapping."""
102+
103+
@abstractmethod
104+
def fields(self, struct: List[MappedField], field_results: List[T]) -> T:
105+
"""Visit a List[MappedField]."""
106+
107+
@abstractmethod
108+
def field(self, field: MappedField, field_result: T) -> T:
109+
"""Visit a MappedField."""
110+
111+
112+
class _IndexByName(NameMappingVisitor[Dict[str, MappedField]]):
113+
def mapping(self, nm: NameMapping, field_results: Dict[str, MappedField]) -> Dict[str, MappedField]:
114+
return field_results
115+
116+
def fields(self, struct: List[MappedField], field_results: List[Dict[str, MappedField]]) -> Dict[str, MappedField]:
117+
return dict(ChainMap(*field_results))
118+
119+
def field(self, field: MappedField, field_result: Dict[str, MappedField]) -> Dict[str, MappedField]:
120+
result: Dict[str, MappedField] = {
121+
f"{field_name}.{key}": result_field for key, result_field in field_result.items() for field_name in field.names
122+
}
123+
124+
for name in field.names:
125+
result[name] = field
126+
127+
return result
128+
129+
130+
@singledispatch
131+
def visit_name_mapping(obj: Union[NameMapping, List[MappedField], MappedField], visitor: NameMappingVisitor[T]) -> T:
132+
"""Traverse the name mapping in post-order traversal."""
133+
raise NotImplementedError(f"Cannot visit non-type: {obj}")
134+
135+
136+
@visit_name_mapping.register(NameMapping)
137+
def _(obj: NameMapping, visitor: NameMappingVisitor[T]) -> T:
138+
return visitor.mapping(obj, visit_name_mapping(obj.root, visitor))
139+
140+
141+
@visit_name_mapping.register(list)
142+
def _(fields: List[MappedField], visitor: NameMappingVisitor[T]) -> T:
143+
results = [visitor.field(field, visit_name_mapping(field.fields, visitor)) for field in fields]
144+
return visitor.fields(fields, results)
145+
146+
147+
def parse_mapping_from_json(mapping: str) -> NameMapping:
148+
return NameMapping.model_validate_json(mapping)
149+
150+
151+
class _CreateMapping(SchemaVisitor[List[MappedField]]):
152+
def schema(self, schema: Schema, struct_result: List[MappedField]) -> List[MappedField]:
153+
return struct_result
154+
155+
def struct(self, struct: StructType, field_results: List[List[MappedField]]) -> List[MappedField]:
156+
return [
157+
MappedField(field_id=field.field_id, names=[field.name], fields=result)
158+
for field, result in zip(struct.fields, field_results)
159+
]
160+
161+
def field(self, field: NestedField, field_result: List[MappedField]) -> List[MappedField]:
162+
return field_result
163+
164+
def list(self, list_type: ListType, element_result: List[MappedField]) -> List[MappedField]:
165+
return [MappedField(field_id=list_type.element_id, names=["element"], fields=element_result)]
166+
167+
def map(self, map_type: MapType, key_result: List[MappedField], value_result: List[MappedField]) -> List[MappedField]:
168+
return [
169+
MappedField(field_id=map_type.key_id, names=["key"], fields=key_result),
170+
MappedField(field_id=map_type.value_id, names=["value"], fields=value_result),
171+
]
172+
173+
def primitive(self, primitive: PrimitiveType) -> List[MappedField]:
174+
return []
175+
176+
177+
def create_mapping_from_schema(schema: Schema) -> NameMapping:
178+
return NameMapping(visit(schema, _CreateMapping()))

0 commit comments

Comments
 (0)