Skip to content

Commit f67b414

Browse files
authored
Merge pull request #1 from honglei/main
support str|None , mapped_column, AnyURL
2 parents c99c1a9 + 46b130d commit f67b414

File tree

1 file changed

+60
-4
lines changed

1 file changed

+60
-4
lines changed

sqlmodel/main.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ipaddress
44
import sys
5+
import types
56
import uuid
67
import weakref
78
from datetime import date, datetime, time, timedelta
@@ -27,6 +28,7 @@
2728
cast,
2829
)
2930

31+
import pydantic
3032
from pydantic import BaseModel
3133
from pydantic._internal._fields import PydanticGeneralMetadata
3234
from pydantic._internal._model_construction import ModelMetaclass
@@ -40,7 +42,9 @@
4042
from sqlalchemy.orm.attributes import set_attribute
4143
from sqlalchemy.orm.decl_api import DeclarativeMeta
4244
from sqlalchemy.orm.instrumentation import is_instrumented
43-
from sqlalchemy.sql.schema import MetaData
45+
from sqlalchemy.orm.properties import MappedColumn
46+
from sqlalchemy.sql import false, true
47+
from sqlalchemy.sql.schema import DefaultClause, MetaData
4448
from sqlalchemy.sql.sqltypes import LargeBinary, Time
4549

4650
from .sql.sqltypes import GUID, AutoString
@@ -51,6 +55,11 @@
5155
else:
5256
from typing_extensions import get_args, get_origin
5357

58+
if sys.version_info >= (3, 9):
59+
from typing import Annotated
60+
else:
61+
from typing_extensions import Annotated
62+
5463
_T = TypeVar("_T")
5564
NoArgAnyCallable = Callable[[], Any]
5665
NoneType = type(None)
@@ -158,14 +167,40 @@ def Field(
158167
unique: bool = False,
159168
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
160169
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
161-
sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore
170+
sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore
162171
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
163172
sa_column_kwargs: Union[
164173
Mapping[str, Any], PydanticUndefinedType
165174
] = PydanticUndefined,
166175
schema_extra: Optional[Dict[str, Any]] = None,
167176
) -> Any:
168177
current_schema_extra = schema_extra or {}
178+
if default is PydanticUndefined:
179+
if isinstance(sa_column, types.FunctionType): # lambda
180+
sa_column_ = sa_column()
181+
else:
182+
sa_column_ = sa_column
183+
184+
# server_default -> default
185+
if isinstance(sa_column_, Column) and isinstance(
186+
sa_column_.server_default, DefaultClause
187+
):
188+
default_value = sa_column_.server_default.arg
189+
if issubclass(type(sa_column_.type), Integer) and isinstance(
190+
default_value, str
191+
):
192+
default = int(default_value)
193+
elif issubclass(type(sa_column_.type), Boolean):
194+
if default_value is false():
195+
default = False
196+
elif default_value is true():
197+
default = True
198+
elif isinstance(default_value, str):
199+
if default_value == "1":
200+
default = True
201+
elif default_value == "0":
202+
default = False
203+
169204
field_info = FieldInfo(
170205
default,
171206
default_factory=default_factory,
@@ -408,14 +443,33 @@ def __init__(
408443
def get_sqlalchemy_type(field: FieldInfo) -> Any:
409444
type_: Optional[type] = field.annotation
410445

411-
# Resolve Optional fields
412-
if type_ is not None and get_origin(type_) is Union:
446+
# Resolve Optional/Union fields
447+
def is_optional_or_union(type_: Optional[type]) -> bool:
448+
if sys.version_info >= (3, 10):
449+
return get_origin(type_) in (types.UnionType, Union)
450+
else:
451+
return get_origin(type_) is Union
452+
453+
if type_ is not None and is_optional_or_union(type_):
413454
bases = get_args(type_)
414455
if len(bases) > 2:
415456
raise RuntimeError(
416457
"Cannot have a (non-optional) union as a SQL alchemy field"
417458
)
418459
type_ = bases[0]
460+
# Resolve Annoted fields,
461+
# like typing.Annotated[pydantic_core._pydantic_core.Url,
462+
# UrlConstraints(max_length=512,
463+
# allowed_schemes=['smb', 'ftp', 'file']) ]
464+
if type_ is pydantic.AnyUrl:
465+
meta = field.metadata[0]
466+
return AutoString(length=meta.max_length)
467+
468+
if get_origin(type_) is Annotated:
469+
type2 = get_args(type_)[0]
470+
if type2 is pydantic.AnyUrl:
471+
meta = get_args(type_)[1]
472+
return AutoString(length=meta.max_length)
419473

420474
# The 3rd is PydanticGeneralMetadata
421475
metadata = _get_field_metadata(field)
@@ -468,6 +522,8 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
468522
sa_column = getattr(field, "sa_column", PydanticUndefined)
469523
if isinstance(sa_column, Column):
470524
return sa_column
525+
if isinstance(sa_column, MappedColumn):
526+
return sa_column.column
471527
sa_type = get_sqlalchemy_type(field)
472528
primary_key = getattr(field, "primary_key", False)
473529
index = getattr(field, "index", PydanticUndefined)

0 commit comments

Comments
 (0)