2323from collections import defaultdict
2424
2525import sqlalchemy as sa
26- from sqlalchemy .sql import crud
26+ from sqlalchemy .sql import crud , selectable
2727from sqlalchemy .sql import compiler
2828from .types import MutableDict
29+ from .sa_version import SA_VERSION , SA_1_4
2930
3031
3132def rewrite_update (clauseelement , multiparams , params ):
@@ -73,7 +74,16 @@ def rewrite_update(clauseelement, multiparams, params):
7374def crate_before_execute (conn , clauseelement , multiparams , params ):
7475 is_crate = type (conn .dialect ).__name__ == 'CrateDialect'
7576 if is_crate and isinstance (clauseelement , sa .sql .expression .Update ):
76- return rewrite_update (clauseelement , multiparams , params )
77+ if SA_VERSION >= SA_1_4 :
78+ multiparams = ([params ],)
79+ params = {}
80+
81+ clauseelement , multiparams , params = rewrite_update (clauseelement , multiparams , params )
82+
83+ if SA_VERSION >= SA_1_4 :
84+ params = multiparams [0 ]
85+ multiparams = []
86+
7787 return clauseelement , multiparams , params
7888
7989
@@ -189,6 +199,9 @@ def visit_update(self, update_stmt, **kw):
189199 Parts are taken from the SQLCompiler base class.
190200 """
191201
202+ if SA_VERSION >= SA_1_4 :
203+ return self .visit_update_14 (update_stmt , ** kw )
204+
192205 if not update_stmt .parameters and \
193206 not hasattr (update_stmt , '_crate_specific' ):
194207 return super (CrateCompiler , self ).visit_update (update_stmt , ** kw )
@@ -212,11 +225,14 @@ def visit_update(self, update_stmt, **kw):
212225 update_stmt , table_text
213226 )
214227
228+ # CrateDB amendment.
215229 crud_params = self ._get_crud_params (update_stmt , ** kw )
216230
217231 text += table_text
218232
219233 text += ' SET '
234+
235+ # CrateDB amendment begin.
220236 include_table = extra_froms and \
221237 self .render_table_with_column_in_update_from
222238
@@ -234,6 +250,7 @@ def visit_update(self, update_stmt, **kw):
234250 set_clauses .append (k + ' = ' + self .process (bindparam ))
235251
236252 text += ', ' .join (set_clauses )
253+ # CrateDB amendment end.
237254
238255 if self .returning or update_stmt ._returning :
239256 if not self .returning :
@@ -269,7 +286,6 @@ def visit_update(self, update_stmt, **kw):
269286
270287 def _get_crud_params (compiler , stmt , ** kw ):
271288 """ extract values from crud parameters
272-
273289 taken from SQLAlchemy's crud module (since 1.0.x) and
274290 adapted for Crate dialect"""
275291
@@ -325,3 +341,298 @@ def _get_crud_params(compiler, stmt, **kw):
325341 values , kw )
326342
327343 return values
344+
345+ def visit_update_14 (self , update_stmt , ** kw ):
346+
347+ compile_state = update_stmt ._compile_state_factory (
348+ update_stmt , self , ** kw
349+ )
350+ update_stmt = compile_state .statement
351+
352+ toplevel = not self .stack
353+ if toplevel :
354+ self .isupdate = True
355+ if not self .compile_state :
356+ self .compile_state = compile_state
357+
358+ extra_froms = compile_state ._extra_froms
359+ is_multitable = bool (extra_froms )
360+
361+ if is_multitable :
362+ # main table might be a JOIN
363+ main_froms = set (selectable ._from_objects (update_stmt .table ))
364+ render_extra_froms = [
365+ f for f in extra_froms if f not in main_froms
366+ ]
367+ correlate_froms = main_froms .union (extra_froms )
368+ else :
369+ render_extra_froms = []
370+ correlate_froms = {update_stmt .table }
371+
372+ self .stack .append (
373+ {
374+ "correlate_froms" : correlate_froms ,
375+ "asfrom_froms" : correlate_froms ,
376+ "selectable" : update_stmt ,
377+ }
378+ )
379+
380+ text = "UPDATE "
381+
382+ if update_stmt ._prefixes :
383+ text += self ._generate_prefixes (
384+ update_stmt , update_stmt ._prefixes , ** kw
385+ )
386+
387+ table_text = self .update_tables_clause (
388+ update_stmt , update_stmt .table , render_extra_froms , ** kw
389+ )
390+
391+ # CrateDB amendment.
392+ crud_params = _get_crud_params_14 (
393+ self , update_stmt , compile_state , ** kw
394+ )
395+
396+ if update_stmt ._hints :
397+ dialect_hints , table_text = self ._setup_crud_hints (
398+ update_stmt , table_text
399+ )
400+ else :
401+ dialect_hints = None
402+
403+ text += table_text
404+
405+ text += " SET "
406+
407+ # CrateDB amendment begin.
408+ include_table = extra_froms and \
409+ self .render_table_with_column_in_update_from
410+
411+ set_clauses = []
412+
413+ for c , expr , value in crud_params :
414+ key = c ._compiler_dispatch (self , include_table = include_table )
415+ clause = key + ' = ' + value
416+ set_clauses .append (clause )
417+
418+ for k , v in compile_state ._dict_parameters .items ():
419+ if isinstance (k , str ) and '[' in k :
420+ bindparam = sa .sql .bindparam (k , v )
421+ clause = k + ' = ' + self .process (bindparam )
422+ set_clauses .append (clause )
423+
424+ text += ', ' .join (set_clauses )
425+ # CrateDB amendment end.
426+
427+ if self .returning or update_stmt ._returning :
428+ if self .returning_precedes_values :
429+ text += " " + self .returning_clause (
430+ update_stmt , self .returning or update_stmt ._returning
431+ )
432+
433+ if extra_froms :
434+ extra_from_text = self .update_from_clause (
435+ update_stmt ,
436+ update_stmt .table ,
437+ render_extra_froms ,
438+ dialect_hints ,
439+ ** kw
440+ )
441+ if extra_from_text :
442+ text += " " + extra_from_text
443+
444+ if update_stmt ._where_criteria :
445+ t = self ._generate_delimited_and_list (
446+ update_stmt ._where_criteria , ** kw
447+ )
448+ if t :
449+ text += " WHERE " + t
450+
451+ limit_clause = self .update_limit_clause (update_stmt )
452+ if limit_clause :
453+ text += " " + limit_clause
454+
455+ if (
456+ self .returning or update_stmt ._returning
457+ ) and not self .returning_precedes_values :
458+ text += " " + self .returning_clause (
459+ update_stmt , self .returning or update_stmt ._returning
460+ )
461+
462+ if self .ctes and toplevel :
463+ text = self ._render_cte_clause () + text
464+
465+ self .stack .pop (- 1 )
466+
467+ return text
468+
469+
470+ def _get_crud_params_14 (compiler , stmt , compile_state , ** kw ):
471+ """create a set of tuples representing column/string pairs for use
472+ in an INSERT or UPDATE statement.
473+
474+ Also generates the Compiled object's postfetch, prefetch, and
475+ returning column collections, used for default handling and ultimately
476+ populating the CursorResult's prefetch_cols() and postfetch_cols()
477+ collections.
478+
479+ """
480+ from sqlalchemy .sql .crud import _key_getters_for_crud_column
481+ from sqlalchemy .sql .crud import _create_bind_param
482+ from sqlalchemy .sql .crud import REQUIRED
483+ from sqlalchemy .sql .crud import _get_stmt_parameter_tuples_params
484+ from sqlalchemy .sql .crud import _get_multitable_params
485+ from sqlalchemy .sql .crud import _scan_insert_from_select_cols
486+ from sqlalchemy .sql .crud import _scan_cols
487+ from sqlalchemy import exc # noqa: F401
488+ from sqlalchemy .sql .crud import _extend_values_for_multiparams
489+
490+ compiler .postfetch = []
491+ compiler .insert_prefetch = []
492+ compiler .update_prefetch = []
493+ compiler .returning = []
494+
495+ # getters - these are normally just column.key,
496+ # but in the case of mysql multi-table update, the rules for
497+ # .key must conditionally take tablename into account
498+ (
499+ _column_as_key ,
500+ _getattr_col_key ,
501+ _col_bind_name ,
502+ ) = getters = _key_getters_for_crud_column (compiler , stmt , compile_state )
503+
504+ compiler ._key_getters_for_crud_column = getters
505+
506+ # no parameters in the statement, no parameters in the
507+ # compiled params - return binds for all columns
508+ if compiler .column_keys is None and compile_state ._no_parameters :
509+ return [
510+ (
511+ c ,
512+ compiler .preparer .format_column (c ),
513+ _create_bind_param (compiler , c , None , required = True ),
514+ )
515+ for c in stmt .table .columns
516+ ]
517+
518+ if compile_state ._has_multi_parameters :
519+ spd = compile_state ._multi_parameters [0 ]
520+ stmt_parameter_tuples = list (spd .items ())
521+ elif compile_state ._ordered_values :
522+ spd = compile_state ._dict_parameters
523+ stmt_parameter_tuples = compile_state ._ordered_values
524+ elif compile_state ._dict_parameters :
525+ spd = compile_state ._dict_parameters
526+ stmt_parameter_tuples = list (spd .items ())
527+ else :
528+ stmt_parameter_tuples = spd = None
529+
530+ # if we have statement parameters - set defaults in the
531+ # compiled params
532+ if compiler .column_keys is None :
533+ parameters = {}
534+ elif stmt_parameter_tuples :
535+ parameters = dict (
536+ (_column_as_key (key ), REQUIRED )
537+ for key in compiler .column_keys
538+ if key not in spd
539+ )
540+ else :
541+ parameters = dict (
542+ (_column_as_key (key ), REQUIRED ) for key in compiler .column_keys
543+ )
544+
545+ # create a list of column assignment clauses as tuples
546+ values = []
547+
548+ if stmt_parameter_tuples is not None :
549+ _get_stmt_parameter_tuples_params (
550+ compiler ,
551+ compile_state ,
552+ parameters ,
553+ stmt_parameter_tuples ,
554+ _column_as_key ,
555+ values ,
556+ kw ,
557+ )
558+
559+ check_columns = {}
560+
561+ # special logic that only occurs for multi-table UPDATE
562+ # statements
563+ if compile_state .isupdate and compile_state .is_multitable :
564+ _get_multitable_params (
565+ compiler ,
566+ stmt ,
567+ compile_state ,
568+ stmt_parameter_tuples ,
569+ check_columns ,
570+ _col_bind_name ,
571+ _getattr_col_key ,
572+ values ,
573+ kw ,
574+ )
575+
576+ if compile_state .isinsert and stmt ._select_names :
577+ _scan_insert_from_select_cols (
578+ compiler ,
579+ stmt ,
580+ compile_state ,
581+ parameters ,
582+ _getattr_col_key ,
583+ _column_as_key ,
584+ _col_bind_name ,
585+ check_columns ,
586+ values ,
587+ kw ,
588+ )
589+ else :
590+ _scan_cols (
591+ compiler ,
592+ stmt ,
593+ compile_state ,
594+ parameters ,
595+ _getattr_col_key ,
596+ _column_as_key ,
597+ _col_bind_name ,
598+ check_columns ,
599+ values ,
600+ kw ,
601+ )
602+
603+ # CrateDB amendment.
604+ # The rewriting logic in `rewrite_update` and `visit_update` needs
605+ # adjustments here in order to prevent `sqlalchemy.exc.CompileError:
606+ # Unconsumed column names: characters_name, data['nested']`
607+ """
608+ if parameters and stmt_parameter_tuples:
609+ check = (
610+ set(parameters)
611+ .intersection(_column_as_key(k) for k, v in stmt_parameter_tuples)
612+ .difference(check_columns)
613+ )
614+ if check:
615+ raise exc.CompileError(
616+ "Unconsumed column names: %s"
617+ % (", ".join("%s" % (c,) for c in check))
618+ )
619+ """
620+
621+ if compile_state ._has_multi_parameters :
622+ values = _extend_values_for_multiparams (
623+ compiler , stmt , compile_state , values , kw
624+ )
625+ elif not values and compiler .for_executemany :
626+ # convert an "INSERT DEFAULT VALUES"
627+ # into INSERT (firstcol) VALUES (DEFAULT) which can be turned
628+ # into an in-place multi values. This supports
629+ # insert_executemany_returning mode :)
630+ values = [
631+ (
632+ stmt .table .columns [0 ],
633+ compiler .preparer .format_column (stmt .table .columns [0 ]),
634+ "DEFAULT" ,
635+ )
636+ ]
637+
638+ return values
0 commit comments