2626from abc import ABC , abstractmethod
2727from collections import ChainMap
2828from functools import cached_property , singledispatch
29- from typing import Any , Dict , Generic , List , TypeVar , Union
29+ from typing import Any , Dict , Generic , List , Optional , TypeVar , Union
3030
3131from pydantic import Field , conlist , field_validator , model_serializer
3232
@@ -45,6 +45,18 @@ class MappedField(IcebergBaseModel):
4545 def convert_null_to_empty_List (cls , v : Any ) -> Any :
4646 return v or []
4747
48+ @field_validator ('names' , mode = 'after' )
49+ @classmethod
50+ def check_at_least_one (cls , v : List [str ]) -> Any :
51+ """
52+ Conlist constraint does not seem to be validating the class on instantiation.
53+
54+ Adding a custom validator to enforce min_length=1 constraint.
55+ """
56+ if len (v ) < 1 :
57+ raise ValueError ("At least one mapped name must be provided for the field" )
58+ return v
59+
4860 @model_serializer
4961 def ser_model (self ) -> Dict [str , Any ]:
5062 """Set custom serializer to leave out the field when it is empty."""
@@ -93,24 +105,25 @@ def __str__(self) -> str:
93105 return "[\n " + "\n " .join ([str (e ) for e in self .root ]) + "\n ]"
94106
95107
108+ S = TypeVar ('S' )
96109T = TypeVar ("T" )
97110
98111
99- class NameMappingVisitor (Generic [T ], ABC ):
112+ class NameMappingVisitor (Generic [S , T ], ABC ):
100113 @abstractmethod
101- def mapping (self , nm : NameMapping , field_results : T ) -> T :
114+ def mapping (self , nm : NameMapping , field_results : S ) -> S :
102115 """Visit a NameMapping."""
103116
104117 @abstractmethod
105- def fields (self , struct : List [MappedField ], field_results : List [T ]) -> T :
118+ def fields (self , struct : List [MappedField ], field_results : List [T ]) -> S :
106119 """Visit a List[MappedField]."""
107120
108121 @abstractmethod
109- def field (self , field : MappedField , field_result : T ) -> T :
122+ def field (self , field : MappedField , field_result : S ) -> T :
110123 """Visit a MappedField."""
111124
112125
113- class _IndexByName (NameMappingVisitor [Dict [str , MappedField ]]):
126+ class _IndexByName (NameMappingVisitor [Dict [str , MappedField ], Dict [ str , MappedField ] ]):
114127 def mapping (self , nm : NameMapping , field_results : Dict [str , MappedField ]) -> Dict [str , MappedField ]:
115128 return field_results
116129
@@ -129,18 +142,18 @@ def field(self, field: MappedField, field_result: Dict[str, MappedField]) -> Dic
129142
130143
131144@singledispatch
132- def visit_name_mapping (obj : Union [NameMapping , List [MappedField ], MappedField ], visitor : NameMappingVisitor [T ]) -> T :
145+ def visit_name_mapping (obj : Union [NameMapping , List [MappedField ], MappedField ], visitor : NameMappingVisitor [S , T ]) -> S :
133146 """Traverse the name mapping in post-order traversal."""
134147 raise NotImplementedError (f"Cannot visit non-type: { obj } " )
135148
136149
137150@visit_name_mapping .register (NameMapping )
138- def _ (obj : NameMapping , visitor : NameMappingVisitor [T ]) -> T :
151+ def _ (obj : NameMapping , visitor : NameMappingVisitor [S , T ]) -> S :
139152 return visitor .mapping (obj , visit_name_mapping (obj .root , visitor ))
140153
141154
142155@visit_name_mapping .register (list )
143- def _ (fields : List [MappedField ], visitor : NameMappingVisitor [T ]) -> T :
156+ def _ (fields : List [MappedField ], visitor : NameMappingVisitor [S , T ]) -> S :
144157 results = [visitor .field (field , visit_name_mapping (field .fields , visitor )) for field in fields ]
145158 return visitor .fields (fields , results )
146159
@@ -175,5 +188,71 @@ def primitive(self, primitive: PrimitiveType) -> List[MappedField]:
175188 return []
176189
177190
191+ class _UpdateMapping (NameMappingVisitor [List [MappedField ], MappedField ]):
192+ _updates : Dict [int , NestedField ]
193+ _adds : Dict [int , List [NestedField ]]
194+
195+ def __init__ (self , updates : Dict [int , NestedField ], adds : Dict [int , List [NestedField ]]):
196+ self ._updates = updates
197+ self ._adds = adds
198+
199+ @staticmethod
200+ def _remove_reassigned_names (field : MappedField , assignments : Dict [str , int ]) -> Optional [MappedField ]:
201+ removed_names = set ()
202+ for name in field .names :
203+ if (assigned_id := assignments .get (name )) and assigned_id != field .field_id :
204+ removed_names .add (name )
205+
206+ remaining_names = [f for f in field .names if f not in removed_names ]
207+ if remaining_names :
208+ return MappedField (field_id = field .field_id , names = remaining_names , fields = field .fields )
209+ else :
210+ return None
211+
212+ def _add_new_fields (self , mapped_fields : List [MappedField ], parent_id : int ) -> List [MappedField ]:
213+ if fields_to_add := self ._adds .get (parent_id ):
214+ fields : List [MappedField ] = []
215+ new_fields : List [MappedField ] = []
216+
217+ for add in fields_to_add :
218+ new_fields .append (
219+ MappedField (field_id = add .field_id , names = [add .name ], fields = visit (add .field_type , _CreateMapping ()))
220+ )
221+
222+ reassignments = {f .name : f .field_id for f in fields_to_add }
223+ fields = [
224+ updated_field
225+ for field in mapped_fields
226+ if (updated_field := self ._remove_reassigned_names (field , reassignments )) is not None
227+ ] + new_fields
228+ return fields
229+ else :
230+ return mapped_fields
231+
232+ def mapping (self , nm : NameMapping , field_results : List [MappedField ]) -> List [MappedField ]:
233+ return self ._add_new_fields (field_results , - 1 )
234+
235+ def fields (self , struct : List [MappedField ], field_results : List [MappedField ]) -> List [MappedField ]:
236+ reassignments : Dict [str , int ] = {
237+ update .name : update .field_id for f in field_results if (update := self ._updates .get (f .field_id ))
238+ }
239+ return [
240+ updated_field
241+ for field in field_results
242+ if (updated_field := self ._remove_reassigned_names (field , reassignments )) is not None
243+ ]
244+
245+ def field (self , field : MappedField , field_result : List [MappedField ]) -> MappedField :
246+ field_names = field .names
247+ if (update := self ._updates .get (field .field_id )) is not None and update .name not in field_names :
248+ field_names .append (update .name )
249+
250+ return MappedField (field_id = field .field_id , names = field_names , fields = self ._add_new_fields (field_result , field .field_id ))
251+
252+
178253def create_mapping_from_schema (schema : Schema ) -> NameMapping :
179254 return NameMapping (visit (schema , _CreateMapping ()))
255+
256+
257+ def update_mapping (mapping : NameMapping , updates : Dict [int , NestedField ], adds : Dict [int , List [NestedField ]]) -> NameMapping :
258+ return NameMapping (visit_name_mapping (mapping , _UpdateMapping (updates , adds )))
0 commit comments