Skip to content

Commit c99c1a9

Browse files
committed
Make sure tests pass in all supported python versions
1 parent 179183c commit c99c1a9

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

sqlmodel/main.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from __future__ import annotations
2+
13
import ipaddress
4+
import sys
25
import uuid
36
import weakref
47
from datetime import date, datetime, time, timedelta
@@ -22,8 +25,6 @@
2225
TypeVar,
2326
Union,
2427
cast,
25-
get_args,
26-
get_origin,
2728
)
2829

2930
from pydantic import BaseModel
@@ -45,6 +46,11 @@
4546
from .sql.sqltypes import GUID, AutoString
4647
from .typing import SQLModelConfig
4748

49+
if sys.version_info >= (3, 8):
50+
from typing import get_args, get_origin
51+
else:
52+
from typing_extensions import get_args, get_origin
53+
4854
_T = TypeVar("_T")
4955
NoArgAnyCallable = Callable[[], Any]
5056
NoneType = type(None)
@@ -61,7 +67,6 @@ def __dataclass_transform__(
6167

6268

6369
class FieldInfo(PydanticFieldInfo):
64-
6570
nullable: Union[bool, PydanticUndefinedType]
6671

6772
def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None:
@@ -401,7 +406,7 @@ def __init__(
401406

402407

403408
def get_sqlalchemy_type(field: FieldInfo) -> Any:
404-
type_: type | None = field.annotation
409+
type_: Optional[type] = field.annotation
405410

406411
# Resolve Optional fields
407412
if type_ is not None and get_origin(type_) is Union:
@@ -486,7 +491,7 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore
486491
"index": index,
487492
"unique": unique,
488493
}
489-
sa_default: PydanticUndefinedType | Callable[[], Any] = PydanticUndefined
494+
sa_default: Union[PydanticUndefinedType, Callable[[], Any]] = PydanticUndefined
490495
if field.default_factory:
491496
sa_default = field.default_factory
492497
elif field.default is not PydanticUndefined:
@@ -531,7 +536,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any:
531536
def __init__(__pydantic_self__, **data: Any) -> None:
532537
old_dict = __pydantic_self__.__dict__.copy()
533538
super().__init__(**data)
534-
__pydantic_self__.__dict__ = old_dict | __pydantic_self__.__dict__
539+
__pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__}
535540
non_pydantic_keys = data.keys() - __pydantic_self__.model_fields
536541
for key in non_pydantic_keys:
537542
if key in __pydantic_self__.__sqlmodel_relationships__:
@@ -560,12 +565,12 @@ def __tablename__(cls) -> str:
560565

561566
@classmethod
562567
def model_validate(
563-
cls: type[_TSQLModel],
568+
cls: Type[_TSQLModel],
564569
obj: Any,
565570
*,
566-
strict: bool | None = None,
567-
from_attributes: bool | None = None,
568-
context: dict[str, Any] | None = None,
571+
strict: Optional[bool] = None,
572+
from_attributes: Optional[bool] = None,
573+
context: Optional[Dict[str, Any]] = None,
569574
) -> _TSQLModel:
570575
# Somehow model validate doesn't call __init__ so it would remove our init logic
571576
validated = super().model_validate(
@@ -590,7 +595,9 @@ def model_validate(
590595

591596

592597
def _is_field_noneable(field: FieldInfo) -> bool:
593-
if not isinstance(field.nullable, PydanticUndefinedType):
598+
if hasattr(field, "nullable") and not isinstance(
599+
field.nullable, PydanticUndefinedType
600+
):
594601
return field.nullable
595602
if not field.is_required():
596603
default = getattr(field, "original_default", field.default)

0 commit comments

Comments
 (0)