From 6d4490bdcdb8c134320cf80fb2ae2f4874f03b11 Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Mon, 25 Aug 2025 00:29:28 +0800 Subject: [PATCH 1/2] Add bulk create and update methods --- sqlalchemy_crud_plus/crud.py | 120 +++++++++++++++++++++++++++------- sqlalchemy_crud_plus/types.py | 49 ++++++++------ sqlalchemy_crud_plus/utils.py | 90 ++++++++++++++----------- 3 files changed, 180 insertions(+), 79 deletions(-) diff --git a/sqlalchemy_crud_plus/crud.py b/sqlalchemy_crud_plus/crud.py index a244b30..2d54511 100644 --- a/sqlalchemy_crud_plus/crud.py +++ b/sqlalchemy_crud_plus/crud.py @@ -2,14 +2,17 @@ # -*- coding: utf-8 -*- from __future__ import annotations -from typing import Any, Generic, Iterable, Sequence +from typing import Any, Generic, Sequence from sqlalchemy import ( Column, ColumnExpressionArgument, + Row, + RowMapping, Select, delete, func, + insert, inspect, select, update, @@ -19,8 +22,8 @@ from sqlalchemy_crud_plus.errors import CompositePrimaryKeysError, ModelColumnError, MultipleResultsError from sqlalchemy_crud_plus.types import ( CreateSchema, - JoinConditionsConfig, - LoadStrategiesConfig, + JoinConditions, + LoadStrategies, Model, QueryOptions, SortColumns, @@ -95,7 +98,7 @@ async def create_model( async def create_models( self, session: AsyncSession, - objs: Iterable[CreateSchema], + objs: list[CreateSchema], flush: bool = False, commit: bool = False, **kwargs, @@ -127,11 +130,41 @@ async def create_models( return ins_list + async def bulk_create_models( + self, + session: AsyncSession, + objs: list[dict[str, Any]], + render_nulls: bool = False, + flush: bool = False, + commit: bool = False, + **kwargs, + ) -> Sequence[Row[Any] | RowMapping | Any]: + """ + Create new instances of a model. + + :param session: The SQLAlchemy async session + :param objs: The dict list containing data to be saved,The dict data should be aligned with the model column + :param render_nulls: render null values instead of ignoring them + :param flush: If `True`, flush all object changes to the database + :param commit: If `True`, commits the transaction immediately + :param kwargs: Additional model data not included in the dict + :return: + """ + stmt = insert(self.model).values(**kwargs).execution_options(render_nulls=render_nulls).returning(self.model) + result = await session.execute(stmt, objs) + + if flush: + await session.flush() + if commit: + await session.commit() + + return result.scalars().all() + async def count( self, session: AsyncSession, *whereclause: ColumnExpressionArgument[bool], - join_conditions: JoinConditionsConfig | None = None, + join_conditions: JoinConditions | None = None, **kwargs, ) -> int: """ @@ -163,7 +196,7 @@ async def exists( self, session: AsyncSession, *whereclause: ColumnExpressionArgument[bool], - join_conditions: JoinConditionsConfig | None = None, + join_conditions: JoinConditions | None = None, **kwargs, ) -> bool: """ @@ -194,8 +227,8 @@ async def select_model( pk: Any | Sequence[Any], *whereclause: ColumnExpressionArgument[bool], load_options: QueryOptions | None = None, - load_strategies: LoadStrategiesConfig | None = None, - join_conditions: JoinConditionsConfig | None = None, + load_strategies: LoadStrategies | None = None, + join_conditions: JoinConditions | None = None, **kwargs: Any, ) -> Model | None: """ @@ -237,8 +270,8 @@ async def select_model_by_column( session: AsyncSession, *whereclause: ColumnExpressionArgument[bool], load_options: QueryOptions | None = None, - load_strategies: LoadStrategiesConfig | None = None, - join_conditions: JoinConditionsConfig | None = None, + load_strategies: LoadStrategies | None = None, + join_conditions: JoinConditions | None = None, **kwargs: Any, ) -> Model | None: """ @@ -267,8 +300,8 @@ async def select( self, *whereclause: ColumnExpressionArgument[bool], load_options: QueryOptions | None = None, - load_strategies: LoadStrategiesConfig | None = None, - join_conditions: JoinConditionsConfig | None = None, + load_strategies: LoadStrategies | None = None, + join_conditions: JoinConditions | None = None, **kwargs, ) -> Select: """ @@ -304,8 +337,8 @@ async def select_order( sort_orders: SortOrders = None, *whereclause: ColumnExpressionArgument[bool], load_options: QueryOptions | None = None, - load_strategies: LoadStrategiesConfig | None = None, - join_conditions: JoinConditionsConfig | None = None, + load_strategies: LoadStrategies | None = None, + join_conditions: JoinConditions | None = None, **kwargs: Any, ) -> Select: """ @@ -335,8 +368,8 @@ async def select_models( session: AsyncSession, *whereclause: ColumnExpressionArgument[bool], load_options: QueryOptions | None = None, - load_strategies: LoadStrategiesConfig | None = None, - join_conditions: JoinConditionsConfig | None = None, + load_strategies: LoadStrategies | None = None, + join_conditions: JoinConditions | None = None, limit: int | None = None, offset: int | None = None, **kwargs: Any, @@ -377,8 +410,8 @@ async def select_models_order( sort_orders: SortOrders = None, *whereclause: ColumnExpressionArgument[bool], load_options: QueryOptions | None = None, - load_strategies: LoadStrategiesConfig | None = None, - join_conditions: JoinConditionsConfig | None = None, + load_strategies: LoadStrategies | None = None, + join_conditions: JoinConditions | None = None, limit: int | None = None, offset: int | None = None, **kwargs: Any, @@ -438,9 +471,9 @@ async def update_model( :return: """ filters = self._get_pk_filter(pk) - instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) - instance_data.update(kwargs) - stmt = update(self.model).where(*filters).values(**instance_data) + data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) + data.update(kwargs) + stmt = update(self.model).where(*filters).values(**data) result = await session.execute(stmt) if flush: @@ -480,8 +513,8 @@ async def update_model_by_column( if total_count > 1: raise MultipleResultsError(f'Only one record is expected to be updated, found {total_count} records.') - instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) - stmt = update(self.model).where(*filters).values(**instance_data) + data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) + stmt = update(self.model).where(*filters).values(**data) result = await session.execute(stmt) if flush: @@ -491,6 +524,47 @@ async def update_model_by_column( return result.rowcount + async def bulk_update_models( + self, + session: AsyncSession, + objs: list[UpdateSchema | dict[str, Any]], + pk_mode: bool = True, + flush: bool = False, + commit: bool = False, + **kwargs, + ) -> int: + """ + Bulk update multiple instances with different data for each record. + Each update item should have 'pk' key and other fields to update. + + :param session: The SQLAlchemy async session + :param objs: To save a list of Pydantic schemas or dict for data + :param pk_mode: Primary key mode, when enabled, the data must contain the primary key data + :param flush: If `True`, flush all object changes to the database + :param commit: If `True`, commits the transaction immediately + :return: Total number of updated records + """ + if not pk_mode: + filters = parse_filters(self.model, **kwargs) + + if not filters: + raise ValueError('At least one filter condition must be provided for update operation') + + datas = [obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) for obj in objs] + stmt = update(self.model).where(*filters) + conn = await session.connection() + await conn.execute(stmt, datas) + else: + datas = [obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) for obj in objs] + await session.execute(update(self.model), datas) + + if flush: + await session.flush() + if commit: + await session.commit() + + return len(datas) + async def delete_model( self, session: AsyncSession, diff --git a/sqlalchemy_crud_plus/types.py b/sqlalchemy_crud_plus/types.py index aa6e125..e5a1c68 100644 --- a/sqlalchemy_crud_plus/types.py +++ b/sqlalchemy_crud_plus/types.py @@ -13,31 +13,42 @@ CreateSchema = TypeVar('CreateSchema', bound=BaseModel) UpdateSchema = TypeVar('UpdateSchema', bound=BaseModel) -# SQLAlchemy relationship loading strategies -LoadingStrategy = Literal[ - 'selectinload', # SELECT IN loading (recommended for one-to-many) - 'joinedload', # JOIN loading (recommended for one-to-one) - 'subqueryload', # Subquery loading (for large datasets) - 'contains_eager', # Use with explicit JOINs - 'raiseload', # Prevent lazy loading - 'noload', # Don't load relationship +# https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#relationship-loader-api +RelationshipLoadingStrategyType = Literal[ + 'contains_eager', + 'defaultload', + 'immediateload', + 'joinedload', + 'lazyload', + 'noload', + 'raiseload', + 'selectinload', + 'subqueryload', + # Load + 'defer', + 'load_only', + 'selectin_polymorphic', + 'undefer', + 'undefer_group', + 'with_expression', +] + +LoadStrategies = list[str] | dict[str, RelationshipLoadingStrategyType] + +# https://docs.sqlalchemy.org/en/20/orm/queryguide/columns.html#column-loading-api +ColumnLoadingStrategyType = Literal[ + 'defer', 'deferred', 'load_only', 'query_expression', 'undefer', 'undefer_group', 'with_expression' ] -# SQL JOIN types JoinType = Literal[ - 'inner', # INNER JOIN - 'left', # LEFT OUTER JOIN - 'right', # RIGHT OUTER JOIN - 'full', # FULL OUTER JOIN + 'inner', + 'left', + 'full', ] -# Configuration for relationship loading strategies -LoadStrategiesConfig = list[str] | dict[str, LoadingStrategy] +JoinConditions = list[str] | dict[str, JoinType] -# Configuration for JOIN conditions -JoinConditionsConfig = list[str] | dict[str, JoinType] +QueryOptions = list[ExecutableOption] -# Query configuration types SortColumns = str | list[str] SortOrders = str | list[str] | None -QueryOptions = list[ExecutableOption] diff --git a/sqlalchemy_crud_plus/utils.py b/sqlalchemy_crud_plus/utils.py index 5a17310..aea0d2f 100644 --- a/sqlalchemy_crud_plus/utils.py +++ b/sqlalchemy_crud_plus/utils.py @@ -10,11 +10,20 @@ from sqlalchemy.orm import ( InstrumentedAttribute, contains_eager, + defaultload, + defer, + immediateload, joinedload, + lazyload, + load_only, noload, raiseload, + selectin_polymorphic, selectinload, subqueryload, + undefer, + undefer_group, + with_expression, ) from sqlalchemy.orm.util import AliasedClass @@ -25,7 +34,7 @@ ModelColumnError, SelectOperatorError, ) -from sqlalchemy_crud_plus.types import JoinConditionsConfig, LoadStrategiesConfig, Model +from sqlalchemy_crud_plus.types import JoinConditions, LoadStrategies, Model _SUPPORTED_FILTERS = { # Comparison: https://docs.sqlalchemy.org/en/20/core/operators.html#comparison-operators @@ -297,7 +306,7 @@ def apply_sorting( return stmt -def build_load_strategies(model: type[Model], load_strategies: LoadStrategiesConfig | None) -> list: +def build_load_strategies(model: type[Model], load_strategies: LoadStrategies | None) -> list: """ Build relationship loading strategy options. @@ -306,44 +315,54 @@ def build_load_strategies(model: type[Model], load_strategies: LoadStrategiesCon :return: """ - strategy_map = { - 'selectinload': selectinload, - 'joinedload': joinedload, - 'subqueryload': subqueryload, + strategies_map = { 'contains_eager': contains_eager, - 'raiseload': raiseload, + 'defaultload': defaultload, + 'immediateload': immediateload, + 'joinedload': joinedload, + 'lazyload': lazyload, 'noload': noload, + 'raiseload': raiseload, + 'selectinload': selectinload, + 'subqueryload': subqueryload, + # Load + 'defer': defer, + 'load_only': load_only, + 'selectin_polymorphic': selectin_polymorphic, + 'undefer': undefer, + 'undefer_group': undefer_group, + 'with_expression': with_expression, } options = [] default_strategy = 'selectinload' if isinstance(load_strategies, list): - for rel_name in load_strategies: + for column in load_strategies: try: - rel_attr = getattr(model, rel_name) - strategy_func = strategy_map[default_strategy] - options.append(strategy_func(rel_attr)) + attr = getattr(model, column) + strategy_func = strategies_map[default_strategy] + options.append(strategy_func(attr)) except AttributeError: - continue + raise ModelColumnError(f'Invalid relationship column: {column}') elif isinstance(load_strategies, dict): - for rel_name, strategy_name in load_strategies.items(): - if strategy_name not in strategy_map: + for column, strategy_name in load_strategies.items(): + if strategy_name not in strategies_map: raise LoadingStrategyError( - f'Invalid loading strategy: {strategy_name}, only supports {list(strategy_map.keys())}' + f'Invalid loading strategy: {strategy_name}, only supports {list(strategies_map.keys())}' ) try: - rel_attr = getattr(model, rel_name) - strategy_func = strategy_map.get(strategy_name) - options.append(strategy_func(rel_attr)) + attr = getattr(model, column) + strategy_func = strategies_map.get(strategy_name) + options.append(strategy_func(attr)) except AttributeError: - continue + raise ModelColumnError(f'Invalid relationship column: {column}') return options -def apply_join_conditions(model: type[Model], stmt: Select, join_conditions: JoinConditionsConfig | None): +def apply_join_conditions(model: type[Model], stmt: Select, join_conditions: JoinConditions | None): """ Apply JOIN conditions to the query statement. @@ -353,32 +372,29 @@ def apply_join_conditions(model: type[Model], stmt: Select, join_conditions: Joi :return: """ if isinstance(join_conditions, list): - for rel_name in join_conditions: + for column in join_conditions: try: - rel_attr = getattr(model, rel_name) - stmt = stmt.join(rel_attr) + attr = getattr(model, column) + stmt = stmt.join(attr) except AttributeError: - continue + raise ModelColumnError(f'Invalid model column: {column}') elif isinstance(join_conditions, dict): - for rel_name, join_type in join_conditions.items(): - if join_type not in ['left', 'inner', 'right', 'full']: - raise JoinConditionError( - f'Invalid join type: {join_type}, only supports `left`, `inner`, `right`, `full`' - ) + for column, join_type in join_conditions.items(): + allowed_join_types = ['inner', 'left', 'full'] + if join_type not in allowed_join_types: + raise JoinConditionError(f'Invalid join type: {join_type}, only supports {allowed_join_types}') try: - rel_attr = getattr(model, rel_name) + attr = getattr(model, column) if join_type == 'left': - stmt = stmt.join(rel_attr, isouter=True) + stmt = stmt.join(attr, isouter=True) elif join_type == 'inner': - stmt = stmt.join(rel_attr) - elif join_type == 'right': - stmt = stmt.join(rel_attr, isouter=True) + stmt = stmt.join(attr) elif join_type == 'full': - stmt = stmt.join(rel_attr, full=True) + stmt = stmt.join(attr, full=True) else: - stmt = stmt.join(rel_attr) + stmt = stmt.join(attr) except AttributeError: - continue + raise ModelColumnError(f'Invalid model column: {column}') return stmt From 9b0d2d40606256ef15057ec7f96ae6263b8aee19 Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Mon, 25 Aug 2025 17:39:50 +0800 Subject: [PATCH 2/2] Update types and tests --- sqlalchemy_crud_plus/crud.py | 14 ++--- sqlalchemy_crud_plus/types.py | 15 +++-- sqlalchemy_crud_plus/utils.py | 19 +++--- tests/test_create.py | 40 +++++++++++- tests/test_relationships.py | 36 ----------- tests/test_update.py | 115 +++++++++++++++++++++++++++++++++- tests/test_utils.py | 13 ++-- 7 files changed, 183 insertions(+), 69 deletions(-) diff --git a/sqlalchemy_crud_plus/crud.py b/sqlalchemy_crud_plus/crud.py index 2d54511..8fe8ad9 100644 --- a/sqlalchemy_crud_plus/crud.py +++ b/sqlalchemy_crud_plus/crud.py @@ -23,9 +23,9 @@ from sqlalchemy_crud_plus.types import ( CreateSchema, JoinConditions, + LoadOptions, LoadStrategies, Model, - QueryOptions, SortColumns, SortOrders, UpdateSchema, @@ -226,7 +226,7 @@ async def select_model( session: AsyncSession, pk: Any | Sequence[Any], *whereclause: ColumnExpressionArgument[bool], - load_options: QueryOptions | None = None, + load_options: LoadOptions | None = None, load_strategies: LoadStrategies | None = None, join_conditions: JoinConditions | None = None, **kwargs: Any, @@ -269,7 +269,7 @@ async def select_model_by_column( self, session: AsyncSession, *whereclause: ColumnExpressionArgument[bool], - load_options: QueryOptions | None = None, + load_options: LoadOptions | None = None, load_strategies: LoadStrategies | None = None, join_conditions: JoinConditions | None = None, **kwargs: Any, @@ -299,7 +299,7 @@ async def select_model_by_column( async def select( self, *whereclause: ColumnExpressionArgument[bool], - load_options: QueryOptions | None = None, + load_options: LoadOptions | None = None, load_strategies: LoadStrategies | None = None, join_conditions: JoinConditions | None = None, **kwargs, @@ -336,7 +336,7 @@ async def select_order( sort_columns: SortColumns, sort_orders: SortOrders = None, *whereclause: ColumnExpressionArgument[bool], - load_options: QueryOptions | None = None, + load_options: LoadOptions | None = None, load_strategies: LoadStrategies | None = None, join_conditions: JoinConditions | None = None, **kwargs: Any, @@ -367,7 +367,7 @@ async def select_models( self, session: AsyncSession, *whereclause: ColumnExpressionArgument[bool], - load_options: QueryOptions | None = None, + load_options: LoadOptions | None = None, load_strategies: LoadStrategies | None = None, join_conditions: JoinConditions | None = None, limit: int | None = None, @@ -409,7 +409,7 @@ async def select_models_order( sort_columns: SortColumns, sort_orders: SortOrders = None, *whereclause: ColumnExpressionArgument[bool], - load_options: QueryOptions | None = None, + load_options: LoadOptions | None = None, load_strategies: LoadStrategies | None = None, join_conditions: JoinConditions | None = None, limit: int | None = None, diff --git a/sqlalchemy_crud_plus/types.py b/sqlalchemy_crud_plus/types.py index e5a1c68..bb18345 100644 --- a/sqlalchemy_crud_plus/types.py +++ b/sqlalchemy_crud_plus/types.py @@ -8,7 +8,6 @@ from sqlalchemy.orm import DeclarativeBase from sqlalchemy.sql.base import ExecutableOption -# Base type variables for generic CRUD operations Model = TypeVar('Model', bound=DeclarativeBase) CreateSchema = TypeVar('CreateSchema', bound=BaseModel) UpdateSchema = TypeVar('UpdateSchema', bound=BaseModel) @@ -33,13 +32,19 @@ 'with_expression', ] -LoadStrategies = list[str] | dict[str, RelationshipLoadingStrategyType] - # https://docs.sqlalchemy.org/en/20/orm/queryguide/columns.html#column-loading-api ColumnLoadingStrategyType = Literal[ - 'defer', 'deferred', 'load_only', 'query_expression', 'undefer', 'undefer_group', 'with_expression' + 'defer', + 'deferred', + 'load_only', + 'query_expression', + 'undefer', + 'undefer_group', + 'with_expression', ] +LoadStrategies = list[str] | dict[str, RelationshipLoadingStrategyType] | dict[str, ColumnLoadingStrategyType] + JoinType = Literal[ 'inner', 'left', @@ -48,7 +53,7 @@ JoinConditions = list[str] | dict[str, JoinType] -QueryOptions = list[ExecutableOption] +LoadOptions = list[ExecutableOption] SortColumns = str | list[str] SortOrders = str | list[str] | None diff --git a/sqlalchemy_crud_plus/utils.py b/sqlalchemy_crud_plus/utils.py index aea0d2f..c90ee7d 100644 --- a/sqlalchemy_crud_plus/utils.py +++ b/sqlalchemy_crud_plus/utils.py @@ -8,7 +8,6 @@ from sqlalchemy import ColumnElement, Select, and_, asc, desc, or_ from sqlalchemy.orm import ( - InstrumentedAttribute, contains_eager, defaultload, defer, @@ -26,6 +25,8 @@ with_expression, ) from sqlalchemy.orm.util import AliasedClass +from sqlalchemy.sql.operators import ColumnOperators +from sqlalchemy.sql.schema import Column from sqlalchemy_crud_plus.errors import ( ColumnSortError, @@ -98,9 +99,7 @@ ] -def get_sqlalchemy_filter( - operator: str, value: Any, allow_arithmetic: bool = True -) -> Callable[[InstrumentedAttribute], Callable] | None: +def get_sqlalchemy_filter(operator: str, value: Any, allow_arithmetic: bool = True) -> Callable[..., Any] | None: if operator in ['in', 'not_in', 'between']: if not isinstance(value, (tuple, list, set)): raise SelectOperatorError(f'The value of the <{operator}> filter must be tuple, list or set') @@ -119,7 +118,7 @@ def get_sqlalchemy_filter( return sqlalchemy_filter -def get_column(model: type[Model] | AliasedClass, field_name: str) -> InstrumentedAttribute: +def get_column(model: type[Model] | AliasedClass, field_name: str) -> Column: """ Get column from model with validation. @@ -138,7 +137,7 @@ def get_column(model: type[Model] | AliasedClass, field_name: str) -> Instrument return column -def _create_or_filters(column: InstrumentedAttribute, op: str, value: dict[str, Any]) -> list[ColumnElement | None]: +def _create_or_filters(column: Column, op: str, value: dict[str, Any]) -> list[ColumnOperators | None]: """ Create OR filter expressions. @@ -156,9 +155,7 @@ def _create_or_filters(column: InstrumentedAttribute, op: str, value: dict[str, return or_filters -def _create_arithmetic_filters( - column: InstrumentedAttribute, op: str, value: dict[str, Any] -) -> list[ColumnElement | None]: +def _create_arithmetic_filters(column: Column, op: str, value: dict[str, Any]) -> list[ColumnOperators | None]: """ Create arithmetic filter expressions. @@ -184,7 +181,7 @@ def _create_arithmetic_filters( return arithmetic_filters -def _create_and_filters(column: InstrumentedAttribute, op: str, value: Any) -> list[ColumnElement | None]: +def _create_and_filters(column: Column, op: str, value: Any) -> list[ColumnElement | None]: """ Create AND filter expressions. @@ -381,7 +378,7 @@ def apply_join_conditions(model: type[Model], stmt: Select, join_conditions: Joi elif isinstance(join_conditions, dict): for column, join_type in join_conditions.items(): - allowed_join_types = ['inner', 'left', 'full'] + allowed_join_types = ['inner', 'left', 'full'] # SQLAlchemy doesn't support right join if join_type not in allowed_join_types: raise JoinConditionError(f'Invalid join type: {join_type}, only supports {allowed_join_types}') try: diff --git a/tests/test_create.py b/tests/test_create.py index 9ded275..db99d07 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from datetime import datetime + import pytest from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy_crud_plus import CRUDPlus -from tests.models.basic import Ins +from tests.models.basic import Ins, InsPks from tests.schemas.basic import InsCreate @@ -94,3 +96,39 @@ async def test_create_models_with_kwargs(async_db_session: AsyncSession, crud_in assert len(results) == 2 assert all(r.del_flag is True for r in results) + + +@pytest.mark.asyncio +async def test_bulk_create_models_basic(async_db_session: AsyncSession, crud_ins: CRUDPlus[Ins]): + async with async_db_session.begin(): + data = [ + {'name': 'bulk_item_1', 'del_flag': False, 'created_time': datetime.now()}, + {'name': 'bulk_item_2', 'del_flag': True, 'created_time': datetime.now()}, + {'name': 'bulk_item_3', 'del_flag': False, 'created_time': datetime.now()}, + ] + results = await crud_ins.bulk_create_models(async_db_session, data) + + assert len(results) == 3 + assert results[0].name == 'bulk_item_1' + assert results[1].name == 'bulk_item_2' + assert results[2].name == 'bulk_item_3' + assert results[0].del_flag is False + assert results[1].del_flag is True + assert results[2].del_flag is False + + +@pytest.mark.asyncio +async def test_bulk_create_models_composite_keys(async_db_session: AsyncSession, crud_ins_pks: CRUDPlus[InsPks]): + data = [ + {'id': 1000, 'name': 'bulk_pks_1', 'sex': 'male', 'created_time': datetime.now()}, + {'id': 1001, 'name': 'bulk_pks_2', 'sex': 'female', 'created_time': datetime.now()}, + {'id': 1002, 'name': 'bulk_pks_3', 'sex': 'male', 'created_time': datetime.now()}, + ] + + async with async_db_session.begin(): + results = await crud_ins_pks.bulk_create_models(async_db_session, data) + + assert len(results) == 3 + assert results[0].id == 1000 + assert results[0].name == 'bulk_pks_1' + assert results[0].sex == 'male' diff --git a/tests/test_relationships.py b/tests/test_relationships.py index 645b04d..e9ad3cc 100644 --- a/tests/test_relationships.py +++ b/tests/test_relationships.py @@ -100,18 +100,6 @@ async def test_load_strategies_subqueryload( assert user is not None -@pytest.mark.asyncio -async def test_load_strategies_nested_relationship( - async_db_session: AsyncSession, rel_sample_data: dict, rel_crud_user: CRUDPlus[RelUser] -): - users = rel_sample_data['users'] - user = await rel_crud_user.select_model( - async_db_session, users[0].id, load_strategies={'posts.category': 'joinedload'} - ) - - assert user is not None - - @pytest.mark.asyncio async def test_load_strategies_with_select_models( async_db_session: AsyncSession, rel_sample_data: dict, rel_crud_user: CRUDPlus[RelUser] @@ -197,30 +185,6 @@ async def test_self_referencing_parent_load( assert category is not None -@pytest.mark.asyncio -async def test_two_level_nested_relationship( - async_db_session: AsyncSession, rel_sample_data: dict, rel_crud_user: CRUDPlus[RelUser] -): - users = rel_sample_data['users'] - user = await rel_crud_user.select_model( - async_db_session, users[0].id, load_strategies={'posts.category': 'joinedload'} - ) - - assert user is not None - - -@pytest.mark.asyncio -async def test_three_level_nested_relationship( - async_db_session: AsyncSession, rel_sample_data: dict, rel_crud_user: CRUDPlus[RelUser] -): - users = rel_sample_data['users'] - user = await rel_crud_user.select_model( - async_db_session, users[0].id, load_strategies={'posts.category.parent': 'joinedload'} - ) - - assert user is not None - - @pytest.mark.asyncio async def test_combined_load_strategies_and_join_conditions( async_db_session: AsyncSession, rel_sample_data: dict, rel_crud_user: CRUDPlus[RelUser] diff --git a/tests/test_update.py b/tests/test_update.py index fdf5619..5636227 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -5,8 +5,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy_crud_plus import CRUDPlus -from tests.models.basic import Ins -from tests.schemas.basic import InsUpdate +from tests.models.basic import Ins, InsPks +from tests.schemas.basic import InsCreate, InsPksCreate, InsUpdate @pytest.mark.asyncio @@ -172,3 +172,114 @@ async def test_update_model_by_column_multiple_results_error( with pytest.raises(Exception): async with async_db_session.begin(): await crud_ins.update_model_by_column(async_db_session, update_data, del_flag=False) + + +@pytest.mark.asyncio +async def test_bulk_update_models_pk_mode_true(async_db_session: AsyncSession, crud_ins: CRUDPlus[Ins]): + create_data = [ + InsCreate(name='update_test_1', del_flag=False), + InsCreate(name='update_test_2', del_flag=False), + InsCreate(name='update_test_3', del_flag=False), + ] + + async with async_db_session.begin(): + created_items = await crud_ins.create_models(async_db_session, create_data) + + update_data = [ + {'id': created_items[0].id, 'name': 'updated_test_1', 'del_flag': True}, + {'id': created_items[1].id, 'name': 'updated_test_2', 'del_flag': True}, + {'id': created_items[2].id, 'name': 'updated_test_3', 'del_flag': True}, + ] + + async with async_db_session.begin(): + result = await crud_ins.bulk_update_models(async_db_session, update_data, pk_mode=True) + + assert result == 3 + + async with async_db_session.begin(): + updated_item1 = await crud_ins.select_model(async_db_session, created_items[0].id) + updated_item2 = await crud_ins.select_model(async_db_session, created_items[1].id) + updated_item3 = await crud_ins.select_model(async_db_session, created_items[2].id) + + assert updated_item1.name == 'updated_test_1' + assert updated_item1.del_flag is True + assert updated_item2.name == 'updated_test_2' + assert updated_item2.del_flag is True + assert updated_item3.name == 'updated_test_3' + assert updated_item3.del_flag is True + + +@pytest.mark.asyncio +async def test_bulk_update_models_pk_mode_false(async_db_session: AsyncSession, crud_ins: CRUDPlus[Ins]): + create_data = [ + InsCreate(name='filter_test_1', del_flag=False), + InsCreate(name='filter_test_2', del_flag=False), + ] + + async with async_db_session.begin(): + await crud_ins.create_models(async_db_session, create_data) + + update_data = [ + {'name': 'bulk_updated_1'}, + {'name': 'bulk_updated_2'}, + ] + + async with async_db_session.begin(): + result = await crud_ins.bulk_update_models(async_db_session, update_data, pk_mode=False, del_flag=False) + + assert result == 2 + + +@pytest.mark.asyncio +async def test_bulk_update_models_with_pydantic_schema(async_db_session: AsyncSession, crud_ins: CRUDPlus[Ins]): + create_data = [InsCreate(name='schema_test')] + + async with async_db_session.begin(): + created_items = await crud_ins.create_models(async_db_session, create_data) + + update_data = [InsUpdate(name='schema_updated')] + + async with async_db_session.begin(): + result = await crud_ins.bulk_update_models(async_db_session, update_data, pk_mode=False, id=created_items[0].id) + + assert result == 1 + + +@pytest.mark.asyncio +async def test_bulk_update_models_pk_mode_false_no_filters_error( + async_db_session: AsyncSession, crud_ins: CRUDPlus[Ins] +): + """测试 bulk_update_models pk_mode=False 时没有过滤条件的错误""" + update_data = [{'name': 'no_filters'}] + + with pytest.raises(ValueError, match='At least one filter condition must be provided'): + async with async_db_session.begin(): + await crud_ins.bulk_update_models(async_db_session, update_data, pk_mode=False) + + +@pytest.mark.asyncio +async def test_bulk_update_models_composite_keys(async_db_session: AsyncSession, crud_ins_pks: CRUDPlus[InsPks]): + create_data = [ + InsPksCreate(id=2000, name='update_pks_1', sex='male'), + InsPksCreate(id=2001, name='update_pks_2', sex='female'), + ] + + async with async_db_session.begin(): + await crud_ins_pks.create_models(async_db_session, create_data) + + update_data = [ + {'id': 2000, 'sex': 'male', 'name': 'updated_pks_1'}, + {'id': 2001, 'sex': 'female', 'name': 'updated_pks_2'}, + ] + + async with async_db_session.begin(): + result = await crud_ins_pks.bulk_update_models(async_db_session, update_data, pk_mode=True) + + assert result == 2 + + async with async_db_session.begin(): + updated_item1 = await crud_ins_pks.select_model(async_db_session, (2000, 'male')) + updated_item2 = await crud_ins_pks.select_model(async_db_session, (2001, 'female')) + + assert updated_item1.name == 'updated_pks_1' + assert updated_item2.name == 'updated_pks_2' diff --git a/tests/test_utils.py b/tests/test_utils.py index 06bbc7c..bf2d240 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -356,9 +356,8 @@ def test_build_load_strategies_invalid_strategy(): def test_build_load_strategies_invalid_relationship(): - options = build_load_strategies(RelUser, ['nonexistent']) - - assert len(options) == 0 + with pytest.raises(ModelColumnError): + build_load_strategies(RelUser, ['nonexistent']) def test_apply_join_conditions_list(): @@ -384,9 +383,9 @@ def test_apply_join_conditions_dict_left(): def test_apply_join_conditions_right_join(): stmt = select(RelUser) - joined_stmt = apply_join_conditions(RelUser, stmt, {'posts': 'right'}) - assert joined_stmt is not None + with pytest.raises(JoinConditionError): + apply_join_conditions(RelUser, stmt, {'posts': 'right'}) def test_apply_join_conditions_full_join(): @@ -425,6 +424,6 @@ def test_apply_join_conditions_invalid_join_type(): def test_apply_join_conditions_invalid_relationship(): stmt = select(RelUser) - joined_stmt = apply_join_conditions(RelUser, stmt, ['nonexistent']) - assert joined_stmt is not None + with pytest.raises(ModelColumnError): + apply_join_conditions(RelUser, stmt, ['nonexistent'])