From e1049fa78c20afd0e69eaec431d09f9ceef7803f Mon Sep 17 00:00:00 2001 From: Maruo <43961566+maru0123-2004@users.noreply.github.com> Date: Mon, 26 Sep 2022 15:57:32 +0900 Subject: [PATCH 01/10] Add sa_type to Field --- sqlmodel/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index d343c698e9..2865f7f5bd 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -372,6 +372,8 @@ def __init__( def get_sqlachemy_type(field: ModelField) -> Any: + if "sa_type" in field.field_info.extra: + return field.field_info.extra["sa_type"] if issubclass(field.type_, str): if field.field_info.max_length: return AutoString(length=field.field_info.max_length) From eb1732b7a1cd2867b4ed9b3e5257d2974bd2bd0f Mon Sep 17 00:00:00 2001 From: Maruo <43961566+maru0123-2004@users.noreply.github.com> Date: Thu, 6 Oct 2022 10:46:03 +0900 Subject: [PATCH 02/10] fix error --- sqlmodel/main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 2865f7f5bd..a537e3c739 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -63,6 +63,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: foreign_key = kwargs.pop("foreign_key", Undefined) unique = kwargs.pop("unique", False) index = kwargs.pop("index", Undefined) + sa_type = kwargs.pop("sa_type", Undefined) sa_column = kwargs.pop("sa_column", Undefined) sa_column_args = kwargs.pop("sa_column_args", Undefined) sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) @@ -146,6 +147,7 @@ def Field( unique: bool = False, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, + sa_type: Type[Any], sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, @@ -177,6 +179,7 @@ def Field( unique=unique, nullable=nullable, index=index, + sa_type=sa_type, sa_column=sa_column, sa_column_args=sa_column_args, sa_column_kwargs=sa_column_kwargs, From a07316da6e76ae54b005fb68b36048a22e8ab2db Mon Sep 17 00:00:00 2001 From: Maruo <43961566+maru0123-2004@users.noreply.github.com> Date: Thu, 6 Oct 2022 11:10:35 +0900 Subject: [PATCH 03/10] Update main.py --- sqlmodel/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index a537e3c739..64bad15708 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -84,6 +84,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: self.foreign_key = foreign_key self.unique = unique self.index = index + self.sa_type = sa_type self.sa_column = sa_column self.sa_column_args = sa_column_args self.sa_column_kwargs = sa_column_kwargs From 689885ffa4564d3a0db942728f341f60f20db9a8 Mon Sep 17 00:00:00 2001 From: Maruo <43961566+maru0123-2004@users.noreply.github.com> Date: Thu, 6 Oct 2022 11:20:53 +0900 Subject: [PATCH 04/10] Update main.py --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 64bad15708..de09796855 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -148,7 +148,7 @@ def Field( unique: bool = False, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, - sa_type: Type[Any], + sa_type: Type[Any] = Undefined, sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, From 793af1a1f19638ce56673fed5428946e26059fe7 Mon Sep 17 00:00:00 2001 From: Maruo <43961566+maru0123-2004@users.noreply.github.com> Date: Thu, 6 Oct 2022 11:35:02 +0900 Subject: [PATCH 05/10] Update main.py --- sqlmodel/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index de09796855..063d670306 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -376,8 +376,8 @@ def __init__( def get_sqlachemy_type(field: ModelField) -> Any: - if "sa_type" in field.field_info.extra: - return field.field_info.extra["sa_type"] + if not issubclass(type(field.field_info.sa_type), type(Undefined)): + return field.field_info.sa_type if issubclass(field.type_, str): if field.field_info.max_length: return AutoString(length=field.field_info.max_length) From 8fa01150f4b61efc6bedc57dada875d3440b75a3 Mon Sep 17 00:00:00 2001 From: Maruo <43961566+maru0123-2004@users.noreply.github.com> Date: Wed, 14 Dec 2022 23:57:57 +0900 Subject: [PATCH 06/10] fix typo --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 720b2b3df3..b6909158e8 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -375,7 +375,7 @@ def __init__( ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) -def get_sqlachemy_type(field: ModelField) -> Any: +def get_sqlalchemy_type(field: ModelField) -> Any: if not issubclass(type(field.field_info.sa_type), type(Undefined)): return field.field_info.sa_type if issubclass(field.type_, str): From c4c58cc2be20c29fc4c8f9e1ed06fe133eadaf85 Mon Sep 17 00:00:00 2001 From: Maruo <43961566+maru0123-2004@users.noreply.github.com> Date: Wed, 14 Dec 2022 23:58:43 +0900 Subject: [PATCH 07/10] fix error in no Field Column --- sqlmodel/main.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index b6909158e8..46f3f0ee28 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -376,8 +376,9 @@ def __init__( def get_sqlalchemy_type(field: ModelField) -> Any: - if not issubclass(type(field.field_info.sa_type), type(Undefined)): - return field.field_info.sa_type + if hasattr(field.field_info, "sa_type"): + if not issubclass(type(field.field_info.sa_type), type(Undefined)): + return field.field_info.sa_type if issubclass(field.type_, str): if field.field_info.max_length: return AutoString(length=field.field_info.max_length) From 9509313eaf5841790a0503fbba5527c950ad54ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 29 Oct 2023 12:00:37 +0400 Subject: [PATCH 08/10] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20checking?= =?UTF-8?q?=20for=20sa=5Ftype?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 8801730bca..266bb6ccd1 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -105,11 +105,15 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: ) if unique is not Undefined: raise RuntimeError( - "Passing unique is not supported when " "also passing a sa_column" + "Passing unique is not supported when also passing a sa_column" ) if index is not Undefined: raise RuntimeError( - "Passing index is not supported when " "also passing a sa_column" + "Passing index is not supported when also passing a sa_column" + ) + if sa_type is not Undefined: + raise RuntimeError( + "Passing sa_type is not supported when also passing a sa_column" ) super().__init__(default=default, **kwargs) self.primary_key = primary_key @@ -187,6 +191,7 @@ def Field( unique: Union[bool, UndefinedType] = Undefined, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, + sa_type: Union[Type[Any], UndefinedType] = Undefined, sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, schema_extra: Optional[Dict[str, Any]] = None, @@ -266,7 +271,7 @@ def Field( unique: Union[bool, UndefinedType] = Undefined, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, - sa_type: Type[Any] = Undefined, + sa_type: Union[Type[Any], UndefinedType] = Undefined, sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, @@ -519,9 +524,9 @@ def __init__( def get_sqlalchemy_type(field: ModelField) -> Any: - if hasattr(field.field_info, "sa_type"): - if not issubclass(type(field.field_info.sa_type), type(Undefined)): - return field.field_info.sa_type + sa_type = getattr(field.field_info, "sa_type") # noqa: B009 + if sa_type is not Undefined: + return sa_type if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI if issubclass(field.type_, Enum): From 00b30c65c9dc12c723ad28d54ebaca034e35d68d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 29 Oct 2023 12:00:58 +0400 Subject: [PATCH 09/10] =?UTF-8?q?=E2=9C=85=20Add=20test=20to=20ensure=20sa?= =?UTF-8?q?=5Ftype=20is=20not=20passed=20with=20sa=5Fcolumn?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_field_sa_column.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_field_sa_column.py b/tests/test_field_sa_column.py index 51cfdfa797..7384f1fabc 100644 --- a/tests/test_field_sa_column.py +++ b/tests/test_field_sa_column.py @@ -39,6 +39,17 @@ class Item(SQLModel, table=True): ) +def test_sa_column_no_type() -> None: + with pytest.raises(RuntimeError): + + class Item(SQLModel, table=True): + id: Optional[int] = Field( + default=None, + sa_type=Integer, + sa_column=Column(Integer, primary_key=True), + ) + + def test_sa_column_no_primary_key() -> None: with pytest.raises(RuntimeError): From 00923c2ee5c231dc48dd766c4b0b8edec3659e3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 29 Oct 2023 12:02:43 +0400 Subject: [PATCH 10/10] =?UTF-8?q?=F0=9F=90=9B=20Fix=20default=20sa=5Ftype?= =?UTF-8?q?=20when=20extracting=20it?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 266bb6ccd1..2b69dd2a75 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -524,7 +524,7 @@ def __init__( def get_sqlalchemy_type(field: ModelField) -> Any: - sa_type = getattr(field.field_info, "sa_type") # noqa: B009 + sa_type = getattr(field.field_info, "sa_type", Undefined) # noqa: B009 if sa_type is not Undefined: return sa_type if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: