2323from collections import defaultdict
2424
2525import sqlalchemy as sa
26- from sqlalchemy .sql import crud
26+ from sqlalchemy .sql import crud , selectable
2727from sqlalchemy .sql import compiler
2828from .types import MutableDict
29- from .sa_version import SA_1_1 , SA_VERSION
29+ from .sa_version import SA_VERSION , SA_1_1 , SA_1_4
30+
31+
32+ INSERT_SELECT_WITHOUT_PARENTHESES_MIN_VERSION = (1 , 0 , 1 )
3033
3134
3235def rewrite_update (clauseelement , multiparams , params ):
@@ -74,7 +77,18 @@ def rewrite_update(clauseelement, multiparams, params):
7477def crate_before_execute (conn , clauseelement , multiparams , params ):
7578 is_crate = type (conn .dialect ).__name__ == 'CrateDialect'
7679 if is_crate and isinstance (clauseelement , sa .sql .expression .Update ):
77- return rewrite_update (clauseelement , multiparams , params )
80+ if SA_VERSION >= SA_1_4 :
81+ multiparams = ([params ],)
82+ params = {}
83+
84+ clauseelement , multiparams , params = rewrite_update (clauseelement , multiparams , params )
85+
86+ if SA_VERSION >= SA_1_4 :
87+ params = multiparams [0 ]
88+ multiparams = []
89+
90+ return clauseelement , multiparams , params
91+
7892 return clauseelement , multiparams , params
7993
8094
@@ -189,9 +203,23 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw):
189203 used to compile <sql.expression.Insert> expressions.
190204
191205 this function wraps insert_from_select statements inside
192- parentheses to be conform with earlier versions of CreateDB.
206+ parentheses to be conform with earlier versions of CreateDB.
207+
208+ According to the changelog, CrateDB >= 1.0.1 already mitigates this requirement:
209+
210+ ``INSERT`` statements now support ``SELECT`` statements without parentheses.
211+ https://crate.io/docs/crate/reference/en/4.3/appendices/release-notes/1.0.1.html
193212 """
194213
214+ # Only CrateDB <= 1.0.0 needs parentheses for ``INSERT INTO ... SELECT ...``.
215+ if self .dialect .server_version_info >= INSERT_SELECT_WITHOUT_PARENTHESES_MIN_VERSION :
216+ return super (CrateCompiler , self ).visit_insert (insert_stmt , asfrom = asfrom , ** kw )
217+
218+ if SA_VERSION >= SA_1_4 :
219+ raise DeprecationWarning (
220+ "CrateDB version < {} not supported with SQLAlchemy 1.4" .format (
221+ INSERT_SELECT_WITHOUT_PARENTHESES_MIN_VERSION ))
222+
195223 self .stack .append (
196224 {'correlate_froms' : set (),
197225 "asfrom_froms" : set (),
@@ -288,6 +316,9 @@ def visit_update(self, update_stmt, **kw):
288316 Parts are taken from the SQLCompiler base class.
289317 """
290318
319+ if SA_VERSION >= SA_1_4 :
320+ return self .visit_update_14 (update_stmt , ** kw )
321+
291322 if not update_stmt .parameters and \
292323 not hasattr (update_stmt , '_crate_specific' ):
293324 return super (CrateCompiler , self ).visit_update (update_stmt , ** kw )
@@ -311,11 +342,14 @@ def visit_update(self, update_stmt, **kw):
311342 update_stmt , table_text
312343 )
313344
314- crud_params = self ._get_crud_params (update_stmt , ** kw )
345+ # CrateDB amendment.
346+ crud_params = self ._get_crud_params (self , update_stmt , ** kw )
315347
316348 text += table_text
317349
318350 text += ' SET '
351+
352+ # CrateDB amendment begin.
319353 include_table = extra_froms and \
320354 self .render_table_with_column_in_update_from
321355
@@ -333,6 +367,7 @@ def visit_update(self, update_stmt, **kw):
333367 set_clauses .append (k + ' = ' + self .process (bindparam ))
334368
335369 text += ', ' .join (set_clauses )
370+ # CrateDB amendment end.
336371
337372 if self .returning or update_stmt ._returning :
338373 if not self .returning :
@@ -368,7 +403,6 @@ def visit_update(self, update_stmt, **kw):
368403
369404 def _get_crud_params (compiler , stmt , ** kw ):
370405 """ extract values from crud parameters
371-
372406 taken from SQLAlchemy's crud module (since 1.0.x) and
373407 adapted for Crate dialect"""
374408
@@ -428,3 +462,295 @@ def _get_crud_params(compiler, stmt, **kw):
428462 values , kw )
429463
430464 return values
465+
466+ def visit_update_14 (self , update_stmt , ** kw ):
467+
468+ compile_state = update_stmt ._compile_state_factory (
469+ update_stmt , self , ** kw
470+ )
471+ update_stmt = compile_state .statement
472+
473+ toplevel = not self .stack
474+ if toplevel :
475+ self .isupdate = True
476+ if not self .compile_state :
477+ self .compile_state = compile_state
478+
479+ extra_froms = compile_state ._extra_froms
480+ is_multitable = bool (extra_froms )
481+
482+ if is_multitable :
483+ # main table might be a JOIN
484+ main_froms = set (selectable ._from_objects (update_stmt .table ))
485+ render_extra_froms = [
486+ f for f in extra_froms if f not in main_froms
487+ ]
488+ correlate_froms = main_froms .union (extra_froms )
489+ else :
490+ render_extra_froms = []
491+ correlate_froms = {update_stmt .table }
492+
493+ self .stack .append (
494+ {
495+ "correlate_froms" : correlate_froms ,
496+ "asfrom_froms" : correlate_froms ,
497+ "selectable" : update_stmt ,
498+ }
499+ )
500+
501+ text = "UPDATE "
502+
503+ if update_stmt ._prefixes :
504+ text += self ._generate_prefixes (
505+ update_stmt , update_stmt ._prefixes , ** kw
506+ )
507+
508+ table_text = self .update_tables_clause (
509+ update_stmt , update_stmt .table , render_extra_froms , ** kw
510+ )
511+
512+ # CrateDB amendment.
513+ crud_params = _get_crud_params_14 (
514+ self , update_stmt , compile_state , ** kw
515+ )
516+
517+ if update_stmt ._hints :
518+ dialect_hints , table_text = self ._setup_crud_hints (
519+ update_stmt , table_text
520+ )
521+ else :
522+ dialect_hints = None
523+
524+ text += table_text
525+
526+ text += " SET "
527+
528+ # CrateDB amendment begin.
529+ include_table = extra_froms and \
530+ self .render_table_with_column_in_update_from
531+
532+ set_clauses = []
533+
534+ for c , expr , value in crud_params :
535+ key = c ._compiler_dispatch (self , include_table = include_table )
536+ clause = key + ' = ' + value
537+ set_clauses .append (clause )
538+
539+ for k , v in compile_state ._dict_parameters .items ():
540+ if isinstance (k , str ) and '[' in k :
541+ bindparam = sa .sql .bindparam (k , v )
542+ clause = k + ' = ' + self .process (bindparam )
543+ set_clauses .append (clause )
544+
545+ text += ', ' .join (set_clauses )
546+ # CrateDB amendment end.
547+
548+ if self .returning or update_stmt ._returning :
549+ if self .returning_precedes_values :
550+ text += " " + self .returning_clause (
551+ update_stmt , self .returning or update_stmt ._returning
552+ )
553+
554+ if extra_froms :
555+ extra_from_text = self .update_from_clause (
556+ update_stmt ,
557+ update_stmt .table ,
558+ render_extra_froms ,
559+ dialect_hints ,
560+ ** kw
561+ )
562+ if extra_from_text :
563+ text += " " + extra_from_text
564+
565+ if update_stmt ._where_criteria :
566+ t = self ._generate_delimited_and_list (
567+ update_stmt ._where_criteria , ** kw
568+ )
569+ if t :
570+ text += " WHERE " + t
571+
572+ limit_clause = self .update_limit_clause (update_stmt )
573+ if limit_clause :
574+ text += " " + limit_clause
575+
576+ if (
577+ self .returning or update_stmt ._returning
578+ ) and not self .returning_precedes_values :
579+ text += " " + self .returning_clause (
580+ update_stmt , self .returning or update_stmt ._returning
581+ )
582+
583+ if self .ctes and toplevel :
584+ text = self ._render_cte_clause () + text
585+
586+ self .stack .pop (- 1 )
587+
588+ return text
589+
590+
591+ def _get_crud_params_14 (compiler , stmt , compile_state , ** kw ):
592+ """create a set of tuples representing column/string pairs for use
593+ in an INSERT or UPDATE statement.
594+
595+ Also generates the Compiled object's postfetch, prefetch, and
596+ returning column collections, used for default handling and ultimately
597+ populating the CursorResult's prefetch_cols() and postfetch_cols()
598+ collections.
599+
600+ """
601+ from sqlalchemy .sql .crud import _key_getters_for_crud_column
602+ from sqlalchemy .sql .crud import _create_bind_param
603+ from sqlalchemy .sql .crud import REQUIRED
604+ from sqlalchemy .sql .crud import _get_stmt_parameter_tuples_params
605+ from sqlalchemy .sql .crud import _get_multitable_params
606+ from sqlalchemy .sql .crud import _scan_insert_from_select_cols
607+ from sqlalchemy .sql .crud import _scan_cols
608+ from sqlalchemy import exc
609+ from sqlalchemy .sql .crud import _extend_values_for_multiparams
610+
611+ compiler .postfetch = []
612+ compiler .insert_prefetch = []
613+ compiler .update_prefetch = []
614+ compiler .returning = []
615+
616+ # getters - these are normally just column.key,
617+ # but in the case of mysql multi-table update, the rules for
618+ # .key must conditionally take tablename into account
619+ (
620+ _column_as_key ,
621+ _getattr_col_key ,
622+ _col_bind_name ,
623+ ) = getters = _key_getters_for_crud_column (compiler , stmt , compile_state )
624+
625+ compiler ._key_getters_for_crud_column = getters
626+
627+ # no parameters in the statement, no parameters in the
628+ # compiled params - return binds for all columns
629+ if compiler .column_keys is None and compile_state ._no_parameters :
630+ return [
631+ (
632+ c ,
633+ compiler .preparer .format_column (c ),
634+ _create_bind_param (compiler , c , None , required = True ),
635+ )
636+ for c in stmt .table .columns
637+ ]
638+
639+ if compile_state ._has_multi_parameters :
640+ spd = compile_state ._multi_parameters [0 ]
641+ stmt_parameter_tuples = list (spd .items ())
642+ elif compile_state ._ordered_values :
643+ spd = compile_state ._dict_parameters
644+ stmt_parameter_tuples = compile_state ._ordered_values
645+ elif compile_state ._dict_parameters :
646+ spd = compile_state ._dict_parameters
647+ stmt_parameter_tuples = list (spd .items ())
648+ else :
649+ stmt_parameter_tuples = spd = None
650+
651+ # if we have statement parameters - set defaults in the
652+ # compiled params
653+ if compiler .column_keys is None :
654+ parameters = {}
655+ elif stmt_parameter_tuples :
656+ parameters = dict (
657+ (_column_as_key (key ), REQUIRED )
658+ for key in compiler .column_keys
659+ if key not in spd
660+ )
661+ else :
662+ parameters = dict (
663+ (_column_as_key (key ), REQUIRED ) for key in compiler .column_keys
664+ )
665+
666+ # create a list of column assignment clauses as tuples
667+ values = []
668+
669+ if stmt_parameter_tuples is not None :
670+ _get_stmt_parameter_tuples_params (
671+ compiler ,
672+ compile_state ,
673+ parameters ,
674+ stmt_parameter_tuples ,
675+ _column_as_key ,
676+ values ,
677+ kw ,
678+ )
679+
680+ check_columns = {}
681+
682+ # special logic that only occurs for multi-table UPDATE
683+ # statements
684+ if compile_state .isupdate and compile_state .is_multitable :
685+ _get_multitable_params (
686+ compiler ,
687+ stmt ,
688+ compile_state ,
689+ stmt_parameter_tuples ,
690+ check_columns ,
691+ _col_bind_name ,
692+ _getattr_col_key ,
693+ values ,
694+ kw ,
695+ )
696+
697+ if compile_state .isinsert and stmt ._select_names :
698+ _scan_insert_from_select_cols (
699+ compiler ,
700+ stmt ,
701+ compile_state ,
702+ parameters ,
703+ _getattr_col_key ,
704+ _column_as_key ,
705+ _col_bind_name ,
706+ check_columns ,
707+ values ,
708+ kw ,
709+ )
710+ else :
711+ _scan_cols (
712+ compiler ,
713+ stmt ,
714+ compile_state ,
715+ parameters ,
716+ _getattr_col_key ,
717+ _column_as_key ,
718+ _col_bind_name ,
719+ check_columns ,
720+ values ,
721+ kw ,
722+ )
723+
724+ # CrateDB amendment.
725+ """
726+ if parameters and stmt_parameter_tuples:
727+ check = (
728+ set(parameters)
729+ .intersection(_column_as_key(k) for k, v in stmt_parameter_tuples)
730+ .difference(check_columns)
731+ )
732+ if check:
733+ raise exc.CompileError(
734+ "Unconsumed column names: %s"
735+ % (", ".join("%s" % (c,) for c in check))
736+ )
737+ """
738+
739+ if compile_state ._has_multi_parameters :
740+ values = _extend_values_for_multiparams (
741+ compiler , stmt , compile_state , values , kw
742+ )
743+ elif not values and compiler .for_executemany :
744+ # convert an "INSERT DEFAULT VALUES"
745+ # into INSERT (firstcol) VALUES (DEFAULT) which can be turned
746+ # into an in-place multi values. This supports
747+ # insert_executemany_returning mode :)
748+ values = [
749+ (
750+ stmt .table .columns [0 ],
751+ compiler .preparer .format_column (stmt .table .columns [0 ]),
752+ "DEFAULT" ,
753+ )
754+ ]
755+
756+ return values
0 commit comments