18
18
from data_diff .query_utils import drop_table
19
19
from data_diff .utils import accumulate
20
20
from data_diff .sqeleton .utils import number_to_human
21
- from data_diff .sqeleton .queries import table , commit
21
+ from data_diff .sqeleton .queries import table , commit , this , Code
22
+ from data_diff .sqeleton .queries .api import insert_rows_in_batches
22
23
from data_diff .hashdiff_tables import HashDiffer , DEFAULT_BISECTION_THRESHOLD
23
24
from data_diff .table_segment import TableSegment
24
25
from .common import (
@@ -362,32 +363,25 @@ class PaginatedTable:
362
363
# much memory.
363
364
RECORDS_PER_BATCH = 1000000
364
365
365
- def __init__ (self , table , conn ):
366
- self .table = table
366
+ def __init__ (self , table_path , conn ):
367
+ self .table_path = table_path
367
368
self .conn = conn
368
369
369
370
def __iter__ (self ):
370
- iter = PaginatedTable (self .table , self .conn )
371
- iter .last_id = 0
372
- iter .values = []
373
- iter .value_index = 0
374
- return iter
375
-
376
- def __next__ (self ) -> str :
377
- if self .value_index == len (self .values ): # end of current batch
378
- query = f"SELECT id, col FROM { self .table } WHERE id > { self .last_id } ORDER BY id ASC LIMIT { self .RECORDS_PER_BATCH } "
379
- if isinstance (self .conn , db .Oracle ):
380
- query = f"SELECT id, col FROM { self .table } WHERE id > { self .last_id } ORDER BY id ASC OFFSET 0 ROWS FETCH NEXT { self .RECORDS_PER_BATCH } ROWS ONLY"
381
-
382
- self .values = self .conn .query (query , list )
383
- if len (self .values ) == 0 : # we must be done!
384
- raise StopIteration
385
- self .last_id = self .values [- 1 ][0 ]
386
- self .value_index = 0
387
-
388
- this_value = self .values [self .value_index ]
389
- self .value_index += 1
390
- return this_value
371
+ last_id = 0
372
+ while True :
373
+ query = (
374
+ table (self .table_path )
375
+ .select (this .id , this .col )
376
+ .where (this .id > last_id )
377
+ .order_by (this .id )
378
+ .limit (self .RECORDS_PER_BATCH )
379
+ )
380
+ rows = self .conn .query (query , list )
381
+ if not rows :
382
+ break
383
+ last_id = rows [- 1 ][0 ]
384
+ yield from rows
391
385
392
386
393
387
class DateTimeFaker :
@@ -560,90 +554,42 @@ def expand_params(testcase_func, param_num, param):
560
554
return name
561
555
562
556
563
- def _insert_to_table (conn , table , values , type ):
564
- current_n_rows = conn .query (f"SELECT COUNT(*) FROM { table } " , int )
557
+ def _insert_to_table (conn , table_path , values , type ):
558
+ tbl = table (table_path )
559
+
560
+ current_n_rows = conn .query (tbl .count (), int )
565
561
if current_n_rows == N_SAMPLES :
566
562
assert BENCHMARK , "Table should've been deleted, or we should be in BENCHMARK mode"
567
563
return
568
564
elif current_n_rows > 0 :
569
- conn .query (drop_table (table ))
570
- _create_table_with_indexes (conn , table , type )
571
-
572
- if BENCHMARK and N_SAMPLES > 10_000 :
573
- description = f"{ conn .name } : { table } "
574
- values = rich .progress .track (values , total = N_SAMPLES , description = description )
575
-
576
- default_insertion_query = f"INSERT INTO { table } (id, col) VALUES "
577
- if isinstance (conn , db .Oracle ):
578
- default_insertion_query = f"INSERT INTO { table } (id, col)"
579
-
580
- batch_size = 8000
581
- if isinstance (conn , db .BigQuery ):
582
- batch_size = 1000
583
-
584
- insertion_query = default_insertion_query
585
- selects = []
586
- for j , sample in values :
587
- if re .search (r"(time zone|tz)" , type ):
588
- sample = sample .replace (tzinfo = timezone .utc )
565
+ conn .query (drop_table (table_name ))
566
+ _create_table_with_indexes (conn , table_path , type )
589
567
590
- if isinstance (sample , bytearray ):
591
- value = f"'{ sample .decode ()} '"
568
+ # if BENCHMARK and N_SAMPLES > 10_000:
569
+ # description = f"{conn.name}: {table}"
570
+ # values = rich.progress.track(values, total=N_SAMPLES, description=description)
592
571
593
- elif type == "boolean" :
594
- value = str (bool (sample ))
572
+ if type == "boolean" :
573
+ values = [(i , bool (sample )) for i , sample in values ]
574
+ elif re .search (r"(time zone|tz)" , type ):
575
+ values = [(i , sample .replace (tzinfo = timezone .utc )) for i , sample in values ]
595
576
596
- elif isinstance (conn , db .Clickhouse ):
597
- if type .startswith ("DateTime64" ):
598
- value = f"' { sample .replace (tzinfo = None )} '"
577
+ if isinstance (conn , db .Clickhouse ):
578
+ if type .startswith ("DateTime64" ):
579
+ values = [( i , f" { sample .replace (tzinfo = None )} " ) for i , sample in values ]
599
580
600
- elif type == "DateTime" :
601
- sample = sample .replace (tzinfo = None )
602
- # Clickhouse's DateTime does not allow to store micro/milli/nano seconds
603
- value = f"'{ str (sample )[:19 ]} '"
581
+ elif type == "DateTime" :
582
+ # Clickhouse's DateTime does not allow to store micro/milli/nano seconds
583
+ values = [(i , str (sample )[:19 ]) for i , sample in values ]
604
584
605
- elif type .startswith ("Decimal" ):
606
- precision = int (type [8 :].rstrip (")" ).split ("," )[1 ])
607
- value = round (sample , precision )
585
+ elif type .startswith ("Decimal(" ):
586
+ precision = int (type [8 :].rstrip (")" ).split ("," )[1 ])
587
+ values = [(i , round (sample , precision )) for i , sample in values ]
588
+ elif isinstance (conn , db .BigQuery ) and type == "datetime" :
589
+ values = [(i , Code (f"cast(timestamp '{ sample } ' as datetime)" )) for i , sample in values ]
608
590
609
- else :
610
- value = f"'{ sample } '"
611
-
612
- elif isinstance (sample , (float , Decimal , int )):
613
- value = str (sample )
614
- elif isinstance (sample , datetime ) and isinstance (conn , (db .Presto , db .Oracle , db .Trino )):
615
- value = f"timestamp '{ sample } '"
616
- elif isinstance (sample , datetime ) and isinstance (conn , db .BigQuery ) and type == "datetime" :
617
- value = f"cast(timestamp '{ sample } ' as datetime)"
618
-
619
- else :
620
- value = f"'{ sample } '"
621
-
622
- if isinstance (conn , db .Oracle ):
623
- selects .append (f"SELECT { j } , { value } FROM dual" )
624
- else :
625
- insertion_query += f"({ j } , { value } ),"
626
-
627
- # Some databases want small batch sizes...
628
- # Need to also insert on the last row, might not divide cleanly!
629
- if j % batch_size == 0 or j == N_SAMPLES :
630
- if isinstance (conn , db .Oracle ):
631
- insertion_query += " UNION ALL " .join (selects )
632
- conn .query (insertion_query , None )
633
- selects = []
634
- insertion_query = default_insertion_query
635
- else :
636
- conn .query (insertion_query [0 :- 1 ], None )
637
- insertion_query = default_insertion_query
638
-
639
- if insertion_query != default_insertion_query :
640
- # Very bad, but this whole function needs to go
641
- if isinstance (conn , db .Oracle ):
642
- insertion_query += " UNION ALL " .join (selects )
643
- conn .query (insertion_query , None )
644
- else :
645
- conn .query (insertion_query [0 :- 1 ], None )
646
591
592
+ insert_rows_in_batches (conn , tbl , values , columns = ["id" , "col" ])
647
593
conn .query (commit )
648
594
649
595
@@ -676,17 +622,27 @@ def _create_indexes(conn, table):
676
622
raise (err )
677
623
678
624
679
- def _create_table_with_indexes (conn , table , type ):
625
+ def _create_table_with_indexes (conn , table_path , type_ ):
626
+ table_name = "." .join (map (conn .dialect .quote , table_path ))
627
+
628
+ tbl = table (
629
+ table_path ,
630
+ schema = {
631
+ "id" : int ,
632
+ "col" : type_ ,
633
+ },
634
+ )
635
+
680
636
if isinstance (conn , db .Oracle ):
681
- already_exists = conn .query (f"SELECT COUNT(*) from tab where tname='{ table .upper ()} '" , int ) > 0
637
+ already_exists = conn .query (f"SELECT COUNT(*) from tab where tname='{ table_name .upper ()} '" , int ) > 0
682
638
if not already_exists :
683
- conn .query (f"CREATE TABLE { table } (id int, col { type } )" , None )
639
+ conn .query (tbl . create () )
684
640
elif isinstance (conn , db .Clickhouse ):
685
- conn .query (f"CREATE TABLE { table } (id int, col { type } ) engine = Memory;" , None )
641
+ conn .query (f"CREATE TABLE { table_name } (id int, col { type_ } ) engine = Memory;" , None )
686
642
else :
687
- conn .query (f"CREATE TABLE IF NOT EXISTS { table } (id int, col { type } )" , None )
643
+ conn .query (tbl . create ( if_not_exists = True ) )
688
644
689
- _create_indexes (conn , table )
645
+ _create_indexes (conn , table_name )
690
646
conn .query (commit )
691
647
692
648
@@ -725,17 +681,15 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
725
681
726
682
self .src_table_path = src_table_path = src_conn .parse_table_name (src_table_name )
727
683
self .dst_table_path = dst_table_path = dst_conn .parse_table_name (dst_table_name )
728
- self .src_table = src_table = "." .join (map (src_conn .dialect .quote , src_table_path ))
729
- self .dst_table = dst_table = "." .join (map (dst_conn .dialect .quote , dst_table_path ))
730
684
731
685
start = time .monotonic ()
732
686
if not BENCHMARK :
733
687
drop_table (src_conn , src_table_path )
734
- _create_table_with_indexes (src_conn , src_table , source_type )
735
- _insert_to_table (src_conn , src_table , enumerate (sample_values , 1 ), source_type )
688
+ _create_table_with_indexes (src_conn , src_table_path , source_type )
689
+ _insert_to_table (src_conn , src_table_path , enumerate (sample_values , 1 ), source_type )
736
690
insertion_source_duration = time .monotonic () - start
737
691
738
- values_in_source = PaginatedTable (src_table , src_conn )
692
+ values_in_source = PaginatedTable (src_table_path , src_conn )
739
693
if source_db is db .Presto or source_db is db .Trino :
740
694
if source_type .startswith ("decimal" ):
741
695
values_in_source = ((a , Decimal (b )) for a , b in values_in_source )
@@ -745,8 +699,8 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
745
699
start = time .monotonic ()
746
700
if not BENCHMARK :
747
701
drop_table (dst_conn , dst_table_path )
748
- _create_table_with_indexes (dst_conn , dst_table , target_type )
749
- _insert_to_table (dst_conn , dst_table , values_in_source , target_type )
702
+ _create_table_with_indexes (dst_conn , dst_table_path , target_type )
703
+ _insert_to_table (dst_conn , dst_table_path , values_in_source , target_type )
750
704
insertion_target_duration = time .monotonic () - start
751
705
752
706
if type_category == "uuid" :
@@ -813,8 +767,8 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
813
767
"rows" : N_SAMPLES ,
814
768
"rows_human" : number_to_human (N_SAMPLES ),
815
769
"name_human" : f"{ source_db .__name__ } /{ sanitize (source_type )} <-> { target_db .__name__ } /{ sanitize (target_type )} " ,
816
- "src_table" : src_table [ 1 : - 1 ], # remove quotes
817
- "target_table" : dst_table [ 1 : - 1 ] ,
770
+ "src_table" : src_table_path ,
771
+ "target_table" : dst_table_path ,
818
772
"source_type" : source_type ,
819
773
"target_type" : target_type ,
820
774
"insertion_source_sec" : round (insertion_source_duration , 3 ),
0 commit comments