Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions sqlalchemy_crud_plus/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def select_model_by_column(self, session: AsyncSession, **kwargs) -> Model
:param kwargs: Query expressions.
:return:
"""
filters = await parse_filters(self.model, **kwargs)
filters = parse_filters(self.model, **kwargs)
stmt = select(self.model).where(*filters)
query = await session.execute(stmt)
return query.scalars().first()
Expand All @@ -87,7 +87,7 @@ async def select_models(self, session: AsyncSession, **kwargs) -> Sequence[Row[A
:param kwargs: Query expressions.
:return:
"""
filters = await parse_filters(self.model, **kwargs)
filters = parse_filters(self.model, **kwargs)
stmt = select(self.model).where(*filters)
query = await session.execute(stmt)
return query.scalars().all()
Expand All @@ -103,9 +103,9 @@ async def select_models_order(
:param sort_orders: more details see apply_sorting
:return:
"""
filters = await parse_filters(self.model, **kwargs)
filters = parse_filters(self.model, **kwargs)
stmt = select(self.model).where(*filters)
stmt_sort = await apply_sorting(self.model, stmt, sort_columns, sort_orders)
stmt_sort = apply_sorting(self.model, stmt, sort_columns, sort_orders)
query = await session.execute(stmt_sort)
return query.scalars().all()

Expand Down Expand Up @@ -149,7 +149,7 @@ async def update_model_by_column(
:param kwargs: Query expressions.
:return:
"""
filters = await parse_filters(self.model, **kwargs)
filters = parse_filters(self.model, **kwargs)
total_count = await count(session, self.model, filters)
if not allow_multiple and total_count > 1:
raise MultipleResultsError(f'Only one record is expected to be update, found {total_count} records.')
Expand Down Expand Up @@ -198,7 +198,7 @@ async def delete_model_by_column(
:param deleted_flag_column: Specify the flag column for logical deletion
:return:
"""
filters = await parse_filters(self.model, **kwargs)
filters = parse_filters(self.model, **kwargs)
total_count = await count(session, self.model, filters)
if not allow_multiple and total_count > 1:
raise MultipleResultsError(f'Only one record is expected to be delete, found {total_count} records.')
Expand Down
22 changes: 11 additions & 11 deletions sqlalchemy_crud_plus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
}


async def get_sqlalchemy_filter(
def get_sqlalchemy_filter(
operator: str, value: Any, allow_arithmetic: bool = True
) -> Callable[[str], Callable] | None:
if operator in ['in', 'not_in', 'between']:
Expand All @@ -82,55 +82,55 @@ async def get_sqlalchemy_filter(
return sqlalchemy_filter


async def get_column(model: Type[Model] | AliasedClass, field_name: str):
def get_column(model: Type[Model] | AliasedClass, field_name: str):
column = getattr(model, field_name, None)
if column is None:
raise ModelColumnError(f'Column {field_name} is not found in {model}')
return column


async def parse_filters(model: Type[Model] | AliasedClass, **kwargs) -> list[ColumnElement]:
def parse_filters(model: Type[Model] | AliasedClass, **kwargs) -> list[ColumnElement]:
filters = []

for key, value in kwargs.items():
if '__' in key:
field_name, op = key.rsplit('__', 1)
column = await get_column(model, field_name)
column = get_column(model, field_name)
if op == 'or':
or_filters = [
sqlalchemy_filter(column)(or_value)
for or_op, or_value in value.items()
if (sqlalchemy_filter := await get_sqlalchemy_filter(or_op, or_value)) is not None
if (sqlalchemy_filter := get_sqlalchemy_filter(or_op, or_value)) is not None
]
filters.append(or_(*or_filters))
elif isinstance(value, dict) and {'value', 'condition'}.issubset(value):
advanced_value = value['value']
condition = value['condition']
sqlalchemy_filter = await get_sqlalchemy_filter(op, advanced_value)
sqlalchemy_filter = get_sqlalchemy_filter(op, advanced_value)
if sqlalchemy_filter is not None:
condition_filters = []
for cond_op, cond_value in condition.items():
condition_filter = await get_sqlalchemy_filter(cond_op, cond_value, allow_arithmetic=False)
condition_filter = get_sqlalchemy_filter(cond_op, cond_value, allow_arithmetic=False)
condition_filters.append(
condition_filter(sqlalchemy_filter(column)(advanced_value))(cond_value)
if cond_op != 'between'
else condition_filter(sqlalchemy_filter(column)(advanced_value))(*cond_value)
)
filters.append(and_(*condition_filters))
else:
sqlalchemy_filter = await get_sqlalchemy_filter(op, value)
sqlalchemy_filter = get_sqlalchemy_filter(op, value)
if sqlalchemy_filter is not None:
filters.append(
sqlalchemy_filter(column)(value) if op != 'between' else sqlalchemy_filter(column)(*value)
)
else:
column = await get_column(model, key)
column = get_column(model, key)
filters.append(column == value)

return filters


async def apply_sorting(
def apply_sorting(
model: Type[Model] | AliasedClass,
stmt: Select,
sort_columns: str | list[str],
Expand Down Expand Up @@ -170,7 +170,7 @@ async def apply_sorting(
validated_sort_orders = ['asc'] * len(sort_columns) if not sort_orders else sort_orders

for idx, column_name in enumerate(sort_columns):
column = await get_column(model, column_name)
column = get_column(model, column_name)
order = validated_sort_orders[idx]
stmt = stmt.order_by(asc(column) if order == 'asc' else desc(column))

Expand Down