Skip to content

Commit 7ecbc38

Browse files
Write for pydantic v1 and v2 compat
1 parent f590548 commit 7ecbc38

File tree

8 files changed

+478
-146
lines changed

8 files changed

+478
-146
lines changed

.github/workflows/test.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ jobs:
2121
strategy:
2222
matrix:
2323
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
24+
pydantic-version: ["pydantic-v1", "pydantic-v2"]
2425
fail-fast: false
2526

2627
steps:
@@ -51,6 +52,12 @@ jobs:
5152
- name: Install Dependencies
5253
if: steps.cache.outputs.cache-hit != 'true'
5354
run: python -m poetry install
55+
- name: Install Pydantic v1
56+
if: matrix.pydantic-version == 'pydantic-v1'
57+
run: pip install "pydantic>=1.10.0,<2.0.0"
58+
- name: Install Pydantic v2
59+
if: matrix.pydantic-version == 'pydantic-v2'
60+
run: pip install "pydantic>=2.0.2,<3.0.0"
5461
- name: Lint
5562
run: python -m poetry run bash scripts/lint.sh
5663
- run: mkdir coverage

sqlmodel/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from sqlalchemy.sql import (
3131
LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
3232
)
33+
from sqlalchemy.sql import Subquery as Subquery
3334
from sqlalchemy.sql import alias as alias
3435
from sqlalchemy.sql import all_ as all_
3536
from sqlalchemy.sql import and_ as and_
@@ -70,7 +71,6 @@
7071
from sqlalchemy.sql import outerjoin as outerjoin
7172
from sqlalchemy.sql import outparam as outparam
7273
from sqlalchemy.sql import over as over
73-
from sqlalchemy.sql import Subquery as Subquery
7474
from sqlalchemy.sql import table as table
7575
from sqlalchemy.sql import tablesample as tablesample
7676
from sqlalchemy.sql import text as text

