1+ import ipaddress
2+ import uuid
3+ from datetime import date , datetime , time , timedelta
4+ from decimal import Decimal
5+ from enum import Enum
6+ from pathlib import Path
17from types import NoneType
28from typing import (
39 TYPE_CHECKING ,
612 Dict ,
713 ForwardRef ,
814 Optional ,
15+ Sequence ,
916 Type ,
1017 TypeVar ,
1118 Union ,
19+ cast ,
1220 get_args ,
1321 get_origin ,
1422)
1523
1624from pydantic import VERSION as PYDANTIC_VERSION
25+ from sqlalchemy import (
26+ Boolean ,
27+ Column ,
28+ Date ,
29+ DateTime ,
30+ Float ,
31+ ForeignKey ,
32+ Integer ,
33+ Interval ,
34+ Numeric ,
35+ )
36+ from sqlalchemy import Enum as sa_Enum
37+ from sqlalchemy .sql .sqltypes import LargeBinary , Time
38+
39+ from .sql .sqltypes import GUID , AutoString
1740
1841IS_PYDANTIC_V2 = int (PYDANTIC_VERSION .split ("." )[0 ]) >= 2
1942
2043
2144if IS_PYDANTIC_V2 :
2245 from pydantic import ConfigDict as PydanticModelConfig
46+ from pydantic ._internal ._fields import PydanticMetadata
47+ from pydantic ._internal ._model_construction import ModelMetaclass
2348 from pydantic_core import PydanticUndefined as PydanticUndefined # noqa
2449 from pydantic_core import PydanticUndefinedType as PydanticUndefinedType
2550else :
2651 from pydantic import BaseConfig as PydanticModelConfig
27- from pydantic .fields import ModelField # noqa
28- from pydantic .fields import Undefined as PydanticUndefined , UndefinedType as PydanticUndefinedType , SHAPE_SINGLETON # noqa
52+ from pydantic .fields import SHAPE_SINGLETON , ModelField
53+ from pydantic .fields import Undefined as PydanticUndefined # noqa
54+ from pydantic .fields import UndefinedType as PydanticUndefinedType
55+ from pydantic .main import ModelMetaclass as ModelMetaclass
2956 from pydantic .typing import resolve_annotations
3057
3158if TYPE_CHECKING :
3764InstanceOrType = Union [T , Type [T ]]
3865
3966if IS_PYDANTIC_V2 :
67+
4068 class SQLModelConfig (PydanticModelConfig , total = False ):
4169 table : Optional [bool ]
4270 registry : Optional [Any ]
4371
4472else :
73+
4574 class SQLModelConfig (PydanticModelConfig ):
4675 table : Optional [bool ] = None
4776 registry : Optional [Any ] = None
@@ -78,14 +107,14 @@ def set_config_value(
78107
79108def get_model_fields (model : InstanceOrType ["SQLModel" ]) -> Dict [str , "FieldInfo" ]:
80109 if IS_PYDANTIC_V2 :
81- return model .model_fields # type: ignore
110+ return model .model_fields # type: ignore
82111 else :
83112 return model .__fields__ # type: ignore
84113
85114
86115def get_fields_set (model : InstanceOrType ["SQLModel" ]) -> set [str ]:
87116 if IS_PYDANTIC_V2 :
88- return model .__pydantic_fields_set__ # type: ignore
117+ return model .__pydantic_fields_set__ # type: ignore
89118 else :
90119 return model .__fields_set__ # type: ignore
91120
@@ -115,21 +144,36 @@ def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]:
115144 )
116145
117146
118- def is_table (class_dict : dict [str , Any ]) -> bool :
147+ def class_dict_is_table (
148+ class_dict : dict [str , Any ], class_kwargs : dict [str , Any ]
149+ ) -> bool :
119150 config : SQLModelConfig = {}
120151 if IS_PYDANTIC_V2 :
121152 config = class_dict .get ("model_config" , {})
122153 else :
123154 config = class_dict .get ("__config__" , {})
124155 config_table = config .get ("table" , PydanticUndefined )
125156 if config_table is not PydanticUndefined :
126- return config_table # type: ignore
127- kw_table = class_dict .get ("table" , PydanticUndefined )
157+ return config_table # type: ignore
158+ kw_table = class_kwargs .get ("table" , PydanticUndefined )
128159 if kw_table is not PydanticUndefined :
129- return kw_table # type: ignore
160+ return kw_table # type: ignore
130161 return False
131162
132163
164+ def cls_is_table (cls : Type ) -> bool :
165+ if IS_PYDANTIC_V2 :
166+ config = getattr (cls , "model_config" , None )
167+ if not config :
168+ return False
169+ return config .get ("table" , False )
170+ else :
171+ config = getattr (cls , "__config__" , None )
172+ if not config :
173+ return False
174+ return getattr (config , "table" , False )
175+
176+
133177def get_relationship_to (
134178 name : str ,
135179 rel_info : "RelationshipInfo" ,
@@ -186,17 +230,15 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any])
186230 value .default in (PydanticUndefined , Ellipsis )
187231 ) and value .default_factory is None :
188232 # So we can check for nullable
189- value .original_default = value .default
190233 value .default = None
191234
192235
193- def is_field_noneable (field : "FieldInfo" ) -> bool :
236+ def _is_field_noneable (field : "FieldInfo" ) -> bool :
194237 if IS_PYDANTIC_V2 :
195238 if getattr (field , "nullable" , PydanticUndefined ) is not PydanticUndefined :
196- return field .nullable # type: ignore
239+ return field .nullable # type: ignore
197240 if not field .is_required ():
198- default = getattr (field , "original_default" , field .default )
199- if default is PydanticUndefined :
241+ if field .default is PydanticUndefined :
200242 return False
201243 if field .annotation is None or field .annotation is NoneType :
202244 return True
@@ -212,4 +254,163 @@ def is_field_noneable(field: "FieldInfo") -> bool:
212254 return field .allow_none and (
213255 field .shape != SHAPE_SINGLETON or not field .sub_fields
214256 )
215- return False
257+ return field .allow_none
258+
259+
260+ def get_sqlalchemy_type (field : Any ) -> Any :
261+ if IS_PYDANTIC_V2 :
262+ field_info = field
263+ else :
264+ field_info = field .field_info
265+ sa_type = getattr (field_info , "sa_type" , PydanticUndefined ) # noqa: B009
266+ if sa_type is not PydanticUndefined :
267+ return sa_type
268+
269+ type_ = get_type_from_field (field )
270+ metadata = get_field_metadata (field )
271+
272+ # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
273+ if issubclass (type_ , Enum ):
274+ return sa_Enum (type_ )
275+ if issubclass (type_ , str ):
276+ max_length = getattr (metadata , "max_length" , None )
277+ if max_length :
278+ return AutoString (length = max_length )
279+ return AutoString
280+ if issubclass (type_ , float ):
281+ return Float
282+ if issubclass (type_ , bool ):
283+ return Boolean
284+ if issubclass (type_ , int ):
285+ return Integer
286+ if issubclass (type_ , datetime ):
287+ return DateTime
288+ if issubclass (type_ , date ):
289+ return Date
290+ if issubclass (type_ , timedelta ):
291+ return Interval
292+ if issubclass (type_ , time ):
293+ return Time
294+ if issubclass (type_ , bytes ):
295+ return LargeBinary
296+ if issubclass (type_ , Decimal ):
297+ return Numeric (
298+ precision = getattr (metadata , "max_digits" , None ),
299+ scale = getattr (metadata , "decimal_places" , None ),
300+ )
301+ if issubclass (type_ , ipaddress .IPv4Address ):
302+ return AutoString
303+ if issubclass (type_ , ipaddress .IPv4Network ):
304+ return AutoString
305+ if issubclass (type_ , ipaddress .IPv6Address ):
306+ return AutoString
307+ if issubclass (type_ , ipaddress .IPv6Network ):
308+ return AutoString
309+ if issubclass (type_ , Path ):
310+ return AutoString
311+ if issubclass (type_ , uuid .UUID ):
312+ return GUID
313+ raise ValueError (f"{ type_ } has no matching SQLAlchemy type" )
314+
315+
316+ def get_type_from_field (field : Any ) -> type :
317+ if IS_PYDANTIC_V2 :
318+ type_ : type | None = field .annotation
319+ # Resolve Optional fields
320+ if type_ is None :
321+ raise ValueError ("Missing field type" )
322+ origin = get_origin (type_ )
323+ if origin is None :
324+ return type_
325+ if origin is Union :
326+ bases = get_args (type_ )
327+ if len (bases ) > 2 :
328+ raise ValueError (
329+ "Cannot have a (non-optional) union as a SQL alchemy field"
330+ )
331+ # Non optional unions are not allowed
332+ if bases [0 ] is not NoneType and bases [1 ] is not NoneType :
333+ raise ValueError (
334+ "Cannot have a (non-optional) union as a SQL alchemy field"
335+ )
336+ # Optional unions are allowed
337+ return bases [0 ] if bases [0 ] is not NoneType else bases [1 ]
338+ return origin
339+ else :
340+ if isinstance (field .type_ , type ) and field .shape == SHAPE_SINGLETON :
341+ return field .type_
342+ raise ValueError (f"The field { field .name } has no matching SQLAlchemy type" )
343+
344+
345+ class FakeMetadata :
346+ max_length : Optional [int ] = None
347+ max_digits : Optional [int ] = None
348+ decimal_places : Optional [int ] = None
349+
350+
351+ def get_field_metadata (field : Any ) -> Any :
352+ if IS_PYDANTIC_V2 :
353+ for meta in field .metadata :
354+ if isinstance (meta , PydanticMetadata ):
355+ return meta
356+ return FakeMetadata ()
357+ else :
358+ metadata = FakeMetadata ()
359+ metadata .max_length = field .field_info .max_length
360+ metadata .max_digits = getattr (field .type_ , "max_digits" , None )
361+ metadata .decimal_places = getattr (field .type_ , "decimal_places" , None )
362+ return metadata
363+
364+
365+ def get_column_from_field (field : Any ) -> Column : # type: ignore
366+ if IS_PYDANTIC_V2 :
367+ field_info = field
368+ else :
369+ field_info = field .field_info
370+ sa_column = getattr (field_info , "sa_column" , PydanticUndefined )
371+ if isinstance (sa_column , Column ):
372+ return sa_column
373+ sa_type = get_sqlalchemy_type (field )
374+ primary_key = getattr (field_info , "primary_key" , PydanticUndefined )
375+ if primary_key is PydanticUndefined :
376+ primary_key = False
377+ index = getattr (field_info , "index" , PydanticUndefined )
378+ if index is PydanticUndefined :
379+ index = False
380+ nullable = not primary_key and _is_field_noneable (field )
381+ # Override derived nullability if the nullable property is set explicitly
382+ # on the field
383+ field_nullable = getattr (field_info , "nullable" , PydanticUndefined ) # noqa: B009
384+ if field_nullable is not PydanticUndefined :
385+ assert not isinstance (field_nullable , PydanticUndefinedType )
386+ nullable = field_nullable
387+ args = []
388+ foreign_key = getattr (field_info , "foreign_key" , PydanticUndefined )
389+ if foreign_key is PydanticUndefined :
390+ foreign_key = None
391+ unique = getattr (field_info , "unique" , PydanticUndefined )
392+ if unique is PydanticUndefined :
393+ unique = False
394+ if foreign_key :
395+ assert isinstance (foreign_key , str )
396+ args .append (ForeignKey (foreign_key ))
397+ kwargs = {
398+ "primary_key" : primary_key ,
399+ "nullable" : nullable ,
400+ "index" : index ,
401+ "unique" : unique ,
402+ }
403+ sa_default = PydanticUndefined
404+ if field_info .default_factory :
405+ sa_default = field_info .default_factory
406+ elif field_info .default is not PydanticUndefined :
407+ sa_default = field_info .default
408+ if sa_default is not PydanticUndefined :
409+ kwargs ["default" ] = sa_default
410+ sa_column_args = getattr (field_info , "sa_column_args" , PydanticUndefined )
411+ if sa_column_args is not PydanticUndefined :
412+ args .extend (list (cast (Sequence [Any ], sa_column_args )))
413+ sa_column_kwargs = getattr (field_info , "sa_column_kwargs" , PydanticUndefined )
414+ if sa_column_kwargs is not PydanticUndefined :
415+ kwargs .update (cast (Dict [Any , Any ], sa_column_kwargs ))
416+ return Column (sa_type , * args , ** kwargs ) # type: ignore
0 commit comments