From d0ae3e81db377cf4d3ccf489e5fe3ce18bc86ddb Mon Sep 17 00:00:00 2001 From: Mohamed Farahat Date: Tue, 7 Mar 2023 02:08:36 +0200 Subject: [PATCH 001/105] migration to sqlalchemy 2.0 --- docs_src/tutorial/many_to_many/tutorial003.py | 30 +++++++-------- .../back_populates/tutorial003.py | 30 +++++++-------- sqlmodel/__init__.py | 2 - sqlmodel/main.py | 1 + sqlmodel/sql/expression.py | 38 +++---------------- 5 files changed, 36 insertions(+), 65 deletions(-) diff --git a/docs_src/tutorial/many_to_many/tutorial003.py b/docs_src/tutorial/many_to_many/tutorial003.py index 1e03c4af89..cec6e56560 100644 --- a/docs_src/tutorial/many_to_many/tutorial003.py +++ b/docs_src/tutorial/many_to_many/tutorial003.py @@ -3,25 +3,12 @@ from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select -class HeroTeamLink(SQLModel, table=True): - team_id: Optional[int] = Field( - default=None, foreign_key="team.id", primary_key=True - ) - hero_id: Optional[int] = Field( - default=None, foreign_key="hero.id", primary_key=True - ) - is_training: bool = False - - team: "Team" = Relationship(back_populates="hero_links") - hero: "Hero" = Relationship(back_populates="team_links") - - class Team(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) name: str = Field(index=True) headquarters: str - hero_links: List[HeroTeamLink] = Relationship(back_populates="team") + hero_links: List["HeroTeamLink"] = Relationship(back_populates="team") class Hero(SQLModel, table=True): @@ -30,7 +17,20 @@ class Hero(SQLModel, table=True): secret_name: str age: Optional[int] = Field(default=None, index=True) - team_links: List[HeroTeamLink] = Relationship(back_populates="hero") + team_links: List["HeroTeamLink"] = Relationship(back_populates="hero") + + +class HeroTeamLink(SQLModel, table=True): + team_id: Optional[int] = Field( + default=None, foreign_key="team.id", primary_key=True + ) + hero_id: Optional[int] = Field( + default=None, foreign_key="hero.id", primary_key=True + ) + is_training: bool = False + + team: "Team" = Relationship(back_populates="hero_links") + hero: "Hero" = Relationship(back_populates="team_links") sqlite_file_name = "database.db" diff --git a/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py b/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py index 98e197002e..8d91a0bc25 100644 --- a/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py +++ b/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py @@ -3,6 +3,21 @@ from sqlmodel import Field, Relationship, SQLModel, create_engine +class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + secret_name: str + age: Optional[int] = Field(default=None, index=True) + + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + team: Optional["Team"] = Relationship(back_populates="heroes") + + weapon_id: Optional[int] = Field(default=None, foreign_key="weapon.id") + weapon: Optional["Weapon"] = Relationship(back_populates="hero") + + powers: List["Power"] = Relationship(back_populates="hero") + + class Weapon(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) name: str = Field(index=True) @@ -26,21 +41,6 @@ class Team(SQLModel, table=True): heroes: List["Hero"] = Relationship(back_populates="team") -class Hero(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - name: str = Field(index=True) - secret_name: str - age: Optional[int] = Field(default=None, index=True) - - team_id: Optional[int] = Field(default=None, foreign_key="team.id") - team: Optional[Team] = Relationship(back_populates="heroes") - - weapon_id: Optional[int] = Field(default=None, foreign_key="weapon.id") - weapon: Optional[Weapon] = Relationship(back_populates="hero") - - powers: List[Power] = Relationship(back_populates="hero") - - sqlite_file_name = "database.db" sqlite_url = f"sqlite:///{sqlite_file_name}" diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index 495ac9c8a8..af37421754 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -21,7 +21,6 @@ from sqlalchemy.schema import PrimaryKeyConstraint as PrimaryKeyConstraint from sqlalchemy.schema import Sequence as Sequence from sqlalchemy.schema import Table as Table -from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData from sqlalchemy.schema import UniqueConstraint as UniqueConstraint from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT from sqlalchemy.sql import ( @@ -71,7 +70,6 @@ from sqlalchemy.sql import outerjoin as outerjoin from sqlalchemy.sql import outparam as outparam from sqlalchemy.sql import over as over -from sqlalchemy.sql import subquery as subquery from sqlalchemy.sql import table as table from sqlalchemy.sql import tablesample as tablesample from sqlalchemy.sql import text as text diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 2b69dd2a75..e06b651cfd 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -642,6 +642,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore __name__: ClassVar[str] metadata: ClassVar[MetaData] + __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six class Config: orm_mode = True diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index 264e39cba7..29e7524ce7 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -117,12 +117,12 @@ class SelectOfScalar(_Select, Generic[_TSelect]): @overload -def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore +def select(entity_0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore ... @overload -def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore +def select(entity_0: Type[_TModel_0]) -> SelectOfScalar[_TModel_0]: # type: ignore ... @@ -133,7 +133,6 @@ def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1]]: ... @@ -142,7 +141,6 @@ def select( # type: ignore def select( # type: ignore entity_0: _TScalar_0, entity_1: Type[_TModel_1], - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1]]: ... @@ -151,7 +149,6 @@ def select( # type: ignore def select( # type: ignore entity_0: Type[_TModel_0], entity_1: _TScalar_1, - **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1]]: ... @@ -160,7 +157,6 @@ def select( # type: ignore def select( # type: ignore entity_0: Type[_TModel_0], entity_1: Type[_TModel_1], - **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1]]: ... @@ -170,7 +166,6 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, entity_2: _TScalar_2, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2]]: ... @@ -180,7 +175,6 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, entity_2: Type[_TModel_2], - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2]]: ... @@ -190,7 +184,6 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: Type[_TModel_1], entity_2: _TScalar_2, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2]]: ... @@ -200,7 +193,6 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2]]: ... @@ -210,7 +202,6 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: _TScalar_1, entity_2: _TScalar_2, - **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2]]: ... @@ -220,7 +211,6 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: _TScalar_1, entity_2: Type[_TModel_2], - **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2]]: ... @@ -230,7 +220,6 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: Type[_TModel_1], entity_2: _TScalar_2, - **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2]]: ... @@ -240,7 +229,6 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], - **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2]]: ... @@ -251,7 +239,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]: ... @@ -262,7 +249,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TModel_3]]: ... @@ -273,7 +259,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TScalar_3]]: ... @@ -284,7 +269,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TModel_3]]: ... @@ -295,7 +279,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TScalar_3]]: ... @@ -306,7 +289,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TModel_3]]: ... @@ -317,7 +299,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TScalar_3]]: ... @@ -328,7 +309,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TModel_3]]: ... @@ -339,7 +319,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TScalar_3]]: ... @@ -350,7 +329,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TModel_3]]: ... @@ -361,7 +339,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TScalar_3]]: ... @@ -372,7 +349,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TModel_3]]: ... @@ -383,7 +359,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TScalar_3]]: ... @@ -394,7 +369,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TModel_3]]: ... @@ -405,7 +379,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TScalar_3]]: ... @@ -416,7 +389,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TModel_3]]: ... @@ -424,10 +396,10 @@ def select( # type: ignore # Generated overloads end -def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore +def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore if len(entities) == 1: - return SelectOfScalar._create(*entities, **kw) # type: ignore - return Select._create(*entities, **kw) # type: ignore + return SelectOfScalar(*entities) # type: ignore + return Select(*entities) # type: ignore # TODO: add several @overload from Python types to SQLAlchemy equivalents From 8a27655052dfdc7a6317da72895df5a46a7760b1 Mon Sep 17 00:00:00 2001 From: Mohamed Farahat Date: Fri, 24 Mar 2023 14:56:26 +0200 Subject: [PATCH 002/105] fix some linting errors --- sqlmodel/engine/create.py | 2 +- sqlmodel/engine/result.py | 20 ++++++++++---------- sqlmodel/main.py | 4 ++-- sqlmodel/orm/session.py | 2 +- sqlmodel/sql/expression.py | 2 +- sqlmodel/sql/sqltypes.py | 6 +++--- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/sqlmodel/engine/create.py b/sqlmodel/engine/create.py index b2d567b1b1..97481259e2 100644 --- a/sqlmodel/engine/create.py +++ b/sqlmodel/engine/create.py @@ -136,4 +136,4 @@ def create_engine( if not isinstance(query_cache_size, _DefaultPlaceholder): current_kwargs["query_cache_size"] = query_cache_size current_kwargs.update(kwargs) - return _create_engine(url, **current_kwargs) # type: ignore + return _create_engine(url, **current_kwargs) diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py index 7a25422227..17020d9995 100644 --- a/sqlmodel/engine/result.py +++ b/sqlmodel/engine/result.py @@ -1,4 +1,4 @@ -from typing import Generic, Iterator, List, Optional, TypeVar +from typing import Generic, Iterator, List, Optional, Sequence, TypeVar from sqlalchemy.engine.result import Result as _Result from sqlalchemy.engine.result import ScalarResult as _ScalarResult @@ -6,24 +6,24 @@ _T = TypeVar("_T") -class ScalarResult(_ScalarResult, Generic[_T]): - def all(self) -> List[_T]: +class ScalarResult(_ScalarResult[_T], Generic[_T]): + def all(self) -> Sequence[_T]: return super().all() - def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: + def partitions(self, size: Optional[int] = None) -> Iterator[Sequence[_T]]: return super().partitions(size) - def fetchall(self) -> List[_T]: + def fetchall(self) -> Sequence[_T]: return super().fetchall() - def fetchmany(self, size: Optional[int] = None) -> List[_T]: + def fetchmany(self, size: Optional[int] = None) -> Sequence[_T]: return super().fetchmany(size) def __iter__(self) -> Iterator[_T]: return super().__iter__() def __next__(self) -> _T: - return super().__next__() # type: ignore + return super().__next__() def first(self) -> Optional[_T]: return super().first() @@ -32,10 +32,10 @@ def one_or_none(self) -> Optional[_T]: return super().one_or_none() def one(self) -> _T: - return super().one() # type: ignore + return super().one() -class Result(_Result, Generic[_T]): +class Result(_Result[_T], Generic[_T]): def scalars(self, index: int = 0) -> ScalarResult[_T]: return super().scalars(index) # type: ignore @@ -76,4 +76,4 @@ def one(self) -> _T: # type: ignore return super().one() # type: ignore def scalar(self) -> Optional[_T]: - return super().scalar() + return super().scalar() # type: ignore diff --git a/sqlmodel/main.py b/sqlmodel/main.py index e06b651cfd..06e58618aa 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -642,7 +642,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore __name__: ClassVar[str] metadata: ClassVar[MetaData] - __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six + __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six class Config: orm_mode = True @@ -686,7 +686,7 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if getattr(self.__config__, "table", False) and is_instrumented(self, name): + if getattr(self.__config__, "table", False) and is_instrumented(self, name): # type: ignore set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 0c70c290ae..03f33037ff 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -118,7 +118,7 @@ def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]": Or otherwise you might want to use `session.execute()` instead of `session.query()`. """ - return super().query(*entities, **kwargs) + return super().query(*entities, **kwargs) # type: ignore def get( self, diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index 29e7524ce7..f473ba4a5a 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -406,4 +406,4 @@ def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore def col(column_expression: Any) -> ColumnClause: # type: ignore if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") - return column_expression + return column_expression # type: ignore diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 17d9b06126..aa30950702 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -15,7 +15,7 @@ class AutoString(types.TypeDecorator): # type: ignore def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": impl = cast(types.String, self.impl) if impl.length is None and dialect.name == "mysql": - return dialect.type_descriptor(types.String(self.mysql_default_length)) # type: ignore + return dialect.type_descriptor(types.String(self.mysql_default_length)) return super().load_dialect_impl(dialect) @@ -34,9 +34,9 @@ class GUID(types.TypeDecorator): # type: ignore def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore if dialect.name == "postgresql": - return dialect.type_descriptor(UUID()) # type: ignore + return dialect.type_descriptor(UUID()) else: - return dialect.type_descriptor(CHAR(32)) # type: ignore + return dialect.type_descriptor(CHAR(32)) def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]: if value is None: From a2b3c1465892acf5b3309110a791cb99520e5f8d Mon Sep 17 00:00:00 2001 From: Mohamed Farahat Date: Fri, 31 Mar 2023 12:59:56 +0200 Subject: [PATCH 003/105] reflecting python 3.6 deprecation in docs and tests --- .github/workflows/test.yml | 7 ++----- docs/contributing.md | 4 ++++ docs/features.md | 2 +- docs/tutorial/index.md | 3 +++ pyproject.toml | 1 + 5 files changed, 11 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 201abc7c22..d2d045aeb9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,11 +20,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: - - "3.7" - - "3.8" - - "3.9" - - "3.10" + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] fail-fast: false steps: @@ -56,6 +52,7 @@ jobs: if: steps.cache.outputs.cache-hit != 'true' run: python -m poetry install - name: Lint + if: ${{ matrix.python-version != '3.7' }} run: python -m poetry run bash scripts/lint.sh - run: mkdir coverage - name: Test diff --git a/docs/contributing.md b/docs/contributing.md index 217ed61c56..6a1ae2d6c7 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -6,6 +6,10 @@ First, you might want to see the basic ways to [help SQLModel and get help](help If you already cloned the repository and you know that you need to deep dive in the code, here are some guidelines to set up your environment. +### Python + +SQLModel supports Python 3.7 and above, but for development you should have at least **Python 3.7**. + ### Poetry **SQLModel** uses Poetry to build, package, and publish the project. diff --git a/docs/features.md b/docs/features.md index 102edef725..2d5e11d84f 100644 --- a/docs/features.md +++ b/docs/features.md @@ -12,7 +12,7 @@ Nevertheless, SQLModel is completely **independent** of FastAPI and can be used ## Just Modern Python -It's all based on standard modern **Python** type annotations. No new syntax to learn. Just standard modern Python. +It's all based on standard modern **Python** type annotations. No new syntax to learn. Just standard modern Python. If you need a 2 minute refresher of how to use Python types (even if you don't use SQLModel or FastAPI), check the FastAPI tutorial section: Python types intro. diff --git a/docs/tutorial/index.md b/docs/tutorial/index.md index 74107776c2..5a4333df11 100644 --- a/docs/tutorial/index.md +++ b/docs/tutorial/index.md @@ -64,6 +64,8 @@ $ cd sqlmodel-tutorial Make sure you have an officially supported version of Python. +Currently it is **Python 3.7** and above (Python 3.6 was already deprecated). + You can check which version you have with:
@@ -82,6 +84,7 @@ You might want to try with the specific versions, for example with: * `python3.10` * `python3.9` * `python3.8` +* `python3.7` The code would look like this: diff --git a/pyproject.toml b/pyproject.toml index 23fa79bf31..6777d69018 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Database", "Topic :: Database :: Database Engines/Servers", "Topic :: Internet", From 4760dbde8c88d296aab3fde2a9cf3c7974cee381 Mon Sep 17 00:00:00 2001 From: farahats9 Date: Fri, 31 Mar 2023 12:58:30 +0200 Subject: [PATCH 004/105] Update sqlmodel/sql/expression.py Co-authored-by: Stefan Borer --- sqlmodel/sql/expression.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index f473ba4a5a..10776e9389 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -22,8 +22,7 @@ _TSelect = TypeVar("_TSelect") - -class Select(_Select, Generic[_TSelect]): +class Select(_Select[_TSelect], Generic[_TSelect]): inherit_cache = True From ad76a886ccadaabd6d8617c8cf2eda119e25284d Mon Sep 17 00:00:00 2001 From: Mohamed Farahat Date: Fri, 31 Mar 2023 13:18:38 +0200 Subject: [PATCH 005/105] resolving @sbor23 comments --- .github/workflows/test.yml | 1 - docs/contributing.md | 2 +- docs/tutorial/index.md | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d2d045aeb9..9f2e06ed71 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -52,7 +52,6 @@ jobs: if: steps.cache.outputs.cache-hit != 'true' run: python -m poetry install - name: Lint - if: ${{ matrix.python-version != '3.7' }} run: python -m poetry run bash scripts/lint.sh - run: mkdir coverage - name: Test diff --git a/docs/contributing.md b/docs/contributing.md index 6a1ae2d6c7..5c160a7c0a 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -8,7 +8,7 @@ If you already cloned the repository and you know that you need to deep dive in ### Python -SQLModel supports Python 3.7 and above, but for development you should have at least **Python 3.7**. +SQLModel supports Python 3.7 and above. ### Poetry diff --git a/docs/tutorial/index.md b/docs/tutorial/index.md index 5a4333df11..54e1147d68 100644 --- a/docs/tutorial/index.md +++ b/docs/tutorial/index.md @@ -81,6 +81,7 @@ There's a chance that you have multiple Python versions installed. You might want to try with the specific versions, for example with: +* `python3.11` * `python3.10` * `python3.9` * `python3.8` From 96e44e5cadebc0d32b209ab275f61294743ab7e4 Mon Sep 17 00:00:00 2001 From: Mohamed Farahat Date: Sun, 30 Apr 2023 17:59:23 +0300 Subject: [PATCH 006/105] add the new Subquery class --- sqlmodel/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index af37421754..4431d7cea7 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -70,6 +70,7 @@ from sqlalchemy.sql import outerjoin as outerjoin from sqlalchemy.sql import outparam as outparam from sqlalchemy.sql import over as over +from sqlalchemy.sql import Subquery as Subquery from sqlalchemy.sql import table as table from sqlalchemy.sql import tablesample as tablesample from sqlalchemy.sql import text as text From 9e72750eea8ada4d270013bf56db5b38a6b9a709 Mon Sep 17 00:00:00 2001 From: Mohamed Farahat Date: Sun, 30 Apr 2023 18:42:46 +0300 Subject: [PATCH 007/105] fix jinja2 template --- sqlmodel/sql/expression.py | 1 + sqlmodel/sql/expression.py.jinja2 | 25 +++++++++++++++++-------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index 10776e9389..32ec7c428b 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -22,6 +22,7 @@ _TSelect = TypeVar("_TSelect") + class Select(_Select[_TSelect], Generic[_TSelect]): inherit_cache = True diff --git a/sqlmodel/sql/expression.py.jinja2 b/sqlmodel/sql/expression.py.jinja2 index 26d12a0395..9bda190415 100644 --- a/sqlmodel/sql/expression.py.jinja2 +++ b/sqlmodel/sql/expression.py.jinja2 @@ -20,7 +20,14 @@ from sqlalchemy.sql.expression import Select as _Select _TSelect = TypeVar("_TSelect") -class Select(_Select, Generic[_TSelect]): +class Select(_Select[_TSelect], Generic[_TSelect]): + inherit_cache = True + +# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different +# purpose. This is the same as a normal SQLAlchemy Select class where there's only one +# entity, so the result will be converted to a scalar by default. This way writing +# for loops on the results will feel natural. +class SelectOfScalar(_Select[_TSelect], Generic[_TSelect]): inherit_cache = True # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different @@ -35,6 +42,7 @@ if TYPE_CHECKING: # pragma: no cover # Generated TypeVars start + {% for i in range(number_of_types) %} _TScalar_{{ i }} = TypeVar( "_TScalar_{{ i }}", @@ -58,12 +66,12 @@ _TModel_{{ i }} = TypeVar("_TModel_{{ i }}", bound="SQLModel") # Generated TypeVars end @overload -def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore +def select(entity_0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore ... @overload -def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore +def select(entity_0: Type[_TModel_0]) -> SelectOfScalar[_TModel_0]: # type: ignore ... @@ -73,7 +81,7 @@ def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: @overload def select( # type: ignore - {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}**kw: Any, + {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %} ) -> Select[Tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]: ... @@ -81,14 +89,15 @@ def select( # type: ignore # Generated overloads end -def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore + +def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore if len(entities) == 1: - return SelectOfScalar._create(*entities, **kw) # type: ignore - return Select._create(*entities, **kw) # type: ignore + return SelectOfScalar(*entities) # type: ignore + return Select(*entities) # type: ignore # TODO: add several @overload from Python types to SQLAlchemy equivalents def col(column_expression: Any) -> ColumnClause: # type: ignore if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") - return column_expression + return column_expression # type: ignore From b19a70961ffa9d7c8673bc19d5732c1c8351e436 Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 13:37:18 +0200 Subject: [PATCH 008/105] `Result` expects a type `Tuple[_T]` --- sqlmodel/engine/result.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py index 17020d9995..a0ddf283b9 100644 --- a/sqlmodel/engine/result.py +++ b/sqlmodel/engine/result.py @@ -1,4 +1,4 @@ -from typing import Generic, Iterator, List, Optional, Sequence, TypeVar +from typing import Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar from sqlalchemy.engine.result import Result as _Result from sqlalchemy.engine.result import ScalarResult as _ScalarResult @@ -35,7 +35,7 @@ def one(self) -> _T: return super().one() -class Result(_Result[_T], Generic[_T]): +class Result(_Result[Tuple[_T]], Generic[_T]): def scalars(self, index: int = 0) -> ScalarResult[_T]: return super().scalars(index) # type: ignore From ae369ed23a58ed47a053d11e7263f9a9ca4bf7b9 Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 13:37:39 +0200 Subject: [PATCH 009/105] Remove unused type ignore --- sqlmodel/engine/result.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py index a0ddf283b9..ecdb6cd547 100644 --- a/sqlmodel/engine/result.py +++ b/sqlmodel/engine/result.py @@ -67,7 +67,7 @@ def one_or_none(self) -> Optional[_T]: # type: ignore return super().one_or_none() # type: ignore def scalar_one(self) -> _T: - return super().scalar_one() # type: ignore + return super().scalar_one() def scalar_one_or_none(self) -> Optional[_T]: return super().scalar_one_or_none() @@ -76,4 +76,4 @@ def one(self) -> _T: # type: ignore return super().one() # type: ignore def scalar(self) -> Optional[_T]: - return super().scalar() # type: ignore + return super().scalar() From 2ff42db5d9f06447a2d0b25371ef39fb75c04b21 Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 14:05:06 +0200 Subject: [PATCH 010/105] Result seems well enough typed in SqlAlchemy now we can simply shim over --- sqlmodel/engine/result.py | 44 ++------------------------------------- 1 file changed, 2 insertions(+), 42 deletions(-) diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py index ecdb6cd547..650dd92b27 100644 --- a/sqlmodel/engine/result.py +++ b/sqlmodel/engine/result.py @@ -1,4 +1,4 @@ -from typing import Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar +from typing import Generic, Iterator, Optional, Sequence, Tuple, TypeVar from sqlalchemy.engine.result import Result as _Result from sqlalchemy.engine.result import ScalarResult as _ScalarResult @@ -36,44 +36,4 @@ def one(self) -> _T: class Result(_Result[Tuple[_T]], Generic[_T]): - def scalars(self, index: int = 0) -> ScalarResult[_T]: - return super().scalars(index) # type: ignore - - def __iter__(self) -> Iterator[_T]: # type: ignore - return super().__iter__() # type: ignore - - def __next__(self) -> _T: # type: ignore - return super().__next__() # type: ignore - - def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: # type: ignore - return super().partitions(size) # type: ignore - - def fetchall(self) -> List[_T]: # type: ignore - return super().fetchall() # type: ignore - - def fetchone(self) -> Optional[_T]: # type: ignore - return super().fetchone() # type: ignore - - def fetchmany(self, size: Optional[int] = None) -> List[_T]: # type: ignore - return super().fetchmany() # type: ignore - - def all(self) -> List[_T]: # type: ignore - return super().all() # type: ignore - - def first(self) -> Optional[_T]: # type: ignore - return super().first() # type: ignore - - def one_or_none(self) -> Optional[_T]: # type: ignore - return super().one_or_none() # type: ignore - - def scalar_one(self) -> _T: - return super().scalar_one() - - def scalar_one_or_none(self) -> Optional[_T]: - return super().scalar_one_or_none() - - def one(self) -> _T: # type: ignore - return super().one() # type: ignore - - def scalar(self) -> Optional[_T]: - return super().scalar() + ... From b6bd94f5a7a2551fabc0ddcbecb2abcde0ba28c8 Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 14:11:13 +0200 Subject: [PATCH 011/105] _Select expects a `Tuple[Any, ...]` --- sqlmodel/sql/expression.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index 32ec7c428b..a0ac1bd9d9 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -23,7 +23,7 @@ _TSelect = TypeVar("_TSelect") -class Select(_Select[_TSelect], Generic[_TSelect]): +class Select(_Select[Tuple[_TSelect]], Generic[_TSelect]): inherit_cache = True @@ -31,7 +31,7 @@ class Select(_Select[_TSelect], Generic[_TSelect]): # purpose. This is the same as a normal SQLAlchemy Select class where there's only one # entity, so the result will be converted to a scalar by default. This way writing # for loops on the results will feel natural. -class SelectOfScalar(_Select, Generic[_TSelect]): +class SelectOfScalar(_Select[Tuple[_TSelect]], Generic[_TSelect]): inherit_cache = True From 1dbce4d3c46e95c0cc8e089f2c936f2931136210 Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 14:21:03 +0200 Subject: [PATCH 012/105] Use Dict type instead of Mapping for SqlAlchemy compat --- sqlmodel/orm/session.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 03f33037ff..21bc0780e7 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -1,4 +1,14 @@ -from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload +from typing import ( + Any, + Dict, + Mapping, + Optional, + Sequence, + Type, + TypeVar, + Union, + overload, +) from sqlalchemy import util from sqlalchemy.orm import Query as _Query @@ -21,7 +31,7 @@ def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, @@ -35,7 +45,7 @@ def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, @@ -52,7 +62,7 @@ def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, @@ -75,7 +85,7 @@ def execute( statement: _Executable, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, From c2d310e94442fc17bc957067b02b030a374485ce Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 14:24:01 +0200 Subject: [PATCH 013/105] Execution options are not Optional in SA --- sqlmodel/orm/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 21bc0780e7..69fc7c68a7 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -84,7 +84,7 @@ def execute( self, statement: _Executable, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, - execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, From 05a1352aa59f3db36707747d95d996405794ca68 Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 14:26:41 +0200 Subject: [PATCH 014/105] Another instance of non-optional execution_options --- sqlmodel/orm/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 69fc7c68a7..c467a3b524 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -138,7 +138,7 @@ def get( populate_existing: bool = False, with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None, identity_token: Optional[Any] = None, - execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT, + execution_options: Mapping[Any, Any] = util.EMPTY_DICT, ) -> Optional[_TSelectParam]: return super().get( entity, From bd00a2b4498f5517df3bee2e9c6ae107856c0b2c Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 14:49:28 +0200 Subject: [PATCH 015/105] Fix Tuple in jinja template as well --- sqlmodel/sql/expression.py.jinja2 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/sql/expression.py.jinja2 b/sqlmodel/sql/expression.py.jinja2 index 9bda190415..55f4a1ac3e 100644 --- a/sqlmodel/sql/expression.py.jinja2 +++ b/sqlmodel/sql/expression.py.jinja2 @@ -20,14 +20,14 @@ from sqlalchemy.sql.expression import Select as _Select _TSelect = TypeVar("_TSelect") -class Select(_Select[_TSelect], Generic[_TSelect]): +class Select(_Select[Tuple[_TSelect]], Generic[_TSelect]): inherit_cache = True # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different # purpose. This is the same as a normal SQLAlchemy Select class where there's only one # entity, so the result will be converted to a scalar by default. This way writing # for loops on the results will feel natural. -class SelectOfScalar(_Select[_TSelect], Generic[_TSelect]): +class SelectOfScalar(_Select[Tuple[_TSelect]], Generic[_TSelect]): inherit_cache = True # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different From c65f018356fac713913aa28b1031d5b28477459c Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 17:39:35 +0200 Subject: [PATCH 016/105] Use ForUpdateArg from sqlalchemy --- sqlmodel/orm/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index c467a3b524..fc96f7def2 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -136,7 +136,7 @@ def get( ident: Any, options: Optional[Sequence[Any]] = None, populate_existing: bool = False, - with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None, + with_for_update: Optional[_ForUpdateArg] = None, identity_token: Optional[Any] = None, execution_options: Mapping[Any, Any] = util.EMPTY_DICT, ) -> Optional[_TSelectParam]: From 4e29e002ec803b1c7e32450efb372a6fee34fbed Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 18:05:49 +0200 Subject: [PATCH 017/105] Fix signature for `Session.get` --- sqlmodel/orm/session.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index fc96f7def2..1d12188ce8 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -11,6 +11,7 @@ ) from sqlalchemy import util +from sqlalchemy.orm import Mapper as _Mapper from sqlalchemy.orm import Query as _Query from sqlalchemy.orm import Session as _Session from sqlalchemy.sql.base import Executable as _Executable @@ -132,13 +133,14 @@ def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]": def get( self, - entity: Type[_TSelectParam], + entity: Union[Type[_TSelectParam], "_Mapper[_TSelectParam]"], ident: Any, options: Optional[Sequence[Any]] = None, populate_existing: bool = False, with_for_update: Optional[_ForUpdateArg] = None, identity_token: Optional[Any] = None, execution_options: Mapping[Any, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, ) -> Optional[_TSelectParam]: return super().get( entity, @@ -148,4 +150,5 @@ def get( with_for_update=with_for_update, identity_token=identity_token, execution_options=execution_options, + bind_arguments=bind_arguments ) From fd85d02ceb16c8368b7285a5991baa618f9c9742 Mon Sep 17 00:00:00 2001 From: Mohamed Farahat Date: Thu, 27 Jul 2023 23:03:52 +0300 Subject: [PATCH 018/105] formatting and remove unused type --- sqlmodel/orm/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 1d12188ce8..c9bb043336 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -150,5 +150,5 @@ def get( with_for_update=with_for_update, identity_token=identity_token, execution_options=execution_options, - bind_arguments=bind_arguments + bind_arguments=bind_arguments, ) From d556059a2ce4abfcf489caf94033a7f1f63e5c0f Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Mon, 31 Jul 2023 13:38:26 +0000 Subject: [PATCH 019/105] Upgrade to Pydantic 2 Change imports Undefined => PydanticUndefined Update SQLModelMetaclass and SQLModel __init__ and __new__ functions Update SQL Alchemy type inference --- docs/tutorial/fastapi/multiple-models.md | 6 +- .../fastapi/app_testing/tutorial001/main.py | 2 +- .../tutorial/fastapi/delete/tutorial001.py | 2 +- .../fastapi/limit_and_offset/tutorial001.py | 2 +- .../fastapi/multiple_models/tutorial001.py | 2 +- .../fastapi/multiple_models/tutorial002.py | 2 +- .../tutorial/fastapi/read_one/tutorial001.py | 2 +- .../fastapi/relationships/tutorial001.py | 4 +- .../session_with_dependency/tutorial001.py | 2 +- .../tutorial/fastapi/teams/tutorial001.py | 4 +- .../tutorial/fastapi/update/tutorial001.py | 2 +- sqlmodel/main.py | 285 +++++------------- sqlmodel/typing.py | 7 + tests/test_instance_no_args.py | 2 +- tests/test_missing_type.py | 3 +- tests/test_validation.py | 8 +- 16 files changed, 109 insertions(+), 226 deletions(-) create mode 100644 sqlmodel/typing.py diff --git a/docs/tutorial/fastapi/multiple-models.md b/docs/tutorial/fastapi/multiple-models.md index 6845b9862d..4ea24c6752 100644 --- a/docs/tutorial/fastapi/multiple-models.md +++ b/docs/tutorial/fastapi/multiple-models.md @@ -174,13 +174,13 @@ Now we use the type annotation `HeroCreate` for the request JSON data in the `he # Code below omitted 👇 ``` -Then we create a new `Hero` (this is the actual **table** model that saves things to the database) using `Hero.from_orm()`. +Then we create a new `Hero` (this is the actual **table** model that saves things to the database) using `Hero.model_validate()`. -The method `.from_orm()` reads data from another object with attributes and creates a new instance of this class, in this case `Hero`. +The method `.model_validate()` reads data from another object with attributes and creates a new instance of this class, in this case `Hero`. The alternative is `Hero.parse_obj()` that reads data from a dictionary. -But as in this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.from_orm()` to read those attributes. +But as in this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.model_validate()` to read those attributes. With this, we create a new `Hero` instance (the one for the database) and put it in the variable `db_hero` from the data in the `hero` variable that is the `HeroCreate` instance we received from the request. diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py index 3f0602e4b4..a23dfad5a8 100644 --- a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py +++ b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py @@ -54,7 +54,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/delete/tutorial001.py b/docs_src/tutorial/fastapi/delete/tutorial001.py index 3069fc5e87..77a99a9c97 100644 --- a/docs_src/tutorial/fastapi/delete/tutorial001.py +++ b/docs_src/tutorial/fastapi/delete/tutorial001.py @@ -50,7 +50,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py index 2b8739ca70..2352f39022 100644 --- a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py +++ b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py @@ -44,7 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py index df20123333..7f59ac6a1d 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py @@ -46,7 +46,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py index 392c2c5829..fffbe72496 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py @@ -44,7 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/read_one/tutorial001.py b/docs_src/tutorial/fastapi/read_one/tutorial001.py index 4d66e471a5..f18426e74c 100644 --- a/docs_src/tutorial/fastapi/read_one/tutorial001.py +++ b/docs_src/tutorial/fastapi/read_one/tutorial001.py @@ -44,7 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/relationships/tutorial001.py b/docs_src/tutorial/fastapi/relationships/tutorial001.py index 8477e4a2a0..e5b196090e 100644 --- a/docs_src/tutorial/fastapi/relationships/tutorial001.py +++ b/docs_src/tutorial/fastapi/relationships/tutorial001.py @@ -92,7 +92,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -146,7 +146,7 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.from_orm(team) + db_team = Team.model_validate(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py index 3f0602e4b4..a23dfad5a8 100644 --- a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py +++ b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py @@ -54,7 +54,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/teams/tutorial001.py b/docs_src/tutorial/fastapi/teams/tutorial001.py index 1da0dad8a2..cc73bb52cb 100644 --- a/docs_src/tutorial/fastapi/teams/tutorial001.py +++ b/docs_src/tutorial/fastapi/teams/tutorial001.py @@ -83,7 +83,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -137,7 +137,7 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.from_orm(team) + db_team = Team.model_validate(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/update/tutorial001.py b/docs_src/tutorial/fastapi/update/tutorial001.py index bb98efd581..28462bff17 100644 --- a/docs_src/tutorial/fastapi/update/tutorial001.py +++ b/docs_src/tutorial/fastapi/update/tutorial001.py @@ -50,7 +50,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 06e58618aa..0ca69be152 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -19,7 +19,7 @@ Set, Tuple, Type, - TypeVar, + TypeVar,ForwardRef, Union, cast, overload, @@ -53,8 +53,10 @@ from sqlalchemy.sql.sqltypes import LargeBinary, Time from .sql.sqltypes import GUID, AutoString +from .typing import SQLModelConfig _T = TypeVar("_T") +NoArgAnyCallable = Callable[[], Any] def __dataclass_transform__( @@ -68,10 +70,10 @@ def __dataclass_transform__( class FieldInfo(PydanticFieldInfo): - def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: + def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) - nullable = kwargs.pop("nullable", Undefined) - foreign_key = kwargs.pop("foreign_key", Undefined) + nullable = kwargs.pop("nullable", PydanticUndefined) + foreign_key = kwargs.pop("foreign_key", PydanticUndefined) unique = kwargs.pop("unique", False) index = kwargs.pop("index", Undefined) sa_type = kwargs.pop("sa_type", Undefined) @@ -84,7 +86,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: "Passing sa_column_args is not supported when " "also passing a sa_column" ) - if sa_column_kwargs is not Undefined: + if sa_column_kwargs is not PydanticUndefined: raise RuntimeError( "Passing sa_column_kwargs is not supported when " "also passing a sa_column" @@ -157,7 +159,7 @@ def __init__( @overload def Field( - default: Any = Undefined, + default: Any = PydanticUndefined, *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, @@ -314,7 +316,6 @@ def Field( sa_column_kwargs=sa_column_kwargs, **current_schema_extra, ) - field_info._validate() return field_info @@ -343,7 +344,7 @@ def Relationship( *, back_populates: Optional[str] = None, link_model: Optional[Any] = None, - sa_relationship: Optional[RelationshipProperty] = None, # type: ignore + sa_relationship: Optional[RelationshipProperty] = None, sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, ) -> Any: @@ -360,18 +361,18 @@ def Relationship( @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: Dict[str, RelationshipInfo] - __config__: Type[BaseConfig] - __fields__: Dict[str, ModelField] + model_config: Type[SQLModelConfig] + model_fields: Dict[str, FieldInfo] # Replicate SQLAlchemy def __setattr__(cls, name: str, value: Any) -> None: - if getattr(cls.__config__, "table", False): + if cls.model_config.get("table", False): DeclarativeMeta.__setattr__(cls, name, value) else: super().__setattr__(name, value) def __delattr__(cls, name: str) -> None: - if getattr(cls.__config__, "table", False): + if cls.model_config.get("table", False): DeclarativeMeta.__delattr__(cls, name) else: super().__delattr__(name) @@ -384,11 +385,10 @@ def __new__( class_dict: Dict[str, Any], **kwargs: Any, ) -> Any: + relationships: Dict[str, RelationshipInfo] = {} dict_for_pydantic = {} - original_annotations = resolve_annotations( - class_dict.get("__annotations__", {}), class_dict.get("__module__", None) - ) + original_annotations = class_dict.get("__annotations__", {}) pydantic_annotations = {} relationship_annotations = {} for k, v in class_dict.items(): @@ -412,7 +412,7 @@ def __new__( # superclass causing an error allowed_config_kwargs: Set[str] = { key - for key in dir(BaseConfig) + for key in dir(SQLModelConfig) if not ( key.startswith("__") and key.endswith("__") ) # skip dunder methods and attributes @@ -422,38 +422,46 @@ def __new__( key: pydantic_kwargs.pop(key) for key in pydantic_kwargs.keys() & allowed_config_kwargs } + config_table = getattr(class_dict.get('Config', object()), 'table', False) + # If we have a table, we need to have defaults for all fields + # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything + if config_table is True: + for key in original_annotations.keys(): + if dict_used.get(key, PydanticUndefined) is PydanticUndefined: + dict_used[key] = None + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, **new_cls.__annotations__, } - + def get_config(name: str) -> Any: - config_class_value = getattr(new_cls.__config__, name, Undefined) - if config_class_value is not Undefined: + config_class_value = new_cls.model_config.get(name, PydanticUndefined) + if config_class_value is not PydanticUndefined: return config_class_value - kwarg_value = kwargs.get(name, Undefined) - if kwarg_value is not Undefined: + kwarg_value = kwargs.get(name, PydanticUndefined) + if kwarg_value is not PydanticUndefined: return kwarg_value - return Undefined + return PydanticUndefined config_table = get_config("table") if config_table is True: # If it was passed by kwargs, ensure it's also set in config - new_cls.__config__.table = config_table - for k, v in new_cls.__fields__.items(): + new_cls.model_config['table'] = config_table + for k, v in new_cls.model_fields.items(): col = get_column_from_field(v) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field # in orm_mode instead of preemptively converting it to a dict. - # This could be done by reading new_cls.__config__.table in FastAPI, but + # This could be done by reading new_cls.model_config['table'] in FastAPI, but # that's very specific about SQLModel, so let's have another config that # other future tools based on Pydantic can use. - new_cls.__config__.read_with_orm_mode = True + new_cls.model_config['read_from_attributes'] = True config_registry = get_config("registry") - if config_registry is not Undefined: + if config_registry is not PydanticUndefined: config_registry = cast(registry, config_registry) # If it was passed by kwargs, ensure it's also set in config new_cls.__config__.registry = config_table @@ -484,16 +492,15 @@ def __init__( setattr(cls, rel_name, rel_info.sa_relationship) # Fix #315 continue ann = cls.__annotations__[rel_name] - temp_field = ModelField.infer( - name=rel_name, - value=rel_info, - annotation=ann, - class_validators=None, - config=BaseConfig, - ) - relationship_to = temp_field.type_ - if isinstance(temp_field.type_, ForwardRef): - relationship_to = temp_field.type_.__forward_arg__ + relationship_to = get_origin(ann) + # If Union (Optional), get the real field + if relationship_to is Union: + relationship_to = get_args(ann)[0] + # If a list, then also get the real field + elif relationship_to is list: + relationship_to = get_args(ann)[0] + if isinstance(relationship_to, ForwardRef): + relationship_to = relationship_to.__forward_arg__ rel_kwargs: Dict[str, Any] = {} if rel_info.back_populates: rel_kwargs["back_populates"] = rel_info.back_populates @@ -511,7 +518,7 @@ def __init__( rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: rel_kwargs.update(rel_info.sa_relationship_kwargs) - rel_value: RelationshipProperty = relationship( # type: ignore + rel_value: RelationshipProperty = relationship( relationship_to, *rel_args, **rel_kwargs ) setattr(cls, rel_name, rel_value) # Fix #315 @@ -571,8 +578,8 @@ def get_sqlalchemy_type(field: ModelField) -> Any: raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") -def get_column_from_field(field: ModelField) -> Column: # type: ignore - sa_column = getattr(field.field_info, "sa_column", Undefined) +def get_column_from_field(field: FieldInfo) -> Column: # type: ignore + sa_column = getattr(field, "sa_column", PydanticUndefined) if isinstance(sa_column, Column): return sa_column sa_type = get_sqlalchemy_type(field) @@ -605,18 +612,18 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore "index": index, "unique": unique, } - sa_default = Undefined - if field.field_info.default_factory: - sa_default = field.field_info.default_factory - elif field.field_info.default is not Undefined: - sa_default = field.field_info.default - if sa_default is not Undefined: + sa_default = PydanticUndefined + if field.default_factory: + sa_default = field.default_factory + elif field.default is not PydanticUndefined: + sa_default = field.default + if sa_default is not PydanticUndefined: kwargs["default"] = sa_default - sa_column_args = getattr(field.field_info, "sa_column_args", Undefined) - if sa_column_args is not Undefined: + sa_column_args = getattr(field, "sa_column_args", PydanticUndefined) + if sa_column_args is not PydanticUndefined: args.extend(list(cast(Sequence[Any], sa_column_args))) - sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined) - if sa_column_kwargs is not Undefined: + sa_column_kwargs = getattr(field, "sa_column_kwargs", PydanticUndefined) + if sa_column_kwargs is not PydanticUndefined: kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) return Column(sa_type, *args, **kwargs) # type: ignore @@ -625,13 +632,6 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore default_registry = registry() - -def _value_items_is_true(v: Any) -> bool: - # Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of - # the current latest, Pydantic 1.8.2 - return v is True or v is ... - - _TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") @@ -639,43 +639,17 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) __tablename__: ClassVar[Union[str, Callable[..., str]]] - __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore + __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] __name__: ClassVar[str] metadata: ClassVar[MetaData] __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six - - class Config: - orm_mode = True - - def __new__(cls, *args: Any, **kwargs: Any) -> Any: - new_object = super().__new__(cls) - # SQLAlchemy doesn't call __init__ on the base class - # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html - # Set __fields_set__ here, that would have been set when calling __init__ - # in the Pydantic model so that when SQLAlchemy sets attributes that are - # added (e.g. when querying from DB) to the __fields_set__, this already exists - object.__setattr__(new_object, "__fields_set__", set()) - return new_object + model_config = SQLModelConfig(from_attributes=True) def __init__(__pydantic_self__, **data: Any) -> None: - # Uses something other than `self` the first arg to allow "self" as a - # settable attribute - values, fields_set, validation_error = validate_model( - __pydantic_self__.__class__, data - ) - # Only raise errors if not a SQLModel model - if ( - not getattr(__pydantic_self__.__config__, "table", False) - and validation_error - ): - raise validation_error - # Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy - # can handle them - # object.__setattr__(__pydantic_self__, '__dict__', values) - for key, value in values.items(): - setattr(__pydantic_self__, key, value) - object.__setattr__(__pydantic_self__, "__fields_set__", fields_set) - non_pydantic_keys = data.keys() - values.keys() + old_dict = __pydantic_self__.__dict__.copy() + super().__init__(**data) + __pydantic_self__.__dict__ = old_dict | __pydantic_self__.__dict__ + non_pydantic_keys = data.keys() - __pydantic_self__.model_fields for key in non_pydantic_keys: if key in __pydantic_self__.__sqlmodel_relationships__: setattr(__pydantic_self__, key, data[key]) @@ -686,58 +660,12 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if getattr(self.__config__, "table", False) and is_instrumented(self, name): # type: ignore + if self.model_config.get("table", False) and is_instrumented(self, name): set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values if name not in self.__sqlmodel_relationships__: - super().__setattr__(name, value) - - @classmethod - def from_orm( - cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None - ) -> _TSQLModel: - # Duplicated from Pydantic - if not cls.__config__.orm_mode: - raise ConfigError( - "You must have the config attribute orm_mode=True to use from_orm" - ) - obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj) - # SQLModel, support update dict - if update is not None: - obj = {**obj, **update} - # End SQLModel support dict - if not getattr(cls.__config__, "table", False): - # If not table, normal Pydantic code - m: _TSQLModel = cls.__new__(cls) - else: - # If table, create the new instance normally to make SQLAlchemy create - # the _sa_instance_state attribute - m = cls() - values, fields_set, validation_error = validate_model(cls, obj) - if validation_error: - raise validation_error - # Updated to trigger SQLAlchemy internal handling - if not getattr(cls.__config__, "table", False): - object.__setattr__(m, "__dict__", values) - else: - for key, value in values.items(): - setattr(m, key, value) - # Continue with standard Pydantic logic - object.__setattr__(m, "__fields_set__", fields_set) - m._init_private_attributes() - return m - - @classmethod - def parse_obj( - cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None - ) -> _TSQLModel: - obj = cls._enforce_dict_if_root(obj) - # SQLModel, support update dict - if update is not None: - obj = {**obj, **update} - # End SQLModel support dict - return super().parse_obj(obj) + super(SQLModel, self).__setattr__(name, value) def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: # Don't show SQLAlchemy private attributes @@ -747,78 +675,25 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: if not (isinstance(k, str) and k.startswith("_sa_")) ] - # From Pydantic, override to enforce validation with dict - @classmethod - def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel: - if isinstance(value, cls): - return value.copy() if cls.__config__.copy_on_model_validation else value - - value = cls._enforce_dict_if_root(value) - if isinstance(value, dict): - values, fields_set, validation_error = validate_model(cls, value) - if validation_error: - raise validation_error - model = cls(**value) - # Reset fields set, this would have been done in Pydantic in __init__ - object.__setattr__(model, "__fields_set__", fields_set) - return model - elif cls.__config__.orm_mode: - return cls.from_orm(value) - elif cls.__custom_root_type__: - return cls.parse_obj(value) - else: - try: - value_as_dict = dict(value) - except (TypeError, ValueError) as e: - raise DictError() from e - return cls(**value_as_dict) - - # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes - def _calculate_keys( - self, - include: Optional[Mapping[Union[int, str], Any]], - exclude: Optional[Mapping[Union[int, str], Any]], - exclude_unset: bool, - update: Optional[Dict[str, Any]] = None, - ) -> Optional[AbstractSet[str]]: - if include is None and exclude is None and not exclude_unset: - # Original in Pydantic: - # return None - # Updated to not return SQLAlchemy attributes - # Do not include relationships as that would easily lead to infinite - # recursion, or traversing the whole database - return self.__fields__.keys() # | self.__sqlmodel_relationships__.keys() - - keys: AbstractSet[str] - if exclude_unset: - keys = self.__fields_set__.copy() - else: - # Original in Pydantic: - # keys = self.__dict__.keys() - # Updated to not return SQLAlchemy attributes - # Do not include relationships as that would easily lead to infinite - # recursion, or traversing the whole database - keys = self.__fields__.keys() # | self.__sqlmodel_relationships__.keys() - if include is not None: - keys &= include.keys() - - if update: - keys -= update.keys() - - if exclude: - keys -= {k for k, v in exclude.items() if _value_items_is_true(v)} - - return keys @declared_attr # type: ignore def __tablename__(cls) -> str: return cls.__name__.lower() -def _is_field_noneable(field: ModelField) -> bool: - if not field.required: - # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) - return field.allow_none and ( - field.shape != SHAPE_SINGLETON or not field.sub_fields - ) +def _is_field_noneable(field: FieldInfo) -> bool: + if not field.is_required(): + if field.annotation is None or field.annotation is type(None): + return True + if get_origin(field.annotation) is Union: + for base in get_args(field.annotation): + if base is type(None): + return True + return False return False + +def _get_field_metadata(field: FieldInfo) -> object: + for meta in field.metadata: + if isinstance(meta, PydanticGeneralMetadata): + return meta + return object() \ No newline at end of file diff --git a/sqlmodel/typing.py b/sqlmodel/typing.py new file mode 100644 index 0000000000..570da2fa55 --- /dev/null +++ b/sqlmodel/typing.py @@ -0,0 +1,7 @@ +from pydantic import ConfigDict +from typing import Optional, Any + +class SQLModelConfig(ConfigDict): + table: Optional[bool] + read_from_attributes: Optional[bool] + registry: Optional[Any] \ No newline at end of file diff --git a/tests/test_instance_no_args.py b/tests/test_instance_no_args.py index 14d560628b..5dc520c54f 100644 --- a/tests/test_instance_no_args.py +++ b/tests/test_instance_no_args.py @@ -8,7 +8,7 @@ def test_allow_instantiation_without_arguments(clear_sqlmodel): class Item(SQLModel): id: Optional[int] = Field(default=None, primary_key=True) - name: str + name: str description: Optional[str] = None class Config: diff --git a/tests/test_missing_type.py b/tests/test_missing_type.py index 2185fa43e9..dd12b2547a 100644 --- a/tests/test_missing_type.py +++ b/tests/test_missing_type.py @@ -2,10 +2,11 @@ import pytest from sqlmodel import Field, SQLModel +from pydantic import BaseModel def test_missing_sql_type(): - class CustomType: + class CustomType(BaseModel): @classmethod def __get_validators__(cls): yield cls.validate diff --git a/tests/test_validation.py b/tests/test_validation.py index ad60fcb945..4183986a06 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,7 +1,7 @@ from typing import Optional import pytest -from pydantic import validator +from pydantic import field_validator from pydantic.error_wrappers import ValidationError from sqlmodel import SQLModel @@ -22,12 +22,12 @@ class Hero(SQLModel): secret_name: Optional[str] = None age: Optional[int] = None - @validator("name", "secret_name", "age") + @field_validator("name", "secret_name", "age") def reject_none(cls, v): assert v is not None return v - Hero.validate({"age": 25}) + Hero.model_validate({"age": 25}) with pytest.raises(ValidationError): - Hero.validate({"name": None, "age": 25}) + Hero.model_validate({"name": None, "age": 25}) From e92a52eb416e7815a54c23f4792ccfb53bf4fccc Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Mon, 31 Jul 2023 13:40:23 +0000 Subject: [PATCH 020/105] Formatting --- sqlmodel/main.py | 20 ++++++++++---------- sqlmodel/typing.py | 6 ++++-- tests/test_instance_no_args.py | 2 +- tests/test_missing_type.py | 2 +- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 0ca69be152..789a3ac4fc 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -19,7 +19,7 @@ Set, Tuple, Type, - TypeVar,ForwardRef, + TypeVar, Union, cast, overload, @@ -385,7 +385,7 @@ def __new__( class_dict: Dict[str, Any], **kwargs: Any, ) -> Any: - + relationships: Dict[str, RelationshipInfo] = {} dict_for_pydantic = {} original_annotations = class_dict.get("__annotations__", {}) @@ -422,21 +422,21 @@ def __new__( key: pydantic_kwargs.pop(key) for key in pydantic_kwargs.keys() & allowed_config_kwargs } - config_table = getattr(class_dict.get('Config', object()), 'table', False) + config_table = getattr(class_dict.get("Config", object()), "table", False) # If we have a table, we need to have defaults for all fields # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything if config_table is True: for key in original_annotations.keys(): if dict_used.get(key, PydanticUndefined) is PydanticUndefined: dict_used[key] = None - + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, **new_cls.__annotations__, } - + def get_config(name: str) -> Any: config_class_value = new_cls.model_config.get(name, PydanticUndefined) if config_class_value is not PydanticUndefined: @@ -449,7 +449,7 @@ def get_config(name: str) -> Any: config_table = get_config("table") if config_table is True: # If it was passed by kwargs, ensure it's also set in config - new_cls.model_config['table'] = config_table + new_cls.model_config["table"] = config_table for k, v in new_cls.model_fields.items(): col = get_column_from_field(v) setattr(new_cls, k, col) @@ -458,7 +458,7 @@ def get_config(name: str) -> Any: # This could be done by reading new_cls.model_config['table'] in FastAPI, but # that's very specific about SQLModel, so let's have another config that # other future tools based on Pydantic can use. - new_cls.model_config['read_from_attributes'] = True + new_cls.model_config["read_from_attributes"] = True config_registry = get_config("registry") if config_registry is not PydanticUndefined: @@ -518,7 +518,7 @@ def __init__( rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: rel_kwargs.update(rel_info.sa_relationship_kwargs) - rel_value: RelationshipProperty = relationship( + rel_value: RelationshipProperty = relationship( relationship_to, *rel_args, **rel_kwargs ) setattr(cls, rel_name, rel_value) # Fix #315 @@ -675,7 +675,6 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: if not (isinstance(k, str) and k.startswith("_sa_")) ] - @declared_attr # type: ignore def __tablename__(cls) -> str: return cls.__name__.lower() @@ -692,8 +691,9 @@ def _is_field_noneable(field: FieldInfo) -> bool: return False return False + def _get_field_metadata(field: FieldInfo) -> object: for meta in field.metadata: if isinstance(meta, PydanticGeneralMetadata): return meta - return object() \ No newline at end of file + return object() diff --git a/sqlmodel/typing.py b/sqlmodel/typing.py index 570da2fa55..f2c87503c0 100644 --- a/sqlmodel/typing.py +++ b/sqlmodel/typing.py @@ -1,7 +1,9 @@ +from typing import Any, Optional + from pydantic import ConfigDict -from typing import Optional, Any + class SQLModelConfig(ConfigDict): table: Optional[bool] read_from_attributes: Optional[bool] - registry: Optional[Any] \ No newline at end of file + registry: Optional[Any] diff --git a/tests/test_instance_no_args.py b/tests/test_instance_no_args.py index 5dc520c54f..14d560628b 100644 --- a/tests/test_instance_no_args.py +++ b/tests/test_instance_no_args.py @@ -8,7 +8,7 @@ def test_allow_instantiation_without_arguments(clear_sqlmodel): class Item(SQLModel): id: Optional[int] = Field(default=None, primary_key=True) - name: str + name: str description: Optional[str] = None class Config: diff --git a/tests/test_missing_type.py b/tests/test_missing_type.py index dd12b2547a..dc31f053ec 100644 --- a/tests/test_missing_type.py +++ b/tests/test_missing_type.py @@ -1,8 +1,8 @@ from typing import Optional import pytest -from sqlmodel import Field, SQLModel from pydantic import BaseModel +from sqlmodel import Field, SQLModel def test_missing_sql_type(): From 61d5d8dfbe2a5af13b6f41f6ca5d38d471955f47 Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Mon, 31 Jul 2023 13:57:57 +0000 Subject: [PATCH 021/105] Linting --- sqlmodel/main.py | 23 ++++++++++++----------- sqlmodel/typing.py | 2 +- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 789a3ac4fc..3af43b5cf1 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -14,6 +14,7 @@ ForwardRef, List, Mapping, + NoneType, Optional, Sequence, Set, @@ -344,7 +345,7 @@ def Relationship( *, back_populates: Optional[str] = None, link_model: Optional[Any] = None, - sa_relationship: Optional[RelationshipProperty] = None, + sa_relationship: Optional[RelationshipProperty[Any]] = None, sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, ) -> Any: @@ -361,7 +362,7 @@ def Relationship( @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: Dict[str, RelationshipInfo] - model_config: Type[SQLModelConfig] + model_config: SQLModelConfig model_fields: Dict[str, FieldInfo] # Replicate SQLAlchemy @@ -430,7 +431,9 @@ def __new__( if dict_used.get(key, PydanticUndefined) is PydanticUndefined: dict_used[key] = None - new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + new_cls: Type["SQLModelMetaclass"] = super().__new__( + cls, name, bases, dict_used, **config_kwargs + ) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, @@ -518,7 +521,7 @@ def __init__( rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: rel_kwargs.update(rel_info.sa_relationship_kwargs) - rel_value: RelationshipProperty = relationship( + rel_value: RelationshipProperty[Any] = relationship( relationship_to, *rel_args, **rel_kwargs ) setattr(cls, rel_name, rel_value) # Fix #315 @@ -612,7 +615,7 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore "index": index, "unique": unique, } - sa_default = PydanticUndefined + sa_default: PydanticUndefinedType | Callable[[], Any] = PydanticUndefined if field.default_factory: sa_default = field.default_factory elif field.default is not PydanticUndefined: @@ -632,14 +635,12 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore default_registry = registry() -_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") - class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) __tablename__: ClassVar[Union[str, Callable[..., str]]] - __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] + __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]] __name__: ClassVar[str] metadata: ClassVar[MetaData] __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six @@ -660,7 +661,7 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if self.model_config.get("table", False) and is_instrumented(self, name): + if self.model_config.get("table", False) and is_instrumented(self, name): # type: ignore set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values @@ -682,11 +683,11 @@ def __tablename__(cls) -> str: def _is_field_noneable(field: FieldInfo) -> bool: if not field.is_required(): - if field.annotation is None or field.annotation is type(None): + if field.annotation is None or field.annotation is NoneType: return True if get_origin(field.annotation) is Union: for base in get_args(field.annotation): - if base is type(None): + if base is NoneType: return True return False return False diff --git a/sqlmodel/typing.py b/sqlmodel/typing.py index f2c87503c0..8151f99692 100644 --- a/sqlmodel/typing.py +++ b/sqlmodel/typing.py @@ -3,7 +3,7 @@ from pydantic import ConfigDict -class SQLModelConfig(ConfigDict): +class SQLModelConfig(ConfigDict, total=False): table: Optional[bool] read_from_attributes: Optional[bool] registry: Optional[Any] From cb494b763cf3088200749373f3819846e91ab67a Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Tue, 1 Aug 2023 08:13:30 +0000 Subject: [PATCH 022/105] Make all tests but fastapi work --- sqlmodel/main.py | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3af43b5cf1..be4623bcac 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -14,7 +14,6 @@ ForwardRef, List, Mapping, - NoneType, Optional, Sequence, Set, @@ -58,6 +57,7 @@ _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] +NoneType = type(None) def __dataclass_transform__( @@ -423,13 +423,17 @@ def __new__( key: pydantic_kwargs.pop(key) for key in pydantic_kwargs.keys() & allowed_config_kwargs } - config_table = getattr(class_dict.get("Config", object()), "table", False) + config_table = getattr(class_dict.get("Config", object()), "table", False) or kwargs.get("table", False) # If we have a table, we need to have defaults for all fields # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything if config_table is True: - for key in original_annotations.keys(): - if dict_used.get(key, PydanticUndefined) is PydanticUndefined: + for key in pydantic_annotations.keys(): + value = dict_used.get(key, PydanticUndefined) + if value is PydanticUndefined: dict_used[key] = None + elif isinstance(value, FieldInfo): + if value.default is PydanticUndefined and value.default_factory is None: + value.default = None new_cls: Type["SQLModelMetaclass"] = super().__new__( cls, name, bases, dict_used, **config_kwargs @@ -496,8 +500,11 @@ def __init__( continue ann = cls.__annotations__[rel_name] relationship_to = get_origin(ann) - # If Union (Optional), get the real field - if relationship_to is Union: + # Direct relationships (e.g. 'Team' or Team) have None as an origin + if relationship_to is None: + relationship_to = ann + # If Union (e.g. Optional), get the real field + elif relationship_to is Union: relationship_to = get_args(ann)[0] # If a list, then also get the real field elif relationship_to is list: @@ -646,6 +653,16 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six model_config = SQLModelConfig(from_attributes=True) + def __new__(cls, *args: Any, **kwargs: Any) -> Any: + new_object = super().__new__(cls) + # SQLAlchemy doesn't call __init__ on the base class + # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html + # Set __fields_set__ here, that would have been set when calling __init__ + # in the Pydantic model so that when SQLAlchemy sets attributes that are + # added (e.g. when querying from DB) to the __fields_set__, this already exists + object.__setattr__(new_object, "__pydantic_fields_set__", set()) + return new_object + def __init__(__pydantic_self__, **data: Any) -> None: old_dict = __pydantic_self__.__dict__.copy() super().__init__(**data) @@ -680,6 +697,10 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: def __tablename__(cls) -> str: return cls.__name__.lower() + @classmethod + def model_validate(cls, *args, **kwargs): + return super().model_validate(*args, **kwargs) + def _is_field_noneable(field: FieldInfo) -> bool: if not field.is_required(): From f590548b42438ed56c4d37c2f32882d9c601fd59 Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Tue, 1 Aug 2023 09:33:55 +0000 Subject: [PATCH 023/105] Get all tests except for openapi working --- sqlmodel/main.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index be4623bcac..236c38100e 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -423,7 +423,9 @@ def __new__( key: pydantic_kwargs.pop(key) for key in pydantic_kwargs.keys() & allowed_config_kwargs } - config_table = getattr(class_dict.get("Config", object()), "table", False) or kwargs.get("table", False) + config_table = getattr( + class_dict.get("Config", object()), "table", False + ) or kwargs.get("table", False) # If we have a table, we need to have defaults for all fields # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything if config_table is True: @@ -432,7 +434,12 @@ def __new__( if value is PydanticUndefined: dict_used[key] = None elif isinstance(value, FieldInfo): - if value.default is PydanticUndefined and value.default_factory is None: + if ( + value.default in (PydanticUndefined, Ellipsis) + ) and value.default_factory is None: + value.original_default = ( + value.default + ) # So we can check for nullable value.default = None new_cls: Type["SQLModelMetaclass"] = super().__new__( @@ -641,6 +648,7 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore class_registry = weakref.WeakValueDictionary() # type: ignore default_registry = registry() +_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): @@ -698,12 +706,28 @@ def __tablename__(cls) -> str: return cls.__name__.lower() @classmethod - def model_validate(cls, *args, **kwargs): - return super().model_validate(*args, **kwargs) + def model_validate( + cls: type[_TSQLModel], + obj: Any, + *, + strict: bool | None = None, + from_attributes: bool | None = None, + context: dict[str, Any] | None = None, + ) -> _TSQLModel: + # Somehow model validate doesn't call __init__ so it would remove our init logic + validated = super().model_validate( + obj, strict=strict, from_attributes=from_attributes, context=context + ) + return cls(**{key: value for key, value in validated}) def _is_field_noneable(field: FieldInfo) -> bool: + if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: + return field.nullable if not field.is_required(): + default = getattr(field, "original_default", field.default) + if default is PydanticUndefined: + return False if field.annotation is None or field.annotation is NoneType: return True if get_origin(field.annotation) is Union: From 7ecbc38b58460e6167a57cd06212b7bc32d853d6 Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Mon, 13 Nov 2023 13:13:29 +0000 Subject: [PATCH 024/105] Write for pydantic v1 and v2 compat --- .github/workflows/test.yml | 7 + sqlmodel/__init__.py | 2 +- sqlmodel/compat.py | 169 +++++++++++++++ sqlmodel/main.py | 420 +++++++++++++++++++++++++------------ sqlmodel/orm/session.py | 4 +- sqlmodel/sql/expression.py | 8 + sqlmodel/typing.py | 9 - tests/conftest.py | 5 + 8 files changed, 478 insertions(+), 146 deletions(-) create mode 100644 sqlmodel/compat.py delete mode 100644 sqlmodel/typing.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9f2e06ed71..979a08b2b2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,6 +21,7 @@ jobs: strategy: matrix: python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + pydantic-version: ["pydantic-v1", "pydantic-v2"] fail-fast: false steps: @@ -51,6 +52,12 @@ jobs: - name: Install Dependencies if: steps.cache.outputs.cache-hit != 'true' run: python -m poetry install + - name: Install Pydantic v1 + if: matrix.pydantic-version == 'pydantic-v1' + run: pip install "pydantic>=1.10.0,<2.0.0" + - name: Install Pydantic v2 + if: matrix.pydantic-version == 'pydantic-v2' + run: pip install "pydantic>=2.0.2,<3.0.0" - name: Lint run: python -m poetry run bash scripts/lint.sh - run: mkdir coverage diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index 4431d7cea7..7e20e1ba41 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -30,6 +30,7 @@ from sqlalchemy.sql import ( LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL, ) +from sqlalchemy.sql import Subquery as Subquery from sqlalchemy.sql import alias as alias from sqlalchemy.sql import all_ as all_ from sqlalchemy.sql import and_ as and_ @@ -70,7 +71,6 @@ from sqlalchemy.sql import outerjoin as outerjoin from sqlalchemy.sql import outparam as outparam from sqlalchemy.sql import over as over -from sqlalchemy.sql import Subquery as Subquery from sqlalchemy.sql import table as table from sqlalchemy.sql import tablesample as tablesample from sqlalchemy.sql import text as text diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py new file mode 100644 index 0000000000..41abbef251 --- /dev/null +++ b/sqlmodel/compat.py @@ -0,0 +1,169 @@ +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + ForwardRef, + Optional, + Type, + TypeVar, + Union, + get_args, + get_origin, +) + +from pydantic import VERSION as PYDANTIC_VERSION + +IS_PYDANTIC_V2 = int(PYDANTIC_VERSION.split(".")[0]) >= 2 + + +if IS_PYDANTIC_V2: + from pydantic import ConfigDict + from pydantic_core import PydanticUndefined as PydanticUndefined, PydanticUndefinedType as PydanticUndefinedType # noqa +else: + from pydantic import BaseConfig # noqa + from pydantic.fields import ModelField # noqa + from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType # noqa + +if TYPE_CHECKING: + from .main import FieldInfo, RelationshipInfo, SQLModel, SQLModelMetaclass + + +NoArgAnyCallable = Callable[[], Any] +T = TypeVar("T") +InstanceOrType = Union[T, Type[T]] + +if IS_PYDANTIC_V2: + + class SQLModelConfig(ConfigDict, total=False): + table: Optional[bool] + read_from_attributes: Optional[bool] + registry: Optional[Any] + +else: + + class SQLModelConfig(BaseConfig): + table: Optional[bool] = None + read_from_attributes: Optional[bool] = None + registry: Optional[Any] = None + + def __getitem__(self, item: str) -> Any: + return self.__getattr__(item) + + def __setitem__(self, item: str, value: Any) -> None: + return self.__setattr__(item, value) + + +# Inspired from https://github.com/roman-right/beanie/blob/main/beanie/odm/utils/pydantic.py +def get_model_config(model: type) -> Optional[SQLModelConfig]: + if IS_PYDANTIC_V2: + return getattr(model, "model_config", None) + else: + return getattr(model, "Config", None) + + +def get_config_value( + model: InstanceOrType["SQLModel"], parameter: str, default: Any = None +) -> Any: + if IS_PYDANTIC_V2: + return model.model_config.get(parameter, default) + else: + return getattr(model.Config, parameter, default) + + +def set_config_value( + model: InstanceOrType["SQLModel"], parameter: str, value: Any, v1_parameter: str = None +) -> None: + if IS_PYDANTIC_V2: + model.model_config[parameter] = value # type: ignore + else: + model.Config[v1_parameter or parameter] = value # type: ignore + + +def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: + if IS_PYDANTIC_V2: + return model.model_fields # type: ignore + else: + return model.__fields__ # type: ignore + + +def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]: + if IS_PYDANTIC_V2: + return model.__pydantic_fields_set__ + else: + return model.__fields_set__ # type: ignore + + +def set_fields_set( + new_object: InstanceOrType["SQLModel"], fields: set["FieldInfo"] +) -> None: + if IS_PYDANTIC_V2: + object.__setattr__(new_object, "__pydantic_fields_set__", fields) + else: + object.__setattr__(new_object, "__fields_set__", fields) + + +def set_attribute_mode(cls: Type["SQLModelMetaclass"]) -> None: + if IS_PYDANTIC_V2: + cls.model_config["read_from_attributes"] = True + else: + cls.__config__.read_with_orm_mode = True # type: ignore + + +def get_relationship_to( + name: str, + rel_info: "RelationshipInfo", + annotation: Any, +) -> Any: + if IS_PYDANTIC_V2: + relationship_to = get_origin(annotation) + # Direct relationships (e.g. 'Team' or Team) have None as an origin + if relationship_to is None: + relationship_to = annotation + # If Union (e.g. Optional), get the real field + elif relationship_to is Union: + relationship_to = get_args(annotation)[0] + # If a list, then also get the real field + elif relationship_to is list: + relationship_to = get_args(annotation)[0] + if isinstance(relationship_to, ForwardRef): + relationship_to = relationship_to.__forward_arg__ + return relationship_to + else: + temp_field = ModelField.infer( + name=name, + value=rel_info, + annotation=annotation, + class_validators=None, + config=SQLModelConfig, + ) + relationship_to = temp_field.type_ + if isinstance(temp_field.type_, ForwardRef): + relationship_to = temp_field.type_.__forward_arg__ + return relationship_to + + +def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) -> None: + """ + Pydantic v2 without required fields with no optionals cannot do empty initialisations. + This means we cannot do Table() and set fields later. + We go around this by adding a default to everything, being None + + Args: + annotations: Dict[str, Any]: The annotations to provide to pydantic + class_dict: Dict[str, Any]: The class dict for the defaults + """ + if IS_PYDANTIC_V2: + from .main import FieldInfo + # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything + for key in annotations.keys(): + value = class_dict.get(key, PydanticUndefined) + if value is PydanticUndefined: + class_dict[key] = None + elif isinstance(value, FieldInfo): + if ( + value.default in (PydanticUndefined, Ellipsis) + ) and value.default_factory is None: + # So we can check for nullable + value.original_default = value.default + value.default = None diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 236c38100e..d5b69b28a8 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -5,13 +5,13 @@ from decimal import Decimal from enum import Enum from pathlib import Path +from types import NoneType from typing import ( AbstractSet, Any, Callable, ClassVar, Dict, - ForwardRef, List, Mapping, Optional, @@ -22,16 +22,16 @@ TypeVar, Union, cast, + get_args, + get_origin, overload, ) -from pydantic import BaseConfig, BaseModel -from pydantic.errors import ConfigError, DictError -from pydantic.fields import SHAPE_SINGLETON, ModelField, Undefined, UndefinedType +from pydantic import BaseModel +from pydantic.fields import SHAPE_SINGLETON, ModelField from pydantic.fields import FieldInfo as PydanticFieldInfo -from pydantic.main import ModelMetaclass, validate_model -from pydantic.typing import NoArgAnyCallable, resolve_annotations -from pydantic.utils import ROOT_KEY, Representation +from pydantic.main import ModelMetaclass +from pydantic.utils import Representation from sqlalchemy import ( Boolean, Column, @@ -52,12 +52,27 @@ from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time +from .compat import ( + IS_PYDANTIC_V2, + NoArgAnyCallable, + PydanticUndefined, + PydanticUndefinedType, + SQLModelConfig, + get_config_value, + get_model_fields, + get_relationship_to, + set_config_value, + set_empty_defaults, + set_fields_set, +) from .sql.sqltypes import GUID, AutoString -from .typing import SQLModelConfig + +if not IS_PYDANTIC_V2: + from pydantic.errors import ConfigError, DictError + from pydantic.main import validate_model + from pydantic.utils import ROOT_KEY _T = TypeVar("_T") -NoArgAnyCallable = Callable[[], Any] -NoneType = type(None) def __dataclass_transform__( @@ -76,13 +91,13 @@ def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: nullable = kwargs.pop("nullable", PydanticUndefined) foreign_key = kwargs.pop("foreign_key", PydanticUndefined) unique = kwargs.pop("unique", False) - index = kwargs.pop("index", Undefined) - sa_type = kwargs.pop("sa_type", Undefined) - sa_column = kwargs.pop("sa_column", Undefined) - sa_column_args = kwargs.pop("sa_column_args", Undefined) - sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) - if sa_column is not Undefined: - if sa_column_args is not Undefined: + index = kwargs.pop("index", PydanticUndefined) + sa_type = kwargs.pop("sa_type", PydanticUndefined) + sa_column = kwargs.pop("sa_column", PydanticUndefined) + sa_column_args = kwargs.pop("sa_column_args", PydanticUndefined) + sa_column_kwargs = kwargs.pop("sa_column_kwargs", PydanticUndefined) + if sa_column is not PydanticUndefined: + if sa_column_args is not PydanticUndefined: raise RuntimeError( "Passing sa_column_args is not supported when " "also passing a sa_column" @@ -92,29 +107,29 @@ def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: "Passing sa_column_kwargs is not supported when " "also passing a sa_column" ) - if primary_key is not Undefined: + if primary_key is not PydanticUndefined: raise RuntimeError( "Passing primary_key is not supported when " "also passing a sa_column" ) - if nullable is not Undefined: + if nullable is not PydanticUndefined: raise RuntimeError( "Passing nullable is not supported when " "also passing a sa_column" ) - if foreign_key is not Undefined: + if foreign_key is not PydanticUndefined: raise RuntimeError( "Passing foreign_key is not supported when " "also passing a sa_column" ) - if unique is not Undefined: + if unique is not PydanticUndefined: raise RuntimeError( "Passing unique is not supported when also passing a sa_column" ) - if index is not Undefined: + if index is not PydanticUndefined: raise RuntimeError( "Passing index is not supported when also passing a sa_column" ) - if sa_type is not Undefined: + if sa_type is not PydanticUndefined: raise RuntimeError( "Passing sa_type is not supported when also passing a sa_column" ) @@ -128,6 +143,7 @@ def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: self.sa_column = sa_column self.sa_column_args = sa_column_args self.sa_column_kwargs = sa_column_kwargs + self.original_default = PydanticUndefined class RelationshipInfo(Representation): @@ -189,14 +205,16 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - primary_key: Union[bool, UndefinedType] = Undefined, - foreign_key: Any = Undefined, - unique: Union[bool, UndefinedType] = Undefined, - nullable: Union[bool, UndefinedType] = Undefined, - index: Union[bool, UndefinedType] = Undefined, - sa_type: Union[Type[Any], UndefinedType] = Undefined, - sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, - sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + primary_key: Union[bool, PydanticUndefinedType] = PydanticUndefined, + foreign_key: Any = PydanticUndefined, + unique: Union[bool, PydanticUndefinedType] = PydanticUndefined, + nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, + index: Union[bool, PydanticUndefinedType] = PydanticUndefined, + sa_type: Union[Type[Any], PydanticUndefinedType] = PydanticUndefined, + sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, + sa_column_kwargs: Union[ + Mapping[str, Any], PydanticUndefinedType + ] = PydanticUndefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: ... @@ -204,7 +222,7 @@ def Field( @overload def Field( - default: Any = Undefined, + default: Any = PydanticUndefined, *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, @@ -233,14 +251,14 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore + sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: ... def Field( - default: Any = Undefined, + default: Any = PydanticUndefined, *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, @@ -269,15 +287,17 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - primary_key: Union[bool, UndefinedType] = Undefined, - foreign_key: Any = Undefined, - unique: Union[bool, UndefinedType] = Undefined, - nullable: Union[bool, UndefinedType] = Undefined, - index: Union[bool, UndefinedType] = Undefined, - sa_type: Union[Type[Any], UndefinedType] = Undefined, - sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore - sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, - sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + primary_key: Union[bool, PydanticUndefinedType] = PydanticUndefined, + foreign_key: Any = PydanticUndefined, + unique: Union[bool, PydanticUndefinedType] = PydanticUndefined, + nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, + index: Union[bool, PydanticUndefinedType] = PydanticUndefined, + sa_type: Union[Type[Any], PydanticUndefinedType] = PydanticUndefined, + sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore + sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, + sa_column_kwargs: Union[ + Mapping[str, Any], PydanticUndefinedType + ] = PydanticUndefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} @@ -345,7 +365,7 @@ def Relationship( *, back_populates: Optional[str] = None, link_model: Optional[Any] = None, - sa_relationship: Optional[RelationshipProperty[Any]] = None, + sa_relationship: Optional[RelationshipProperty] = None, sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, ) -> Any: @@ -362,18 +382,22 @@ def Relationship( @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: Dict[str, RelationshipInfo] - model_config: SQLModelConfig - model_fields: Dict[str, FieldInfo] + if IS_PYDANTIC_V2: + model_config: SQLModelConfig + model_fields: Dict[str, FieldInfo] + else: + __config__: Type[SQLModelConfig] + __fields__: Dict[str, FieldInfo] # Replicate SQLAlchemy def __setattr__(cls, name: str, value: Any) -> None: - if cls.model_config.get("table", False): + if get_config_value(cls, "table", False): DeclarativeMeta.__setattr__(cls, name, value) else: super().__setattr__(name, value) def __delattr__(cls, name: str) -> None: - if cls.model_config.get("table", False): + if get_config_value(cls, "table", False): DeclarativeMeta.__delattr__(cls, name) else: super().__delattr__(name) @@ -386,7 +410,6 @@ def __new__( class_dict: Dict[str, Any], **kwargs: Any, ) -> Any: - relationships: Dict[str, RelationshipInfo] = {} dict_for_pydantic = {} original_annotations = class_dict.get("__annotations__", {}) @@ -426,21 +449,8 @@ def __new__( config_table = getattr( class_dict.get("Config", object()), "table", False ) or kwargs.get("table", False) - # If we have a table, we need to have defaults for all fields - # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything - if config_table is True: - for key in pydantic_annotations.keys(): - value = dict_used.get(key, PydanticUndefined) - if value is PydanticUndefined: - dict_used[key] = None - elif isinstance(value, FieldInfo): - if ( - value.default in (PydanticUndefined, Ellipsis) - ) and value.default_factory is None: - value.original_default = ( - value.default - ) # So we can check for nullable - value.default = None + if config_table: + set_empty_defaults(pydantic_annotations, dict_used) new_cls: Type["SQLModelMetaclass"] = super().__new__( cls, name, bases, dict_used, **config_kwargs @@ -452,7 +462,7 @@ def __new__( } def get_config(name: str) -> Any: - config_class_value = new_cls.model_config.get(name, PydanticUndefined) + config_class_value = get_config_value(new_cls, name, PydanticUndefined) if config_class_value is not PydanticUndefined: return config_class_value kwarg_value = kwargs.get(name, PydanticUndefined) @@ -463,8 +473,8 @@ def get_config(name: str) -> Any: config_table = get_config("table") if config_table is True: # If it was passed by kwargs, ensure it's also set in config - new_cls.model_config["table"] = config_table - for k, v in new_cls.model_fields.items(): + set_config_value(new_cls, "table", config_table) + for k, v in get_model_fields(new_cls).items(): col = get_column_from_field(v) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field @@ -472,13 +482,15 @@ def get_config(name: str) -> Any: # This could be done by reading new_cls.model_config['table'] in FastAPI, but # that's very specific about SQLModel, so let's have another config that # other future tools based on Pydantic can use. - new_cls.model_config["read_from_attributes"] = True + set_config_value( + new_cls, "read_from_attributes", True, v1_parameter="read_with_orm_mode" + ) config_registry = get_config("registry") if config_registry is not PydanticUndefined: config_registry = cast(registry, config_registry) # If it was passed by kwargs, ensure it's also set in config - new_cls.__config__.registry = config_table + set_config_value(new_cls, "registry", config_table) setattr(new_cls, "_sa_registry", config_registry) # noqa: B010 setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010 setattr(new_cls, "__abstract__", True) # noqa: B010 @@ -506,18 +518,7 @@ def __init__( setattr(cls, rel_name, rel_info.sa_relationship) # Fix #315 continue ann = cls.__annotations__[rel_name] - relationship_to = get_origin(ann) - # Direct relationships (e.g. 'Team' or Team) have None as an origin - if relationship_to is None: - relationship_to = ann - # If Union (e.g. Optional), get the real field - elif relationship_to is Union: - relationship_to = get_args(ann)[0] - # If a list, then also get the real field - elif relationship_to is list: - relationship_to = get_args(ann)[0] - if isinstance(relationship_to, ForwardRef): - relationship_to = relationship_to.__forward_arg__ + relationship_to = get_relationship_to(rel_name, rel_info, ann) rel_kwargs: Dict[str, Any] = {} if rel_info.back_populates: rel_kwargs["back_populates"] = rel_info.back_populates @@ -535,7 +536,7 @@ def __init__( rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: rel_kwargs.update(rel_info.sa_relationship_kwargs) - rel_value: RelationshipProperty[Any] = relationship( + rel_value: RelationshipProperty = relationship( relationship_to, *rel_args, **rel_kwargs ) setattr(cls, rel_name, rel_value) # Fix #315 @@ -548,8 +549,8 @@ def __init__( def get_sqlalchemy_type(field: ModelField) -> Any: - sa_type = getattr(field.field_info, "sa_type", Undefined) # noqa: B009 - if sa_type is not Undefined: + sa_type = getattr(field.field_info, "sa_type", PydanticUndefined) # noqa: B009 + if sa_type is not PydanticUndefined: return sa_type if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI @@ -595,30 +596,30 @@ def get_sqlalchemy_type(field: ModelField) -> Any: raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") -def get_column_from_field(field: FieldInfo) -> Column: # type: ignore - sa_column = getattr(field, "sa_column", PydanticUndefined) +def get_column_from_field(field: ModelField) -> Column: # type: ignore + sa_column = getattr(field.field_info, "sa_column", PydanticUndefined) if isinstance(sa_column, Column): return sa_column sa_type = get_sqlalchemy_type(field) - primary_key = getattr(field.field_info, "primary_key", Undefined) - if primary_key is Undefined: + primary_key = getattr(field.field_info, "primary_key", PydanticUndefined) + if primary_key is PydanticUndefined: primary_key = False - index = getattr(field.field_info, "index", Undefined) - if index is Undefined: + index = getattr(field.field_info, "index", PydanticUndefined) + if index is PydanticUndefined: index = False nullable = not primary_key and _is_field_noneable(field) # Override derived nullability if the nullable property is set explicitly # on the field - field_nullable = getattr(field.field_info, "nullable", Undefined) # noqa: B009 - if field_nullable != Undefined: - assert not isinstance(field_nullable, UndefinedType) + field_nullable = getattr(field.field_info, "nullable", PydanticUndefined) # noqa: B009 + if field_nullable != PydanticUndefined: + assert not isinstance(field_nullable, PydanticUndefinedType) nullable = field_nullable args = [] - foreign_key = getattr(field.field_info, "foreign_key", Undefined) - if foreign_key is Undefined: + foreign_key = getattr(field.field_info, "foreign_key", PydanticUndefined) + if foreign_key is PydanticUndefined: foreign_key = None - unique = getattr(field.field_info, "unique", Undefined) - if unique is Undefined: + unique = getattr(field.field_info, "unique", PydanticUndefined) + if unique is PydanticUndefined: unique = False if foreign_key: assert isinstance(foreign_key, str) @@ -629,17 +630,17 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore "index": index, "unique": unique, } - sa_default: PydanticUndefinedType | Callable[[], Any] = PydanticUndefined - if field.default_factory: - sa_default = field.default_factory - elif field.default is not PydanticUndefined: - sa_default = field.default + sa_default = PydanticUndefined + if field.field_info.default_factory: + sa_default = field.field_info.default_factory + elif field.field_info.default is not PydanticUndefined: + sa_default = field.field_info.default if sa_default is not PydanticUndefined: kwargs["default"] = sa_default - sa_column_args = getattr(field, "sa_column_args", PydanticUndefined) + sa_column_args = getattr(field.field_info, "sa_column_args", PydanticUndefined) if sa_column_args is not PydanticUndefined: args.extend(list(cast(Sequence[Any], sa_column_args))) - sa_column_kwargs = getattr(field, "sa_column_kwargs", PydanticUndefined) + sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", PydanticUndefined) if sa_column_kwargs is not PydanticUndefined: kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) return Column(sa_type, *args, **kwargs) # type: ignore @@ -648,6 +649,14 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore class_registry = weakref.WeakValueDictionary() # type: ignore default_registry = registry() + + +def _value_items_is_true(v: Any) -> bool: + # Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of + # the current latest, Pydantic 1.8.2 + return v is True or v is ... + + _TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") @@ -655,11 +664,17 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) __tablename__: ClassVar[Union[str, Callable[..., str]]] - __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]] + __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] __name__: ClassVar[str] metadata: ClassVar[MetaData] __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six - model_config = SQLModelConfig(from_attributes=True) + + if IS_PYDANTIC_V2: + model_config = SQLModelConfig(from_attributes=True) + else: + + class Config: + orm_mode = True def __new__(cls, *args: Any, **kwargs: Any) -> Any: new_object = super().__new__(cls) @@ -668,14 +683,35 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any: # Set __fields_set__ here, that would have been set when calling __init__ # in the Pydantic model so that when SQLAlchemy sets attributes that are # added (e.g. when querying from DB) to the __fields_set__, this already exists - object.__setattr__(new_object, "__pydantic_fields_set__", set()) + set_fields_set(new_object, set()) return new_object def __init__(__pydantic_self__, **data: Any) -> None: - old_dict = __pydantic_self__.__dict__.copy() - super().__init__(**data) - __pydantic_self__.__dict__ = old_dict | __pydantic_self__.__dict__ - non_pydantic_keys = data.keys() - __pydantic_self__.model_fields + # Uses something other than `self` the first arg to allow "self" as a + # settable attribute + if IS_PYDANTIC_V2: + old_dict = __pydantic_self__.__dict__.copy() + __pydantic_self__.super().__init__(**data) # noqa + __pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__} + non_pydantic_keys = data.keys() - __pydantic_self__.model_fields + else: + values, fields_set, validation_error = validate_model( + __pydantic_self__.__class__, data + ) + # Only raise errors if not a SQLModel model + if ( + not getattr(__pydantic_self__.__config__, "table", False) # noqa + and validation_error + ): + raise validation_error + # Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy + # can handle them + # object.__setattr__(__pydantic_self__, '__dict__', values) + for key, value in values.items(): + setattr(__pydantic_self__, key, value) + object.__setattr__(__pydantic_self__, "__fields_set__", fields_set) + non_pydantic_keys = data.keys() - values.keys() + for key in non_pydantic_keys: if key in __pydantic_self__.__sqlmodel_relationships__: setattr(__pydantic_self__, key, data[key]) @@ -686,12 +722,12 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if self.model_config.get("table", False) and is_instrumented(self, name): # type: ignore + if get_config_value(self, "table", False) and is_instrumented(self, name): set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values if name not in self.__sqlmodel_relationships__: - super(SQLModel, self).__setattr__(name, value) + super().__setattr__(name, value) def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: # Don't show SQLAlchemy private attributes @@ -705,20 +741,143 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: def __tablename__(cls) -> str: return cls.__name__.lower() - @classmethod - def model_validate( - cls: type[_TSQLModel], - obj: Any, - *, - strict: bool | None = None, - from_attributes: bool | None = None, - context: dict[str, Any] | None = None, - ) -> _TSQLModel: - # Somehow model validate doesn't call __init__ so it would remove our init logic - validated = super().model_validate( - obj, strict=strict, from_attributes=from_attributes, context=context - ) - return cls(**{key: value for key, value in validated}) + if IS_PYDANTIC_V2: + + @classmethod + def model_validate( + cls: type[_TSQLModel], + obj: Any, + *, + strict: bool | None = None, + from_attributes: bool | None = None, + context: dict[str, Any] | None = None, + ) -> _TSQLModel: + # Somehow model validate doesn't call __init__ so it would remove our init logic + validated = super().model_validate( + obj, strict=strict, from_attributes=from_attributes, context=context + ) + return cls(**dict(validated)) + + else: + + @classmethod + def from_orm( + cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None + ) -> _TSQLModel: + # Duplicated from Pydantic + if not cls.__config__.orm_mode: # noqa: attr-defined + raise ConfigError( + "You must have the config attribute orm_mode=True to use from_orm" + ) + obj = ( + {ROOT_KEY: obj} + if cls.__custom_root_type__ # noqa + else cls._decompose_class(obj) # noqa + ) + # SQLModel, support update dict + if update is not None: + obj = {**obj, **update} + # End SQLModel support dict + if not getattr(cls.__config__, "table", False): # noqa + # If not table, normal Pydantic code + m: _TSQLModel = cls.__new__(cls) + else: + # If table, create the new instance normally to make SQLAlchemy create + # the _sa_instance_state attribute + m = cls() + values, fields_set, validation_error = validate_model(cls, obj) + if validation_error: + raise validation_error + # Updated to trigger SQLAlchemy internal handling + if not getattr(cls.__config__, "table", False): # noqa + object.__setattr__(m, "__dict__", values) + else: + for key, value in values.items(): + setattr(m, key, value) + # Continue with standard Pydantic logic + object.__setattr__(m, "__fields_set__", fields_set) + m._init_private_attributes() # noqa + return m + + @classmethod + def parse_obj( + cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None + ) -> _TSQLModel: + obj = cls._enforce_dict_if_root(obj) # noqa + # SQLModel, support update dict + if update is not None: + obj = {**obj, **update} + # End SQLModel support dict + return super().parse_obj(obj) + + # From Pydantic, override to enforce validation with dict + @classmethod + def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel: + if isinstance(value, cls): + return ( + value.copy() if cls.__config__.copy_on_model_validation else value # noqa + ) + + value = cls._enforce_dict_if_root(value) + if isinstance(value, dict): + values, fields_set, validation_error = validate_model(cls, value) + if validation_error: + raise validation_error + model = cls(**value) + # Reset fields set, this would have been done in Pydantic in __init__ + object.__setattr__(model, "__fields_set__", fields_set) + return model + elif cls.__config__.orm_mode: # noqa + return cls.from_orm(value) + elif cls.__custom_root_type__: # noqa + return cls.parse_obj(value) + else: + try: + value_as_dict = dict(value) + except (TypeError, ValueError) as e: + raise DictError() from e + return cls(**value_as_dict) + + # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes + def _calculate_keys( + self, + include: Optional[Mapping[Union[int, str], Any]], + exclude: Optional[Mapping[Union[int, str], Any]], + exclude_unset: bool, + update: Optional[Dict[str, Any]] = None, + ) -> Optional[AbstractSet[str]]: + if include is None and exclude is None and not exclude_unset: + # Original in Pydantic: + # return None + # Updated to not return SQLAlchemy attributes + # Do not include relationships as that would easily lead to infinite + # recursion, or traversing the whole database + return ( + self.__fields__.keys() # noqa + ) # | self.__sqlmodel_relationships__.keys() + + keys: AbstractSet[str] + if exclude_unset: + keys = self.__fields_set__.copy() # noqa + else: + # Original in Pydantic: + # keys = self.__dict__.keys() + # Updated to not return SQLAlchemy attributes + # Do not include relationships as that would easily lead to infinite + # recursion, or traversing the whole database + keys = ( + self.__fields__.keys() # noqa + ) # | self.__sqlmodel_relationships__.keys() + if include is not None: + keys &= include.keys() + + if update: + keys -= update.keys() + + if exclude: + keys -= {k for k, v in exclude.items() if _value_items_is_true(v)} + + return keys def _is_field_noneable(field: FieldInfo) -> bool: @@ -736,10 +895,3 @@ def _is_field_noneable(field: FieldInfo) -> bool: return True return False return False - - -def _get_field_metadata(field: FieldInfo) -> object: - for meta in field.metadata: - if isinstance(meta, PydanticGeneralMetadata): - return meta - return object() diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index c9bb043336..29aba05eec 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -1,6 +1,7 @@ from typing import ( Any, Dict, + Literal, Mapping, Optional, Sequence, @@ -15,7 +16,6 @@ from sqlalchemy.orm import Query as _Query from sqlalchemy.orm import Session as _Session from sqlalchemy.sql.base import Executable as _Executable -from typing_extensions import Literal from ..engine.result import Result, ScalarResult from ..sql.base import Executable @@ -137,7 +137,7 @@ def get( ident: Any, options: Optional[Sequence[Any]] = None, populate_existing: bool = False, - with_for_update: Optional[_ForUpdateArg] = None, + with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None, identity_token: Optional[Any] = None, execution_options: Mapping[Any, Any] = util.EMPTY_DICT, bind_arguments: Optional[Dict[str, Any]] = None, diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index a0ac1bd9d9..8cb2309228 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -35,6 +35,14 @@ class SelectOfScalar(_Select[Tuple[_TSelect]], Generic[_TSelect]): inherit_cache = True +# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different +# purpose. This is the same as a normal SQLAlchemy Select class where there's only one +# entity, so the result will be converted to a scalar by default. This way writing +# for loops on the results will feel natural. +class SelectOfScalar(_Select, Generic[_TSelect]): + inherit_cache = True + + if TYPE_CHECKING: # pragma: no cover from ..main import SQLModel diff --git a/sqlmodel/typing.py b/sqlmodel/typing.py deleted file mode 100644 index 8151f99692..0000000000 --- a/sqlmodel/typing.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Any, Optional - -from pydantic import ConfigDict - - -class SQLModelConfig(ConfigDict, total=False): - table: Optional[bool] - read_from_attributes: Optional[bool] - registry: Optional[Any] diff --git a/tests/conftest.py b/tests/conftest.py index 2b8e5fc29e..020b33a566 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ import pytest from pydantic import BaseModel from sqlmodel import SQLModel +from sqlmodel.compat import IS_PYDANTIC_V2 from sqlmodel.main import default_registry top_level_path = Path(__file__).resolve().parent.parent @@ -67,3 +68,7 @@ def new_print(*args): calls.append(data) return new_print + + +needs_pydanticv2 = pytest.mark.skipif(not IS_PYDANTIC_V2, reason="requires Pydantic v2") +needs_pydanticv1 = pytest.mark.skipif(IS_PYDANTIC_V2, reason="requires Pydantic v1") From c21ff699edc62dcda598e952f35a97859549e052 Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Tue, 14 Nov 2023 08:50:11 +0000 Subject: [PATCH 025/105] Make pydantic v1 work again --- docs/contributing.md | 4 -- docs/features.md | 2 +- .../fastapi/app_testing/tutorial001/main.py | 7 ++- .../tutorial/fastapi/delete/tutorial001.py | 6 +- .../fastapi/limit_and_offset/tutorial001.py | 6 +- .../fastapi/multiple_models/tutorial001.py | 6 +- .../fastapi/multiple_models/tutorial002.py | 6 +- .../tutorial/fastapi/read_one/tutorial001.py | 6 +- .../fastapi/relationships/tutorial001.py | 11 +++- .../session_with_dependency/tutorial001.py | 6 +- .../tutorial/fastapi/teams/tutorial001.py | 11 +++- .../tutorial/fastapi/update/tutorial001.py | 6 +- sqlmodel/compat.py | 63 +++++++++++++++---- sqlmodel/engine/result.py | 61 ++++++++++++++---- sqlmodel/ext/asyncio/session.py | 2 +- sqlmodel/main.py | 39 ++++-------- sqlmodel/orm/session.py | 35 ++++------- sqlmodel/sql/expression.py | 50 ++++++++++----- sqlmodel/sql/expression.py.jinja2 | 25 +++----- sqlmodel/sql/sqltypes.py | 8 +-- tests/test_validation.py | 31 ++++++++- 21 files changed, 260 insertions(+), 131 deletions(-) diff --git a/docs/contributing.md b/docs/contributing.md index 5c160a7c0a..217ed61c56 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -6,10 +6,6 @@ First, you might want to see the basic ways to [help SQLModel and get help](help If you already cloned the repository and you know that you need to deep dive in the code, here are some guidelines to set up your environment. -### Python - -SQLModel supports Python 3.7 and above. - ### Poetry **SQLModel** uses Poetry to build, package, and publish the project. diff --git a/docs/features.md b/docs/features.md index 2d5e11d84f..102edef725 100644 --- a/docs/features.md +++ b/docs/features.md @@ -12,7 +12,7 @@ Nevertheless, SQLModel is completely **independent** of FastAPI and can be used ## Just Modern Python -It's all based on standard modern **Python** type annotations. No new syntax to learn. Just standard modern Python. +It's all based on standard modern **Python** type annotations. No new syntax to learn. Just standard modern Python. If you need a 2 minute refresher of how to use Python types (even if you don't use SQLModel or FastAPI), check the FastAPI tutorial section: Python types intro. diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py index a23dfad5a8..cf2f4da233 100644 --- a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py +++ b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py @@ -2,6 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -51,10 +52,12 @@ def get_session(): def on_startup(): create_db_and_tables() - @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/delete/tutorial001.py b/docs_src/tutorial/fastapi/delete/tutorial001.py index 77a99a9c97..f186c42b2b 100644 --- a/docs_src/tutorial/fastapi/delete/tutorial001.py +++ b/docs_src/tutorial/fastapi/delete/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -50,7 +51,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py index 2352f39022..6701355f17 100644 --- a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py +++ b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -44,7 +45,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py index 7f59ac6a1d..0ceed94ca1 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import FastAPI from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class Hero(SQLModel, table=True): @@ -46,7 +47,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py index fffbe72496..d92745a339 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py @@ -2,6 +2,7 @@ from fastapi import FastAPI from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -44,7 +45,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/read_one/tutorial001.py b/docs_src/tutorial/fastapi/read_one/tutorial001.py index f18426e74c..4cdf898922 100644 --- a/docs_src/tutorial/fastapi/read_one/tutorial001.py +++ b/docs_src/tutorial/fastapi/read_one/tutorial001.py @@ -3,6 +3,7 @@ from fastapi import FastAPI, HTTPException from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): name: str = Field(index=True) @@ -44,7 +45,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/relationships/tutorial001.py b/docs_src/tutorial/fastapi/relationships/tutorial001.py index e5b196090e..dfcedaf881 100644 --- a/docs_src/tutorial/fastapi/relationships/tutorial001.py +++ b/docs_src/tutorial/fastapi/relationships/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class TeamBase(SQLModel): @@ -92,7 +93,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -146,7 +150,10 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.model_validate(team) + if IS_PYDANTIC_V2: + db_team = Team.model_validate(team) + else: + db_team = Team.from_orm(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py index a23dfad5a8..f305f75194 100644 --- a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py +++ b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -54,7 +55,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/teams/tutorial001.py b/docs_src/tutorial/fastapi/teams/tutorial001.py index cc73bb52cb..46ea0f933c 100644 --- a/docs_src/tutorial/fastapi/teams/tutorial001.py +++ b/docs_src/tutorial/fastapi/teams/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class TeamBase(SQLModel): @@ -83,7 +84,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -137,7 +141,10 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.model_validate(team) + if IS_PYDANTIC_V2: + db_team = Team.model_validate(team) + else: + db_team = Team.from_orm(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/update/tutorial001.py b/docs_src/tutorial/fastapi/update/tutorial001.py index 28462bff17..93dfa7496a 100644 --- a/docs_src/tutorial/fastapi/update/tutorial001.py +++ b/docs_src/tutorial/fastapi/update/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -50,7 +51,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 41abbef251..3ffcd1cd34 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -11,6 +11,7 @@ get_args, get_origin, ) +from types import NoneType from pydantic import VERSION as PYDANTIC_VERSION @@ -23,7 +24,8 @@ else: from pydantic import BaseConfig # noqa from pydantic.fields import ModelField # noqa - from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType # noqa + from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType, SHAPE_SINGLETON + from pydantic.typing import resolve_annotations if TYPE_CHECKING: from .main import FieldInfo, RelationshipInfo, SQLModel, SQLModelMetaclass @@ -34,32 +36,26 @@ InstanceOrType = Union[T, Type[T]] if IS_PYDANTIC_V2: + PydanticModelConfig = ConfigDict class SQLModelConfig(ConfigDict, total=False): table: Optional[bool] - read_from_attributes: Optional[bool] registry: Optional[Any] else: + PydanticModelConfig = BaseConfig class SQLModelConfig(BaseConfig): table: Optional[bool] = None - read_from_attributes: Optional[bool] = None registry: Optional[Any] = None - def __getitem__(self, item: str) -> Any: - return self.__getattr__(item) - - def __setitem__(self, item: str, value: Any) -> None: - return self.__setattr__(item, value) - # Inspired from https://github.com/roman-right/beanie/blob/main/beanie/odm/utils/pydantic.py def get_model_config(model: type) -> Optional[SQLModelConfig]: if IS_PYDANTIC_V2: return getattr(model, "model_config", None) else: - return getattr(model, "Config", None) + return getattr(model, "__config__", None) def get_config_value( @@ -68,7 +64,7 @@ def get_config_value( if IS_PYDANTIC_V2: return model.model_config.get(parameter, default) else: - return getattr(model.Config, parameter, default) + return getattr(model.__config__, parameter, default) def set_config_value( @@ -77,7 +73,7 @@ def set_config_value( if IS_PYDANTIC_V2: model.model_config[parameter] = value # type: ignore else: - model.Config[v1_parameter or parameter] = value # type: ignore + setattr(model.__config__, v1_parameter or parameter, value) # type: ignore def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: @@ -109,6 +105,25 @@ def set_attribute_mode(cls: Type["SQLModelMetaclass"]) -> None: else: cls.__config__.read_with_orm_mode = True # type: ignore +def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: + if IS_PYDANTIC_V2: + return class_dict.get("__annotations__", {}) + else: + return resolve_annotations(class_dict.get("__annotations__", {}),class_dict.get("__module__", None)) + +def is_table(class_dict: dict[str, Any]) -> bool: + config: SQLModelConfig = {} + if IS_PYDANTIC_V2: + config = class_dict.get("model_config", {}) + else: + config = class_dict.get("__config__", {}) + config_table = config.get("table", PydanticUndefined) + if config_table is not PydanticUndefined: + return config_table + kw_table = class_dict.get("table", PydanticUndefined) + if kw_table is not PydanticUndefined: + return kw_table + return False def get_relationship_to( name: str, @@ -167,3 +182,27 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) # So we can check for nullable value.original_default = value.default value.default = None + +def is_field_noneable(field: "FieldInfo") -> bool: + if IS_PYDANTIC_V2: + if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: + return field.nullable + if not field.is_required(): + default = getattr(field, "original_default", field.default) + if default is PydanticUndefined: + return False + if field.annotation is None or field.annotation is NoneType: + return True + if get_origin(field.annotation) is Union: + for base in get_args(field.annotation): + if base is NoneType: + return True + return False + return False + else: + if not field.required: + # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) + return field.allow_none and ( + field.shape != SHAPE_SINGLETON or not field.sub_fields + ) + return False \ No newline at end of file diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py index 650dd92b27..2401609ae1 100644 --- a/sqlmodel/engine/result.py +++ b/sqlmodel/engine/result.py @@ -1,4 +1,4 @@ -from typing import Generic, Iterator, Optional, Sequence, Tuple, TypeVar +from typing import Generic, Iterator, List, Optional, TypeVar from sqlalchemy.engine.result import Result as _Result from sqlalchemy.engine.result import ScalarResult as _ScalarResult @@ -6,34 +6,73 @@ _T = TypeVar("_T") -class ScalarResult(_ScalarResult[_T], Generic[_T]): - def all(self) -> Sequence[_T]: +class ScalarResult(_ScalarResult, Generic[_T]): + def all(self) -> List[_T]: return super().all() - def partitions(self, size: Optional[int] = None) -> Iterator[Sequence[_T]]: + def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: return super().partitions(size) - def fetchall(self) -> Sequence[_T]: + def fetchall(self) -> List[_T]: return super().fetchall() - def fetchmany(self, size: Optional[int] = None) -> Sequence[_T]: + def fetchmany(self, size: Optional[int] = None) -> List[_T]: return super().fetchmany(size) def __iter__(self) -> Iterator[_T]: return super().__iter__() def __next__(self) -> _T: - return super().__next__() + return super().__next__() # type: ignore def first(self) -> Optional[_T]: return super().first() - def one_or_none(self) -> Optional[_T]: return super().one_or_none() def one(self) -> _T: - return super().one() + return super().one() # type: ignore + + +class Result(_Result, Generic[_T]): + def scalars(self, index: int = 0) -> ScalarResult[_T]: + return super().scalars(index) # type: ignore + + def __iter__(self) -> Iterator[_T]: # type: ignore + return super().__iter__() # type: ignore + + def __next__(self) -> _T: # type: ignore + return super().__next__() # type: ignore + + def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: # type: ignore + return super().partitions(size) # type: ignore + + def fetchall(self) -> List[_T]: # type: ignore + return super().fetchall() # type: ignore + + def fetchone(self) -> Optional[_T]: # type: ignore + return super().fetchone() # type: ignore + + def fetchmany(self, size: Optional[int] = None) -> List[_T]: # type: ignore + return super().fetchmany() # type: ignore + + def all(self) -> List[_T]: # type: ignore + return super().all() # type: ignore + + def first(self) -> Optional[_T]: # type: ignore + return super().first() # type: ignore + + def one_or_none(self) -> Optional[_T]: # type: ignore + return super().one_or_none() # type: ignore + + def scalar_one(self) -> _T: + return super().scalar_one() # type: ignore + + def scalar_one_or_none(self) -> Optional[_T]: + return super().scalar_one_or_none() + def one(self) -> _T: # type: ignore + return super().one() # type: ignore -class Result(_Result[Tuple[_T]], Generic[_T]): - ... + def scalar(self) -> Optional[_T]: + return super().scalar() \ No newline at end of file diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index f500c44dc2..d4678b0370 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -91,4 +91,4 @@ async def exec( execution_options=execution_options, bind_arguments=bind_arguments, **kw, - ) + ) \ No newline at end of file diff --git a/sqlmodel/main.py b/sqlmodel/main.py index d5b69b28a8..1100fb7da7 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -5,7 +5,6 @@ from decimal import Decimal from enum import Enum from pathlib import Path -from types import NoneType from typing import ( AbstractSet, Any, @@ -22,8 +21,6 @@ TypeVar, Union, cast, - get_args, - get_origin, overload, ) @@ -64,6 +61,10 @@ set_config_value, set_empty_defaults, set_fields_set, + is_table, + is_field_noneable, + PydanticModelConfig, + get_annotations ) from .sql.sqltypes import GUID, AutoString @@ -71,6 +72,7 @@ from pydantic.errors import ConfigError, DictError from pydantic.main import validate_model from pydantic.utils import ROOT_KEY + from pydantic.typing import resolve_annotations _T = TypeVar("_T") @@ -412,7 +414,7 @@ def __new__( ) -> Any: relationships: Dict[str, RelationshipInfo] = {} dict_for_pydantic = {} - original_annotations = class_dict.get("__annotations__", {}) + original_annotations = get_annotations(class_dict) pydantic_annotations = {} relationship_annotations = {} for k, v in class_dict.items(): @@ -436,19 +438,16 @@ def __new__( # superclass causing an error allowed_config_kwargs: Set[str] = { key - for key in dir(SQLModelConfig) + for key in dir(PydanticModelConfig) if not ( key.startswith("__") and key.endswith("__") ) # skip dunder methods and attributes } - pydantic_kwargs = kwargs.copy() config_kwargs = { - key: pydantic_kwargs.pop(key) - for key in pydantic_kwargs.keys() & allowed_config_kwargs + key: kwargs[key] + for key in kwargs.keys() & allowed_config_kwargs } - config_table = getattr( - class_dict.get("Config", object()), "table", False - ) or kwargs.get("table", False) + config_table = is_table(class_dict) if config_table: set_empty_defaults(pydantic_annotations, dict_used) @@ -607,7 +606,7 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore index = getattr(field.field_info, "index", PydanticUndefined) if index is PydanticUndefined: index = False - nullable = not primary_key and _is_field_noneable(field) + nullable = not primary_key and is_field_noneable(field) # Override derived nullability if the nullable property is set explicitly # on the field field_nullable = getattr(field.field_info, "nullable", PydanticUndefined) # noqa: B009 @@ -879,19 +878,3 @@ def _calculate_keys( return keys - -def _is_field_noneable(field: FieldInfo) -> bool: - if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: - return field.nullable - if not field.is_required(): - default = getattr(field, "original_default", field.default) - if default is PydanticUndefined: - return False - if field.annotation is None or field.annotation is NoneType: - return True - if get_origin(field.annotation) is Union: - for base in get_args(field.annotation): - if base is NoneType: - return True - return False - return False diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 29aba05eec..db5faf6d7f 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -1,21 +1,10 @@ -from typing import ( - Any, - Dict, - Literal, - Mapping, - Optional, - Sequence, - Type, - TypeVar, - Union, - overload, -) +from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload from sqlalchemy import util -from sqlalchemy.orm import Mapper as _Mapper from sqlalchemy.orm import Query as _Query from sqlalchemy.orm import Session as _Session from sqlalchemy.sql.base import Executable as _Executable +from typing_extensions import Literal from ..engine.result import Result, ScalarResult from ..sql.base import Executable @@ -32,7 +21,7 @@ def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + bind_arguments: Optional[Mapping[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, @@ -46,7 +35,7 @@ def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + bind_arguments: Optional[Mapping[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, @@ -63,7 +52,7 @@ def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + bind_arguments: Optional[Mapping[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, @@ -85,8 +74,8 @@ def execute( self, statement: _Executable, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, - execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, @@ -129,18 +118,17 @@ def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]": Or otherwise you might want to use `session.execute()` instead of `session.query()`. """ - return super().query(*entities, **kwargs) # type: ignore + return super().query(*entities, **kwargs) def get( self, - entity: Union[Type[_TSelectParam], "_Mapper[_TSelectParam]"], + entity: Type[_TSelectParam], ident: Any, options: Optional[Sequence[Any]] = None, populate_existing: bool = False, with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None, identity_token: Optional[Any] = None, - execution_options: Mapping[Any, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT, ) -> Optional[_TSelectParam]: return super().get( entity, @@ -150,5 +138,4 @@ def get( with_for_update=with_for_update, identity_token=identity_token, execution_options=execution_options, - bind_arguments=bind_arguments, - ) + ) \ No newline at end of file diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index 8cb2309228..5d5cddaee3 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -23,15 +23,7 @@ _TSelect = TypeVar("_TSelect") -class Select(_Select[Tuple[_TSelect]], Generic[_TSelect]): - inherit_cache = True - - -# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different -# purpose. This is the same as a normal SQLAlchemy Select class where there's only one -# entity, so the result will be converted to a scalar by default. This way writing -# for loops on the results will feel natural. -class SelectOfScalar(_Select[Tuple[_TSelect]], Generic[_TSelect]): +class Select(_Select, Generic[_TSelect]): inherit_cache = True @@ -125,12 +117,12 @@ class SelectOfScalar(_Select, Generic[_TSelect]): @overload -def select(entity_0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore +def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore ... @overload -def select(entity_0: Type[_TModel_0]) -> SelectOfScalar[_TModel_0]: # type: ignore +def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore ... @@ -141,6 +133,7 @@ def select(entity_0: Type[_TModel_0]) -> SelectOfScalar[_TModel_0]: # type: ign def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1]]: ... @@ -149,6 +142,7 @@ def select( # type: ignore def select( # type: ignore entity_0: _TScalar_0, entity_1: Type[_TModel_1], + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1]]: ... @@ -157,6 +151,7 @@ def select( # type: ignore def select( # type: ignore entity_0: Type[_TModel_0], entity_1: _TScalar_1, + **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1]]: ... @@ -165,6 +160,7 @@ def select( # type: ignore def select( # type: ignore entity_0: Type[_TModel_0], entity_1: Type[_TModel_1], + **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1]]: ... @@ -174,6 +170,7 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, entity_2: _TScalar_2, + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2]]: ... @@ -183,6 +180,7 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, entity_2: Type[_TModel_2], + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2]]: ... @@ -192,6 +190,7 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: Type[_TModel_1], entity_2: _TScalar_2, + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2]]: ... @@ -201,6 +200,7 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2]]: ... @@ -210,6 +210,7 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: _TScalar_1, entity_2: _TScalar_2, + **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2]]: ... @@ -219,6 +220,7 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: _TScalar_1, entity_2: Type[_TModel_2], + **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2]]: ... @@ -228,6 +230,7 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: Type[_TModel_1], entity_2: _TScalar_2, + **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2]]: ... @@ -237,6 +240,7 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], + **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2]]: ... @@ -247,6 +251,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]: ... @@ -257,6 +262,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TModel_3]]: ... @@ -267,6 +273,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TScalar_3]]: ... @@ -277,6 +284,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TModel_3]]: ... @@ -287,6 +295,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TScalar_3]]: ... @@ -297,6 +306,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TModel_3]]: ... @@ -307,6 +317,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TScalar_3]]: ... @@ -317,6 +328,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TModel_3]]: ... @@ -327,6 +339,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TScalar_3]]: ... @@ -337,6 +350,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TModel_3]]: ... @@ -347,6 +361,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TScalar_3]]: ... @@ -357,6 +372,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TModel_3]]: ... @@ -367,6 +383,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TScalar_3]]: ... @@ -377,6 +394,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TModel_3]]: ... @@ -387,6 +405,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TScalar_3]]: ... @@ -397,6 +416,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TModel_3]]: ... @@ -404,14 +424,14 @@ def select( # type: ignore # Generated overloads end -def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore +def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore if len(entities) == 1: - return SelectOfScalar(*entities) # type: ignore - return Select(*entities) # type: ignore + return SelectOfScalar._create(*entities, **kw) # type: ignore + return Select._create(*entities, **kw) # type: ignore # TODO: add several @overload from Python types to SQLAlchemy equivalents def col(column_expression: Any) -> ColumnClause: # type: ignore if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") - return column_expression # type: ignore + return column_expression \ No newline at end of file diff --git a/sqlmodel/sql/expression.py.jinja2 b/sqlmodel/sql/expression.py.jinja2 index 55f4a1ac3e..b3acb22c95 100644 --- a/sqlmodel/sql/expression.py.jinja2 +++ b/sqlmodel/sql/expression.py.jinja2 @@ -20,14 +20,7 @@ from sqlalchemy.sql.expression import Select as _Select _TSelect = TypeVar("_TSelect") -class Select(_Select[Tuple[_TSelect]], Generic[_TSelect]): - inherit_cache = True - -# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different -# purpose. This is the same as a normal SQLAlchemy Select class where there's only one -# entity, so the result will be converted to a scalar by default. This way writing -# for loops on the results will feel natural. -class SelectOfScalar(_Select[Tuple[_TSelect]], Generic[_TSelect]): +class Select(_Select, Generic[_TSelect]): inherit_cache = True # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different @@ -42,7 +35,6 @@ if TYPE_CHECKING: # pragma: no cover # Generated TypeVars start - {% for i in range(number_of_types) %} _TScalar_{{ i }} = TypeVar( "_TScalar_{{ i }}", @@ -66,12 +58,12 @@ _TModel_{{ i }} = TypeVar("_TModel_{{ i }}", bound="SQLModel") # Generated TypeVars end @overload -def select(entity_0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore +def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore ... @overload -def select(entity_0: Type[_TModel_0]) -> SelectOfScalar[_TModel_0]: # type: ignore +def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore ... @@ -81,7 +73,7 @@ def select(entity_0: Type[_TModel_0]) -> SelectOfScalar[_TModel_0]: # type: ign @overload def select( # type: ignore - {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %} + {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}**kw: Any, ) -> Select[Tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]: ... @@ -89,15 +81,14 @@ def select( # type: ignore # Generated overloads end - -def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore +def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore if len(entities) == 1: - return SelectOfScalar(*entities) # type: ignore - return Select(*entities) # type: ignore + return SelectOfScalar._create(*entities, **kw) # type: ignore + return Select._create(*entities, **kw) # type: ignore # TODO: add several @overload from Python types to SQLAlchemy equivalents def col(column_expression: Any) -> ColumnClause: # type: ignore if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") - return column_expression # type: ignore + return column_expression \ No newline at end of file diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index aa30950702..33bc45cdf4 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -15,7 +15,7 @@ class AutoString(types.TypeDecorator): # type: ignore def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": impl = cast(types.String, self.impl) if impl.length is None and dialect.name == "mysql": - return dialect.type_descriptor(types.String(self.mysql_default_length)) + return dialect.type_descriptor(types.String(self.mysql_default_length)) # type: ignore return super().load_dialect_impl(dialect) @@ -34,9 +34,9 @@ class GUID(types.TypeDecorator): # type: ignore def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore if dialect.name == "postgresql": - return dialect.type_descriptor(UUID()) + return dialect.type_descriptor(UUID()) # type: ignore else: - return dialect.type_descriptor(CHAR(32)) + return dialect.type_descriptor(CHAR(32)) # type: ignore def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]: if value is None: @@ -56,4 +56,4 @@ def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UU else: if not isinstance(value, uuid.UUID): value = uuid.UUID(value) - return cast(uuid.UUID, value) + return cast(uuid.UUID, value) \ No newline at end of file diff --git a/tests/test_validation.py b/tests/test_validation.py index 4183986a06..be648c1015 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -5,8 +5,37 @@ from pydantic.error_wrappers import ValidationError from sqlmodel import SQLModel +from .conftest import needs_pydanticv1, needs_pydanticv2 -def test_validation(clear_sqlmodel): +@needs_pydanticv1 +def test_validation_pydantic_v1(clear_sqlmodel): + """Test validation of implicit and explicit None values. + + # For consistency with pydantic, validators are not to be called on + # arguments that are not explicitly provided. + + https://github.com/tiangolo/sqlmodel/issues/230 + https://github.com/samuelcolvin/pydantic/issues/1223 + + """ + + class Hero(SQLModel): + name: Optional[str] = None + secret_name: Optional[str] = None + age: Optional[int] = None + + @field_validator("name", "secret_name", "age") + def reject_none(cls, v): + assert v is not None + return v + + Hero.from_orm({"age": 25}) + + with pytest.raises(ValidationError): + Hero.from_orm({"name": None, "age": 25}) + +@needs_pydanticv2 +def test_validation_pydantic_v2(clear_sqlmodel): """Test validation of implicit and explicit None values. # For consistency with pydantic, validators are not to be called on From 254fb132eef75d73c29527218354114d8ac52186 Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Tue, 14 Nov 2023 08:50:46 +0000 Subject: [PATCH 026/105] Liniting --- .../fastapi/app_testing/tutorial001/main.py | 1 + .../tutorial/fastapi/read_one/tutorial001.py | 2 +- sqlmodel/compat.py | 39 ++++++++++++------- sqlmodel/engine/result.py | 3 +- sqlmodel/ext/asyncio/session.py | 2 +- sqlmodel/main.py | 39 +++++++++---------- sqlmodel/orm/session.py | 2 +- sqlmodel/sql/expression.py | 2 +- sqlmodel/sql/sqltypes.py | 2 +- tests/test_validation.py | 2 + 10 files changed, 53 insertions(+), 41 deletions(-) diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py index cf2f4da233..f305f75194 100644 --- a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py +++ b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py @@ -52,6 +52,7 @@ def get_session(): def on_startup(): create_db_and_tables() + @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): if IS_PYDANTIC_V2: diff --git a/docs_src/tutorial/fastapi/read_one/tutorial001.py b/docs_src/tutorial/fastapi/read_one/tutorial001.py index 4cdf898922..aa805d6c8f 100644 --- a/docs_src/tutorial/fastapi/read_one/tutorial001.py +++ b/docs_src/tutorial/fastapi/read_one/tutorial001.py @@ -2,9 +2,9 @@ from fastapi import FastAPI, HTTPException from sqlmodel import Field, Session, SQLModel, create_engine, select - from sqlmodel.compat import IS_PYDANTIC_V2 + class HeroBase(SQLModel): name: str = Field(index=True) secret_name: str diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 3ffcd1cd34..37a7c3716a 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -1,3 +1,4 @@ +from types import NoneType from typing import ( TYPE_CHECKING, Any, @@ -11,7 +12,6 @@ get_args, get_origin, ) -from types import NoneType from pydantic import VERSION as PYDANTIC_VERSION @@ -20,11 +20,12 @@ if IS_PYDANTIC_V2: from pydantic import ConfigDict - from pydantic_core import PydanticUndefined as PydanticUndefined, PydanticUndefinedType as PydanticUndefinedType # noqa + from pydantic_core import PydanticUndefined as PydanticUndefined # noqa + from pydantic_core import PydanticUndefinedType as PydanticUndefinedType else: - from pydantic import BaseConfig # noqa - from pydantic.fields import ModelField # noqa - from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType, SHAPE_SINGLETON + from pydantic import BaseConfig # noqa + from pydantic.fields import ModelField # noqa + from pydantic.fields import Undefined as PydanticUndefined, SHAPE_SINGLETON from pydantic.typing import resolve_annotations if TYPE_CHECKING: @@ -68,26 +69,29 @@ def get_config_value( def set_config_value( - model: InstanceOrType["SQLModel"], parameter: str, value: Any, v1_parameter: str = None + model: InstanceOrType["SQLModel"], + parameter: str, + value: Any, + v1_parameter: str = None, ) -> None: if IS_PYDANTIC_V2: - model.model_config[parameter] = value # type: ignore + model.model_config[parameter] = value # type: ignore else: setattr(model.__config__, v1_parameter or parameter, value) # type: ignore def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: if IS_PYDANTIC_V2: - return model.model_fields # type: ignore + return model.model_fields # type: ignore else: - return model.__fields__ # type: ignore + return model.__fields__ # type: ignore def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]: if IS_PYDANTIC_V2: return model.__pydantic_fields_set__ else: - return model.__fields_set__ # type: ignore + return model.__fields_set__ # type: ignore def set_fields_set( @@ -103,13 +107,17 @@ def set_attribute_mode(cls: Type["SQLModelMetaclass"]) -> None: if IS_PYDANTIC_V2: cls.model_config["read_from_attributes"] = True else: - cls.__config__.read_with_orm_mode = True # type: ignore + cls.__config__.read_with_orm_mode = True # type: ignore + def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: if IS_PYDANTIC_V2: return class_dict.get("__annotations__", {}) else: - return resolve_annotations(class_dict.get("__annotations__", {}),class_dict.get("__module__", None)) + return resolve_annotations( + class_dict.get("__annotations__", {}), class_dict.get("__module__", None) + ) + def is_table(class_dict: dict[str, Any]) -> bool: config: SQLModelConfig = {} @@ -125,6 +133,7 @@ def is_table(class_dict: dict[str, Any]) -> bool: return kw_table return False + def get_relationship_to( name: str, rel_info: "RelationshipInfo", @@ -170,6 +179,7 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) """ if IS_PYDANTIC_V2: from .main import FieldInfo + # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything for key in annotations.keys(): value = class_dict.get(key, PydanticUndefined) @@ -180,9 +190,10 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) value.default in (PydanticUndefined, Ellipsis) ) and value.default_factory is None: # So we can check for nullable - value.original_default = value.default + value.original_default = value.default value.default = None + def is_field_noneable(field: "FieldInfo") -> bool: if IS_PYDANTIC_V2: if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: @@ -205,4 +216,4 @@ def is_field_noneable(field: "FieldInfo") -> bool: return field.allow_none and ( field.shape != SHAPE_SINGLETON or not field.sub_fields ) - return False \ No newline at end of file + return False diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py index 2401609ae1..7a25422227 100644 --- a/sqlmodel/engine/result.py +++ b/sqlmodel/engine/result.py @@ -27,6 +27,7 @@ def __next__(self) -> _T: def first(self) -> Optional[_T]: return super().first() + def one_or_none(self) -> Optional[_T]: return super().one_or_none() @@ -75,4 +76,4 @@ def one(self) -> _T: # type: ignore return super().one() # type: ignore def scalar(self) -> Optional[_T]: - return super().scalar() \ No newline at end of file + return super().scalar() diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index d4678b0370..f500c44dc2 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -91,4 +91,4 @@ async def exec( execution_options=execution_options, bind_arguments=bind_arguments, **kw, - ) \ No newline at end of file + ) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 1100fb7da7..7a05ef3888 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -52,19 +52,19 @@ from .compat import ( IS_PYDANTIC_V2, NoArgAnyCallable, + PydanticModelConfig, PydanticUndefined, PydanticUndefinedType, SQLModelConfig, + get_annotations, get_config_value, get_model_fields, get_relationship_to, + is_field_noneable, + is_table, set_config_value, set_empty_defaults, set_fields_set, - is_table, - is_field_noneable, - PydanticModelConfig, - get_annotations ) from .sql.sqltypes import GUID, AutoString @@ -72,7 +72,6 @@ from pydantic.errors import ConfigError, DictError from pydantic.main import validate_model from pydantic.utils import ROOT_KEY - from pydantic.typing import resolve_annotations _T = TypeVar("_T") @@ -444,8 +443,7 @@ def __new__( ) # skip dunder methods and attributes } config_kwargs = { - key: kwargs[key] - for key in kwargs.keys() & allowed_config_kwargs + key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } config_table = is_table(class_dict) if config_table: @@ -690,7 +688,7 @@ def __init__(__pydantic_self__, **data: Any) -> None: # settable attribute if IS_PYDANTIC_V2: old_dict = __pydantic_self__.__dict__.copy() - __pydantic_self__.super().__init__(**data) # noqa + __pydantic_self__.super().__init__(**data) # noqa __pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__} non_pydantic_keys = data.keys() - __pydantic_self__.model_fields else: @@ -699,7 +697,7 @@ def __init__(__pydantic_self__, **data: Any) -> None: ) # Only raise errors if not a SQLModel model if ( - not getattr(__pydantic_self__.__config__, "table", False) # noqa + not getattr(__pydantic_self__.__config__, "table", False) # noqa and validation_error ): raise validation_error @@ -764,7 +762,7 @@ def from_orm( cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None ) -> _TSQLModel: # Duplicated from Pydantic - if not cls.__config__.orm_mode: # noqa: attr-defined + if not cls.__config__.orm_mode: # noqa: attr-defined raise ConfigError( "You must have the config attribute orm_mode=True to use from_orm" ) @@ -777,7 +775,7 @@ def from_orm( if update is not None: obj = {**obj, **update} # End SQLModel support dict - if not getattr(cls.__config__, "table", False): # noqa + if not getattr(cls.__config__, "table", False): # noqa # If not table, normal Pydantic code m: _TSQLModel = cls.__new__(cls) else: @@ -788,21 +786,21 @@ def from_orm( if validation_error: raise validation_error # Updated to trigger SQLAlchemy internal handling - if not getattr(cls.__config__, "table", False): # noqa + if not getattr(cls.__config__, "table", False): # noqa object.__setattr__(m, "__dict__", values) else: for key, value in values.items(): setattr(m, key, value) # Continue with standard Pydantic logic object.__setattr__(m, "__fields_set__", fields_set) - m._init_private_attributes() # noqa + m._init_private_attributes() # noqa return m @classmethod def parse_obj( cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None ) -> _TSQLModel: - obj = cls._enforce_dict_if_root(obj) # noqa + obj = cls._enforce_dict_if_root(obj) # noqa # SQLModel, support update dict if update is not None: obj = {**obj, **update} @@ -814,7 +812,7 @@ def parse_obj( def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel: if isinstance(value, cls): return ( - value.copy() if cls.__config__.copy_on_model_validation else value # noqa + value.copy() if cls.__config__.copy_on_model_validation else value # noqa ) value = cls._enforce_dict_if_root(value) @@ -826,9 +824,9 @@ def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel: # Reset fields set, this would have been done in Pydantic in __init__ object.__setattr__(model, "__fields_set__", fields_set) return model - elif cls.__config__.orm_mode: # noqa + elif cls.__config__.orm_mode: # noqa return cls.from_orm(value) - elif cls.__custom_root_type__: # noqa + elif cls.__custom_root_type__: # noqa return cls.parse_obj(value) else: try: @@ -852,12 +850,12 @@ def _calculate_keys( # Do not include relationships as that would easily lead to infinite # recursion, or traversing the whole database return ( - self.__fields__.keys() # noqa + self.__fields__.keys() # noqa ) # | self.__sqlmodel_relationships__.keys() keys: AbstractSet[str] if exclude_unset: - keys = self.__fields_set__.copy() # noqa + keys = self.__fields_set__.copy() # noqa else: # Original in Pydantic: # keys = self.__dict__.keys() @@ -865,7 +863,7 @@ def _calculate_keys( # Do not include relationships as that would easily lead to infinite # recursion, or traversing the whole database keys = ( - self.__fields__.keys() # noqa + self.__fields__.keys() # noqa ) # | self.__sqlmodel_relationships__.keys() if include is not None: keys &= include.keys() @@ -877,4 +875,3 @@ def _calculate_keys( keys -= {k for k, v in exclude.items() if _value_items_is_true(v)} return keys - diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index db5faf6d7f..0c70c290ae 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -138,4 +138,4 @@ def get( with_for_update=with_for_update, identity_token=identity_token, execution_options=execution_options, - ) \ No newline at end of file + ) diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index 5d5cddaee3..264e39cba7 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -434,4 +434,4 @@ def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: def col(column_expression: Any) -> ColumnClause: # type: ignore if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") - return column_expression \ No newline at end of file + return column_expression diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 33bc45cdf4..17d9b06126 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -56,4 +56,4 @@ def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UU else: if not isinstance(value, uuid.UUID): value = uuid.UUID(value) - return cast(uuid.UUID, value) \ No newline at end of file + return cast(uuid.UUID, value) diff --git a/tests/test_validation.py b/tests/test_validation.py index be648c1015..e200a1e73d 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -7,6 +7,7 @@ from .conftest import needs_pydanticv1, needs_pydanticv2 + @needs_pydanticv1 def test_validation_pydantic_v1(clear_sqlmodel): """Test validation of implicit and explicit None values. @@ -34,6 +35,7 @@ def reject_none(cls, v): with pytest.raises(ValidationError): Hero.from_orm({"name": None, "age": 25}) + @needs_pydanticv2 def test_validation_pydantic_v2(clear_sqlmodel): """Test validation of implicit and explicit None values. From ab075145f34b6fea1b6600a63ff00ae0411fce8d Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Tue, 14 Nov 2023 08:57:30 +0000 Subject: [PATCH 027/105] Linter --- sqlmodel/compat.py | 26 +++++++++++--------------- sqlmodel/main.py | 1 - 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 37a7c3716a..e9a63f6144 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -19,13 +19,13 @@ if IS_PYDANTIC_V2: - from pydantic import ConfigDict + from pydantic import ConfigDict as PydanticModelConfig from pydantic_core import PydanticUndefined as PydanticUndefined # noqa from pydantic_core import PydanticUndefinedType as PydanticUndefinedType else: - from pydantic import BaseConfig # noqa + from pydantic import BaseConfig as PydanticModelConfig from pydantic.fields import ModelField # noqa - from pydantic.fields import Undefined as PydanticUndefined, SHAPE_SINGLETON + from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType, SHAPE_SINGLETON # noqa from pydantic.typing import resolve_annotations if TYPE_CHECKING: @@ -37,16 +37,12 @@ InstanceOrType = Union[T, Type[T]] if IS_PYDANTIC_V2: - PydanticModelConfig = ConfigDict - - class SQLModelConfig(ConfigDict, total=False): + class SQLModelConfig(PydanticModelConfig, total=False): table: Optional[bool] registry: Optional[Any] else: - PydanticModelConfig = BaseConfig - - class SQLModelConfig(BaseConfig): + class SQLModelConfig(PydanticModelConfig): table: Optional[bool] = None registry: Optional[Any] = None @@ -72,7 +68,7 @@ def set_config_value( model: InstanceOrType["SQLModel"], parameter: str, value: Any, - v1_parameter: str = None, + v1_parameter: Optional[str] = None, ) -> None: if IS_PYDANTIC_V2: model.model_config[parameter] = value # type: ignore @@ -82,14 +78,14 @@ def set_config_value( def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: if IS_PYDANTIC_V2: - return model.model_fields # type: ignore + return model.model_fields # type: ignore else: return model.__fields__ # type: ignore def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]: if IS_PYDANTIC_V2: - return model.__pydantic_fields_set__ + return model.__pydantic_fields_set__ # type: ignore else: return model.__fields_set__ # type: ignore @@ -127,10 +123,10 @@ def is_table(class_dict: dict[str, Any]) -> bool: config = class_dict.get("__config__", {}) config_table = config.get("table", PydanticUndefined) if config_table is not PydanticUndefined: - return config_table + return config_table # type: ignore kw_table = class_dict.get("table", PydanticUndefined) if kw_table is not PydanticUndefined: - return kw_table + return kw_table # type: ignore return False @@ -197,7 +193,7 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) def is_field_noneable(field: "FieldInfo") -> bool: if IS_PYDANTIC_V2: if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: - return field.nullable + return field.nullable # type: ignore if not field.is_required(): default = getattr(field, "original_default", field.default) if default is PydanticUndefined: diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 7a05ef3888..435b9f31f3 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -739,7 +739,6 @@ def __tablename__(cls) -> str: return cls.__name__.lower() if IS_PYDANTIC_V2: - @classmethod def model_validate( cls: type[_TSQLModel], From 4a4161a52776de966f2b4008a72acb51637f778d Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Tue, 14 Nov 2023 09:03:11 +0000 Subject: [PATCH 028/105] Move lint to after tests to see tests --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 979a08b2b2..3b7cfaf95a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -58,8 +58,6 @@ jobs: - name: Install Pydantic v2 if: matrix.pydantic-version == 'pydantic-v2' run: pip install "pydantic>=2.0.2,<3.0.0" - - name: Lint - run: python -m poetry run bash scripts/lint.sh - run: mkdir coverage - name: Test run: python -m poetry run bash scripts/test.sh @@ -71,6 +69,8 @@ jobs: with: name: coverage path: coverage + - name: Lint + run: python -m poetry run bash scripts/lint.sh coverage-combine: needs: - test From e2d4d1fa833c2ea33c64159ef35d8ce614270e28 Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Thu, 16 Nov 2023 09:50:27 +0000 Subject: [PATCH 029/105] Make most tests succeed Only need to fix OPEN API things I think --- pyproject.toml | 2 +- sqlmodel/compat.py | 229 +++++++++++++++++++++++++++++++-- sqlmodel/main.py | 143 ++------------------ tests/test_enums.py | 47 ++++++- tests/test_instance_no_args.py | 37 +++++- tests/test_nullable.py | 2 +- tests/test_validation.py | 9 +- 7 files changed, 314 insertions(+), 155 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6777d69018..f104631655 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ pillow = "^9.3.0" cairosvg = "^2.5.2" mdx-include = "^1.4.1" coverage = {extras = ["toml"], version = ">=6.2,<8.0"} -fastapi = "^0.68.1" +fastapi = "^0.100.0" requests = "^2.26.0" ruff = "^0.1.2" diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index e9a63f6144..dbd22053a8 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -1,3 +1,9 @@ +import ipaddress +import uuid +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from enum import Enum +from pathlib import Path from types import NoneType from typing import ( TYPE_CHECKING, @@ -6,26 +12,47 @@ Dict, ForwardRef, Optional, + Sequence, Type, TypeVar, Union, + cast, get_args, get_origin, ) from pydantic import VERSION as PYDANTIC_VERSION +from sqlalchemy import ( + Boolean, + Column, + Date, + DateTime, + Float, + ForeignKey, + Integer, + Interval, + Numeric, +) +from sqlalchemy import Enum as sa_Enum +from sqlalchemy.sql.sqltypes import LargeBinary, Time + +from .sql.sqltypes import GUID, AutoString IS_PYDANTIC_V2 = int(PYDANTIC_VERSION.split(".")[0]) >= 2 if IS_PYDANTIC_V2: from pydantic import ConfigDict as PydanticModelConfig + from pydantic._internal._fields import PydanticMetadata + from pydantic._internal._model_construction import ModelMetaclass from pydantic_core import PydanticUndefined as PydanticUndefined # noqa from pydantic_core import PydanticUndefinedType as PydanticUndefinedType else: from pydantic import BaseConfig as PydanticModelConfig - from pydantic.fields import ModelField # noqa - from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType, SHAPE_SINGLETON # noqa + from pydantic.fields import SHAPE_SINGLETON, ModelField + from pydantic.fields import Undefined as PydanticUndefined # noqa + from pydantic.fields import UndefinedType as PydanticUndefinedType + from pydantic.main import ModelMetaclass as ModelMetaclass from pydantic.typing import resolve_annotations if TYPE_CHECKING: @@ -37,11 +64,13 @@ InstanceOrType = Union[T, Type[T]] if IS_PYDANTIC_V2: + class SQLModelConfig(PydanticModelConfig, total=False): table: Optional[bool] registry: Optional[Any] else: + class SQLModelConfig(PydanticModelConfig): table: Optional[bool] = None registry: Optional[Any] = None @@ -78,14 +107,14 @@ def set_config_value( def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: if IS_PYDANTIC_V2: - return model.model_fields # type: ignore + return model.model_fields # type: ignore else: return model.__fields__ # type: ignore def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]: if IS_PYDANTIC_V2: - return model.__pydantic_fields_set__ # type: ignore + return model.__pydantic_fields_set__ # type: ignore else: return model.__fields_set__ # type: ignore @@ -115,7 +144,9 @@ def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: ) -def is_table(class_dict: dict[str, Any]) -> bool: +def class_dict_is_table( + class_dict: dict[str, Any], class_kwargs: dict[str, Any] +) -> bool: config: SQLModelConfig = {} if IS_PYDANTIC_V2: config = class_dict.get("model_config", {}) @@ -123,13 +154,26 @@ def is_table(class_dict: dict[str, Any]) -> bool: config = class_dict.get("__config__", {}) config_table = config.get("table", PydanticUndefined) if config_table is not PydanticUndefined: - return config_table # type: ignore - kw_table = class_dict.get("table", PydanticUndefined) + return config_table # type: ignore + kw_table = class_kwargs.get("table", PydanticUndefined) if kw_table is not PydanticUndefined: - return kw_table # type: ignore + return kw_table # type: ignore return False +def cls_is_table(cls: Type) -> bool: + if IS_PYDANTIC_V2: + config = getattr(cls, "model_config", None) + if not config: + return False + return config.get("table", False) + else: + config = getattr(cls, "__config__", None) + if not config: + return False + return getattr(config, "table", False) + + def get_relationship_to( name: str, rel_info: "RelationshipInfo", @@ -186,17 +230,15 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) value.default in (PydanticUndefined, Ellipsis) ) and value.default_factory is None: # So we can check for nullable - value.original_default = value.default value.default = None -def is_field_noneable(field: "FieldInfo") -> bool: +def _is_field_noneable(field: "FieldInfo") -> bool: if IS_PYDANTIC_V2: if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: - return field.nullable # type: ignore + return field.nullable # type: ignore if not field.is_required(): - default = getattr(field, "original_default", field.default) - if default is PydanticUndefined: + if field.default is PydanticUndefined: return False if field.annotation is None or field.annotation is NoneType: return True @@ -212,4 +254,163 @@ def is_field_noneable(field: "FieldInfo") -> bool: return field.allow_none and ( field.shape != SHAPE_SINGLETON or not field.sub_fields ) - return False + return field.allow_none + + +def get_sqlalchemy_type(field: Any) -> Any: + if IS_PYDANTIC_V2: + field_info = field + else: + field_info = field.field_info + sa_type = getattr(field_info, "sa_type", PydanticUndefined) # noqa: B009 + if sa_type is not PydanticUndefined: + return sa_type + + type_ = get_type_from_field(field) + metadata = get_field_metadata(field) + + # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI + if issubclass(type_, Enum): + return sa_Enum(type_) + if issubclass(type_, str): + max_length = getattr(metadata, "max_length", None) + if max_length: + return AutoString(length=max_length) + return AutoString + if issubclass(type_, float): + return Float + if issubclass(type_, bool): + return Boolean + if issubclass(type_, int): + return Integer + if issubclass(type_, datetime): + return DateTime + if issubclass(type_, date): + return Date + if issubclass(type_, timedelta): + return Interval + if issubclass(type_, time): + return Time + if issubclass(type_, bytes): + return LargeBinary + if issubclass(type_, Decimal): + return Numeric( + precision=getattr(metadata, "max_digits", None), + scale=getattr(metadata, "decimal_places", None), + ) + if issubclass(type_, ipaddress.IPv4Address): + return AutoString + if issubclass(type_, ipaddress.IPv4Network): + return AutoString + if issubclass(type_, ipaddress.IPv6Address): + return AutoString + if issubclass(type_, ipaddress.IPv6Network): + return AutoString + if issubclass(type_, Path): + return AutoString + if issubclass(type_, uuid.UUID): + return GUID + raise ValueError(f"{type_} has no matching SQLAlchemy type") + + +def get_type_from_field(field: Any) -> type: + if IS_PYDANTIC_V2: + type_: type | None = field.annotation + # Resolve Optional fields + if type_ is None: + raise ValueError("Missing field type") + origin = get_origin(type_) + if origin is None: + return type_ + if origin is Union: + bases = get_args(type_) + if len(bases) > 2: + raise ValueError( + "Cannot have a (non-optional) union as a SQL alchemy field" + ) + # Non optional unions are not allowed + if bases[0] is not NoneType and bases[1] is not NoneType: + raise ValueError( + "Cannot have a (non-optional) union as a SQL alchemy field" + ) + # Optional unions are allowed + return bases[0] if bases[0] is not NoneType else bases[1] + return origin + else: + if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: + return field.type_ + raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") + + +class FakeMetadata: + max_length: Optional[int] = None + max_digits: Optional[int] = None + decimal_places: Optional[int] = None + + +def get_field_metadata(field: Any) -> Any: + if IS_PYDANTIC_V2: + for meta in field.metadata: + if isinstance(meta, PydanticMetadata): + return meta + return FakeMetadata() + else: + metadata = FakeMetadata() + metadata.max_length = field.field_info.max_length + metadata.max_digits = getattr(field.type_, "max_digits", None) + metadata.decimal_places = getattr(field.type_, "decimal_places", None) + return metadata + + +def get_column_from_field(field: Any) -> Column: # type: ignore + if IS_PYDANTIC_V2: + field_info = field + else: + field_info = field.field_info + sa_column = getattr(field_info, "sa_column", PydanticUndefined) + if isinstance(sa_column, Column): + return sa_column + sa_type = get_sqlalchemy_type(field) + primary_key = getattr(field_info, "primary_key", PydanticUndefined) + if primary_key is PydanticUndefined: + primary_key = False + index = getattr(field_info, "index", PydanticUndefined) + if index is PydanticUndefined: + index = False + nullable = not primary_key and _is_field_noneable(field) + # Override derived nullability if the nullable property is set explicitly + # on the field + field_nullable = getattr(field_info, "nullable", PydanticUndefined) # noqa: B009 + if field_nullable is not PydanticUndefined: + assert not isinstance(field_nullable, PydanticUndefinedType) + nullable = field_nullable + args = [] + foreign_key = getattr(field_info, "foreign_key", PydanticUndefined) + if foreign_key is PydanticUndefined: + foreign_key = None + unique = getattr(field_info, "unique", PydanticUndefined) + if unique is PydanticUndefined: + unique = False + if foreign_key: + assert isinstance(foreign_key, str) + args.append(ForeignKey(foreign_key)) + kwargs = { + "primary_key": primary_key, + "nullable": nullable, + "index": index, + "unique": unique, + } + sa_default = PydanticUndefined + if field_info.default_factory: + sa_default = field_info.default_factory + elif field_info.default is not PydanticUndefined: + sa_default = field_info.default + if sa_default is not PydanticUndefined: + kwargs["default"] = sa_default + sa_column_args = getattr(field_info, "sa_column_args", PydanticUndefined) + if sa_column_args is not PydanticUndefined: + args.extend(list(cast(Sequence[Any], sa_column_args))) + sa_column_kwargs = getattr(field_info, "sa_column_kwargs", PydanticUndefined) + if sa_column_kwargs is not PydanticUndefined: + kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) + return Column(sa_type, *args, **kwargs) # type: ignore diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 435b9f31f3..cb008bb663 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -1,10 +1,4 @@ -import ipaddress -import uuid import weakref -from datetime import date, datetime, time, timedelta -from decimal import Decimal -from enum import Enum -from pathlib import Path from typing import ( AbstractSet, Any, @@ -25,48 +19,37 @@ ) from pydantic import BaseModel -from pydantic.fields import SHAPE_SINGLETON, ModelField from pydantic.fields import FieldInfo as PydanticFieldInfo -from pydantic.main import ModelMetaclass from pydantic.utils import Representation from sqlalchemy import ( - Boolean, Column, - Date, - DateTime, - Float, - ForeignKey, - Integer, - Interval, - Numeric, inspect, ) -from sqlalchemy import Enum as sa_Enum from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData -from sqlalchemy.sql.sqltypes import LargeBinary, Time from .compat import ( IS_PYDANTIC_V2, + ModelMetaclass, NoArgAnyCallable, PydanticModelConfig, PydanticUndefined, PydanticUndefinedType, SQLModelConfig, + class_dict_is_table, + cls_is_table, get_annotations, + get_column_from_field, get_config_value, get_model_fields, get_relationship_to, - is_field_noneable, - is_table, set_config_value, set_empty_defaults, set_fields_set, ) -from .sql.sqltypes import GUID, AutoString if not IS_PYDANTIC_V2: from pydantic.errors import ConfigError, DictError @@ -144,7 +127,6 @@ def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: self.sa_column = sa_column self.sa_column_args = sa_column_args self.sa_column_kwargs = sa_column_kwargs - self.original_default = PydanticUndefined class RelationshipInfo(Representation): @@ -445,8 +427,7 @@ def __new__( config_kwargs = { key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } - config_table = is_table(class_dict) - if config_table: + if class_dict_is_table(class_dict, kwargs): set_empty_defaults(pydantic_annotations, dict_used) new_cls: Type["SQLModelMetaclass"] = super().__new__( @@ -501,13 +482,8 @@ def __init__( # this allows FastAPI cloning a SQLModel for the response_model without # trying to create a new SQLAlchemy, for a new table, with the same name, that # triggers an error - base_is_table = False - for base in bases: - config = getattr(base, "__config__") # noqa: B009 - if config and getattr(config, "table", False): - base_is_table = True - break - if getattr(cls.__config__, "table", False) and not base_is_table: + base_is_table = any(cls_is_table(base) for base in bases) + if cls_is_table(cls) and not base_is_table: for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): if rel_info.sa_relationship: # There's a SQLAlchemy relationship declared, that takes precedence @@ -545,104 +521,6 @@ def __init__( ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) -def get_sqlalchemy_type(field: ModelField) -> Any: - sa_type = getattr(field.field_info, "sa_type", PydanticUndefined) # noqa: B009 - if sa_type is not PydanticUndefined: - return sa_type - if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: - # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI - if issubclass(field.type_, Enum): - return sa_Enum(field.type_) - if issubclass(field.type_, str): - if field.field_info.max_length: - return AutoString(length=field.field_info.max_length) - return AutoString - if issubclass(field.type_, float): - return Float - if issubclass(field.type_, bool): - return Boolean - if issubclass(field.type_, int): - return Integer - if issubclass(field.type_, datetime): - return DateTime - if issubclass(field.type_, date): - return Date - if issubclass(field.type_, timedelta): - return Interval - if issubclass(field.type_, time): - return Time - if issubclass(field.type_, bytes): - return LargeBinary - if issubclass(field.type_, Decimal): - return Numeric( - precision=getattr(field.type_, "max_digits", None), - scale=getattr(field.type_, "decimal_places", None), - ) - if issubclass(field.type_, ipaddress.IPv4Address): - return AutoString - if issubclass(field.type_, ipaddress.IPv4Network): - return AutoString - if issubclass(field.type_, ipaddress.IPv6Address): - return AutoString - if issubclass(field.type_, ipaddress.IPv6Network): - return AutoString - if issubclass(field.type_, Path): - return AutoString - if issubclass(field.type_, uuid.UUID): - return GUID - raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") - - -def get_column_from_field(field: ModelField) -> Column: # type: ignore - sa_column = getattr(field.field_info, "sa_column", PydanticUndefined) - if isinstance(sa_column, Column): - return sa_column - sa_type = get_sqlalchemy_type(field) - primary_key = getattr(field.field_info, "primary_key", PydanticUndefined) - if primary_key is PydanticUndefined: - primary_key = False - index = getattr(field.field_info, "index", PydanticUndefined) - if index is PydanticUndefined: - index = False - nullable = not primary_key and is_field_noneable(field) - # Override derived nullability if the nullable property is set explicitly - # on the field - field_nullable = getattr(field.field_info, "nullable", PydanticUndefined) # noqa: B009 - if field_nullable != PydanticUndefined: - assert not isinstance(field_nullable, PydanticUndefinedType) - nullable = field_nullable - args = [] - foreign_key = getattr(field.field_info, "foreign_key", PydanticUndefined) - if foreign_key is PydanticUndefined: - foreign_key = None - unique = getattr(field.field_info, "unique", PydanticUndefined) - if unique is PydanticUndefined: - unique = False - if foreign_key: - assert isinstance(foreign_key, str) - args.append(ForeignKey(foreign_key)) - kwargs = { - "primary_key": primary_key, - "nullable": nullable, - "index": index, - "unique": unique, - } - sa_default = PydanticUndefined - if field.field_info.default_factory: - sa_default = field.field_info.default_factory - elif field.field_info.default is not PydanticUndefined: - sa_default = field.field_info.default - if sa_default is not PydanticUndefined: - kwargs["default"] = sa_default - sa_column_args = getattr(field.field_info, "sa_column_args", PydanticUndefined) - if sa_column_args is not PydanticUndefined: - args.extend(list(cast(Sequence[Any], sa_column_args))) - sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", PydanticUndefined) - if sa_column_kwargs is not PydanticUndefined: - kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) - return Column(sa_type, *args, **kwargs) # type: ignore - - class_registry = weakref.WeakValueDictionary() # type: ignore default_registry = registry() @@ -688,7 +566,7 @@ def __init__(__pydantic_self__, **data: Any) -> None: # settable attribute if IS_PYDANTIC_V2: old_dict = __pydantic_self__.__dict__.copy() - __pydantic_self__.super().__init__(**data) # noqa + super().__init__(**data) # noqa __pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__} non_pydantic_keys = data.keys() - __pydantic_self__.model_fields else: @@ -739,6 +617,7 @@ def __tablename__(cls) -> str: return cls.__name__.lower() if IS_PYDANTIC_V2: + @classmethod def model_validate( cls: type[_TSQLModel], @@ -752,7 +631,7 @@ def model_validate( validated = super().model_validate( obj, strict=strict, from_attributes=from_attributes, context=context ) - return cls(**dict(validated)) + return cls(**validated.model_dump(exclude_unset=True)) else: @@ -761,7 +640,7 @@ def from_orm( cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None ) -> _TSQLModel: # Duplicated from Pydantic - if not cls.__config__.orm_mode: # noqa: attr-defined + if not cls.__config__.orm_mode: # noqa raise ConfigError( "You must have the config attribute orm_mode=True to use from_orm" ) diff --git a/tests/test_enums.py b/tests/test_enums.py index 194bdefea1..07a04c686e 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -5,6 +5,8 @@ from sqlalchemy.sql.type_api import TypeEngine from sqlmodel import Field, SQLModel +from .conftest import needs_pydanticv1, needs_pydanticv2 + """ Tests related to Enums @@ -72,7 +74,8 @@ def test_sqlite_ddl_sql(capsys): assert "CREATE TYPE" not in captured.out -def test_json_schema_flat_model(): +@needs_pydanticv1 +def test_json_schema_flat_model_pydantic_v1(): assert FlatModel.schema() == { "title": "FlatModel", "type": "object", @@ -92,7 +95,8 @@ def test_json_schema_flat_model(): } -def test_json_schema_inherit_model(): +@needs_pydanticv1 +def test_json_schema_inherit_model_pydantic_v1(): assert InheritModel.schema() == { "title": "InheritModel", "type": "object", @@ -110,3 +114,42 @@ def test_json_schema_inherit_model(): } }, } + + +@needs_pydanticv2 +def test_json_schema_flat_model_pydantic_v2(): + assert FlatModel.model_json_schema() == { + "title": "FlatModel", + "type": "object", + "properties": { + "id": {"default": None, "format": "uuid", "title": "Id", "type": "string"}, + "enum_field": {"allOf": [{"$ref": "#/$defs/MyEnum1"}], "default": None}, + }, + "$defs": { + "MyEnum1": { + "title": "MyEnum1", + "enum": ["A", "B"], + "type": "string", + } + }, + } + + +@needs_pydanticv2 +def test_json_schema_inherit_model_pydantic_v2(): + assert InheritModel.model_json_schema() == { + "title": "InheritModel", + "type": "object", + "properties": { + "id": {"title": "Id", "type": "string", "format": "uuid"}, + "enum_field": {"$ref": "#/$defs/MyEnum2"}, + }, + "required": ["id", "enum_field"], + "$defs": { + "MyEnum2": { + "title": "MyEnum2", + "enum": ["C", "D"], + "type": "string", + } + }, + } diff --git a/tests/test_instance_no_args.py b/tests/test_instance_no_args.py index 14d560628b..e54e8163b3 100644 --- a/tests/test_instance_no_args.py +++ b/tests/test_instance_no_args.py @@ -1,11 +1,16 @@ from typing import Optional +import pytest +from pydantic import ValidationError from sqlalchemy import create_engine, select from sqlalchemy.orm import Session from sqlmodel import Field, SQLModel +from .conftest import needs_pydanticv1, needs_pydanticv2 -def test_allow_instantiation_without_arguments(clear_sqlmodel): + +@needs_pydanticv1 +def test_allow_instantiation_without_arguments_pydantic_v1(clear_sqlmodel): class Item(SQLModel): id: Optional[int] = Field(default=None, primary_key=True) name: str @@ -25,3 +30,33 @@ class Config: assert len(result) == 1 assert isinstance(item.id, int) SQLModel.metadata.clear() + + +def test_not_allow_instantiation_without_arguments_if_not_table(): + class Item(SQLModel): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + description: Optional[str] = None + + with pytest.raises(ValidationError): + Item() + + +@needs_pydanticv2 +def test_allow_instantiation_without_arguments_pydnatic_v2(clear_sqlmodel): + class Item(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + description: Optional[str] = None + + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + item = Item() + item.name = "Rick" + db.add(item) + db.commit() + result = db.execute(select(Item)).scalars().all() + assert len(result) == 1 + assert isinstance(item.id, int) + SQLModel.metadata.clear() diff --git a/tests/test_nullable.py b/tests/test_nullable.py index 1c8b37b218..a40bb5b5f0 100644 --- a/tests/test_nullable.py +++ b/tests/test_nullable.py @@ -58,7 +58,7 @@ class Hero(SQLModel, table=True): ][0] assert "primary_key INTEGER NOT NULL," in create_table_log assert "required_value VARCHAR NOT NULL," in create_table_log - assert "optional_default_ellipsis VARCHAR NOT NULL," in create_table_log + assert "optional_default_ellipsis VARCHAR," in create_table_log assert "optional_default_none VARCHAR," in create_table_log assert "optional_non_nullable VARCHAR NOT NULL," in create_table_log assert "optional_nullable VARCHAR," in create_table_log diff --git a/tests/test_validation.py b/tests/test_validation.py index e200a1e73d..3265922070 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,7 +1,6 @@ from typing import Optional import pytest -from pydantic import field_validator from pydantic.error_wrappers import ValidationError from sqlmodel import SQLModel @@ -19,21 +18,22 @@ def test_validation_pydantic_v1(clear_sqlmodel): https://github.com/samuelcolvin/pydantic/issues/1223 """ + from pydantic import validator class Hero(SQLModel): name: Optional[str] = None secret_name: Optional[str] = None age: Optional[int] = None - @field_validator("name", "secret_name", "age") + @validator("name", "secret_name", "age") def reject_none(cls, v): assert v is not None return v - Hero.from_orm({"age": 25}) + Hero.validate({"age": 25}) with pytest.raises(ValidationError): - Hero.from_orm({"name": None, "age": 25}) + Hero.validate({"name": None, "age": 25}) @needs_pydanticv2 @@ -47,6 +47,7 @@ def test_validation_pydantic_v2(clear_sqlmodel): https://github.com/samuelcolvin/pydantic/issues/1223 """ + from pydantic import field_validator class Hero(SQLModel): name: Optional[str] = None From 4c358b2c5fea6a4611e4f9e6951b1baa68386210 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 23 Nov 2023 13:12:56 +0100 Subject: [PATCH 030/105] =?UTF-8?q?=F0=9F=92=A1=20Add=20TODO=20comments?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index dbd22053a8..445bbf2e38 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -180,6 +180,7 @@ def get_relationship_to( annotation: Any, ) -> Any: if IS_PYDANTIC_V2: + # TODO: review rename origin and relationship_to relationship_to = get_origin(annotation) # Direct relationships (e.g. 'Team' or Team) have None as an origin if relationship_to is None: @@ -190,6 +191,7 @@ def get_relationship_to( # If a list, then also get the real field elif relationship_to is list: relationship_to = get_args(annotation)[0] + # TODO: given this, should there be a recursive call in this whole if block to get_relationship_to? if isinstance(relationship_to, ForwardRef): relationship_to = relationship_to.__forward_arg__ return relationship_to @@ -217,6 +219,7 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) annotations: Dict[str, Any]: The annotations to provide to pydantic class_dict: Dict[str, Any]: The class dict for the defaults """ + # TODO: no v1? if IS_PYDANTIC_V2: from .main import FieldInfo From 00291938ae1f1e1c5bf6f05b7116ca02cafdf40c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 23 Nov 2023 13:15:47 +0100 Subject: [PATCH 031/105] =?UTF-8?q?=F0=9F=91=B7=20Update=20CI=20from=20mai?= =?UTF-8?q?n?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/test.yml | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c646ce470f..d35614e59d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,8 +20,15 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] - pydantic-version: ["pydantic-v1", "pydantic-v2"] + python-version: + - "3.7" + - "3.8" + - "3.9" + - "3.10" + - "3.11" + pydantic-version: + - pydantic-v1 + - pydantic-v2 fail-fast: false steps: @@ -73,8 +80,6 @@ jobs: with: name: coverage path: coverage - - name: Lint - run: python -m poetry run bash scripts/lint.sh coverage-combine: needs: - test From 8a139457c3e3874d09d0e2dd968e2a760f4faee7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 23 Nov 2023 13:20:21 +0100 Subject: [PATCH 032/105] =?UTF-8?q?=F0=9F=93=9D=20Update=20index.md=20with?= =?UTF-8?q?=20Python=20versions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/tutorial/index.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/tutorial/index.md b/docs/tutorial/index.md index 54e1147d68..e8da8ece79 100644 --- a/docs/tutorial/index.md +++ b/docs/tutorial/index.md @@ -64,8 +64,6 @@ $ cd sqlmodel-tutorial Make sure you have an officially supported version of Python. -Currently it is **Python 3.7** and above (Python 3.6 was already deprecated). - You can check which version you have with:
@@ -85,7 +83,6 @@ You might want to try with the specific versions, for example with: * `python3.10` * `python3.9` * `python3.8` -* `python3.7` The code would look like this: From bebd5d8fdc4d76fe8d8c750c271aa59237d710b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 23 Nov 2023 13:22:56 +0100 Subject: [PATCH 033/105] =?UTF-8?q?=E2=8F=AA=EF=B8=8F=20Revert=20changing?= =?UTF-8?q?=20model=20order=20in=20examples=20to=20account=20for=20previou?= =?UTF-8?q?s=20lack=20of=20support=20for=20forward=20references=20in=20SQL?= =?UTF-8?q?Alchemy=20(from=20a=20previous=20PR=20this=20was=20based=20of)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs_src/tutorial/many_to_many/tutorial003.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/docs_src/tutorial/many_to_many/tutorial003.py b/docs_src/tutorial/many_to_many/tutorial003.py index cec6e56560..1e03c4af89 100644 --- a/docs_src/tutorial/many_to_many/tutorial003.py +++ b/docs_src/tutorial/many_to_many/tutorial003.py @@ -3,12 +3,25 @@ from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select +class HeroTeamLink(SQLModel, table=True): + team_id: Optional[int] = Field( + default=None, foreign_key="team.id", primary_key=True + ) + hero_id: Optional[int] = Field( + default=None, foreign_key="hero.id", primary_key=True + ) + is_training: bool = False + + team: "Team" = Relationship(back_populates="hero_links") + hero: "Hero" = Relationship(back_populates="team_links") + + class Team(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) name: str = Field(index=True) headquarters: str - hero_links: List["HeroTeamLink"] = Relationship(back_populates="team") + hero_links: List[HeroTeamLink] = Relationship(back_populates="team") class Hero(SQLModel, table=True): @@ -17,20 +30,7 @@ class Hero(SQLModel, table=True): secret_name: str age: Optional[int] = Field(default=None, index=True) - team_links: List["HeroTeamLink"] = Relationship(back_populates="hero") - - -class HeroTeamLink(SQLModel, table=True): - team_id: Optional[int] = Field( - default=None, foreign_key="team.id", primary_key=True - ) - hero_id: Optional[int] = Field( - default=None, foreign_key="hero.id", primary_key=True - ) - is_training: bool = False - - team: "Team" = Relationship(back_populates="hero_links") - hero: "Hero" = Relationship(back_populates="team_links") + team_links: List[HeroTeamLink] = Relationship(back_populates="hero") sqlite_file_name = "database.db" From d229a87ae6969291dfc40aecbd1bb863e5166404 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 23 Nov 2023 13:23:51 +0100 Subject: [PATCH 034/105] =?UTF-8?q?=E2=8F=AA=EF=B8=8F=20Revert=20example?= =?UTF-8?q?=20changing=20the=20order=20of=20models=20to=20account=20for=20?= =?UTF-8?q?the=20lack=20of=20support=20for=20forward=20references=20in=20S?= =?UTF-8?q?QLAlchemy=20in=20a=20previous=20PR=20this=20was=20based=20of?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../back_populates/tutorial003.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py b/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py index 8d91a0bc25..98e197002e 100644 --- a/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py +++ b/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py @@ -3,21 +3,6 @@ from sqlmodel import Field, Relationship, SQLModel, create_engine -class Hero(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - name: str = Field(index=True) - secret_name: str - age: Optional[int] = Field(default=None, index=True) - - team_id: Optional[int] = Field(default=None, foreign_key="team.id") - team: Optional["Team"] = Relationship(back_populates="heroes") - - weapon_id: Optional[int] = Field(default=None, foreign_key="weapon.id") - weapon: Optional["Weapon"] = Relationship(back_populates="hero") - - powers: List["Power"] = Relationship(back_populates="hero") - - class Weapon(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) name: str = Field(index=True) @@ -41,6 +26,21 @@ class Team(SQLModel, table=True): heroes: List["Hero"] = Relationship(back_populates="team") +class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + secret_name: str + age: Optional[int] = Field(default=None, index=True) + + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + team: Optional[Team] = Relationship(back_populates="heroes") + + weapon_id: Optional[int] = Field(default=None, foreign_key="weapon.id") + weapon: Optional[Weapon] = Relationship(back_populates="hero") + + powers: List[Power] = Relationship(back_populates="hero") + + sqlite_file_name = "database.db" sqlite_url = f"sqlite:///{sqlite_file_name}" From 29e372b721df25ce79af273a4fa3ede041d92dd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 23 Nov 2023 13:24:39 +0100 Subject: [PATCH 035/105] =?UTF-8?q?=F0=9F=94=A5=20Remove=20import=20remove?= =?UTF-8?q?d=20in=20=5F=5Finit=5F=5F=20from=20main?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index 4ed87e8ad7..5b117a1c05 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -33,7 +33,6 @@ from sqlalchemy.sql import ( LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL, ) -from sqlalchemy.sql import Subquery as Subquery from sqlalchemy.sql import alias as alias from sqlalchemy.sql import bindparam as bindparam from sqlalchemy.sql import column as column From 665902c2008fca7caeebe1168d46e7218dcb8b0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Fri, 24 Nov 2023 09:01:27 +0100 Subject: [PATCH 036/105] =?UTF-8?q?=E2=8F=AA=EF=B8=8F=20Revert=20renaming?= =?UTF-8?q?=20Undefined=20to=20PydanticUndefined,=20limit=20the=20changes?= =?UTF-8?q?=20and=20difff?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 96 +++++++++++++++++++++++------------------------- 1 file changed, 46 insertions(+), 50 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 9e0ddf1493..fe5272f4d9 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -43,9 +43,9 @@ ModelMetaclass, NoArgAnyCallable, PydanticModelConfig, - PydanticUndefined, - PydanticUndefinedType, SQLModelConfig, + Undefined, + UndefinedType, class_dict_is_table, cls_is_table, get_annotations, @@ -77,50 +77,50 @@ def __dataclass_transform__( class FieldInfo(PydanticFieldInfo): - def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: + def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) - nullable = kwargs.pop("nullable", PydanticUndefined) - foreign_key = kwargs.pop("foreign_key", PydanticUndefined) + nullable = kwargs.pop("nullable", Undefined) + foreign_key = kwargs.pop("foreign_key", Undefined) unique = kwargs.pop("unique", False) - index = kwargs.pop("index", PydanticUndefined) - sa_type = kwargs.pop("sa_type", PydanticUndefined) - sa_column = kwargs.pop("sa_column", PydanticUndefined) - sa_column_args = kwargs.pop("sa_column_args", PydanticUndefined) - sa_column_kwargs = kwargs.pop("sa_column_kwargs", PydanticUndefined) - if sa_column is not PydanticUndefined: - if sa_column_args is not PydanticUndefined: + index = kwargs.pop("index", Undefined) + sa_type = kwargs.pop("sa_type", Undefined) + sa_column = kwargs.pop("sa_column", Undefined) + sa_column_args = kwargs.pop("sa_column_args", Undefined) + sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) + if sa_column is not Undefined: + if sa_column_args is not Undefined: raise RuntimeError( "Passing sa_column_args is not supported when " "also passing a sa_column" ) - if sa_column_kwargs is not PydanticUndefined: + if sa_column_kwargs is not Undefined: raise RuntimeError( "Passing sa_column_kwargs is not supported when " "also passing a sa_column" ) - if primary_key is not PydanticUndefined: + if primary_key is not Undefined: raise RuntimeError( "Passing primary_key is not supported when " "also passing a sa_column" ) - if nullable is not PydanticUndefined: + if nullable is not Undefined: raise RuntimeError( "Passing nullable is not supported when " "also passing a sa_column" ) - if foreign_key is not PydanticUndefined: + if foreign_key is not Undefined: raise RuntimeError( "Passing foreign_key is not supported when " "also passing a sa_column" ) - if unique is not PydanticUndefined: + if unique is not Undefined: raise RuntimeError( "Passing unique is not supported when also passing a sa_column" ) - if index is not PydanticUndefined: + if index is not Undefined: raise RuntimeError( "Passing index is not supported when also passing a sa_column" ) - if sa_type is not PydanticUndefined: + if sa_type is not Undefined: raise RuntimeError( "Passing sa_type is not supported when also passing a sa_column" ) @@ -166,7 +166,7 @@ def __init__( @overload def Field( - default: Any = PydanticUndefined, + default: Any = Undefined, *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, @@ -195,16 +195,14 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - primary_key: Union[bool, PydanticUndefinedType] = PydanticUndefined, - foreign_key: Any = PydanticUndefined, - unique: Union[bool, PydanticUndefinedType] = PydanticUndefined, - nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, - index: Union[bool, PydanticUndefinedType] = PydanticUndefined, - sa_type: Union[Type[Any], PydanticUndefinedType] = PydanticUndefined, - sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, - sa_column_kwargs: Union[ - Mapping[str, Any], PydanticUndefinedType - ] = PydanticUndefined, + primary_key: Union[bool, UndefinedType] = Undefined, + foreign_key: Any = Undefined, + unique: Union[bool, UndefinedType] = Undefined, + nullable: Union[bool, UndefinedType] = Undefined, + index: Union[bool, UndefinedType] = Undefined, + sa_type: Union[Type[Any], UndefinedType] = Undefined, + sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, + sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: ... @@ -212,7 +210,7 @@ def Field( @overload def Field( - default: Any = PydanticUndefined, + default: Any = Undefined, *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, @@ -241,14 +239,14 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore + sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: ... def Field( - default: Any = PydanticUndefined, + default: Any = Undefined, *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, @@ -277,17 +275,15 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - primary_key: Union[bool, PydanticUndefinedType] = PydanticUndefined, - foreign_key: Any = PydanticUndefined, - unique: Union[bool, PydanticUndefinedType] = PydanticUndefined, - nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, - index: Union[bool, PydanticUndefinedType] = PydanticUndefined, - sa_type: Union[Type[Any], PydanticUndefinedType] = PydanticUndefined, - sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore - sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, - sa_column_kwargs: Union[ - Mapping[str, Any], PydanticUndefinedType - ] = PydanticUndefined, + primary_key: Union[bool, UndefinedType] = Undefined, + foreign_key: Any = Undefined, + unique: Union[bool, UndefinedType] = Undefined, + nullable: Union[bool, UndefinedType] = Undefined, + index: Union[bool, UndefinedType] = Undefined, + sa_type: Union[Type[Any], UndefinedType] = Undefined, + sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore + sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, + sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} @@ -447,13 +443,13 @@ def __new__( } def get_config(name: str) -> Any: - config_class_value = get_config_value(new_cls, name, PydanticUndefined) - if config_class_value is not PydanticUndefined: + config_class_value = get_config_value(new_cls, name, Undefined) + if config_class_value is not Undefined: return config_class_value - kwarg_value = kwargs.get(name, PydanticUndefined) - if kwarg_value is not PydanticUndefined: + kwarg_value = kwargs.get(name, Undefined) + if kwarg_value is not Undefined: return kwarg_value - return PydanticUndefined + return Undefined config_table = get_config("table") if config_table is True: @@ -472,7 +468,7 @@ def get_config(name: str) -> Any: ) config_registry = get_config("registry") - if config_registry is not PydanticUndefined: + if config_registry is not Undefined: config_registry = cast(registry, config_registry) # If it was passed by kwargs, ensure it's also set in config set_config_value(new_cls, "registry", config_table) From b5460f9303d7e964c6ad37585530c06b18a71de5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Fri, 24 Nov 2023 09:03:13 +0100 Subject: [PATCH 037/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20compat,?= =?UTF-8?q?=20rename=20PydanticUndefined=20to=20Undefined=20as=20before?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 71 +++++++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 445bbf2e38..9d2a097ed6 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -38,20 +38,19 @@ from .sql.sqltypes import GUID, AutoString -IS_PYDANTIC_V2 = int(PYDANTIC_VERSION.split(".")[0]) >= 2 - +IS_PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") if IS_PYDANTIC_V2: from pydantic import ConfigDict as PydanticModelConfig from pydantic._internal._fields import PydanticMetadata from pydantic._internal._model_construction import ModelMetaclass - from pydantic_core import PydanticUndefined as PydanticUndefined # noqa - from pydantic_core import PydanticUndefinedType as PydanticUndefinedType + from pydantic_core import PydanticUndefined as Undefined # noqa + from pydantic_core import PydanticUndefinedType as UndefinedType else: from pydantic import BaseConfig as PydanticModelConfig from pydantic.fields import SHAPE_SINGLETON, ModelField - from pydantic.fields import Undefined as PydanticUndefined # noqa - from pydantic.fields import UndefinedType as PydanticUndefinedType + from pydantic.fields import Undefined as Undefined # noqa + from pydantic.fields import UndefinedType as UndefinedType from pydantic.main import ModelMetaclass as ModelMetaclass from pydantic.typing import resolve_annotations @@ -152,11 +151,11 @@ def class_dict_is_table( config = class_dict.get("model_config", {}) else: config = class_dict.get("__config__", {}) - config_table = config.get("table", PydanticUndefined) - if config_table is not PydanticUndefined: + config_table = config.get("table", Undefined) + if config_table is not Undefined: return config_table # type: ignore - kw_table = class_kwargs.get("table", PydanticUndefined) - if kw_table is not PydanticUndefined: + kw_table = class_kwargs.get("table", Undefined) + if kw_table is not Undefined: return kw_table # type: ignore return False @@ -225,12 +224,12 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything for key in annotations.keys(): - value = class_dict.get(key, PydanticUndefined) - if value is PydanticUndefined: + value = class_dict.get(key, Undefined) + if value is Undefined: class_dict[key] = None elif isinstance(value, FieldInfo): if ( - value.default in (PydanticUndefined, Ellipsis) + value.default in (Undefined, Ellipsis) ) and value.default_factory is None: # So we can check for nullable value.default = None @@ -238,10 +237,10 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) def _is_field_noneable(field: "FieldInfo") -> bool: if IS_PYDANTIC_V2: - if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: + if getattr(field, "nullable", Undefined) is not Undefined: return field.nullable # type: ignore if not field.is_required(): - if field.default is PydanticUndefined: + if field.default is Undefined: return False if field.annotation is None or field.annotation is NoneType: return True @@ -265,8 +264,8 @@ def get_sqlalchemy_type(field: Any) -> Any: field_info = field else: field_info = field.field_info - sa_type = getattr(field_info, "sa_type", PydanticUndefined) # noqa: B009 - if sa_type is not PydanticUndefined: + sa_type = getattr(field_info, "sa_type", Undefined) # noqa: B009 + if sa_type is not Undefined: return sa_type type_ = get_type_from_field(field) @@ -370,29 +369,29 @@ def get_column_from_field(field: Any) -> Column: # type: ignore field_info = field else: field_info = field.field_info - sa_column = getattr(field_info, "sa_column", PydanticUndefined) + sa_column = getattr(field_info, "sa_column", Undefined) if isinstance(sa_column, Column): return sa_column sa_type = get_sqlalchemy_type(field) - primary_key = getattr(field_info, "primary_key", PydanticUndefined) - if primary_key is PydanticUndefined: + primary_key = getattr(field_info, "primary_key", Undefined) + if primary_key is Undefined: primary_key = False - index = getattr(field_info, "index", PydanticUndefined) - if index is PydanticUndefined: + index = getattr(field_info, "index", Undefined) + if index is Undefined: index = False nullable = not primary_key and _is_field_noneable(field) # Override derived nullability if the nullable property is set explicitly # on the field - field_nullable = getattr(field_info, "nullable", PydanticUndefined) # noqa: B009 - if field_nullable is not PydanticUndefined: - assert not isinstance(field_nullable, PydanticUndefinedType) + field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009 + if field_nullable is not Undefined: + assert not isinstance(field_nullable, UndefinedType) nullable = field_nullable args = [] - foreign_key = getattr(field_info, "foreign_key", PydanticUndefined) - if foreign_key is PydanticUndefined: + foreign_key = getattr(field_info, "foreign_key", Undefined) + if foreign_key is Undefined: foreign_key = None - unique = getattr(field_info, "unique", PydanticUndefined) - if unique is PydanticUndefined: + unique = getattr(field_info, "unique", Undefined) + if unique is Undefined: unique = False if foreign_key: assert isinstance(foreign_key, str) @@ -403,17 +402,17 @@ def get_column_from_field(field: Any) -> Column: # type: ignore "index": index, "unique": unique, } - sa_default = PydanticUndefined + sa_default = Undefined if field_info.default_factory: sa_default = field_info.default_factory - elif field_info.default is not PydanticUndefined: + elif field_info.default is not Undefined: sa_default = field_info.default - if sa_default is not PydanticUndefined: + if sa_default is not Undefined: kwargs["default"] = sa_default - sa_column_args = getattr(field_info, "sa_column_args", PydanticUndefined) - if sa_column_args is not PydanticUndefined: + sa_column_args = getattr(field_info, "sa_column_args", Undefined) + if sa_column_args is not Undefined: args.extend(list(cast(Sequence[Any], sa_column_args))) - sa_column_kwargs = getattr(field_info, "sa_column_kwargs", PydanticUndefined) - if sa_column_kwargs is not PydanticUndefined: + sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined) + if sa_column_kwargs is not Undefined: kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) return Column(sa_type, *args, **kwargs) # type: ignore From d9bafa34367f2a98c93313dbe66cacb1da63648d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Fri, 24 Nov 2023 09:31:55 +0100 Subject: [PATCH 038/105] =?UTF-8?q?=F0=9F=90=9B=20Fix=20type=20of=20model.?= =?UTF-8?q?=5F=5Ffields=5F=5F=20in=20Pydantic=20v1=20and=20add=20compatibi?= =?UTF-8?q?lity=20class=20for=20Pydantic=20v2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 9d2a097ed6..4f4c71aa77 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -46,6 +46,10 @@ from pydantic._internal._model_construction import ModelMetaclass from pydantic_core import PydanticUndefined as Undefined # noqa from pydantic_core import PydanticUndefinedType as UndefinedType + + # Dummy for types, to make it importable + class ModelField: + pass else: from pydantic import BaseConfig as PydanticModelConfig from pydantic.fields import SHAPE_SINGLETON, ModelField @@ -75,7 +79,6 @@ class SQLModelConfig(PydanticModelConfig): registry: Optional[Any] = None -# Inspired from https://github.com/roman-right/beanie/blob/main/beanie/odm/utils/pydantic.py def get_model_config(model: type) -> Optional[SQLModelConfig]: if IS_PYDANTIC_V2: return getattr(model, "model_config", None) @@ -416,3 +419,8 @@ def get_column_from_field(field: Any) -> Column: # type: ignore if sa_column_kwargs is not Undefined: kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) return Column(sa_type, *args, **kwargs) # type: ignore + + +def post_init_field_info(field_info: FieldInfo) -> None: + if not IS_PYDANTIC_V2: + field_info._validate() From 6e29b228c342a0e13edc44667e4d58d815afa0d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Fri, 24 Nov 2023 09:34:33 +0100 Subject: [PATCH 039/105] =?UTF-8?q?=F0=9F=90=9B=20Use=20compat=20ModelFiel?= =?UTF-8?q?d=20for=20type=20annotation=20in=20model.=5F=5Ffields=5F=5F=20f?= =?UTF-8?q?or=20Pydantic=20v1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index fe5272f4d9..af84be7b73 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -373,7 +373,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): model_fields: Dict[str, FieldInfo] else: __config__: Type[SQLModelConfig] - __fields__: Dict[str, FieldInfo] + __fields__: Dict[str, ModelField] # Replicate SQLAlchemy def __setattr__(cls, name: str, value: Any) -> None: From 835cbc8cb41cdd7668cdd7594b97101ecfeb43ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Fri, 24 Nov 2023 09:35:26 +0100 Subject: [PATCH 040/105] =?UTF-8?q?=F0=9F=90=9B=20Fix=20removed=20Pydantic?= =?UTF-8?q?=20v1=20field=20post=20init?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index af84be7b73..3e8968fda1 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -323,6 +323,7 @@ def Field( sa_column_kwargs=sa_column_kwargs, **current_schema_extra, ) + post_init_field_info(field_info) return field_info From dbaef533dc66618d2b4d8ca0426907ebd750d04c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 11:25:39 +0100 Subject: [PATCH 041/105] =?UTF-8?q?=F0=9F=9A=9A=20Move=20get=5Fsqlalchemy?= =?UTF-8?q?=5Ftype=20and=20get=5Fcolumn=5Ffrom=5Ffield=20back=20to=20main?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 110 --------------------------------------------- sqlmodel/main.py | 110 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 110 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 4f4c71aa77..03a9cfccd1 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -262,62 +262,6 @@ def _is_field_noneable(field: "FieldInfo") -> bool: return field.allow_none -def get_sqlalchemy_type(field: Any) -> Any: - if IS_PYDANTIC_V2: - field_info = field - else: - field_info = field.field_info - sa_type = getattr(field_info, "sa_type", Undefined) # noqa: B009 - if sa_type is not Undefined: - return sa_type - - type_ = get_type_from_field(field) - metadata = get_field_metadata(field) - - # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI - if issubclass(type_, Enum): - return sa_Enum(type_) - if issubclass(type_, str): - max_length = getattr(metadata, "max_length", None) - if max_length: - return AutoString(length=max_length) - return AutoString - if issubclass(type_, float): - return Float - if issubclass(type_, bool): - return Boolean - if issubclass(type_, int): - return Integer - if issubclass(type_, datetime): - return DateTime - if issubclass(type_, date): - return Date - if issubclass(type_, timedelta): - return Interval - if issubclass(type_, time): - return Time - if issubclass(type_, bytes): - return LargeBinary - if issubclass(type_, Decimal): - return Numeric( - precision=getattr(metadata, "max_digits", None), - scale=getattr(metadata, "decimal_places", None), - ) - if issubclass(type_, ipaddress.IPv4Address): - return AutoString - if issubclass(type_, ipaddress.IPv4Network): - return AutoString - if issubclass(type_, ipaddress.IPv6Address): - return AutoString - if issubclass(type_, ipaddress.IPv6Network): - return AutoString - if issubclass(type_, Path): - return AutoString - if issubclass(type_, uuid.UUID): - return GUID - raise ValueError(f"{type_} has no matching SQLAlchemy type") - - def get_type_from_field(field: Any) -> type: if IS_PYDANTIC_V2: type_: type | None = field.annotation @@ -367,60 +311,6 @@ def get_field_metadata(field: Any) -> Any: return metadata -def get_column_from_field(field: Any) -> Column: # type: ignore - if IS_PYDANTIC_V2: - field_info = field - else: - field_info = field.field_info - sa_column = getattr(field_info, "sa_column", Undefined) - if isinstance(sa_column, Column): - return sa_column - sa_type = get_sqlalchemy_type(field) - primary_key = getattr(field_info, "primary_key", Undefined) - if primary_key is Undefined: - primary_key = False - index = getattr(field_info, "index", Undefined) - if index is Undefined: - index = False - nullable = not primary_key and _is_field_noneable(field) - # Override derived nullability if the nullable property is set explicitly - # on the field - field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009 - if field_nullable is not Undefined: - assert not isinstance(field_nullable, UndefinedType) - nullable = field_nullable - args = [] - foreign_key = getattr(field_info, "foreign_key", Undefined) - if foreign_key is Undefined: - foreign_key = None - unique = getattr(field_info, "unique", Undefined) - if unique is Undefined: - unique = False - if foreign_key: - assert isinstance(foreign_key, str) - args.append(ForeignKey(foreign_key)) - kwargs = { - "primary_key": primary_key, - "nullable": nullable, - "index": index, - "unique": unique, - } - sa_default = Undefined - if field_info.default_factory: - sa_default = field_info.default_factory - elif field_info.default is not Undefined: - sa_default = field_info.default - if sa_default is not Undefined: - kwargs["default"] = sa_default - sa_column_args = getattr(field_info, "sa_column_args", Undefined) - if sa_column_args is not Undefined: - args.extend(list(cast(Sequence[Any], sa_column_args))) - sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined) - if sa_column_kwargs is not Undefined: - kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) - return Column(sa_type, *args, **kwargs) # type: ignore - - def post_init_field_info(field_info: FieldInfo) -> None: if not IS_PYDANTIC_V2: field_info._validate() diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3e8968fda1..bc4e82e079 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -548,6 +548,116 @@ def __init__( ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) +def get_sqlalchemy_type(field: Any) -> Any: + if IS_PYDANTIC_V2: + field_info = field + else: + field_info = field.field_info + sa_type = getattr(field_info, "sa_type", Undefined) # noqa: B009 + if sa_type is not Undefined: + return sa_type + + type_ = get_type_from_field(field) + metadata = get_field_metadata(field) + + # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI + if issubclass(type_, Enum): + return sa_Enum(type_) + if issubclass(type_, str): + max_length = getattr(metadata, "max_length", None) + if max_length: + return AutoString(length=max_length) + return AutoString + if issubclass(type_, float): + return Float + if issubclass(type_, bool): + return Boolean + if issubclass(type_, int): + return Integer + if issubclass(type_, datetime): + return DateTime + if issubclass(type_, date): + return Date + if issubclass(type_, timedelta): + return Interval + if issubclass(type_, time): + return Time + if issubclass(type_, bytes): + return LargeBinary + if issubclass(type_, Decimal): + return Numeric( + precision=getattr(metadata, "max_digits", None), + scale=getattr(metadata, "decimal_places", None), + ) + if issubclass(type_, ipaddress.IPv4Address): + return AutoString + if issubclass(type_, ipaddress.IPv4Network): + return AutoString + if issubclass(type_, ipaddress.IPv6Address): + return AutoString + if issubclass(type_, ipaddress.IPv6Network): + return AutoString + if issubclass(type_, Path): + return AutoString + if issubclass(type_, uuid.UUID): + return GUID + raise ValueError(f"{type_} has no matching SQLAlchemy type") + + +def get_column_from_field(field: Any) -> Column: # type: ignore + if IS_PYDANTIC_V2: + field_info = field + else: + field_info = field.field_info + sa_column = getattr(field_info, "sa_column", Undefined) + if isinstance(sa_column, Column): + return sa_column + sa_type = get_sqlalchemy_type(field) + primary_key = getattr(field_info, "primary_key", Undefined) + if primary_key is Undefined: + primary_key = False + index = getattr(field_info, "index", Undefined) + if index is Undefined: + index = False + nullable = not primary_key and _is_field_noneable(field) + # Override derived nullability if the nullable property is set explicitly + # on the field + field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009 + if field_nullable is not Undefined: + assert not isinstance(field_nullable, UndefinedType) + nullable = field_nullable + args = [] + foreign_key = getattr(field_info, "foreign_key", Undefined) + if foreign_key is Undefined: + foreign_key = None + unique = getattr(field_info, "unique", Undefined) + if unique is Undefined: + unique = False + if foreign_key: + assert isinstance(foreign_key, str) + args.append(ForeignKey(foreign_key)) + kwargs = { + "primary_key": primary_key, + "nullable": nullable, + "index": index, + "unique": unique, + } + sa_default = Undefined + if field_info.default_factory: + sa_default = field_info.default_factory + elif field_info.default is not Undefined: + sa_default = field_info.default + if sa_default is not Undefined: + kwargs["default"] = sa_default + sa_column_args = getattr(field_info, "sa_column_args", Undefined) + if sa_column_args is not Undefined: + args.extend(list(cast(Sequence[Any], sa_column_args))) + sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined) + if sa_column_kwargs is not Undefined: + kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) + return Column(sa_type, *args, **kwargs) # type: ignore + + class_registry = weakref.WeakValueDictionary() # type: ignore default_registry = registry() From dcbc08428e78cd29c5c157eeed08c9b90fe4c488 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 11:46:58 +0100 Subject: [PATCH 042/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20paramet?= =?UTF-8?q?ers=20for=20set=5Fconfig=5Fvalue?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 4 ++-- sqlmodel/main.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 03a9cfccd1..27f73205fe 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -96,15 +96,15 @@ def get_config_value( def set_config_value( + *, model: InstanceOrType["SQLModel"], parameter: str, value: Any, - v1_parameter: Optional[str] = None, ) -> None: if IS_PYDANTIC_V2: model.model_config[parameter] = value # type: ignore else: - setattr(model.__config__, v1_parameter or parameter, value) # type: ignore + setattr(model.__config__, parameter, value) # type: ignore def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: diff --git a/sqlmodel/main.py b/sqlmodel/main.py index bc4e82e079..e2a7f5623c 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -455,7 +455,7 @@ def get_config(name: str) -> Any: config_table = get_config("table") if config_table is True: # If it was passed by kwargs, ensure it's also set in config - set_config_value(new_cls, "table", config_table) + set_config_value(model=new_cls, parameter="table", value=config_table) for k, v in get_model_fields(new_cls).items(): col = get_column_from_field(v) setattr(new_cls, k, col) @@ -465,14 +465,17 @@ def get_config(name: str) -> Any: # that's very specific about SQLModel, so let's have another config that # other future tools based on Pydantic can use. set_config_value( - new_cls, "read_from_attributes", True, v1_parameter="read_with_orm_mode" + model=new_cls, parameter="read_from_attributes", value=True ) + # For compatibility with older versions + # TODO: remove this in the future + set_config_value(model=new_cls, parameter="read_with_orm_mode", value=True) config_registry = get_config("registry") if config_registry is not Undefined: config_registry = cast(registry, config_registry) # If it was passed by kwargs, ensure it's also set in config - set_config_value(new_cls, "registry", config_table) + set_config_value(model=new_cls, parameter="registry", value=config_table) setattr(new_cls, "_sa_registry", config_registry) # noqa: B010 setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010 setattr(new_cls, "__abstract__", True) # noqa: B010 From c836ec1fb9f6820c7934f4854d4f9a8dabe767f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 11:48:12 +0100 Subject: [PATCH 043/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Update=20calls=20f?= =?UTF-8?q?or=20get=5Frelationship=5Fto=20with=20keyword=20args,=20remove?= =?UTF-8?q?=20done=20related=20TODOs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index e2a7f5623c..0178d09acf 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -497,11 +497,6 @@ def __init__( # over anything else, use that and continue with the next attribute setattr(cls, rel_name, rel_info.sa_relationship) # Fix #315 continue - # TODO: remove this - # From PR - # ann = cls.__annotations__[rel_name] - # relationship_to = get_relationship_to(rel_name, rel_info, ann) - # From main, modified with PR code raw_ann = cls.__annotations__[rel_name] origin = get_origin(raw_ann) if origin is Mapped: @@ -512,18 +507,9 @@ def __init__( # handled well by SQLAlchemy without Mapped, so, wrap the # annotations in Mapped here cls.__annotations__[rel_name] = Mapped[ann] # type: ignore[valid-type] - # TODO: remove this, moved to get_relationship_to - # temp_field = ModelField.infer( - # name=rel_name, - # value=rel_info, - # annotation=ann, - # class_validators=None, - # config=BaseConfig, - # ) - # relationship_to = temp_field.type_ - # if isinstance(temp_field.type_, ForwardRef): - # relationship_to = temp_field.type_.__forward_arg__ - relationship_to = get_relationship_to(rel_name, rel_info, ann) + relationship_to = get_relationship_to( + name=rel_name, rel_info=rel_info, annotation=ann + ) rel_kwargs: Dict[str, Any] = {} if rel_info.back_populates: rel_kwargs["back_populates"] = rel_info.back_populates From 3c6b7a9f1f107f386bb91acf7cd3a6df68382cc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 11:49:32 +0100 Subject: [PATCH 044/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Revert=20import=20?= =?UTF-8?q?renaming=20of=20BaseConfig,=20to=20reduce=20the=20diff?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 27f73205fe..301c7fdb32 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -41,7 +41,7 @@ IS_PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") if IS_PYDANTIC_V2: - from pydantic import ConfigDict as PydanticModelConfig + from pydantic import ConfigDict as BaseConfig from pydantic._internal._fields import PydanticMetadata from pydantic._internal._model_construction import ModelMetaclass from pydantic_core import PydanticUndefined as Undefined # noqa @@ -51,7 +51,7 @@ class ModelField: pass else: - from pydantic import BaseConfig as PydanticModelConfig + from pydantic import BaseConfig as BaseConfig from pydantic.fields import SHAPE_SINGLETON, ModelField from pydantic.fields import Undefined as Undefined # noqa from pydantic.fields import UndefinedType as UndefinedType From c7390e15505ca2b5092422e27ee50d2a7cc7749f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 11:49:59 +0100 Subject: [PATCH 045/105] =?UTF-8?q?=E2=9C=8F=EF=B8=8F=20Fix=20typos=20in?= =?UTF-8?q?=20warning=20texts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 301c7fdb32..27041f7dd3 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -275,12 +275,12 @@ def get_type_from_field(field: Any) -> type: bases = get_args(type_) if len(bases) > 2: raise ValueError( - "Cannot have a (non-optional) union as a SQL alchemy field" + "Cannot have a (non-optional) union as a SQLAlchemy field" ) # Non optional unions are not allowed if bases[0] is not NoneType and bases[1] is not NoneType: raise ValueError( - "Cannot have a (non-optional) union as a SQL alchemy field" + "Cannot have a (non-optional) union as a SQLlchemy field" ) # Optional unions are allowed return bases[0] if bases[0] is not NoneType else bases[1] From 3e45276ccdc03991cd303ef609b991b58e95ab66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 11:52:28 +0100 Subject: [PATCH 046/105] =?UTF-8?q?=F0=9F=94=A5=20Remove=20unused=20import?= =?UTF-8?q?s?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 27041f7dd3..9ac9b809ef 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -1,9 +1,3 @@ -import ipaddress -import uuid -from datetime import date, datetime, time, timedelta -from decimal import Decimal -from enum import Enum -from pathlib import Path from types import NoneType from typing import ( TYPE_CHECKING, @@ -12,31 +6,14 @@ Dict, ForwardRef, Optional, - Sequence, Type, TypeVar, Union, - cast, get_args, get_origin, ) from pydantic import VERSION as PYDANTIC_VERSION -from sqlalchemy import ( - Boolean, - Column, - Date, - DateTime, - Float, - ForeignKey, - Integer, - Interval, - Numeric, -) -from sqlalchemy import Enum as sa_Enum -from sqlalchemy.sql.sqltypes import LargeBinary, Time - -from .sql.sqltypes import GUID, AutoString IS_PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") From eadf5bcac5b1ad9ee2bdceb2463e9d82fad48519 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 12:01:55 +0100 Subject: [PATCH 047/105] =?UTF-8?q?=F0=9F=94=A5=20Remove=20unused=20utils?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 38 -------------------------------------- 1 file changed, 38 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 9ac9b809ef..49ab36b3ca 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -56,13 +56,6 @@ class SQLModelConfig(PydanticModelConfig): registry: Optional[Any] = None -def get_model_config(model: type) -> Optional[SQLModelConfig]: - if IS_PYDANTIC_V2: - return getattr(model, "model_config", None) - else: - return getattr(model, "__config__", None) - - def get_config_value( model: InstanceOrType["SQLModel"], parameter: str, default: Any = None ) -> Any: @@ -91,13 +84,6 @@ def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo" return model.__fields__ # type: ignore -def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]: - if IS_PYDANTIC_V2: - return model.__pydantic_fields_set__ # type: ignore - else: - return model.__fields_set__ # type: ignore - - def set_fields_set( new_object: InstanceOrType["SQLModel"], fields: set["FieldInfo"] ) -> None: @@ -107,13 +93,6 @@ def set_fields_set( object.__setattr__(new_object, "__fields_set__", fields) -def set_attribute_mode(cls: Type["SQLModelMetaclass"]) -> None: - if IS_PYDANTIC_V2: - cls.model_config["read_from_attributes"] = True - else: - cls.__config__.read_with_orm_mode = True # type: ignore - - def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: if IS_PYDANTIC_V2: return class_dict.get("__annotations__", {}) @@ -123,23 +102,6 @@ def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: ) -def class_dict_is_table( - class_dict: dict[str, Any], class_kwargs: dict[str, Any] -) -> bool: - config: SQLModelConfig = {} - if IS_PYDANTIC_V2: - config = class_dict.get("model_config", {}) - else: - config = class_dict.get("__config__", {}) - config_table = config.get("table", Undefined) - if config_table is not Undefined: - return config_table # type: ignore - kw_table = class_kwargs.get("table", Undefined) - if kw_table is not Undefined: - return kw_table # type: ignore - return False - - def cls_is_table(cls: Type) -> bool: if IS_PYDANTIC_V2: config = getattr(cls, "model_config", None) From d35989af38bd93539967b5008a3e2fa373c84e18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 12:02:23 +0100 Subject: [PATCH 048/105] =?UTF-8?q?=F0=9F=9A=9A=20Move=20NoArgAnyCallable?= =?UTF-8?q?=20type=20to=20main?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 1 - sqlmodel/main.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 49ab36b3ca..4af1e79797 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -39,7 +39,6 @@ class ModelField: from .main import FieldInfo, RelationshipInfo, SQLModel, SQLModelMetaclass -NoArgAnyCallable = Callable[[], Any] T = TypeVar("T") InstanceOrType = Union[T, Type[T]] diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 0178d09acf..850639b9fd 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -64,6 +64,7 @@ from pydantic.utils import ROOT_KEY _T = TypeVar("_T") +NoArgAnyCallable = Callable[[], Any] def __dataclass_transform__( From ef8bbc63c25547dcda736e2ad0f86b993dee9452 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 12:03:29 +0100 Subject: [PATCH 049/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20signatu?= =?UTF-8?q?re=20of=20get=5Fconfig=5Fvalue,=20require=20all=20keyword=20arg?= =?UTF-8?q?uments?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 2 +- sqlmodel/main.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 4af1e79797..365a0067f4 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -56,7 +56,7 @@ class SQLModelConfig(PydanticModelConfig): def get_config_value( - model: InstanceOrType["SQLModel"], parameter: str, default: Any = None + *, model: InstanceOrType["SQLModel"], parameter: str, default: Any = None ) -> Any: if IS_PYDANTIC_V2: return model.model_config.get(parameter, default) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 850639b9fd..95a8b0e535 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -379,13 +379,13 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): # Replicate SQLAlchemy def __setattr__(cls, name: str, value: Any) -> None: - if get_config_value(cls, "table", False): + if get_config_value(model=cls, parameter="table", default=False): DeclarativeMeta.__setattr__(cls, name, value) else: super().__setattr__(name, value) def __delattr__(cls, name: str) -> None: - if get_config_value(cls, "table", False): + if get_config_value(model=cls, parameter="table", default=False): DeclarativeMeta.__delattr__(cls, name) else: super().__delattr__(name) @@ -445,7 +445,9 @@ def __new__( } def get_config(name: str) -> Any: - config_class_value = get_config_value(new_cls, name, Undefined) + config_class_value = get_config_value( + model=new_cls, parameter=name, default=Undefined + ) if config_class_value is not Undefined: return config_class_value kwarg_value = kwargs.get(name, Undefined) @@ -724,7 +726,9 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if get_config_value(self, "table", False) and is_instrumented(self, name): + if get_config_value( + model=self, parameter="table", default=False + ) and is_instrumented(self, name): set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values From 9ba2e5904f4dd5331b2c950beb6af3d6ed839cf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 13:25:06 +0100 Subject: [PATCH 050/105] =?UTF-8?q?=F0=9F=94=A5=20Remove=20unused=20import?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 365a0067f4..cc7b4415c6 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -36,7 +36,7 @@ class ModelField: from pydantic.typing import resolve_annotations if TYPE_CHECKING: - from .main import FieldInfo, RelationshipInfo, SQLModel, SQLModelMetaclass + from .main import FieldInfo, RelationshipInfo, SQLModel T = TypeVar("T") From 63a55a8353c2db6fae76cadb16300d7e84835a11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 13:25:25 +0100 Subject: [PATCH 051/105] =?UTF-8?q?=F0=9F=94=A5=20Remove=20unused=20type?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index cc7b4415c6..6b52b8be4c 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -2,7 +2,6 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Dict, ForwardRef, Optional, From f754390d80f87871205cdf316d9a30d2f8a2cc99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 13:26:13 +0100 Subject: [PATCH 052/105] =?UTF-8?q?=E2=8F=AA=EF=B8=8F=20Revert=20removing?= =?UTF-8?q?=20class=5Fdict=5Fis=5Ftable=20until=20it's=20clear=20if=20it's?= =?UTF-8?q?=20needed=20or=20not?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 6b52b8be4c..0db4b2b037 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -99,6 +99,23 @@ def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: class_dict.get("__annotations__", {}), class_dict.get("__module__", None) ) +# TODO: review if this is necessary +def class_dict_is_table( + class_dict: dict[str, Any], class_kwargs: dict[str, Any] +) -> bool: + config: SQLModelConfig = {} + if IS_PYDANTIC_V2: + config = class_dict.get("model_config", {}) + else: + config = class_dict.get("__config__", {}) + config_table = config.get("table", Undefined) + if config_table is not Undefined: + return config_table # type: ignore + kw_table = class_kwargs.get("table", Undefined) + if kw_table is not Undefined: + return kw_table # type: ignore + return False + def cls_is_table(cls: Type) -> bool: if IS_PYDANTIC_V2: From 4a3fea803fa7876c9c4bf76e3d40ccfc07490b77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 13:27:07 +0100 Subject: [PATCH 053/105] =?UTF-8?q?=F0=9F=92=A1=20Add=20comments=20for=20n?= =?UTF-8?q?ext=20tasks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 1 + sqlmodel/main.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 0db4b2b037..c909c2d06b 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -99,6 +99,7 @@ def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: class_dict.get("__annotations__", {}), class_dict.get("__module__", None) ) + # TODO: review if this is necessary def class_dict_is_table( class_dict: dict[str, Any], class_kwargs: dict[str, Any] diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 95a8b0e535..12dc862424 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -432,6 +432,7 @@ def __new__( config_kwargs = { key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } + # TODO: review if this is necessary if class_dict_is_table(class_dict, kwargs): set_empty_defaults(pydantic_annotations, dict_used) @@ -693,6 +694,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any: def __init__(__pydantic_self__, **data: Any) -> None: # Uses something other than `self` the first arg to allow "self" as a # settable attribute + # TODO: review how this works and check defaults set in metaclass __new__ if IS_PYDANTIC_V2: old_dict = __pydantic_self__.__dict__.copy() super().__init__(**data) # noqa @@ -747,6 +749,8 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: def __tablename__(cls) -> str: return cls.__name__.lower() + # TODO: refactor this and make each method available in both Pydantic v1 and v2 + # add deprecations, re-use methods from backwards compatibility parts, etc. if IS_PYDANTIC_V2: @classmethod From b76366806e4cc54e333a73aad99258c974788a5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 13:30:59 +0100 Subject: [PATCH 054/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20imports?= =?UTF-8?q?=20in=20main?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 12dc862424..f8ef913601 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -1,4 +1,10 @@ +import ipaddress +import uuid import weakref +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from enum import Enum +from pathlib import Path from typing import ( AbstractSet, Any, @@ -22,9 +28,18 @@ from pydantic.fields import FieldInfo as PydanticFieldInfo from pydantic.utils import Representation from sqlalchemy import ( + Boolean, Column, + Date, + DateTime, + Float, + ForeignKey, + Integer, + Interval, + Numeric, inspect, ) +from sqlalchemy import Enum as sa_Enum from sqlalchemy.orm import ( Mapped, RelationshipProperty, @@ -36,27 +51,32 @@ from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData +from sqlalchemy.sql.sqltypes import LargeBinary, Time from typing_extensions import get_origin from .compat import ( IS_PYDANTIC_V2, + ModelField, ModelMetaclass, - NoArgAnyCallable, PydanticModelConfig, SQLModelConfig, Undefined, UndefinedType, + _is_field_noneable, class_dict_is_table, cls_is_table, get_annotations, - get_column_from_field, get_config_value, + get_field_metadata, get_model_fields, get_relationship_to, + get_type_from_field, + post_init_field_info, set_config_value, set_empty_defaults, set_fields_set, ) +from .sql.sqltypes import GUID, AutoString if not IS_PYDANTIC_V2: from pydantic.errors import ConfigError, DictError From fb112854cc05351cd5494d80ff57cfe81f814efd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 13:34:51 +0100 Subject: [PATCH 055/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Tweak=20names=20an?= =?UTF-8?q?d=20imports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 7 ++++--- sqlmodel/main.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index c909c2d06b..63fb9953bb 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -13,6 +13,7 @@ ) from pydantic import VERSION as PYDANTIC_VERSION +from pydantic.fields import FieldInfo IS_PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") @@ -35,7 +36,7 @@ class ModelField: from pydantic.typing import resolve_annotations if TYPE_CHECKING: - from .main import FieldInfo, RelationshipInfo, SQLModel + from .main import RelationshipInfo, SQLModel T = TypeVar("T") @@ -43,13 +44,13 @@ class ModelField: if IS_PYDANTIC_V2: - class SQLModelConfig(PydanticModelConfig, total=False): + class SQLModelConfig(BaseConfig, total=False): table: Optional[bool] registry: Optional[Any] else: - class SQLModelConfig(PydanticModelConfig): + class SQLModelConfig(BaseConfig): table: Optional[bool] = None registry: Optional[Any] = None diff --git a/sqlmodel/main.py b/sqlmodel/main.py index f8ef913601..6ac332d364 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -56,9 +56,9 @@ from .compat import ( IS_PYDANTIC_V2, + BaseConfig, ModelField, ModelMetaclass, - PydanticModelConfig, SQLModelConfig, Undefined, UndefinedType, @@ -444,7 +444,7 @@ def __new__( # superclass causing an error allowed_config_kwargs: Set[str] = { key - for key in dir(PydanticModelConfig) + for key in dir(BaseConfig) if not ( key.startswith("__") and key.endswith("__") ) # skip dunder methods and attributes From aeabcf2e3745f26bed542848bc5ab69545c72cd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 18:23:20 +0100 Subject: [PATCH 056/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20get=5Fr?= =?UTF-8?q?elationship=5Fto=20to=20support=20more=20than=20one=20level=20o?= =?UTF-8?q?f=20type=20unwrapping/origins?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 63fb9953bb..02b369fccd 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -138,21 +138,22 @@ def get_relationship_to( annotation: Any, ) -> Any: if IS_PYDANTIC_V2: - # TODO: review rename origin and relationship_to - relationship_to = get_origin(annotation) + origin = get_origin(annotation) + use_annotation = annotation # Direct relationships (e.g. 'Team' or Team) have None as an origin - if relationship_to is None: - relationship_to = annotation + if origin is None: + return annotation # If Union (e.g. Optional), get the real field - elif relationship_to is Union: - relationship_to = get_args(annotation)[0] + elif origin is Union: + use_annotation = get_args(annotation)[0] # If a list, then also get the real field - elif relationship_to is list: - relationship_to = get_args(annotation)[0] - # TODO: given this, should there be a recursive call in this whole if block to get_relationship_to? - if isinstance(relationship_to, ForwardRef): - relationship_to = relationship_to.__forward_arg__ - return relationship_to + elif origin is list: + use_annotation = get_args(annotation)[0] + elif isinstance(origin, ForwardRef): + use_annotation = origin.__forward_arg__ + return get_relationship_to( + name=name, rel_info=rel_info, annotation=use_annotation + ) else: temp_field = ModelField.infer( name=name, From d6de62247dcf4d24205030e34edb8eaa72066b32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sat, 25 Nov 2023 22:16:23 +0100 Subject: [PATCH 057/105] =?UTF-8?q?=F0=9F=90=9B=20Fix=20recursion=20for=20?= =?UTF-8?q?forward=20refs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 02b369fccd..21aada2921 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -142,15 +142,17 @@ def get_relationship_to( use_annotation = annotation # Direct relationships (e.g. 'Team' or Team) have None as an origin if origin is None: - return annotation + if isinstance(use_annotation, ForwardRef): + use_annotation = use_annotation.__forward_arg__ + else: + return use_annotation # If Union (e.g. Optional), get the real field elif origin is Union: use_annotation = get_args(annotation)[0] # If a list, then also get the real field elif origin is list: use_annotation = get_args(annotation)[0] - elif isinstance(origin, ForwardRef): - use_annotation = origin.__forward_arg__ + return get_relationship_to( name=name, rel_info=rel_info, annotation=use_annotation ) From cb02fb2e66a69a899c5556f78f5fca14d8991f63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 29 Nov 2023 19:41:04 +0100 Subject: [PATCH 058/105] =?UTF-8?q?=F0=9F=90=9B=20Fix=20detection=20of=20u?= =?UTF-8?q?nions,=20detected=20by=20new=20Python=203.10=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 21aada2921..4f7511dc75 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -1,4 +1,4 @@ -from types import NoneType +import types from typing import ( TYPE_CHECKING, Any, @@ -38,7 +38,8 @@ class ModelField: if TYPE_CHECKING: from .main import RelationshipInfo, SQLModel - +UnionType = getattr(types, "UnionType", Union) +NoneType = type(None) T = TypeVar("T") InstanceOrType = Union[T, Type[T]] @@ -55,6 +56,10 @@ class SQLModelConfig(BaseConfig): registry: Optional[Any] = None +def _is_union_type(t: Any) -> bool: + return t is UnionType or t is Union + + def get_config_value( *, model: InstanceOrType["SQLModel"], parameter: str, default: Any = None ) -> Any: @@ -147,8 +152,22 @@ def get_relationship_to( else: return use_annotation # If Union (e.g. Optional), get the real field - elif origin is Union: - use_annotation = get_args(annotation)[0] + elif _is_union_type(origin): + use_annotation = get_args(annotation) + if len(use_annotation) > 2: + raise ValueError( + "Cannot have a (non-optional) union as a SQLAlchemy field" + ) + arg1, arg2 = use_annotation + if arg1 is NoneType and arg2 is not NoneType: + use_annotation = arg2 + elif arg2 is NoneType and arg1 is not NoneType: + use_annotation = arg1 + else: + raise ValueError( + "Cannot have a Union of None and None as a SQLAlchemy field" + ) + # If a list, then also get the real field elif origin is list: use_annotation = get_args(annotation)[0] @@ -206,7 +225,7 @@ def _is_field_noneable(field: "FieldInfo") -> bool: return False if field.annotation is None or field.annotation is NoneType: return True - if get_origin(field.annotation) is Union: + if _is_union_type(get_origin(field.annotation)): for base in get_args(field.annotation): if base is NoneType: return True @@ -230,7 +249,7 @@ def get_type_from_field(field: Any) -> type: origin = get_origin(type_) if origin is None: return type_ - if origin is Union: + if _is_union_type(origin): bases = get_args(type_) if len(bases) > 2: raise ValueError( From 0460f05ed6e228fb09f7a3f54a8ec1e72ba3595d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 29 Nov 2023 19:53:57 +0100 Subject: [PATCH 059/105] =?UTF-8?q?=F0=9F=94=A5=20Remove=20unnecessary=20c?= =?UTF-8?q?hecks=20and=20overrides=20for=20class=20defaults?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 45 --------------------------------------------- sqlmodel/main.py | 6 ------ 2 files changed, 51 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 4f7511dc75..c5b944a907 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -106,24 +106,6 @@ def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: ) -# TODO: review if this is necessary -def class_dict_is_table( - class_dict: dict[str, Any], class_kwargs: dict[str, Any] -) -> bool: - config: SQLModelConfig = {} - if IS_PYDANTIC_V2: - config = class_dict.get("model_config", {}) - else: - config = class_dict.get("__config__", {}) - config_table = config.get("table", Undefined) - if config_table is not Undefined: - return config_table # type: ignore - kw_table = class_kwargs.get("table", Undefined) - if kw_table is not Undefined: - return kw_table # type: ignore - return False - - def cls_is_table(cls: Type) -> bool: if IS_PYDANTIC_V2: config = getattr(cls, "model_config", None) @@ -189,33 +171,6 @@ def get_relationship_to( return relationship_to -def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) -> None: - """ - Pydantic v2 without required fields with no optionals cannot do empty initialisations. - This means we cannot do Table() and set fields later. - We go around this by adding a default to everything, being None - - Args: - annotations: Dict[str, Any]: The annotations to provide to pydantic - class_dict: Dict[str, Any]: The class dict for the defaults - """ - # TODO: no v1? - if IS_PYDANTIC_V2: - from .main import FieldInfo - - # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything - for key in annotations.keys(): - value = class_dict.get(key, Undefined) - if value is Undefined: - class_dict[key] = None - elif isinstance(value, FieldInfo): - if ( - value.default in (Undefined, Ellipsis) - ) and value.default_factory is None: - # So we can check for nullable - value.default = None - - def _is_field_noneable(field: "FieldInfo") -> bool: if IS_PYDANTIC_V2: if getattr(field, "nullable", Undefined) is not Undefined: diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 6ac332d364..c14cc63154 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -63,7 +63,6 @@ Undefined, UndefinedType, _is_field_noneable, - class_dict_is_table, cls_is_table, get_annotations, get_config_value, @@ -73,7 +72,6 @@ get_type_from_field, post_init_field_info, set_config_value, - set_empty_defaults, set_fields_set, ) from .sql.sqltypes import GUID, AutoString @@ -452,10 +450,6 @@ def __new__( config_kwargs = { key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } - # TODO: review if this is necessary - if class_dict_is_table(class_dict, kwargs): - set_empty_defaults(pydantic_annotations, dict_used) - new_cls: Type["SQLModelMetaclass"] = super().__new__( cls, name, bases, dict_used, **config_kwargs ) From 0187c361c6fb518bb50027ee4356d271c1f3ce8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 29 Nov 2023 20:05:48 +0100 Subject: [PATCH 060/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20compat?= =?UTF-8?q?=20to=20run=20all=20conditionals=20at=20import=20time=20and=20n?= =?UTF-8?q?ot=20at=20every=20execution?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/compat.py | 209 +++++++++++++++++++++++---------------------- 1 file changed, 107 insertions(+), 102 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index c5b944a907..dbd9f15479 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -43,88 +43,58 @@ class ModelField: T = TypeVar("T") InstanceOrType = Union[T, Type[T]] -if IS_PYDANTIC_V2: - - class SQLModelConfig(BaseConfig, total=False): - table: Optional[bool] - registry: Optional[Any] - -else: - class SQLModelConfig(BaseConfig): - table: Optional[bool] = None - registry: Optional[Any] = None +class FakeMetadata: + max_length: Optional[int] = None + max_digits: Optional[int] = None + decimal_places: Optional[int] = None def _is_union_type(t: Any) -> bool: return t is UnionType or t is Union -def get_config_value( - *, model: InstanceOrType["SQLModel"], parameter: str, default: Any = None -) -> Any: - if IS_PYDANTIC_V2: - return model.model_config.get(parameter, default) - else: - return getattr(model.__config__, parameter, default) +if IS_PYDANTIC_V2: + class SQLModelConfig(BaseConfig, total=False): + table: Optional[bool] + registry: Optional[Any] -def set_config_value( - *, - model: InstanceOrType["SQLModel"], - parameter: str, - value: Any, -) -> None: - if IS_PYDANTIC_V2: - model.model_config[parameter] = value # type: ignore - else: - setattr(model.__config__, parameter, value) # type: ignore + def get_config_value( + *, model: InstanceOrType["SQLModel"], parameter: str, default: Any = None + ) -> Any: + return model.model_config.get(parameter, default) + def set_config_value( + *, + model: InstanceOrType["SQLModel"], + parameter: str, + value: Any, + ) -> None: + model.model_config[parameter] = value -def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: - if IS_PYDANTIC_V2: + def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: return model.model_fields # type: ignore - else: - return model.__fields__ # type: ignore - -def set_fields_set( - new_object: InstanceOrType["SQLModel"], fields: set["FieldInfo"] -) -> None: - if IS_PYDANTIC_V2: + def set_fields_set( + new_object: InstanceOrType["SQLModel"], fields: set["FieldInfo"] + ) -> None: object.__setattr__(new_object, "__pydantic_fields_set__", fields) - else: - object.__setattr__(new_object, "__fields_set__", fields) - -def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: - if IS_PYDANTIC_V2: + def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: return class_dict.get("__annotations__", {}) - else: - return resolve_annotations( - class_dict.get("__annotations__", {}), class_dict.get("__module__", None) - ) - -def cls_is_table(cls: Type) -> bool: - if IS_PYDANTIC_V2: + def cls_is_table(cls: Type) -> bool: config = getattr(cls, "model_config", None) if not config: return False return config.get("table", False) - else: - config = getattr(cls, "__config__", None) - if not config: - return False - return getattr(config, "table", False) - -def get_relationship_to( - name: str, - rel_info: "RelationshipInfo", - annotation: Any, -) -> Any: - if IS_PYDANTIC_V2: + def get_relationship_to( + name: str, + rel_info: "RelationshipInfo", + annotation: Any, + ) -> Any: origin = get_origin(annotation) use_annotation = annotation # Direct relationships (e.g. 'Team' or Team) have None as an origin @@ -157,22 +127,8 @@ def get_relationship_to( return get_relationship_to( name=name, rel_info=rel_info, annotation=use_annotation ) - else: - temp_field = ModelField.infer( - name=name, - value=rel_info, - annotation=annotation, - class_validators=None, - config=SQLModelConfig, - ) - relationship_to = temp_field.type_ - if isinstance(temp_field.type_, ForwardRef): - relationship_to = temp_field.type_.__forward_arg__ - return relationship_to - -def _is_field_noneable(field: "FieldInfo") -> bool: - if IS_PYDANTIC_V2: + def _is_field_noneable(field: "FieldInfo") -> bool: if getattr(field, "nullable", Undefined) is not Undefined: return field.nullable # type: ignore if not field.is_required(): @@ -186,17 +142,8 @@ def _is_field_noneable(field: "FieldInfo") -> bool: return True return False return False - else: - if not field.required: - # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) - return field.allow_none and ( - field.shape != SHAPE_SINGLETON or not field.sub_fields - ) - return field.allow_none - -def get_type_from_field(field: Any) -> type: - if IS_PYDANTIC_V2: + def get_type_from_field(field: Any) -> type: type_: type | None = field.annotation # Resolve Optional fields if type_ is None: @@ -218,32 +165,90 @@ def get_type_from_field(field: Any) -> type: # Optional unions are allowed return bases[0] if bases[0] is not NoneType else bases[1] return origin - else: - if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: - return field.type_ - raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") - - -class FakeMetadata: - max_length: Optional[int] = None - max_digits: Optional[int] = None - decimal_places: Optional[int] = None - -def get_field_metadata(field: Any) -> Any: - if IS_PYDANTIC_V2: + def get_field_metadata(field: Any) -> Any: for meta in field.metadata: if isinstance(meta, PydanticMetadata): return meta return FakeMetadata() - else: + + def post_init_field_info(field_info: FieldInfo) -> None: + return None +else: + + class SQLModelConfig(BaseConfig): + table: Optional[bool] = None + registry: Optional[Any] = None + + def get_config_value( + *, model: InstanceOrType["SQLModel"], parameter: str, default: Any = None + ) -> Any: + return getattr(model.__config__, parameter, default) + + def set_config_value( + *, + model: InstanceOrType["SQLModel"], + parameter: str, + value: Any, + ) -> None: + setattr(model.__config__, parameter, value) # type: ignore + + def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: + return model.__fields__ # type: ignore + + def set_fields_set( + new_object: InstanceOrType["SQLModel"], fields: set["FieldInfo"] + ) -> None: + object.__setattr__(new_object, "__fields_set__", fields) + + def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: + return resolve_annotations( + class_dict.get("__annotations__", {}), + class_dict.get("__module__", None), + ) + + def cls_is_table(cls: Type) -> bool: + config = getattr(cls, "__config__", None) + if not config: + return False + return getattr(config, "table", False) + + def get_relationship_to( + name: str, + rel_info: "RelationshipInfo", + annotation: Any, + ) -> Any: + temp_field = ModelField.infer( + name=name, + value=rel_info, + annotation=annotation, + class_validators=None, + config=SQLModelConfig, + ) + relationship_to = temp_field.type_ + if isinstance(temp_field.type_, ForwardRef): + relationship_to = temp_field.type_.__forward_arg__ + return relationship_to + + def _is_field_noneable(field: "FieldInfo") -> bool: + if not field.required: + # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) + return field.allow_none and ( + field.shape != SHAPE_SINGLETON or not field.sub_fields + ) + return field.allow_none + + def get_type_from_field(field: Any) -> type: + if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: + return field.type_ + raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") + + def get_field_metadata(field: Any) -> Any: metadata = FakeMetadata() metadata.max_length = field.field_info.max_length metadata.max_digits = getattr(field.type_, "max_digits", None) metadata.decimal_places = getattr(field.type_, "decimal_places", None) return metadata - -def post_init_field_info(field_info: FieldInfo) -> None: - if not IS_PYDANTIC_V2: + def post_init_field_info(field_info: FieldInfo) -> None: field_info._validate() From 92490213ba68cdf940b313cf4636dbb45e2ae531 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 29 Nov 2023 20:07:18 +0100 Subject: [PATCH 061/105] =?UTF-8?q?=F0=9F=9A=9A=20Rename=20compat.py=20to?= =?UTF-8?q?=20=5Fcompat.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/{compat.py => _compat.py} | 0 sqlmodel/main.py | 2 +- tests/conftest.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename sqlmodel/{compat.py => _compat.py} (100%) diff --git a/sqlmodel/compat.py b/sqlmodel/_compat.py similarity index 100% rename from sqlmodel/compat.py rename to sqlmodel/_compat.py diff --git a/sqlmodel/main.py b/sqlmodel/main.py index c14cc63154..4c9e5325b7 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -54,7 +54,7 @@ from sqlalchemy.sql.sqltypes import LargeBinary, Time from typing_extensions import get_origin -from .compat import ( +from ._compat import ( IS_PYDANTIC_V2, BaseConfig, ModelField, diff --git a/tests/conftest.py b/tests/conftest.py index a73d4c1e97..7cf7b054b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ import pytest from pydantic import BaseModel from sqlmodel import SQLModel -from sqlmodel.compat import IS_PYDANTIC_V2 +from sqlmodel._compat import IS_PYDANTIC_V2 from sqlmodel.main import default_registry top_level_path = Path(__file__).resolve().parent.parent From d400fd77dca5122c8609705e2c333e44be6bf4ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 30 Nov 2023 17:02:52 +0100 Subject: [PATCH 062/105] =?UTF-8?q?=F0=9F=90=9B=20Fix=20handling=20of=20No?= =?UTF-8?q?nable=20unions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index dbd9f15479..fc249e90df 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -131,15 +131,16 @@ def get_relationship_to( def _is_field_noneable(field: "FieldInfo") -> bool: if getattr(field, "nullable", Undefined) is not Undefined: return field.nullable # type: ignore + origin = get_origin(field.annotation) + if origin is not None and _is_union_type(origin): + args = get_args(field.annotation) + if any(arg is NoneType for arg in args): + return True if not field.is_required(): if field.default is Undefined: return False if field.annotation is None or field.annotation is NoneType: return True - if _is_union_type(get_origin(field.annotation)): - for base in get_args(field.annotation): - if base is NoneType: - return True return False return False From 9bb29aafb5fc84cd2de36139ad1ce10553738c5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 30 Nov 2023 17:12:58 +0100 Subject: [PATCH 063/105] =?UTF-8?q?=E2=9C=85=20Simplify=20test=20that=20sh?= =?UTF-8?q?oudln't=20require=20Pydantic=20v1=20or=20v2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_instance_no_args.py | 28 +--------------------------- 1 file changed, 1 insertion(+), 27 deletions(-) diff --git a/tests/test_instance_no_args.py b/tests/test_instance_no_args.py index e54e8163b3..9630efa33b 100644 --- a/tests/test_instance_no_args.py +++ b/tests/test_instance_no_args.py @@ -6,19 +6,13 @@ from sqlalchemy.orm import Session from sqlmodel import Field, SQLModel -from .conftest import needs_pydanticv1, needs_pydanticv2 - -@needs_pydanticv1 def test_allow_instantiation_without_arguments_pydantic_v1(clear_sqlmodel): - class Item(SQLModel): + class Item(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) name: str description: Optional[str] = None - class Config: - table = True - engine = create_engine("sqlite:///:memory:") SQLModel.metadata.create_all(engine) with Session(engine) as db: @@ -40,23 +34,3 @@ class Item(SQLModel): with pytest.raises(ValidationError): Item() - - -@needs_pydanticv2 -def test_allow_instantiation_without_arguments_pydnatic_v2(clear_sqlmodel): - class Item(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - name: str - description: Optional[str] = None - - engine = create_engine("sqlite:///:memory:") - SQLModel.metadata.create_all(engine) - with Session(engine) as db: - item = Item() - item.name = "Rick" - db.add(item) - db.commit() - result = db.execute(select(Item)).scalars().all() - assert len(result) == 1 - assert isinstance(item.id, int) - SQLModel.metadata.clear() From edb051e16029e3a5cc4aafb4c5769aa8e60353df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 30 Nov 2023 17:13:28 +0100 Subject: [PATCH 064/105] =?UTF-8?q?=E2=AC=86=EF=B8=8F=20Set=20minimum=20Py?= =?UTF-8?q?dantic=20version=20to=201.10.13?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9bfc434cfb..ab007a51c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.7" SQLAlchemy = ">=2.0.0,<2.1.0" -pydantic = "^1.9.0" +pydantic = ">1.10.13,<3.0.0" [tool.poetry.group.dev.dependencies] pytest = "^7.0.1" From 65d6b3e6698098bf36471ae1444c36d8db4fd1aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 09:02:05 +0100 Subject: [PATCH 065/105] =?UTF-8?q?=F0=9F=9A=9A=20Rename=20compat=20to=20?= =?UTF-8?q?=5Fcompat?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs_src/tutorial/fastapi/app_testing/tutorial001/main.py | 2 +- docs_src/tutorial/fastapi/delete/tutorial001.py | 2 +- docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py | 2 +- docs_src/tutorial/fastapi/multiple_models/tutorial001.py | 2 +- docs_src/tutorial/fastapi/multiple_models/tutorial002.py | 2 +- docs_src/tutorial/fastapi/read_one/tutorial001.py | 2 +- docs_src/tutorial/fastapi/relationships/tutorial001.py | 2 +- .../tutorial/fastapi/session_with_dependency/tutorial001.py | 2 +- docs_src/tutorial/fastapi/teams/tutorial001.py | 2 +- docs_src/tutorial/fastapi/update/tutorial001.py | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py index f305f75194..cc830e8b19 100644 --- a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py +++ b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py @@ -2,7 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel.compat import IS_PYDANTIC_V2 +from sqlmodel._compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): diff --git a/docs_src/tutorial/fastapi/delete/tutorial001.py b/docs_src/tutorial/fastapi/delete/tutorial001.py index f186c42b2b..04e23ee251 100644 --- a/docs_src/tutorial/fastapi/delete/tutorial001.py +++ b/docs_src/tutorial/fastapi/delete/tutorial001.py @@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel.compat import IS_PYDANTIC_V2 +from sqlmodel._compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): diff --git a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py index 6701355f17..223edddfa7 100644 --- a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py +++ b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py @@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel.compat import IS_PYDANTIC_V2 +from sqlmodel._compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py index 0ceed94ca1..10b169d3b8 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py @@ -2,7 +2,7 @@ from fastapi import FastAPI from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel.compat import IS_PYDANTIC_V2 +from sqlmodel._compat import IS_PYDANTIC_V2 class Hero(SQLModel, table=True): diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py index d92745a339..daa34ccedd 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py @@ -2,7 +2,7 @@ from fastapi import FastAPI from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel.compat import IS_PYDANTIC_V2 +from sqlmodel._compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): diff --git a/docs_src/tutorial/fastapi/read_one/tutorial001.py b/docs_src/tutorial/fastapi/read_one/tutorial001.py index aa805d6c8f..1a106b80ca 100644 --- a/docs_src/tutorial/fastapi/read_one/tutorial001.py +++ b/docs_src/tutorial/fastapi/read_one/tutorial001.py @@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel.compat import IS_PYDANTIC_V2 +from sqlmodel._compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): diff --git a/docs_src/tutorial/fastapi/relationships/tutorial001.py b/docs_src/tutorial/fastapi/relationships/tutorial001.py index dfcedaf881..1f067c302a 100644 --- a/docs_src/tutorial/fastapi/relationships/tutorial001.py +++ b/docs_src/tutorial/fastapi/relationships/tutorial001.py @@ -2,7 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select -from sqlmodel.compat import IS_PYDANTIC_V2 +from sqlmodel._compat import IS_PYDANTIC_V2 class TeamBase(SQLModel): diff --git a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py index f305f75194..cc830e8b19 100644 --- a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py +++ b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py @@ -2,7 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel.compat import IS_PYDANTIC_V2 +from sqlmodel._compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): diff --git a/docs_src/tutorial/fastapi/teams/tutorial001.py b/docs_src/tutorial/fastapi/teams/tutorial001.py index 46ea0f933c..3f757c0cec 100644 --- a/docs_src/tutorial/fastapi/teams/tutorial001.py +++ b/docs_src/tutorial/fastapi/teams/tutorial001.py @@ -2,7 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select -from sqlmodel.compat import IS_PYDANTIC_V2 +from sqlmodel._compat import IS_PYDANTIC_V2 class TeamBase(SQLModel): diff --git a/docs_src/tutorial/fastapi/update/tutorial001.py b/docs_src/tutorial/fastapi/update/tutorial001.py index 93dfa7496a..c179c6e363 100644 --- a/docs_src/tutorial/fastapi/update/tutorial001.py +++ b/docs_src/tutorial/fastapi/update/tutorial001.py @@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel.compat import IS_PYDANTIC_V2 +from sqlmodel._compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): From f879c7789d8917f7ba95368c631bdc5f18fee5ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 09:02:23 +0100 Subject: [PATCH 066/105] =?UTF-8?q?=F0=9F=9A=9A=20Rename=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_instance_no_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_instance_no_args.py b/tests/test_instance_no_args.py index 9630efa33b..f424c16970 100644 --- a/tests/test_instance_no_args.py +++ b/tests/test_instance_no_args.py @@ -7,7 +7,7 @@ from sqlmodel import Field, SQLModel -def test_allow_instantiation_without_arguments_pydantic_v1(clear_sqlmodel): +def test_allow_instantiation_without_arguments(clear_sqlmodel): class Item(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) name: str From baf8f47dea927cc2f342bc11d348c3998ced3d8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 09:03:53 +0100 Subject: [PATCH 067/105] =?UTF-8?q?=E2=9C=85=20Refactor=20and=20update=20t?= =?UTF-8?q?est?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_instance_no_args.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_instance_no_args.py b/tests/test_instance_no_args.py index f424c16970..5c8ad77531 100644 --- a/tests/test_instance_no_args.py +++ b/tests/test_instance_no_args.py @@ -2,9 +2,7 @@ import pytest from pydantic import ValidationError -from sqlalchemy import create_engine, select -from sqlalchemy.orm import Session -from sqlmodel import Field, SQLModel +from sqlmodel import Field, Session, SQLModel, create_engine, select def test_allow_instantiation_without_arguments(clear_sqlmodel): @@ -20,7 +18,8 @@ class Item(SQLModel, table=True): item.name = "Rick" db.add(item) db.commit() - result = db.execute(select(Item)).scalars().all() + statement = select(Item) + result = db.exec(statement).all() assert len(result) == 1 assert isinstance(item.id, int) SQLModel.metadata.clear() From 8c603f1180e53884d72a996ea07dbbb3a4e4cfe3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 09:04:21 +0100 Subject: [PATCH 068/105] =?UTF-8?q?=E2=9C=85=20Update=20tests,=20revert=20?= =?UTF-8?q?broken=20JSON=20Schema=20as=20a=20workaround=20for=20new=20mode?= =?UTF-8?q?ls=20without=20defaults?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_enums.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/tests/test_enums.py b/tests/test_enums.py index 07a04c686e..f0543e90f1 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -122,15 +122,12 @@ def test_json_schema_flat_model_pydantic_v2(): "title": "FlatModel", "type": "object", "properties": { - "id": {"default": None, "format": "uuid", "title": "Id", "type": "string"}, - "enum_field": {"allOf": [{"$ref": "#/$defs/MyEnum1"}], "default": None}, + "id": {"title": "Id", "type": "string", "format": "uuid"}, + "enum_field": {"$ref": "#/$defs/MyEnum1"}, }, + "required": ["id", "enum_field"], "$defs": { - "MyEnum1": { - "title": "MyEnum1", - "enum": ["A", "B"], - "type": "string", - } + "MyEnum1": {"enum": ["A", "B"], "title": "MyEnum1", "type": "string"} }, } @@ -146,10 +143,6 @@ def test_json_schema_inherit_model_pydantic_v2(): }, "required": ["id", "enum_field"], "$defs": { - "MyEnum2": { - "title": "MyEnum2", - "enum": ["C", "D"], - "type": "string", - } + "MyEnum2": {"enum": ["C", "D"], "title": "MyEnum2", "type": "string"} }, } From 9f388f138eb5f6dbaf7ba17ad57eb49932fad648 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 10:50:48 +0100 Subject: [PATCH 069/105] =?UTF-8?q?=E2=9E=95=20Add=20dirty=5Fequals=20to?= =?UTF-8?q?=20testing=20dependencies?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index ab007a51c6..d2562011e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ fastapi = "^0.103.2" ruff = "^0.1.2" # For FastAPI tests httpx = "0.24.1" +dirty-equals = "^0.7.1.post0" [build-system] requires = ["poetry-core"] From 8d763dadf3d578bd09b0a29ae6e563c38ececc73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 10:54:25 +0100 Subject: [PATCH 070/105] =?UTF-8?q?=E2=9C=85=20Update=20tests=20with=20com?= =?UTF-8?q?patibility=20for=20Pydantic=20v2=20JSON=20Schema=20changes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_delete/test_tutorial001.py | 56 ++++++- .../test_delete/test_tutorial001_py310.py | 56 ++++++- .../test_delete/test_tutorial001_py39.py | 56 ++++++- .../test_limit_and_offset/test_tutorial001.py | 23 ++- .../test_tutorial001_py310.py | 23 ++- .../test_tutorial001_py39.py | 23 ++- .../test_read_one/test_tutorial001.py | 26 ++- .../test_read_one/test_tutorial001_py310.py | 26 ++- .../test_read_one/test_tutorial001_py39.py | 26 ++- .../test_relationships/test_tutorial001.py | 157 ++++++++++++++++-- .../test_tutorial001_py310.py | 157 ++++++++++++++++-- .../test_tutorial001_py39.py | 157 ++++++++++++++++-- .../test_response_model/test_tutorial001.py | 26 ++- .../test_tutorial001_py310.py | 26 ++- .../test_tutorial001_py39.py | 26 ++- .../test_tutorial001.py | 56 ++++++- .../test_tutorial001_py310.py | 56 ++++++- .../test_tutorial001_py39.py | 56 ++++++- .../test_teams/test_tutorial001.py | 111 +++++++++++-- .../test_teams/test_tutorial001_py310.py | 111 +++++++++++-- .../test_teams/test_tutorial001_py39.py | 111 +++++++++++-- 21 files changed, 1233 insertions(+), 132 deletions(-) diff --git a/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001.py index 6a55d6cb98..706cc8aed7 100644 --- a/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -284,7 +285,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -294,7 +304,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, @@ -302,9 +321,36 @@ def test_tutorial(clear_sqlmodel): "title": "HeroUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "secret_name": IsDict( + { + "title": "Secret Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Secret Name", "type": "string"} + ), + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001_py310.py b/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001_py310.py index 133b287630..46c8c42dd3 100644 --- a/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001_py310.py +++ b/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001_py310.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -287,7 +288,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -297,7 +307,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, @@ -305,9 +324,36 @@ def test_tutorial(clear_sqlmodel): "title": "HeroUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "secret_name": IsDict( + { + "title": "Secret Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Secret Name", "type": "string"} + ), + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001_py39.py b/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001_py39.py index 5aac8cb11f..e2874c1095 100644 --- a/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001_py39.py +++ b/tests/test_tutorial/test_fastapi/test_delete/test_tutorial001_py39.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -287,7 +288,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -297,7 +307,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, @@ -305,9 +324,36 @@ def test_tutorial(clear_sqlmodel): "title": "HeroUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "secret_name": IsDict( + { + "title": "Secret Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Secret Name", "type": "string"} + ), + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001.py index 2709231504..d177c80c4c 100644 --- a/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -217,7 +218,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -227,7 +237,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, diff --git a/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001_py310.py b/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001_py310.py index ee0d89ac55..03086996ca 100644 --- a/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001_py310.py +++ b/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001_py310.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -220,7 +221,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -230,7 +240,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, diff --git a/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001_py39.py b/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001_py39.py index f4ef44abc5..f7e42e4e20 100644 --- a/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001_py39.py +++ b/tests/test_tutorial/test_fastapi/test_limit_and_offset/test_tutorial001_py39.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -220,7 +221,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -230,7 +240,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, diff --git a/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001.py index 5d2327095e..62fbb25a9c 100644 --- a/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -38,11 +39,10 @@ def test_tutorial(clear_sqlmodel): assert response.status_code == 404, response.text response = client.get("/openapi.json") - data = response.json() assert response.status_code == 200, response.text - assert data == { + assert response.json() == { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -163,7 +163,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -173,7 +182,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, diff --git a/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001_py310.py b/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001_py310.py index 2e0a97e780..913d098882 100644 --- a/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001_py310.py +++ b/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001_py310.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -41,11 +42,10 @@ def test_tutorial(clear_sqlmodel): assert response.status_code == 404, response.text response = client.get("/openapi.json") - data = response.json() assert response.status_code == 200, response.text - assert data == { + assert response.json() == { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -166,7 +166,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -176,7 +185,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, diff --git a/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001_py39.py b/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001_py39.py index a663eccac3..9bedf5c62d 100644 --- a/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001_py39.py +++ b/tests/test_tutorial/test_fastapi/test_read_one/test_tutorial001_py39.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -41,11 +42,10 @@ def test_tutorial(clear_sqlmodel): assert response.status_code == 404, response.text response = client.get("/openapi.json") - data = response.json() assert response.status_code == 200, response.text - assert data == { + assert response.json() == { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -166,7 +166,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -176,7 +185,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, diff --git a/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001.py index fb08b9a5fd..18fe0e7c44 100644 --- a/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -531,8 +532,26 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), }, }, "HeroRead": { @@ -542,8 +561,26 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, @@ -554,20 +591,85 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, - "team": {"$ref": "#/components/schemas/TeamRead"}, + "team": IsDict( + { + "anyOf": [ + {"$ref": "#/components/schemas/TeamRead"}, + {"type": "null"}, + ] + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"$ref": "#/components/schemas/TeamRead"} + ), }, }, "HeroUpdate": { "title": "HeroUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "secret_name": IsDict( + { + "title": "Secret Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Secret Name", "type": "string"} + ), + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), }, }, "TeamCreate": { @@ -609,9 +711,36 @@ def test_tutorial(clear_sqlmodel): "title": "TeamUpdate", "type": "object", "properties": { - "id": {"title": "Id", "type": "integer"}, - "name": {"title": "Name", "type": "string"}, - "headquarters": {"title": "Headquarters", "type": "string"}, + "id": IsDict( + { + "title": "Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Id", "type": "integer"} + ), + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "headquarters": IsDict( + { + "title": "Headquarters", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Headquarters", "type": "string"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001_py310.py b/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001_py310.py index dae7db3378..282c807096 100644 --- a/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001_py310.py +++ b/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001_py310.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -534,8 +535,26 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), }, }, "HeroRead": { @@ -545,8 +564,26 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, @@ -557,20 +594,85 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, - "team": {"$ref": "#/components/schemas/TeamRead"}, + "team": IsDict( + { + "anyOf": [ + {"$ref": "#/components/schemas/TeamRead"}, + {"type": "null"}, + ] + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"$ref": "#/components/schemas/TeamRead"} + ), }, }, "HeroUpdate": { "title": "HeroUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "secret_name": IsDict( + { + "title": "Secret Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Secret Name", "type": "string"} + ), + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), }, }, "TeamCreate": { @@ -612,9 +714,36 @@ def test_tutorial(clear_sqlmodel): "title": "TeamUpdate", "type": "object", "properties": { - "id": {"title": "Id", "type": "integer"}, - "name": {"title": "Name", "type": "string"}, - "headquarters": {"title": "Headquarters", "type": "string"}, + "id": IsDict( + { + "title": "Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Id", "type": "integer"} + ), + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "headquarters": IsDict( + { + "title": "Headquarters", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Headquarters", "type": "string"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001_py39.py b/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001_py39.py index 72dee33434..f71ef04721 100644 --- a/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001_py39.py +++ b/tests/test_tutorial/test_fastapi/test_relationships/test_tutorial001_py39.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -534,8 +535,26 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), }, }, "HeroRead": { @@ -545,8 +564,26 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, @@ -557,20 +594,85 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, - "team": {"$ref": "#/components/schemas/TeamRead"}, + "team": IsDict( + { + "anyOf": [ + {"$ref": "#/components/schemas/TeamRead"}, + {"type": "null"}, + ] + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"$ref": "#/components/schemas/TeamRead"} + ), }, }, "HeroUpdate": { "title": "HeroUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "secret_name": IsDict( + { + "title": "Secret Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Secret Name", "type": "string"} + ), + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), }, }, "TeamCreate": { @@ -612,9 +714,36 @@ def test_tutorial(clear_sqlmodel): "title": "TeamUpdate", "type": "object", "properties": { - "id": {"title": "Id", "type": "integer"}, - "name": {"title": "Name", "type": "string"}, - "headquarters": {"title": "Headquarters", "type": "string"}, + "id": IsDict( + { + "title": "Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Id", "type": "integer"} + ), + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "headquarters": IsDict( + { + "title": "Headquarters", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Headquarters", "type": "string"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001.py index ca8a41845e..8f273bbd93 100644 --- a/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -31,11 +32,10 @@ def test_tutorial(clear_sqlmodel): assert data[0]["secret_name"] == hero_data["secret_name"] response = client.get("/openapi.json") - data = response.json() assert response.status_code == 200, response.text - assert data == { + assert response.json() == { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -114,10 +114,28 @@ def test_tutorial(clear_sqlmodel): "required": ["name", "secret_name"], "type": "object", "properties": { - "id": {"title": "Id", "type": "integer"}, + "id": IsDict( + { + "title": "Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Id", "type": "integer"} + ), "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001_py310.py b/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001_py310.py index 4acb0068a1..d249cc4e90 100644 --- a/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001_py310.py +++ b/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001_py310.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -34,11 +35,10 @@ def test_tutorial(clear_sqlmodel): assert data[0]["secret_name"] == hero_data["secret_name"] response = client.get("/openapi.json") - data = response.json() assert response.status_code == 200, response.text - assert data == { + assert response.json() == { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -117,10 +117,28 @@ def test_tutorial(clear_sqlmodel): "required": ["name", "secret_name"], "type": "object", "properties": { - "id": {"title": "Id", "type": "integer"}, + "id": IsDict( + { + "title": "Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Id", "type": "integer"} + ), "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001_py39.py b/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001_py39.py index 20f3f52313..b9fb2be03f 100644 --- a/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001_py39.py +++ b/tests/test_tutorial/test_fastapi/test_response_model/test_tutorial001_py39.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -34,11 +35,10 @@ def test_tutorial(clear_sqlmodel): assert data[0]["secret_name"] == hero_data["secret_name"] response = client.get("/openapi.json") - data = response.json() assert response.status_code == 200, response.text - assert data == { + assert response.json() == { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -117,10 +117,28 @@ def test_tutorial(clear_sqlmodel): "required": ["name", "secret_name"], "type": "object", "properties": { - "id": {"title": "Id", "type": "integer"}, + "id": IsDict( + { + "title": "Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Id", "type": "integer"} + ), "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001.py index 6f97cbf92b..441cc42b28 100644 --- a/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -284,7 +285,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -294,7 +304,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, @@ -302,9 +321,36 @@ def test_tutorial(clear_sqlmodel): "title": "HeroUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "secret_name": IsDict( + { + "title": "Secret Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Secret Name", "type": "string"} + ), + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001_py310.py b/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001_py310.py index f0c5416bdf..7c427a1c67 100644 --- a/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001_py310.py +++ b/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001_py310.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -289,7 +290,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -299,7 +309,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, @@ -307,9 +326,36 @@ def test_tutorial(clear_sqlmodel): "title": "HeroUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "secret_name": IsDict( + { + "title": "Secret Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Secret Name", "type": "string"} + ), + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001_py39.py b/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001_py39.py index 5b911c8462..ea63f52c41 100644 --- a/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001_py39.py +++ b/tests/test_tutorial/test_fastapi/test_session_with_dependency/test_tutorial001_py39.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -289,7 +290,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -299,7 +309,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, @@ -307,9 +326,36 @@ def test_tutorial(clear_sqlmodel): "title": "HeroUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "secret_name": IsDict( + { + "title": "Secret Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Secret Name", "type": "string"} + ), + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001.py index 42f87cef76..7ac06f6245 100644 --- a/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -518,8 +519,26 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), }, }, "HeroRead": { @@ -529,8 +548,26 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, @@ -538,10 +575,46 @@ def test_tutorial(clear_sqlmodel): "title": "HeroUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "secret_name": IsDict( + { + "title": "Secret Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Secret Name", "type": "string"} + ), + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), }, }, "TeamCreate": { @@ -567,8 +640,26 @@ def test_tutorial(clear_sqlmodel): "title": "TeamUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "headquarters": {"title": "Headquarters", "type": "string"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "headquarters": IsDict( + { + "title": "Headquarters", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Headquarters", "type": "string"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001_py310.py b/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001_py310.py index 6cec87a0a7..e875162e00 100644 --- a/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001_py310.py +++ b/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001_py310.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -521,8 +522,26 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), }, }, "HeroRead": { @@ -532,8 +551,26 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, @@ -541,10 +578,46 @@ def test_tutorial(clear_sqlmodel): "title": "HeroUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "secret_name": IsDict( + { + "title": "Secret Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Secret Name", "type": "string"} + ), + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), }, }, "TeamCreate": { @@ -570,8 +643,26 @@ def test_tutorial(clear_sqlmodel): "title": "TeamUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "headquarters": {"title": "Headquarters", "type": "string"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "headquarters": IsDict( + { + "title": "Headquarters", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Headquarters", "type": "string"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001_py39.py b/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001_py39.py index 70279f5b8d..6c93c87514 100644 --- a/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001_py39.py +++ b/tests/test_tutorial/test_fastapi/test_teams/test_tutorial001_py39.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -521,8 +522,26 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), }, }, "HeroRead": { @@ -532,8 +551,26 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, @@ -541,10 +578,46 @@ def test_tutorial(clear_sqlmodel): "title": "HeroUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, - "team_id": {"title": "Team Id", "type": "integer"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "secret_name": IsDict( + { + "title": "Secret Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Secret Name", "type": "string"} + ), + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), + "team_id": IsDict( + { + "title": "Team Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Team Id", "type": "integer"} + ), }, }, "TeamCreate": { @@ -570,8 +643,26 @@ def test_tutorial(clear_sqlmodel): "title": "TeamUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "headquarters": {"title": "Headquarters", "type": "string"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "headquarters": IsDict( + { + "title": "Headquarters", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Headquarters", "type": "string"} + ), }, }, "ValidationError": { From 553a3c569a3debc3dff50b6bc3c3735c527ff947 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 15:18:48 +0100 Subject: [PATCH 071/105] =?UTF-8?q?=E2=9C=85=20Update=20additional=20tests?= =?UTF-8?q?=20for=20compatibility=20with=20Pydantic=20v2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_multiple_models/test_tutorial001.py | 26 +++++++-- .../test_tutorial001_py310.py | 27 +++++++-- .../test_tutorial001_py39.py | 26 +++++++-- .../test_multiple_models/test_tutorial002.py | 26 +++++++-- .../test_tutorial002_py310.py | 26 +++++++-- .../test_tutorial002_py39.py | 26 +++++++-- .../test_simple_hero_api/test_tutorial001.py | 26 +++++++-- .../test_tutorial001_py310.py | 26 +++++++-- .../test_update/test_tutorial001.py | 56 +++++++++++++++++-- .../test_update/test_tutorial001_py310.py | 56 +++++++++++++++++-- .../test_update/test_tutorial001_py39.py | 56 +++++++++++++++++-- 11 files changed, 329 insertions(+), 48 deletions(-) diff --git a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001.py index 7444f8858d..2ebfc0c0d0 100644 --- a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlalchemy import inspect from sqlalchemy.engine.reflection import Inspector @@ -53,11 +54,10 @@ def test_tutorial(clear_sqlmodel): assert data[1]["id"] != hero2_data["id"] response = client.get("/openapi.json") - data = response.json() assert response.status_code == 200, response.text - assert data == { + assert response.json() == { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -142,7 +142,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -153,7 +162,16 @@ def test_tutorial(clear_sqlmodel): "id": {"title": "Id", "type": "integer"}, "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001_py310.py b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001_py310.py index 080a907e0e..c17e482921 100644 --- a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001_py310.py +++ b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001_py310.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlalchemy import inspect from sqlalchemy.engine.reflection import Inspector @@ -56,11 +57,9 @@ def test_tutorial(clear_sqlmodel): assert data[1]["id"] != hero2_data["id"] response = client.get("/openapi.json") - data = response.json() - assert response.status_code == 200, response.text - assert data == { + assert response.json() == { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -145,7 +144,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -156,7 +164,16 @@ def test_tutorial(clear_sqlmodel): "id": {"title": "Id", "type": "integer"}, "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001_py39.py b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001_py39.py index 7c320093ae..258b3a4e54 100644 --- a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001_py39.py +++ b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial001_py39.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlalchemy import inspect from sqlalchemy.engine.reflection import Inspector @@ -56,11 +57,10 @@ def test_tutorial(clear_sqlmodel): assert data[1]["id"] != hero2_data["id"] response = client.get("/openapi.json") - data = response.json() assert response.status_code == 200, response.text - assert data == { + assert response.json() == { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -145,7 +145,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -156,7 +165,16 @@ def test_tutorial(clear_sqlmodel): "id": {"title": "Id", "type": "integer"}, "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002.py b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002.py index 4a6bb7499e..47f2e64155 100644 --- a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002.py +++ b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlalchemy import inspect from sqlalchemy.engine.reflection import Inspector @@ -53,11 +54,10 @@ def test_tutorial(clear_sqlmodel): assert data[1]["id"] != hero2_data["id"] response = client.get("/openapi.json") - data = response.json() assert response.status_code == 200, response.text - assert data == { + assert response.json() == { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -142,7 +142,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -152,7 +161,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, diff --git a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002_py310.py b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002_py310.py index 20195c6fdf..c09b15bd53 100644 --- a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002_py310.py +++ b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002_py310.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlalchemy import inspect from sqlalchemy.engine.reflection import Inspector @@ -56,11 +57,10 @@ def test_tutorial(clear_sqlmodel): assert data[1]["id"] != hero2_data["id"] response = client.get("/openapi.json") - data = response.json() assert response.status_code == 200, response.text - assert data == { + assert response.json() == { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -145,7 +145,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -155,7 +164,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, diff --git a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002_py39.py b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002_py39.py index 45b061b401..8ad0f271e1 100644 --- a/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002_py39.py +++ b/tests/test_tutorial/test_fastapi/test_multiple_models/test_tutorial002_py39.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlalchemy import inspect from sqlalchemy.engine.reflection import Inspector @@ -56,11 +57,10 @@ def test_tutorial(clear_sqlmodel): assert data[1]["id"] != hero2_data["id"] response = client.get("/openapi.json") - data = response.json() assert response.status_code == 200, response.text - assert data == { + assert response.json() == { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -145,7 +145,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -155,7 +164,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, diff --git a/tests/test_tutorial/test_fastapi/test_simple_hero_api/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_simple_hero_api/test_tutorial001.py index 2136ed8a1f..9df7e50b81 100644 --- a/tests/test_tutorial/test_fastapi/test_simple_hero_api/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_simple_hero_api/test_tutorial001.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -51,11 +52,10 @@ def test_tutorial(clear_sqlmodel): assert data[1]["id"] == hero2_data["id"] response = client.get("/openapi.json") - data = response.json() assert response.status_code == 200, response.text - assert data == { + assert response.json() == { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -120,10 +120,28 @@ def test_tutorial(clear_sqlmodel): "required": ["name", "secret_name"], "type": "object", "properties": { - "id": {"title": "Id", "type": "integer"}, + "id": IsDict( + { + "title": "Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Id", "type": "integer"} + ), "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_simple_hero_api/test_tutorial001_py310.py b/tests/test_tutorial/test_fastapi/test_simple_hero_api/test_tutorial001_py310.py index d85d9ee5b2..a47513dde2 100644 --- a/tests/test_tutorial/test_fastapi/test_simple_hero_api/test_tutorial001_py310.py +++ b/tests/test_tutorial/test_fastapi/test_simple_hero_api/test_tutorial001_py310.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -54,11 +55,10 @@ def test_tutorial(clear_sqlmodel): assert data[1]["id"] == hero2_data["id"] response = client.get("/openapi.json") - data = response.json() assert response.status_code == 200, response.text - assert data == { + assert response.json() == { "openapi": "3.1.0", "info": {"title": "FastAPI", "version": "0.1.0"}, "paths": { @@ -123,10 +123,28 @@ def test_tutorial(clear_sqlmodel): "required": ["name", "secret_name"], "type": "object", "properties": { - "id": {"title": "Id", "type": "integer"}, + "id": IsDict( + { + "title": "Id", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Id", "type": "integer"} + ), "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_update/test_tutorial001.py b/tests/test_tutorial/test_fastapi/test_update/test_tutorial001.py index a4573ef11b..973ab2db04 100644 --- a/tests/test_tutorial/test_fastapi/test_update/test_tutorial001.py +++ b/tests/test_tutorial/test_fastapi/test_update/test_tutorial001.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -263,7 +264,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -273,7 +283,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, @@ -281,9 +300,36 @@ def test_tutorial(clear_sqlmodel): "title": "HeroUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "secret_name": IsDict( + { + "title": "Secret Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Secret Name", "type": "string"} + ), + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_update/test_tutorial001_py310.py b/tests/test_tutorial/test_fastapi/test_update/test_tutorial001_py310.py index cf56e3cb01..090af8c60f 100644 --- a/tests/test_tutorial/test_fastapi/test_update/test_tutorial001_py310.py +++ b/tests/test_tutorial/test_fastapi/test_update/test_tutorial001_py310.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -266,7 +267,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -276,7 +286,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, @@ -284,9 +303,36 @@ def test_tutorial(clear_sqlmodel): "title": "HeroUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "secret_name": IsDict( + { + "title": "Secret Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Secret Name", "type": "string"} + ), + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { diff --git a/tests/test_tutorial/test_fastapi/test_update/test_tutorial001_py39.py b/tests/test_tutorial/test_fastapi/test_update/test_tutorial001_py39.py index b301ca3bf1..22dfb8f268 100644 --- a/tests/test_tutorial/test_fastapi/test_update/test_tutorial001_py39.py +++ b/tests/test_tutorial/test_fastapi/test_update/test_tutorial001_py39.py @@ -1,3 +1,4 @@ +from dirty_equals import IsDict from fastapi.testclient import TestClient from sqlmodel import create_engine from sqlmodel.pool import StaticPool @@ -266,7 +267,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "HeroRead": { @@ -276,7 +286,16 @@ def test_tutorial(clear_sqlmodel): "properties": { "name": {"title": "Name", "type": "string"}, "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), "id": {"title": "Id", "type": "integer"}, }, }, @@ -284,9 +303,36 @@ def test_tutorial(clear_sqlmodel): "title": "HeroUpdate", "type": "object", "properties": { - "name": {"title": "Name", "type": "string"}, - "secret_name": {"title": "Secret Name", "type": "string"}, - "age": {"title": "Age", "type": "integer"}, + "name": IsDict( + { + "title": "Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Name", "type": "string"} + ), + "secret_name": IsDict( + { + "title": "Secret Name", + "anyOf": [{"type": "string"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Secret Name", "type": "string"} + ), + "age": IsDict( + { + "title": "Age", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } + ) + | IsDict( + # TODO: remove when deprecating Pydantic v1 + {"title": "Age", "type": "integer"} + ), }, }, "ValidationError": { From f0d088cf11d8b2cab4f61de9c26e1059cf7171a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 16:44:51 +0100 Subject: [PATCH 072/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20main=20?= =?UTF-8?q?and=20compat,=20restructure,=20simplify,=20and=20fix=20implemen?= =?UTF-8?q?tation,=20supporting=20table=20models=20without=20defaults,=20S?= =?UTF-8?q?QLAlchemy=20=5F=5Finit=5F=5F=20overrides?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 271 ++++++++++++++++++++++++++++++++++++++--- sqlmodel/main.py | 287 +++++++++++++++++++------------------------- 2 files changed, 378 insertions(+), 180 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index fc249e90df..4ceca7ad61 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -1,10 +1,14 @@ import types +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, Dict, ForwardRef, Optional, + Set, Type, TypeVar, Union, @@ -17,23 +21,6 @@ IS_PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") -if IS_PYDANTIC_V2: - from pydantic import ConfigDict as BaseConfig - from pydantic._internal._fields import PydanticMetadata - from pydantic._internal._model_construction import ModelMetaclass - from pydantic_core import PydanticUndefined as Undefined # noqa - from pydantic_core import PydanticUndefinedType as UndefinedType - - # Dummy for types, to make it importable - class ModelField: - pass -else: - from pydantic import BaseConfig as BaseConfig - from pydantic.fields import SHAPE_SINGLETON, ModelField - from pydantic.fields import Undefined as Undefined # noqa - from pydantic.fields import UndefinedType as UndefinedType - from pydantic.main import ModelMetaclass as ModelMetaclass - from pydantic.typing import resolve_annotations if TYPE_CHECKING: from .main import RelationshipInfo, SQLModel @@ -42,6 +29,7 @@ class ModelField: NoneType = type(None) T = TypeVar("T") InstanceOrType = Union[T, Type[T]] +_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") class FakeMetadata: @@ -50,11 +38,41 @@ class FakeMetadata: decimal_places: Optional[int] = None +@dataclass +class ObjectWithUpdateWrapper: + obj: Any + update: Dict[str, Any] + + def __getattribute__(self, __name: str) -> Any: + if __name in self.update: + return self.update[__name] + return getattr(self.obj, __name) + + def _is_union_type(t: Any) -> bool: return t is UnionType or t is Union +_finish_init: ContextVar[bool] = ContextVar("_finish_init", default=True) + + +@contextmanager +def _partial_init(): + token = _finish_init.set(False) + yield + _finish_init.reset(token) + + if IS_PYDANTIC_V2: + from pydantic import ConfigDict as BaseConfig + from pydantic._internal._fields import PydanticMetadata + from pydantic._internal._model_construction import ModelMetaclass + from pydantic_core import PydanticUndefined as Undefined # noqa + from pydantic_core import PydanticUndefinedType as UndefinedType + + # Dummy for types, to make it importable + class ModelField: + pass class SQLModelConfig(BaseConfig, total=False): table: Optional[bool] @@ -175,7 +193,156 @@ def get_field_metadata(field: Any) -> Any: def post_init_field_info(field_info: FieldInfo) -> None: return None + + def _sqlmodel_table_construct( + # SQLModel override + # cls: Type[_TSQLModel], _fields_set: Union[Set[str], None] = None, **values: Any + *, + self_instance: _TSQLModel, + values: Dict[str, Any], + _fields_set: Union[Set[str], None] = None, + ) -> _TSQLModel: + # Copy from Pydantic's BaseModel.construct() + # Ref: https://github.com/pydantic/pydantic/blob/v2.5.2/pydantic/main.py#L198 + # Modified to not include everything, only the model fields, and to + # set relationships + # SQLModel override to get class SQLAlchemy __dict__ attributes and + # set them back in after creating the object + # new_obj = cls.__new__(cls) + cls = type(self_instance) + old_dict = self_instance.__dict__.copy() + # End SQLModel override + + fields_values: dict[str, Any] = {} + defaults: dict[ + str, Any + ] = {} # keeping this separate from `fields_values` helps us compute `_fields_set` + for name, field in cls.model_fields.items(): + if field.alias and field.alias in values: + fields_values[name] = values.pop(field.alias) + elif name in values: + fields_values[name] = values.pop(name) + elif not field.is_required(): + defaults[name] = field.get_default(call_default_factory=True) + if _fields_set is None: + _fields_set = set(fields_values.keys()) + fields_values.update(defaults) + + _extra: dict[str, Any] | None = None + if cls.model_config.get("extra") == "allow": + _extra = {} + for k, v in values.items(): + _extra[k] = v + # SQLModel override, do not include everything, only the model fields + # else: + # fields_values.update(values) + # End SQLModel override + # SQLModel override + # Do not set __dict__, instead use setattr to trigger SQLAlchemy + # object.__setattr__(new_obj, "__dict__", fields_values) + # instrumentation + for key, value in {**old_dict, **fields_values}.items(): + setattr(self_instance, key, value) + # End SQLModel override + object.__setattr__(self_instance, "__pydantic_fields_set__", _fields_set) + if not cls.__pydantic_root_model__: + object.__setattr__(self_instance, "__pydantic_extra__", _extra) + + if cls.__pydantic_post_init__: + self_instance.model_post_init(None) + elif not cls.__pydantic_root_model__: + # Note: if there are any private attributes, cls.__pydantic_post_init__ would exist + # Since it doesn't, that means that `__pydantic_private__` should be set to None + object.__setattr__(self_instance, "__pydantic_private__", None) + # SQLModel override, set relationships + # Get and set any relationship objects + for key in self_instance.__sqlmodel_relationships__: + value = values.get(key, Undefined) + if value is not Undefined: + setattr(self_instance, key, value) + # End SQLModel override + return self_instance + + def _model_validate( + cls: Type[_TSQLModel], + obj: Any, + *, + strict: Union[bool, None] = None, + from_attributes: Union[bool, None] = None, + context: Union[Dict[str, Any], None] = None, + update: Union[Dict[str, Any], None] = None, + ) -> _TSQLModel: + if not cls_is_table(cls): + new_obj: _TSQLModel = cls.__new__(cls) + else: + # If table, create the new instance normally to make SQLAlchemy create + # the _sa_instance_state attribute + # The wrapper of this function should use with _partial_init() + with _partial_init(): + new_obj = cls() + # SQLModel Override to get class SQLAlchemy __dict__ attributes and + # set them back in after creating the object + old_dict = new_obj.__dict__.copy() + use_obj = obj + if isinstance(obj, dict) and update: + use_obj = {**obj, **update} + elif update: + use_obj = ObjectWithUpdateWrapper(obj=obj, update=update) + cls.__pydantic_validator__.validate_python( + use_obj, + strict=strict, + from_attributes=from_attributes, + context=context, + self_instance=new_obj, + ) + # Capture fields set to restore it later + fields_set = new_obj.__pydantic_fields_set__.copy() + if not cls_is_table(cls): + # If not table, normal Pydantic code, set __dict__ + new_obj.__dict__ = {**old_dict, **new_obj.__dict__} + else: + # Do not set __dict__, instead use setattr to trigger SQLAlchemy + # instrumentation + for key, value in {**old_dict, **new_obj.__dict__}.items(): + setattr(new_obj, key, value) + # Restore fields set + object.__setattr__(new_obj, "__pydantic_fields_set__", fields_set) + # Get and set any relationship objects + if cls_is_table(cls): + for key in new_obj.__sqlmodel_relationships__: + value = getattr(use_obj, key, Undefined) + if value is not Undefined: + setattr(new_obj, key, value) + return new_obj + + def __sqlmodel_init__(__pydantic_self__: "SQLModel", **data: Any) -> None: + old_dict = __pydantic_self__.__dict__.copy() + if not cls_is_table(__pydantic_self__.__class__): + __pydantic_self__.__pydantic_validator__.validate_python( + data, + self_instance=__pydantic_self__, + ) + else: + _sqlmodel_table_construct( + self_instance=__pydantic_self__, + values=data, + ) + object.__setattr__( + __pydantic_self__, + "__dict__", + {**old_dict, **__pydantic_self__.__dict__}, + ) + else: + from pydantic import BaseConfig as BaseConfig + from pydantic.errors import ConfigError + from pydantic.fields import SHAPE_SINGLETON, ModelField + from pydantic.fields import Undefined as Undefined # noqa + from pydantic.fields import UndefinedType as UndefinedType + from pydantic.main import ModelMetaclass as ModelMetaclass + from pydantic.main import validate_model + from pydantic.typing import resolve_annotations + from pydantic.utils import ROOT_KEY class SQLModelConfig(BaseConfig): table: Optional[bool] = None @@ -253,3 +420,73 @@ def get_field_metadata(field: Any) -> Any: def post_init_field_info(field_info: FieldInfo) -> None: field_info._validate() + + def _model_validate( + cls: Type[_TSQLModel], + obj: Any, + *, + strict: Union[bool, None] = None, + from_attributes: Union[bool, None] = None, + context: Union[Dict[str, Any], None] = None, + update: Union[Dict[str, Any], None] = None, + ) -> _TSQLModel: + # This was SQLModel's original from_orm() for Pydantic v1 + # Duplicated from Pydantic + if not cls.__config__.orm_mode: # noqa + raise ConfigError( + "You must have the config attribute orm_mode=True to use from_orm" + ) + obj = ( + {ROOT_KEY: obj} + if cls.__custom_root_type__ # noqa + else cls._decompose_class(obj) # noqa + ) + # SQLModel, support update dict + if update is not None: + obj = {**obj, **update} + # End SQLModel support dict + if not getattr(cls.__config__, "table", False): # noqa + # If not table, normal Pydantic code + m: _TSQLModel = cls.__new__(cls) + else: + # If table, create the new instance normally to make SQLAlchemy create + # the _sa_instance_state attribute + m = cls() + values, fields_set, validation_error = validate_model(cls, obj) + if validation_error: + raise validation_error + # Updated to trigger SQLAlchemy internal handling + if not getattr(cls.__config__, "table", False): # noqa + object.__setattr__(m, "__dict__", values) + else: + for key, value in values.items(): + setattr(m, key, value) + # Continue with standard Pydantic logic + object.__setattr__(m, "__fields_set__", fields_set) + m._init_private_attributes() # noqa + return m + + def __sqlmodel_init__(__pydantic_self__: "SQLModel", **data: Any) -> None: + values, fields_set, validation_error = validate_model( + __pydantic_self__.__class__, data + ) + # Only raise errors if not a SQLModel model + if ( + not cls_is_table(__pydantic_self__.__class__) # noqa + and validation_error + ): + raise validation_error + if not cls_is_table(__pydantic_self__.__class__): + object.__setattr__(__pydantic_self__, "__dict__", values) + else: + # Do not set values as in Pydantic, pass them through setattr, so + # SQLAlchemy can handle them + for key, value in values.items(): + setattr(__pydantic_self__, key, value) + object.__setattr__(__pydantic_self__, "__fields_set__", fields_set) + non_pydantic_keys = data.keys() - values.keys() + + if cls_is_table(__pydantic_self__.__class__): + for key in non_pydantic_keys: + if key in __pydantic_self__.__sqlmodel_relationships__: + setattr(__pydantic_self__, key, data[key]) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 4c9e5325b7..520ecbf145 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -26,7 +26,6 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo as PydanticFieldInfo -from pydantic.utils import Representation from sqlalchemy import ( Boolean, Column, @@ -62,7 +61,10 @@ SQLModelConfig, Undefined, UndefinedType, + __sqlmodel_init__, + _finish_init, _is_field_noneable, + _model_validate, cls_is_table, get_annotations, get_config_value, @@ -76,10 +78,12 @@ ) from .sql.sqltypes import GUID, AutoString -if not IS_PYDANTIC_V2: - from pydantic.errors import ConfigError, DictError +if IS_PYDANTIC_V2: + from pydantic._internal._repr import Representation +else: + from pydantic.errors import DictError from pydantic.main import validate_model - from pydantic.utils import ROOT_KEY + from pydantic.utils import Representation _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] @@ -708,33 +712,22 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any: def __init__(__pydantic_self__, **data: Any) -> None: # Uses something other than `self` the first arg to allow "self" as a # settable attribute - # TODO: review how this works and check defaults set in metaclass __new__ - if IS_PYDANTIC_V2: - old_dict = __pydantic_self__.__dict__.copy() - super().__init__(**data) # noqa - __pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__} - non_pydantic_keys = data.keys() - __pydantic_self__.model_fields - else: - values, fields_set, validation_error = validate_model( - __pydantic_self__.__class__, data - ) - # Only raise errors if not a SQLModel model - if ( - not getattr(__pydantic_self__.__config__, "table", False) # noqa - and validation_error - ): - raise validation_error - # Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy - # can handle them - # object.__setattr__(__pydantic_self__, '__dict__', values) - for key, value in values.items(): - setattr(__pydantic_self__, key, value) - object.__setattr__(__pydantic_self__, "__fields_set__", fields_set) - non_pydantic_keys = data.keys() - values.keys() - - for key in non_pydantic_keys: - if key in __pydantic_self__.__sqlmodel_relationships__: - setattr(__pydantic_self__, key, data[key]) + + # SQLAlchemy does very dark black magic and modifies the __init__ method in + # sqlalchemy.orm.instrumentation._generate_init() + # so, to make SQLAlchemy work, it's needed to explicitly call __init__ to + # trigger all the SQLAlchemy logic, it doesn't work using cls.__new__, setting + # attributes obj.__dict__, etc. The __init__ method has to be called. But + # there are cases where calling all the default logic is not ideal, e.g. + # when calling Model.model_validate(), as the validation is done outside + # of instance creation. + # At the same time, __init__ is what users would normally call, by creating + # a new instance, which should have validation and all the default logic. + # So, to be able to set up the internal SQLAlchemy logic alone without + # executing the rest, and support things like Model.model_validate(), we + # use a contextvar to know if we should execute everything. + if _finish_init.get(): + __sqlmodel_init__(__pydantic_self__, **data) def __setattr__(self, name: str, value: Any) -> None: if name in {"_sa_instance_state"}: @@ -765,140 +758,108 @@ def __tablename__(cls) -> str: # TODO: refactor this and make each method available in both Pydantic v1 and v2 # add deprecations, re-use methods from backwards compatibility parts, etc. - if IS_PYDANTIC_V2: - @classmethod - def model_validate( - cls: type[_TSQLModel], - obj: Any, - *, - strict: bool | None = None, - from_attributes: bool | None = None, - context: dict[str, Any] | None = None, - ) -> _TSQLModel: - # Somehow model validate doesn't call __init__ so it would remove our init logic - validated = super().model_validate( - obj, strict=strict, from_attributes=from_attributes, context=context - ) - return cls(**validated.model_dump(exclude_unset=True)) - - else: + @classmethod + def model_validate( + cls: Type[_TSQLModel], + obj: Any, + *, + strict: Union[bool, None] = None, + from_attributes: Union[bool, None] = None, + context: Union[Dict[str, Any], None] = None, + update: Union[Dict[str, Any], None] = None, + ) -> _TSQLModel: + return _model_validate( + cls=cls, + obj=obj, + strict=strict, + from_attributes=from_attributes, + context=context, + update=update, + ) - @classmethod - def from_orm( - cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None - ) -> _TSQLModel: - # Duplicated from Pydantic - if not cls.__config__.orm_mode: # noqa - raise ConfigError( - "You must have the config attribute orm_mode=True to use from_orm" - ) - obj = ( - {ROOT_KEY: obj} - if cls.__custom_root_type__ # noqa - else cls._decompose_class(obj) # noqa + @classmethod + def from_orm( + cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None + ) -> _TSQLModel: + return cls.model_validate(obj, update=update) + + @classmethod + def parse_obj( + cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None + ) -> _TSQLModel: + obj = cls._enforce_dict_if_root(obj) # noqa + # SQLModel, support update dict + if update is not None: + obj = {**obj, **update} + # End SQLModel support dict + return super().parse_obj(obj) + + # From Pydantic, override to enforce validation with dict + @classmethod + def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel: + if isinstance(value, cls): + return ( + value.copy() if cls.__config__.copy_on_model_validation else value # noqa ) - # SQLModel, support update dict - if update is not None: - obj = {**obj, **update} - # End SQLModel support dict - if not getattr(cls.__config__, "table", False): # noqa - # If not table, normal Pydantic code - m: _TSQLModel = cls.__new__(cls) - else: - # If table, create the new instance normally to make SQLAlchemy create - # the _sa_instance_state attribute - m = cls() - values, fields_set, validation_error = validate_model(cls, obj) + + value = cls._enforce_dict_if_root(value) + if isinstance(value, dict): + values, fields_set, validation_error = validate_model(cls, value) if validation_error: raise validation_error - # Updated to trigger SQLAlchemy internal handling - if not getattr(cls.__config__, "table", False): # noqa - object.__setattr__(m, "__dict__", values) - else: - for key, value in values.items(): - setattr(m, key, value) - # Continue with standard Pydantic logic - object.__setattr__(m, "__fields_set__", fields_set) - m._init_private_attributes() # noqa - return m - - @classmethod - def parse_obj( - cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None - ) -> _TSQLModel: - obj = cls._enforce_dict_if_root(obj) # noqa - # SQLModel, support update dict - if update is not None: - obj = {**obj, **update} - # End SQLModel support dict - return super().parse_obj(obj) - - # From Pydantic, override to enforce validation with dict - @classmethod - def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel: - if isinstance(value, cls): - return ( - value.copy() if cls.__config__.copy_on_model_validation else value # noqa - ) - - value = cls._enforce_dict_if_root(value) - if isinstance(value, dict): - values, fields_set, validation_error = validate_model(cls, value) - if validation_error: - raise validation_error - model = cls(**value) - # Reset fields set, this would have been done in Pydantic in __init__ - object.__setattr__(model, "__fields_set__", fields_set) - return model - elif cls.__config__.orm_mode: # noqa - return cls.from_orm(value) - elif cls.__custom_root_type__: # noqa - return cls.parse_obj(value) - else: - try: - value_as_dict = dict(value) - except (TypeError, ValueError) as e: - raise DictError() from e - return cls(**value_as_dict) - - # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes - def _calculate_keys( - self, - include: Optional[Mapping[Union[int, str], Any]], - exclude: Optional[Mapping[Union[int, str], Any]], - exclude_unset: bool, - update: Optional[Dict[str, Any]] = None, - ) -> Optional[AbstractSet[str]]: - if include is None and exclude is None and not exclude_unset: - # Original in Pydantic: - # return None - # Updated to not return SQLAlchemy attributes - # Do not include relationships as that would easily lead to infinite - # recursion, or traversing the whole database - return ( - self.__fields__.keys() # noqa - ) # | self.__sqlmodel_relationships__.keys() - - keys: AbstractSet[str] - if exclude_unset: - keys = self.__fields_set__.copy() # noqa - else: - # Original in Pydantic: - # keys = self.__dict__.keys() - # Updated to not return SQLAlchemy attributes - # Do not include relationships as that would easily lead to infinite - # recursion, or traversing the whole database - keys = ( - self.__fields__.keys() # noqa - ) # | self.__sqlmodel_relationships__.keys() - if include is not None: - keys &= include.keys() - - if update: - keys -= update.keys() - - if exclude: - keys -= {k for k, v in exclude.items() if _value_items_is_true(v)} - - return keys + model = cls(**value) + # Reset fields set, this would have been done in Pydantic in __init__ + object.__setattr__(model, "__fields_set__", fields_set) + return model + elif cls.__config__.orm_mode: # noqa + return cls.from_orm(value) + elif cls.__custom_root_type__: # noqa + return cls.parse_obj(value) + else: + try: + value_as_dict = dict(value) + except (TypeError, ValueError) as e: + raise DictError() from e + return cls(**value_as_dict) + + # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes + def _calculate_keys( + self, + include: Optional[Mapping[Union[int, str], Any]], + exclude: Optional[Mapping[Union[int, str], Any]], + exclude_unset: bool, + update: Optional[Dict[str, Any]] = None, + ) -> Optional[AbstractSet[str]]: + if include is None and exclude is None and not exclude_unset: + # Original in Pydantic: + # return None + # Updated to not return SQLAlchemy attributes + # Do not include relationships as that would easily lead to infinite + # recursion, or traversing the whole database + return ( + self.__fields__.keys() # noqa + ) # | self.__sqlmodel_relationships__.keys() + + keys: AbstractSet[str] + if exclude_unset: + keys = self.__fields_set__.copy() # noqa + else: + # Original in Pydantic: + # keys = self.__dict__.keys() + # Updated to not return SQLAlchemy attributes + # Do not include relationships as that would easily lead to infinite + # recursion, or traversing the whole database + keys = ( + self.__fields__.keys() # noqa + ) # | self.__sqlmodel_relationships__.keys() + if include is not None: + keys &= include.keys() + + if update: + keys -= update.keys() + + if exclude: + keys -= {k for k, v in exclude.items() if _value_items_is_true(v)} + + return keys From 2806b385e9b633868e879b404db1b2987bc88b8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 16:47:05 +0100 Subject: [PATCH 073/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Re-export=20Repres?= =?UTF-8?q?entation=20from=20=5Fcompat?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 4ceca7ad61..67f9a6b221 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -67,6 +67,7 @@ def _partial_init(): from pydantic import ConfigDict as BaseConfig from pydantic._internal._fields import PydanticMetadata from pydantic._internal._model_construction import ModelMetaclass + from pydantic._internal._repr import Representation as Representation from pydantic_core import PydanticUndefined as Undefined # noqa from pydantic_core import PydanticUndefinedType as UndefinedType @@ -343,6 +344,7 @@ def __sqlmodel_init__(__pydantic_self__: "SQLModel", **data: Any) -> None: from pydantic.main import validate_model from pydantic.typing import resolve_annotations from pydantic.utils import ROOT_KEY + from pydantic.utils import Representation as Representation class SQLModelConfig(BaseConfig): table: Optional[bool] = None From 08bd1ef0d1c8b0a04240aa383d68751439b5b4e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 16:49:37 +0100 Subject: [PATCH 074/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20imports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 520ecbf145..ed687d64bb 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -58,6 +58,7 @@ BaseConfig, ModelField, ModelMetaclass, + Representation, SQLModelConfig, Undefined, UndefinedType, @@ -78,12 +79,9 @@ ) from .sql.sqltypes import GUID, AutoString -if IS_PYDANTIC_V2: - from pydantic._internal._repr import Representation -else: +if not IS_PYDANTIC_V2: from pydantic.errors import DictError from pydantic.main import validate_model - from pydantic.utils import Representation _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] From e07ada1ca40b5807ecd07750647fbea08653e2ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 17:06:08 +0100 Subject: [PATCH 075/105] =?UTF-8?q?=F0=9F=9A=9A=20Update=20and=20rename=20?= =?UTF-8?q?symbols?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 88 ++++++++++++++++++++++----------------------- sqlmodel/main.py | 22 ++++++------ 2 files changed, 54 insertions(+), 56 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 67f9a6b221..94d3c2790a 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -29,7 +29,7 @@ NoneType = type(None) T = TypeVar("T") InstanceOrType = Union[T, Type[T]] -_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") +TSQLModel = TypeVar("TSQLModel", bound="SQLModel") class FakeMetadata: @@ -53,14 +53,14 @@ def _is_union_type(t: Any) -> bool: return t is UnionType or t is Union -_finish_init: ContextVar[bool] = ContextVar("_finish_init", default=True) +finish_init: ContextVar[bool] = ContextVar("finish_init", default=True) @contextmanager -def _partial_init(): - token = _finish_init.set(False) +def partial_init(): + token = finish_init.set(False) yield - _finish_init.reset(token) + finish_init.reset(token) if IS_PYDANTIC_V2: @@ -103,7 +103,7 @@ def set_fields_set( def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: return class_dict.get("__annotations__", {}) - def cls_is_table(cls: Type) -> bool: + def is_table_model_class(cls: Type) -> bool: config = getattr(cls, "model_config", None) if not config: return False @@ -147,7 +147,7 @@ def get_relationship_to( name=name, rel_info=rel_info, annotation=use_annotation ) - def _is_field_noneable(field: "FieldInfo") -> bool: + def is_field_noneable(field: "FieldInfo") -> bool: if getattr(field, "nullable", Undefined) is not Undefined: return field.nullable # type: ignore origin = get_origin(field.annotation) @@ -195,14 +195,14 @@ def get_field_metadata(field: Any) -> Any: def post_init_field_info(field_info: FieldInfo) -> None: return None - def _sqlmodel_table_construct( + def sqlmodel_table_construct( # SQLModel override # cls: Type[_TSQLModel], _fields_set: Union[Set[str], None] = None, **values: Any *, - self_instance: _TSQLModel, + self_instance: TSQLModel, values: Dict[str, Any], _fields_set: Union[Set[str], None] = None, - ) -> _TSQLModel: + ) -> TSQLModel: # Copy from Pydantic's BaseModel.construct() # Ref: https://github.com/pydantic/pydantic/blob/v2.5.2/pydantic/main.py#L198 # Modified to not include everything, only the model fields, and to @@ -264,22 +264,22 @@ def _sqlmodel_table_construct( # End SQLModel override return self_instance - def _model_validate( - cls: Type[_TSQLModel], + def sqlmodel_validate( + cls: Type[TSQLModel], obj: Any, *, strict: Union[bool, None] = None, from_attributes: Union[bool, None] = None, context: Union[Dict[str, Any], None] = None, update: Union[Dict[str, Any], None] = None, - ) -> _TSQLModel: - if not cls_is_table(cls): - new_obj: _TSQLModel = cls.__new__(cls) + ) -> TSQLModel: + if not is_table_model_class(cls): + new_obj: TSQLModel = cls.__new__(cls) else: # If table, create the new instance normally to make SQLAlchemy create # the _sa_instance_state attribute # The wrapper of this function should use with _partial_init() - with _partial_init(): + with partial_init(): new_obj = cls() # SQLModel Override to get class SQLAlchemy __dict__ attributes and # set them back in after creating the object @@ -298,7 +298,7 @@ def _model_validate( ) # Capture fields set to restore it later fields_set = new_obj.__pydantic_fields_set__.copy() - if not cls_is_table(cls): + if not is_table_model_class(cls): # If not table, normal Pydantic code, set __dict__ new_obj.__dict__ = {**old_dict, **new_obj.__dict__} else: @@ -309,29 +309,29 @@ def _model_validate( # Restore fields set object.__setattr__(new_obj, "__pydantic_fields_set__", fields_set) # Get and set any relationship objects - if cls_is_table(cls): + if is_table_model_class(cls): for key in new_obj.__sqlmodel_relationships__: value = getattr(use_obj, key, Undefined) if value is not Undefined: setattr(new_obj, key, value) return new_obj - def __sqlmodel_init__(__pydantic_self__: "SQLModel", **data: Any) -> None: - old_dict = __pydantic_self__.__dict__.copy() - if not cls_is_table(__pydantic_self__.__class__): - __pydantic_self__.__pydantic_validator__.validate_python( + def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None: + old_dict = self.__dict__.copy() + if not is_table_model_class(self.__class__): + self.__pydantic_validator__.validate_python( data, - self_instance=__pydantic_self__, + self_instance=self, ) else: - _sqlmodel_table_construct( - self_instance=__pydantic_self__, + sqlmodel_table_construct( + self_instance=self, values=data, ) object.__setattr__( - __pydantic_self__, + self, "__dict__", - {**old_dict, **__pydantic_self__.__dict__}, + {**old_dict, **self.__dict__}, ) else: @@ -377,7 +377,7 @@ def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: class_dict.get("__module__", None), ) - def cls_is_table(cls: Type) -> bool: + def is_table_model_class(cls: Type) -> bool: config = getattr(cls, "__config__", None) if not config: return False @@ -400,7 +400,7 @@ def get_relationship_to( relationship_to = temp_field.type_.__forward_arg__ return relationship_to - def _is_field_noneable(field: "FieldInfo") -> bool: + def is_field_noneable(field: "FieldInfo") -> bool: if not field.required: # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) return field.allow_none and ( @@ -423,15 +423,15 @@ def get_field_metadata(field: Any) -> Any: def post_init_field_info(field_info: FieldInfo) -> None: field_info._validate() - def _model_validate( - cls: Type[_TSQLModel], + def sqlmodel_validate( + cls: Type[TSQLModel], obj: Any, *, strict: Union[bool, None] = None, from_attributes: Union[bool, None] = None, context: Union[Dict[str, Any], None] = None, update: Union[Dict[str, Any], None] = None, - ) -> _TSQLModel: + ) -> TSQLModel: # This was SQLModel's original from_orm() for Pydantic v1 # Duplicated from Pydantic if not cls.__config__.orm_mode: # noqa @@ -449,7 +449,7 @@ def _model_validate( # End SQLModel support dict if not getattr(cls.__config__, "table", False): # noqa # If not table, normal Pydantic code - m: _TSQLModel = cls.__new__(cls) + m: TSQLModel = cls.__new__(cls) else: # If table, create the new instance normally to make SQLAlchemy create # the _sa_instance_state attribute @@ -468,27 +468,25 @@ def _model_validate( m._init_private_attributes() # noqa return m - def __sqlmodel_init__(__pydantic_self__: "SQLModel", **data: Any) -> None: - values, fields_set, validation_error = validate_model( - __pydantic_self__.__class__, data - ) + def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None: + values, fields_set, validation_error = validate_model(self.__class__, data) # Only raise errors if not a SQLModel model if ( - not cls_is_table(__pydantic_self__.__class__) # noqa + not is_table_model_class(self.__class__) # noqa and validation_error ): raise validation_error - if not cls_is_table(__pydantic_self__.__class__): - object.__setattr__(__pydantic_self__, "__dict__", values) + if not is_table_model_class(self.__class__): + object.__setattr__(self, "__dict__", values) else: # Do not set values as in Pydantic, pass them through setattr, so # SQLAlchemy can handle them for key, value in values.items(): - setattr(__pydantic_self__, key, value) - object.__setattr__(__pydantic_self__, "__fields_set__", fields_set) + setattr(self, key, value) + object.__setattr__(self, "__fields_set__", fields_set) non_pydantic_keys = data.keys() - values.keys() - if cls_is_table(__pydantic_self__.__class__): + if is_table_model_class(self.__class__): for key in non_pydantic_keys: - if key in __pydantic_self__.__sqlmodel_relationships__: - setattr(__pydantic_self__, key, data[key]) + if key in self.__sqlmodel_relationships__: + setattr(self, key, data[key]) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index ed687d64bb..6fe336babd 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -62,20 +62,20 @@ SQLModelConfig, Undefined, UndefinedType, - __sqlmodel_init__, - _finish_init, - _is_field_noneable, - _model_validate, - cls_is_table, + finish_init, get_annotations, get_config_value, get_field_metadata, get_model_fields, get_relationship_to, get_type_from_field, + is_field_noneable, + is_table_model_class, post_init_field_info, set_config_value, set_fields_set, + sqlmodel_init, + sqlmodel_validate, ) from .sql.sqltypes import GUID, AutoString @@ -509,8 +509,8 @@ def __init__( # this allows FastAPI cloning a SQLModel for the response_model without # trying to create a new SQLAlchemy, for a new table, with the same name, that # triggers an error - base_is_table = any(cls_is_table(base) for base in bases) - if cls_is_table(cls) and not base_is_table: + base_is_table = any(is_table_model_class(base) for base in bases) + if is_table_model_class(cls) and not base_is_table: for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): if rel_info.sa_relationship: # There's a SQLAlchemy relationship declared, that takes precedence @@ -628,7 +628,7 @@ def get_column_from_field(field: Any) -> Column: # type: ignore index = getattr(field_info, "index", Undefined) if index is Undefined: index = False - nullable = not primary_key and _is_field_noneable(field) + nullable = not primary_key and is_field_noneable(field) # Override derived nullability if the nullable property is set explicitly # on the field field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009 @@ -724,8 +724,8 @@ def __init__(__pydantic_self__, **data: Any) -> None: # So, to be able to set up the internal SQLAlchemy logic alone without # executing the rest, and support things like Model.model_validate(), we # use a contextvar to know if we should execute everything. - if _finish_init.get(): - __sqlmodel_init__(__pydantic_self__, **data) + if finish_init.get(): + sqlmodel_init(self=__pydantic_self__, data=data) def __setattr__(self, name: str, value: Any) -> None: if name in {"_sa_instance_state"}: @@ -767,7 +767,7 @@ def model_validate( context: Union[Dict[str, Any], None] = None, update: Union[Dict[str, Any], None] = None, ) -> _TSQLModel: - return _model_validate( + return sqlmodel_validate( cls=cls, obj=obj, strict=strict, From 349a374a48d109f72e3fa522f02ee33c0bc8a508 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 17:11:24 +0100 Subject: [PATCH 076/105] =?UTF-8?q?=F0=9F=93=9D=20Update=20docs,=20use=20M?= =?UTF-8?q?odel.model=5Fvalidate=20everywhere?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tutorial/fastapi/app_testing/tutorial001/main.py | 6 +----- docs_src/tutorial/fastapi/delete/tutorial001.py | 6 +----- .../tutorial/fastapi/limit_and_offset/tutorial001.py | 6 +----- .../tutorial/fastapi/multiple_models/tutorial001.py | 6 +----- .../tutorial/fastapi/multiple_models/tutorial002.py | 6 +----- docs_src/tutorial/fastapi/read_one/tutorial001.py | 6 +----- .../tutorial/fastapi/relationships/tutorial001.py | 11 ++--------- .../fastapi/session_with_dependency/tutorial001.py | 6 +----- docs_src/tutorial/fastapi/teams/tutorial001.py | 11 ++--------- docs_src/tutorial/fastapi/update/tutorial001.py | 6 +----- 10 files changed, 12 insertions(+), 58 deletions(-) diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py index cc830e8b19..a23dfad5a8 100644 --- a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py +++ b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py @@ -2,7 +2,6 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel._compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -55,10 +54,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - if IS_PYDANTIC_V2: - db_hero = Hero.model_validate(hero) - else: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/delete/tutorial001.py b/docs_src/tutorial/fastapi/delete/tutorial001.py index 04e23ee251..77a99a9c97 100644 --- a/docs_src/tutorial/fastapi/delete/tutorial001.py +++ b/docs_src/tutorial/fastapi/delete/tutorial001.py @@ -2,7 +2,6 @@ from fastapi import FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel._compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -51,10 +50,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - if IS_PYDANTIC_V2: - db_hero = Hero.model_validate(hero) - else: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py index 223edddfa7..2352f39022 100644 --- a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py +++ b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py @@ -2,7 +2,6 @@ from fastapi import FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel._compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -45,10 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - if IS_PYDANTIC_V2: - db_hero = Hero.model_validate(hero) - else: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py index 10b169d3b8..7f59ac6a1d 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py @@ -2,7 +2,6 @@ from fastapi import FastAPI from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel._compat import IS_PYDANTIC_V2 class Hero(SQLModel, table=True): @@ -47,10 +46,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - if IS_PYDANTIC_V2: - db_hero = Hero.model_validate(hero) - else: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py index daa34ccedd..fffbe72496 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py @@ -2,7 +2,6 @@ from fastapi import FastAPI from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel._compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -45,10 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - if IS_PYDANTIC_V2: - db_hero = Hero.model_validate(hero) - else: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/read_one/tutorial001.py b/docs_src/tutorial/fastapi/read_one/tutorial001.py index 1a106b80ca..f18426e74c 100644 --- a/docs_src/tutorial/fastapi/read_one/tutorial001.py +++ b/docs_src/tutorial/fastapi/read_one/tutorial001.py @@ -2,7 +2,6 @@ from fastapi import FastAPI, HTTPException from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel._compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -45,10 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - if IS_PYDANTIC_V2: - db_hero = Hero.model_validate(hero) - else: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/relationships/tutorial001.py b/docs_src/tutorial/fastapi/relationships/tutorial001.py index 1f067c302a..e5b196090e 100644 --- a/docs_src/tutorial/fastapi/relationships/tutorial001.py +++ b/docs_src/tutorial/fastapi/relationships/tutorial001.py @@ -2,7 +2,6 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select -from sqlmodel._compat import IS_PYDANTIC_V2 class TeamBase(SQLModel): @@ -93,10 +92,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - if IS_PYDANTIC_V2: - db_hero = Hero.model_validate(hero) - else: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -150,10 +146,7 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - if IS_PYDANTIC_V2: - db_team = Team.model_validate(team) - else: - db_team = Team.from_orm(team) + db_team = Team.model_validate(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py index cc830e8b19..a23dfad5a8 100644 --- a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py +++ b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py @@ -2,7 +2,6 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel._compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -55,10 +54,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - if IS_PYDANTIC_V2: - db_hero = Hero.model_validate(hero) - else: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/teams/tutorial001.py b/docs_src/tutorial/fastapi/teams/tutorial001.py index 3f757c0cec..cc73bb52cb 100644 --- a/docs_src/tutorial/fastapi/teams/tutorial001.py +++ b/docs_src/tutorial/fastapi/teams/tutorial001.py @@ -2,7 +2,6 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select -from sqlmodel._compat import IS_PYDANTIC_V2 class TeamBase(SQLModel): @@ -84,10 +83,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - if IS_PYDANTIC_V2: - db_hero = Hero.model_validate(hero) - else: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -141,10 +137,7 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - if IS_PYDANTIC_V2: - db_team = Team.model_validate(team) - else: - db_team = Team.from_orm(team) + db_team = Team.model_validate(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/update/tutorial001.py b/docs_src/tutorial/fastapi/update/tutorial001.py index c179c6e363..28462bff17 100644 --- a/docs_src/tutorial/fastapi/update/tutorial001.py +++ b/docs_src/tutorial/fastapi/update/tutorial001.py @@ -2,7 +2,6 @@ from fastapi import FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select -from sqlmodel._compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -51,10 +50,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - if IS_PYDANTIC_V2: - db_hero = Hero.model_validate(hero) - else: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) From 577ebb92be1b856ec8e7775b8a44a8a622283c83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 17:34:34 +0100 Subject: [PATCH 077/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20and=20s?= =?UTF-8?q?implify=20main.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 6fe336babd..cc8d0cf43c 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -390,22 +390,20 @@ def Relationship( @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: Dict[str, RelationshipInfo] - if IS_PYDANTIC_V2: - model_config: SQLModelConfig - model_fields: Dict[str, FieldInfo] - else: - __config__: Type[SQLModelConfig] - __fields__: Dict[str, ModelField] + model_config: SQLModelConfig + model_fields: Dict[str, FieldInfo] + __config__: Type[SQLModelConfig] + __fields__: Dict[str, ModelField] # Replicate SQLAlchemy def __setattr__(cls, name: str, value: Any) -> None: - if get_config_value(model=cls, parameter="table", default=False): + if is_table_model_class(cls): DeclarativeMeta.__setattr__(cls, name, value) else: super().__setattr__(name, value) def __delattr__(cls, name: str) -> None: - if get_config_value(model=cls, parameter="table", default=False): + if is_table_model_class(cls): DeclarativeMeta.__delattr__(cls, name) else: super().__delattr__(name) @@ -733,9 +731,7 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if get_config_value( - model=self, parameter="table", default=False - ) and is_instrumented(self, name): + if is_table_model_class(self) and is_instrumented(self, name): set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values From f39bd8fc1f3d243de30dbb947315e83077e4cc83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 18:30:49 +0100 Subject: [PATCH 078/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Reimplement=20=5Fc?= =?UTF-8?q?alculate=5Fkeys=20in=20=5Fcompat.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 74 +++++++++++++++++++++++++++++++++++++-------- sqlmodel/main.py | 47 +++++----------------------- 2 files changed, 70 insertions(+), 51 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 94d3c2790a..b7f3f22fc8 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -4,9 +4,11 @@ from dataclasses import dataclass from typing import ( TYPE_CHECKING, + AbstractSet, Any, Dict, ForwardRef, + Mapping, Optional, Set, Type, @@ -29,7 +31,7 @@ NoneType = type(None) T = TypeVar("T") InstanceOrType = Union[T, Type[T]] -TSQLModel = TypeVar("TSQLModel", bound="SQLModel") +_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") class FakeMetadata: @@ -195,14 +197,22 @@ def get_field_metadata(field: Any) -> Any: def post_init_field_info(field_info: FieldInfo) -> None: return None + # Dummy to make it importable + def _calculate_keys( + self, + include: Optional[Mapping[Union[int, str], Any]], + exclude: Optional[Mapping[Union[int, str], Any]], + exclude_unset: bool, + update: Optional[Dict[str, Any]] = None, + ) -> Optional[AbstractSet[str]]: + return None + def sqlmodel_table_construct( - # SQLModel override - # cls: Type[_TSQLModel], _fields_set: Union[Set[str], None] = None, **values: Any *, - self_instance: TSQLModel, + self_instance: _TSQLModel, values: Dict[str, Any], _fields_set: Union[Set[str], None] = None, - ) -> TSQLModel: + ) -> _TSQLModel: # Copy from Pydantic's BaseModel.construct() # Ref: https://github.com/pydantic/pydantic/blob/v2.5.2/pydantic/main.py#L198 # Modified to not include everything, only the model fields, and to @@ -265,16 +275,16 @@ def sqlmodel_table_construct( return self_instance def sqlmodel_validate( - cls: Type[TSQLModel], + cls: Type[_TSQLModel], obj: Any, *, strict: Union[bool, None] = None, from_attributes: Union[bool, None] = None, context: Union[Dict[str, Any], None] = None, update: Union[Dict[str, Any], None] = None, - ) -> TSQLModel: + ) -> _TSQLModel: if not is_table_model_class(cls): - new_obj: TSQLModel = cls.__new__(cls) + new_obj: _TSQLModel = cls.__new__(cls) else: # If table, create the new instance normally to make SQLAlchemy create # the _sa_instance_state attribute @@ -343,7 +353,7 @@ def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None: from pydantic.main import ModelMetaclass as ModelMetaclass from pydantic.main import validate_model from pydantic.typing import resolve_annotations - from pydantic.utils import ROOT_KEY + from pydantic.utils import ROOT_KEY, ValueItems from pydantic.utils import Representation as Representation class SQLModelConfig(BaseConfig): @@ -423,15 +433,55 @@ def get_field_metadata(field: Any) -> Any: def post_init_field_info(field_info: FieldInfo) -> None: field_info._validate() + def _calculate_keys( + self, + include: Optional[Mapping[Union[int, str], Any]], + exclude: Optional[Mapping[Union[int, str], Any]], + exclude_unset: bool, + update: Optional[Dict[str, Any]] = None, + ) -> Optional[AbstractSet[str]]: + if include is None and exclude is None and not exclude_unset: + # Original in Pydantic: + # return None + # Updated to not return SQLAlchemy attributes + # Do not include relationships as that would easily lead to infinite + # recursion, or traversing the whole database + return ( + self.__fields__.keys() # noqa + ) # | self.__sqlmodel_relationships__.keys() + + keys: AbstractSet[str] + if exclude_unset: + keys = self.__fields_set__.copy() # noqa + else: + # Original in Pydantic: + # keys = self.__dict__.keys() + # Updated to not return SQLAlchemy attributes + # Do not include relationships as that would easily lead to infinite + # recursion, or traversing the whole database + keys = ( + self.__fields__.keys() # noqa + ) # | self.__sqlmodel_relationships__.keys() + if include is not None: + keys &= include.keys() + + if update: + keys -= update.keys() + + if exclude: + keys -= {k for k, v in exclude.items() if ValueItems.is_true(v)} + + return keys + def sqlmodel_validate( - cls: Type[TSQLModel], + cls: Type[_TSQLModel], obj: Any, *, strict: Union[bool, None] = None, from_attributes: Union[bool, None] = None, context: Union[Dict[str, Any], None] = None, update: Union[Dict[str, Any], None] = None, - ) -> TSQLModel: + ) -> _TSQLModel: # This was SQLModel's original from_orm() for Pydantic v1 # Duplicated from Pydantic if not cls.__config__.orm_mode: # noqa @@ -449,7 +499,7 @@ def sqlmodel_validate( # End SQLModel support dict if not getattr(cls.__config__, "table", False): # noqa # If not table, normal Pydantic code - m: TSQLModel = cls.__new__(cls) + m: _TSQLModel = cls.__new__(cls) else: # If table, create the new instance normally to make SQLAlchemy create # the _sa_instance_state attribute diff --git a/sqlmodel/main.py b/sqlmodel/main.py index cc8d0cf43c..aff673d525 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -62,6 +62,7 @@ SQLModelConfig, Undefined, UndefinedType, + _calculate_keys, finish_init, get_annotations, get_config_value, @@ -669,13 +670,6 @@ def get_column_from_field(field: Any) -> Column: # type: ignore default_registry = registry() - -def _value_items_is_true(v: Any) -> bool: - # Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of - # the current latest, Pydantic 1.8.2 - return v is True or v is ... - - _TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") @@ -825,35 +819,10 @@ def _calculate_keys( exclude_unset: bool, update: Optional[Dict[str, Any]] = None, ) -> Optional[AbstractSet[str]]: - if include is None and exclude is None and not exclude_unset: - # Original in Pydantic: - # return None - # Updated to not return SQLAlchemy attributes - # Do not include relationships as that would easily lead to infinite - # recursion, or traversing the whole database - return ( - self.__fields__.keys() # noqa - ) # | self.__sqlmodel_relationships__.keys() - - keys: AbstractSet[str] - if exclude_unset: - keys = self.__fields_set__.copy() # noqa - else: - # Original in Pydantic: - # keys = self.__dict__.keys() - # Updated to not return SQLAlchemy attributes - # Do not include relationships as that would easily lead to infinite - # recursion, or traversing the whole database - keys = ( - self.__fields__.keys() # noqa - ) # | self.__sqlmodel_relationships__.keys() - if include is not None: - keys &= include.keys() - - if update: - keys -= update.keys() - - if exclude: - keys -= {k for k, v in exclude.items() if _value_items_is_true(v)} - - return keys + return _calculate_keys( + self, + include=include, + exclude=exclude, + exclude_unset=exclude_unset, + update=update, + ) From 4c38fb0fb8026fb0502bea2fd8a78d328d7ef298 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 18:32:18 +0100 Subject: [PATCH 079/105] =?UTF-8?q?=F0=9F=97=91=EF=B8=8F=20Add=20deprecati?= =?UTF-8?q?on=20markers=20for=20methods?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 67 ++++++++++++++++++------------------------------ 1 file changed, 25 insertions(+), 42 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index aff673d525..c2268f7415 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -51,7 +51,7 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time -from typing_extensions import get_origin +from typing_extensions import deprecated, get_origin from ._compat import ( IS_PYDANTIC_V2, @@ -80,10 +80,6 @@ ) from .sql.sqltypes import GUID, AutoString -if not IS_PYDANTIC_V2: - from pydantic.errors import DictError - from pydantic.main import validate_model - _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] @@ -744,9 +740,6 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: def __tablename__(cls) -> str: return cls.__name__.lower() - # TODO: refactor this and make each method available in both Pydantic v1 and v2 - # add deprecations, re-use methods from backwards compatibility parts, etc. - @classmethod def model_validate( cls: Type[_TSQLModel], @@ -767,51 +760,41 @@ def model_validate( ) @classmethod + @deprecated( + """ + 🚨 `obj.from_orm(data)` was deprecated in SQLModel 0.0.12, you should + instead use `obj.model_validate(data)`. + """ + ) def from_orm( cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None ) -> _TSQLModel: return cls.model_validate(obj, update=update) @classmethod + @deprecated( + """ + 🚨 `obj.parse_obj(data)` was deprecated in SQLModel 0.0.12, you should + instead use `obj.model_validate(data)`. + """ + ) def parse_obj( cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None ) -> _TSQLModel: - obj = cls._enforce_dict_if_root(obj) # noqa - # SQLModel, support update dict - if update is not None: - obj = {**obj, **update} - # End SQLModel support dict - return super().parse_obj(obj) - - # From Pydantic, override to enforce validation with dict - @classmethod - def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel: - if isinstance(value, cls): - return ( - value.copy() if cls.__config__.copy_on_model_validation else value # noqa - ) - - value = cls._enforce_dict_if_root(value) - if isinstance(value, dict): - values, fields_set, validation_error = validate_model(cls, value) - if validation_error: - raise validation_error - model = cls(**value) - # Reset fields set, this would have been done in Pydantic in __init__ - object.__setattr__(model, "__fields_set__", fields_set) - return model - elif cls.__config__.orm_mode: # noqa - return cls.from_orm(value) - elif cls.__custom_root_type__: # noqa - return cls.parse_obj(value) - else: - try: - value_as_dict = dict(value) - except (TypeError, ValueError) as e: - raise DictError() from e - return cls(**value_as_dict) + if not IS_PYDANTIC_V2: + obj = cls._enforce_dict_if_root(obj) # noqa + return cls.model_validate(obj, update=update) # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes + @deprecated( + """ + 🚨 You should not access `obj._calculate_keys()` directly. + + It is only useful for Pydantic v1.X, you should probably upgrade to + Pydantic v2.X. + """, + category=None, + ) def _calculate_keys( self, include: Optional[Mapping[Union[int, str], Any]], From 5eb99a371ef3eb9e604067cac11a312c00c129fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 18:33:26 +0100 Subject: [PATCH 080/105] =?UTF-8?q?=F0=9F=93=9D=20Update=20reference=20to?= =?UTF-8?q?=20from=5Form()=20in=20docs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/tutorial/fastapi/multiple-models.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/tutorial/fastapi/multiple-models.md b/docs/tutorial/fastapi/multiple-models.md index b8b63bdbb9..183d9cdff7 100644 --- a/docs/tutorial/fastapi/multiple-models.md +++ b/docs/tutorial/fastapi/multiple-models.md @@ -177,9 +177,7 @@ Now we use the type annotation `HeroCreate` for the request JSON data in the `he Then we create a new `Hero` (this is the actual **table** model that saves things to the database) using `Hero.model_validate()`. -The method `.model_validate()` reads data from another object with attributes and creates a new instance of this class, in this case `Hero`. - -The alternative is `Hero.parse_obj()` that reads data from a dictionary. +The method `.model_validate()` reads data from another object with attributes (or a dict) and creates a new instance of this class, in this case `Hero`. But as in this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.model_validate()` to read those attributes. From b7f27aa6d24c17757e0e9ef2df11e8241fb8078b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 18:38:59 +0100 Subject: [PATCH 081/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20source?= =?UTF-8?q?=20examples=20to=20use=20model=5Fvalidate=20instead=20of=20from?= =?UTF-8?q?=5Form?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tutorial/fastapi/app_testing/tutorial001_py310/main.py | 2 +- .../tutorial/fastapi/app_testing/tutorial001_py39/main.py | 2 +- docs_src/tutorial/fastapi/delete/tutorial001_py310.py | 2 +- docs_src/tutorial/fastapi/delete/tutorial001_py39.py | 2 +- .../tutorial/fastapi/limit_and_offset/tutorial001_py310.py | 2 +- .../tutorial/fastapi/limit_and_offset/tutorial001_py39.py | 2 +- .../tutorial/fastapi/multiple_models/tutorial001_py310.py | 2 +- docs_src/tutorial/fastapi/multiple_models/tutorial001_py39.py | 2 +- .../tutorial/fastapi/multiple_models/tutorial002_py310.py | 2 +- docs_src/tutorial/fastapi/multiple_models/tutorial002_py39.py | 2 +- docs_src/tutorial/fastapi/read_one/tutorial001_py310.py | 2 +- docs_src/tutorial/fastapi/read_one/tutorial001_py39.py | 2 +- docs_src/tutorial/fastapi/relationships/tutorial001_py310.py | 4 ++-- docs_src/tutorial/fastapi/relationships/tutorial001_py39.py | 4 ++-- .../fastapi/session_with_dependency/tutorial001_py310.py | 2 +- .../fastapi/session_with_dependency/tutorial001_py39.py | 2 +- docs_src/tutorial/fastapi/teams/tutorial001_py310.py | 4 ++-- docs_src/tutorial/fastapi/teams/tutorial001_py39.py | 4 ++-- docs_src/tutorial/fastapi/update/tutorial001_py310.py | 2 +- docs_src/tutorial/fastapi/update/tutorial001_py39.py | 2 +- 20 files changed, 24 insertions(+), 24 deletions(-) diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001_py310/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001_py310/main.py index e8615d91df..66c34d6939 100644 --- a/docs_src/tutorial/fastapi/app_testing/tutorial001_py310/main.py +++ b/docs_src/tutorial/fastapi/app_testing/tutorial001_py310/main.py @@ -52,7 +52,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001_py39/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001_py39/main.py index 9816e70eb0..d71cb34f49 100644 --- a/docs_src/tutorial/fastapi/app_testing/tutorial001_py39/main.py +++ b/docs_src/tutorial/fastapi/app_testing/tutorial001_py39/main.py @@ -54,7 +54,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/delete/tutorial001_py310.py b/docs_src/tutorial/fastapi/delete/tutorial001_py310.py index 5b2da0a0b1..9a3cbb1dd6 100644 --- a/docs_src/tutorial/fastapi/delete/tutorial001_py310.py +++ b/docs_src/tutorial/fastapi/delete/tutorial001_py310.py @@ -48,7 +48,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/delete/tutorial001_py39.py b/docs_src/tutorial/fastapi/delete/tutorial001_py39.py index 5f498cf136..218d423ff2 100644 --- a/docs_src/tutorial/fastapi/delete/tutorial001_py39.py +++ b/docs_src/tutorial/fastapi/delete/tutorial001_py39.py @@ -50,7 +50,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001_py310.py b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001_py310.py index 874a6e8438..ad8ff95e3a 100644 --- a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001_py310.py +++ b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001_py310.py @@ -42,7 +42,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001_py39.py b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001_py39.py index b63fa753ff..b1f7cdcb6a 100644 --- a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001_py39.py +++ b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001_py39.py @@ -44,7 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial001_py310.py b/docs_src/tutorial/fastapi/multiple_models/tutorial001_py310.py index 13129f383f..ff12eff55c 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial001_py310.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial001_py310.py @@ -44,7 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial001_py39.py b/docs_src/tutorial/fastapi/multiple_models/tutorial001_py39.py index 41a51f448d..977a1ac8db 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial001_py39.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial001_py39.py @@ -46,7 +46,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial002_py310.py b/docs_src/tutorial/fastapi/multiple_models/tutorial002_py310.py index 3eda88b194..7373edff5e 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial002_py310.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial002_py310.py @@ -42,7 +42,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial002_py39.py b/docs_src/tutorial/fastapi/multiple_models/tutorial002_py39.py index 473fe5b832..1b4a512520 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial002_py39.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial002_py39.py @@ -44,7 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/read_one/tutorial001_py310.py b/docs_src/tutorial/fastapi/read_one/tutorial001_py310.py index 8883570dc5..e8c7d49b99 100644 --- a/docs_src/tutorial/fastapi/read_one/tutorial001_py310.py +++ b/docs_src/tutorial/fastapi/read_one/tutorial001_py310.py @@ -42,7 +42,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/read_one/tutorial001_py39.py b/docs_src/tutorial/fastapi/read_one/tutorial001_py39.py index 0ad7016687..4dc5702fb6 100644 --- a/docs_src/tutorial/fastapi/read_one/tutorial001_py39.py +++ b/docs_src/tutorial/fastapi/read_one/tutorial001_py39.py @@ -44,7 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/relationships/tutorial001_py310.py b/docs_src/tutorial/fastapi/relationships/tutorial001_py310.py index bec6a6f2e2..8ac9eef332 100644 --- a/docs_src/tutorial/fastapi/relationships/tutorial001_py310.py +++ b/docs_src/tutorial/fastapi/relationships/tutorial001_py310.py @@ -90,7 +90,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -144,7 +144,7 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.from_orm(team) + db_team = Team.model_validate(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/relationships/tutorial001_py39.py b/docs_src/tutorial/fastapi/relationships/tutorial001_py39.py index 3893905519..d5c209e083 100644 --- a/docs_src/tutorial/fastapi/relationships/tutorial001_py39.py +++ b/docs_src/tutorial/fastapi/relationships/tutorial001_py39.py @@ -92,7 +92,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -146,7 +146,7 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.from_orm(team) + db_team = Team.model_validate(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py310.py b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py310.py index e8615d91df..66c34d6939 100644 --- a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py310.py +++ b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py310.py @@ -52,7 +52,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py39.py b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py39.py index 9816e70eb0..d71cb34f49 100644 --- a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py39.py +++ b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py39.py @@ -54,7 +54,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/teams/tutorial001_py310.py b/docs_src/tutorial/fastapi/teams/tutorial001_py310.py index a9a527df73..e57078c2ad 100644 --- a/docs_src/tutorial/fastapi/teams/tutorial001_py310.py +++ b/docs_src/tutorial/fastapi/teams/tutorial001_py310.py @@ -81,7 +81,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -135,7 +135,7 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.from_orm(team) + db_team = Team.model_validate(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/teams/tutorial001_py39.py b/docs_src/tutorial/fastapi/teams/tutorial001_py39.py index 1a36428994..eaa2e04896 100644 --- a/docs_src/tutorial/fastapi/teams/tutorial001_py39.py +++ b/docs_src/tutorial/fastapi/teams/tutorial001_py39.py @@ -83,7 +83,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -137,7 +137,7 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.from_orm(team) + db_team = Team.model_validate(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/update/tutorial001_py310.py b/docs_src/tutorial/fastapi/update/tutorial001_py310.py index 79069181fb..4125c40437 100644 --- a/docs_src/tutorial/fastapi/update/tutorial001_py310.py +++ b/docs_src/tutorial/fastapi/update/tutorial001_py310.py @@ -48,7 +48,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/update/tutorial001_py39.py b/docs_src/tutorial/fastapi/update/tutorial001_py39.py index c788eb1c7a..2c1321dd22 100644 --- a/docs_src/tutorial/fastapi/update/tutorial001_py39.py +++ b/docs_src/tutorial/fastapi/update/tutorial001_py39.py @@ -50,7 +50,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) From 3070a1e791778fd27ca0da68c767ea2bb8396aa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 18:46:42 +0100 Subject: [PATCH 082/105] =?UTF-8?q?=F0=9F=93=9D=20Add=20note=20about=20fro?= =?UTF-8?q?m=5Form=20to=20the=20docs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/tutorial/fastapi/multiple-models.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/tutorial/fastapi/multiple-models.md b/docs/tutorial/fastapi/multiple-models.md index 183d9cdff7..16c9d96070 100644 --- a/docs/tutorial/fastapi/multiple-models.md +++ b/docs/tutorial/fastapi/multiple-models.md @@ -179,9 +179,13 @@ Then we create a new `Hero` (this is the actual **table** model that saves thing The method `.model_validate()` reads data from another object with attributes (or a dict) and creates a new instance of this class, in this case `Hero`. -But as in this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.model_validate()` to read those attributes. +In this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.model_validate()` to read those attributes. -With this, we create a new `Hero` instance (the one for the database) and put it in the variable `db_hero` from the data in the `hero` variable that is the `HeroCreate` instance we received from the request. +/// tip +In versions of **SQLModel** before `0.0.13` you would use the method `.from_orm()`, but it is now deprecated and you should use `.model_validate()` instead. +/// + +We can now create a new `Hero` instance (the one for the database) and put it in the variable `db_hero` from the data in the `hero` variable that is the `HeroCreate` instance we received from the request. ```Python hl_lines="3" # Code above omitted 👆 From 8ac78048b54c13283cf3e405166ce53b82e4d060 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 18:53:42 +0100 Subject: [PATCH 083/105] =?UTF-8?q?=E2=9C=85=20Tweak=20unreachable=20code?= =?UTF-8?q?=20coverage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 2 +- tests/test_main.py | 1 - tests/test_missing_type.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index b7f3f22fc8..cf134a6e42 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -204,7 +204,7 @@ def _calculate_keys( exclude: Optional[Mapping[Union[int, str], Any]], exclude_unset: bool, update: Optional[Dict[str, Any]] = None, - ) -> Optional[AbstractSet[str]]: + ) -> Optional[AbstractSet[str]]: # pragma: no cover return None def sqlmodel_table_construct( diff --git a/tests/test_main.py b/tests/test_main.py index bdbcdeb76d..60d5c40ebb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -91,7 +91,6 @@ class Hero(SQLModel, table=True): with Session(engine) as session: session.add(hero_2) session.commit() - session.refresh(hero_2) def test_sa_relationship_property(clear_sqlmodel): diff --git a/tests/test_missing_type.py b/tests/test_missing_type.py index dc31f053ec..6f76eb6ad9 100644 --- a/tests/test_missing_type.py +++ b/tests/test_missing_type.py @@ -12,7 +12,7 @@ def __get_validators__(cls): yield cls.validate @classmethod - def validate(cls, v): + def validate(cls, v): # pragma: nocover return v with pytest.raises(ValueError): From ad5eb089991c684d673adafbe71eb08b392ab6f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 18:57:30 +0100 Subject: [PATCH 084/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Simplify=20logic?= =?UTF-8?q?=20for=20is=5Ftable=5Fmodel=5Fclass?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index cf134a6e42..e522cf7ce6 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -106,10 +106,7 @@ def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: return class_dict.get("__annotations__", {}) def is_table_model_class(cls: Type) -> bool: - config = getattr(cls, "model_config", None) - if not config: - return False - return config.get("table", False) + return cls.model_config.get("table", False) def get_relationship_to( name: str, @@ -388,10 +385,7 @@ def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: ) def is_table_model_class(cls: Type) -> bool: - config = getattr(cls, "__config__", None) - if not config: - return False - return getattr(config, "table", False) + return getattr(cls.__config__, "table", False) def get_relationship_to( name: str, From 0e74c50900dece592652e6fb0df46837400c8626 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 19:11:32 +0100 Subject: [PATCH 085/105] =?UTF-8?q?=E2=9C=85=20Add=20tests=20for=20depreca?= =?UTF-8?q?tions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_deprecations.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 tests/test_deprecations.py diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py new file mode 100644 index 0000000000..0e565c5410 --- /dev/null +++ b/tests/test_deprecations.py @@ -0,0 +1,24 @@ +import pytest +from sqlmodel import SQLModel + + +class Item(SQLModel): + name: str + + +class SubItem(Item): + password: str + + +def test_deprecated_from_orm_inheritance(): + new_item = SubItem(name="Hello", password="secret") + with pytest.warns(DeprecationWarning): + item = Item.from_orm(new_item) + assert item.name == "Hello" + assert not hasattr(item, "password") + + +def test_deprecated_parse_obj(): + with pytest.warns(DeprecationWarning): + item = Item.parse_obj({"name": "Hello"}) + assert item.name == "Hello" From cac622a552899190bf0b20ea2950ff42db922453 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 19:11:54 +0100 Subject: [PATCH 086/105] =?UTF-8?q?=F0=9F=90=9B=20Fix=20implementation=20o?= =?UTF-8?q?f=20sqlmodel=5Fvalidate=20for=20Pydantic=20v1=20with=20deprecat?= =?UTF-8?q?ed=20parse=5Fobj=20with=20a=20dict?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index e522cf7ce6..eab14afd37 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -482,11 +482,12 @@ def sqlmodel_validate( raise ConfigError( "You must have the config attribute orm_mode=True to use from_orm" ) - obj = ( - {ROOT_KEY: obj} - if cls.__custom_root_type__ # noqa - else cls._decompose_class(obj) # noqa - ) + if not isinstance(obj, Mapping): + obj = ( + {ROOT_KEY: obj} + if cls.__custom_root_type__ # noqa + else cls._decompose_class(obj) # noqa + ) # SQLModel, support update dict if update is not None: obj = {**obj, **update} From 9d883f5c817433e25727b9ba3446604c5fc3c1e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 3 Dec 2023 19:18:07 +0100 Subject: [PATCH 087/105] =?UTF-8?q?=E2=9C=85=20Tweak=20coverage=20for=20un?= =?UTF-8?q?reachable=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_field_sa_relationship.py | 4 ++-- tests/test_missing_type.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_field_sa_relationship.py b/tests/test_field_sa_relationship.py index 7606fd86d8..022a100a78 100644 --- a/tests/test_field_sa_relationship.py +++ b/tests/test_field_sa_relationship.py @@ -6,7 +6,7 @@ def test_sa_relationship_no_args() -> None: - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError): # pragma: no cover class Team(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) @@ -30,7 +30,7 @@ class Hero(SQLModel, table=True): def test_sa_relationship_no_kwargs() -> None: - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError): # pragma: no cover class Team(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) diff --git a/tests/test_missing_type.py b/tests/test_missing_type.py index 6f76eb6ad9..ac4aa42e05 100644 --- a/tests/test_missing_type.py +++ b/tests/test_missing_type.py @@ -12,7 +12,7 @@ def __get_validators__(cls): yield cls.validate @classmethod - def validate(cls, v): # pragma: nocover + def validate(cls, v): # pragma: no cover return v with pytest.raises(ValueError): From cac349c3b46b763a9b5aa36f3fbd1a856fa2d800 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 13:20:38 +0100 Subject: [PATCH 088/105] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Update=20reference?= =?UTF-8?q?s=20to=20version=200.0.13=20as=20latest?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/tutorial/fastapi/multiple-models.md | 2 +- sqlmodel/main.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/tutorial/fastapi/multiple-models.md b/docs/tutorial/fastapi/multiple-models.md index 16c9d96070..3995daa650 100644 --- a/docs/tutorial/fastapi/multiple-models.md +++ b/docs/tutorial/fastapi/multiple-models.md @@ -182,7 +182,7 @@ The method `.model_validate()` reads data from another object with attributes (o In this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.model_validate()` to read those attributes. /// tip -In versions of **SQLModel** before `0.0.13` you would use the method `.from_orm()`, but it is now deprecated and you should use `.model_validate()` instead. +In versions of **SQLModel** before `0.0.14` you would use the method `.from_orm()`, but it is now deprecated and you should use `.model_validate()` instead. /// We can now create a new `Hero` instance (the one for the database) and put it in the variable `db_hero` from the data in the `hero` variable that is the `HeroCreate` instance we received from the request. diff --git a/sqlmodel/main.py b/sqlmodel/main.py index c2268f7415..24b7c514f0 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -762,7 +762,7 @@ def model_validate( @classmethod @deprecated( """ - 🚨 `obj.from_orm(data)` was deprecated in SQLModel 0.0.12, you should + 🚨 `obj.from_orm(data)` was deprecated in SQLModel 0.0.14, you should instead use `obj.model_validate(data)`. """ ) @@ -774,7 +774,7 @@ def from_orm( @classmethod @deprecated( """ - 🚨 `obj.parse_obj(data)` was deprecated in SQLModel 0.0.12, you should + 🚨 `obj.parse_obj(data)` was deprecated in SQLModel 0.0.14, you should instead use `obj.model_validate(data)`. """ ) From 8d5e62b0533ea23d06b4c8e9ecb8fbf0ab3494bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 14:40:52 +0100 Subject: [PATCH 089/105] =?UTF-8?q?=F0=9F=8E=A8=20Fix=20type=20annotations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 86 +++++++++++++++++++++++++++------------------ sqlmodel/main.py | 18 +++++----- 2 files changed, 60 insertions(+), 44 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index eab14afd37..97467dbaf7 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -8,6 +8,7 @@ Any, Dict, ForwardRef, + Generator, Mapping, Optional, Set, @@ -59,7 +60,7 @@ def _is_union_type(t: Any) -> bool: @contextmanager -def partial_init(): +def partial_init() -> Generator[None, None, None]: token = finish_init.set(False) yield finish_init.reset(token) @@ -70,7 +71,7 @@ def partial_init(): from pydantic._internal._fields import PydanticMetadata from pydantic._internal._model_construction import ModelMetaclass from pydantic._internal._repr import Representation as Representation - from pydantic_core import PydanticUndefined as Undefined # noqa + from pydantic_core import PydanticUndefined as Undefined from pydantic_core import PydanticUndefinedType as UndefinedType # Dummy for types, to make it importable @@ -92,10 +93,10 @@ def set_config_value( parameter: str, value: Any, ) -> None: - model.model_config[parameter] = value + model.model_config[parameter] = value # type: ignore[literal-required] def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: - return model.model_fields # type: ignore + return model.model_fields def set_fields_set( new_object: InstanceOrType["SQLModel"], fields: set["FieldInfo"] @@ -105,8 +106,11 @@ def set_fields_set( def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: return class_dict.get("__annotations__", {}) - def is_table_model_class(cls: Type) -> bool: - return cls.model_config.get("table", False) + def is_table_model_class(cls: Type[Any]) -> bool: + config = getattr(cls, "model_config", {}) + if config: + return config.get("table", False) or False + return False def get_relationship_to( name: str, @@ -157,7 +161,7 @@ def is_field_noneable(field: "FieldInfo") -> bool: if not field.is_required(): if field.default is Undefined: return False - if field.annotation is None or field.annotation is NoneType: + if field.annotation is None or field.annotation is NoneType: # type: ignore[comparison-overlap] return True return False return False @@ -182,8 +186,8 @@ def get_type_from_field(field: Any) -> type: "Cannot have a (non-optional) union as a SQLlchemy field" ) # Optional unions are allowed - return bases[0] if bases[0] is not NoneType else bases[1] - return origin + return bases[0] if bases[0] is not NoneType else bases[1] # type: ignore[no-any-return] + return origin # type: ignore[no-any-return] def get_field_metadata(field: Any) -> Any: for meta in field.metadata: @@ -196,7 +200,7 @@ def post_init_field_info(field_info: FieldInfo) -> None: # Dummy to make it importable def _calculate_keys( - self, + self: "SQLModel", include: Optional[Mapping[Union[int, str], Any]], exclude: Optional[Mapping[Union[int, str], Any]], exclude_unset: bool, @@ -342,25 +346,36 @@ def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None: ) else: - from pydantic import BaseConfig as BaseConfig + from pydantic import BaseConfig as BaseConfig # type: ignore[assignment] from pydantic.errors import ConfigError - from pydantic.fields import SHAPE_SINGLETON, ModelField - from pydantic.fields import Undefined as Undefined # noqa - from pydantic.fields import UndefinedType as UndefinedType - from pydantic.main import ModelMetaclass as ModelMetaclass + from pydantic.fields import ( # type: ignore[attr-defined, no-redef] + SHAPE_SINGLETON, + ModelField, + ) + from pydantic.fields import ( # type: ignore[attr-defined, no-redef] + Undefined as Undefined, # noqa + ) + from pydantic.fields import ( # type: ignore[attr-defined, no-redef] + UndefinedType as UndefinedType, + ) + from pydantic.main import ( # type: ignore[no-redef] + ModelMetaclass as ModelMetaclass, + ) from pydantic.main import validate_model from pydantic.typing import resolve_annotations from pydantic.utils import ROOT_KEY, ValueItems - from pydantic.utils import Representation as Representation + from pydantic.utils import ( # type: ignore[no-redef] + Representation as Representation, + ) - class SQLModelConfig(BaseConfig): - table: Optional[bool] = None - registry: Optional[Any] = None + class SQLModelConfig(BaseConfig): # type: ignore[no-redef] + table: Optional[bool] = None # type: ignore[misc] + registry: Optional[Any] = None # type: ignore[misc] def get_config_value( *, model: InstanceOrType["SQLModel"], parameter: str, default: Any = None ) -> Any: - return getattr(model.__config__, parameter, default) + return getattr(model.__config__, parameter, default) # type: ignore[union-attr] def set_config_value( *, @@ -379,20 +394,23 @@ def set_fields_set( object.__setattr__(new_object, "__fields_set__", fields) def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: - return resolve_annotations( + return resolve_annotations( # type: ignore[no-any-return] class_dict.get("__annotations__", {}), class_dict.get("__module__", None), ) - def is_table_model_class(cls: Type) -> bool: - return getattr(cls.__config__, "table", False) + def is_table_model_class(cls: Type[Any]) -> bool: + config = getattr(cls, "__config__", None) + if config: + return getattr(config, "table", False) + return False def get_relationship_to( name: str, rel_info: "RelationshipInfo", annotation: Any, ) -> Any: - temp_field = ModelField.infer( + temp_field = ModelField.infer( # type: ignore[attr-defined] name=name, value=rel_info, annotation=annotation, @@ -405,12 +423,12 @@ def get_relationship_to( return relationship_to def is_field_noneable(field: "FieldInfo") -> bool: - if not field.required: + if not field.required: # type: ignore[attr-defined] # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) - return field.allow_none and ( - field.shape != SHAPE_SINGLETON or not field.sub_fields + return field.allow_none and ( # type: ignore[attr-defined] + field.shape != SHAPE_SINGLETON or not field.sub_fields # type: ignore[attr-defined] ) - return field.allow_none + return field.allow_none # type: ignore[no-any-return, attr-defined] def get_type_from_field(field: Any) -> type: if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: @@ -425,10 +443,10 @@ def get_field_metadata(field: Any) -> Any: return metadata def post_init_field_info(field_info: FieldInfo) -> None: - field_info._validate() + field_info._validate() # type: ignore[attr-defined] def _calculate_keys( - self, + self: "SQLModel", include: Optional[Mapping[Union[int, str], Any]], exclude: Optional[Mapping[Union[int, str], Any]], exclude_unset: bool, @@ -478,15 +496,15 @@ def sqlmodel_validate( ) -> _TSQLModel: # This was SQLModel's original from_orm() for Pydantic v1 # Duplicated from Pydantic - if not cls.__config__.orm_mode: # noqa + if not cls.__config__.orm_mode: # type: ignore[attr-defined] # noqa raise ConfigError( "You must have the config attribute orm_mode=True to use from_orm" ) if not isinstance(obj, Mapping): obj = ( {ROOT_KEY: obj} - if cls.__custom_root_type__ # noqa - else cls._decompose_class(obj) # noqa + if cls.__custom_root_type__ # type: ignore[attr-defined] # noqa + else cls._decompose_class(obj) # type: ignore[attr-defined] # noqa ) # SQLModel, support update dict if update is not None: @@ -510,7 +528,7 @@ def sqlmodel_validate( setattr(m, key, value) # Continue with standard Pydantic logic object.__setattr__(m, "__fields_set__", fields_set) - m._init_private_attributes() # noqa + m._init_private_attributes() # type: ignore[attr-defined] # noqa return m def sqlmodel_init(*, self: "SQLModel", data: Dict[str, Any]) -> None: diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 24b7c514f0..736b89849e 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -53,7 +53,7 @@ from sqlalchemy.sql.sqltypes import LargeBinary, Time from typing_extensions import deprecated, get_origin -from ._compat import ( +from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, BaseConfig, ModelField, @@ -361,7 +361,7 @@ def Relationship( *, back_populates: Optional[str] = None, link_model: Optional[Any] = None, - sa_relationship: Optional[RelationshipProperty] = None, # type: ignore + sa_relationship: Optional[RelationshipProperty[Any]] = None, ) -> Any: ... @@ -370,7 +370,7 @@ def Relationship( *, back_populates: Optional[str] = None, link_model: Optional[Any] = None, - sa_relationship: Optional[RelationshipProperty] = None, + sa_relationship: Optional[RelationshipProperty[Any]] = None, sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, ) -> Any: @@ -390,7 +390,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): model_config: SQLModelConfig model_fields: Dict[str, FieldInfo] __config__: Type[SQLModelConfig] - __fields__: Dict[str, ModelField] + __fields__: Dict[str, ModelField] # type: ignore[assignment] # Replicate SQLAlchemy def __setattr__(cls, name: str, value: Any) -> None: @@ -447,9 +447,7 @@ def __new__( config_kwargs = { key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } - new_cls: Type["SQLModelMetaclass"] = super().__new__( - cls, name, bases, dict_used, **config_kwargs - ) + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, @@ -673,7 +671,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) __tablename__: ClassVar[Union[str, Callable[..., str]]] - __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] + __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]] __name__: ClassVar[str] metadata: ClassVar[MetaData] __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six @@ -721,7 +719,7 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if is_table_model_class(self) and is_instrumented(self, name): + if is_table_model_class(self.__class__) and is_instrumented(self, name): # type: ignore[no-untyped-call] set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values @@ -782,7 +780,7 @@ def parse_obj( cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None ) -> _TSQLModel: if not IS_PYDANTIC_V2: - obj = cls._enforce_dict_if_root(obj) # noqa + obj = cls._enforce_dict_if_root(obj) # type: ignore[attr-defined] # noqa return cls.model_validate(obj, update=update) # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes From d03c233ab976bc894d77708d388303345ccdbf1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 14:51:39 +0100 Subject: [PATCH 090/105] =?UTF-8?q?=E2=9C=A8=20Add=20method=20model=5Fdump?= =?UTF-8?q?=20for=20compatibility=20with=20Pydantic=20v1=20while=20keeping?= =?UTF-8?q?=20v2=20method=20names?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 736b89849e..70f904813c 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -51,7 +51,7 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time -from typing_extensions import deprecated, get_origin +from typing_extensions import Literal, deprecated, get_origin from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, @@ -82,6 +82,7 @@ _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] +IncEx = Set[int] | Set[str] | Dict[int, Any] | Dict[str, Any] | None def __dataclass_transform__( @@ -757,6 +758,42 @@ def model_validate( update=update, ) + # TODO: remove when deprecating Pydantic v1, only for compatibility + def model_dump( + self, + *, + mode: Literal["json", "python"] | str = "python", + include: IncEx = None, + exclude: IncEx = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + round_trip: bool = False, + warnings: bool = True, + ) -> Dict[str, Any]: + if IS_PYDANTIC_V2: + return super().model_dump( + mode=mode, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + round_trip=round_trip, + warnings=warnings, + ) + else: + return self.dict( + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + @classmethod @deprecated( """ From be49ad5f9215959471fe687dc712508c57b5b449 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 14:54:40 +0100 Subject: [PATCH 091/105] =?UTF-8?q?=E2=9C=A8=20Add=20override=20for=20.dic?= =?UTF-8?q?t()=20only=20to=20deprecate=20it=20and=20help=20people=20migrat?= =?UTF-8?q?e=20to=20model=5Fdump,=20even=20(if)=20before=20they=20migrate?= =?UTF-8?q?=20to=20Pydantic=20v2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 70f904813c..9d2eb3af9f 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -794,6 +794,31 @@ def model_dump( exclude_none=exclude_none, ) + @deprecated( + """ + 🚨 `obj.dict()` was deprecated in SQLModel 0.0.14, you should + instead use `obj.model_dump()`. + """ + ) + def dict( + self, + *, + include: IncEx = None, + exclude: IncEx = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> Dict[str, Any]: + return self.model_dump( + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + @classmethod @deprecated( """ From 10e435613808e8fb72579ccae4efa676e057f3ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 14:56:17 +0100 Subject: [PATCH 092/105] =?UTF-8?q?=F0=9F=93=9D=20Update=20docs=20to=20use?= =?UTF-8?q?=20new=20method=20obj.model=5Fdump()=20instead=20of=20obj.dict(?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/tutorial/fastapi/update.md | 15 +++++++++------ .../fastapi/app_testing/tutorial001/main.py | 2 +- .../fastapi/app_testing/tutorial001_py310/main.py | 2 +- .../fastapi/app_testing/tutorial001_py39/main.py | 2 +- docs_src/tutorial/fastapi/delete/tutorial001.py | 2 +- .../tutorial/fastapi/delete/tutorial001_py310.py | 2 +- .../tutorial/fastapi/delete/tutorial001_py39.py | 2 +- .../tutorial/fastapi/relationships/tutorial001.py | 4 ++-- .../fastapi/relationships/tutorial001_py310.py | 4 ++-- .../fastapi/relationships/tutorial001_py39.py | 4 ++-- .../session_with_dependency/tutorial001.py | 2 +- .../session_with_dependency/tutorial001_py310.py | 2 +- .../session_with_dependency/tutorial001_py39.py | 2 +- docs_src/tutorial/fastapi/teams/tutorial001.py | 4 ++-- .../tutorial/fastapi/teams/tutorial001_py310.py | 4 ++-- .../tutorial/fastapi/teams/tutorial001_py39.py | 4 ++-- docs_src/tutorial/fastapi/update/tutorial001.py | 2 +- .../tutorial/fastapi/update/tutorial001_py310.py | 2 +- .../tutorial/fastapi/update/tutorial001_py39.py | 2 +- 19 files changed, 33 insertions(+), 30 deletions(-) diff --git a/docs/tutorial/fastapi/update.md b/docs/tutorial/fastapi/update.md index 27c413f387..cfcf8a98e7 100644 --- a/docs/tutorial/fastapi/update.md +++ b/docs/tutorial/fastapi/update.md @@ -90,7 +90,7 @@ So, we need to read the hero from the database, with the **same logic** we used The `HeroUpdate` model has all the fields with **default values**, because they all have defaults, they are all optional, which is what we want. -But that also means that if we just call `hero.dict()` we will get a dictionary that could potentially have several or all of those values with their defaults, for example: +But that also means that if we just call `hero.model_dump()` we will get a dictionary that could potentially have several or all of those values with their defaults, for example: ```Python { @@ -102,7 +102,7 @@ But that also means that if we just call `hero.dict()` we will get a dictionary And then, if we update the hero in the database with this data, we would be removing any existing values, and that's probably **not what the client intended**. -But fortunately Pydantic models (and so SQLModel models) have a parameter we can pass to the `.dict()` method for that: `exclude_unset=True`. +But fortunately Pydantic models (and so SQLModel models) have a parameter we can pass to the `.model_dump()` method for that: `exclude_unset=True`. This tells Pydantic to **not include** the values that were **not sent** by the client. Saying it another way, it would **only** include the values that were **sent by the client**. @@ -112,7 +112,7 @@ So, if the client sent a JSON with no values: {} ``` -Then the dictionary we would get in Python using `hero.dict(exclude_unset=True)` would be: +Then the dictionary we would get in Python using `hero.model_dump(exclude_unset=True)` would be: ```Python {} @@ -126,7 +126,7 @@ But if the client sent a JSON with: } ``` -Then the dictionary we would get in Python using `hero.dict(exclude_unset=True)` would be: +Then the dictionary we would get in Python using `hero.model_dump(exclude_unset=True)` would be: ```Python { @@ -152,6 +152,9 @@ Then we use that to get the data that was actually sent by the client: /// +/// tip +Before SQLModel 0.0.14, the method was called `hero.dict(exclude_unset=True)`, but it was renamed to `hero.model_dump(exclude_unset=True)` to be consistent with Pydantic v2. + ## Update the Hero in the Database Now that we have a **dictionary with the data sent by the client**, we can iterate for each one of the keys and the values, and then we set them in the database hero model `db_hero` using `setattr()`. @@ -208,7 +211,7 @@ So, if the client wanted to intentionally remove the `age` of a hero, they could } ``` -And when getting the data with `hero.dict(exclude_unset=True)`, we would get: +And when getting the data with `hero.model_dump(exclude_unset=True)`, we would get: ```Python { @@ -226,4 +229,4 @@ These are some of the advantages of Pydantic, that we can use with SQLModel. ## Recap -Using `.dict(exclude_unset=True)` in SQLModel models (and Pydantic models) we can easily update data **correctly**, even in the **edge cases**. 😎 +Using `.model_dump(exclude_unset=True)` in SQLModel models (and Pydantic models) we can easily update data **correctly**, even in the **edge cases**. 😎 diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py index a23dfad5a8..7014a73918 100644 --- a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py +++ b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py @@ -87,7 +87,7 @@ def update_hero( db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001_py310/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001_py310/main.py index 66c34d6939..cf1bbb7130 100644 --- a/docs_src/tutorial/fastapi/app_testing/tutorial001_py310/main.py +++ b/docs_src/tutorial/fastapi/app_testing/tutorial001_py310/main.py @@ -85,7 +85,7 @@ def update_hero( db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001_py39/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001_py39/main.py index d71cb34f49..9f428ab3e8 100644 --- a/docs_src/tutorial/fastapi/app_testing/tutorial001_py39/main.py +++ b/docs_src/tutorial/fastapi/app_testing/tutorial001_py39/main.py @@ -87,7 +87,7 @@ def update_hero( db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) diff --git a/docs_src/tutorial/fastapi/delete/tutorial001.py b/docs_src/tutorial/fastapi/delete/tutorial001.py index 77a99a9c97..532817360a 100644 --- a/docs_src/tutorial/fastapi/delete/tutorial001.py +++ b/docs_src/tutorial/fastapi/delete/tutorial001.py @@ -79,7 +79,7 @@ def update_hero(hero_id: int, hero: HeroUpdate): db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) diff --git a/docs_src/tutorial/fastapi/delete/tutorial001_py310.py b/docs_src/tutorial/fastapi/delete/tutorial001_py310.py index 9a3cbb1dd6..45e2e1d515 100644 --- a/docs_src/tutorial/fastapi/delete/tutorial001_py310.py +++ b/docs_src/tutorial/fastapi/delete/tutorial001_py310.py @@ -77,7 +77,7 @@ def update_hero(hero_id: int, hero: HeroUpdate): db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) diff --git a/docs_src/tutorial/fastapi/delete/tutorial001_py39.py b/docs_src/tutorial/fastapi/delete/tutorial001_py39.py index 218d423ff2..12f6bc3f9b 100644 --- a/docs_src/tutorial/fastapi/delete/tutorial001_py39.py +++ b/docs_src/tutorial/fastapi/delete/tutorial001_py39.py @@ -79,7 +79,7 @@ def update_hero(hero_id: int, hero: HeroUpdate): db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) diff --git a/docs_src/tutorial/fastapi/relationships/tutorial001.py b/docs_src/tutorial/fastapi/relationships/tutorial001.py index e5b196090e..51339e2a20 100644 --- a/docs_src/tutorial/fastapi/relationships/tutorial001.py +++ b/docs_src/tutorial/fastapi/relationships/tutorial001.py @@ -125,7 +125,7 @@ def update_hero( db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) @@ -182,7 +182,7 @@ def update_team( db_team = session.get(Team, team_id) if not db_team: raise HTTPException(status_code=404, detail="Team not found") - team_data = team.dict(exclude_unset=True) + team_data = team.model_dump(exclude_unset=True) for key, value in team_data.items(): setattr(db_team, key, value) session.add(db_team) diff --git a/docs_src/tutorial/fastapi/relationships/tutorial001_py310.py b/docs_src/tutorial/fastapi/relationships/tutorial001_py310.py index 8ac9eef332..35257bd513 100644 --- a/docs_src/tutorial/fastapi/relationships/tutorial001_py310.py +++ b/docs_src/tutorial/fastapi/relationships/tutorial001_py310.py @@ -123,7 +123,7 @@ def update_hero( db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) @@ -180,7 +180,7 @@ def update_team( db_team = session.get(Team, team_id) if not db_team: raise HTTPException(status_code=404, detail="Team not found") - team_data = team.dict(exclude_unset=True) + team_data = team.model_dump(exclude_unset=True) for key, value in team_data.items(): setattr(db_team, key, value) session.add(db_team) diff --git a/docs_src/tutorial/fastapi/relationships/tutorial001_py39.py b/docs_src/tutorial/fastapi/relationships/tutorial001_py39.py index d5c209e083..6ceae130a3 100644 --- a/docs_src/tutorial/fastapi/relationships/tutorial001_py39.py +++ b/docs_src/tutorial/fastapi/relationships/tutorial001_py39.py @@ -125,7 +125,7 @@ def update_hero( db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) @@ -182,7 +182,7 @@ def update_team( db_team = session.get(Team, team_id) if not db_team: raise HTTPException(status_code=404, detail="Team not found") - team_data = team.dict(exclude_unset=True) + team_data = team.model_dump(exclude_unset=True) for key, value in team_data.items(): setattr(db_team, key, value) session.add(db_team) diff --git a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py index a23dfad5a8..7014a73918 100644 --- a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py +++ b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py @@ -87,7 +87,7 @@ def update_hero( db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) diff --git a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py310.py b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py310.py index 66c34d6939..cf1bbb7130 100644 --- a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py310.py +++ b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py310.py @@ -85,7 +85,7 @@ def update_hero( db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) diff --git a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py39.py b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py39.py index d71cb34f49..9f428ab3e8 100644 --- a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py39.py +++ b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001_py39.py @@ -87,7 +87,7 @@ def update_hero( db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) diff --git a/docs_src/tutorial/fastapi/teams/tutorial001.py b/docs_src/tutorial/fastapi/teams/tutorial001.py index cc73bb52cb..785c525918 100644 --- a/docs_src/tutorial/fastapi/teams/tutorial001.py +++ b/docs_src/tutorial/fastapi/teams/tutorial001.py @@ -116,7 +116,7 @@ def update_hero( db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) @@ -173,7 +173,7 @@ def update_team( db_team = session.get(Team, team_id) if not db_team: raise HTTPException(status_code=404, detail="Team not found") - team_data = team.dict(exclude_unset=True) + team_data = team.model_dump(exclude_unset=True) for key, value in team_data.items(): setattr(db_team, key, value) session.add(db_team) diff --git a/docs_src/tutorial/fastapi/teams/tutorial001_py310.py b/docs_src/tutorial/fastapi/teams/tutorial001_py310.py index e57078c2ad..dea4bd8a9b 100644 --- a/docs_src/tutorial/fastapi/teams/tutorial001_py310.py +++ b/docs_src/tutorial/fastapi/teams/tutorial001_py310.py @@ -114,7 +114,7 @@ def update_hero( db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) @@ -171,7 +171,7 @@ def update_team( db_team = session.get(Team, team_id) if not db_team: raise HTTPException(status_code=404, detail="Team not found") - team_data = team.dict(exclude_unset=True) + team_data = team.model_dump(exclude_unset=True) for key, value in team_data.items(): setattr(db_team, key, value) session.add(db_team) diff --git a/docs_src/tutorial/fastapi/teams/tutorial001_py39.py b/docs_src/tutorial/fastapi/teams/tutorial001_py39.py index eaa2e04896..cc6429adcf 100644 --- a/docs_src/tutorial/fastapi/teams/tutorial001_py39.py +++ b/docs_src/tutorial/fastapi/teams/tutorial001_py39.py @@ -116,7 +116,7 @@ def update_hero( db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) @@ -173,7 +173,7 @@ def update_team( db_team = session.get(Team, team_id) if not db_team: raise HTTPException(status_code=404, detail="Team not found") - team_data = team.dict(exclude_unset=True) + team_data = team.model_dump(exclude_unset=True) for key, value in team_data.items(): setattr(db_team, key, value) session.add(db_team) diff --git a/docs_src/tutorial/fastapi/update/tutorial001.py b/docs_src/tutorial/fastapi/update/tutorial001.py index 28462bff17..5639638d5c 100644 --- a/docs_src/tutorial/fastapi/update/tutorial001.py +++ b/docs_src/tutorial/fastapi/update/tutorial001.py @@ -79,7 +79,7 @@ def update_hero(hero_id: int, hero: HeroUpdate): db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) diff --git a/docs_src/tutorial/fastapi/update/tutorial001_py310.py b/docs_src/tutorial/fastapi/update/tutorial001_py310.py index 4125c40437..4faf266f84 100644 --- a/docs_src/tutorial/fastapi/update/tutorial001_py310.py +++ b/docs_src/tutorial/fastapi/update/tutorial001_py310.py @@ -77,7 +77,7 @@ def update_hero(hero_id: int, hero: HeroUpdate): db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) diff --git a/docs_src/tutorial/fastapi/update/tutorial001_py39.py b/docs_src/tutorial/fastapi/update/tutorial001_py39.py index 2c1321dd22..b0daa87880 100644 --- a/docs_src/tutorial/fastapi/update/tutorial001_py39.py +++ b/docs_src/tutorial/fastapi/update/tutorial001_py39.py @@ -79,7 +79,7 @@ def update_hero(hero_id: int, hero: HeroUpdate): db_hero = session.get(Hero, hero_id) if not db_hero: raise HTTPException(status_code=404, detail="Hero not found") - hero_data = hero.dict(exclude_unset=True) + hero_data = hero.model_dump(exclude_unset=True) for key, value in hero_data.items(): setattr(db_hero, key, value) session.add(db_hero) From 30dbd40b4797095ef83e90d4ad3492cf3db8692b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 14:56:46 +0100 Subject: [PATCH 093/105] =?UTF-8?q?=E2=9C=85=20Update=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 7cf7b054b8..e273e23538 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,12 +57,12 @@ def new_print(*args): data = [] for arg in args: if isinstance(arg, BaseModel): - data.append(arg.dict()) + data.append(arg.model_dump()) elif isinstance(arg, list): new_list = [] for item in arg: if isinstance(item, BaseModel): - new_list.append(item.dict()) + new_list.append(item.model_dump()) data.append(new_list) else: data.append(arg) From 47d473d6ad0e06f4e2e0ec62c8070c4a1c268a62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 14:57:02 +0100 Subject: [PATCH 094/105] =?UTF-8?q?=E2=9C=85=20Add=20tests=20for=20depreca?= =?UTF-8?q?ted=20obj.dict()?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_deprecations.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py index 0e565c5410..ef66c91b53 100644 --- a/tests/test_deprecations.py +++ b/tests/test_deprecations.py @@ -22,3 +22,9 @@ def test_deprecated_parse_obj(): with pytest.warns(DeprecationWarning): item = Item.parse_obj({"name": "Hello"}) assert item.name == "Hello" + + +def test_deprecated_dict(): + with pytest.warns(DeprecationWarning): + data = Item(name="Hello").dict() + assert data == {"name": "Hello"} From fca110112062b1777e98d9facc3a3ffb4a2feb92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 14:59:04 +0100 Subject: [PATCH 095/105] =?UTF-8?q?=F0=9F=90=9B=20Fix=20recursion=20in=20P?= =?UTF-8?q?ydantic=20v1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 9d2eb3af9f..6a70a81bc4 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -785,7 +785,7 @@ def model_dump( warnings=warnings, ) else: - return self.dict( + return super().dict( include=include, exclude=exclude, by_alias=by_alias, From f1ab6a6f8994afe299318ecb87a22545e425726c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 14:59:36 +0100 Subject: [PATCH 096/105] =?UTF-8?q?=F0=9F=91=B7=20Run=20lints=20only=20on?= =?UTF-8?q?=20Pydantic=20v2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 58ba2ba492..ade60f2559 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -68,7 +68,7 @@ jobs: run: pip install "pydantic>=2.0.2,<3.0.0" - name: Lint # Do not run on Python 3.7 as mypy behaves differently - if: matrix.python-version != '3.7' + if: matrix.python-version != '3.7' && matrix.pydantic-version == 'pydantic-v2' run: python -m poetry run bash scripts/lint.sh - run: mkdir coverage - name: Test From f68e93e2da27e688f115c3d48415c3b012e21b7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 15:12:25 +0100 Subject: [PATCH 097/105] =?UTF-8?q?=F0=9F=8E=A8=20Fix=20type=20annotations?= =?UTF-8?q?=20for=20Python=203.7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 11 +++++------ sqlmodel/main.py | 4 ++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 97467dbaf7..27d6bc0de9 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -15,12 +15,11 @@ Type, TypeVar, Union, - get_args, - get_origin, ) from pydantic import VERSION as PYDANTIC_VERSION from pydantic.fields import FieldInfo +from typing_extensions import get_args, get_origin IS_PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") @@ -99,11 +98,11 @@ def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo" return model.model_fields def set_fields_set( - new_object: InstanceOrType["SQLModel"], fields: set["FieldInfo"] + new_object: InstanceOrType["SQLModel"], fields: Set["FieldInfo"] ) -> None: object.__setattr__(new_object, "__pydantic_fields_set__", fields) - def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: + def get_annotations(class_dict: Dict[str, Any]) -> Dict[str, Any]: return class_dict.get("__annotations__", {}) def is_table_model_class(cls: Type[Any]) -> bool: @@ -167,7 +166,7 @@ def is_field_noneable(field: "FieldInfo") -> bool: return False def get_type_from_field(field: Any) -> type: - type_: type | None = field.annotation + type_: Any = field.annotation # Resolve Optional fields if type_ is None: raise ValueError("Missing field type") @@ -240,7 +239,7 @@ def sqlmodel_table_construct( _fields_set = set(fields_values.keys()) fields_values.update(defaults) - _extra: dict[str, Any] | None = None + _extra: Union[Dict[str, Any], None] = None if cls.model_config.get("extra") == "allow": _extra = {} for k, v in values.items(): diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 6a70a81bc4..10064c7116 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -82,7 +82,7 @@ _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] -IncEx = Set[int] | Set[str] | Dict[int, Any] | Dict[str, Any] | None +IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any], None] def __dataclass_transform__( @@ -762,7 +762,7 @@ def model_validate( def model_dump( self, *, - mode: Literal["json", "python"] | str = "python", + mode: Union[Literal["json", "python"], str] = "python", include: IncEx = None, exclude: IncEx = None, by_alias: bool = False, From 3b2f955f41c9dc4c00ee5cff4b6b4d90965f8848 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 15:20:42 +0100 Subject: [PATCH 098/105] =?UTF-8?q?=E2=AC=87=EF=B8=8F=20Downgrade=20dirty-?= =?UTF-8?q?equals=20for=20compatibility=20with=20Python=203.7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ee737d5633..5287cb1882 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,8 @@ fastapi = "^0.103.2" ruff = "^0.1.2" # For FastAPI tests httpx = "0.24.1" -dirty-equals = "^0.7.1.post0" +# TODO: upgrade when deprecating Python 3.7 +dirty-equals = "^0.6.0" typer-cli = "^0.0.13" mkdocs-markdownextradata-plugin = ">=0.1.7,<0.3.0" From 5f4d502030c26845b7bc7c73c4de4cce329b44e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 15:21:14 +0100 Subject: [PATCH 099/105] =?UTF-8?q?=E2=9C=85=20Update=20tests=20to=20check?= =?UTF-8?q?=20for=20deprecation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_query.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_query.py b/tests/test_query.py index abca97253b..88517b92fe 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,5 +1,6 @@ from typing import Optional +import pytest from sqlmodel import Field, Session, SQLModel, create_engine @@ -21,6 +22,7 @@ class Hero(SQLModel, table=True): session.refresh(hero_1) with Session(engine) as session: - query_hero = session.query(Hero).first() + with pytest.warns(DeprecationWarning): + query_hero = session.query(Hero).first() assert query_hero assert query_hero.name == hero_1.name From b995766da5f6cf18fc9537f4b44097e3eadaa581 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 15:24:42 +0100 Subject: [PATCH 100/105] =?UTF-8?q?=F0=9F=8E=A8=20Fix=20type=20annotation?= =?UTF-8?q?=20for=20Python=203.7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 27d6bc0de9..facb6fabe5 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -388,7 +388,7 @@ def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo" return model.__fields__ # type: ignore def set_fields_set( - new_object: InstanceOrType["SQLModel"], fields: set["FieldInfo"] + new_object: InstanceOrType["SQLModel"], fields: Set["FieldInfo"] ) -> None: object.__setattr__(new_object, "__fields_set__", fields) From b3fcaf01a2ff067455aa007ae069760dcae4efd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 15:26:07 +0100 Subject: [PATCH 101/105] =?UTF-8?q?=F0=9F=8E=A8=20Fix=20type=20annotations?= =?UTF-8?q?=20in=20=5Fcompat?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index facb6fabe5..97c971dba4 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -165,7 +165,7 @@ def is_field_noneable(field: "FieldInfo") -> bool: return False return False - def get_type_from_field(field: Any) -> type: + def get_type_from_field(field: Any) -> Any: type_: Any = field.annotation # Resolve Optional fields if type_ is None: @@ -429,7 +429,7 @@ def is_field_noneable(field: "FieldInfo") -> bool: ) return field.allow_none # type: ignore[no-any-return, attr-defined] - def get_type_from_field(field: Any) -> type: + def get_type_from_field(field: Any) -> Any: if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: return field.type_ raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") From 7ab2fdcc6dc2cc4004999f421abbe88637a006dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 15:30:06 +0100 Subject: [PATCH 102/105] =?UTF-8?q?=F0=9F=8E=A8=20Fix=20type=20annotations?= =?UTF-8?q?=20for=20Python=203.7=20in=20=5Fcompat?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 97c971dba4..697c1b6a95 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -392,7 +392,7 @@ def set_fields_set( ) -> None: object.__setattr__(new_object, "__fields_set__", fields) - def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: + def get_annotations(class_dict: Dict[str, Any]) -> Dict[str, Any]: return resolve_annotations( # type: ignore[no-any-return] class_dict.get("__annotations__", {}), class_dict.get("__module__", None), From 43fba3619d2cc5564975f9f46f9d4fc9255458aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 15:30:28 +0100 Subject: [PATCH 103/105] =?UTF-8?q?=F0=9F=93=8C=20Fix=20pin=20for=20Pydant?= =?UTF-8?q?ic=20v1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5287cb1882..10d73793d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.7" SQLAlchemy = ">=2.0.0,<2.1.0" -pydantic = ">1.10.13,<3.0.0" +pydantic = ">=1.10.13,<3.0.0" [tool.poetry.group.dev.dependencies] pytest = "^7.0.1" From 115d4392d086385237b07e86bd59ed3a131b2ab4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 15:31:49 +0100 Subject: [PATCH 104/105] =?UTF-8?q?=F0=9F=8E=A8=20Fix=20type=20ignores=20a?= =?UTF-8?q?fter=20fixing=20type=20annotations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 697c1b6a95..393e8416d3 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -185,8 +185,8 @@ def get_type_from_field(field: Any) -> Any: "Cannot have a (non-optional) union as a SQLlchemy field" ) # Optional unions are allowed - return bases[0] if bases[0] is not NoneType else bases[1] # type: ignore[no-any-return] - return origin # type: ignore[no-any-return] + return bases[0] if bases[0] is not NoneType else bases[1] + return origin def get_field_metadata(field: Any) -> Any: for meta in field.metadata: From f121b00e2c949167d08999e1ebbb3e91eb6e23f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 4 Dec 2023 15:33:46 +0100 Subject: [PATCH 105/105] =?UTF-8?q?=F0=9F=8E=A8=20Fix=20more=20type=20anno?= =?UTF-8?q?tations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 393e8416d3..2a2caca3e8 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -224,8 +224,8 @@ def sqlmodel_table_construct( old_dict = self_instance.__dict__.copy() # End SQLModel override - fields_values: dict[str, Any] = {} - defaults: dict[ + fields_values: Dict[str, Any] = {} + defaults: Dict[ str, Any ] = {} # keeping this separate from `fields_values` helps us compute `_fields_set` for name, field in cls.model_fields.items():