Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 104 additions & 30 deletions sqlalchemy_crud_plus/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

from typing import Any, Generic, Iterable, Sequence
from typing import Any, Generic, Sequence

from sqlalchemy import (
Column,
ColumnExpressionArgument,
Row,
RowMapping,
Select,
delete,
func,
insert,
inspect,
select,
update,
Expand All @@ -19,10 +22,10 @@
from sqlalchemy_crud_plus.errors import CompositePrimaryKeysError, ModelColumnError, MultipleResultsError
from sqlalchemy_crud_plus.types import (
CreateSchema,
JoinConditionsConfig,
LoadStrategiesConfig,
JoinConditions,
LoadOptions,
LoadStrategies,
Model,
QueryOptions,
SortColumns,
SortOrders,
UpdateSchema,
Expand Down Expand Up @@ -95,7 +98,7 @@ async def create_model(
async def create_models(
self,
session: AsyncSession,
objs: Iterable[CreateSchema],
objs: list[CreateSchema],
flush: bool = False,
commit: bool = False,
**kwargs,
Expand Down Expand Up @@ -127,11 +130,41 @@ async def create_models(

return ins_list

async def bulk_create_models(
self,
session: AsyncSession,
objs: list[dict[str, Any]],
render_nulls: bool = False,
flush: bool = False,
commit: bool = False,
**kwargs,
) -> Sequence[Row[Any] | RowMapping | Any]:
"""
Create new instances of a model.

:param session: The SQLAlchemy async session
:param objs: The dict list containing data to be saved,The dict data should be aligned with the model column
:param render_nulls: render null values instead of ignoring them
:param flush: If `True`, flush all object changes to the database
:param commit: If `True`, commits the transaction immediately
:param kwargs: Additional model data not included in the dict
:return:
"""
stmt = insert(self.model).values(**kwargs).execution_options(render_nulls=render_nulls).returning(self.model)
result = await session.execute(stmt, objs)

if flush:
await session.flush()
if commit:
await session.commit()

return result.scalars().all()

async def count(
self,
session: AsyncSession,
*whereclause: ColumnExpressionArgument[bool],
join_conditions: JoinConditionsConfig | None = None,
join_conditions: JoinConditions | None = None,
**kwargs,
) -> int:
"""
Expand Down Expand Up @@ -163,7 +196,7 @@ async def exists(
self,
session: AsyncSession,
*whereclause: ColumnExpressionArgument[bool],
join_conditions: JoinConditionsConfig | None = None,
join_conditions: JoinConditions | None = None,
**kwargs,
) -> bool:
"""
Expand Down Expand Up @@ -193,9 +226,9 @@ async def select_model(
session: AsyncSession,
pk: Any | Sequence[Any],
*whereclause: ColumnExpressionArgument[bool],
load_options: QueryOptions | None = None,
load_strategies: LoadStrategiesConfig | None = None,
join_conditions: JoinConditionsConfig | None = None,
load_options: LoadOptions | None = None,
load_strategies: LoadStrategies | None = None,
join_conditions: JoinConditions | None = None,
**kwargs: Any,
) -> Model | None:
"""
Expand Down Expand Up @@ -236,9 +269,9 @@ async def select_model_by_column(
self,
session: AsyncSession,
*whereclause: ColumnExpressionArgument[bool],
load_options: QueryOptions | None = None,
load_strategies: LoadStrategiesConfig | None = None,
join_conditions: JoinConditionsConfig | None = None,
load_options: LoadOptions | None = None,
load_strategies: LoadStrategies | None = None,
join_conditions: JoinConditions | None = None,
**kwargs: Any,
) -> Model | None:
"""
Expand Down Expand Up @@ -266,9 +299,9 @@ async def select_model_by_column(
async def select(
self,
*whereclause: ColumnExpressionArgument[bool],
load_options: QueryOptions | None = None,
load_strategies: LoadStrategiesConfig | None = None,
join_conditions: JoinConditionsConfig | None = None,
load_options: LoadOptions | None = None,
load_strategies: LoadStrategies | None = None,
join_conditions: JoinConditions | None = None,
**kwargs,
) -> Select:
"""
Expand Down Expand Up @@ -303,9 +336,9 @@ async def select_order(
sort_columns: SortColumns,
sort_orders: SortOrders = None,
*whereclause: ColumnExpressionArgument[bool],
load_options: QueryOptions | None = None,
load_strategies: LoadStrategiesConfig | None = None,
join_conditions: JoinConditionsConfig | None = None,
load_options: LoadOptions | None = None,
load_strategies: LoadStrategies | None = None,
join_conditions: JoinConditions | None = None,
**kwargs: Any,
) -> Select:
"""
Expand Down Expand Up @@ -334,9 +367,9 @@ async def select_models(
self,
session: AsyncSession,
*whereclause: ColumnExpressionArgument[bool],
load_options: QueryOptions | None = None,
load_strategies: LoadStrategiesConfig | None = None,
join_conditions: JoinConditionsConfig | None = None,
load_options: LoadOptions | None = None,
load_strategies: LoadStrategies | None = None,
join_conditions: JoinConditions | None = None,
limit: int | None = None,
offset: int | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -376,9 +409,9 @@ async def select_models_order(
sort_columns: SortColumns,
sort_orders: SortOrders = None,
*whereclause: ColumnExpressionArgument[bool],
load_options: QueryOptions | None = None,
load_strategies: LoadStrategiesConfig | None = None,
join_conditions: JoinConditionsConfig | None = None,
load_options: LoadOptions | None = None,
load_strategies: LoadStrategies | None = None,
join_conditions: JoinConditions | None = None,
limit: int | None = None,
offset: int | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -438,9 +471,9 @@ async def update_model(
:return:
"""
filters = self._get_pk_filter(pk)
instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
instance_data.update(kwargs)
stmt = update(self.model).where(*filters).values(**instance_data)
data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
data.update(kwargs)
stmt = update(self.model).where(*filters).values(**data)
result = await session.execute(stmt)

if flush:
Expand Down Expand Up @@ -480,8 +513,8 @@ async def update_model_by_column(
if total_count > 1:
raise MultipleResultsError(f'Only one record is expected to be updated, found {total_count} records.')

instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
stmt = update(self.model).where(*filters).values(**instance_data)
data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
stmt = update(self.model).where(*filters).values(**data)
result = await session.execute(stmt)

if flush:
Expand All @@ -491,6 +524,47 @@ async def update_model_by_column(

return result.rowcount

async def bulk_update_models(
self,
session: AsyncSession,
objs: list[UpdateSchema | dict[str, Any]],
pk_mode: bool = True,
flush: bool = False,
commit: bool = False,
**kwargs,
) -> int:
"""
Bulk update multiple instances with different data for each record.
Each update item should have 'pk' key and other fields to update.

:param session: The SQLAlchemy async session
:param objs: To save a list of Pydantic schemas or dict for data
:param pk_mode: Primary key mode, when enabled, the data must contain the primary key data
:param flush: If `True`, flush all object changes to the database
:param commit: If `True`, commits the transaction immediately
:return: Total number of updated records
"""
if not pk_mode:
filters = parse_filters(self.model, **kwargs)

if not filters:
raise ValueError('At least one filter condition must be provided for update operation')

datas = [obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) for obj in objs]
stmt = update(self.model).where(*filters)
conn = await session.connection()
await conn.execute(stmt, datas)
else:
datas = [obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) for obj in objs]
await session.execute(update(self.model), datas)

if flush:
await session.flush()
if commit:
await session.commit()

return len(datas)

async def delete_model(
self,
session: AsyncSession,
Expand Down
56 changes: 36 additions & 20 deletions sqlalchemy_crud_plus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,52 @@
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.sql.base import ExecutableOption

# Base type variables for generic CRUD operations
Model = TypeVar('Model', bound=DeclarativeBase)
CreateSchema = TypeVar('CreateSchema', bound=BaseModel)
UpdateSchema = TypeVar('UpdateSchema', bound=BaseModel)

# SQLAlchemy relationship loading strategies
LoadingStrategy = Literal[
'selectinload', # SELECT IN loading (recommended for one-to-many)
'joinedload', # JOIN loading (recommended for one-to-one)
'subqueryload', # Subquery loading (for large datasets)
'contains_eager', # Use with explicit JOINs
'raiseload', # Prevent lazy loading
'noload', # Don't load relationship
# https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#relationship-loader-api
RelationshipLoadingStrategyType = Literal[
'contains_eager',
'defaultload',
'immediateload',
'joinedload',
'lazyload',
'noload',
'raiseload',
'selectinload',
'subqueryload',
# Load
'defer',
'load_only',
'selectin_polymorphic',
'undefer',
'undefer_group',
'with_expression',
]

# SQL JOIN types
# https://docs.sqlalchemy.org/en/20/orm/queryguide/columns.html#column-loading-api
ColumnLoadingStrategyType = Literal[
'defer',
'deferred',
'load_only',
'query_expression',
'undefer',
'undefer_group',
'with_expression',
]

LoadStrategies = list[str] | dict[str, RelationshipLoadingStrategyType] | dict[str, ColumnLoadingStrategyType]

JoinType = Literal[
'inner', # INNER JOIN
'left', # LEFT OUTER JOIN
'right', # RIGHT OUTER JOIN
'full', # FULL OUTER JOIN
'inner',
'left',
'full',
]

# Configuration for relationship loading strategies
LoadStrategiesConfig = list[str] | dict[str, LoadingStrategy]
JoinConditions = list[str] | dict[str, JoinType]

# Configuration for JOIN conditions
JoinConditionsConfig = list[str] | dict[str, JoinType]
LoadOptions = list[ExecutableOption]

# Query configuration types
SortColumns = str | list[str]
SortOrders = str | list[str] | None
QueryOptions = list[ExecutableOption]
Loading