2020from copy import copy
2121from dataclasses import dataclass
2222from enum import Enum
23- from typing import TYPE_CHECKING , Dict , List , Optional , Set , Tuple , Union
23+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Set , Tuple , Union
2424
2525from pyiceberg .exceptions import ResolveError , ValidationError
26+ from pyiceberg .expressions import literal # type: ignore
2627from pyiceberg .schema import (
2728 PartnerAccessor ,
2829 Schema ,
4748 UpdatesAndRequirements ,
4849 UpdateTableMetadata ,
4950)
51+ from pyiceberg .typedef import L
5052from pyiceberg .types import IcebergType , ListType , MapType , NestedField , PrimitiveType , StructType
5153
5254if TYPE_CHECKING :
@@ -153,7 +155,12 @@ def union_by_name(self, new_schema: Union[Schema, "pa.Schema"]) -> UpdateSchema:
153155 return self
154156
155157 def add_column (
156- self , path : Union [str , Tuple [str , ...]], field_type : IcebergType , doc : Optional [str ] = None , required : bool = False
158+ self ,
159+ path : Union [str , Tuple [str , ...]],
160+ field_type : IcebergType ,
161+ doc : Optional [str ] = None ,
162+ required : bool = False ,
163+ default_value : Optional [L ] = None ,
157164 ) -> UpdateSchema :
158165 """Add a new column to a nested struct or Add a new top-level column.
159166
@@ -168,6 +175,7 @@ def add_column(
168175 field_type: Type for the new column.
169176 doc: Documentation string for the new column.
170177 required: Whether the new column is required.
178+ default_value: Default value for the new column.
171179
172180 Returns:
173181 This for method chaining.
@@ -177,10 +185,6 @@ def add_column(
177185 raise ValueError (f"Cannot add column with ambiguous name: { path } , provide a tuple instead" )
178186 path = (path ,)
179187
180- if required and not self ._allow_incompatible_changes :
181- # Table format version 1 and 2 cannot add required column because there is no initial value
182- raise ValueError (f"Incompatible change: cannot add required column: { '.' .join (path )} " )
183-
184188 name = path [- 1 ]
185189 parent = path [:- 1 ]
186190
@@ -212,13 +216,34 @@ def add_column(
212216
213217 # assign new IDs in order
214218 new_id = self .assign_new_column_id ()
219+ new_type = assign_fresh_schema_ids (field_type , self .assign_new_column_id )
220+
221+ if default_value is not None :
222+ try :
223+ # To make sure that the value is valid for the type
224+ initial_default = literal (default_value ).to (new_type ).value
225+ except ValueError as e :
226+ raise ValueError (f"Invalid default value: { e } " ) from e
227+ else :
228+ initial_default = default_value # type: ignore
229+
230+ if (required and initial_default is None ) and not self ._allow_incompatible_changes :
231+ # Table format version 1 and 2 cannot add required column because there is no initial value
232+ raise ValueError (f"Incompatible change: cannot add required column: { '.' .join (path )} " )
215233
216234 # update tracking for moves
217235 self ._added_name_to_id [full_name ] = new_id
218236 self ._id_to_parent [new_id ] = parent_full_path
219237
220- new_type = assign_fresh_schema_ids (field_type , self .assign_new_column_id )
221- field = NestedField (field_id = new_id , name = name , field_type = new_type , required = required , doc = doc )
238+ field = NestedField (
239+ field_id = new_id ,
240+ name = name ,
241+ field_type = new_type ,
242+ required = required ,
243+ doc = doc ,
244+ initial_default = initial_default ,
245+ write_default = initial_default ,
246+ )
222247
223248 if parent_id in self ._adds :
224249 self ._adds [parent_id ].append (field )
@@ -250,6 +275,19 @@ def delete_column(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema:
250275
251276 return self
252277
278+ def set_default_value (self , path : Union [str , Tuple [str , ...]], default_value : Optional [L ]) -> UpdateSchema :
279+ """Set the default value of a column.
280+
281+ Args:
282+ path: The path to the column.
283+
284+ Returns:
285+ The UpdateSchema with the delete operation staged.
286+ """
287+ self ._set_column_default_value (path , default_value )
288+
289+ return self
290+
253291 def rename_column (self , path_from : Union [str , Tuple [str , ...]], new_name : str ) -> UpdateSchema :
254292 """Update the name of a column.
255293
@@ -273,6 +311,8 @@ def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) -
273311 field_type = updated .field_type ,
274312 doc = updated .doc ,
275313 required = updated .required ,
314+ initial_default = updated .initial_default ,
315+ write_default = updated .write_default ,
276316 )
277317 else :
278318 self ._updates [field_from .field_id ] = NestedField (
@@ -281,6 +321,8 @@ def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) -
281321 field_type = field_from .field_type ,
282322 doc = field_from .doc ,
283323 required = field_from .required ,
324+ initial_default = field_from .initial_default ,
325+ write_default = field_from .write_default ,
284326 )
285327
286328 # Lookup the field because of casing
@@ -330,6 +372,8 @@ def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: b
330372 field_type = updated .field_type ,
331373 doc = updated .doc ,
332374 required = required ,
375+ initial_default = updated .initial_default ,
376+ write_default = updated .write_default ,
333377 )
334378 else :
335379 self ._updates [field .field_id ] = NestedField (
@@ -338,6 +382,52 @@ def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: b
338382 field_type = field .field_type ,
339383 doc = field .doc ,
340384 required = required ,
385+ initial_default = field .initial_default ,
386+ write_default = field .write_default ,
387+ )
388+
389+ def _set_column_default_value (self , path : Union [str , Tuple [str , ...]], default_value : Any ) -> None :
390+ path = (path ,) if isinstance (path , str ) else path
391+ name = "." .join (path )
392+
393+ field = self ._schema .find_field (name , self ._case_sensitive )
394+
395+ if default_value is not None :
396+ try :
397+ # To make sure that the value is valid for the type
398+ default_value = literal (default_value ).to (field .field_type ).value
399+ except ValueError as e :
400+ raise ValueError (f"Invalid default value: { e } " ) from e
401+
402+ if field .required and default_value == field .write_default :
403+ # if the change is a noop, allow it even if allowIncompatibleChanges is false
404+ return
405+
406+ if not self ._allow_incompatible_changes and field .required and default_value is None :
407+ raise ValueError ("Cannot change change default-value of a required column to None" )
408+
409+ if field .field_id in self ._deletes :
410+ raise ValueError (f"Cannot update a column that will be deleted: { name } " )
411+
412+ if updated := self ._updates .get (field .field_id ):
413+ self ._updates [field .field_id ] = NestedField (
414+ field_id = updated .field_id ,
415+ name = updated .name ,
416+ field_type = updated .field_type ,
417+ doc = updated .doc ,
418+ required = updated .required ,
419+ initial_default = updated .initial_default ,
420+ write_default = default_value ,
421+ )
422+ else :
423+ self ._updates [field .field_id ] = NestedField (
424+ field_id = field .field_id ,
425+ name = field .name ,
426+ field_type = field .field_type ,
427+ doc = field .doc ,
428+ required = field .required ,
429+ initial_default = field .initial_default ,
430+ write_default = default_value ,
341431 )
342432
343433 def update_column (
@@ -387,6 +477,8 @@ def update_column(
387477 field_type = field_type or updated .field_type ,
388478 doc = doc if doc is not None else updated .doc ,
389479 required = updated .required ,
480+ initial_default = updated .initial_default ,
481+ write_default = updated .write_default ,
390482 )
391483 else :
392484 self ._updates [field .field_id ] = NestedField (
@@ -395,6 +487,8 @@ def update_column(
395487 field_type = field_type or field .field_type ,
396488 doc = doc if doc is not None else field .doc ,
397489 required = field .required ,
490+ initial_default = field .initial_default ,
491+ write_default = field .write_default ,
398492 )
399493
400494 if required is not None :
@@ -636,19 +730,35 @@ def struct(self, struct: StructType, field_results: List[Optional[IcebergType]])
636730 name = field .name
637731 doc = field .doc
638732 required = field .required
733+ write_default = field .write_default
639734
640735 # There is an update
641736 if update := self ._updates .get (field .field_id ):
642737 name = update .name
643738 doc = update .doc
644739 required = update .required
645-
646- if field .name == name and field .field_type == result_type and field .required == required and field .doc == doc :
740+ write_default = update .write_default
741+
742+ if (
743+ field .name == name
744+ and field .field_type == result_type
745+ and field .required == required
746+ and field .doc == doc
747+ and field .write_default == write_default
748+ ):
647749 new_fields .append (field )
648750 else :
649751 has_changes = True
650752 new_fields .append (
651- NestedField (field_id = field .field_id , name = name , field_type = result_type , required = required , doc = doc )
753+ NestedField (
754+ field_id = field .field_id ,
755+ name = name ,
756+ field_type = result_type ,
757+ required = required ,
758+ doc = doc ,
759+ initial_default = field .initial_default ,
760+ write_default = write_default ,
761+ )
652762 )
653763
654764 if has_changes :
0 commit comments