Skip to content

Commit 4f679d0

Browse files
authored
Add bulk create and update methods (#53)
* Add bulk create and update methods * Update types and tests
1 parent 69c60cf commit 4f679d0

File tree

7 files changed

+358
-143
lines changed

7 files changed

+358
-143
lines changed

sqlalchemy_crud_plus/crud.py

Lines changed: 104 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
# -*- coding: utf-8 -*-
33
from __future__ import annotations
44

5-
from typing import Any, Generic, Iterable, Sequence
5+
from typing import Any, Generic, Sequence
66

77
from sqlalchemy import (
88
Column,
99
ColumnExpressionArgument,
10+
Row,
11+
RowMapping,
1012
Select,
1113
delete,
1214
func,
15+
insert,
1316
inspect,
1417
select,
1518
update,
@@ -19,10 +22,10 @@
1922
from sqlalchemy_crud_plus.errors import CompositePrimaryKeysError, ModelColumnError, MultipleResultsError
2023
from sqlalchemy_crud_plus.types import (
2124
CreateSchema,
22-
JoinConditionsConfig,
23-
LoadStrategiesConfig,
25+
JoinConditions,
26+
LoadOptions,
27+
LoadStrategies,
2428
Model,
25-
QueryOptions,
2629
SortColumns,
2730
SortOrders,
2831
UpdateSchema,
@@ -95,7 +98,7 @@ async def create_model(
9598
async def create_models(
9699
self,
97100
session: AsyncSession,
98-
objs: Iterable[CreateSchema],
101+
objs: list[CreateSchema],
99102
flush: bool = False,
100103
commit: bool = False,
101104
**kwargs,
@@ -127,11 +130,41 @@ async def create_models(
127130

128131
return ins_list
129132

133+
async def bulk_create_models(
134+
self,
135+
session: AsyncSession,
136+
objs: list[dict[str, Any]],
137+
render_nulls: bool = False,
138+
flush: bool = False,
139+
commit: bool = False,
140+
**kwargs,
141+
) -> Sequence[Row[Any] | RowMapping | Any]:
142+
"""
143+
Create new instances of a model.
144+
145+
:param session: The SQLAlchemy async session
146+
:param objs: The dict list containing data to be saved,The dict data should be aligned with the model column
147+
:param render_nulls: render null values instead of ignoring them
148+
:param flush: If `True`, flush all object changes to the database
149+
:param commit: If `True`, commits the transaction immediately
150+
:param kwargs: Additional model data not included in the dict
151+
:return:
152+
"""
153+
stmt = insert(self.model).values(**kwargs).execution_options(render_nulls=render_nulls).returning(self.model)
154+
result = await session.execute(stmt, objs)
155+
156+
if flush:
157+
await session.flush()
158+
if commit:
159+
await session.commit()
160+
161+
return result.scalars().all()
162+
130163
async def count(
131164
self,
132165
session: AsyncSession,
133166
*whereclause: ColumnExpressionArgument[bool],
134-
join_conditions: JoinConditionsConfig | None = None,
167+
join_conditions: JoinConditions | None = None,
135168
**kwargs,
136169
) -> int:
137170
"""
@@ -163,7 +196,7 @@ async def exists(
163196
self,
164197
session: AsyncSession,
165198
*whereclause: ColumnExpressionArgument[bool],
166-
join_conditions: JoinConditionsConfig | None = None,
199+
join_conditions: JoinConditions | None = None,
167200
**kwargs,
168201
) -> bool:
169202
"""
@@ -193,9 +226,9 @@ async def select_model(
193226
session: AsyncSession,
194227
pk: Any | Sequence[Any],
195228
*whereclause: ColumnExpressionArgument[bool],
196-
load_options: QueryOptions | None = None,
197-
load_strategies: LoadStrategiesConfig | None = None,
198-
join_conditions: JoinConditionsConfig | None = None,
229+
load_options: LoadOptions | None = None,
230+
load_strategies: LoadStrategies | None = None,
231+
join_conditions: JoinConditions | None = None,
199232
**kwargs: Any,
200233
) -> Model | None:
201234
"""
@@ -236,9 +269,9 @@ async def select_model_by_column(
236269
self,
237270
session: AsyncSession,
238271
*whereclause: ColumnExpressionArgument[bool],
239-
load_options: QueryOptions | None = None,
240-
load_strategies: LoadStrategiesConfig | None = None,
241-
join_conditions: JoinConditionsConfig | None = None,
272+
load_options: LoadOptions | None = None,
273+
load_strategies: LoadStrategies | None = None,
274+
join_conditions: JoinConditions | None = None,
242275
**kwargs: Any,
243276
) -> Model | None:
244277
"""
@@ -266,9 +299,9 @@ async def select_model_by_column(
266299
async def select(
267300
self,
268301
*whereclause: ColumnExpressionArgument[bool],
269-
load_options: QueryOptions | None = None,
270-
load_strategies: LoadStrategiesConfig | None = None,
271-
join_conditions: JoinConditionsConfig | None = None,
302+
load_options: LoadOptions | None = None,
303+
load_strategies: LoadStrategies | None = None,
304+
join_conditions: JoinConditions | None = None,
272305
**kwargs,
273306
) -> Select:
274307
"""
@@ -303,9 +336,9 @@ async def select_order(
303336
sort_columns: SortColumns,
304337
sort_orders: SortOrders = None,
305338
*whereclause: ColumnExpressionArgument[bool],
306-
load_options: QueryOptions | None = None,
307-
load_strategies: LoadStrategiesConfig | None = None,
308-
join_conditions: JoinConditionsConfig | None = None,
339+
load_options: LoadOptions | None = None,
340+
load_strategies: LoadStrategies | None = None,
341+
join_conditions: JoinConditions | None = None,
309342
**kwargs: Any,
310343
) -> Select:
311344
"""
@@ -334,9 +367,9 @@ async def select_models(
334367
self,
335368
session: AsyncSession,
336369
*whereclause: ColumnExpressionArgument[bool],
337-
load_options: QueryOptions | None = None,
338-
load_strategies: LoadStrategiesConfig | None = None,
339-
join_conditions: JoinConditionsConfig | None = None,
370+
load_options: LoadOptions | None = None,
371+
load_strategies: LoadStrategies | None = None,
372+
join_conditions: JoinConditions | None = None,
340373
limit: int | None = None,
341374
offset: int | None = None,
342375
**kwargs: Any,
@@ -376,9 +409,9 @@ async def select_models_order(
376409
sort_columns: SortColumns,
377410
sort_orders: SortOrders = None,
378411
*whereclause: ColumnExpressionArgument[bool],
379-
load_options: QueryOptions | None = None,
380-
load_strategies: LoadStrategiesConfig | None = None,
381-
join_conditions: JoinConditionsConfig | None = None,
412+
load_options: LoadOptions | None = None,
413+
load_strategies: LoadStrategies | None = None,
414+
join_conditions: JoinConditions | None = None,
382415
limit: int | None = None,
383416
offset: int | None = None,
384417
**kwargs: Any,
@@ -438,9 +471,9 @@ async def update_model(
438471
:return:
439472
"""
440473
filters = self._get_pk_filter(pk)
441-
instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
442-
instance_data.update(kwargs)
443-
stmt = update(self.model).where(*filters).values(**instance_data)
474+
data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
475+
data.update(kwargs)
476+
stmt = update(self.model).where(*filters).values(**data)
444477
result = await session.execute(stmt)
445478

446479
if flush:
@@ -480,8 +513,8 @@ async def update_model_by_column(
480513
if total_count > 1:
481514
raise MultipleResultsError(f'Only one record is expected to be updated, found {total_count} records.')
482515

483-
instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
484-
stmt = update(self.model).where(*filters).values(**instance_data)
516+
data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
517+
stmt = update(self.model).where(*filters).values(**data)
485518
result = await session.execute(stmt)
486519

487520
if flush:
@@ -491,6 +524,47 @@ async def update_model_by_column(
491524

492525
return result.rowcount
493526

527+
async def bulk_update_models(
528+
self,
529+
session: AsyncSession,
530+
objs: list[UpdateSchema | dict[str, Any]],
531+
pk_mode: bool = True,
532+
flush: bool = False,
533+
commit: bool = False,
534+
**kwargs,
535+
) -> int:
536+
"""
537+
Bulk update multiple instances with different data for each record.
538+
Each update item should have 'pk' key and other fields to update.
539+
540+
:param session: The SQLAlchemy async session
541+
:param objs: To save a list of Pydantic schemas or dict for data
542+
:param pk_mode: Primary key mode, when enabled, the data must contain the primary key data
543+
:param flush: If `True`, flush all object changes to the database
544+
:param commit: If `True`, commits the transaction immediately
545+
:return: Total number of updated records
546+
"""
547+
if not pk_mode:
548+
filters = parse_filters(self.model, **kwargs)
549+
550+
if not filters:
551+
raise ValueError('At least one filter condition must be provided for update operation')
552+
553+
datas = [obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) for obj in objs]
554+
stmt = update(self.model).where(*filters)
555+
conn = await session.connection()
556+
await conn.execute(stmt, datas)
557+
else:
558+
datas = [obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) for obj in objs]
559+
await session.execute(update(self.model), datas)
560+
561+
if flush:
562+
await session.flush()
563+
if commit:
564+
await session.commit()
565+
566+
return len(datas)
567+
494568
async def delete_model(
495569
self,
496570
session: AsyncSession,

sqlalchemy_crud_plus/types.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,36 +8,52 @@
88
from sqlalchemy.orm import DeclarativeBase
99
from sqlalchemy.sql.base import ExecutableOption
1010

11-
# Base type variables for generic CRUD operations
1211
Model = TypeVar('Model', bound=DeclarativeBase)
1312
CreateSchema = TypeVar('CreateSchema', bound=BaseModel)
1413
UpdateSchema = TypeVar('UpdateSchema', bound=BaseModel)
1514

16-
# SQLAlchemy relationship loading strategies
17-
LoadingStrategy = Literal[
18-
'selectinload', # SELECT IN loading (recommended for one-to-many)
19-
'joinedload', # JOIN loading (recommended for one-to-one)
20-
'subqueryload', # Subquery loading (for large datasets)
21-
'contains_eager', # Use with explicit JOINs
22-
'raiseload', # Prevent lazy loading
23-
'noload', # Don't load relationship
15+
# https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#relationship-loader-api
16+
RelationshipLoadingStrategyType = Literal[
17+
'contains_eager',
18+
'defaultload',
19+
'immediateload',
20+
'joinedload',
21+
'lazyload',
22+
'noload',
23+
'raiseload',
24+
'selectinload',
25+
'subqueryload',
26+
# Load
27+
'defer',
28+
'load_only',
29+
'selectin_polymorphic',
30+
'undefer',
31+
'undefer_group',
32+
'with_expression',
2433
]
2534

26-
# SQL JOIN types
35+
# https://docs.sqlalchemy.org/en/20/orm/queryguide/columns.html#column-loading-api
36+
ColumnLoadingStrategyType = Literal[
37+
'defer',
38+
'deferred',
39+
'load_only',
40+
'query_expression',
41+
'undefer',
42+
'undefer_group',
43+
'with_expression',
44+
]
45+
46+
LoadStrategies = list[str] | dict[str, RelationshipLoadingStrategyType] | dict[str, ColumnLoadingStrategyType]
47+
2748
JoinType = Literal[
28-
'inner', # INNER JOIN
29-
'left', # LEFT OUTER JOIN
30-
'right', # RIGHT OUTER JOIN
31-
'full', # FULL OUTER JOIN
49+
'inner',
50+
'left',
51+
'full',
3252
]
3353

34-
# Configuration for relationship loading strategies
35-
LoadStrategiesConfig = list[str] | dict[str, LoadingStrategy]
54+
JoinConditions = list[str] | dict[str, JoinType]
3655

37-
# Configuration for JOIN conditions
38-
JoinConditionsConfig = list[str] | dict[str, JoinType]
56+
LoadOptions = list[ExecutableOption]
3957

40-
# Query configuration types
4158
SortColumns = str | list[str]
4259
SortOrders = str | list[str] | None
43-
QueryOptions = list[ExecutableOption]

0 commit comments

Comments
 (0)