Skip to content

Commit 244c947

Browse files
Merge pull request #1 from mbsantiago/main
Making checks pass
2 parents 40bcdfe + 3005495 commit 244c947

File tree

23 files changed

+580
-168
lines changed

23 files changed

+580
-168
lines changed

docs_src/tutorial/fastapi/delete/tutorial001.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def update_hero(hero_id: int, hero: HeroUpdate):
7979
db_hero = session.get(Hero, hero_id)
8080
if not db_hero:
8181
raise HTTPException(status_code=404, detail="Hero not found")
82-
hero_data = hero.dict(exclude_unset=True)
82+
hero_data = hero.model_dump(exclude_unset=True)
8383
for key, value in hero_data.items():
8484
setattr(db_hero, key, value)
8585
session.add(db_hero)

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ classifiers = [
3333
[tool.poetry.dependencies]
3434
python = "^3.7"
3535
SQLAlchemy = ">=2.0.0,<=2.0.11"
36-
pydantic = "^2.1.1"
36+
pydantic = { version = ">=2.1.1,<=2.4", extras = ["email"] }
3737

3838
[tool.poetry.dev-dependencies]
3939
pytest = "^7.0.1"
@@ -52,6 +52,7 @@ autoflake = "^1.4"
5252
isort = "^5.9.3"
5353
async_generator = {version = "*", python = "~3.7"}
5454
async-exit-stack = {version = "*", python = "~3.7"}
55+
importlib-metadata = { version = "*", python = ">3.7" }
5556
httpx = "^0.24.1"
5657

5758
[build-system]

sqlmodel/ext/asyncio/session.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
1+
from typing import Any, Dict, Mapping, Optional, Sequence, Type, TypeVar, Union
22

33
from sqlalchemy import util
44
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
55
from sqlalchemy.ext.asyncio import engine
66
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
7+
from sqlalchemy.orm import Mapper
8+
from sqlalchemy.sql.expression import TableClause
79
from sqlalchemy.util.concurrency import greenlet_spawn
810
from sqlmodel.sql.base import Executable
911

@@ -14,13 +16,18 @@
1416
_T = TypeVar("_T")
1517

1618

19+
BindsType = Dict[
20+
Union[Type[Any], Mapper[Any], TableClause, str], Union[AsyncEngine, AsyncConnection]
21+
]
22+
23+
1724
class AsyncSession(_AsyncSession):
1825
sync_session: Session
1926

2027
def __init__(
2128
self,
2229
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
23-
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
30+
binds: Optional[BindsType] = None,
2431
**kw: Any,
2532
):
2633
# All the same code of the original AsyncSession

sqlmodel/main.py

Lines changed: 122 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
from __future__ import annotations
2+
13
import ipaddress
4+
import sys
5+
import types
26
import uuid
37
import weakref
48
from datetime import date, datetime, time, timedelta
@@ -22,11 +26,11 @@
2226
TypeVar,
2327
Union,
2428
cast,
25-
get_args,
26-
get_origin,
2729
)
2830

29-
from pydantic import BaseModel
31+
import pydantic
32+
from annotated_types import MaxLen
33+
from pydantic import BaseModel, EmailStr, ImportString, NameEmail
3034
from pydantic._internal._fields import PydanticGeneralMetadata
3135
from pydantic._internal._model_construction import ModelMetaclass
3236
from pydantic._internal._repr import Representation
@@ -39,12 +43,21 @@
3943
from sqlalchemy.orm.attributes import set_attribute
4044
from sqlalchemy.orm.decl_api import DeclarativeMeta
4145
from sqlalchemy.orm.instrumentation import is_instrumented
42-
from sqlalchemy.sql.schema import MetaData
46+
from sqlalchemy.orm.properties import MappedColumn
47+
from sqlalchemy.sql import false, true
48+
from sqlalchemy.sql.schema import DefaultClause, MetaData
4349
from sqlalchemy.sql.sqltypes import LargeBinary, Time
4450

4551
from .sql.sqltypes import GUID, AutoString
4652
from .typing import SQLModelConfig
4753

