5555else :
5656 from typing_extensions import get_args , get_origin
5757
58- if sys .version_info >= (3 , 9 ):
59- from typing import Annotated
60- else :
61- from typing_extensions import Annotated
58+ from typing_extensions import Annotated , _AnnotatedAlias
6259
6360_T = TypeVar ("_T" )
6461NoArgAnyCallable = Callable [[], Any ]
@@ -167,7 +164,7 @@ def Field(
167164 unique : bool = False ,
168165 nullable : Union [bool , PydanticUndefinedType ] = PydanticUndefined ,
169166 index : Union [bool , PydanticUndefinedType ] = PydanticUndefined ,
170- sa_column : Union [Column , PydanticUndefinedType , types . FunctionType ] = PydanticUndefined , # type: ignore
167+ sa_column : Union [Column , PydanticUndefinedType , Callable [[], Column ] ] = PydanticUndefined , # type: ignore
171168 sa_column_args : Union [Sequence [Any ], PydanticUndefinedType ] = PydanticUndefined ,
172169 sa_column_kwargs : Union [
173170 Mapping [str , Any ], PydanticUndefinedType
@@ -440,17 +437,19 @@ def __init__(
440437 ModelMetaclass .__init__ (cls , classname , bases , dict_ , ** kw )
441438
442439
440+ def _is_optional_or_union (type_ : Optional [type ]) -> bool :
441+ if sys .version_info >= (3 , 10 ):
442+ return get_origin (type_ ) in (types .UnionType , Union )
443+ else :
444+ return get_origin (type_ ) is Union
445+
446+
443447def get_sqlalchemy_type (field : FieldInfo ) -> Any :
444- type_ : Optional [type ] = field .annotation
448+ type_ : Optional [type ] | _AnnotatedAlias = field .annotation
445449
446450 # Resolve Optional/Union fields
447- def is_optional_or_union (type_ : Optional [type ]) -> bool :
448- if sys .version_info >= (3 , 10 ):
449- return get_origin (type_ ) in (types .UnionType , Union )
450- else :
451- return get_origin (type_ ) is Union
452451
453- if type_ is not None and is_optional_or_union (type_ ):
452+ if type_ is not None and _is_optional_or_union (type_ ):
454453 bases = get_args (type_ )
455454 if len (bases ) > 2 :
456455 raise RuntimeError (
@@ -462,14 +461,20 @@ def is_optional_or_union(type_: Optional[type]) -> bool:
462461 # UrlConstraints(max_length=512,
463462 # allowed_schemes=['smb', 'ftp', 'file']) ]
464463 if type_ is pydantic .AnyUrl :
465- meta = field .metadata [0 ]
466- return AutoString (length = meta .max_length )
464+ if field .metadata :
465+ meta = field .metadata [0 ]
466+ return AutoString (length = meta .max_length )
467+ else :
468+ return AutoString
467469
468- if get_origin (type_ ) is Annotated :
470+ org_type = get_origin (type_ )
471+ if org_type is Annotated :
469472 type2 = get_args (type_ )[0 ]
470473 if type2 is pydantic .AnyUrl :
471474 meta = get_args (type_ )[1 ]
472475 return AutoString (length = meta .max_length )
476+ elif org_type is pydantic .AnyUrl and type (type_ ) is _AnnotatedAlias :
477+ return AutoString (type_ .__metadata__ [0 ].max_length )
473478
474479 # The 3rd is PydanticGeneralMetadata
475480 metadata = _get_field_metadata (field )
@@ -519,11 +524,18 @@ def is_optional_or_union(type_: Optional[type]) -> bool:
519524
520525
521526def get_column_from_field (field : FieldInfo ) -> Column : # type: ignore
527+ """
528+ sa_column > field attributes > annotation info
529+ """
522530 sa_column = getattr (field , "sa_column" , PydanticUndefined )
523531 if isinstance (sa_column , Column ):
524532 return sa_column
525533 if isinstance (sa_column , MappedColumn ):
526534 return sa_column .column
535+ if isinstance (sa_column , types .FunctionType ):
536+ col = sa_column ()
537+ assert isinstance (col , Column )
538+ return col
527539 sa_type = get_sqlalchemy_type (field )
528540 primary_key = getattr (field , "primary_key" , False )
529541 index = getattr (field , "index" , PydanticUndefined )
@@ -587,6 +599,10 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
587599 # in the Pydantic model so that when SQLAlchemy sets attributes that are
588600 # added (e.g. when querying from DB) to the __fields_set__, this already exists
589601 object .__setattr__ (new_object , "__pydantic_fields_set__" , set ())
602+ if not hasattr (new_object , "__pydantic_extra__" ):
603+ object .__setattr__ (new_object , "__pydantic_extra__" , None )
604+ if not hasattr (new_object , "__pydantic_private__" ):
605+ object .__setattr__ (new_object , "__pydantic_private__" , None )
590606 return new_object
591607
592608 def __init__ (__pydantic_self__ , ** data : Any ) -> None :
@@ -636,7 +652,10 @@ def model_validate(
636652 # remove defaults so they don't get validated
637653 data = {}
638654 for key , value in validated :
639- field = cls .model_fields [key ]
655+ field = cls .model_fields .get (key )
656+
657+ if field is None :
658+ continue
640659
641660 if (
642661 hasattr (field , "default" )
@@ -661,10 +680,11 @@ def _is_field_noneable(field: FieldInfo) -> bool:
661680 return False
662681 if field .annotation is None or field .annotation is NoneType :
663682 return True
664- if get_origin (field .annotation ) is Union :
683+ if _is_optional_or_union (field .annotation ):
665684 for base in get_args (field .annotation ):
666685 if base is NoneType :
667686 return True
687+
668688 return False
669689 return False
670690
0 commit comments