6969)
7070from pyiceberg .partitioning import PartitionSpec
7171from pyiceberg .schema import (
72+ PartnerAccessor ,
7273 Schema ,
7374 SchemaVisitor ,
75+ SchemaWithPartnerVisitor ,
7476 assign_fresh_schema_ids ,
7577 promote ,
7678 visit ,
79+ visit_with_partner ,
7780)
7881from pyiceberg .table .metadata import (
7982 INITIAL_SEQUENCE_NUMBER ,
@@ -1379,7 +1382,7 @@ class Move:
13791382
13801383
13811384class UpdateSchema :
1382- _table : Table
1385+ _table : Optional [ Table ]
13831386 _schema : Schema
13841387 _last_column_id : itertools .count [int ]
13851388 _identifier_field_names : Set [str ]
@@ -1398,14 +1401,23 @@ class UpdateSchema:
13981401
13991402 def __init__ (
14001403 self ,
1401- table : Table ,
1404+ table : Optional [ Table ] ,
14021405 transaction : Optional [Transaction ] = None ,
14031406 allow_incompatible_changes : bool = False ,
14041407 case_sensitive : bool = True ,
1408+ schema : Optional [Schema ] = None ,
14051409 ) -> None :
14061410 self ._table = table
1407- self ._schema = table .schema ()
1408- self ._last_column_id = itertools .count (table .metadata .last_column_id + 1 )
1411+
1412+ if isinstance (schema , Schema ):
1413+ self ._schema = schema
1414+ self ._last_column_id = itertools .count (1 + schema .highest_field_id )
1415+ elif table is not None :
1416+ self ._schema = table .schema ()
1417+ self ._last_column_id = itertools .count (1 + table .metadata .last_column_id )
1418+ else :
1419+ raise ValueError ("Either provide a table or a schema" )
1420+
14091421 self ._identifier_field_names = self ._schema .identifier_field_names ()
14101422
14111423 self ._adds = {}
@@ -1449,6 +1461,15 @@ def case_sensitive(self, case_sensitive: bool) -> UpdateSchema:
14491461 self ._case_sensitive = case_sensitive
14501462 return self
14511463
1464+ def union_by_name (self , new_schema : Schema ) -> UpdateSchema :
1465+ visit_with_partner (
1466+ new_schema ,
1467+ - 1 ,
1468+ UnionByNameVisitor (update_schema = self , existing_schema = self ._schema , case_sensitive = self ._case_sensitive ), # type: ignore
1469+ PartnerIdByNameAccessor (partner_schema = self ._schema , case_sensitive = self ._case_sensitive ),
1470+ )
1471+ return self
1472+
14521473 def add_column (
14531474 self , path : Union [str , Tuple [str , ...]], field_type : IcebergType , doc : Optional [str ] = None , required : bool = False
14541475 ) -> UpdateSchema :
@@ -1816,6 +1837,9 @@ def move_after(self, path: Union[str, Tuple[str, ...]], after_name: Union[str, T
18161837
18171838 def commit (self ) -> None :
18181839 """Apply the pending changes and commit."""
1840+ if self ._table is None :
1841+ raise ValueError ("Requires a table to commit to" )
1842+
18191843 new_schema = self ._apply ()
18201844
18211845 existing_schema_id = next ((schema .schema_id for schema in self ._table .metadata .schemas if schema == new_schema ), None )
@@ -1862,7 +1886,8 @@ def _apply(self) -> Schema:
18621886
18631887 field_ids .add (field .field_id )
18641888
1865- return Schema (* struct .fields , schema_id = 1 + max (self ._table .schemas ().keys ()), identifier_field_ids = field_ids )
1889+ next_schema_id = 1 + (max (self ._table .schemas ().keys ()) if self ._table is not None else self ._schema .schema_id )
1890+ return Schema (* struct .fields , schema_id = next_schema_id , identifier_field_ids = field_ids )
18661891
18671892 def assign_new_column_id (self ) -> int :
18681893 return next (self ._last_column_id )
@@ -1995,6 +2020,159 @@ def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]:
19952020 return primitive
19962021
19972022
2023+ class UnionByNameVisitor (SchemaWithPartnerVisitor [int , bool ]):
2024+ update_schema : UpdateSchema
2025+ existing_schema : Schema
2026+ case_sensitive : bool
2027+
2028+ def __init__ (self , update_schema : UpdateSchema , existing_schema : Schema , case_sensitive : bool ) -> None :
2029+ self .update_schema = update_schema
2030+ self .existing_schema = existing_schema
2031+ self .case_sensitive = case_sensitive
2032+
2033+ def schema (self , schema : Schema , partner_id : Optional [int ], struct_result : bool ) -> bool :
2034+ return struct_result
2035+
2036+ def struct (self , struct : StructType , partner_id : Optional [int ], missing_positions : List [bool ]) -> bool :
2037+ if partner_id is None :
2038+ return True
2039+
2040+ fields = struct .fields
2041+ partner_struct = self ._find_field_type (partner_id )
2042+
2043+ if not partner_struct .is_struct :
2044+ raise ValueError (f"Expected a struct, got: { partner_struct } " )
2045+
2046+ for pos , missing in enumerate (missing_positions ):
2047+ if missing :
2048+ self ._add_column (partner_id , fields [pos ])
2049+ else :
2050+ field = fields [pos ]
2051+ if nested_field := partner_struct .field_by_name (field .name , case_sensitive = self .case_sensitive ):
2052+ self ._update_column (field , nested_field )
2053+
2054+ return False
2055+
2056+ def _add_column (self , parent_id : int , field : NestedField ) -> None :
2057+ if parent_name := self .existing_schema .find_column_name (parent_id ):
2058+ path : Tuple [str , ...] = (parent_name , field .name )
2059+ else :
2060+ path = (field .name ,)
2061+
2062+ self .update_schema .add_column (path = path , field_type = field .field_type , required = field .required , doc = field .doc )
2063+
2064+ def _update_column (self , field : NestedField , existing_field : NestedField ) -> None :
2065+ full_name = self .existing_schema .find_column_name (existing_field .field_id )
2066+
2067+ if full_name is None :
2068+ raise ValueError (f"Could not find field: { existing_field } " )
2069+
2070+ if field .optional and existing_field .required :
2071+ self .update_schema .make_column_optional (full_name )
2072+
2073+ if field .field_type .is_primitive and field .field_type != existing_field .field_type :
2074+ self .update_schema .update_column (full_name , field_type = field .field_type )
2075+
2076+ if field .doc is not None and not field .doc != existing_field .doc :
2077+ self .update_schema .update_column (full_name , doc = field .doc )
2078+
2079+ def _find_field_type (self , field_id : int ) -> IcebergType :
2080+ if field_id == - 1 :
2081+ return self .existing_schema .as_struct ()
2082+ else :
2083+ return self .existing_schema .find_field (field_id ).field_type
2084+
2085+ def field (self , field : NestedField , partner_id : Optional [int ], field_result : bool ) -> bool :
2086+ return partner_id is None
2087+
2088+ def list (self , list_type : ListType , list_partner_id : Optional [int ], element_missing : bool ) -> bool :
2089+ if list_partner_id is None :
2090+ return True
2091+
2092+ if element_missing :
2093+ raise ValueError ("Error traversing schemas: element is missing, but list is present" )
2094+
2095+ partner_list_type = self ._find_field_type (list_partner_id )
2096+ if not isinstance (partner_list_type , ListType ):
2097+ raise ValueError (f"Expected list-type, got: { partner_list_type } " )
2098+
2099+ self ._update_column (list_type .element_field , partner_list_type .element_field )
2100+
2101+ return False
2102+
2103+ def map (self , map_type : MapType , map_partner_id : Optional [int ], key_missing : bool , value_missing : bool ) -> bool :
2104+ if map_partner_id is None :
2105+ return True
2106+
2107+ if key_missing :
2108+ raise ValueError ("Error traversing schemas: key is missing, but map is present" )
2109+
2110+ if value_missing :
2111+ raise ValueError ("Error traversing schemas: value is missing, but map is present" )
2112+
2113+ partner_map_type = self ._find_field_type (map_partner_id )
2114+ if not isinstance (partner_map_type , MapType ):
2115+ raise ValueError (f"Expected map-type, got: { partner_map_type } " )
2116+
2117+ self ._update_column (map_type .key_field , partner_map_type .key_field )
2118+ self ._update_column (map_type .value_field , partner_map_type .value_field )
2119+
2120+ return False
2121+
2122+ def primitive (self , primitive : PrimitiveType , primitive_partner_id : Optional [int ]) -> bool :
2123+ return primitive_partner_id is None
2124+
2125+
2126+ class PartnerIdByNameAccessor (PartnerAccessor [int ]):
2127+ partner_schema : Schema
2128+ case_sensitive : bool
2129+
2130+ def __init__ (self , partner_schema : Schema , case_sensitive : bool ) -> None :
2131+ self .partner_schema = partner_schema
2132+ self .case_sensitive = case_sensitive
2133+
2134+ def schema_partner (self , partner : Optional [int ]) -> Optional [int ]:
2135+ return - 1
2136+
2137+ def field_partner (self , partner_field_id : Optional [int ], field_id : int , field_name : str ) -> Optional [int ]:
2138+ if partner_field_id is not None :
2139+ if partner_field_id == - 1 :
2140+ struct = self .partner_schema .as_struct ()
2141+ else :
2142+ struct = self .partner_schema .find_field (partner_field_id ).field_type
2143+ if not struct .is_struct :
2144+ raise ValueError (f"Expected StructType: { struct } " )
2145+
2146+ if field := struct .field_by_name (name = field_name , case_sensitive = self .case_sensitive ):
2147+ return field .field_id
2148+
2149+ return None
2150+
2151+ def list_element_partner (self , partner_list_id : Optional [int ]) -> Optional [int ]:
2152+ if partner_list_id is not None and (field := self .partner_schema .find_field (partner_list_id )):
2153+ if not isinstance (field .field_type , ListType ):
2154+ raise ValueError (f"Expected ListType: { field } " )
2155+ return field .field_type .element_field .field_id
2156+ else :
2157+ return None
2158+
2159+ def map_key_partner (self , partner_map_id : Optional [int ]) -> Optional [int ]:
2160+ if partner_map_id is not None and (field := self .partner_schema .find_field (partner_map_id )):
2161+ if not isinstance (field .field_type , MapType ):
2162+ raise ValueError (f"Expected MapType: { field } " )
2163+ return field .field_type .key_field .field_id
2164+ else :
2165+ return None
2166+
2167+ def map_value_partner (self , partner_map_id : Optional [int ]) -> Optional [int ]:
2168+ if partner_map_id is not None and (field := self .partner_schema .find_field (partner_map_id )):
2169+ if not isinstance (field .field_type , MapType ):
2170+ raise ValueError (f"Expected MapType: { field } " )
2171+ return field .field_type .value_field .field_id
2172+ else :
2173+ return None
2174+
2175+
19982176def _add_fields (fields : Tuple [NestedField , ...], adds : Optional [List [NestedField ]]) -> Tuple [NestedField , ...]:
19992177 adds = adds or []
20002178 return fields + tuple (adds )
0 commit comments