From b38e0487e18b38599f1c0fefe683feb9baf24c2f Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Fri, 24 Jun 2022 07:28:04 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- sqlmodel/main.py | 68 +++++++++++++++++---------------------------- tests/test_enums.py | 10 +++---- tests/test_types.py | 18 ++++++++---- 3 files changed, 42 insertions(+), 54 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index f175c8aba5..6694d9f7c2 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -204,14 +204,13 @@ def Relationship( sa_relationship_args: Optional[Sequence[Any]] = None, sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, ) -> Any: - relationship_info = RelationshipInfo( + return RelationshipInfo( back_populates=back_populates, link_model=link_model, sa_relationship=sa_relationship, sa_relationship_args=sa_relationship_args, sa_relationship_kwargs=sa_relationship_kwargs, ) - return relationship_info @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) @@ -328,9 +327,7 @@ def get_config(name: str) -> Any: return new_cls # Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models - def __init__( - cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any - ) -> None: + def __init__(self, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any) -> None: # Only one of the base classes (or the current one) should be a table model # this allows FastAPI cloning a SQLModel for the response_model without # trying to create a new SQLAlchemy, for a new table, with the same name, that @@ -341,17 +338,17 @@ def __init__( if config and getattr(config, "table", False): base_is_table = True break - if getattr(cls.__config__, "table", False) and not base_is_table: + if getattr(self.__config__, "table", False) and not base_is_table: dict_used = dict_.copy() - for field_name, field_value in cls.__fields__.items(): - dict_used[field_name] = get_column_from_field(field_value, cls=cls) - for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): + for field_name, field_value in self.__fields__.items(): + dict_used[field_name] = get_column_from_field(field_value, cls=self) + for rel_name, rel_info in self.__sqlmodel_relationships__.items(): if rel_info.sa_relationship: # There's a SQLAlchemy relationship declared, that takes precedence # over anything else, use that and continue with the next attribute dict_used[rel_name] = rel_info.sa_relationship continue - ann = cls.__annotations__[rel_name] + ann = self.__annotations__[rel_name] temp_field = ModelField.infer( name=rel_name, value=rel_info, @@ -378,15 +375,15 @@ def __init__( if rel_info.sa_relationship_args: rel_args.extend(rel_info.sa_relationship_args) if rel_info.sa_relationship_kwargs: - rel_kwargs.update(rel_info.sa_relationship_kwargs) + rel_kwargs |= rel_info.sa_relationship_kwargs rel_value: RelationshipProperty = relationship( # type: ignore relationship_to, *rel_args, **rel_kwargs ) dict_used[rel_name] = rel_value - setattr(cls, rel_name, rel_value) # Fix #315 - DeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw) + setattr(self, rel_name, rel_value) + DeclarativeMeta.__init__(self, classname, bases, dict_used, **kw) else: - ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) + ModelMetaclass.__init__(self, classname, bases, dict_, **kw) def get_pydantic_root_model_engine_type( @@ -411,11 +408,10 @@ def process_result_value(self, value: Any, dialect: Any) -> Any: def get_sqlachemy_type(field: ModelField) -> Any: - if issubclass(field.type_, BaseModel): - if field.type_.__custom_root_type__: - return get_pydantic_root_model_engine_type( - get_sqlachemy_type(field.type_.__fields__["__root__"]), field.type_ - ) + if issubclass(field.type_, BaseModel) and field.type_.__custom_root_type__: + return get_pydantic_root_model_engine_type( + get_sqlachemy_type(field.type_.__fields__["__root__"]), field.type_ + ) if issubclass(field.type_, str): if field.field_info.max_length: return AutoString(length=field.field_info.max_length) @@ -505,7 +501,7 @@ def get_column_from_field( args.extend(list(cast(Sequence[Any], sa_column_args))) sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined) if sa_column_kwargs is not Undefined: - kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) + kwargs |= cast(Dict[Any, Any], sa_column_kwargs) return Column(sa_type, *args, **kwargs) # type: ignore @@ -570,9 +566,10 @@ def __setattr__(self, name: str, value: Any) -> None: return else: # Set in SQLAlchemy, before Pydantic to trigger events and updates - if getattr(self.__config__, "table", False): - if is_instrumented(self, name): - set_attribute(self, name, value) + if getattr(self.__config__, "table", False) and is_instrumented( + self, name + ): + set_attribute(self, name, value) # Set in Pydantic model to trigger possible validation changes, only for # non relationship values if name not in self.__sqlmodel_relationships__: @@ -592,13 +589,7 @@ def from_orm( if update is not None: obj = {**obj, **update} # End SQLModel support dict - if not getattr(cls.__config__, "table", False): - # If not table, normal Pydantic code - m: _TSQLModel = cls.__new__(cls) - else: - # If table, create the new instance normally to make SQLAlchemy create - # the _sa_instance_state attribute - m = cls() + m = cls() if getattr(cls.__config__, "table", False) else cls.__new__(cls) values, fields_set, validation_error = validate_model(cls, obj) if validation_error: raise validation_error @@ -662,7 +653,7 @@ def _calculate_keys( exclude_unset: bool, update: Optional[Dict[str, Any]] = None, ) -> Optional[AbstractSet[str]]: - if include is None and exclude is None and exclude_unset is False: + if include is None and exclude is None and not exclude_unset: # Original in Pydantic: # return None # Updated to not return SQLAlchemy attributes @@ -671,16 +662,7 @@ def _calculate_keys( return self.__fields__.keys() # | self.__sqlmodel_relationships__.keys() keys: AbstractSet[str] - if exclude_unset: - keys = self.__fields_set__.copy() - else: - # Original in Pydantic: - # keys = self.__dict__.keys() - # Updated to not return SQLAlchemy attributes - # Do not include relationships as that would easily lead to infinite - # recursion, or traversing the whole database - keys = self.__fields__.keys() # | self.__sqlmodel_relationships__.keys() - + keys = self.__fields_set__.copy() if exclude_unset else self.__fields__.keys() if include is not None: keys &= include.keys() @@ -693,8 +675,8 @@ def _calculate_keys( return keys @declared_attr # type: ignore - def __tablename__(cls) -> str: - return cls.__name__.lower() + def __tablename__(self) -> str: + return self.__name__.lower() def create_model( diff --git a/tests/test_enums.py b/tests/test_enums.py index 1cf9028b1c..5b9ebb0115 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -17,16 +17,14 @@ def pg_dump(sql: TypeEngine, *args, **kwargs): dialect = sql.compile(dialect=postgres_engine.dialect) - sql_str = str(dialect).rstrip() - if sql_str: - print(sql_str + ";") + if sql_str := str(dialect).rstrip(): + print(f"{sql_str};") def sqlite_dump(sql: TypeEngine, *args, **kwargs): dialect = sql.compile(dialect=sqlite_engine.dialect) - sql_str = str(dialect).rstrip() - if sql_str: - print(sql_str + ";") + if sql_str := str(dialect).rstrip(): + print(f"{sql_str};") postgres_engine = create_mock_engine("postgresql://", pg_dump) diff --git a/tests/test_types.py b/tests/test_types.py index 436e7c6f52..da043486de 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -17,13 +17,21 @@ class HeroEnum(Enum): types_values = [ - str("Hero"), - float(0.5), - int(5), - datetime(year=2020, month=2, day=2, hour=2, minute=2, second=2, microsecond=2), + "Hero", + 0.5, + 5, + datetime( + year=2020, month=2, day=2, hour=2, minute=2, second=2, microsecond=2 + ), date(year=2020, month=2, day=2), timedelta( - days=2, seconds=2, microseconds=2, milliseconds=2, minutes=2, hours=2, weeks=2 + days=2, + seconds=2, + microseconds=2, + milliseconds=2, + minutes=2, + hours=2, + weeks=2, ), time(hour=2, minute=2, second=2, microsecond=2), HeroEnum.SPIDER_MAN,