From d0ae3e81db377cf4d3ccf489e5fe3ce18bc86ddb Mon Sep 17 00:00:00 2001 From: Mohamed Farahat Date: Tue, 7 Mar 2023 02:08:36 +0200 Subject: [PATCH 01/29] migration to sqlalchemy 2.0 --- docs_src/tutorial/many_to_many/tutorial003.py | 30 +++++++-------- .../back_populates/tutorial003.py | 30 +++++++-------- sqlmodel/__init__.py | 2 - sqlmodel/main.py | 1 + sqlmodel/sql/expression.py | 38 +++---------------- 5 files changed, 36 insertions(+), 65 deletions(-) diff --git a/docs_src/tutorial/many_to_many/tutorial003.py b/docs_src/tutorial/many_to_many/tutorial003.py index 1e03c4af89..cec6e56560 100644 --- a/docs_src/tutorial/many_to_many/tutorial003.py +++ b/docs_src/tutorial/many_to_many/tutorial003.py @@ -3,25 +3,12 @@ from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select -class HeroTeamLink(SQLModel, table=True): - team_id: Optional[int] = Field( - default=None, foreign_key="team.id", primary_key=True - ) - hero_id: Optional[int] = Field( - default=None, foreign_key="hero.id", primary_key=True - ) - is_training: bool = False - - team: "Team" = Relationship(back_populates="hero_links") - hero: "Hero" = Relationship(back_populates="team_links") - - class Team(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) name: str = Field(index=True) headquarters: str - hero_links: List[HeroTeamLink] = Relationship(back_populates="team") + hero_links: List["HeroTeamLink"] = Relationship(back_populates="team") class Hero(SQLModel, table=True): @@ -30,7 +17,20 @@ class Hero(SQLModel, table=True): secret_name: str age: Optional[int] = Field(default=None, index=True) - team_links: List[HeroTeamLink] = Relationship(back_populates="hero") + team_links: List["HeroTeamLink"] = Relationship(back_populates="hero") + + +class HeroTeamLink(SQLModel, table=True): + team_id: Optional[int] = Field( + default=None, foreign_key="team.id", primary_key=True + ) + hero_id: Optional[int] = Field( + default=None, foreign_key="hero.id", primary_key=True + ) + is_training: bool = False + + team: "Team" = Relationship(back_populates="hero_links") + hero: "Hero" = Relationship(back_populates="team_links") sqlite_file_name = "database.db" diff --git a/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py b/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py index 98e197002e..8d91a0bc25 100644 --- a/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py +++ b/docs_src/tutorial/relationship_attributes/back_populates/tutorial003.py @@ -3,6 +3,21 @@ from sqlmodel import Field, Relationship, SQLModel, create_engine +class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + secret_name: str + age: Optional[int] = Field(default=None, index=True) + + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + team: Optional["Team"] = Relationship(back_populates="heroes") + + weapon_id: Optional[int] = Field(default=None, foreign_key="weapon.id") + weapon: Optional["Weapon"] = Relationship(back_populates="hero") + + powers: List["Power"] = Relationship(back_populates="hero") + + class Weapon(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) name: str = Field(index=True) @@ -26,21 +41,6 @@ class Team(SQLModel, table=True): heroes: List["Hero"] = Relationship(back_populates="team") -class Hero(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - name: str = Field(index=True) - secret_name: str - age: Optional[int] = Field(default=None, index=True) - - team_id: Optional[int] = Field(default=None, foreign_key="team.id") - team: Optional[Team] = Relationship(back_populates="heroes") - - weapon_id: Optional[int] = Field(default=None, foreign_key="weapon.id") - weapon: Optional[Weapon] = Relationship(back_populates="hero") - - powers: List[Power] = Relationship(back_populates="hero") - - sqlite_file_name = "database.db" sqlite_url = f"sqlite:///{sqlite_file_name}" diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index 495ac9c8a8..af37421754 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -21,7 +21,6 @@ from sqlalchemy.schema import PrimaryKeyConstraint as PrimaryKeyConstraint from sqlalchemy.schema import Sequence as Sequence from sqlalchemy.schema import Table as Table -from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData from sqlalchemy.schema import UniqueConstraint as UniqueConstraint from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT from sqlalchemy.sql import ( @@ -71,7 +70,6 @@ from sqlalchemy.sql import outerjoin as outerjoin from sqlalchemy.sql import outparam as outparam from sqlalchemy.sql import over as over -from sqlalchemy.sql import subquery as subquery from sqlalchemy.sql import table as table from sqlalchemy.sql import tablesample as tablesample from sqlalchemy.sql import text as text diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 2b69dd2a75..e06b651cfd 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -642,6 +642,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore __name__: ClassVar[str] metadata: ClassVar[MetaData] + __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six class Config: orm_mode = True diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index 264e39cba7..29e7524ce7 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -117,12 +117,12 @@ class SelectOfScalar(_Select, Generic[_TSelect]): @overload -def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore +def select(entity_0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore ... @overload -def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore +def select(entity_0: Type[_TModel_0]) -> SelectOfScalar[_TModel_0]: # type: ignore ... @@ -133,7 +133,6 @@ def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1]]: ... @@ -142,7 +141,6 @@ def select( # type: ignore def select( # type: ignore entity_0: _TScalar_0, entity_1: Type[_TModel_1], - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1]]: ... @@ -151,7 +149,6 @@ def select( # type: ignore def select( # type: ignore entity_0: Type[_TModel_0], entity_1: _TScalar_1, - **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1]]: ... @@ -160,7 +157,6 @@ def select( # type: ignore def select( # type: ignore entity_0: Type[_TModel_0], entity_1: Type[_TModel_1], - **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1]]: ... @@ -170,7 +166,6 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, entity_2: _TScalar_2, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2]]: ... @@ -180,7 +175,6 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, entity_2: Type[_TModel_2], - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2]]: ... @@ -190,7 +184,6 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: Type[_TModel_1], entity_2: _TScalar_2, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2]]: ... @@ -200,7 +193,6 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2]]: ... @@ -210,7 +202,6 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: _TScalar_1, entity_2: _TScalar_2, - **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2]]: ... @@ -220,7 +211,6 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: _TScalar_1, entity_2: Type[_TModel_2], - **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2]]: ... @@ -230,7 +220,6 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: Type[_TModel_1], entity_2: _TScalar_2, - **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2]]: ... @@ -240,7 +229,6 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], - **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2]]: ... @@ -251,7 +239,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]: ... @@ -262,7 +249,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TModel_3]]: ... @@ -273,7 +259,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TScalar_3]]: ... @@ -284,7 +269,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TModel_3]]: ... @@ -295,7 +279,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TScalar_3]]: ... @@ -306,7 +289,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TModel_3]]: ... @@ -317,7 +299,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TScalar_3]]: ... @@ -328,7 +309,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TModel_3]]: ... @@ -339,7 +319,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TScalar_3]]: ... @@ -350,7 +329,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TModel_3]]: ... @@ -361,7 +339,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TScalar_3]]: ... @@ -372,7 +349,6 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TModel_3]]: ... @@ -383,7 +359,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TScalar_3]]: ... @@ -394,7 +369,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TModel_3]]: ... @@ -405,7 +379,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: _TScalar_3, - **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TScalar_3]]: ... @@ -416,7 +389,6 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], - **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TModel_3]]: ... @@ -424,10 +396,10 @@ def select( # type: ignore # Generated overloads end -def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore +def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore if len(entities) == 1: - return SelectOfScalar._create(*entities, **kw) # type: ignore - return Select._create(*entities, **kw) # type: ignore + return SelectOfScalar(*entities) # type: ignore + return Select(*entities) # type: ignore # TODO: add several @overload from Python types to SQLAlchemy equivalents From 8a27655052dfdc7a6317da72895df5a46a7760b1 Mon Sep 17 00:00:00 2001 From: Mohamed Farahat Date: Fri, 24 Mar 2023 14:56:26 +0200 Subject: [PATCH 02/29] fix some linting errors --- sqlmodel/engine/create.py | 2 +- sqlmodel/engine/result.py | 20 ++++++++++---------- sqlmodel/main.py | 4 ++-- sqlmodel/orm/session.py | 2 +- sqlmodel/sql/expression.py | 2 +- sqlmodel/sql/sqltypes.py | 6 +++--- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/sqlmodel/engine/create.py b/sqlmodel/engine/create.py index b2d567b1b1..97481259e2 100644 --- a/sqlmodel/engine/create.py +++ b/sqlmodel/engine/create.py @@ -136,4 +136,4 @@ def create_engine( if not isinstance(query_cache_size, _DefaultPlaceholder): current_kwargs["query_cache_size"] = query_cache_size current_kwargs.update(kwargs) - return _create_engine(url, **current_kwargs) # type: ignore + return _create_engine(url, **current_kwargs) diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py index 7a25422227..17020d9995 100644 --- a/sqlmodel/engine/result.py +++ b/sqlmodel/engine/result.py @@ -1,4 +1,4 @@ -from typing import Generic, Iterator, List, Optional, TypeVar +from typing import Generic, Iterator, List, Optional, Sequence, TypeVar from sqlalchemy.engine.result import Result as _Result from sqlalchemy.engine.result import ScalarResult as _ScalarResult @@ -6,24 +6,24 @@ _T = TypeVar("_T") -class ScalarResult(_ScalarResult, Generic[_T]): - def all(self) -> List[_T]: +class ScalarResult(_ScalarResult[_T], Generic[_T]): + def all(self) -> Sequence[_T]: return super().all() - def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: + def partitions(self, size: Optional[int] = None) -> Iterator[Sequence[_T]]: return super().partitions(size) - def fetchall(self) -> List[_T]: + def fetchall(self) -> Sequence[_T]: return super().fetchall() - def fetchmany(self, size: Optional[int] = None) -> List[_T]: + def fetchmany(self, size: Optional[int] = None) -> Sequence[_T]: return super().fetchmany(size) def __iter__(self) -> Iterator[_T]: return super().__iter__() def __next__(self) -> _T: - return super().__next__() # type: ignore + return super().__next__() def first(self) -> Optional[_T]: return super().first() @@ -32,10 +32,10 @@ def one_or_none(self) -> Optional[_T]: return super().one_or_none() def one(self) -> _T: - return super().one() # type: ignore + return super().one() -class Result(_Result, Generic[_T]): +class Result(_Result[_T], Generic[_T]): def scalars(self, index: int = 0) -> ScalarResult[_T]: return super().scalars(index) # type: ignore @@ -76,4 +76,4 @@ def one(self) -> _T: # type: ignore return super().one() # type: ignore def scalar(self) -> Optional[_T]: - return super().scalar() + return super().scalar() # type: ignore diff --git a/sqlmodel/main.py b/sqlmodel/main.py index e06b651cfd..06e58618aa 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -642,7 +642,7 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore __name__: ClassVar[str] metadata: ClassVar[MetaData] - __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six + __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six class Config: orm_mode = True @@ -686,7 +686,7 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if getattr(self.__config__, "table", False) and is_instrumented(self, name): + if getattr(self.__config__, "table", False) and is_instrumented(self, name): # type: ignore set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 0c70c290ae..03f33037ff 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -118,7 +118,7 @@ def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]": Or otherwise you might want to use `session.execute()` instead of `session.query()`. """ - return super().query(*entities, **kwargs) + return super().query(*entities, **kwargs) # type: ignore def get( self, diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index 29e7524ce7..f473ba4a5a 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -406,4 +406,4 @@ def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore def col(column_expression: Any) -> ColumnClause: # type: ignore if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") - return column_expression + return column_expression # type: ignore diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 17d9b06126..aa30950702 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -15,7 +15,7 @@ class AutoString(types.TypeDecorator): # type: ignore def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": impl = cast(types.String, self.impl) if impl.length is None and dialect.name == "mysql": - return dialect.type_descriptor(types.String(self.mysql_default_length)) # type: ignore + return dialect.type_descriptor(types.String(self.mysql_default_length)) return super().load_dialect_impl(dialect) @@ -34,9 +34,9 @@ class GUID(types.TypeDecorator): # type: ignore def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore if dialect.name == "postgresql": - return dialect.type_descriptor(UUID()) # type: ignore + return dialect.type_descriptor(UUID()) else: - return dialect.type_descriptor(CHAR(32)) # type: ignore + return dialect.type_descriptor(CHAR(32)) def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]: if value is None: From a2b3c1465892acf5b3309110a791cb99520e5f8d Mon Sep 17 00:00:00 2001 From: Mohamed Farahat Date: Fri, 31 Mar 2023 12:59:56 +0200 Subject: [PATCH 03/29] reflecting python 3.6 deprecation in docs and tests --- .github/workflows/test.yml | 7 ++----- docs/contributing.md | 4 ++++ docs/features.md | 2 +- docs/tutorial/index.md | 3 +++ pyproject.toml | 1 + 5 files changed, 11 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 201abc7c22..d2d045aeb9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,11 +20,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: - - "3.7" - - "3.8" - - "3.9" - - "3.10" + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] fail-fast: false steps: @@ -56,6 +52,7 @@ jobs: if: steps.cache.outputs.cache-hit != 'true' run: python -m poetry install - name: Lint + if: ${{ matrix.python-version != '3.7' }} run: python -m poetry run bash scripts/lint.sh - run: mkdir coverage - name: Test diff --git a/docs/contributing.md b/docs/contributing.md index 217ed61c56..6a1ae2d6c7 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -6,6 +6,10 @@ First, you might want to see the basic ways to [help SQLModel and get help](help If you already cloned the repository and you know that you need to deep dive in the code, here are some guidelines to set up your environment. +### Python + +SQLModel supports Python 3.7 and above, but for development you should have at least **Python 3.7**. + ### Poetry **SQLModel** uses Poetry to build, package, and publish the project. diff --git a/docs/features.md b/docs/features.md index 102edef725..2d5e11d84f 100644 --- a/docs/features.md +++ b/docs/features.md @@ -12,7 +12,7 @@ Nevertheless, SQLModel is completely **independent** of FastAPI and can be used ## Just Modern Python -It's all based on standard modern **Python** type annotations. No new syntax to learn. Just standard modern Python. +It's all based on standard modern **Python** type annotations. No new syntax to learn. Just standard modern Python. If you need a 2 minute refresher of how to use Python types (even if you don't use SQLModel or FastAPI), check the FastAPI tutorial section: Python types intro. diff --git a/docs/tutorial/index.md b/docs/tutorial/index.md index 74107776c2..5a4333df11 100644 --- a/docs/tutorial/index.md +++ b/docs/tutorial/index.md @@ -64,6 +64,8 @@ $ cd sqlmodel-tutorial Make sure you have an officially supported version of Python. +Currently it is **Python 3.7** and above (Python 3.6 was already deprecated). + You can check which version you have with:
@@ -82,6 +84,7 @@ You might want to try with the specific versions, for example with: * `python3.10` * `python3.9` * `python3.8` +* `python3.7` The code would look like this: diff --git a/pyproject.toml b/pyproject.toml index 23fa79bf31..6777d69018 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Database", "Topic :: Database :: Database Engines/Servers", "Topic :: Internet", From 4760dbde8c88d296aab3fde2a9cf3c7974cee381 Mon Sep 17 00:00:00 2001 From: farahats9 Date: Fri, 31 Mar 2023 12:58:30 +0200 Subject: [PATCH 04/29] Update sqlmodel/sql/expression.py Co-authored-by: Stefan Borer --- sqlmodel/sql/expression.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index f473ba4a5a..10776e9389 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -22,8 +22,7 @@ _TSelect = TypeVar("_TSelect") - -class Select(_Select, Generic[_TSelect]): +class Select(_Select[_TSelect], Generic[_TSelect]): inherit_cache = True From ad76a886ccadaabd6d8617c8cf2eda119e25284d Mon Sep 17 00:00:00 2001 From: Mohamed Farahat Date: Fri, 31 Mar 2023 13:18:38 +0200 Subject: [PATCH 05/29] resolving @sbor23 comments --- .github/workflows/test.yml | 1 - docs/contributing.md | 2 +- docs/tutorial/index.md | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d2d045aeb9..9f2e06ed71 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -52,7 +52,6 @@ jobs: if: steps.cache.outputs.cache-hit != 'true' run: python -m poetry install - name: Lint - if: ${{ matrix.python-version != '3.7' }} run: python -m poetry run bash scripts/lint.sh - run: mkdir coverage - name: Test diff --git a/docs/contributing.md b/docs/contributing.md index 6a1ae2d6c7..5c160a7c0a 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -8,7 +8,7 @@ If you already cloned the repository and you know that you need to deep dive in ### Python -SQLModel supports Python 3.7 and above, but for development you should have at least **Python 3.7**. +SQLModel supports Python 3.7 and above. ### Poetry diff --git a/docs/tutorial/index.md b/docs/tutorial/index.md index 5a4333df11..54e1147d68 100644 --- a/docs/tutorial/index.md +++ b/docs/tutorial/index.md @@ -81,6 +81,7 @@ There's a chance that you have multiple Python versions installed. You might want to try with the specific versions, for example with: +* `python3.11` * `python3.10` * `python3.9` * `python3.8` From 96e44e5cadebc0d32b209ab275f61294743ab7e4 Mon Sep 17 00:00:00 2001 From: Mohamed Farahat Date: Sun, 30 Apr 2023 17:59:23 +0300 Subject: [PATCH 06/29] add the new Subquery class --- sqlmodel/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index af37421754..4431d7cea7 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -70,6 +70,7 @@ from sqlalchemy.sql import outerjoin as outerjoin from sqlalchemy.sql import outparam as outparam from sqlalchemy.sql import over as over +from sqlalchemy.sql import Subquery as Subquery from sqlalchemy.sql import table as table from sqlalchemy.sql import tablesample as tablesample from sqlalchemy.sql import text as text From 9e72750eea8ada4d270013bf56db5b38a6b9a709 Mon Sep 17 00:00:00 2001 From: Mohamed Farahat Date: Sun, 30 Apr 2023 18:42:46 +0300 Subject: [PATCH 07/29] fix jinja2 template --- sqlmodel/sql/expression.py | 1 + sqlmodel/sql/expression.py.jinja2 | 25 +++++++++++++++++-------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index 10776e9389..32ec7c428b 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -22,6 +22,7 @@ _TSelect = TypeVar("_TSelect") + class Select(_Select[_TSelect], Generic[_TSelect]): inherit_cache = True diff --git a/sqlmodel/sql/expression.py.jinja2 b/sqlmodel/sql/expression.py.jinja2 index 26d12a0395..9bda190415 100644 --- a/sqlmodel/sql/expression.py.jinja2 +++ b/sqlmodel/sql/expression.py.jinja2 @@ -20,7 +20,14 @@ from sqlalchemy.sql.expression import Select as _Select _TSelect = TypeVar("_TSelect") -class Select(_Select, Generic[_TSelect]): +class Select(_Select[_TSelect], Generic[_TSelect]): + inherit_cache = True + +# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different +# purpose. This is the same as a normal SQLAlchemy Select class where there's only one +# entity, so the result will be converted to a scalar by default. This way writing +# for loops on the results will feel natural. +class SelectOfScalar(_Select[_TSelect], Generic[_TSelect]): inherit_cache = True # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different @@ -35,6 +42,7 @@ if TYPE_CHECKING: # pragma: no cover # Generated TypeVars start + {% for i in range(number_of_types) %} _TScalar_{{ i }} = TypeVar( "_TScalar_{{ i }}", @@ -58,12 +66,12 @@ _TModel_{{ i }} = TypeVar("_TModel_{{ i }}", bound="SQLModel") # Generated TypeVars end @overload -def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore +def select(entity_0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore ... @overload -def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore +def select(entity_0: Type[_TModel_0]) -> SelectOfScalar[_TModel_0]: # type: ignore ... @@ -73,7 +81,7 @@ def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: @overload def select( # type: ignore - {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}**kw: Any, + {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %} ) -> Select[Tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]: ... @@ -81,14 +89,15 @@ def select( # type: ignore # Generated overloads end -def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore + +def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore if len(entities) == 1: - return SelectOfScalar._create(*entities, **kw) # type: ignore - return Select._create(*entities, **kw) # type: ignore + return SelectOfScalar(*entities) # type: ignore + return Select(*entities) # type: ignore # TODO: add several @overload from Python types to SQLAlchemy equivalents def col(column_expression: Any) -> ColumnClause: # type: ignore if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") - return column_expression + return column_expression # type: ignore From b19a70961ffa9d7c8673bc19d5732c1c8351e436 Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 13:37:18 +0200 Subject: [PATCH 08/29] `Result` expects a type `Tuple[_T]` --- sqlmodel/engine/result.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py index 17020d9995..a0ddf283b9 100644 --- a/sqlmodel/engine/result.py +++ b/sqlmodel/engine/result.py @@ -1,4 +1,4 @@ -from typing import Generic, Iterator, List, Optional, Sequence, TypeVar +from typing import Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar from sqlalchemy.engine.result import Result as _Result from sqlalchemy.engine.result import ScalarResult as _ScalarResult @@ -35,7 +35,7 @@ def one(self) -> _T: return super().one() -class Result(_Result[_T], Generic[_T]): +class Result(_Result[Tuple[_T]], Generic[_T]): def scalars(self, index: int = 0) -> ScalarResult[_T]: return super().scalars(index) # type: ignore From ae369ed23a58ed47a053d11e7263f9a9ca4bf7b9 Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 13:37:39 +0200 Subject: [PATCH 09/29] Remove unused type ignore --- sqlmodel/engine/result.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py index a0ddf283b9..ecdb6cd547 100644 --- a/sqlmodel/engine/result.py +++ b/sqlmodel/engine/result.py @@ -67,7 +67,7 @@ def one_or_none(self) -> Optional[_T]: # type: ignore return super().one_or_none() # type: ignore def scalar_one(self) -> _T: - return super().scalar_one() # type: ignore + return super().scalar_one() def scalar_one_or_none(self) -> Optional[_T]: return super().scalar_one_or_none() @@ -76,4 +76,4 @@ def one(self) -> _T: # type: ignore return super().one() # type: ignore def scalar(self) -> Optional[_T]: - return super().scalar() # type: ignore + return super().scalar() From 2ff42db5d9f06447a2d0b25371ef39fb75c04b21 Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 14:05:06 +0200 Subject: [PATCH 10/29] Result seems well enough typed in SqlAlchemy now we can simply shim over --- sqlmodel/engine/result.py | 44 ++------------------------------------- 1 file changed, 2 insertions(+), 42 deletions(-) diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py index ecdb6cd547..650dd92b27 100644 --- a/sqlmodel/engine/result.py +++ b/sqlmodel/engine/result.py @@ -1,4 +1,4 @@ -from typing import Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar +from typing import Generic, Iterator, Optional, Sequence, Tuple, TypeVar from sqlalchemy.engine.result import Result as _Result from sqlalchemy.engine.result import ScalarResult as _ScalarResult @@ -36,44 +36,4 @@ def one(self) -> _T: class Result(_Result[Tuple[_T]], Generic[_T]): - def scalars(self, index: int = 0) -> ScalarResult[_T]: - return super().scalars(index) # type: ignore - - def __iter__(self) -> Iterator[_T]: # type: ignore - return super().__iter__() # type: ignore - - def __next__(self) -> _T: # type: ignore - return super().__next__() # type: ignore - - def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: # type: ignore - return super().partitions(size) # type: ignore - - def fetchall(self) -> List[_T]: # type: ignore - return super().fetchall() # type: ignore - - def fetchone(self) -> Optional[_T]: # type: ignore - return super().fetchone() # type: ignore - - def fetchmany(self, size: Optional[int] = None) -> List[_T]: # type: ignore - return super().fetchmany() # type: ignore - - def all(self) -> List[_T]: # type: ignore - return super().all() # type: ignore - - def first(self) -> Optional[_T]: # type: ignore - return super().first() # type: ignore - - def one_or_none(self) -> Optional[_T]: # type: ignore - return super().one_or_none() # type: ignore - - def scalar_one(self) -> _T: - return super().scalar_one() - - def scalar_one_or_none(self) -> Optional[_T]: - return super().scalar_one_or_none() - - def one(self) -> _T: # type: ignore - return super().one() # type: ignore - - def scalar(self) -> Optional[_T]: - return super().scalar() + ... From b6bd94f5a7a2551fabc0ddcbecb2abcde0ba28c8 Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 14:11:13 +0200 Subject: [PATCH 11/29] _Select expects a `Tuple[Any, ...]` --- sqlmodel/sql/expression.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index 32ec7c428b..a0ac1bd9d9 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -23,7 +23,7 @@ _TSelect = TypeVar("_TSelect") -class Select(_Select[_TSelect], Generic[_TSelect]): +class Select(_Select[Tuple[_TSelect]], Generic[_TSelect]): inherit_cache = True @@ -31,7 +31,7 @@ class Select(_Select[_TSelect], Generic[_TSelect]): # purpose. This is the same as a normal SQLAlchemy Select class where there's only one # entity, so the result will be converted to a scalar by default. This way writing # for loops on the results will feel natural. -class SelectOfScalar(_Select, Generic[_TSelect]): +class SelectOfScalar(_Select[Tuple[_TSelect]], Generic[_TSelect]): inherit_cache = True From 1dbce4d3c46e95c0cc8e089f2c936f2931136210 Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 14:21:03 +0200 Subject: [PATCH 12/29] Use Dict type instead of Mapping for SqlAlchemy compat --- sqlmodel/orm/session.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 03f33037ff..21bc0780e7 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -1,4 +1,14 @@ -from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload +from typing import ( + Any, + Dict, + Mapping, + Optional, + Sequence, + Type, + TypeVar, + Union, + overload, +) from sqlalchemy import util from sqlalchemy.orm import Query as _Query @@ -21,7 +31,7 @@ def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, @@ -35,7 +45,7 @@ def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, @@ -52,7 +62,7 @@ def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, @@ -75,7 +85,7 @@ def execute( statement: _Executable, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT, - bind_arguments: Optional[Mapping[str, Any]] = None, + bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, From c2d310e94442fc17bc957067b02b030a374485ce Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 14:24:01 +0200 Subject: [PATCH 13/29] Execution options are not Optional in SA --- sqlmodel/orm/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 21bc0780e7..69fc7c68a7 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -84,7 +84,7 @@ def execute( self, statement: _Executable, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, - execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, bind_arguments: Optional[Dict[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, From 05a1352aa59f3db36707747d95d996405794ca68 Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 14:26:41 +0200 Subject: [PATCH 14/29] Another instance of non-optional execution_options --- sqlmodel/orm/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 69fc7c68a7..c467a3b524 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -138,7 +138,7 @@ def get( populate_existing: bool = False, with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None, identity_token: Optional[Any] = None, - execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT, + execution_options: Mapping[Any, Any] = util.EMPTY_DICT, ) -> Optional[_TSelectParam]: return super().get( entity, From bd00a2b4498f5517df3bee2e9c6ae107856c0b2c Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 14:49:28 +0200 Subject: [PATCH 15/29] Fix Tuple in jinja template as well --- sqlmodel/sql/expression.py.jinja2 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/sql/expression.py.jinja2 b/sqlmodel/sql/expression.py.jinja2 index 9bda190415..55f4a1ac3e 100644 --- a/sqlmodel/sql/expression.py.jinja2 +++ b/sqlmodel/sql/expression.py.jinja2 @@ -20,14 +20,14 @@ from sqlalchemy.sql.expression import Select as _Select _TSelect = TypeVar("_TSelect") -class Select(_Select[_TSelect], Generic[_TSelect]): +class Select(_Select[Tuple[_TSelect]], Generic[_TSelect]): inherit_cache = True # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different # purpose. This is the same as a normal SQLAlchemy Select class where there's only one # entity, so the result will be converted to a scalar by default. This way writing # for loops on the results will feel natural. -class SelectOfScalar(_Select[_TSelect], Generic[_TSelect]): +class SelectOfScalar(_Select[Tuple[_TSelect]], Generic[_TSelect]): inherit_cache = True # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different From c65f018356fac713913aa28b1031d5b28477459c Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 17:39:35 +0200 Subject: [PATCH 16/29] Use ForUpdateArg from sqlalchemy --- sqlmodel/orm/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index c467a3b524..fc96f7def2 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -136,7 +136,7 @@ def get( ident: Any, options: Optional[Sequence[Any]] = None, populate_existing: bool = False, - with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None, + with_for_update: Optional[_ForUpdateArg] = None, identity_token: Optional[Any] = None, execution_options: Mapping[Any, Any] = util.EMPTY_DICT, ) -> Optional[_TSelectParam]: From 4e29e002ec803b1c7e32450efb372a6fee34fbed Mon Sep 17 00:00:00 2001 From: Peter Landry Date: Wed, 26 Jul 2023 18:05:49 +0200 Subject: [PATCH 17/29] Fix signature for `Session.get` --- sqlmodel/orm/session.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index fc96f7def2..1d12188ce8 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -11,6 +11,7 @@ ) from sqlalchemy import util +from sqlalchemy.orm import Mapper as _Mapper from sqlalchemy.orm import Query as _Query from sqlalchemy.orm import Session as _Session from sqlalchemy.sql.base import Executable as _Executable @@ -132,13 +133,14 @@ def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]": def get( self, - entity: Type[_TSelectParam], + entity: Union[Type[_TSelectParam], "_Mapper[_TSelectParam]"], ident: Any, options: Optional[Sequence[Any]] = None, populate_existing: bool = False, with_for_update: Optional[_ForUpdateArg] = None, identity_token: Optional[Any] = None, execution_options: Mapping[Any, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Dict[str, Any]] = None, ) -> Optional[_TSelectParam]: return super().get( entity, @@ -148,4 +150,5 @@ def get( with_for_update=with_for_update, identity_token=identity_token, execution_options=execution_options, + bind_arguments=bind_arguments ) From fd85d02ceb16c8368b7285a5991baa618f9c9742 Mon Sep 17 00:00:00 2001 From: Mohamed Farahat Date: Thu, 27 Jul 2023 23:03:52 +0300 Subject: [PATCH 18/29] formatting and remove unused type --- sqlmodel/orm/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 1d12188ce8..c9bb043336 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -150,5 +150,5 @@ def get( with_for_update=with_for_update, identity_token=identity_token, execution_options=execution_options, - bind_arguments=bind_arguments + bind_arguments=bind_arguments, ) From d556059a2ce4abfcf489caf94033a7f1f63e5c0f Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Mon, 31 Jul 2023 13:38:26 +0000 Subject: [PATCH 19/29] Upgrade to Pydantic 2 Change imports Undefined => PydanticUndefined Update SQLModelMetaclass and SQLModel __init__ and __new__ functions Update SQL Alchemy type inference --- docs/tutorial/fastapi/multiple-models.md | 6 +- .../fastapi/app_testing/tutorial001/main.py | 2 +- .../tutorial/fastapi/delete/tutorial001.py | 2 +- .../fastapi/limit_and_offset/tutorial001.py | 2 +- .../fastapi/multiple_models/tutorial001.py | 2 +- .../fastapi/multiple_models/tutorial002.py | 2 +- .../tutorial/fastapi/read_one/tutorial001.py | 2 +- .../fastapi/relationships/tutorial001.py | 4 +- .../session_with_dependency/tutorial001.py | 2 +- .../tutorial/fastapi/teams/tutorial001.py | 4 +- .../tutorial/fastapi/update/tutorial001.py | 2 +- sqlmodel/main.py | 285 +++++------------- sqlmodel/typing.py | 7 + tests/test_instance_no_args.py | 2 +- tests/test_missing_type.py | 3 +- tests/test_validation.py | 8 +- 16 files changed, 109 insertions(+), 226 deletions(-) create mode 100644 sqlmodel/typing.py diff --git a/docs/tutorial/fastapi/multiple-models.md b/docs/tutorial/fastapi/multiple-models.md index 6845b9862d..4ea24c6752 100644 --- a/docs/tutorial/fastapi/multiple-models.md +++ b/docs/tutorial/fastapi/multiple-models.md @@ -174,13 +174,13 @@ Now we use the type annotation `HeroCreate` for the request JSON data in the `he # Code below omitted 👇 ``` -Then we create a new `Hero` (this is the actual **table** model that saves things to the database) using `Hero.from_orm()`. +Then we create a new `Hero` (this is the actual **table** model that saves things to the database) using `Hero.model_validate()`. -The method `.from_orm()` reads data from another object with attributes and creates a new instance of this class, in this case `Hero`. +The method `.model_validate()` reads data from another object with attributes and creates a new instance of this class, in this case `Hero`. The alternative is `Hero.parse_obj()` that reads data from a dictionary. -But as in this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.from_orm()` to read those attributes. +But as in this case, we have a `HeroCreate` instance in the `hero` variable. This is an object with attributes, so we use `.model_validate()` to read those attributes. With this, we create a new `Hero` instance (the one for the database) and put it in the variable `db_hero` from the data in the `hero` variable that is the `HeroCreate` instance we received from the request. diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py index 3f0602e4b4..a23dfad5a8 100644 --- a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py +++ b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py @@ -54,7 +54,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/delete/tutorial001.py b/docs_src/tutorial/fastapi/delete/tutorial001.py index 3069fc5e87..77a99a9c97 100644 --- a/docs_src/tutorial/fastapi/delete/tutorial001.py +++ b/docs_src/tutorial/fastapi/delete/tutorial001.py @@ -50,7 +50,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py index 2b8739ca70..2352f39022 100644 --- a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py +++ b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py @@ -44,7 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py index df20123333..7f59ac6a1d 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py @@ -46,7 +46,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py index 392c2c5829..fffbe72496 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py @@ -44,7 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/read_one/tutorial001.py b/docs_src/tutorial/fastapi/read_one/tutorial001.py index 4d66e471a5..f18426e74c 100644 --- a/docs_src/tutorial/fastapi/read_one/tutorial001.py +++ b/docs_src/tutorial/fastapi/read_one/tutorial001.py @@ -44,7 +44,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/relationships/tutorial001.py b/docs_src/tutorial/fastapi/relationships/tutorial001.py index 8477e4a2a0..e5b196090e 100644 --- a/docs_src/tutorial/fastapi/relationships/tutorial001.py +++ b/docs_src/tutorial/fastapi/relationships/tutorial001.py @@ -92,7 +92,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -146,7 +146,7 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.from_orm(team) + db_team = Team.model_validate(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py index 3f0602e4b4..a23dfad5a8 100644 --- a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py +++ b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py @@ -54,7 +54,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/teams/tutorial001.py b/docs_src/tutorial/fastapi/teams/tutorial001.py index 1da0dad8a2..cc73bb52cb 100644 --- a/docs_src/tutorial/fastapi/teams/tutorial001.py +++ b/docs_src/tutorial/fastapi/teams/tutorial001.py @@ -83,7 +83,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -137,7 +137,7 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.from_orm(team) + db_team = Team.model_validate(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/update/tutorial001.py b/docs_src/tutorial/fastapi/update/tutorial001.py index bb98efd581..28462bff17 100644 --- a/docs_src/tutorial/fastapi/update/tutorial001.py +++ b/docs_src/tutorial/fastapi/update/tutorial001.py @@ -50,7 +50,7 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.from_orm(hero) + db_hero = Hero.model_validate(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 06e58618aa..0ca69be152 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -19,7 +19,7 @@ Set, Tuple, Type, - TypeVar, + TypeVar,ForwardRef, Union, cast, overload, @@ -53,8 +53,10 @@ from sqlalchemy.sql.sqltypes import LargeBinary, Time from .sql.sqltypes import GUID, AutoString +from .typing import SQLModelConfig _T = TypeVar("_T") +NoArgAnyCallable = Callable[[], Any] def __dataclass_transform__( @@ -68,10 +70,10 @@ def __dataclass_transform__( class FieldInfo(PydanticFieldInfo): - def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: + def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) - nullable = kwargs.pop("nullable", Undefined) - foreign_key = kwargs.pop("foreign_key", Undefined) + nullable = kwargs.pop("nullable", PydanticUndefined) + foreign_key = kwargs.pop("foreign_key", PydanticUndefined) unique = kwargs.pop("unique", False) index = kwargs.pop("index", Undefined) sa_type = kwargs.pop("sa_type", Undefined) @@ -84,7 +86,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: "Passing sa_column_args is not supported when " "also passing a sa_column" ) - if sa_column_kwargs is not Undefined: + if sa_column_kwargs is not PydanticUndefined: raise RuntimeError( "Passing sa_column_kwargs is not supported when " "also passing a sa_column" @@ -157,7 +159,7 @@ def __init__( @overload def Field( - default: Any = Undefined, + default: Any = PydanticUndefined, *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, @@ -314,7 +316,6 @@ def Field( sa_column_kwargs=sa_column_kwargs, **current_schema_extra, ) - field_info._validate() return field_info @@ -343,7 +344,7 @@ def Relationship( *, back_populates: Optional[str] = None, link_model: Optional[Any] = None, - sa_relationship: Optional[RelationshipProperty] = None, # type: ignore + sa_relationship: Optional[RelationshipProperty] = None, sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, ) -> Any: @@ -360,18 +361,18 @@ def Relationship( @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: Dict[str, RelationshipInfo] - __config__: Type[BaseConfig] - __fields__: Dict[str, ModelField] + model_config: Type[SQLModelConfig] + model_fields: Dict[str, FieldInfo] # Replicate SQLAlchemy def __setattr__(cls, name: str, value: Any) -> None: - if getattr(cls.__config__, "table", False): + if cls.model_config.get("table", False): DeclarativeMeta.__setattr__(cls, name, value) else: super().__setattr__(name, value) def __delattr__(cls, name: str) -> None: - if getattr(cls.__config__, "table", False): + if cls.model_config.get("table", False): DeclarativeMeta.__delattr__(cls, name) else: super().__delattr__(name) @@ -384,11 +385,10 @@ def __new__( class_dict: Dict[str, Any], **kwargs: Any, ) -> Any: + relationships: Dict[str, RelationshipInfo] = {} dict_for_pydantic = {} - original_annotations = resolve_annotations( - class_dict.get("__annotations__", {}), class_dict.get("__module__", None) - ) + original_annotations = class_dict.get("__annotations__", {}) pydantic_annotations = {} relationship_annotations = {} for k, v in class_dict.items(): @@ -412,7 +412,7 @@ def __new__( # superclass causing an error allowed_config_kwargs: Set[str] = { key - for key in dir(BaseConfig) + for key in dir(SQLModelConfig) if not ( key.startswith("__") and key.endswith("__") ) # skip dunder methods and attributes @@ -422,38 +422,46 @@ def __new__( key: pydantic_kwargs.pop(key) for key in pydantic_kwargs.keys() & allowed_config_kwargs } + config_table = getattr(class_dict.get('Config', object()), 'table', False) + # If we have a table, we need to have defaults for all fields + # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything + if config_table is True: + for key in original_annotations.keys(): + if dict_used.get(key, PydanticUndefined) is PydanticUndefined: + dict_used[key] = None + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, **new_cls.__annotations__, } - + def get_config(name: str) -> Any: - config_class_value = getattr(new_cls.__config__, name, Undefined) - if config_class_value is not Undefined: + config_class_value = new_cls.model_config.get(name, PydanticUndefined) + if config_class_value is not PydanticUndefined: return config_class_value - kwarg_value = kwargs.get(name, Undefined) - if kwarg_value is not Undefined: + kwarg_value = kwargs.get(name, PydanticUndefined) + if kwarg_value is not PydanticUndefined: return kwarg_value - return Undefined + return PydanticUndefined config_table = get_config("table") if config_table is True: # If it was passed by kwargs, ensure it's also set in config - new_cls.__config__.table = config_table - for k, v in new_cls.__fields__.items(): + new_cls.model_config['table'] = config_table + for k, v in new_cls.model_fields.items(): col = get_column_from_field(v) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field # in orm_mode instead of preemptively converting it to a dict. - # This could be done by reading new_cls.__config__.table in FastAPI, but + # This could be done by reading new_cls.model_config['table'] in FastAPI, but # that's very specific about SQLModel, so let's have another config that # other future tools based on Pydantic can use. - new_cls.__config__.read_with_orm_mode = True + new_cls.model_config['read_from_attributes'] = True config_registry = get_config("registry") - if config_registry is not Undefined: + if config_registry is not PydanticUndefined: config_registry = cast(registry, config_registry) # If it was passed by kwargs, ensure it's also set in config new_cls.__config__.registry = config_table @@ -484,16 +492,15 @@ def __init__( setattr(cls, rel_name, rel_info.sa_relationship) # Fix #315 continue ann = cls.__annotations__[rel_name] - temp_field = ModelField.infer( - name=rel_name, - value=rel_info, - annotation=ann, - class_validators=None, - config=BaseConfig, - ) - relationship_to = temp_field.type_ - if isinstance(temp_field.type_, ForwardRef): - relationship_to = temp_field.type_.__forward_arg__ + relationship_to = get_origin(ann) + # If Union (Optional), get the real field + if relationship_to is Union: + relationship_to = get_args(ann)[0] + # If a list, then also get the real field + elif relationship_to is list: + relationship_to = get_args(ann)[0] + if isinstance(relationship_to, ForwardRef): + relationship_to = relationship_to.__forward_arg__ rel_kwargs: Dict[str, Any] = {} if rel_info.back_populates: rel_kwargs["back_populates"] = rel_info.back_populates @@ -511,7 +518,7 @@ def __init__( rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: rel_kwargs.update(rel_info.sa_relationship_kwargs) - rel_value: RelationshipProperty = relationship( # type: ignore + rel_value: RelationshipProperty = relationship( relationship_to, *rel_args, **rel_kwargs ) setattr(cls, rel_name, rel_value) # Fix #315 @@ -571,8 +578,8 @@ def get_sqlalchemy_type(field: ModelField) -> Any: raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") -def get_column_from_field(field: ModelField) -> Column: # type: ignore - sa_column = getattr(field.field_info, "sa_column", Undefined) +def get_column_from_field(field: FieldInfo) -> Column: # type: ignore + sa_column = getattr(field, "sa_column", PydanticUndefined) if isinstance(sa_column, Column): return sa_column sa_type = get_sqlalchemy_type(field) @@ -605,18 +612,18 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore "index": index, "unique": unique, } - sa_default = Undefined - if field.field_info.default_factory: - sa_default = field.field_info.default_factory - elif field.field_info.default is not Undefined: - sa_default = field.field_info.default - if sa_default is not Undefined: + sa_default = PydanticUndefined + if field.default_factory: + sa_default = field.default_factory + elif field.default is not PydanticUndefined: + sa_default = field.default + if sa_default is not PydanticUndefined: kwargs["default"] = sa_default - sa_column_args = getattr(field.field_info, "sa_column_args", Undefined) - if sa_column_args is not Undefined: + sa_column_args = getattr(field, "sa_column_args", PydanticUndefined) + if sa_column_args is not PydanticUndefined: args.extend(list(cast(Sequence[Any], sa_column_args))) - sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined) - if sa_column_kwargs is not Undefined: + sa_column_kwargs = getattr(field, "sa_column_kwargs", PydanticUndefined) + if sa_column_kwargs is not PydanticUndefined: kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) return Column(sa_type, *args, **kwargs) # type: ignore @@ -625,13 +632,6 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore default_registry = registry() - -def _value_items_is_true(v: Any) -> bool: - # Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of - # the current latest, Pydantic 1.8.2 - return v is True or v is ... - - _TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") @@ -639,43 +639,17 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) __tablename__: ClassVar[Union[str, Callable[..., str]]] - __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore + __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] __name__: ClassVar[str] metadata: ClassVar[MetaData] __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six - - class Config: - orm_mode = True - - def __new__(cls, *args: Any, **kwargs: Any) -> Any: - new_object = super().__new__(cls) - # SQLAlchemy doesn't call __init__ on the base class - # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html - # Set __fields_set__ here, that would have been set when calling __init__ - # in the Pydantic model so that when SQLAlchemy sets attributes that are - # added (e.g. when querying from DB) to the __fields_set__, this already exists - object.__setattr__(new_object, "__fields_set__", set()) - return new_object + model_config = SQLModelConfig(from_attributes=True) def __init__(__pydantic_self__, **data: Any) -> None: - # Uses something other than `self` the first arg to allow "self" as a - # settable attribute - values, fields_set, validation_error = validate_model( - __pydantic_self__.__class__, data - ) - # Only raise errors if not a SQLModel model - if ( - not getattr(__pydantic_self__.__config__, "table", False) - and validation_error - ): - raise validation_error - # Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy - # can handle them - # object.__setattr__(__pydantic_self__, '__dict__', values) - for key, value in values.items(): - setattr(__pydantic_self__, key, value) - object.__setattr__(__pydantic_self__, "__fields_set__", fields_set) - non_pydantic_keys = data.keys() - values.keys() + old_dict = __pydantic_self__.__dict__.copy() + super().__init__(**data) + __pydantic_self__.__dict__ = old_dict | __pydantic_self__.__dict__ + non_pydantic_keys = data.keys() - __pydantic_self__.model_fields for key in non_pydantic_keys: if key in __pydantic_self__.__sqlmodel_relationships__: setattr(__pydantic_self__, key, data[key]) @@ -686,58 +660,12 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if getattr(self.__config__, "table", False) and is_instrumented(self, name): # type: ignore + if self.model_config.get("table", False) and is_instrumented(self, name): set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values if name not in self.__sqlmodel_relationships__: - super().__setattr__(name, value) - - @classmethod - def from_orm( - cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None - ) -> _TSQLModel: - # Duplicated from Pydantic - if not cls.__config__.orm_mode: - raise ConfigError( - "You must have the config attribute orm_mode=True to use from_orm" - ) - obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj) - # SQLModel, support update dict - if update is not None: - obj = {**obj, **update} - # End SQLModel support dict - if not getattr(cls.__config__, "table", False): - # If not table, normal Pydantic code - m: _TSQLModel = cls.__new__(cls) - else: - # If table, create the new instance normally to make SQLAlchemy create - # the _sa_instance_state attribute - m = cls() - values, fields_set, validation_error = validate_model(cls, obj) - if validation_error: - raise validation_error - # Updated to trigger SQLAlchemy internal handling - if not getattr(cls.__config__, "table", False): - object.__setattr__(m, "__dict__", values) - else: - for key, value in values.items(): - setattr(m, key, value) - # Continue with standard Pydantic logic - object.__setattr__(m, "__fields_set__", fields_set) - m._init_private_attributes() - return m - - @classmethod - def parse_obj( - cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None - ) -> _TSQLModel: - obj = cls._enforce_dict_if_root(obj) - # SQLModel, support update dict - if update is not None: - obj = {**obj, **update} - # End SQLModel support dict - return super().parse_obj(obj) + super(SQLModel, self).__setattr__(name, value) def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: # Don't show SQLAlchemy private attributes @@ -747,78 +675,25 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: if not (isinstance(k, str) and k.startswith("_sa_")) ] - # From Pydantic, override to enforce validation with dict - @classmethod - def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel: - if isinstance(value, cls): - return value.copy() if cls.__config__.copy_on_model_validation else value - - value = cls._enforce_dict_if_root(value) - if isinstance(value, dict): - values, fields_set, validation_error = validate_model(cls, value) - if validation_error: - raise validation_error - model = cls(**value) - # Reset fields set, this would have been done in Pydantic in __init__ - object.__setattr__(model, "__fields_set__", fields_set) - return model - elif cls.__config__.orm_mode: - return cls.from_orm(value) - elif cls.__custom_root_type__: - return cls.parse_obj(value) - else: - try: - value_as_dict = dict(value) - except (TypeError, ValueError) as e: - raise DictError() from e - return cls(**value_as_dict) - - # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes - def _calculate_keys( - self, - include: Optional[Mapping[Union[int, str], Any]], - exclude: Optional[Mapping[Union[int, str], Any]], - exclude_unset: bool, - update: Optional[Dict[str, Any]] = None, - ) -> Optional[AbstractSet[str]]: - if include is None and exclude is None and not exclude_unset: - # Original in Pydantic: - # return None - # Updated to not return SQLAlchemy attributes - # Do not include relationships as that would easily lead to infinite - # recursion, or traversing the whole database - return self.__fields__.keys() # | self.__sqlmodel_relationships__.keys() - - keys: AbstractSet[str] - if exclude_unset: - keys = self.__fields_set__.copy() - else: - # Original in Pydantic: - # keys = self.__dict__.keys() - # Updated to not return SQLAlchemy attributes - # Do not include relationships as that would easily lead to infinite - # recursion, or traversing the whole database - keys = self.__fields__.keys() # | self.__sqlmodel_relationships__.keys() - if include is not None: - keys &= include.keys() - - if update: - keys -= update.keys() - - if exclude: - keys -= {k for k, v in exclude.items() if _value_items_is_true(v)} - - return keys @declared_attr # type: ignore def __tablename__(cls) -> str: return cls.__name__.lower() -def _is_field_noneable(field: ModelField) -> bool: - if not field.required: - # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) - return field.allow_none and ( - field.shape != SHAPE_SINGLETON or not field.sub_fields - ) +def _is_field_noneable(field: FieldInfo) -> bool: + if not field.is_required(): + if field.annotation is None or field.annotation is type(None): + return True + if get_origin(field.annotation) is Union: + for base in get_args(field.annotation): + if base is type(None): + return True + return False return False + +def _get_field_metadata(field: FieldInfo) -> object: + for meta in field.metadata: + if isinstance(meta, PydanticGeneralMetadata): + return meta + return object() \ No newline at end of file diff --git a/sqlmodel/typing.py b/sqlmodel/typing.py new file mode 100644 index 0000000000..570da2fa55 --- /dev/null +++ b/sqlmodel/typing.py @@ -0,0 +1,7 @@ +from pydantic import ConfigDict +from typing import Optional, Any + +class SQLModelConfig(ConfigDict): + table: Optional[bool] + read_from_attributes: Optional[bool] + registry: Optional[Any] \ No newline at end of file diff --git a/tests/test_instance_no_args.py b/tests/test_instance_no_args.py index 14d560628b..5dc520c54f 100644 --- a/tests/test_instance_no_args.py +++ b/tests/test_instance_no_args.py @@ -8,7 +8,7 @@ def test_allow_instantiation_without_arguments(clear_sqlmodel): class Item(SQLModel): id: Optional[int] = Field(default=None, primary_key=True) - name: str + name: str description: Optional[str] = None class Config: diff --git a/tests/test_missing_type.py b/tests/test_missing_type.py index 2185fa43e9..dd12b2547a 100644 --- a/tests/test_missing_type.py +++ b/tests/test_missing_type.py @@ -2,10 +2,11 @@ import pytest from sqlmodel import Field, SQLModel +from pydantic import BaseModel def test_missing_sql_type(): - class CustomType: + class CustomType(BaseModel): @classmethod def __get_validators__(cls): yield cls.validate diff --git a/tests/test_validation.py b/tests/test_validation.py index ad60fcb945..4183986a06 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,7 +1,7 @@ from typing import Optional import pytest -from pydantic import validator +from pydantic import field_validator from pydantic.error_wrappers import ValidationError from sqlmodel import SQLModel @@ -22,12 +22,12 @@ class Hero(SQLModel): secret_name: Optional[str] = None age: Optional[int] = None - @validator("name", "secret_name", "age") + @field_validator("name", "secret_name", "age") def reject_none(cls, v): assert v is not None return v - Hero.validate({"age": 25}) + Hero.model_validate({"age": 25}) with pytest.raises(ValidationError): - Hero.validate({"name": None, "age": 25}) + Hero.model_validate({"name": None, "age": 25}) From e92a52eb416e7815a54c23f4792ccfb53bf4fccc Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Mon, 31 Jul 2023 13:40:23 +0000 Subject: [PATCH 20/29] Formatting --- sqlmodel/main.py | 20 ++++++++++---------- sqlmodel/typing.py | 6 ++++-- tests/test_instance_no_args.py | 2 +- tests/test_missing_type.py | 2 +- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 0ca69be152..789a3ac4fc 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -19,7 +19,7 @@ Set, Tuple, Type, - TypeVar,ForwardRef, + TypeVar, Union, cast, overload, @@ -385,7 +385,7 @@ def __new__( class_dict: Dict[str, Any], **kwargs: Any, ) -> Any: - + relationships: Dict[str, RelationshipInfo] = {} dict_for_pydantic = {} original_annotations = class_dict.get("__annotations__", {}) @@ -422,21 +422,21 @@ def __new__( key: pydantic_kwargs.pop(key) for key in pydantic_kwargs.keys() & allowed_config_kwargs } - config_table = getattr(class_dict.get('Config', object()), 'table', False) + config_table = getattr(class_dict.get("Config", object()), "table", False) # If we have a table, we need to have defaults for all fields # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything if config_table is True: for key in original_annotations.keys(): if dict_used.get(key, PydanticUndefined) is PydanticUndefined: dict_used[key] = None - + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, **new_cls.__annotations__, } - + def get_config(name: str) -> Any: config_class_value = new_cls.model_config.get(name, PydanticUndefined) if config_class_value is not PydanticUndefined: @@ -449,7 +449,7 @@ def get_config(name: str) -> Any: config_table = get_config("table") if config_table is True: # If it was passed by kwargs, ensure it's also set in config - new_cls.model_config['table'] = config_table + new_cls.model_config["table"] = config_table for k, v in new_cls.model_fields.items(): col = get_column_from_field(v) setattr(new_cls, k, col) @@ -458,7 +458,7 @@ def get_config(name: str) -> Any: # This could be done by reading new_cls.model_config['table'] in FastAPI, but # that's very specific about SQLModel, so let's have another config that # other future tools based on Pydantic can use. - new_cls.model_config['read_from_attributes'] = True + new_cls.model_config["read_from_attributes"] = True config_registry = get_config("registry") if config_registry is not PydanticUndefined: @@ -518,7 +518,7 @@ def __init__( rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: rel_kwargs.update(rel_info.sa_relationship_kwargs) - rel_value: RelationshipProperty = relationship( + rel_value: RelationshipProperty = relationship( relationship_to, *rel_args, **rel_kwargs ) setattr(cls, rel_name, rel_value) # Fix #315 @@ -675,7 +675,6 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: if not (isinstance(k, str) and k.startswith("_sa_")) ] - @declared_attr # type: ignore def __tablename__(cls) -> str: return cls.__name__.lower() @@ -692,8 +691,9 @@ def _is_field_noneable(field: FieldInfo) -> bool: return False return False + def _get_field_metadata(field: FieldInfo) -> object: for meta in field.metadata: if isinstance(meta, PydanticGeneralMetadata): return meta - return object() \ No newline at end of file + return object() diff --git a/sqlmodel/typing.py b/sqlmodel/typing.py index 570da2fa55..f2c87503c0 100644 --- a/sqlmodel/typing.py +++ b/sqlmodel/typing.py @@ -1,7 +1,9 @@ +from typing import Any, Optional + from pydantic import ConfigDict -from typing import Optional, Any + class SQLModelConfig(ConfigDict): table: Optional[bool] read_from_attributes: Optional[bool] - registry: Optional[Any] \ No newline at end of file + registry: Optional[Any] diff --git a/tests/test_instance_no_args.py b/tests/test_instance_no_args.py index 5dc520c54f..14d560628b 100644 --- a/tests/test_instance_no_args.py +++ b/tests/test_instance_no_args.py @@ -8,7 +8,7 @@ def test_allow_instantiation_without_arguments(clear_sqlmodel): class Item(SQLModel): id: Optional[int] = Field(default=None, primary_key=True) - name: str + name: str description: Optional[str] = None class Config: diff --git a/tests/test_missing_type.py b/tests/test_missing_type.py index dd12b2547a..dc31f053ec 100644 --- a/tests/test_missing_type.py +++ b/tests/test_missing_type.py @@ -1,8 +1,8 @@ from typing import Optional import pytest -from sqlmodel import Field, SQLModel from pydantic import BaseModel +from sqlmodel import Field, SQLModel def test_missing_sql_type(): From 61d5d8dfbe2a5af13b6f41f6ca5d38d471955f47 Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Mon, 31 Jul 2023 13:57:57 +0000 Subject: [PATCH 21/29] Linting --- sqlmodel/main.py | 23 ++++++++++++----------- sqlmodel/typing.py | 2 +- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 789a3ac4fc..3af43b5cf1 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -14,6 +14,7 @@ ForwardRef, List, Mapping, + NoneType, Optional, Sequence, Set, @@ -344,7 +345,7 @@ def Relationship( *, back_populates: Optional[str] = None, link_model: Optional[Any] = None, - sa_relationship: Optional[RelationshipProperty] = None, + sa_relationship: Optional[RelationshipProperty[Any]] = None, sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, ) -> Any: @@ -361,7 +362,7 @@ def Relationship( @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: Dict[str, RelationshipInfo] - model_config: Type[SQLModelConfig] + model_config: SQLModelConfig model_fields: Dict[str, FieldInfo] # Replicate SQLAlchemy @@ -430,7 +431,9 @@ def __new__( if dict_used.get(key, PydanticUndefined) is PydanticUndefined: dict_used[key] = None - new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + new_cls: Type["SQLModelMetaclass"] = super().__new__( + cls, name, bases, dict_used, **config_kwargs + ) new_cls.__annotations__ = { **relationship_annotations, **pydantic_annotations, @@ -518,7 +521,7 @@ def __init__( rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: rel_kwargs.update(rel_info.sa_relationship_kwargs) - rel_value: RelationshipProperty = relationship( + rel_value: RelationshipProperty[Any] = relationship( relationship_to, *rel_args, **rel_kwargs ) setattr(cls, rel_name, rel_value) # Fix #315 @@ -612,7 +615,7 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore "index": index, "unique": unique, } - sa_default = PydanticUndefined + sa_default: PydanticUndefinedType | Callable[[], Any] = PydanticUndefined if field.default_factory: sa_default = field.default_factory elif field.default is not PydanticUndefined: @@ -632,14 +635,12 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore default_registry = registry() -_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") - class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) __tablename__: ClassVar[Union[str, Callable[..., str]]] - __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] + __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]] __name__: ClassVar[str] metadata: ClassVar[MetaData] __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six @@ -660,7 +661,7 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if self.model_config.get("table", False) and is_instrumented(self, name): + if self.model_config.get("table", False) and is_instrumented(self, name): # type: ignore set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values @@ -682,11 +683,11 @@ def __tablename__(cls) -> str: def _is_field_noneable(field: FieldInfo) -> bool: if not field.is_required(): - if field.annotation is None or field.annotation is type(None): + if field.annotation is None or field.annotation is NoneType: return True if get_origin(field.annotation) is Union: for base in get_args(field.annotation): - if base is type(None): + if base is NoneType: return True return False return False diff --git a/sqlmodel/typing.py b/sqlmodel/typing.py index f2c87503c0..8151f99692 100644 --- a/sqlmodel/typing.py +++ b/sqlmodel/typing.py @@ -3,7 +3,7 @@ from pydantic import ConfigDict -class SQLModelConfig(ConfigDict): +class SQLModelConfig(ConfigDict, total=False): table: Optional[bool] read_from_attributes: Optional[bool] registry: Optional[Any] From cb494b763cf3088200749373f3819846e91ab67a Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Tue, 1 Aug 2023 08:13:30 +0000 Subject: [PATCH 22/29] Make all tests but fastapi work --- sqlmodel/main.py | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 3af43b5cf1..be4623bcac 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -14,7 +14,6 @@ ForwardRef, List, Mapping, - NoneType, Optional, Sequence, Set, @@ -58,6 +57,7 @@ _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] +NoneType = type(None) def __dataclass_transform__( @@ -423,13 +423,17 @@ def __new__( key: pydantic_kwargs.pop(key) for key in pydantic_kwargs.keys() & allowed_config_kwargs } - config_table = getattr(class_dict.get("Config", object()), "table", False) + config_table = getattr(class_dict.get("Config", object()), "table", False) or kwargs.get("table", False) # If we have a table, we need to have defaults for all fields # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything if config_table is True: - for key in original_annotations.keys(): - if dict_used.get(key, PydanticUndefined) is PydanticUndefined: + for key in pydantic_annotations.keys(): + value = dict_used.get(key, PydanticUndefined) + if value is PydanticUndefined: dict_used[key] = None + elif isinstance(value, FieldInfo): + if value.default is PydanticUndefined and value.default_factory is None: + value.default = None new_cls: Type["SQLModelMetaclass"] = super().__new__( cls, name, bases, dict_used, **config_kwargs @@ -496,8 +500,11 @@ def __init__( continue ann = cls.__annotations__[rel_name] relationship_to = get_origin(ann) - # If Union (Optional), get the real field - if relationship_to is Union: + # Direct relationships (e.g. 'Team' or Team) have None as an origin + if relationship_to is None: + relationship_to = ann + # If Union (e.g. Optional), get the real field + elif relationship_to is Union: relationship_to = get_args(ann)[0] # If a list, then also get the real field elif relationship_to is list: @@ -646,6 +653,16 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six model_config = SQLModelConfig(from_attributes=True) + def __new__(cls, *args: Any, **kwargs: Any) -> Any: + new_object = super().__new__(cls) + # SQLAlchemy doesn't call __init__ on the base class + # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html + # Set __fields_set__ here, that would have been set when calling __init__ + # in the Pydantic model so that when SQLAlchemy sets attributes that are + # added (e.g. when querying from DB) to the __fields_set__, this already exists + object.__setattr__(new_object, "__pydantic_fields_set__", set()) + return new_object + def __init__(__pydantic_self__, **data: Any) -> None: old_dict = __pydantic_self__.__dict__.copy() super().__init__(**data) @@ -680,6 +697,10 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: def __tablename__(cls) -> str: return cls.__name__.lower() + @classmethod + def model_validate(cls, *args, **kwargs): + return super().model_validate(*args, **kwargs) + def _is_field_noneable(field: FieldInfo) -> bool: if not field.is_required(): From f590548b42438ed56c4d37c2f32882d9c601fd59 Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Tue, 1 Aug 2023 09:33:55 +0000 Subject: [PATCH 23/29] Get all tests except for openapi working --- sqlmodel/main.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index be4623bcac..236c38100e 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -423,7 +423,9 @@ def __new__( key: pydantic_kwargs.pop(key) for key in pydantic_kwargs.keys() & allowed_config_kwargs } - config_table = getattr(class_dict.get("Config", object()), "table", False) or kwargs.get("table", False) + config_table = getattr( + class_dict.get("Config", object()), "table", False + ) or kwargs.get("table", False) # If we have a table, we need to have defaults for all fields # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything if config_table is True: @@ -432,7 +434,12 @@ def __new__( if value is PydanticUndefined: dict_used[key] = None elif isinstance(value, FieldInfo): - if value.default is PydanticUndefined and value.default_factory is None: + if ( + value.default in (PydanticUndefined, Ellipsis) + ) and value.default_factory is None: + value.original_default = ( + value.default + ) # So we can check for nullable value.default = None new_cls: Type["SQLModelMetaclass"] = super().__new__( @@ -641,6 +648,7 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore class_registry = weakref.WeakValueDictionary() # type: ignore default_registry = registry() +_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): @@ -698,12 +706,28 @@ def __tablename__(cls) -> str: return cls.__name__.lower() @classmethod - def model_validate(cls, *args, **kwargs): - return super().model_validate(*args, **kwargs) + def model_validate( + cls: type[_TSQLModel], + obj: Any, + *, + strict: bool | None = None, + from_attributes: bool | None = None, + context: dict[str, Any] | None = None, + ) -> _TSQLModel: + # Somehow model validate doesn't call __init__ so it would remove our init logic + validated = super().model_validate( + obj, strict=strict, from_attributes=from_attributes, context=context + ) + return cls(**{key: value for key, value in validated}) def _is_field_noneable(field: FieldInfo) -> bool: + if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: + return field.nullable if not field.is_required(): + default = getattr(field, "original_default", field.default) + if default is PydanticUndefined: + return False if field.annotation is None or field.annotation is NoneType: return True if get_origin(field.annotation) is Union: From 7ecbc38b58460e6167a57cd06212b7bc32d853d6 Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Mon, 13 Nov 2023 13:13:29 +0000 Subject: [PATCH 24/29] Write for pydantic v1 and v2 compat --- .github/workflows/test.yml | 7 + sqlmodel/__init__.py | 2 +- sqlmodel/compat.py | 169 +++++++++++++++ sqlmodel/main.py | 420 +++++++++++++++++++++++++------------ sqlmodel/orm/session.py | 4 +- sqlmodel/sql/expression.py | 8 + sqlmodel/typing.py | 9 - tests/conftest.py | 5 + 8 files changed, 478 insertions(+), 146 deletions(-) create mode 100644 sqlmodel/compat.py delete mode 100644 sqlmodel/typing.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9f2e06ed71..979a08b2b2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,6 +21,7 @@ jobs: strategy: matrix: python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + pydantic-version: ["pydantic-v1", "pydantic-v2"] fail-fast: false steps: @@ -51,6 +52,12 @@ jobs: - name: Install Dependencies if: steps.cache.outputs.cache-hit != 'true' run: python -m poetry install + - name: Install Pydantic v1 + if: matrix.pydantic-version == 'pydantic-v1' + run: pip install "pydantic>=1.10.0,<2.0.0" + - name: Install Pydantic v2 + if: matrix.pydantic-version == 'pydantic-v2' + run: pip install "pydantic>=2.0.2,<3.0.0" - name: Lint run: python -m poetry run bash scripts/lint.sh - run: mkdir coverage diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index 4431d7cea7..7e20e1ba41 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -30,6 +30,7 @@ from sqlalchemy.sql import ( LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL, ) +from sqlalchemy.sql import Subquery as Subquery from sqlalchemy.sql import alias as alias from sqlalchemy.sql import all_ as all_ from sqlalchemy.sql import and_ as and_ @@ -70,7 +71,6 @@ from sqlalchemy.sql import outerjoin as outerjoin from sqlalchemy.sql import outparam as outparam from sqlalchemy.sql import over as over -from sqlalchemy.sql import Subquery as Subquery from sqlalchemy.sql import table as table from sqlalchemy.sql import tablesample as tablesample from sqlalchemy.sql import text as text diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py new file mode 100644 index 0000000000..41abbef251 --- /dev/null +++ b/sqlmodel/compat.py @@ -0,0 +1,169 @@ +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + ForwardRef, + Optional, + Type, + TypeVar, + Union, + get_args, + get_origin, +) + +from pydantic import VERSION as PYDANTIC_VERSION + +IS_PYDANTIC_V2 = int(PYDANTIC_VERSION.split(".")[0]) >= 2 + + +if IS_PYDANTIC_V2: + from pydantic import ConfigDict + from pydantic_core import PydanticUndefined as PydanticUndefined, PydanticUndefinedType as PydanticUndefinedType # noqa +else: + from pydantic import BaseConfig # noqa + from pydantic.fields import ModelField # noqa + from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType # noqa + +if TYPE_CHECKING: + from .main import FieldInfo, RelationshipInfo, SQLModel, SQLModelMetaclass + + +NoArgAnyCallable = Callable[[], Any] +T = TypeVar("T") +InstanceOrType = Union[T, Type[T]] + +if IS_PYDANTIC_V2: + + class SQLModelConfig(ConfigDict, total=False): + table: Optional[bool] + read_from_attributes: Optional[bool] + registry: Optional[Any] + +else: + + class SQLModelConfig(BaseConfig): + table: Optional[bool] = None + read_from_attributes: Optional[bool] = None + registry: Optional[Any] = None + + def __getitem__(self, item: str) -> Any: + return self.__getattr__(item) + + def __setitem__(self, item: str, value: Any) -> None: + return self.__setattr__(item, value) + + +# Inspired from https://github.com/roman-right/beanie/blob/main/beanie/odm/utils/pydantic.py +def get_model_config(model: type) -> Optional[SQLModelConfig]: + if IS_PYDANTIC_V2: + return getattr(model, "model_config", None) + else: + return getattr(model, "Config", None) + + +def get_config_value( + model: InstanceOrType["SQLModel"], parameter: str, default: Any = None +) -> Any: + if IS_PYDANTIC_V2: + return model.model_config.get(parameter, default) + else: + return getattr(model.Config, parameter, default) + + +def set_config_value( + model: InstanceOrType["SQLModel"], parameter: str, value: Any, v1_parameter: str = None +) -> None: + if IS_PYDANTIC_V2: + model.model_config[parameter] = value # type: ignore + else: + model.Config[v1_parameter or parameter] = value # type: ignore + + +def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: + if IS_PYDANTIC_V2: + return model.model_fields # type: ignore + else: + return model.__fields__ # type: ignore + + +def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]: + if IS_PYDANTIC_V2: + return model.__pydantic_fields_set__ + else: + return model.__fields_set__ # type: ignore + + +def set_fields_set( + new_object: InstanceOrType["SQLModel"], fields: set["FieldInfo"] +) -> None: + if IS_PYDANTIC_V2: + object.__setattr__(new_object, "__pydantic_fields_set__", fields) + else: + object.__setattr__(new_object, "__fields_set__", fields) + + +def set_attribute_mode(cls: Type["SQLModelMetaclass"]) -> None: + if IS_PYDANTIC_V2: + cls.model_config["read_from_attributes"] = True + else: + cls.__config__.read_with_orm_mode = True # type: ignore + + +def get_relationship_to( + name: str, + rel_info: "RelationshipInfo", + annotation: Any, +) -> Any: + if IS_PYDANTIC_V2: + relationship_to = get_origin(annotation) + # Direct relationships (e.g. 'Team' or Team) have None as an origin + if relationship_to is None: + relationship_to = annotation + # If Union (e.g. Optional), get the real field + elif relationship_to is Union: + relationship_to = get_args(annotation)[0] + # If a list, then also get the real field + elif relationship_to is list: + relationship_to = get_args(annotation)[0] + if isinstance(relationship_to, ForwardRef): + relationship_to = relationship_to.__forward_arg__ + return relationship_to + else: + temp_field = ModelField.infer( + name=name, + value=rel_info, + annotation=annotation, + class_validators=None, + config=SQLModelConfig, + ) + relationship_to = temp_field.type_ + if isinstance(temp_field.type_, ForwardRef): + relationship_to = temp_field.type_.__forward_arg__ + return relationship_to + + +def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) -> None: + """ + Pydantic v2 without required fields with no optionals cannot do empty initialisations. + This means we cannot do Table() and set fields later. + We go around this by adding a default to everything, being None + + Args: + annotations: Dict[str, Any]: The annotations to provide to pydantic + class_dict: Dict[str, Any]: The class dict for the defaults + """ + if IS_PYDANTIC_V2: + from .main import FieldInfo + # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything + for key in annotations.keys(): + value = class_dict.get(key, PydanticUndefined) + if value is PydanticUndefined: + class_dict[key] = None + elif isinstance(value, FieldInfo): + if ( + value.default in (PydanticUndefined, Ellipsis) + ) and value.default_factory is None: + # So we can check for nullable + value.original_default = value.default + value.default = None diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 236c38100e..d5b69b28a8 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -5,13 +5,13 @@ from decimal import Decimal from enum import Enum from pathlib import Path +from types import NoneType from typing import ( AbstractSet, Any, Callable, ClassVar, Dict, - ForwardRef, List, Mapping, Optional, @@ -22,16 +22,16 @@ TypeVar, Union, cast, + get_args, + get_origin, overload, ) -from pydantic import BaseConfig, BaseModel -from pydantic.errors import ConfigError, DictError -from pydantic.fields import SHAPE_SINGLETON, ModelField, Undefined, UndefinedType +from pydantic import BaseModel +from pydantic.fields import SHAPE_SINGLETON, ModelField from pydantic.fields import FieldInfo as PydanticFieldInfo -from pydantic.main import ModelMetaclass, validate_model -from pydantic.typing import NoArgAnyCallable, resolve_annotations -from pydantic.utils import ROOT_KEY, Representation +from pydantic.main import ModelMetaclass +from pydantic.utils import Representation from sqlalchemy import ( Boolean, Column, @@ -52,12 +52,27 @@ from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time +from .compat import ( + IS_PYDANTIC_V2, + NoArgAnyCallable, + PydanticUndefined, + PydanticUndefinedType, + SQLModelConfig, + get_config_value, + get_model_fields, + get_relationship_to, + set_config_value, + set_empty_defaults, + set_fields_set, +) from .sql.sqltypes import GUID, AutoString -from .typing import SQLModelConfig + +if not IS_PYDANTIC_V2: + from pydantic.errors import ConfigError, DictError + from pydantic.main import validate_model + from pydantic.utils import ROOT_KEY _T = TypeVar("_T") -NoArgAnyCallable = Callable[[], Any] -NoneType = type(None) def __dataclass_transform__( @@ -76,13 +91,13 @@ def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: nullable = kwargs.pop("nullable", PydanticUndefined) foreign_key = kwargs.pop("foreign_key", PydanticUndefined) unique = kwargs.pop("unique", False) - index = kwargs.pop("index", Undefined) - sa_type = kwargs.pop("sa_type", Undefined) - sa_column = kwargs.pop("sa_column", Undefined) - sa_column_args = kwargs.pop("sa_column_args", Undefined) - sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) - if sa_column is not Undefined: - if sa_column_args is not Undefined: + index = kwargs.pop("index", PydanticUndefined) + sa_type = kwargs.pop("sa_type", PydanticUndefined) + sa_column = kwargs.pop("sa_column", PydanticUndefined) + sa_column_args = kwargs.pop("sa_column_args", PydanticUndefined) + sa_column_kwargs = kwargs.pop("sa_column_kwargs", PydanticUndefined) + if sa_column is not PydanticUndefined: + if sa_column_args is not PydanticUndefined: raise RuntimeError( "Passing sa_column_args is not supported when " "also passing a sa_column" @@ -92,29 +107,29 @@ def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: "Passing sa_column_kwargs is not supported when " "also passing a sa_column" ) - if primary_key is not Undefined: + if primary_key is not PydanticUndefined: raise RuntimeError( "Passing primary_key is not supported when " "also passing a sa_column" ) - if nullable is not Undefined: + if nullable is not PydanticUndefined: raise RuntimeError( "Passing nullable is not supported when " "also passing a sa_column" ) - if foreign_key is not Undefined: + if foreign_key is not PydanticUndefined: raise RuntimeError( "Passing foreign_key is not supported when " "also passing a sa_column" ) - if unique is not Undefined: + if unique is not PydanticUndefined: raise RuntimeError( "Passing unique is not supported when also passing a sa_column" ) - if index is not Undefined: + if index is not PydanticUndefined: raise RuntimeError( "Passing index is not supported when also passing a sa_column" ) - if sa_type is not Undefined: + if sa_type is not PydanticUndefined: raise RuntimeError( "Passing sa_type is not supported when also passing a sa_column" ) @@ -128,6 +143,7 @@ def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: self.sa_column = sa_column self.sa_column_args = sa_column_args self.sa_column_kwargs = sa_column_kwargs + self.original_default = PydanticUndefined class RelationshipInfo(Representation): @@ -189,14 +205,16 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - primary_key: Union[bool, UndefinedType] = Undefined, - foreign_key: Any = Undefined, - unique: Union[bool, UndefinedType] = Undefined, - nullable: Union[bool, UndefinedType] = Undefined, - index: Union[bool, UndefinedType] = Undefined, - sa_type: Union[Type[Any], UndefinedType] = Undefined, - sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, - sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + primary_key: Union[bool, PydanticUndefinedType] = PydanticUndefined, + foreign_key: Any = PydanticUndefined, + unique: Union[bool, PydanticUndefinedType] = PydanticUndefined, + nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, + index: Union[bool, PydanticUndefinedType] = PydanticUndefined, + sa_type: Union[Type[Any], PydanticUndefinedType] = PydanticUndefined, + sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, + sa_column_kwargs: Union[ + Mapping[str, Any], PydanticUndefinedType + ] = PydanticUndefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: ... @@ -204,7 +222,7 @@ def Field( @overload def Field( - default: Any = Undefined, + default: Any = PydanticUndefined, *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, @@ -233,14 +251,14 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore + sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: ... def Field( - default: Any = Undefined, + default: Any = PydanticUndefined, *, default_factory: Optional[NoArgAnyCallable] = None, alias: Optional[str] = None, @@ -269,15 +287,17 @@ def Field( regex: Optional[str] = None, discriminator: Optional[str] = None, repr: bool = True, - primary_key: Union[bool, UndefinedType] = Undefined, - foreign_key: Any = Undefined, - unique: Union[bool, UndefinedType] = Undefined, - nullable: Union[bool, UndefinedType] = Undefined, - index: Union[bool, UndefinedType] = Undefined, - sa_type: Union[Type[Any], UndefinedType] = Undefined, - sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore - sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, - sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + primary_key: Union[bool, PydanticUndefinedType] = PydanticUndefined, + foreign_key: Any = PydanticUndefined, + unique: Union[bool, PydanticUndefinedType] = PydanticUndefined, + nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, + index: Union[bool, PydanticUndefinedType] = PydanticUndefined, + sa_type: Union[Type[Any], PydanticUndefinedType] = PydanticUndefined, + sa_column: Union[Column, PydanticUndefinedType] = PydanticUndefined, # type: ignore + sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, + sa_column_kwargs: Union[ + Mapping[str, Any], PydanticUndefinedType + ] = PydanticUndefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} @@ -345,7 +365,7 @@ def Relationship( *, back_populates: Optional[str] = None, link_model: Optional[Any] = None, - sa_relationship: Optional[RelationshipProperty[Any]] = None, + sa_relationship: Optional[RelationshipProperty] = None, sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, ) -> Any: @@ -362,18 +382,22 @@ def Relationship( @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: Dict[str, RelationshipInfo] - model_config: SQLModelConfig - model_fields: Dict[str, FieldInfo] + if IS_PYDANTIC_V2: + model_config: SQLModelConfig + model_fields: Dict[str, FieldInfo] + else: + __config__: Type[SQLModelConfig] + __fields__: Dict[str, FieldInfo] # Replicate SQLAlchemy def __setattr__(cls, name: str, value: Any) -> None: - if cls.model_config.get("table", False): + if get_config_value(cls, "table", False): DeclarativeMeta.__setattr__(cls, name, value) else: super().__setattr__(name, value) def __delattr__(cls, name: str) -> None: - if cls.model_config.get("table", False): + if get_config_value(cls, "table", False): DeclarativeMeta.__delattr__(cls, name) else: super().__delattr__(name) @@ -386,7 +410,6 @@ def __new__( class_dict: Dict[str, Any], **kwargs: Any, ) -> Any: - relationships: Dict[str, RelationshipInfo] = {} dict_for_pydantic = {} original_annotations = class_dict.get("__annotations__", {}) @@ -426,21 +449,8 @@ def __new__( config_table = getattr( class_dict.get("Config", object()), "table", False ) or kwargs.get("table", False) - # If we have a table, we need to have defaults for all fields - # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything - if config_table is True: - for key in pydantic_annotations.keys(): - value = dict_used.get(key, PydanticUndefined) - if value is PydanticUndefined: - dict_used[key] = None - elif isinstance(value, FieldInfo): - if ( - value.default in (PydanticUndefined, Ellipsis) - ) and value.default_factory is None: - value.original_default = ( - value.default - ) # So we can check for nullable - value.default = None + if config_table: + set_empty_defaults(pydantic_annotations, dict_used) new_cls: Type["SQLModelMetaclass"] = super().__new__( cls, name, bases, dict_used, **config_kwargs @@ -452,7 +462,7 @@ def __new__( } def get_config(name: str) -> Any: - config_class_value = new_cls.model_config.get(name, PydanticUndefined) + config_class_value = get_config_value(new_cls, name, PydanticUndefined) if config_class_value is not PydanticUndefined: return config_class_value kwarg_value = kwargs.get(name, PydanticUndefined) @@ -463,8 +473,8 @@ def get_config(name: str) -> Any: config_table = get_config("table") if config_table is True: # If it was passed by kwargs, ensure it's also set in config - new_cls.model_config["table"] = config_table - for k, v in new_cls.model_fields.items(): + set_config_value(new_cls, "table", config_table) + for k, v in get_model_fields(new_cls).items(): col = get_column_from_field(v) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field @@ -472,13 +482,15 @@ def get_config(name: str) -> Any: # This could be done by reading new_cls.model_config['table'] in FastAPI, but # that's very specific about SQLModel, so let's have another config that # other future tools based on Pydantic can use. - new_cls.model_config["read_from_attributes"] = True + set_config_value( + new_cls, "read_from_attributes", True, v1_parameter="read_with_orm_mode" + ) config_registry = get_config("registry") if config_registry is not PydanticUndefined: config_registry = cast(registry, config_registry) # If it was passed by kwargs, ensure it's also set in config - new_cls.__config__.registry = config_table + set_config_value(new_cls, "registry", config_table) setattr(new_cls, "_sa_registry", config_registry) # noqa: B010 setattr(new_cls, "metadata", config_registry.metadata) # noqa: B010 setattr(new_cls, "__abstract__", True) # noqa: B010 @@ -506,18 +518,7 @@ def __init__( setattr(cls, rel_name, rel_info.sa_relationship) # Fix #315 continue ann = cls.__annotations__[rel_name] - relationship_to = get_origin(ann) - # Direct relationships (e.g. 'Team' or Team) have None as an origin - if relationship_to is None: - relationship_to = ann - # If Union (e.g. Optional), get the real field - elif relationship_to is Union: - relationship_to = get_args(ann)[0] - # If a list, then also get the real field - elif relationship_to is list: - relationship_to = get_args(ann)[0] - if isinstance(relationship_to, ForwardRef): - relationship_to = relationship_to.__forward_arg__ + relationship_to = get_relationship_to(rel_name, rel_info, ann) rel_kwargs: Dict[str, Any] = {} if rel_info.back_populates: rel_kwargs["back_populates"] = rel_info.back_populates @@ -535,7 +536,7 @@ def __init__( rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: rel_kwargs.update(rel_info.sa_relationship_kwargs) - rel_value: RelationshipProperty[Any] = relationship( + rel_value: RelationshipProperty = relationship( relationship_to, *rel_args, **rel_kwargs ) setattr(cls, rel_name, rel_value) # Fix #315 @@ -548,8 +549,8 @@ def __init__( def get_sqlalchemy_type(field: ModelField) -> Any: - sa_type = getattr(field.field_info, "sa_type", Undefined) # noqa: B009 - if sa_type is not Undefined: + sa_type = getattr(field.field_info, "sa_type", PydanticUndefined) # noqa: B009 + if sa_type is not PydanticUndefined: return sa_type if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI @@ -595,30 +596,30 @@ def get_sqlalchemy_type(field: ModelField) -> Any: raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") -def get_column_from_field(field: FieldInfo) -> Column: # type: ignore - sa_column = getattr(field, "sa_column", PydanticUndefined) +def get_column_from_field(field: ModelField) -> Column: # type: ignore + sa_column = getattr(field.field_info, "sa_column", PydanticUndefined) if isinstance(sa_column, Column): return sa_column sa_type = get_sqlalchemy_type(field) - primary_key = getattr(field.field_info, "primary_key", Undefined) - if primary_key is Undefined: + primary_key = getattr(field.field_info, "primary_key", PydanticUndefined) + if primary_key is PydanticUndefined: primary_key = False - index = getattr(field.field_info, "index", Undefined) - if index is Undefined: + index = getattr(field.field_info, "index", PydanticUndefined) + if index is PydanticUndefined: index = False nullable = not primary_key and _is_field_noneable(field) # Override derived nullability if the nullable property is set explicitly # on the field - field_nullable = getattr(field.field_info, "nullable", Undefined) # noqa: B009 - if field_nullable != Undefined: - assert not isinstance(field_nullable, UndefinedType) + field_nullable = getattr(field.field_info, "nullable", PydanticUndefined) # noqa: B009 + if field_nullable != PydanticUndefined: + assert not isinstance(field_nullable, PydanticUndefinedType) nullable = field_nullable args = [] - foreign_key = getattr(field.field_info, "foreign_key", Undefined) - if foreign_key is Undefined: + foreign_key = getattr(field.field_info, "foreign_key", PydanticUndefined) + if foreign_key is PydanticUndefined: foreign_key = None - unique = getattr(field.field_info, "unique", Undefined) - if unique is Undefined: + unique = getattr(field.field_info, "unique", PydanticUndefined) + if unique is PydanticUndefined: unique = False if foreign_key: assert isinstance(foreign_key, str) @@ -629,17 +630,17 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore "index": index, "unique": unique, } - sa_default: PydanticUndefinedType | Callable[[], Any] = PydanticUndefined - if field.default_factory: - sa_default = field.default_factory - elif field.default is not PydanticUndefined: - sa_default = field.default + sa_default = PydanticUndefined + if field.field_info.default_factory: + sa_default = field.field_info.default_factory + elif field.field_info.default is not PydanticUndefined: + sa_default = field.field_info.default if sa_default is not PydanticUndefined: kwargs["default"] = sa_default - sa_column_args = getattr(field, "sa_column_args", PydanticUndefined) + sa_column_args = getattr(field.field_info, "sa_column_args", PydanticUndefined) if sa_column_args is not PydanticUndefined: args.extend(list(cast(Sequence[Any], sa_column_args))) - sa_column_kwargs = getattr(field, "sa_column_kwargs", PydanticUndefined) + sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", PydanticUndefined) if sa_column_kwargs is not PydanticUndefined: kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) return Column(sa_type, *args, **kwargs) # type: ignore @@ -648,6 +649,14 @@ def get_column_from_field(field: FieldInfo) -> Column: # type: ignore class_registry = weakref.WeakValueDictionary() # type: ignore default_registry = registry() + + +def _value_items_is_true(v: Any) -> bool: + # Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of + # the current latest, Pydantic 1.8.2 + return v is True or v is ... + + _TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") @@ -655,11 +664,17 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",) __tablename__: ClassVar[Union[str, Callable[..., str]]] - __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]] + __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] __name__: ClassVar[str] metadata: ClassVar[MetaData] __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six - model_config = SQLModelConfig(from_attributes=True) + + if IS_PYDANTIC_V2: + model_config = SQLModelConfig(from_attributes=True) + else: + + class Config: + orm_mode = True def __new__(cls, *args: Any, **kwargs: Any) -> Any: new_object = super().__new__(cls) @@ -668,14 +683,35 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Any: # Set __fields_set__ here, that would have been set when calling __init__ # in the Pydantic model so that when SQLAlchemy sets attributes that are # added (e.g. when querying from DB) to the __fields_set__, this already exists - object.__setattr__(new_object, "__pydantic_fields_set__", set()) + set_fields_set(new_object, set()) return new_object def __init__(__pydantic_self__, **data: Any) -> None: - old_dict = __pydantic_self__.__dict__.copy() - super().__init__(**data) - __pydantic_self__.__dict__ = old_dict | __pydantic_self__.__dict__ - non_pydantic_keys = data.keys() - __pydantic_self__.model_fields + # Uses something other than `self` the first arg to allow "self" as a + # settable attribute + if IS_PYDANTIC_V2: + old_dict = __pydantic_self__.__dict__.copy() + __pydantic_self__.super().__init__(**data) # noqa + __pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__} + non_pydantic_keys = data.keys() - __pydantic_self__.model_fields + else: + values, fields_set, validation_error = validate_model( + __pydantic_self__.__class__, data + ) + # Only raise errors if not a SQLModel model + if ( + not getattr(__pydantic_self__.__config__, "table", False) # noqa + and validation_error + ): + raise validation_error + # Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy + # can handle them + # object.__setattr__(__pydantic_self__, '__dict__', values) + for key, value in values.items(): + setattr(__pydantic_self__, key, value) + object.__setattr__(__pydantic_self__, "__fields_set__", fields_set) + non_pydantic_keys = data.keys() - values.keys() + for key in non_pydantic_keys: if key in __pydantic_self__.__sqlmodel_relationships__: setattr(__pydantic_self__, key, data[key]) @@ -686,12 +722,12 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if self.model_config.get("table", False) and is_instrumented(self, name): # type: ignore + if get_config_value(self, "table", False) and is_instrumented(self, name): set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values if name not in self.__sqlmodel_relationships__: - super(SQLModel, self).__setattr__(name, value) + super().__setattr__(name, value) def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: # Don't show SQLAlchemy private attributes @@ -705,20 +741,143 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: def __tablename__(cls) -> str: return cls.__name__.lower() - @classmethod - def model_validate( - cls: type[_TSQLModel], - obj: Any, - *, - strict: bool | None = None, - from_attributes: bool | None = None, - context: dict[str, Any] | None = None, - ) -> _TSQLModel: - # Somehow model validate doesn't call __init__ so it would remove our init logic - validated = super().model_validate( - obj, strict=strict, from_attributes=from_attributes, context=context - ) - return cls(**{key: value for key, value in validated}) + if IS_PYDANTIC_V2: + + @classmethod + def model_validate( + cls: type[_TSQLModel], + obj: Any, + *, + strict: bool | None = None, + from_attributes: bool | None = None, + context: dict[str, Any] | None = None, + ) -> _TSQLModel: + # Somehow model validate doesn't call __init__ so it would remove our init logic + validated = super().model_validate( + obj, strict=strict, from_attributes=from_attributes, context=context + ) + return cls(**dict(validated)) + + else: + + @classmethod + def from_orm( + cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None + ) -> _TSQLModel: + # Duplicated from Pydantic + if not cls.__config__.orm_mode: # noqa: attr-defined + raise ConfigError( + "You must have the config attribute orm_mode=True to use from_orm" + ) + obj = ( + {ROOT_KEY: obj} + if cls.__custom_root_type__ # noqa + else cls._decompose_class(obj) # noqa + ) + # SQLModel, support update dict + if update is not None: + obj = {**obj, **update} + # End SQLModel support dict + if not getattr(cls.__config__, "table", False): # noqa + # If not table, normal Pydantic code + m: _TSQLModel = cls.__new__(cls) + else: + # If table, create the new instance normally to make SQLAlchemy create + # the _sa_instance_state attribute + m = cls() + values, fields_set, validation_error = validate_model(cls, obj) + if validation_error: + raise validation_error + # Updated to trigger SQLAlchemy internal handling + if not getattr(cls.__config__, "table", False): # noqa + object.__setattr__(m, "__dict__", values) + else: + for key, value in values.items(): + setattr(m, key, value) + # Continue with standard Pydantic logic + object.__setattr__(m, "__fields_set__", fields_set) + m._init_private_attributes() # noqa + return m + + @classmethod + def parse_obj( + cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None + ) -> _TSQLModel: + obj = cls._enforce_dict_if_root(obj) # noqa + # SQLModel, support update dict + if update is not None: + obj = {**obj, **update} + # End SQLModel support dict + return super().parse_obj(obj) + + # From Pydantic, override to enforce validation with dict + @classmethod + def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel: + if isinstance(value, cls): + return ( + value.copy() if cls.__config__.copy_on_model_validation else value # noqa + ) + + value = cls._enforce_dict_if_root(value) + if isinstance(value, dict): + values, fields_set, validation_error = validate_model(cls, value) + if validation_error: + raise validation_error + model = cls(**value) + # Reset fields set, this would have been done in Pydantic in __init__ + object.__setattr__(model, "__fields_set__", fields_set) + return model + elif cls.__config__.orm_mode: # noqa + return cls.from_orm(value) + elif cls.__custom_root_type__: # noqa + return cls.parse_obj(value) + else: + try: + value_as_dict = dict(value) + except (TypeError, ValueError) as e: + raise DictError() from e + return cls(**value_as_dict) + + # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes + def _calculate_keys( + self, + include: Optional[Mapping[Union[int, str], Any]], + exclude: Optional[Mapping[Union[int, str], Any]], + exclude_unset: bool, + update: Optional[Dict[str, Any]] = None, + ) -> Optional[AbstractSet[str]]: + if include is None and exclude is None and not exclude_unset: + # Original in Pydantic: + # return None + # Updated to not return SQLAlchemy attributes + # Do not include relationships as that would easily lead to infinite + # recursion, or traversing the whole database + return ( + self.__fields__.keys() # noqa + ) # | self.__sqlmodel_relationships__.keys() + + keys: AbstractSet[str] + if exclude_unset: + keys = self.__fields_set__.copy() # noqa + else: + # Original in Pydantic: + # keys = self.__dict__.keys() + # Updated to not return SQLAlchemy attributes + # Do not include relationships as that would easily lead to infinite + # recursion, or traversing the whole database + keys = ( + self.__fields__.keys() # noqa + ) # | self.__sqlmodel_relationships__.keys() + if include is not None: + keys &= include.keys() + + if update: + keys -= update.keys() + + if exclude: + keys -= {k for k, v in exclude.items() if _value_items_is_true(v)} + + return keys def _is_field_noneable(field: FieldInfo) -> bool: @@ -736,10 +895,3 @@ def _is_field_noneable(field: FieldInfo) -> bool: return True return False return False - - -def _get_field_metadata(field: FieldInfo) -> object: - for meta in field.metadata: - if isinstance(meta, PydanticGeneralMetadata): - return meta - return object() diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index c9bb043336..29aba05eec 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -1,6 +1,7 @@ from typing import ( Any, Dict, + Literal, Mapping, Optional, Sequence, @@ -15,7 +16,6 @@ from sqlalchemy.orm import Query as _Query from sqlalchemy.orm import Session as _Session from sqlalchemy.sql.base import Executable as _Executable -from typing_extensions import Literal from ..engine.result import Result, ScalarResult from ..sql.base import Executable @@ -137,7 +137,7 @@ def get( ident: Any, options: Optional[Sequence[Any]] = None, populate_existing: bool = False, - with_for_update: Optional[_ForUpdateArg] = None, + with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None, identity_token: Optional[Any] = None, execution_options: Mapping[Any, Any] = util.EMPTY_DICT, bind_arguments: Optional[Dict[str, Any]] = None, diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index a0ac1bd9d9..8cb2309228 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -35,6 +35,14 @@ class SelectOfScalar(_Select[Tuple[_TSelect]], Generic[_TSelect]): inherit_cache = True +# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different +# purpose. This is the same as a normal SQLAlchemy Select class where there's only one +# entity, so the result will be converted to a scalar by default. This way writing +# for loops on the results will feel natural. +class SelectOfScalar(_Select, Generic[_TSelect]): + inherit_cache = True + + if TYPE_CHECKING: # pragma: no cover from ..main import SQLModel diff --git a/sqlmodel/typing.py b/sqlmodel/typing.py deleted file mode 100644 index 8151f99692..0000000000 --- a/sqlmodel/typing.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Any, Optional - -from pydantic import ConfigDict - - -class SQLModelConfig(ConfigDict, total=False): - table: Optional[bool] - read_from_attributes: Optional[bool] - registry: Optional[Any] diff --git a/tests/conftest.py b/tests/conftest.py index 2b8e5fc29e..020b33a566 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ import pytest from pydantic import BaseModel from sqlmodel import SQLModel +from sqlmodel.compat import IS_PYDANTIC_V2 from sqlmodel.main import default_registry top_level_path = Path(__file__).resolve().parent.parent @@ -67,3 +68,7 @@ def new_print(*args): calls.append(data) return new_print + + +needs_pydanticv2 = pytest.mark.skipif(not IS_PYDANTIC_V2, reason="requires Pydantic v2") +needs_pydanticv1 = pytest.mark.skipif(IS_PYDANTIC_V2, reason="requires Pydantic v1") From c21ff699edc62dcda598e952f35a97859549e052 Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Tue, 14 Nov 2023 08:50:11 +0000 Subject: [PATCH 25/29] Make pydantic v1 work again --- docs/contributing.md | 4 -- docs/features.md | 2 +- .../fastapi/app_testing/tutorial001/main.py | 7 ++- .../tutorial/fastapi/delete/tutorial001.py | 6 +- .../fastapi/limit_and_offset/tutorial001.py | 6 +- .../fastapi/multiple_models/tutorial001.py | 6 +- .../fastapi/multiple_models/tutorial002.py | 6 +- .../tutorial/fastapi/read_one/tutorial001.py | 6 +- .../fastapi/relationships/tutorial001.py | 11 +++- .../session_with_dependency/tutorial001.py | 6 +- .../tutorial/fastapi/teams/tutorial001.py | 11 +++- .../tutorial/fastapi/update/tutorial001.py | 6 +- sqlmodel/compat.py | 63 +++++++++++++++---- sqlmodel/engine/result.py | 61 ++++++++++++++---- sqlmodel/ext/asyncio/session.py | 2 +- sqlmodel/main.py | 39 ++++-------- sqlmodel/orm/session.py | 35 ++++------- sqlmodel/sql/expression.py | 50 ++++++++++----- sqlmodel/sql/expression.py.jinja2 | 25 +++----- sqlmodel/sql/sqltypes.py | 8 +-- tests/test_validation.py | 31 ++++++++- 21 files changed, 260 insertions(+), 131 deletions(-) diff --git a/docs/contributing.md b/docs/contributing.md index 5c160a7c0a..217ed61c56 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -6,10 +6,6 @@ First, you might want to see the basic ways to [help SQLModel and get help](help If you already cloned the repository and you know that you need to deep dive in the code, here are some guidelines to set up your environment. -### Python - -SQLModel supports Python 3.7 and above. - ### Poetry **SQLModel** uses Poetry to build, package, and publish the project. diff --git a/docs/features.md b/docs/features.md index 2d5e11d84f..102edef725 100644 --- a/docs/features.md +++ b/docs/features.md @@ -12,7 +12,7 @@ Nevertheless, SQLModel is completely **independent** of FastAPI and can be used ## Just Modern Python -It's all based on standard modern **Python** type annotations. No new syntax to learn. Just standard modern Python. +It's all based on standard modern **Python** type annotations. No new syntax to learn. Just standard modern Python. If you need a 2 minute refresher of how to use Python types (even if you don't use SQLModel or FastAPI), check the FastAPI tutorial section: Python types intro. diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py index a23dfad5a8..cf2f4da233 100644 --- a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py +++ b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py @@ -2,6 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -51,10 +52,12 @@ def get_session(): def on_startup(): create_db_and_tables() - @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/delete/tutorial001.py b/docs_src/tutorial/fastapi/delete/tutorial001.py index 77a99a9c97..f186c42b2b 100644 --- a/docs_src/tutorial/fastapi/delete/tutorial001.py +++ b/docs_src/tutorial/fastapi/delete/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -50,7 +51,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py index 2352f39022..6701355f17 100644 --- a/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py +++ b/docs_src/tutorial/fastapi/limit_and_offset/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -44,7 +45,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py index 7f59ac6a1d..0ceed94ca1 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial001.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import FastAPI from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class Hero(SQLModel, table=True): @@ -46,7 +47,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py index fffbe72496..d92745a339 100644 --- a/docs_src/tutorial/fastapi/multiple_models/tutorial002.py +++ b/docs_src/tutorial/fastapi/multiple_models/tutorial002.py @@ -2,6 +2,7 @@ from fastapi import FastAPI from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -44,7 +45,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/read_one/tutorial001.py b/docs_src/tutorial/fastapi/read_one/tutorial001.py index f18426e74c..4cdf898922 100644 --- a/docs_src/tutorial/fastapi/read_one/tutorial001.py +++ b/docs_src/tutorial/fastapi/read_one/tutorial001.py @@ -3,6 +3,7 @@ from fastapi import FastAPI, HTTPException from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): name: str = Field(index=True) @@ -44,7 +45,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/relationships/tutorial001.py b/docs_src/tutorial/fastapi/relationships/tutorial001.py index e5b196090e..dfcedaf881 100644 --- a/docs_src/tutorial/fastapi/relationships/tutorial001.py +++ b/docs_src/tutorial/fastapi/relationships/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class TeamBase(SQLModel): @@ -92,7 +93,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -146,7 +150,10 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.model_validate(team) + if IS_PYDANTIC_V2: + db_team = Team.model_validate(team) + else: + db_team = Team.from_orm(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py index a23dfad5a8..f305f75194 100644 --- a/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py +++ b/docs_src/tutorial/fastapi/session_with_dependency/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -54,7 +55,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/docs_src/tutorial/fastapi/teams/tutorial001.py b/docs_src/tutorial/fastapi/teams/tutorial001.py index cc73bb52cb..46ea0f933c 100644 --- a/docs_src/tutorial/fastapi/teams/tutorial001.py +++ b/docs_src/tutorial/fastapi/teams/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import Depends, FastAPI, HTTPException, Query from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class TeamBase(SQLModel): @@ -83,7 +84,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) @@ -137,7 +141,10 @@ def delete_hero(*, session: Session = Depends(get_session), hero_id: int): @app.post("/teams/", response_model=TeamRead) def create_team(*, session: Session = Depends(get_session), team: TeamCreate): - db_team = Team.model_validate(team) + if IS_PYDANTIC_V2: + db_team = Team.model_validate(team) + else: + db_team = Team.from_orm(team) session.add(db_team) session.commit() session.refresh(db_team) diff --git a/docs_src/tutorial/fastapi/update/tutorial001.py b/docs_src/tutorial/fastapi/update/tutorial001.py index 28462bff17..93dfa7496a 100644 --- a/docs_src/tutorial/fastapi/update/tutorial001.py +++ b/docs_src/tutorial/fastapi/update/tutorial001.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, Query from sqlmodel import Field, Session, SQLModel, create_engine, select +from sqlmodel.compat import IS_PYDANTIC_V2 class HeroBase(SQLModel): @@ -50,7 +51,10 @@ def on_startup(): @app.post("/heroes/", response_model=HeroRead) def create_hero(hero: HeroCreate): with Session(engine) as session: - db_hero = Hero.model_validate(hero) + if IS_PYDANTIC_V2: + db_hero = Hero.model_validate(hero) + else: + db_hero = Hero.from_orm(hero) session.add(db_hero) session.commit() session.refresh(db_hero) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 41abbef251..3ffcd1cd34 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -11,6 +11,7 @@ get_args, get_origin, ) +from types import NoneType from pydantic import VERSION as PYDANTIC_VERSION @@ -23,7 +24,8 @@ else: from pydantic import BaseConfig # noqa from pydantic.fields import ModelField # noqa - from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType # noqa + from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType, SHAPE_SINGLETON + from pydantic.typing import resolve_annotations if TYPE_CHECKING: from .main import FieldInfo, RelationshipInfo, SQLModel, SQLModelMetaclass @@ -34,32 +36,26 @@ InstanceOrType = Union[T, Type[T]] if IS_PYDANTIC_V2: + PydanticModelConfig = ConfigDict class SQLModelConfig(ConfigDict, total=False): table: Optional[bool] - read_from_attributes: Optional[bool] registry: Optional[Any] else: + PydanticModelConfig = BaseConfig class SQLModelConfig(BaseConfig): table: Optional[bool] = None - read_from_attributes: Optional[bool] = None registry: Optional[Any] = None - def __getitem__(self, item: str) -> Any: - return self.__getattr__(item) - - def __setitem__(self, item: str, value: Any) -> None: - return self.__setattr__(item, value) - # Inspired from https://github.com/roman-right/beanie/blob/main/beanie/odm/utils/pydantic.py def get_model_config(model: type) -> Optional[SQLModelConfig]: if IS_PYDANTIC_V2: return getattr(model, "model_config", None) else: - return getattr(model, "Config", None) + return getattr(model, "__config__", None) def get_config_value( @@ -68,7 +64,7 @@ def get_config_value( if IS_PYDANTIC_V2: return model.model_config.get(parameter, default) else: - return getattr(model.Config, parameter, default) + return getattr(model.__config__, parameter, default) def set_config_value( @@ -77,7 +73,7 @@ def set_config_value( if IS_PYDANTIC_V2: model.model_config[parameter] = value # type: ignore else: - model.Config[v1_parameter or parameter] = value # type: ignore + setattr(model.__config__, v1_parameter or parameter, value) # type: ignore def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: @@ -109,6 +105,25 @@ def set_attribute_mode(cls: Type["SQLModelMetaclass"]) -> None: else: cls.__config__.read_with_orm_mode = True # type: ignore +def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: + if IS_PYDANTIC_V2: + return class_dict.get("__annotations__", {}) + else: + return resolve_annotations(class_dict.get("__annotations__", {}),class_dict.get("__module__", None)) + +def is_table(class_dict: dict[str, Any]) -> bool: + config: SQLModelConfig = {} + if IS_PYDANTIC_V2: + config = class_dict.get("model_config", {}) + else: + config = class_dict.get("__config__", {}) + config_table = config.get("table", PydanticUndefined) + if config_table is not PydanticUndefined: + return config_table + kw_table = class_dict.get("table", PydanticUndefined) + if kw_table is not PydanticUndefined: + return kw_table + return False def get_relationship_to( name: str, @@ -167,3 +182,27 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) # So we can check for nullable value.original_default = value.default value.default = None + +def is_field_noneable(field: "FieldInfo") -> bool: + if IS_PYDANTIC_V2: + if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: + return field.nullable + if not field.is_required(): + default = getattr(field, "original_default", field.default) + if default is PydanticUndefined: + return False + if field.annotation is None or field.annotation is NoneType: + return True + if get_origin(field.annotation) is Union: + for base in get_args(field.annotation): + if base is NoneType: + return True + return False + return False + else: + if not field.required: + # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) + return field.allow_none and ( + field.shape != SHAPE_SINGLETON or not field.sub_fields + ) + return False \ No newline at end of file diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py index 650dd92b27..2401609ae1 100644 --- a/sqlmodel/engine/result.py +++ b/sqlmodel/engine/result.py @@ -1,4 +1,4 @@ -from typing import Generic, Iterator, Optional, Sequence, Tuple, TypeVar +from typing import Generic, Iterator, List, Optional, TypeVar from sqlalchemy.engine.result import Result as _Result from sqlalchemy.engine.result import ScalarResult as _ScalarResult @@ -6,34 +6,73 @@ _T = TypeVar("_T") -class ScalarResult(_ScalarResult[_T], Generic[_T]): - def all(self) -> Sequence[_T]: +class ScalarResult(_ScalarResult, Generic[_T]): + def all(self) -> List[_T]: return super().all() - def partitions(self, size: Optional[int] = None) -> Iterator[Sequence[_T]]: + def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: return super().partitions(size) - def fetchall(self) -> Sequence[_T]: + def fetchall(self) -> List[_T]: return super().fetchall() - def fetchmany(self, size: Optional[int] = None) -> Sequence[_T]: + def fetchmany(self, size: Optional[int] = None) -> List[_T]: return super().fetchmany(size) def __iter__(self) -> Iterator[_T]: return super().__iter__() def __next__(self) -> _T: - return super().__next__() + return super().__next__() # type: ignore def first(self) -> Optional[_T]: return super().first() - def one_or_none(self) -> Optional[_T]: return super().one_or_none() def one(self) -> _T: - return super().one() + return super().one() # type: ignore + + +class Result(_Result, Generic[_T]): + def scalars(self, index: int = 0) -> ScalarResult[_T]: + return super().scalars(index) # type: ignore + + def __iter__(self) -> Iterator[_T]: # type: ignore + return super().__iter__() # type: ignore + + def __next__(self) -> _T: # type: ignore + return super().__next__() # type: ignore + + def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: # type: ignore + return super().partitions(size) # type: ignore + + def fetchall(self) -> List[_T]: # type: ignore + return super().fetchall() # type: ignore + + def fetchone(self) -> Optional[_T]: # type: ignore + return super().fetchone() # type: ignore + + def fetchmany(self, size: Optional[int] = None) -> List[_T]: # type: ignore + return super().fetchmany() # type: ignore + + def all(self) -> List[_T]: # type: ignore + return super().all() # type: ignore + + def first(self) -> Optional[_T]: # type: ignore + return super().first() # type: ignore + + def one_or_none(self) -> Optional[_T]: # type: ignore + return super().one_or_none() # type: ignore + + def scalar_one(self) -> _T: + return super().scalar_one() # type: ignore + + def scalar_one_or_none(self) -> Optional[_T]: + return super().scalar_one_or_none() + def one(self) -> _T: # type: ignore + return super().one() # type: ignore -class Result(_Result[Tuple[_T]], Generic[_T]): - ... + def scalar(self) -> Optional[_T]: + return super().scalar() \ No newline at end of file diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index f500c44dc2..d4678b0370 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -91,4 +91,4 @@ async def exec( execution_options=execution_options, bind_arguments=bind_arguments, **kw, - ) + ) \ No newline at end of file diff --git a/sqlmodel/main.py b/sqlmodel/main.py index d5b69b28a8..1100fb7da7 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -5,7 +5,6 @@ from decimal import Decimal from enum import Enum from pathlib import Path -from types import NoneType from typing import ( AbstractSet, Any, @@ -22,8 +21,6 @@ TypeVar, Union, cast, - get_args, - get_origin, overload, ) @@ -64,6 +61,10 @@ set_config_value, set_empty_defaults, set_fields_set, + is_table, + is_field_noneable, + PydanticModelConfig, + get_annotations ) from .sql.sqltypes import GUID, AutoString @@ -71,6 +72,7 @@ from pydantic.errors import ConfigError, DictError from pydantic.main import validate_model from pydantic.utils import ROOT_KEY + from pydantic.typing import resolve_annotations _T = TypeVar("_T") @@ -412,7 +414,7 @@ def __new__( ) -> Any: relationships: Dict[str, RelationshipInfo] = {} dict_for_pydantic = {} - original_annotations = class_dict.get("__annotations__", {}) + original_annotations = get_annotations(class_dict) pydantic_annotations = {} relationship_annotations = {} for k, v in class_dict.items(): @@ -436,19 +438,16 @@ def __new__( # superclass causing an error allowed_config_kwargs: Set[str] = { key - for key in dir(SQLModelConfig) + for key in dir(PydanticModelConfig) if not ( key.startswith("__") and key.endswith("__") ) # skip dunder methods and attributes } - pydantic_kwargs = kwargs.copy() config_kwargs = { - key: pydantic_kwargs.pop(key) - for key in pydantic_kwargs.keys() & allowed_config_kwargs + key: kwargs[key] + for key in kwargs.keys() & allowed_config_kwargs } - config_table = getattr( - class_dict.get("Config", object()), "table", False - ) or kwargs.get("table", False) + config_table = is_table(class_dict) if config_table: set_empty_defaults(pydantic_annotations, dict_used) @@ -607,7 +606,7 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore index = getattr(field.field_info, "index", PydanticUndefined) if index is PydanticUndefined: index = False - nullable = not primary_key and _is_field_noneable(field) + nullable = not primary_key and is_field_noneable(field) # Override derived nullability if the nullable property is set explicitly # on the field field_nullable = getattr(field.field_info, "nullable", PydanticUndefined) # noqa: B009 @@ -879,19 +878,3 @@ def _calculate_keys( return keys - -def _is_field_noneable(field: FieldInfo) -> bool: - if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: - return field.nullable - if not field.is_required(): - default = getattr(field, "original_default", field.default) - if default is PydanticUndefined: - return False - if field.annotation is None or field.annotation is NoneType: - return True - if get_origin(field.annotation) is Union: - for base in get_args(field.annotation): - if base is NoneType: - return True - return False - return False diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index 29aba05eec..db5faf6d7f 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -1,21 +1,10 @@ -from typing import ( - Any, - Dict, - Literal, - Mapping, - Optional, - Sequence, - Type, - TypeVar, - Union, - overload, -) +from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload from sqlalchemy import util -from sqlalchemy.orm import Mapper as _Mapper from sqlalchemy.orm import Query as _Query from sqlalchemy.orm import Session as _Session from sqlalchemy.sql.base import Executable as _Executable +from typing_extensions import Literal from ..engine.result import Result, ScalarResult from ..sql.base import Executable @@ -32,7 +21,7 @@ def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + bind_arguments: Optional[Mapping[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, @@ -46,7 +35,7 @@ def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + bind_arguments: Optional[Mapping[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, @@ -63,7 +52,7 @@ def exec( *, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + bind_arguments: Optional[Mapping[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, @@ -85,8 +74,8 @@ def execute( self, statement: _Executable, params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, - execution_options: Mapping[str, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, **kw: Any, @@ -129,18 +118,17 @@ def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]": Or otherwise you might want to use `session.execute()` instead of `session.query()`. """ - return super().query(*entities, **kwargs) # type: ignore + return super().query(*entities, **kwargs) def get( self, - entity: Union[Type[_TSelectParam], "_Mapper[_TSelectParam]"], + entity: Type[_TSelectParam], ident: Any, options: Optional[Sequence[Any]] = None, populate_existing: bool = False, with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None, identity_token: Optional[Any] = None, - execution_options: Mapping[Any, Any] = util.EMPTY_DICT, - bind_arguments: Optional[Dict[str, Any]] = None, + execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT, ) -> Optional[_TSelectParam]: return super().get( entity, @@ -150,5 +138,4 @@ def get( with_for_update=with_for_update, identity_token=identity_token, execution_options=execution_options, - bind_arguments=bind_arguments, - ) + ) \ No newline at end of file diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index 8cb2309228..5d5cddaee3 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -23,15 +23,7 @@ _TSelect = TypeVar("_TSelect") -class Select(_Select[Tuple[_TSelect]], Generic[_TSelect]): - inherit_cache = True - - -# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different -# purpose. This is the same as a normal SQLAlchemy Select class where there's only one -# entity, so the result will be converted to a scalar by default. This way writing -# for loops on the results will feel natural. -class SelectOfScalar(_Select[Tuple[_TSelect]], Generic[_TSelect]): +class Select(_Select, Generic[_TSelect]): inherit_cache = True @@ -125,12 +117,12 @@ class SelectOfScalar(_Select, Generic[_TSelect]): @overload -def select(entity_0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore +def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore ... @overload -def select(entity_0: Type[_TModel_0]) -> SelectOfScalar[_TModel_0]: # type: ignore +def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore ... @@ -141,6 +133,7 @@ def select(entity_0: Type[_TModel_0]) -> SelectOfScalar[_TModel_0]: # type: ign def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1]]: ... @@ -149,6 +142,7 @@ def select( # type: ignore def select( # type: ignore entity_0: _TScalar_0, entity_1: Type[_TModel_1], + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1]]: ... @@ -157,6 +151,7 @@ def select( # type: ignore def select( # type: ignore entity_0: Type[_TModel_0], entity_1: _TScalar_1, + **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1]]: ... @@ -165,6 +160,7 @@ def select( # type: ignore def select( # type: ignore entity_0: Type[_TModel_0], entity_1: Type[_TModel_1], + **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1]]: ... @@ -174,6 +170,7 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, entity_2: _TScalar_2, + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2]]: ... @@ -183,6 +180,7 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: _TScalar_1, entity_2: Type[_TModel_2], + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2]]: ... @@ -192,6 +190,7 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: Type[_TModel_1], entity_2: _TScalar_2, + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2]]: ... @@ -201,6 +200,7 @@ def select( # type: ignore entity_0: _TScalar_0, entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2]]: ... @@ -210,6 +210,7 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: _TScalar_1, entity_2: _TScalar_2, + **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2]]: ... @@ -219,6 +220,7 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: _TScalar_1, entity_2: Type[_TModel_2], + **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2]]: ... @@ -228,6 +230,7 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: Type[_TModel_1], entity_2: _TScalar_2, + **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2]]: ... @@ -237,6 +240,7 @@ def select( # type: ignore entity_0: Type[_TModel_0], entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], + **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2]]: ... @@ -247,6 +251,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]: ... @@ -257,6 +262,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TModel_3]]: ... @@ -267,6 +273,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TScalar_3]]: ... @@ -277,6 +284,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TModel_3]]: ... @@ -287,6 +295,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TScalar_3]]: ... @@ -297,6 +306,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TModel_3]]: ... @@ -307,6 +317,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TScalar_3]]: ... @@ -317,6 +328,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TModel_3]]: ... @@ -327,6 +339,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TScalar_3]]: ... @@ -337,6 +350,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: _TScalar_2, entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TModel_3]]: ... @@ -347,6 +361,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TScalar_3]]: ... @@ -357,6 +372,7 @@ def select( # type: ignore entity_1: _TScalar_1, entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TModel_3]]: ... @@ -367,6 +383,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TScalar_3]]: ... @@ -377,6 +394,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: _TScalar_2, entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TModel_3]]: ... @@ -387,6 +405,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: _TScalar_3, + **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TScalar_3]]: ... @@ -397,6 +416,7 @@ def select( # type: ignore entity_1: Type[_TModel_1], entity_2: Type[_TModel_2], entity_3: Type[_TModel_3], + **kw: Any, ) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TModel_3]]: ... @@ -404,14 +424,14 @@ def select( # type: ignore # Generated overloads end -def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore +def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore if len(entities) == 1: - return SelectOfScalar(*entities) # type: ignore - return Select(*entities) # type: ignore + return SelectOfScalar._create(*entities, **kw) # type: ignore + return Select._create(*entities, **kw) # type: ignore # TODO: add several @overload from Python types to SQLAlchemy equivalents def col(column_expression: Any) -> ColumnClause: # type: ignore if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") - return column_expression # type: ignore + return column_expression \ No newline at end of file diff --git a/sqlmodel/sql/expression.py.jinja2 b/sqlmodel/sql/expression.py.jinja2 index 55f4a1ac3e..b3acb22c95 100644 --- a/sqlmodel/sql/expression.py.jinja2 +++ b/sqlmodel/sql/expression.py.jinja2 @@ -20,14 +20,7 @@ from sqlalchemy.sql.expression import Select as _Select _TSelect = TypeVar("_TSelect") -class Select(_Select[Tuple[_TSelect]], Generic[_TSelect]): - inherit_cache = True - -# This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different -# purpose. This is the same as a normal SQLAlchemy Select class where there's only one -# entity, so the result will be converted to a scalar by default. This way writing -# for loops on the results will feel natural. -class SelectOfScalar(_Select[Tuple[_TSelect]], Generic[_TSelect]): +class Select(_Select, Generic[_TSelect]): inherit_cache = True # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different @@ -42,7 +35,6 @@ if TYPE_CHECKING: # pragma: no cover # Generated TypeVars start - {% for i in range(number_of_types) %} _TScalar_{{ i }} = TypeVar( "_TScalar_{{ i }}", @@ -66,12 +58,12 @@ _TModel_{{ i }} = TypeVar("_TModel_{{ i }}", bound="SQLModel") # Generated TypeVars end @overload -def select(entity_0: _TScalar_0) -> SelectOfScalar[_TScalar_0]: # type: ignore +def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore ... @overload -def select(entity_0: Type[_TModel_0]) -> SelectOfScalar[_TModel_0]: # type: ignore +def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore ... @@ -81,7 +73,7 @@ def select(entity_0: Type[_TModel_0]) -> SelectOfScalar[_TModel_0]: # type: ign @overload def select( # type: ignore - {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %} + {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}**kw: Any, ) -> Select[Tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]: ... @@ -89,15 +81,14 @@ def select( # type: ignore # Generated overloads end - -def select(*entities: Any) -> Union[Select, SelectOfScalar]: # type: ignore +def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore if len(entities) == 1: - return SelectOfScalar(*entities) # type: ignore - return Select(*entities) # type: ignore + return SelectOfScalar._create(*entities, **kw) # type: ignore + return Select._create(*entities, **kw) # type: ignore # TODO: add several @overload from Python types to SQLAlchemy equivalents def col(column_expression: Any) -> ColumnClause: # type: ignore if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") - return column_expression # type: ignore + return column_expression \ No newline at end of file diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index aa30950702..33bc45cdf4 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -15,7 +15,7 @@ class AutoString(types.TypeDecorator): # type: ignore def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": impl = cast(types.String, self.impl) if impl.length is None and dialect.name == "mysql": - return dialect.type_descriptor(types.String(self.mysql_default_length)) + return dialect.type_descriptor(types.String(self.mysql_default_length)) # type: ignore return super().load_dialect_impl(dialect) @@ -34,9 +34,9 @@ class GUID(types.TypeDecorator): # type: ignore def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore if dialect.name == "postgresql": - return dialect.type_descriptor(UUID()) + return dialect.type_descriptor(UUID()) # type: ignore else: - return dialect.type_descriptor(CHAR(32)) + return dialect.type_descriptor(CHAR(32)) # type: ignore def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]: if value is None: @@ -56,4 +56,4 @@ def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UU else: if not isinstance(value, uuid.UUID): value = uuid.UUID(value) - return cast(uuid.UUID, value) + return cast(uuid.UUID, value) \ No newline at end of file diff --git a/tests/test_validation.py b/tests/test_validation.py index 4183986a06..be648c1015 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -5,8 +5,37 @@ from pydantic.error_wrappers import ValidationError from sqlmodel import SQLModel +from .conftest import needs_pydanticv1, needs_pydanticv2 -def test_validation(clear_sqlmodel): +@needs_pydanticv1 +def test_validation_pydantic_v1(clear_sqlmodel): + """Test validation of implicit and explicit None values. + + # For consistency with pydantic, validators are not to be called on + # arguments that are not explicitly provided. + + https://github.com/tiangolo/sqlmodel/issues/230 + https://github.com/samuelcolvin/pydantic/issues/1223 + + """ + + class Hero(SQLModel): + name: Optional[str] = None + secret_name: Optional[str] = None + age: Optional[int] = None + + @field_validator("name", "secret_name", "age") + def reject_none(cls, v): + assert v is not None + return v + + Hero.from_orm({"age": 25}) + + with pytest.raises(ValidationError): + Hero.from_orm({"name": None, "age": 25}) + +@needs_pydanticv2 +def test_validation_pydantic_v2(clear_sqlmodel): """Test validation of implicit and explicit None values. # For consistency with pydantic, validators are not to be called on From 254fb132eef75d73c29527218354114d8ac52186 Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Tue, 14 Nov 2023 08:50:46 +0000 Subject: [PATCH 26/29] Liniting --- .../fastapi/app_testing/tutorial001/main.py | 1 + .../tutorial/fastapi/read_one/tutorial001.py | 2 +- sqlmodel/compat.py | 39 ++++++++++++------- sqlmodel/engine/result.py | 3 +- sqlmodel/ext/asyncio/session.py | 2 +- sqlmodel/main.py | 39 +++++++++---------- sqlmodel/orm/session.py | 2 +- sqlmodel/sql/expression.py | 2 +- sqlmodel/sql/sqltypes.py | 2 +- tests/test_validation.py | 2 + 10 files changed, 53 insertions(+), 41 deletions(-) diff --git a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py index cf2f4da233..f305f75194 100644 --- a/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py +++ b/docs_src/tutorial/fastapi/app_testing/tutorial001/main.py @@ -52,6 +52,7 @@ def get_session(): def on_startup(): create_db_and_tables() + @app.post("/heroes/", response_model=HeroRead) def create_hero(*, session: Session = Depends(get_session), hero: HeroCreate): if IS_PYDANTIC_V2: diff --git a/docs_src/tutorial/fastapi/read_one/tutorial001.py b/docs_src/tutorial/fastapi/read_one/tutorial001.py index 4cdf898922..aa805d6c8f 100644 --- a/docs_src/tutorial/fastapi/read_one/tutorial001.py +++ b/docs_src/tutorial/fastapi/read_one/tutorial001.py @@ -2,9 +2,9 @@ from fastapi import FastAPI, HTTPException from sqlmodel import Field, Session, SQLModel, create_engine, select - from sqlmodel.compat import IS_PYDANTIC_V2 + class HeroBase(SQLModel): name: str = Field(index=True) secret_name: str diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 3ffcd1cd34..37a7c3716a 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -1,3 +1,4 @@ +from types import NoneType from typing import ( TYPE_CHECKING, Any, @@ -11,7 +12,6 @@ get_args, get_origin, ) -from types import NoneType from pydantic import VERSION as PYDANTIC_VERSION @@ -20,11 +20,12 @@ if IS_PYDANTIC_V2: from pydantic import ConfigDict - from pydantic_core import PydanticUndefined as PydanticUndefined, PydanticUndefinedType as PydanticUndefinedType # noqa + from pydantic_core import PydanticUndefined as PydanticUndefined # noqa + from pydantic_core import PydanticUndefinedType as PydanticUndefinedType else: - from pydantic import BaseConfig # noqa - from pydantic.fields import ModelField # noqa - from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType, SHAPE_SINGLETON + from pydantic import BaseConfig # noqa + from pydantic.fields import ModelField # noqa + from pydantic.fields import Undefined as PydanticUndefined, SHAPE_SINGLETON from pydantic.typing import resolve_annotations if TYPE_CHECKING: @@ -68,26 +69,29 @@ def get_config_value( def set_config_value( - model: InstanceOrType["SQLModel"], parameter: str, value: Any, v1_parameter: str = None + model: InstanceOrType["SQLModel"], + parameter: str, + value: Any, + v1_parameter: str = None, ) -> None: if IS_PYDANTIC_V2: - model.model_config[parameter] = value # type: ignore + model.model_config[parameter] = value # type: ignore else: setattr(model.__config__, v1_parameter or parameter, value) # type: ignore def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: if IS_PYDANTIC_V2: - return model.model_fields # type: ignore + return model.model_fields # type: ignore else: - return model.__fields__ # type: ignore + return model.__fields__ # type: ignore def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]: if IS_PYDANTIC_V2: return model.__pydantic_fields_set__ else: - return model.__fields_set__ # type: ignore + return model.__fields_set__ # type: ignore def set_fields_set( @@ -103,13 +107,17 @@ def set_attribute_mode(cls: Type["SQLModelMetaclass"]) -> None: if IS_PYDANTIC_V2: cls.model_config["read_from_attributes"] = True else: - cls.__config__.read_with_orm_mode = True # type: ignore + cls.__config__.read_with_orm_mode = True # type: ignore + def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: if IS_PYDANTIC_V2: return class_dict.get("__annotations__", {}) else: - return resolve_annotations(class_dict.get("__annotations__", {}),class_dict.get("__module__", None)) + return resolve_annotations( + class_dict.get("__annotations__", {}), class_dict.get("__module__", None) + ) + def is_table(class_dict: dict[str, Any]) -> bool: config: SQLModelConfig = {} @@ -125,6 +133,7 @@ def is_table(class_dict: dict[str, Any]) -> bool: return kw_table return False + def get_relationship_to( name: str, rel_info: "RelationshipInfo", @@ -170,6 +179,7 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) """ if IS_PYDANTIC_V2: from .main import FieldInfo + # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything for key in annotations.keys(): value = class_dict.get(key, PydanticUndefined) @@ -180,9 +190,10 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) value.default in (PydanticUndefined, Ellipsis) ) and value.default_factory is None: # So we can check for nullable - value.original_default = value.default + value.original_default = value.default value.default = None + def is_field_noneable(field: "FieldInfo") -> bool: if IS_PYDANTIC_V2: if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: @@ -205,4 +216,4 @@ def is_field_noneable(field: "FieldInfo") -> bool: return field.allow_none and ( field.shape != SHAPE_SINGLETON or not field.sub_fields ) - return False \ No newline at end of file + return False diff --git a/sqlmodel/engine/result.py b/sqlmodel/engine/result.py index 2401609ae1..7a25422227 100644 --- a/sqlmodel/engine/result.py +++ b/sqlmodel/engine/result.py @@ -27,6 +27,7 @@ def __next__(self) -> _T: def first(self) -> Optional[_T]: return super().first() + def one_or_none(self) -> Optional[_T]: return super().one_or_none() @@ -75,4 +76,4 @@ def one(self) -> _T: # type: ignore return super().one() # type: ignore def scalar(self) -> Optional[_T]: - return super().scalar() \ No newline at end of file + return super().scalar() diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index d4678b0370..f500c44dc2 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -91,4 +91,4 @@ async def exec( execution_options=execution_options, bind_arguments=bind_arguments, **kw, - ) \ No newline at end of file + ) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 1100fb7da7..7a05ef3888 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -52,19 +52,19 @@ from .compat import ( IS_PYDANTIC_V2, NoArgAnyCallable, + PydanticModelConfig, PydanticUndefined, PydanticUndefinedType, SQLModelConfig, + get_annotations, get_config_value, get_model_fields, get_relationship_to, + is_field_noneable, + is_table, set_config_value, set_empty_defaults, set_fields_set, - is_table, - is_field_noneable, - PydanticModelConfig, - get_annotations ) from .sql.sqltypes import GUID, AutoString @@ -72,7 +72,6 @@ from pydantic.errors import ConfigError, DictError from pydantic.main import validate_model from pydantic.utils import ROOT_KEY - from pydantic.typing import resolve_annotations _T = TypeVar("_T") @@ -444,8 +443,7 @@ def __new__( ) # skip dunder methods and attributes } config_kwargs = { - key: kwargs[key] - for key in kwargs.keys() & allowed_config_kwargs + key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } config_table = is_table(class_dict) if config_table: @@ -690,7 +688,7 @@ def __init__(__pydantic_self__, **data: Any) -> None: # settable attribute if IS_PYDANTIC_V2: old_dict = __pydantic_self__.__dict__.copy() - __pydantic_self__.super().__init__(**data) # noqa + __pydantic_self__.super().__init__(**data) # noqa __pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__} non_pydantic_keys = data.keys() - __pydantic_self__.model_fields else: @@ -699,7 +697,7 @@ def __init__(__pydantic_self__, **data: Any) -> None: ) # Only raise errors if not a SQLModel model if ( - not getattr(__pydantic_self__.__config__, "table", False) # noqa + not getattr(__pydantic_self__.__config__, "table", False) # noqa and validation_error ): raise validation_error @@ -764,7 +762,7 @@ def from_orm( cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None ) -> _TSQLModel: # Duplicated from Pydantic - if not cls.__config__.orm_mode: # noqa: attr-defined + if not cls.__config__.orm_mode: # noqa: attr-defined raise ConfigError( "You must have the config attribute orm_mode=True to use from_orm" ) @@ -777,7 +775,7 @@ def from_orm( if update is not None: obj = {**obj, **update} # End SQLModel support dict - if not getattr(cls.__config__, "table", False): # noqa + if not getattr(cls.__config__, "table", False): # noqa # If not table, normal Pydantic code m: _TSQLModel = cls.__new__(cls) else: @@ -788,21 +786,21 @@ def from_orm( if validation_error: raise validation_error # Updated to trigger SQLAlchemy internal handling - if not getattr(cls.__config__, "table", False): # noqa + if not getattr(cls.__config__, "table", False): # noqa object.__setattr__(m, "__dict__", values) else: for key, value in values.items(): setattr(m, key, value) # Continue with standard Pydantic logic object.__setattr__(m, "__fields_set__", fields_set) - m._init_private_attributes() # noqa + m._init_private_attributes() # noqa return m @classmethod def parse_obj( cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None ) -> _TSQLModel: - obj = cls._enforce_dict_if_root(obj) # noqa + obj = cls._enforce_dict_if_root(obj) # noqa # SQLModel, support update dict if update is not None: obj = {**obj, **update} @@ -814,7 +812,7 @@ def parse_obj( def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel: if isinstance(value, cls): return ( - value.copy() if cls.__config__.copy_on_model_validation else value # noqa + value.copy() if cls.__config__.copy_on_model_validation else value # noqa ) value = cls._enforce_dict_if_root(value) @@ -826,9 +824,9 @@ def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel: # Reset fields set, this would have been done in Pydantic in __init__ object.__setattr__(model, "__fields_set__", fields_set) return model - elif cls.__config__.orm_mode: # noqa + elif cls.__config__.orm_mode: # noqa return cls.from_orm(value) - elif cls.__custom_root_type__: # noqa + elif cls.__custom_root_type__: # noqa return cls.parse_obj(value) else: try: @@ -852,12 +850,12 @@ def _calculate_keys( # Do not include relationships as that would easily lead to infinite # recursion, or traversing the whole database return ( - self.__fields__.keys() # noqa + self.__fields__.keys() # noqa ) # | self.__sqlmodel_relationships__.keys() keys: AbstractSet[str] if exclude_unset: - keys = self.__fields_set__.copy() # noqa + keys = self.__fields_set__.copy() # noqa else: # Original in Pydantic: # keys = self.__dict__.keys() @@ -865,7 +863,7 @@ def _calculate_keys( # Do not include relationships as that would easily lead to infinite # recursion, or traversing the whole database keys = ( - self.__fields__.keys() # noqa + self.__fields__.keys() # noqa ) # | self.__sqlmodel_relationships__.keys() if include is not None: keys &= include.keys() @@ -877,4 +875,3 @@ def _calculate_keys( keys -= {k for k, v in exclude.items() if _value_items_is_true(v)} return keys - diff --git a/sqlmodel/orm/session.py b/sqlmodel/orm/session.py index db5faf6d7f..0c70c290ae 100644 --- a/sqlmodel/orm/session.py +++ b/sqlmodel/orm/session.py @@ -138,4 +138,4 @@ def get( with_for_update=with_for_update, identity_token=identity_token, execution_options=execution_options, - ) \ No newline at end of file + ) diff --git a/sqlmodel/sql/expression.py b/sqlmodel/sql/expression.py index 5d5cddaee3..264e39cba7 100644 --- a/sqlmodel/sql/expression.py +++ b/sqlmodel/sql/expression.py @@ -434,4 +434,4 @@ def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: def col(column_expression: Any) -> ColumnClause: # type: ignore if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") - return column_expression \ No newline at end of file + return column_expression diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 33bc45cdf4..17d9b06126 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -56,4 +56,4 @@ def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UU else: if not isinstance(value, uuid.UUID): value = uuid.UUID(value) - return cast(uuid.UUID, value) \ No newline at end of file + return cast(uuid.UUID, value) diff --git a/tests/test_validation.py b/tests/test_validation.py index be648c1015..e200a1e73d 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -7,6 +7,7 @@ from .conftest import needs_pydanticv1, needs_pydanticv2 + @needs_pydanticv1 def test_validation_pydantic_v1(clear_sqlmodel): """Test validation of implicit and explicit None values. @@ -34,6 +35,7 @@ def reject_none(cls, v): with pytest.raises(ValidationError): Hero.from_orm({"name": None, "age": 25}) + @needs_pydanticv2 def test_validation_pydantic_v2(clear_sqlmodel): """Test validation of implicit and explicit None values. From ab075145f34b6fea1b6600a63ff00ae0411fce8d Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Tue, 14 Nov 2023 08:57:30 +0000 Subject: [PATCH 27/29] Linter --- sqlmodel/compat.py | 26 +++++++++++--------------- sqlmodel/main.py | 1 - 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index 37a7c3716a..e9a63f6144 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -19,13 +19,13 @@ if IS_PYDANTIC_V2: - from pydantic import ConfigDict + from pydantic import ConfigDict as PydanticModelConfig from pydantic_core import PydanticUndefined as PydanticUndefined # noqa from pydantic_core import PydanticUndefinedType as PydanticUndefinedType else: - from pydantic import BaseConfig # noqa + from pydantic import BaseConfig as PydanticModelConfig from pydantic.fields import ModelField # noqa - from pydantic.fields import Undefined as PydanticUndefined, SHAPE_SINGLETON + from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType, SHAPE_SINGLETON # noqa from pydantic.typing import resolve_annotations if TYPE_CHECKING: @@ -37,16 +37,12 @@ InstanceOrType = Union[T, Type[T]] if IS_PYDANTIC_V2: - PydanticModelConfig = ConfigDict - - class SQLModelConfig(ConfigDict, total=False): + class SQLModelConfig(PydanticModelConfig, total=False): table: Optional[bool] registry: Optional[Any] else: - PydanticModelConfig = BaseConfig - - class SQLModelConfig(BaseConfig): + class SQLModelConfig(PydanticModelConfig): table: Optional[bool] = None registry: Optional[Any] = None @@ -72,7 +68,7 @@ def set_config_value( model: InstanceOrType["SQLModel"], parameter: str, value: Any, - v1_parameter: str = None, + v1_parameter: Optional[str] = None, ) -> None: if IS_PYDANTIC_V2: model.model_config[parameter] = value # type: ignore @@ -82,14 +78,14 @@ def set_config_value( def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: if IS_PYDANTIC_V2: - return model.model_fields # type: ignore + return model.model_fields # type: ignore else: return model.__fields__ # type: ignore def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]: if IS_PYDANTIC_V2: - return model.__pydantic_fields_set__ + return model.__pydantic_fields_set__ # type: ignore else: return model.__fields_set__ # type: ignore @@ -127,10 +123,10 @@ def is_table(class_dict: dict[str, Any]) -> bool: config = class_dict.get("__config__", {}) config_table = config.get("table", PydanticUndefined) if config_table is not PydanticUndefined: - return config_table + return config_table # type: ignore kw_table = class_dict.get("table", PydanticUndefined) if kw_table is not PydanticUndefined: - return kw_table + return kw_table # type: ignore return False @@ -197,7 +193,7 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) def is_field_noneable(field: "FieldInfo") -> bool: if IS_PYDANTIC_V2: if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: - return field.nullable + return field.nullable # type: ignore if not field.is_required(): default = getattr(field, "original_default", field.default) if default is PydanticUndefined: diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 7a05ef3888..435b9f31f3 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -739,7 +739,6 @@ def __tablename__(cls) -> str: return cls.__name__.lower() if IS_PYDANTIC_V2: - @classmethod def model_validate( cls: type[_TSQLModel], From 4a4161a52776de966f2b4008a72acb51637f778d Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Tue, 14 Nov 2023 09:03:11 +0000 Subject: [PATCH 28/29] Move lint to after tests to see tests --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 979a08b2b2..3b7cfaf95a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -58,8 +58,6 @@ jobs: - name: Install Pydantic v2 if: matrix.pydantic-version == 'pydantic-v2' run: pip install "pydantic>=2.0.2,<3.0.0" - - name: Lint - run: python -m poetry run bash scripts/lint.sh - run: mkdir coverage - name: Test run: python -m poetry run bash scripts/test.sh @@ -71,6 +69,8 @@ jobs: with: name: coverage path: coverage + - name: Lint + run: python -m poetry run bash scripts/lint.sh coverage-combine: needs: - test From e2d4d1fa833c2ea33c64159ef35d8ce614270e28 Mon Sep 17 00:00:00 2001 From: Anton De Meester Date: Thu, 16 Nov 2023 09:50:27 +0000 Subject: [PATCH 29/29] Make most tests succeed Only need to fix OPEN API things I think --- pyproject.toml | 2 +- sqlmodel/compat.py | 229 +++++++++++++++++++++++++++++++-- sqlmodel/main.py | 143 ++------------------ tests/test_enums.py | 47 ++++++- tests/test_instance_no_args.py | 37 +++++- tests/test_nullable.py | 2 +- tests/test_validation.py | 9 +- 7 files changed, 314 insertions(+), 155 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6777d69018..f104631655 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ pillow = "^9.3.0" cairosvg = "^2.5.2" mdx-include = "^1.4.1" coverage = {extras = ["toml"], version = ">=6.2,<8.0"} -fastapi = "^0.68.1" +fastapi = "^0.100.0" requests = "^2.26.0" ruff = "^0.1.2" diff --git a/sqlmodel/compat.py b/sqlmodel/compat.py index e9a63f6144..dbd22053a8 100644 --- a/sqlmodel/compat.py +++ b/sqlmodel/compat.py @@ -1,3 +1,9 @@ +import ipaddress +import uuid +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from enum import Enum +from pathlib import Path from types import NoneType from typing import ( TYPE_CHECKING, @@ -6,26 +12,47 @@ Dict, ForwardRef, Optional, + Sequence, Type, TypeVar, Union, + cast, get_args, get_origin, ) from pydantic import VERSION as PYDANTIC_VERSION +from sqlalchemy import ( + Boolean, + Column, + Date, + DateTime, + Float, + ForeignKey, + Integer, + Interval, + Numeric, +) +from sqlalchemy import Enum as sa_Enum +from sqlalchemy.sql.sqltypes import LargeBinary, Time + +from .sql.sqltypes import GUID, AutoString IS_PYDANTIC_V2 = int(PYDANTIC_VERSION.split(".")[0]) >= 2 if IS_PYDANTIC_V2: from pydantic import ConfigDict as PydanticModelConfig + from pydantic._internal._fields import PydanticMetadata + from pydantic._internal._model_construction import ModelMetaclass from pydantic_core import PydanticUndefined as PydanticUndefined # noqa from pydantic_core import PydanticUndefinedType as PydanticUndefinedType else: from pydantic import BaseConfig as PydanticModelConfig - from pydantic.fields import ModelField # noqa - from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType, SHAPE_SINGLETON # noqa + from pydantic.fields import SHAPE_SINGLETON, ModelField + from pydantic.fields import Undefined as PydanticUndefined # noqa + from pydantic.fields import UndefinedType as PydanticUndefinedType + from pydantic.main import ModelMetaclass as ModelMetaclass from pydantic.typing import resolve_annotations if TYPE_CHECKING: @@ -37,11 +64,13 @@ InstanceOrType = Union[T, Type[T]] if IS_PYDANTIC_V2: + class SQLModelConfig(PydanticModelConfig, total=False): table: Optional[bool] registry: Optional[Any] else: + class SQLModelConfig(PydanticModelConfig): table: Optional[bool] = None registry: Optional[Any] = None @@ -78,14 +107,14 @@ def set_config_value( def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: if IS_PYDANTIC_V2: - return model.model_fields # type: ignore + return model.model_fields # type: ignore else: return model.__fields__ # type: ignore def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]: if IS_PYDANTIC_V2: - return model.__pydantic_fields_set__ # type: ignore + return model.__pydantic_fields_set__ # type: ignore else: return model.__fields_set__ # type: ignore @@ -115,7 +144,9 @@ def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]: ) -def is_table(class_dict: dict[str, Any]) -> bool: +def class_dict_is_table( + class_dict: dict[str, Any], class_kwargs: dict[str, Any] +) -> bool: config: SQLModelConfig = {} if IS_PYDANTIC_V2: config = class_dict.get("model_config", {}) @@ -123,13 +154,26 @@ def is_table(class_dict: dict[str, Any]) -> bool: config = class_dict.get("__config__", {}) config_table = config.get("table", PydanticUndefined) if config_table is not PydanticUndefined: - return config_table # type: ignore - kw_table = class_dict.get("table", PydanticUndefined) + return config_table # type: ignore + kw_table = class_kwargs.get("table", PydanticUndefined) if kw_table is not PydanticUndefined: - return kw_table # type: ignore + return kw_table # type: ignore return False +def cls_is_table(cls: Type) -> bool: + if IS_PYDANTIC_V2: + config = getattr(cls, "model_config", None) + if not config: + return False + return config.get("table", False) + else: + config = getattr(cls, "__config__", None) + if not config: + return False + return getattr(config, "table", False) + + def get_relationship_to( name: str, rel_info: "RelationshipInfo", @@ -186,17 +230,15 @@ def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) value.default in (PydanticUndefined, Ellipsis) ) and value.default_factory is None: # So we can check for nullable - value.original_default = value.default value.default = None -def is_field_noneable(field: "FieldInfo") -> bool: +def _is_field_noneable(field: "FieldInfo") -> bool: if IS_PYDANTIC_V2: if getattr(field, "nullable", PydanticUndefined) is not PydanticUndefined: - return field.nullable # type: ignore + return field.nullable # type: ignore if not field.is_required(): - default = getattr(field, "original_default", field.default) - if default is PydanticUndefined: + if field.default is PydanticUndefined: return False if field.annotation is None or field.annotation is NoneType: return True @@ -212,4 +254,163 @@ def is_field_noneable(field: "FieldInfo") -> bool: return field.allow_none and ( field.shape != SHAPE_SINGLETON or not field.sub_fields ) - return False + return field.allow_none + + +def get_sqlalchemy_type(field: Any) -> Any: + if IS_PYDANTIC_V2: + field_info = field + else: + field_info = field.field_info + sa_type = getattr(field_info, "sa_type", PydanticUndefined) # noqa: B009 + if sa_type is not PydanticUndefined: + return sa_type + + type_ = get_type_from_field(field) + metadata = get_field_metadata(field) + + # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI + if issubclass(type_, Enum): + return sa_Enum(type_) + if issubclass(type_, str): + max_length = getattr(metadata, "max_length", None) + if max_length: + return AutoString(length=max_length) + return AutoString + if issubclass(type_, float): + return Float + if issubclass(type_, bool): + return Boolean + if issubclass(type_, int): + return Integer + if issubclass(type_, datetime): + return DateTime + if issubclass(type_, date): + return Date + if issubclass(type_, timedelta): + return Interval + if issubclass(type_, time): + return Time + if issubclass(type_, bytes): + return LargeBinary + if issubclass(type_, Decimal): + return Numeric( + precision=getattr(metadata, "max_digits", None), + scale=getattr(metadata, "decimal_places", None), + ) + if issubclass(type_, ipaddress.IPv4Address): + return AutoString + if issubclass(type_, ipaddress.IPv4Network): + return AutoString + if issubclass(type_, ipaddress.IPv6Address): + return AutoString + if issubclass(type_, ipaddress.IPv6Network): + return AutoString + if issubclass(type_, Path): + return AutoString + if issubclass(type_, uuid.UUID): + return GUID + raise ValueError(f"{type_} has no matching SQLAlchemy type") + + +def get_type_from_field(field: Any) -> type: + if IS_PYDANTIC_V2: + type_: type | None = field.annotation + # Resolve Optional fields + if type_ is None: + raise ValueError("Missing field type") + origin = get_origin(type_) + if origin is None: + return type_ + if origin is Union: + bases = get_args(type_) + if len(bases) > 2: + raise ValueError( + "Cannot have a (non-optional) union as a SQL alchemy field" + ) + # Non optional unions are not allowed + if bases[0] is not NoneType and bases[1] is not NoneType: + raise ValueError( + "Cannot have a (non-optional) union as a SQL alchemy field" + ) + # Optional unions are allowed + return bases[0] if bases[0] is not NoneType else bases[1] + return origin + else: + if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: + return field.type_ + raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") + + +class FakeMetadata: + max_length: Optional[int] = None + max_digits: Optional[int] = None + decimal_places: Optional[int] = None + + +def get_field_metadata(field: Any) -> Any: + if IS_PYDANTIC_V2: + for meta in field.metadata: + if isinstance(meta, PydanticMetadata): + return meta + return FakeMetadata() + else: + metadata = FakeMetadata() + metadata.max_length = field.field_info.max_length + metadata.max_digits = getattr(field.type_, "max_digits", None) + metadata.decimal_places = getattr(field.type_, "decimal_places", None) + return metadata + + +def get_column_from_field(field: Any) -> Column: # type: ignore + if IS_PYDANTIC_V2: + field_info = field + else: + field_info = field.field_info + sa_column = getattr(field_info, "sa_column", PydanticUndefined) + if isinstance(sa_column, Column): + return sa_column + sa_type = get_sqlalchemy_type(field) + primary_key = getattr(field_info, "primary_key", PydanticUndefined) + if primary_key is PydanticUndefined: + primary_key = False + index = getattr(field_info, "index", PydanticUndefined) + if index is PydanticUndefined: + index = False + nullable = not primary_key and _is_field_noneable(field) + # Override derived nullability if the nullable property is set explicitly + # on the field + field_nullable = getattr(field_info, "nullable", PydanticUndefined) # noqa: B009 + if field_nullable is not PydanticUndefined: + assert not isinstance(field_nullable, PydanticUndefinedType) + nullable = field_nullable + args = [] + foreign_key = getattr(field_info, "foreign_key", PydanticUndefined) + if foreign_key is PydanticUndefined: + foreign_key = None + unique = getattr(field_info, "unique", PydanticUndefined) + if unique is PydanticUndefined: + unique = False + if foreign_key: + assert isinstance(foreign_key, str) + args.append(ForeignKey(foreign_key)) + kwargs = { + "primary_key": primary_key, + "nullable": nullable, + "index": index, + "unique": unique, + } + sa_default = PydanticUndefined + if field_info.default_factory: + sa_default = field_info.default_factory + elif field_info.default is not PydanticUndefined: + sa_default = field_info.default + if sa_default is not PydanticUndefined: + kwargs["default"] = sa_default + sa_column_args = getattr(field_info, "sa_column_args", PydanticUndefined) + if sa_column_args is not PydanticUndefined: + args.extend(list(cast(Sequence[Any], sa_column_args))) + sa_column_kwargs = getattr(field_info, "sa_column_kwargs", PydanticUndefined) + if sa_column_kwargs is not PydanticUndefined: + kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) + return Column(sa_type, *args, **kwargs) # type: ignore diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 435b9f31f3..cb008bb663 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -1,10 +1,4 @@ -import ipaddress -import uuid import weakref -from datetime import date, datetime, time, timedelta -from decimal import Decimal -from enum import Enum -from pathlib import Path from typing import ( AbstractSet, Any, @@ -25,48 +19,37 @@ ) from pydantic import BaseModel -from pydantic.fields import SHAPE_SINGLETON, ModelField from pydantic.fields import FieldInfo as PydanticFieldInfo -from pydantic.main import ModelMetaclass from pydantic.utils import Representation from sqlalchemy import ( - Boolean, Column, - Date, - DateTime, - Float, - ForeignKey, - Integer, - Interval, - Numeric, inspect, ) -from sqlalchemy import Enum as sa_Enum from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData -from sqlalchemy.sql.sqltypes import LargeBinary, Time from .compat import ( IS_PYDANTIC_V2, + ModelMetaclass, NoArgAnyCallable, PydanticModelConfig, PydanticUndefined, PydanticUndefinedType, SQLModelConfig, + class_dict_is_table, + cls_is_table, get_annotations, + get_column_from_field, get_config_value, get_model_fields, get_relationship_to, - is_field_noneable, - is_table, set_config_value, set_empty_defaults, set_fields_set, ) -from .sql.sqltypes import GUID, AutoString if not IS_PYDANTIC_V2: from pydantic.errors import ConfigError, DictError @@ -144,7 +127,6 @@ def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: self.sa_column = sa_column self.sa_column_args = sa_column_args self.sa_column_kwargs = sa_column_kwargs - self.original_default = PydanticUndefined class RelationshipInfo(Representation): @@ -445,8 +427,7 @@ def __new__( config_kwargs = { key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs } - config_table = is_table(class_dict) - if config_table: + if class_dict_is_table(class_dict, kwargs): set_empty_defaults(pydantic_annotations, dict_used) new_cls: Type["SQLModelMetaclass"] = super().__new__( @@ -501,13 +482,8 @@ def __init__( # this allows FastAPI cloning a SQLModel for the response_model without # trying to create a new SQLAlchemy, for a new table, with the same name, that # triggers an error - base_is_table = False - for base in bases: - config = getattr(base, "__config__") # noqa: B009 - if config and getattr(config, "table", False): - base_is_table = True - break - if getattr(cls.__config__, "table", False) and not base_is_table: + base_is_table = any(cls_is_table(base) for base in bases) + if cls_is_table(cls) and not base_is_table: for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): if rel_info.sa_relationship: # There's a SQLAlchemy relationship declared, that takes precedence @@ -545,104 +521,6 @@ def __init__( ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) -def get_sqlalchemy_type(field: ModelField) -> Any: - sa_type = getattr(field.field_info, "sa_type", PydanticUndefined) # noqa: B009 - if sa_type is not PydanticUndefined: - return sa_type - if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: - # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI - if issubclass(field.type_, Enum): - return sa_Enum(field.type_) - if issubclass(field.type_, str): - if field.field_info.max_length: - return AutoString(length=field.field_info.max_length) - return AutoString - if issubclass(field.type_, float): - return Float - if issubclass(field.type_, bool): - return Boolean - if issubclass(field.type_, int): - return Integer - if issubclass(field.type_, datetime): - return DateTime - if issubclass(field.type_, date): - return Date - if issubclass(field.type_, timedelta): - return Interval - if issubclass(field.type_, time): - return Time - if issubclass(field.type_, bytes): - return LargeBinary - if issubclass(field.type_, Decimal): - return Numeric( - precision=getattr(field.type_, "max_digits", None), - scale=getattr(field.type_, "decimal_places", None), - ) - if issubclass(field.type_, ipaddress.IPv4Address): - return AutoString - if issubclass(field.type_, ipaddress.IPv4Network): - return AutoString - if issubclass(field.type_, ipaddress.IPv6Address): - return AutoString - if issubclass(field.type_, ipaddress.IPv6Network): - return AutoString - if issubclass(field.type_, Path): - return AutoString - if issubclass(field.type_, uuid.UUID): - return GUID - raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") - - -def get_column_from_field(field: ModelField) -> Column: # type: ignore - sa_column = getattr(field.field_info, "sa_column", PydanticUndefined) - if isinstance(sa_column, Column): - return sa_column - sa_type = get_sqlalchemy_type(field) - primary_key = getattr(field.field_info, "primary_key", PydanticUndefined) - if primary_key is PydanticUndefined: - primary_key = False - index = getattr(field.field_info, "index", PydanticUndefined) - if index is PydanticUndefined: - index = False - nullable = not primary_key and is_field_noneable(field) - # Override derived nullability if the nullable property is set explicitly - # on the field - field_nullable = getattr(field.field_info, "nullable", PydanticUndefined) # noqa: B009 - if field_nullable != PydanticUndefined: - assert not isinstance(field_nullable, PydanticUndefinedType) - nullable = field_nullable - args = [] - foreign_key = getattr(field.field_info, "foreign_key", PydanticUndefined) - if foreign_key is PydanticUndefined: - foreign_key = None - unique = getattr(field.field_info, "unique", PydanticUndefined) - if unique is PydanticUndefined: - unique = False - if foreign_key: - assert isinstance(foreign_key, str) - args.append(ForeignKey(foreign_key)) - kwargs = { - "primary_key": primary_key, - "nullable": nullable, - "index": index, - "unique": unique, - } - sa_default = PydanticUndefined - if field.field_info.default_factory: - sa_default = field.field_info.default_factory - elif field.field_info.default is not PydanticUndefined: - sa_default = field.field_info.default - if sa_default is not PydanticUndefined: - kwargs["default"] = sa_default - sa_column_args = getattr(field.field_info, "sa_column_args", PydanticUndefined) - if sa_column_args is not PydanticUndefined: - args.extend(list(cast(Sequence[Any], sa_column_args))) - sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", PydanticUndefined) - if sa_column_kwargs is not PydanticUndefined: - kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) - return Column(sa_type, *args, **kwargs) # type: ignore - - class_registry = weakref.WeakValueDictionary() # type: ignore default_registry = registry() @@ -688,7 +566,7 @@ def __init__(__pydantic_self__, **data: Any) -> None: # settable attribute if IS_PYDANTIC_V2: old_dict = __pydantic_self__.__dict__.copy() - __pydantic_self__.super().__init__(**data) # noqa + super().__init__(**data) # noqa __pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__} non_pydantic_keys = data.keys() - __pydantic_self__.model_fields else: @@ -739,6 +617,7 @@ def __tablename__(cls) -> str: return cls.__name__.lower() if IS_PYDANTIC_V2: + @classmethod def model_validate( cls: type[_TSQLModel], @@ -752,7 +631,7 @@ def model_validate( validated = super().model_validate( obj, strict=strict, from_attributes=from_attributes, context=context ) - return cls(**dict(validated)) + return cls(**validated.model_dump(exclude_unset=True)) else: @@ -761,7 +640,7 @@ def from_orm( cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None ) -> _TSQLModel: # Duplicated from Pydantic - if not cls.__config__.orm_mode: # noqa: attr-defined + if not cls.__config__.orm_mode: # noqa raise ConfigError( "You must have the config attribute orm_mode=True to use from_orm" ) diff --git a/tests/test_enums.py b/tests/test_enums.py index 194bdefea1..07a04c686e 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -5,6 +5,8 @@ from sqlalchemy.sql.type_api import TypeEngine from sqlmodel import Field, SQLModel +from .conftest import needs_pydanticv1, needs_pydanticv2 + """ Tests related to Enums @@ -72,7 +74,8 @@ def test_sqlite_ddl_sql(capsys): assert "CREATE TYPE" not in captured.out -def test_json_schema_flat_model(): +@needs_pydanticv1 +def test_json_schema_flat_model_pydantic_v1(): assert FlatModel.schema() == { "title": "FlatModel", "type": "object", @@ -92,7 +95,8 @@ def test_json_schema_flat_model(): } -def test_json_schema_inherit_model(): +@needs_pydanticv1 +def test_json_schema_inherit_model_pydantic_v1(): assert InheritModel.schema() == { "title": "InheritModel", "type": "object", @@ -110,3 +114,42 @@ def test_json_schema_inherit_model(): } }, } + + +@needs_pydanticv2 +def test_json_schema_flat_model_pydantic_v2(): + assert FlatModel.model_json_schema() == { + "title": "FlatModel", + "type": "object", + "properties": { + "id": {"default": None, "format": "uuid", "title": "Id", "type": "string"}, + "enum_field": {"allOf": [{"$ref": "#/$defs/MyEnum1"}], "default": None}, + }, + "$defs": { + "MyEnum1": { + "title": "MyEnum1", + "enum": ["A", "B"], + "type": "string", + } + }, + } + + +@needs_pydanticv2 +def test_json_schema_inherit_model_pydantic_v2(): + assert InheritModel.model_json_schema() == { + "title": "InheritModel", + "type": "object", + "properties": { + "id": {"title": "Id", "type": "string", "format": "uuid"}, + "enum_field": {"$ref": "#/$defs/MyEnum2"}, + }, + "required": ["id", "enum_field"], + "$defs": { + "MyEnum2": { + "title": "MyEnum2", + "enum": ["C", "D"], + "type": "string", + } + }, + } diff --git a/tests/test_instance_no_args.py b/tests/test_instance_no_args.py index 14d560628b..e54e8163b3 100644 --- a/tests/test_instance_no_args.py +++ b/tests/test_instance_no_args.py @@ -1,11 +1,16 @@ from typing import Optional +import pytest +from pydantic import ValidationError from sqlalchemy import create_engine, select from sqlalchemy.orm import Session from sqlmodel import Field, SQLModel +from .conftest import needs_pydanticv1, needs_pydanticv2 -def test_allow_instantiation_without_arguments(clear_sqlmodel): + +@needs_pydanticv1 +def test_allow_instantiation_without_arguments_pydantic_v1(clear_sqlmodel): class Item(SQLModel): id: Optional[int] = Field(default=None, primary_key=True) name: str @@ -25,3 +30,33 @@ class Config: assert len(result) == 1 assert isinstance(item.id, int) SQLModel.metadata.clear() + + +def test_not_allow_instantiation_without_arguments_if_not_table(): + class Item(SQLModel): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + description: Optional[str] = None + + with pytest.raises(ValidationError): + Item() + + +@needs_pydanticv2 +def test_allow_instantiation_without_arguments_pydnatic_v2(clear_sqlmodel): + class Item(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + description: Optional[str] = None + + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + with Session(engine) as db: + item = Item() + item.name = "Rick" + db.add(item) + db.commit() + result = db.execute(select(Item)).scalars().all() + assert len(result) == 1 + assert isinstance(item.id, int) + SQLModel.metadata.clear() diff --git a/tests/test_nullable.py b/tests/test_nullable.py index 1c8b37b218..a40bb5b5f0 100644 --- a/tests/test_nullable.py +++ b/tests/test_nullable.py @@ -58,7 +58,7 @@ class Hero(SQLModel, table=True): ][0] assert "primary_key INTEGER NOT NULL," in create_table_log assert "required_value VARCHAR NOT NULL," in create_table_log - assert "optional_default_ellipsis VARCHAR NOT NULL," in create_table_log + assert "optional_default_ellipsis VARCHAR," in create_table_log assert "optional_default_none VARCHAR," in create_table_log assert "optional_non_nullable VARCHAR NOT NULL," in create_table_log assert "optional_nullable VARCHAR," in create_table_log diff --git a/tests/test_validation.py b/tests/test_validation.py index e200a1e73d..3265922070 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,7 +1,6 @@ from typing import Optional import pytest -from pydantic import field_validator from pydantic.error_wrappers import ValidationError from sqlmodel import SQLModel @@ -19,21 +18,22 @@ def test_validation_pydantic_v1(clear_sqlmodel): https://github.com/samuelcolvin/pydantic/issues/1223 """ + from pydantic import validator class Hero(SQLModel): name: Optional[str] = None secret_name: Optional[str] = None age: Optional[int] = None - @field_validator("name", "secret_name", "age") + @validator("name", "secret_name", "age") def reject_none(cls, v): assert v is not None return v - Hero.from_orm({"age": 25}) + Hero.validate({"age": 25}) with pytest.raises(ValidationError): - Hero.from_orm({"name": None, "age": 25}) + Hero.validate({"name": None, "age": 25}) @needs_pydanticv2 @@ -47,6 +47,7 @@ def test_validation_pydantic_v2(clear_sqlmodel): https://github.com/samuelcolvin/pydantic/issues/1223 """ + from pydantic import field_validator class Hero(SQLModel): name: Optional[str] = None