diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 38dd501c4a..6295662aec 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -21,7 +21,7 @@ from pydantic import VERSION as P_VERSION from pydantic import BaseModel from pydantic.fields import FieldInfo -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import Annotated, Literal, get_args, get_origin # Reassign variable to make it reexported for mypy PYDANTIC_VERSION = P_VERSION @@ -208,6 +208,13 @@ def get_sa_type_from_type_annotation(annotation: Any) -> Any: # Optional unions are allowed use_type = bases[0] if bases[0] is not NoneType else bases[1] return get_sa_type_from_type_annotation(use_type) + if origin is Literal: + literal_args = get_args(annotation) + if all(isinstance(arg, bool) for arg in literal_args): # all bools + return bool + if all(isinstance(arg, int) for arg in literal_args): # all ints + return int + return str return origin def get_sa_type_from_field(field: Any) -> Any: @@ -459,6 +466,14 @@ def is_field_noneable(field: "FieldInfo") -> bool: return field.allow_none # type: ignore[no-any-return, attr-defined] def get_sa_type_from_field(field: Any) -> Any: + if get_origin(field.type_) is Literal: + literal_args = get_args(field.type_) + if all(isinstance(arg, bool) for arg in literal_args): # all bools + return bool + if all(isinstance(arg, int) for arg in literal_args): # all ints + return int + return str + if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: return field.type_ raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") diff --git a/tests/test_main.py b/tests/test_main.py index 60d5c40ebb..188bf9df66 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,6 +4,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import RelationshipProperty from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select +from typing_extensions import Literal def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel): @@ -125,3 +126,50 @@ class Hero(SQLModel, table=True): # The next statement should not raise an AttributeError assert hero_rusty_man.team assert hero_rusty_man.team.name == "Preventers" + + +def test_literal_str(clear_sqlmodel, caplog): + """Test https://github.com/fastapi/sqlmodel/issues/57""" + + class Model(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + all_str: Literal["a", "b", "c"] + mixed: Literal["yes", "no", 1, 0] + all_int: Literal[1, 2, 3] + int_bool: Literal[0, 1, True, False] + all_bool: Literal[True, False] + + obj = Model( + all_str="a", + mixed="yes", + all_int=1, + int_bool=True, + all_bool=False, + ) + + engine = create_engine("sqlite://", echo=True) + + SQLModel.metadata.create_all(engine) + + # Check DDL + assert "all_str VARCHAR NOT NULL" in caplog.text + assert "mixed VARCHAR NOT NULL" in caplog.text + assert "all_int INTEGER NOT NULL" in caplog.text + assert "int_bool INTEGER NOT NULL" in caplog.text + assert "all_bool BOOLEAN NOT NULL" in caplog.text + + # Check query + with Session(engine) as session: + session.add(obj) + session.commit() + session.refresh(obj) + assert isinstance(obj.all_str, str) + assert obj.all_str == "a" + assert isinstance(obj.mixed, str) + assert obj.mixed == "yes" + assert isinstance(obj.all_int, int) + assert obj.all_int == 1 + assert isinstance(obj.int_bool, int) + assert obj.int_bool == 1 + assert isinstance(obj.all_bool, bool) + assert obj.all_bool is False