1+ from __future__ import annotations
2+
13import ipaddress
4+ import sys
5+ import types
26import uuid
37import weakref
48from datetime import date , datetime , time , timedelta
2226 TypeVar ,
2327 Union ,
2428 cast ,
25- get_args ,
26- get_origin ,
2729)
2830
29- from pydantic import BaseModel
31+ import pydantic
32+ from annotated_types import MaxLen
33+ from pydantic import BaseModel , EmailStr , ImportString , NameEmail
3034from pydantic ._internal ._fields import PydanticGeneralMetadata
3135from pydantic ._internal ._model_construction import ModelMetaclass
3236from pydantic ._internal ._repr import Representation
3943from sqlalchemy .orm .attributes import set_attribute
4044from sqlalchemy .orm .decl_api import DeclarativeMeta
4145from sqlalchemy .orm .instrumentation import is_instrumented
42- from sqlalchemy .sql .schema import MetaData
46+ from sqlalchemy .orm .properties import MappedColumn
47+ from sqlalchemy .sql import false , true
48+ from sqlalchemy .sql .schema import DefaultClause , MetaData
4349from sqlalchemy .sql .sqltypes import LargeBinary , Time
4450
4551from .sql .sqltypes import GUID , AutoString
4652from .typing import SQLModelConfig
4753
54+ if sys .version_info >= (3 , 8 ):
55+ from typing import get_args , get_origin
56+ else :
57+ from typing_extensions import get_args , get_origin
58+
59+ from typing_extensions import Annotated , _AnnotatedAlias
60+
4861_T = TypeVar ("_T" )
4962NoArgAnyCallable = Callable [[], Any ]
5063NoneType = type (None )
@@ -61,6 +74,8 @@ def __dataclass_transform__(
6174
6275
6376class FieldInfo (PydanticFieldInfo ):
77+ nullable : Union [bool , PydanticUndefinedType ]
78+
6479 def __init__ (self , default : Any = PydanticUndefined , ** kwargs : Any ) -> None :
6580 primary_key = kwargs .pop ("primary_key" , False )
6681 nullable = kwargs .pop ("nullable" , PydanticUndefined )
@@ -150,14 +165,40 @@ def Field(
150165 unique : bool = False ,
151166 nullable : Union [bool , PydanticUndefinedType ] = PydanticUndefined ,
152167 index : Union [bool , PydanticUndefinedType ] = PydanticUndefined ,
153- sa_column : Union [Column , PydanticUndefinedType ] = PydanticUndefined , # type: ignore
168+ sa_column : Union [Column , PydanticUndefinedType , Callable [[], Column ] ] = PydanticUndefined , # type: ignore
154169 sa_column_args : Union [Sequence [Any ], PydanticUndefinedType ] = PydanticUndefined ,
155170 sa_column_kwargs : Union [
156171 Mapping [str , Any ], PydanticUndefinedType
157172 ] = PydanticUndefined ,
158173 schema_extra : Optional [Dict [str , Any ]] = None ,
159174) -> Any :
160175 current_schema_extra = schema_extra or {}
176+ if default is PydanticUndefined :
177+ if isinstance (sa_column , types .FunctionType ): # lambda
178+ sa_column_ = sa_column ()
179+ else :
180+ sa_column_ = sa_column
181+
182+ # server_default -> default
183+ if isinstance (sa_column_ , Column ) and isinstance (
184+ sa_column_ .server_default , DefaultClause
185+ ):
186+ default_value = sa_column_ .server_default .arg
187+ if issubclass (type (sa_column_ .type ), Integer ) and isinstance (
188+ default_value , str
189+ ):
190+ default = int (default_value )
191+ elif issubclass (type (sa_column_ .type ), Boolean ):
192+ if default_value is false ():
193+ default = False
194+ elif default_value is true ():
195+ default = True
196+ elif isinstance (default_value , str ):
197+ if default_value == "1" :
198+ default = True
199+ elif default_value == "0" :
200+ default = False
201+
161202 field_info = FieldInfo (
162203 default ,
163204 default_factory = default_factory ,
@@ -236,7 +277,6 @@ def __new__(
236277 class_dict : Dict [str , Any ],
237278 ** kwargs : Any ,
238279 ) -> Any :
239-
240280 relationships : Dict [str , RelationshipInfo ] = {}
241281 dict_for_pydantic = {}
242282 original_annotations = class_dict .get ("__annotations__" , {})
@@ -398,23 +438,50 @@ def __init__(
398438 ModelMetaclass .__init__ (cls , classname , bases , dict_ , ** kw )
399439
400440
441+ def _is_optional_or_union (type_ : Optional [type ]) -> bool :
442+ if sys .version_info >= (3 , 10 ):
443+ return get_origin (type_ ) in (types .UnionType , Union )
444+ else :
445+ return get_origin (type_ ) is Union
446+
447+
401448def get_sqlalchemy_type (field : FieldInfo ) -> Any :
402- type_ : type | None = field .annotation
449+ type_ : Optional [type ] | _AnnotatedAlias = field .annotation
450+
451+ # Resolve Optional/Union fields
403452
404- # Resolve Optional fields
405- if type_ is not None and get_origin (type_ ) is Union :
453+ if type_ is not None and _is_optional_or_union (type_ ):
406454 bases = get_args (type_ )
407455 if len (bases ) > 2 :
408456 raise RuntimeError (
409457 "Cannot have a (non-optional) union as a SQL alchemy field"
410458 )
411459 type_ = bases [0 ]
460+ # Resolve Annoted fields,
461+ # like typing.Annotated[pydantic_core._pydantic_core.Url,
462+ # UrlConstraints(max_length=512,
463+ # allowed_schemes=['smb', 'ftp', 'file']) ]
464+ if type_ is pydantic .AnyUrl :
465+ if field .metadata :
466+ meta = field .metadata [0 ]
467+ return AutoString (length = meta .max_length )
468+ else :
469+ return AutoString
470+
471+ org_type = get_origin (type_ )
472+ if org_type is Annotated :
473+ type2 = get_args (type_ )[0 ]
474+ if type2 is pydantic .AnyUrl :
475+ meta = get_args (type_ )[1 ]
476+ return AutoString (length = meta .max_length )
477+ elif org_type is pydantic .AnyUrl and type (type_ ) is _AnnotatedAlias :
478+ return AutoString (type_ .__metadata__ [0 ].max_length )
412479
413480 # The 3rd is PydanticGeneralMetadata
414481 metadata = _get_field_metadata (field )
415482 if type_ is None :
416483 raise ValueError ("Missing field type" )
417- if issubclass (type_ , str ):
484+ if issubclass (type_ , str ) or type_ in ( EmailStr , NameEmail , ImportString ) :
418485 max_length = getattr (metadata , "max_length" , None )
419486 if max_length :
420487 return AutoString (length = max_length )
@@ -458,9 +525,18 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
458525
459526
460527def get_column_from_field (field : FieldInfo ) -> Column : # type: ignore
528+ """
529+ sa_column > field attributes > annotation info
530+ """
461531 sa_column = getattr (field , "sa_column" , PydanticUndefined )
462532 if isinstance (sa_column , Column ):
463533 return sa_column
534+ if isinstance (sa_column , MappedColumn ):
535+ return sa_column .column
536+ if isinstance (sa_column , types .FunctionType ):
537+ col = sa_column ()
538+ assert isinstance (col , Column )
539+ return col
464540 sa_type = get_sqlalchemy_type (field )
465541 primary_key = getattr (field , "primary_key" , False )
466542 index = getattr (field , "index" , PydanticUndefined )
@@ -484,7 +560,7 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
484560 "index" : index ,
485561 "unique" : unique ,
486562 }
487- sa_default : PydanticUndefinedType | Callable [[], Any ] = PydanticUndefined
563+ sa_default : Union [ PydanticUndefinedType , Callable [[], Any ] ] = PydanticUndefined
488564 if field .default_factory :
489565 sa_default = field .default_factory
490566 elif field .default is not PydanticUndefined :
@@ -524,12 +600,16 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
524600 # in the Pydantic model so that when SQLAlchemy sets attributes that are
525601 # added (e.g. when querying from DB) to the __fields_set__, this already exists
526602 object .__setattr__ (new_object , "__pydantic_fields_set__" , set ())
603+ if not hasattr (new_object , "__pydantic_extra__" ):
604+ object .__setattr__ (new_object , "__pydantic_extra__" , None )
605+ if not hasattr (new_object , "__pydantic_private__" ):
606+ object .__setattr__ (new_object , "__pydantic_private__" , None )
527607 return new_object
528608
529609 def __init__ (__pydantic_self__ , ** data : Any ) -> None :
530610 old_dict = __pydantic_self__ .__dict__ .copy ()
531611 super ().__init__ (** data )
532- __pydantic_self__ .__dict__ = old_dict | __pydantic_self__ .__dict__
612+ __pydantic_self__ .__dict__ = { ** old_dict , ** __pydantic_self__ .__dict__ }
533613 non_pydantic_keys = data .keys () - __pydantic_self__ .model_fields
534614 for key in non_pydantic_keys :
535615 if key in __pydantic_self__ .__sqlmodel_relationships__ :
@@ -558,33 +638,54 @@ def __tablename__(cls) -> str:
558638
559639 @classmethod
560640 def model_validate (
561- cls : type [_TSQLModel ],
641+ cls : Type [_TSQLModel ],
562642 obj : Any ,
563643 * ,
564- strict : bool | None = None ,
565- from_attributes : bool | None = None ,
566- context : dict [ str , Any ] | None = None ,
644+ strict : Optional [ bool ] = None ,
645+ from_attributes : Optional [ bool ] = None ,
646+ context : Optional [ Dict [ str , Any ]] = None ,
567647 ) -> _TSQLModel :
568648 # Somehow model validate doesn't call __init__ so it would remove our init logic
569649 validated = super ().model_validate (
570650 obj , strict = strict , from_attributes = from_attributes , context = context
571651 )
572- return cls (** {key : value for key , value in validated })
652+
653+ # remove defaults so they don't get validated
654+ data = {}
655+ for key , value in validated :
656+ field = cls .model_fields .get (key )
657+
658+ if field is None :
659+ continue
660+
661+ if (
662+ hasattr (field , "default" )
663+ and field .default is not PydanticUndefined
664+ and value == field .default
665+ ):
666+ continue
667+
668+ data [key ] = value
669+
670+ return cls (** data )
573671
574672
575673def _is_field_noneable (field : FieldInfo ) -> bool :
576- if getattr (field , "nullable" , PydanticUndefined ) is not PydanticUndefined :
674+ if hasattr (field , "nullable" ) and not isinstance (
675+ field .nullable , PydanticUndefinedType
676+ ):
577677 return field .nullable
578678 if not field .is_required ():
579679 default = getattr (field , "original_default" , field .default )
580680 if default is PydanticUndefined :
581681 return False
582682 if field .annotation is None or field .annotation is NoneType :
583683 return True
584- if get_origin (field .annotation ) is Union :
684+ if _is_optional_or_union (field .annotation ):
585685 for base in get_args (field .annotation ):
586686 if base is NoneType :
587687 return True
688+
588689 return False
589690 return False
590691
@@ -593,4 +694,6 @@ def _get_field_metadata(field: FieldInfo) -> object:
593694 for meta in field .metadata :
594695 if isinstance (meta , PydanticGeneralMetadata ):
595696 return meta
697+ if isinstance (meta , MaxLen ):
698+ return meta
596699 return object ()
0 commit comments