54+
if sys.version_info >= (3, 8):
55+
from typing import get_args, get_origin
56+
else:
57+
from typing_extensions import get_args, get_origin
58+
59+
from typing_extensions import Annotated, _AnnotatedAlias
60+
4861
_T = TypeVar("_T")
4962
NoArgAnyCallable = Callable[[], Any]
5063
NoneType = type(None)
@@ -61,6 +74,8 @@ def __dataclass_transform__(
6174

6275

6376
class FieldInfo(PydanticFieldInfo):
77+
nullable: Union[bool, PydanticUndefinedType]
78+
6479
def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None:
6580
primary_key = kwargs.pop("primary_key", False)
6681
nullable = kwargs.pop("nullable", PydanticUndefined)
@@ -150,14 +165,40 @@ def Field(
150165
unique: bool = False,
151166
nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined,
152167
index: Union[bool, PydanticUndefinedType] = PydanticUndefined,
153-
sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore
168+
sa_column: Union[Column, PydanticUndefinedType, Callable[[], Column]] = PydanticUndefined, # type: ignore
154169
sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined,
155170
sa_column_kwargs: Union[
156171
Mapping[str, Any], PydanticUndefinedType
157172
] = PydanticUndefined,
158173
schema_extra: Optional[Dict[str, Any]] = None,
159174
) -> Any:
160175
current_schema_extra = schema_extra or {}
176+
if default is PydanticUndefined:
177+
if isinstance(sa_column, types.FunctionType): # lambda
178+
sa_column_ = sa_column()
179+
else:
180+
sa_column_ = sa_column
181+
182+
# server_default -> default
183+
if isinstance(sa_column_, Column) and isinstance(
184+
sa_column_.server_default, DefaultClause
185+
):
186+
default_value = sa_column_.server_default.arg
187+
if issubclass(type(sa_column_.type), Integer) and isinstance(
188+
default_value, str
189+
):
190+
default = int(default_value)
191+
elif issubclass(type(sa_column_.type), Boolean):
192+
if default_value is false():
193+
default = False
194+
elif default_value is true():
195+
default = True
196+
elif isinstance(default_value, str):
197+
if default_value == "1":
198+
default = True
199+
elif default_value == "0":
200+
default = False
201+
161202
field_info = FieldInfo(
162203
default,
163204
default_factory=default_factory,
@@ -236,7 +277,6 @@ def __new__(
236277
class_dict: Dict[str, Any],
237278
**kwargs: Any,
238279
) -> Any:
239-
240280
relationships: Dict[str, RelationshipInfo] = {}
241281
dict_for_pydantic = {}
242282
original_annotations = class_dict.get("__annotations__", {})
@@ -398,23 +438,50 @@ def __init__(
398438
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
399439

400440

441+
def _is_optional_or_union(type_: Optional[type]) -> bool:
442+
if sys.version_info >= (3, 10):
443+
return get_origin(type_) in (types.UnionType, Union)
444+
else:
445+
return get_origin(type_) is Union
446+
447+
401448
def get_sqlalchemy_type(field: FieldInfo) -> Any:
402-
type_: type | None = field.annotation
449+
type_: Optional[type] | _AnnotatedAlias = field.annotation
450+
451+
# Resolve Optional/Union fields
403452

404-
# Resolve Optional fields
405-
if type_ is not None and get_origin(type_) is Union:
453+
if type_ is not None and _is_optional_or_union(type_):
406454
bases = get_args(type_)
407455
if len(bases) > 2:
408456
raise RuntimeError(
409457
"Cannot have a (non-optional) union as a SQL alchemy field"
410458
)
411459
type_ = bases[0]
460+
# Resolve Annoted fields,
461+
# like typing.Annotated[pydantic_core._pydantic_core.Url,
462+
# UrlConstraints(max_length=512,
463+
# allowed_schemes=['smb', 'ftp', 'file']) ]
464+
if type_ is pydantic.AnyUrl:
465+
if field.metadata:
466+
meta = field.metadata[0]
467+
return AutoString(length=meta.max_length)
468+
else:
469+
return AutoString
470+
471+
org_type = get_origin(type_)
472+
if org_type is Annotated:
473+
type2 = get_args(type_)[0]
474+
if type2 is pydantic.AnyUrl:
475+
meta = get_args(type_)[1]
476+
return AutoString(length=meta.max_length)
477+
elif org_type is pydantic.AnyUrl and type(type_) is _AnnotatedAlias:
478+
return AutoString(type_.__metadata__[0].max_length)
412479

413480
# The 3rd is PydanticGeneralMetadata
414481
metadata = _get_field_metadata(field)
415482
if type_ is None:
416483
raise ValueError("Missing field type")
417-
if issubclass(type_, str):
484+
if issubclass(type_, str) or type_ in (EmailStr, NameEmail, ImportString):
418485
max_length = getattr(metadata, "max_length", None)
419486
if max_length:
420487
return AutoString(length=max_length)
@@ -458,9 +525,18 @@ def get_sqlalchemy_type(field: FieldInfo) -> Any:
458525

459526

460527
def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
528+
"""
529+
sa_column > field attributes > annotation info
530+
"""
461531
sa_column = getattr(field, "sa_column", PydanticUndefined)
462532
if isinstance(sa_column, Column):
463533
return sa_column
534+
if isinstance(sa_column, MappedColumn):
535+
return sa_column.column
536+
if isinstance(sa_column, types.FunctionType):
537+
col = sa_column()
538+
assert isinstance(col, Column)
539+
return col
464540
sa_type = get_sqlalchemy_type(field)
465541
primary_key = getattr(field, "primary_key", False)
466542
index = getattr(field, "index", PydanticUndefined)
@@ -484,7 +560,7 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
484560
"index": index,
485561
"unique": unique,
486562
}
487-
sa_default: PydanticUndefinedType | Callable[[], Any] = PydanticUndefined
563+
sa_default: Union[PydanticUndefinedType, Callable[[], Any]] = PydanticUndefined
488564
if field.default_factory:
489565
sa_default = field.default_factory
490566
elif field.default is not PydanticUndefined:
@@ -524,12 +600,16 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
524600
# in the Pydantic model so that when SQLAlchemy sets attributes that are
525601
# added (e.g. when querying from DB) to the __fields_set__, this already exists
526602
object.__setattr__(new_object, "__pydantic_fields_set__", set())
603+
if not hasattr(new_object, "__pydantic_extra__"):
604+
object.__setattr__(new_object, "__pydantic_extra__", None)
605+
if not hasattr(new_object, "__pydantic_private__"):
606+
object.__setattr__(new_object, "__pydantic_private__", None)
527607
return new_object
528608

529609
def __init__(__pydantic_self__, **data: Any) -> None:
530610
old_dict = __pydantic_self__.__dict__.copy()
531611
super().__init__(**data)
532-
__pydantic_self__.__dict__ = old_dict | __pydantic_self__.__dict__
612+
__pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__}
533613
non_pydantic_keys = data.keys() - __pydantic_self__.model_fields
534614
for key in non_pydantic_keys:
535615
if key in __pydantic_self__.__sqlmodel_relationships__:
@@ -558,33 +638,54 @@ def __tablename__(cls) -> str:
558638

559639
@classmethod
560640
def model_validate(
561-
cls: type[_TSQLModel],
641+
cls: Type[_TSQLModel],
562642
obj: Any,
563643
*,
564-
strict: bool | None = None,
565-
from_attributes: bool | None = None,
566-
context: dict[str, Any] | None = None,
644+
strict: Optional[bool] = None,
645+
from_attributes: Optional[bool] = None,
646+
context: Optional[Dict[str, Any]] = None,
567647
) -> _TSQLModel:
568648
# Somehow model validate doesn't call __init__ so it would remove our init logic
569649
validated = super().model_validate(
570650
obj, strict=strict, from_attributes=from_attributes, context=context
571651
)
572-
return cls(**{key: value for key, value in validated})
652+
653+
# remove defaults so they don't get validated
654+
data = {}
655+
for key, value in validated:
656+
field = cls.model_fields.get(key)
657+
658+
if field is None:
659+
continue
660+
661+
if (
662+
hasattr(field, "default")
663+
and field.default is not PydanticUndefined
664+
and value == field.default
665+
):
666+
continue
667+
668+
data[key] = value
669+
670+
return cls(**data)
573671

574672

575673
def _is_field_noneable(field: FieldInfo) -> bool:
576-
if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined:
674+
if hasattr(field, "nullable") and not isinstance(
675+
field.nullable, PydanticUndefinedType
676+
):
577677
return field.nullable
578678
if not field.is_required():
579679
default = getattr(field, "original_default", field.default)
580680
if default is PydanticUndefined:
581681
return False
582682
if field.annotation is None or field.annotation is NoneType:
583683
return True
584-
if get_origin(field.annotation) is Union:
684+
if _is_optional_or_union(field.annotation):
585685
for base in get_args(field.annotation):
586686
if base is NoneType:
587687
return True
688+
588689
return False
589690
return False
590691

@@ -593,4 +694,6 @@ def _get_field_metadata(field: FieldInfo) -> object:
593694
for meta in field.metadata:
594695
if isinstance(meta, PydanticGeneralMetadata):
595696
return meta
697+
if isinstance(meta, MaxLen):
698+
return meta
596699
return object()

0 commit comments

Comments
 (0)