22# -*- coding: utf-8 -*-
33from __future__ import annotations
44
5- from typing import Any , Generic , Iterable , Sequence
5+ from typing import Any , Generic , Sequence
66
77from sqlalchemy import (
88 Column ,
99 ColumnExpressionArgument ,
10+ Row ,
11+ RowMapping ,
1012 Select ,
1113 delete ,
1214 func ,
15+ insert ,
1316 inspect ,
1417 select ,
1518 update ,
1922from sqlalchemy_crud_plus .errors import CompositePrimaryKeysError , ModelColumnError , MultipleResultsError
2023from sqlalchemy_crud_plus .types import (
2124 CreateSchema ,
22- JoinConditionsConfig ,
23- LoadStrategiesConfig ,
25+ JoinConditions ,
26+ LoadOptions ,
27+ LoadStrategies ,
2428 Model ,
25- QueryOptions ,
2629 SortColumns ,
2730 SortOrders ,
2831 UpdateSchema ,
@@ -95,7 +98,7 @@ async def create_model(
9598 async def create_models (
9699 self ,
97100 session : AsyncSession ,
98- objs : Iterable [CreateSchema ],
101+ objs : list [CreateSchema ],
99102 flush : bool = False ,
100103 commit : bool = False ,
101104 ** kwargs ,
@@ -127,11 +130,41 @@ async def create_models(
127130
128131 return ins_list
129132
133+ async def bulk_create_models (
134+ self ,
135+ session : AsyncSession ,
136+ objs : list [dict [str , Any ]],
137+ render_nulls : bool = False ,
138+ flush : bool = False ,
139+ commit : bool = False ,
140+ ** kwargs ,
141+ ) -> Sequence [Row [Any ] | RowMapping | Any ]:
142+ """
143+ Create new instances of a model.
144+
145+ :param session: The SQLAlchemy async session
146+ :param objs: The dict list containing data to be saved,The dict data should be aligned with the model column
147+ :param render_nulls: render null values instead of ignoring them
148+ :param flush: If `True`, flush all object changes to the database
149+ :param commit: If `True`, commits the transaction immediately
150+ :param kwargs: Additional model data not included in the dict
151+ :return:
152+ """
153+ stmt = insert (self .model ).values (** kwargs ).execution_options (render_nulls = render_nulls ).returning (self .model )
154+ result = await session .execute (stmt , objs )
155+
156+ if flush :
157+ await session .flush ()
158+ if commit :
159+ await session .commit ()
160+
161+ return result .scalars ().all ()
162+
130163 async def count (
131164 self ,
132165 session : AsyncSession ,
133166 * whereclause : ColumnExpressionArgument [bool ],
134- join_conditions : JoinConditionsConfig | None = None ,
167+ join_conditions : JoinConditions | None = None ,
135168 ** kwargs ,
136169 ) -> int :
137170 """
@@ -163,7 +196,7 @@ async def exists(
163196 self ,
164197 session : AsyncSession ,
165198 * whereclause : ColumnExpressionArgument [bool ],
166- join_conditions : JoinConditionsConfig | None = None ,
199+ join_conditions : JoinConditions | None = None ,
167200 ** kwargs ,
168201 ) -> bool :
169202 """
@@ -193,9 +226,9 @@ async def select_model(
193226 session : AsyncSession ,
194227 pk : Any | Sequence [Any ],
195228 * whereclause : ColumnExpressionArgument [bool ],
196- load_options : QueryOptions | None = None ,
197- load_strategies : LoadStrategiesConfig | None = None ,
198- join_conditions : JoinConditionsConfig | None = None ,
229+ load_options : LoadOptions | None = None ,
230+ load_strategies : LoadStrategies | None = None ,
231+ join_conditions : JoinConditions | None = None ,
199232 ** kwargs : Any ,
200233 ) -> Model | None :
201234 """
@@ -236,9 +269,9 @@ async def select_model_by_column(
236269 self ,
237270 session : AsyncSession ,
238271 * whereclause : ColumnExpressionArgument [bool ],
239- load_options : QueryOptions | None = None ,
240- load_strategies : LoadStrategiesConfig | None = None ,
241- join_conditions : JoinConditionsConfig | None = None ,
272+ load_options : LoadOptions | None = None ,
273+ load_strategies : LoadStrategies | None = None ,
274+ join_conditions : JoinConditions | None = None ,
242275 ** kwargs : Any ,
243276 ) -> Model | None :
244277 """
@@ -266,9 +299,9 @@ async def select_model_by_column(
266299 async def select (
267300 self ,
268301 * whereclause : ColumnExpressionArgument [bool ],
269- load_options : QueryOptions | None = None ,
270- load_strategies : LoadStrategiesConfig | None = None ,
271- join_conditions : JoinConditionsConfig | None = None ,
302+ load_options : LoadOptions | None = None ,
303+ load_strategies : LoadStrategies | None = None ,
304+ join_conditions : JoinConditions | None = None ,
272305 ** kwargs ,
273306 ) -> Select :
274307 """
@@ -303,9 +336,9 @@ async def select_order(
303336 sort_columns : SortColumns ,
304337 sort_orders : SortOrders = None ,
305338 * whereclause : ColumnExpressionArgument [bool ],
306- load_options : QueryOptions | None = None ,
307- load_strategies : LoadStrategiesConfig | None = None ,
308- join_conditions : JoinConditionsConfig | None = None ,
339+ load_options : LoadOptions | None = None ,
340+ load_strategies : LoadStrategies | None = None ,
341+ join_conditions : JoinConditions | None = None ,
309342 ** kwargs : Any ,
310343 ) -> Select :
311344 """
@@ -334,9 +367,9 @@ async def select_models(
334367 self ,
335368 session : AsyncSession ,
336369 * whereclause : ColumnExpressionArgument [bool ],
337- load_options : QueryOptions | None = None ,
338- load_strategies : LoadStrategiesConfig | None = None ,
339- join_conditions : JoinConditionsConfig | None = None ,
370+ load_options : LoadOptions | None = None ,
371+ load_strategies : LoadStrategies | None = None ,
372+ join_conditions : JoinConditions | None = None ,
340373 limit : int | None = None ,
341374 offset : int | None = None ,
342375 ** kwargs : Any ,
@@ -376,9 +409,9 @@ async def select_models_order(
376409 sort_columns : SortColumns ,
377410 sort_orders : SortOrders = None ,
378411 * whereclause : ColumnExpressionArgument [bool ],
379- load_options : QueryOptions | None = None ,
380- load_strategies : LoadStrategiesConfig | None = None ,
381- join_conditions : JoinConditionsConfig | None = None ,
412+ load_options : LoadOptions | None = None ,
413+ load_strategies : LoadStrategies | None = None ,
414+ join_conditions : JoinConditions | None = None ,
382415 limit : int | None = None ,
383416 offset : int | None = None ,
384417 ** kwargs : Any ,
@@ -438,9 +471,9 @@ async def update_model(
438471 :return:
439472 """
440473 filters = self ._get_pk_filter (pk )
441- instance_data = obj if isinstance (obj , dict ) else obj .model_dump (exclude_unset = True )
442- instance_data .update (kwargs )
443- stmt = update (self .model ).where (* filters ).values (** instance_data )
474+ data = obj if isinstance (obj , dict ) else obj .model_dump (exclude_unset = True )
475+ data .update (kwargs )
476+ stmt = update (self .model ).where (* filters ).values (** data )
444477 result = await session .execute (stmt )
445478
446479 if flush :
@@ -480,8 +513,8 @@ async def update_model_by_column(
480513 if total_count > 1 :
481514 raise MultipleResultsError (f'Only one record is expected to be updated, found { total_count } records.' )
482515
483- instance_data = obj if isinstance (obj , dict ) else obj .model_dump (exclude_unset = True )
484- stmt = update (self .model ).where (* filters ).values (** instance_data )
516+ data = obj if isinstance (obj , dict ) else obj .model_dump (exclude_unset = True )
517+ stmt = update (self .model ).where (* filters ).values (** data )
485518 result = await session .execute (stmt )
486519
487520 if flush :
@@ -491,6 +524,47 @@ async def update_model_by_column(
491524
492525 return result .rowcount
493526
527+ async def bulk_update_models (
528+ self ,
529+ session : AsyncSession ,
530+ objs : list [UpdateSchema | dict [str , Any ]],
531+ pk_mode : bool = True ,
532+ flush : bool = False ,
533+ commit : bool = False ,
534+ ** kwargs ,
535+ ) -> int :
536+ """
537+ Bulk update multiple instances with different data for each record.
538+ Each update item should have 'pk' key and other fields to update.
539+
540+ :param session: The SQLAlchemy async session
541+ :param objs: To save a list of Pydantic schemas or dict for data
542+ :param pk_mode: Primary key mode, when enabled, the data must contain the primary key data
543+ :param flush: If `True`, flush all object changes to the database
544+ :param commit: If `True`, commits the transaction immediately
545+ :return: Total number of updated records
546+ """
547+ if not pk_mode :
548+ filters = parse_filters (self .model , ** kwargs )
549+
550+ if not filters :
551+ raise ValueError ('At least one filter condition must be provided for update operation' )
552+
553+ datas = [obj if isinstance (obj , dict ) else obj .model_dump (exclude_unset = True ) for obj in objs ]
554+ stmt = update (self .model ).where (* filters )
555+ conn = await session .connection ()
556+ await conn .execute (stmt , datas )
557+ else :
558+ datas = [obj if isinstance (obj , dict ) else obj .model_dump (exclude_unset = True ) for obj in objs ]
559+ await session .execute (update (self .model ), datas )
560+
561+ if flush :
562+ await session .flush ()
563+ if commit :
564+ await session .commit ()
565+
566+ return len (datas )
567+
494568 async def delete_model (
495569 self ,
496570 session : AsyncSession ,
0 commit comments