Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 25 additions & 43 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,13 @@ def Relationship(
sa_relationship_args: Optional[Sequence[Any]] = None,
sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
) -> Any:
relationship_info = RelationshipInfo(
return RelationshipInfo(
back_populates=back_populates,
link_model=link_model,
sa_relationship=sa_relationship,
sa_relationship_args=sa_relationship_args,
sa_relationship_kwargs=sa_relationship_kwargs,
)
return relationship_info
Comment on lines -207 to -214
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Relationship refactored with the following changes:



@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
Expand Down Expand Up @@ -328,9 +327,7 @@ def get_config(name: str) -> Any:
return new_cls

# Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models
def __init__(
cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any
) -> None:
def __init__(self, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any) -> None:
Comment on lines -331 to +330
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function SQLModelMetaclass.__init__ refactored with the following changes:

This removes the following comments ( why? ):

# Fix #315

# Only one of the base classes (or the current one) should be a table model
# this allows FastAPI cloning a SQLModel for the response_model without
# trying to create a new SQLAlchemy, for a new table, with the same name, that
Expand All @@ -341,17 +338,17 @@ def __init__(
if config and getattr(config, "table", False):
base_is_table = True
break
if getattr(cls.__config__, "table", False) and not base_is_table:
if getattr(self.__config__, "table", False) and not base_is_table:
dict_used = dict_.copy()
for field_name, field_value in cls.__fields__.items():
dict_used[field_name] = get_column_from_field(field_value, cls=cls)
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
for field_name, field_value in self.__fields__.items():
dict_used[field_name] = get_column_from_field(field_value, cls=self)
for rel_name, rel_info in self.__sqlmodel_relationships__.items():
if rel_info.sa_relationship:
# There's a SQLAlchemy relationship declared, that takes precedence
# over anything else, use that and continue with the next attribute
dict_used[rel_name] = rel_info.sa_relationship
continue
ann = cls.__annotations__[rel_name]
ann = self.__annotations__[rel_name]
temp_field = ModelField.infer(
name=rel_name,
value=rel_info,
Expand All @@ -378,15 +375,15 @@ def __init__(
if rel_info.sa_relationship_args:
rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs)
rel_kwargs |= rel_info.sa_relationship_kwargs
rel_value: RelationshipProperty = relationship( # type: ignore
relationship_to, *rel_args, **rel_kwargs
)
dict_used[rel_name] = rel_value
setattr(cls, rel_name, rel_value) # Fix #315
DeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw)
setattr(self, rel_name, rel_value)
DeclarativeMeta.__init__(self, classname, bases, dict_used, **kw)
else:
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
ModelMetaclass.__init__(self, classname, bases, dict_, **kw)


def get_pydantic_root_model_engine_type(
Expand All @@ -411,11 +408,10 @@ def process_result_value(self, value: Any, dialect: Any) -> Any:


def get_sqlachemy_type(field: ModelField) -> Any:
if issubclass(field.type_, BaseModel):
if field.type_.__custom_root_type__:
return get_pydantic_root_model_engine_type(
get_sqlachemy_type(field.type_.__fields__["__root__"]), field.type_
)
if issubclass(field.type_, BaseModel) and field.type_.__custom_root_type__:
return get_pydantic_root_model_engine_type(
get_sqlachemy_type(field.type_.__fields__["__root__"]), field.type_
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_sqlachemy_type refactored with the following changes:

if issubclass(field.type_, str):
if field.field_info.max_length:
return AutoString(length=field.field_info.max_length)
Expand Down Expand Up @@ -505,7 +501,7 @@ def get_column_from_field(
args.extend(list(cast(Sequence[Any], sa_column_args)))
sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
if sa_column_kwargs is not Undefined:
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
kwargs |= cast(Dict[Any, Any], sa_column_kwargs)
Comment on lines -508 to +504
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_column_from_field refactored with the following changes:

return Column(sa_type, *args, **kwargs) # type: ignore


Expand Down Expand Up @@ -570,9 +566,10 @@ def __setattr__(self, name: str, value: Any) -> None:
return
else:
# Set in SQLAlchemy, before Pydantic to trigger events and updates
if getattr(self.__config__, "table", False):
if is_instrumented(self, name):
set_attribute(self, name, value)
if getattr(self.__config__, "table", False) and is_instrumented(
self, name
):
set_attribute(self, name, value)
Comment on lines -573 to +572
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function SQLModel.__setattr__ refactored with the following changes:

# Set in Pydantic model to trigger possible validation changes, only for
# non relationship values
if name not in self.__sqlmodel_relationships__:
Expand All @@ -592,13 +589,7 @@ def from_orm(
if update is not None:
obj = {**obj, **update}
# End SQLModel support dict
if not getattr(cls.__config__, "table", False):
# If not table, normal Pydantic code
m: _TSQLModel = cls.__new__(cls)
else:
# If table, create the new instance normally to make SQLAlchemy create
# the _sa_instance_state attribute
m = cls()
m = cls() if getattr(cls.__config__, "table", False) else cls.__new__(cls)
Comment on lines -595 to +592
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function SQLModel.from_orm refactored with the following changes:

This removes the following comments ( why? ):

# the _sa_instance_state attribute
# If table, create the new instance normally to make SQLAlchemy create
# If not table, normal Pydantic code

values, fields_set, validation_error = validate_model(cls, obj)
if validation_error:
raise validation_error
Expand Down Expand Up @@ -662,7 +653,7 @@ def _calculate_keys(
exclude_unset: bool,
update: Optional[Dict[str, Any]] = None,
) -> Optional[AbstractSet[str]]:
if include is None and exclude is None and exclude_unset is False:
if include is None and exclude is None and not exclude_unset:
Comment on lines -665 to +656
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function SQLModel._calculate_keys refactored with the following changes:

This removes the following comments ( why? ):

# Do not include relationships as that would easily lead to infinite
# recursion, or traversing the whole database
# | self.__sqlmodel_relationships__.keys()
# Original in Pydantic:
# keys = self.__dict__.keys()
# Updated to not return SQLAlchemy attributes

# Original in Pydantic:
# return None
# Updated to not return SQLAlchemy attributes
Expand All @@ -671,16 +662,7 @@ def _calculate_keys(
return self.__fields__.keys() # | self.__sqlmodel_relationships__.keys()

keys: AbstractSet[str]
if exclude_unset:
keys = self.__fields_set__.copy()
else:
# Original in Pydantic:
# keys = self.__dict__.keys()
# Updated to not return SQLAlchemy attributes
# Do not include relationships as that would easily lead to infinite
# recursion, or traversing the whole database
keys = self.__fields__.keys() # | self.__sqlmodel_relationships__.keys()

keys = self.__fields_set__.copy() if exclude_unset else self.__fields__.keys()
if include is not None:
keys &= include.keys()

Expand All @@ -693,8 +675,8 @@ def _calculate_keys(
return keys

@declared_attr # type: ignore
def __tablename__(cls) -> str:
return cls.__name__.lower()
def __tablename__(self) -> str:
return self.__name__.lower()
Comment on lines -696 to +679
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function SQLModel.__tablename__ refactored with the following changes:



def create_model(
Expand Down
10 changes: 4 additions & 6 deletions tests/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@

def pg_dump(sql: TypeEngine, *args, **kwargs):
dialect = sql.compile(dialect=postgres_engine.dialect)
sql_str = str(dialect).rstrip()
if sql_str:
print(sql_str + ";")
if sql_str := str(dialect).rstrip():
print(f"{sql_str};")
Comment on lines -20 to +21
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function pg_dump refactored with the following changes:



def sqlite_dump(sql: TypeEngine, *args, **kwargs):
dialect = sql.compile(dialect=sqlite_engine.dialect)
sql_str = str(dialect).rstrip()
if sql_str:
print(sql_str + ";")
if sql_str := str(dialect).rstrip():
print(f"{sql_str};")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function sqlite_dump refactored with the following changes:



postgres_engine = create_mock_engine("postgresql://", pg_dump)
Expand Down
18 changes: 13 additions & 5 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,21 @@ class HeroEnum(Enum):


types_values = [
str("Hero"),
float(0.5),
int(5),
datetime(year=2020, month=2, day=2, hour=2, minute=2, second=2, microsecond=2),
"Hero",
0.5,
5,
datetime(
year=2020, month=2, day=2, hour=2, minute=2, second=2, microsecond=2
),
date(year=2020, month=2, day=2),
timedelta(
days=2, seconds=2, microseconds=2, milliseconds=2, minutes=2, hours=2, weeks=2
days=2,
seconds=2,
microseconds=2,
milliseconds=2,
minutes=2,
hours=2,
weeks=2,
Comment on lines -20 to +34
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 20-26 refactored with the following changes:

),
time(hour=2, minute=2, second=2, microsecond=2),
HeroEnum.SPIDER_MAN,
Expand Down