|
2 | 2 |
|
3 | 3 | import ipaddress |
4 | 4 | import sys |
| 5 | +import types |
5 | 6 | import uuid |
6 | 7 | import weakref |
7 | 8 | from datetime import date, datetime, time, timedelta |
|
27 | 28 | cast, |
28 | 29 | ) |
29 | 30 |
|
| 31 | +import pydantic |
30 | 32 | from pydantic import BaseModel |
31 | 33 | from pydantic._internal._fields import PydanticGeneralMetadata |
32 | 34 | from pydantic._internal._model_construction import ModelMetaclass |
|
40 | 42 | from sqlalchemy.orm.attributes import set_attribute |
41 | 43 | from sqlalchemy.orm.decl_api import DeclarativeMeta |
42 | 44 | from sqlalchemy.orm.instrumentation import is_instrumented |
43 | | -from sqlalchemy.sql.schema import MetaData |
| 45 | +from sqlalchemy.orm.properties import MappedColumn |
| 46 | +from sqlalchemy.sql import false, true |
| 47 | +from sqlalchemy.sql.schema import DefaultClause, MetaData |
44 | 48 | from sqlalchemy.sql.sqltypes import LargeBinary, Time |
45 | 49 |
|
46 | 50 | from .sql.sqltypes import GUID, AutoString |
|
51 | 55 | else: |
52 | 56 | from typing_extensions import get_args, get_origin |
53 | 57 |
|
| 58 | +if sys.version_info >= (3, 9): |
| 59 | + from typing import Annotated |
| 60 | +else: |
| 61 | + from typing_extensions import Annotated |
| 62 | + |
54 | 63 | _T = TypeVar("_T") |
55 | 64 | NoArgAnyCallable = Callable[[], Any] |
56 | 65 | NoneType = type(None) |
@@ -158,14 +167,40 @@ def Field( |
158 | 167 | unique: bool = False, |
159 | 168 | nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, |
160 | 169 | index: Union[bool, PydanticUndefinedType] = PydanticUndefined, |
161 | | - sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore |
| 170 | + sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore |
162 | 171 | sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, |
163 | 172 | sa_column_kwargs: Union[ |
164 | 173 | Mapping[str, Any], PydanticUndefinedType |
165 | 174 | ] = PydanticUndefined, |
166 | 175 | schema_extra: Optional[Dict[str, Any]] = None, |
167 | 176 | ) -> Any: |
168 | 177 | current_schema_extra = schema_extra or {} |
| 178 | + if default is PydanticUndefined: |
| 179 | + if isinstance(sa_column, types.FunctionType): # lambda |
| 180 | + sa_column_ = sa_column() |
| 181 | + else: |
| 182 | + sa_column_ = sa_column |
| 183 | + |
| 184 | + # server_default -> default |
| 185 | + if isinstance(sa_column_, Column) and isinstance( |
| 186 | + sa_column_.server_default, DefaultClause |
| 187 | + ): |
| 188 | + default_value = sa_column_.server_default.arg |
| 189 | + if issubclass(type(sa_column_.type), Integer) and isinstance( |
| 190 | + default_value, str |
| 191 | + ): |
| 192 | + default = int(default_value) |
| 193 | + elif issubclass(type(sa_column_.type), Boolean): |
| 194 | + if default_value is false(): |
| 195 | + default = False |
| 196 | + elif default_value is true(): |
| 197 | + default = True |
| 198 | + elif isinstance(default_value, str): |
| 199 | + if default_value == "1": |
| 200 | + default = True |
| 201 | + elif default_value == "0": |
| 202 | + default = False |
| 203 | + |
169 | 204 | field_info = FieldInfo( |
170 | 205 | default, |
171 | 206 | default_factory=default_factory, |
@@ -408,14 +443,33 @@ def __init__( |
408 | 443 | def get_sqlalchemy_type(field: FieldInfo) -> Any: |
409 | 444 | type_: Optional[type] = field.annotation |
410 | 445 |
|
411 | | - # Resolve Optional fields |
412 | | - if type_ is not None and get_origin(type_) is Union: |
| 446 | + # Resolve Optional/Union fields |
| 447 | + def is_optional_or_union(type_: Optional[type]) -> bool: |
| 448 | + if sys.version_info >= (3, 10): |
| 449 | + return get_origin(type_) in (types.UnionType, Union) |
| 450 | + else: |
| 451 | + return get_origin(type_) is Union |
| 452 | + |
| 453 | + if type_ is not None and is_optional_or_union(type_): |
413 | 454 | bases = get_args(type_) |
414 | 455 | if len(bases) > 2: |
415 | 456 | raise RuntimeError( |
416 | 457 | "Cannot have a (non-optional) union as a SQL alchemy field" |
417 | 458 | ) |
418 | 459 | type_ = bases[0] |
| 460 | + # Resolve Annoted fields, |
| 461 | + # like typing.Annotated[pydantic_core._pydantic_core.Url, |
| 462 | + # UrlConstraints(max_length=512, |
| 463 | + # allowed_schemes=['smb', 'ftp', 'file']) ] |
| 464 | + if type_ is pydantic.AnyUrl: |
| 465 | + meta = field.metadata[0] |
| 466 | + return AutoString(length=meta.max_length) |
| 467 | + |
| 468 | + if get_origin(type_) is Annotated: |
| 469 | + type2 = get_args(type_)[0] |
| 470 | + if type2 is pydantic.AnyUrl: |
| 471 | + meta = get_args(type_)[1] |
| 472 | + return AutoString(length=meta.max_length) |
419 | 473 |
|
420 | 474 | # The 3rd is PydanticGeneralMetadata |
421 | 475 | metadata = _get_field_metadata(field) |
@@ -468,6 +522,8 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore |
468 | 522 | sa_column = getattr(field, "sa_column", PydanticUndefined) |
469 | 523 | if isinstance(sa_column, Column): |
470 | 524 | return sa_column |
| 525 | + if isinstance(sa_column, MappedColumn): |
| 526 | + return sa_column.column |
471 | 527 | sa_type = get_sqlalchemy_type(field) |
472 | 528 | primary_key = getattr(field, "primary_key", False) |
473 | 529 | index = getattr(field, "index", PydanticUndefined) |
|
0 commit comments