sqlmodel/compat.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
from typing import (
2+
TYPE_CHECKING,
3+
Any,
4+
Callable,
5+
Dict,
6+
ForwardRef,
7+
Optional,
8+
Type,
9+
TypeVar,
10+
Union,
11+
get_args,
12+
get_origin,
13+
)
14+
15+
from pydantic import VERSION as PYDANTIC_VERSION
16+
17+
IS_PYDANTIC_V2 = int(PYDANTIC_VERSION.split(".")[0]) >= 2
18+
19+
20+
if IS_PYDANTIC_V2:
21+
from pydantic import ConfigDict
22+
from pydantic_core import PydanticUndefined as PydanticUndefined, PydanticUndefinedType as PydanticUndefinedType # noqa
23+
else:
24+
from pydantic import BaseConfig # noqa
25+
from pydantic.fields import ModelField # noqa
26+
from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType # noqa
27+
28+
if TYPE_CHECKING:
29+
from .main import FieldInfo, RelationshipInfo, SQLModel, SQLModelMetaclass
30+
31+
32+
NoArgAnyCallable = Callable[[], Any]
33+
T = TypeVar("T")
34+
InstanceOrType = Union[T, Type[T]]
35+
36+
if IS_PYDANTIC_V2:
37+
38+
class SQLModelConfig(ConfigDict, total=False):
39+
table: Optional[bool]
40+
read_from_attributes: Optional[bool]
41+
registry: Optional[Any]
42+
43+
else:
44+
45+
class SQLModelConfig(BaseConfig):
46+
table: Optional[bool] = None
47+
read_from_attributes: Optional[bool] = None
48+
registry: Optional[Any] = None
49+
50+
def __getitem__(self, item: str) -> Any:
51+
return self.__getattr__(item)
52+
53+
def __setitem__(self, item: str, value: Any) -> None:
54+
return self.__setattr__(item, value)
55+
56+
57+
# Inspired from https://github.com/roman-right/beanie/blob/main/beanie/odm/utils/pydantic.py
58+
def get_model_config(model: type) -> Optional[SQLModelConfig]:
59+
if IS_PYDANTIC_V2:
60+
return getattr(model, "model_config", None)
61+
else:
62+
return getattr(model, "Config", None)
63+
64+
65+
def get_config_value(
66+
model: InstanceOrType["SQLModel"], parameter: str, default: Any = None
67+
) -> Any:
68+
if IS_PYDANTIC_V2:
69+
return model.model_config.get(parameter, default)
70+
else:
71+
return getattr(model.Config, parameter, default)
72+
73+
74+
def set_config_value(
75+
model: InstanceOrType["SQLModel"], parameter: str, value: Any, v1_parameter: str = None
76+
) -> None:
77+
if IS_PYDANTIC_V2:
78+
model.model_config[parameter] = value # type: ignore
79+
else:
80+
model.Config[v1_parameter or parameter] = value # type: ignore
81+
82+
83+
def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]:
84+
if IS_PYDANTIC_V2:
85+
return model.model_fields # type: ignore
86+
else:
87+
return model.__fields__ # type: ignore
88+
89+
90+
def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]:
91+
if IS_PYDANTIC_V2:
92+
return model.__pydantic_fields_set__
93+
else:
94+
return model.__fields_set__ # type: ignore
95+
96+
97+
def set_fields_set(
98+
new_object: InstanceOrType["SQLModel"], fields: set["FieldInfo"]
99+
) -> None:
100+
if IS_PYDANTIC_V2:
101+
object.__setattr__(new_object, "__pydantic_fields_set__", fields)
102+
else:
103+
object.__setattr__(new_object, "__fields_set__", fields)
104+
105+
106+
def set_attribute_mode(cls: Type["SQLModelMetaclass"]) -> None:
107+
if IS_PYDANTIC_V2:
108+
cls.model_config["read_from_attributes"] = True
109+
else:
110+
cls.__config__.read_with_orm_mode = True # type: ignore
111+
112+
113+
def get_relationship_to(
114+
name: str,
115+
rel_info: "RelationshipInfo",
116+
annotation: Any,
117+
) -> Any:
118+
if IS_PYDANTIC_V2:
119+
relationship_to = get_origin(annotation)
120+
# Direct relationships (e.g. 'Team' or Team) have None as an origin
121+
if relationship_to is None:
122+
relationship_to = annotation
123+
# If Union (e.g. Optional), get the real field
124+
elif relationship_to is Union:
125+
relationship_to = get_args(annotation)[0]
126+
# If a list, then also get the real field
127+
elif relationship_to is list:
128+
relationship_to = get_args(annotation)[0]
129+
if isinstance(relationship_to, ForwardRef):
130+
relationship_to = relationship_to.__forward_arg__
131+
return relationship_to
132+
else:
133+
temp_field = ModelField.infer(
134+
name=name,
135+
value=rel_info,
136+
annotation=annotation,
137+
class_validators=None,
138+
config=SQLModelConfig,
139+
)
140+
relationship_to = temp_field.type_
141+
if isinstance(temp_field.type_, ForwardRef):
142+
relationship_to = temp_field.type_.__forward_arg__
143+
return relationship_to
144+
145+
146+
def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) -> None:
147+
"""
148+
Pydantic v2 without required fields with no optionals cannot do empty initialisations.
149+
This means we cannot do Table() and set fields later.
150+
We go around this by adding a default to everything, being None
151+
152+
Args:
153+
annotations: Dict[str, Any]: The annotations to provide to pydantic
154+
class_dict: Dict[str, Any]: The class dict for the defaults
155+
"""
156+
if IS_PYDANTIC_V2:
157+
from .main import FieldInfo
158+
# Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything
159+
for key in annotations.keys():
160+
value = class_dict.get(key, PydanticUndefined)
161+
if value is PydanticUndefined:
162+
class_dict[key] = None
163+
elif isinstance(value, FieldInfo):
164+
if (
165+
value.default in (PydanticUndefined, Ellipsis)
166+
) and value.default_factory is None:
167+
# So we can check for nullable
168+
value.original_default = value.default
169+
value.default = None

0 commit comments

Comments
 (0)