Skip to content

Commit bcb6f32

Browse files
authored
Merge pull request #2 from honglei/main
get_column_from_field support functional sa_column
2 parents 63e2692 + 4213c97 commit bcb6f32

File tree

5 files changed

+187
-18
lines changed

5 files changed

+187
-18
lines changed

sqlmodel/main.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,7 @@
5555
else:
5656
from typing_extensions import get_args, get_origin
5757

58-
if sys.version_info >= (3, 9):
59-
from typing import Annotated
60-
else:
61-
from typing_extensions import Annotated
58+
from typing_extensions import Annotated, _AnnotatedAlias
6259

6360
_T = TypeVar("_T")
6461
NoArgAnyCallable = Callable[[], Any]
@@ -167,7 +164,7 @@ def Field(
167164
unique: bool = False,
168165
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
169166
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
170-
sa_column: Union[Column, PydanticUndefinedType, types.FunctionType] = PydanticUndefined, # type: ignore
167+
sa_column: Union[Column, PydanticUndefinedType, Callable[[], Column]] = PydanticUndefined, # type: ignore
171168
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
172169
sa_column_kwargs: Union[
173170
Mapping[str, Any], PydanticUndefinedType
@@ -440,17 +437,19 @@ def __init__(
440437
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
441438

442439

440+
def _is_optional_or_union(type_: Optional[type]) -> bool:
441+
if sys.version_info >= (3, 10):
442+
return get_origin(type_) in (types.UnionType, Union)
443+
else:
444+
return get_origin(type_) is Union
445+
446+
443447
def get_sqlalchemy_type(field: FieldInfo) -> Any:
444-
type_: Optional[type] = field.annotation
448+
type_: Optional[type] | _AnnotatedAlias = field.annotation
445449

446450
# Resolve Optional/Union fields
447-
def is_optional_or_union(type_: Optional[type]) -> bool:
448-
if sys.version_info >= (3, 10):
449-
return get_origin(type_) in (types.UnionType, Union)
450-
else:
451-
return get_origin(type_) is Union
452451

453-
if type_ is not None and is_optional_or_union(type_):
452+
if type_ is not None and _is_optional_or_union(type_):
454453
bases = get_args(type_)
455454
if len(bases) > 2:
456455
raise RuntimeError(
@@ -462,14 +461,20 @@ def is_optional_or_union(type_: Optional[type]) -> bool:
462461
# UrlConstraints(max_length=512,
463462
# allowed_schemes=['smb', 'ftp', 'file']) ]
464463
if type_ is pydantic.AnyUrl:
465-
meta = field.metadata[0]
466-
return AutoString(length=meta.max_length)
464+
if field.metadata:
465+
meta = field.metadata[0]
466+
return AutoString(length=meta.max_length)
467+
else:
468+
return AutoString
467469

468-
if get_origin(type_) is Annotated:
470+
org_type = get_origin(type_)
471+
if org_type is Annotated:
469472
type2 = get_args(type_)[0]
470473
if type2 is pydantic.AnyUrl:
471474
meta = get_args(type_)[1]
472475
return AutoString(length=meta.max_length)
476+
elif org_type is pydantic.AnyUrl and type(type_) is _AnnotatedAlias:
477+
return AutoString(type_.__metadata__[0].max_length)
473478

474479
# The 3rd is PydanticGeneralMetadata
475480
metadata = _get_field_metadata(field)
@@ -519,11 +524,18 @@ def is_optional_or_union(type_: Optional[type]) -> bool:
519524

520525

521526
def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
527+
"""
528+
sa_column > field attributes > annotation info
529+
"""
522530
sa_column = getattr(field, "sa_column", PydanticUndefined)
523531
if isinstance(sa_column, Column):
524532
return sa_column
525533
if isinstance(sa_column, MappedColumn):
526534
return sa_column.column
535+
if isinstance(sa_column, types.FunctionType):
536+
col = sa_column()
537+
assert isinstance(col, Column)
538+
return col
527539
sa_type = get_sqlalchemy_type(field)
528540
primary_key = getattr(field, "primary_key", False)
529541
index = getattr(field, "index", PydanticUndefined)
@@ -587,6 +599,10 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
587599
# in the Pydantic model so that when SQLAlchemy sets attributes that are
588600
# added (e.g. when querying from DB) to the __fields_set__, this already exists
589601
object.__setattr__(new_object, "__pydantic_fields_set__", set())
602+
if not hasattr(new_object, "__pydantic_extra__"):
603+
object.__setattr__(new_object, "__pydantic_extra__", None)
604+
if not hasattr(new_object, "__pydantic_private__"):
605+
object.__setattr__(new_object, "__pydantic_private__", None)
590606
return new_object
591607

592608
def __init__(__pydantic_self__, **data: Any) -> None:
@@ -636,7 +652,10 @@ def model_validate(
636652
# remove defaults so they don't get validated
637653
data = {}
638654
for key, value in validated:
639-
field = cls.model_fields[key]
655+
field = cls.model_fields.get(key)
656+
657+
if field is None:
658+
continue
640659

641660
if (
642661
hasattr(field, "default")
@@ -661,10 +680,11 @@ def _is_field_noneable(field: FieldInfo) -> bool:
661680
return False
662681
if field.annotation is None or field.annotation is NoneType:
663682
return True
664-
if get_origin(field.annotation) is Union:
683+
if _is_optional_or_union(field.annotation):
665684
for base in get_args(field.annotation):
666685
if base is NoneType:
667686
return True
687+
668688
return False
669689
return False
670690

sqlmodel/sql/sqltypes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99

1010
class AutoString(types.TypeDecorator): # type: ignore
11-
1211
impl = types.String
1312
cache_ok = True
1413
mysql_default_length = 255

tests/test_class_hierarchy.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import datetime
2+
import sys
3+
4+
import pytest
5+
from pydantic import AnyUrl, UrlConstraints
6+
from sqlmodel import (
7+
BigInteger,
8+
Column,
9+
DateTime,
10+
Field,
11+
Integer,
12+
SQLModel,
13+
String,
14+
create_engine,
15+
)
16+
from typing_extensions import Annotated
17+
18+
MoveSharedUrl = Annotated[
19+
AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"])
20+
]
21+
22+
23+
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
24+
def test_field_resuse():
25+
class BasicFileLog(SQLModel):
26+
resourceID: int = Field(
27+
sa_column=lambda: Column(Integer, index=True), description=""" """
28+
)
29+
transportID: Annotated[int | None, Field(description=" for ")] = None
30+
fileName: str = Field(
31+
sa_column=lambda: Column(String, index=True), description=""" """
32+
)
33+
fileSize: int | None = Field(
34+
sa_column=lambda: Column(BigInteger), ge=0, description=""" """
35+
)
36+
beginTime: datetime.datetime | None = Field(
37+
sa_column=lambda: Column(
38+
DateTime(timezone=True),
39+
index=True,
40+
),
41+
description="",
42+
)
43+
44+
class SendFileLog(BasicFileLog, table=True):
45+
id: int | None = Field(
46+
sa_column=Column(Integer, primary_key=True, autoincrement=True),
47+
description=""" """,
48+
)
49+
sendUser: str
50+
dstUrl: MoveSharedUrl | None
51+
52+
class RecvFileLog(BasicFileLog, table=True):
53+
id: int | None = Field(
54+
sa_column=Column(Integer, primary_key=True, autoincrement=True),
55+
description=""" """,
56+
)
57+
recvUser: str
58+
59+
sqlite_file_name = "database.db"
60+
sqlite_url = f"sqlite:///{sqlite_file_name}"
61+
62+
engine = create_engine(sqlite_url, echo=True)
63+
SQLModel.metadata.drop_all(engine)
64+
SQLModel.metadata.create_all(engine)
65+
SendFileLog(
66+
sendUser="j",
67+
resourceID=1,
68+
fileName="a.txt",
69+
fileSize=3234,
70+
beginTime=datetime.datetime.now(),
71+
)
72+
RecvFileLog(
73+
sendUser="j",
74+
resourceID=1,
75+
fileName="a.txt",
76+
fileSize=3234,
77+
beginTime=datetime.datetime.now(),
78+
)

tests/test_model_copy.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Optional
2+
3+
from sqlmodel import Field, Session, SQLModel, create_engine
4+
5+
6+
def test_model_copy(clear_sqlmodel):
7+
"""Test validation of implicit and explict None values.
8+
9+
# For consistency with pydantic, validators are not to be called on
10+
# arguments that are not explicitly provided.
11+
12+
https://github.com/tiangolo/sqlmodel/issues/230
13+
https://github.com/samuelcolvin/pydantic/issues/1223
14+
15+
"""
16+
17+
class Hero(SQLModel, table=True):
18+
id: Optional[int] = Field(default=None, primary_key=True)
19+
name: str
20+
secret_name: str
21+
age: Optional[int] = None
22+
23+
hero = Hero(name="Deadpond", secret_name="Dive Wilson", age=25)
24+
25+
engine = create_engine("sqlite://")
26+
27+
SQLModel.metadata.create_all(engine)
28+
29+
with Session(engine) as session:
30+
session.add(hero)
31+
session.commit()
32+
session.refresh(hero)
33+
34+
model_copy = hero.model_copy(update={"name": "Deadpond Copy"})
35+
36+
assert (
37+
model_copy.name == "Deadpond Copy"
38+
and model_copy.secret_name == "Dive Wilson"
39+
and model_copy.age == 25
40+
)
41+
42+
db_hero = session.get(Hero, hero.id)
43+
44+
db_copy = db_hero.model_copy(update={"name": "Deadpond Copy"})
45+
46+
assert (
47+
db_copy.name == "Deadpond Copy"
48+
and db_copy.secret_name == "Dive Wilson"
49+
and db_copy.age == 25
50+
)

tests/test_nullable.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
from typing import Optional
22

33
import pytest
4+
from pydantic import AnyUrl, UrlConstraints
45
from sqlalchemy.exc import IntegrityError
56
from sqlmodel import Field, Session, SQLModel, create_engine
7+
from typing_extensions import Annotated
8+
9+
MoveSharedUrl = Annotated[
10+
AnyUrl, UrlConstraints(max_length=512, allowed_schemes=["smb", "ftp", "file"])
11+
]
612

713

814
def test_nullable_fields(clear_sqlmodel, caplog):
@@ -13,6 +19,8 @@ class Hero(SQLModel, table=True):
1319
)
1420
required_value: str
1521
optional_default_ellipsis: Optional[str] = Field(default=...)
22+
optional_no_field: Optional[str]
23+
optional_no_field_default: Optional[str] = Field(description="no default")
1624
optional_default_none: Optional[str] = Field(default=None)
1725
optional_non_nullable: Optional[str] = Field(
1826
nullable=False,
@@ -49,6 +57,13 @@ class Hero(SQLModel, table=True):
4957
str_default_str_nullable: str = Field(default="default", nullable=True)
5058
str_default_ellipsis_non_nullable: str = Field(default=..., nullable=False)
5159
str_default_ellipsis_nullable: str = Field(default=..., nullable=True)
60+
base_url: AnyUrl
61+
optional_url: Optional[MoveSharedUrl] = Field(default=None, description="")
62+
url: MoveSharedUrl
63+
annotated_url: Annotated[MoveSharedUrl, Field(description="")]
64+
annotated_optional_url: Annotated[
65+
Optional[MoveSharedUrl], Field(description="")
66+
] = None
5267

5368
engine = create_engine("sqlite://", echo=True)
5469
SQLModel.metadata.create_all(engine)
@@ -59,6 +74,8 @@ class Hero(SQLModel, table=True):
5974
assert "primary_key INTEGER NOT NULL," in create_table_log
6075
assert "required_value VARCHAR NOT NULL," in create_table_log
6176
assert "optional_default_ellipsis VARCHAR NOT NULL," in create_table_log
77+
assert "optional_no_field VARCHAR," in create_table_log
78+
assert "optional_no_field_default VARCHAR NOT NULL," in create_table_log
6279
assert "optional_default_none VARCHAR," in create_table_log
6380
assert "optional_non_nullable VARCHAR NOT NULL," in create_table_log
6481
assert "optional_nullable VARCHAR," in create_table_log
@@ -77,6 +94,11 @@ class Hero(SQLModel, table=True):
7794
assert "str_default_str_nullable VARCHAR," in create_table_log
7895
assert "str_default_ellipsis_non_nullable VARCHAR NOT NULL," in create_table_log
7996
assert "str_default_ellipsis_nullable VARCHAR," in create_table_log
97+
assert "base_url VARCHAR NOT NULL," in create_table_log
98+
assert "optional_url VARCHAR(512), " in create_table_log
99+
assert "url VARCHAR(512) NOT NULL," in create_table_log
100+
assert "annotated_url VARCHAR(512) NOT NULL," in create_table_log
101+
assert "annotated_optional_url VARCHAR(512)," in create_table_log
80102

81103

82104
# Test for regression in https://github.com/tiangolo/sqlmodel/issues/420

0 commit comments

Comments
 (0)