1
1
import os
2
2
import time
3
3
import webbrowser
4
+ import pydantic
4
5
import rich
5
6
from rich .prompt import Confirm
6
7
7
- from dataclasses import dataclass
8
8
from typing import List , Optional , Dict
9
9
from .utils import dbt_diff_string_template , getLogger
10
10
from pathlib import Path
32
32
from . import connect_to_table , diff_tables , Algorithm
33
33
34
34
35
- @dataclass
36
- class DiffVars :
35
+ class TDiffVars (pydantic .BaseModel ):
37
36
dev_path : List [str ]
38
37
prod_path : List [str ]
39
38
primary_keys : List [str ]
40
- connection : Dict [str , str ]
41
- threads : Optional [int ]
39
+ connection : Dict [str , Optional [ str ] ]
40
+ threads : Optional [int ] = None
42
41
where_filter : Optional [str ] = None
42
+ include_columns : List [str ]
43
+ exclude_columns : List [str ]
43
44
44
45
45
46
def dbt_diff (
@@ -122,7 +123,7 @@ def _get_diff_vars(
122
123
config_prod_schema : Optional [str ],
123
124
config_prod_custom_schema : Optional [str ],
124
125
model ,
125
- ) -> DiffVars :
126
+ ) -> TDiffVars :
126
127
dev_database = model .database
127
128
dev_schema = model .schema_
128
129
@@ -156,19 +157,21 @@ def _get_diff_vars(
156
157
dev_qualified_list = [x for x in [dev_database , dev_schema , model .alias ] if x ]
157
158
prod_qualified_list = [x for x in [prod_database , prod_schema , model .alias ] if x ]
158
159
159
- where_filter = None
160
- if model .meta :
161
- try :
162
- where_filter = model .meta ["datafold" ]["datadiff" ]["filter" ]
163
- except KeyError :
164
- pass
165
-
166
- return DiffVars (
167
- dev_qualified_list , prod_qualified_list , primary_keys , dbt_parser .connection , dbt_parser .threads , where_filter
160
+ datadiff_model_config = dbt_parser .get_datadiff_model_config (model .meta )
161
+
162
+ return TDiffVars (
163
+ dev_path = dev_qualified_list ,
164
+ prod_path = prod_qualified_list ,
165
+ primary_keys = primary_keys ,
166
+ connection = dbt_parser .connection ,
167
+ threads = dbt_parser .threads ,
168
+ where_filter = datadiff_model_config .where_filter ,
169
+ include_columns = datadiff_model_config .include_columns ,
170
+ exclude_columns = datadiff_model_config .exclude_columns ,
168
171
)
169
172
170
173
171
- def _local_diff (diff_vars : DiffVars ) -> None :
174
+ def _local_diff (diff_vars : TDiffVars ) -> None :
172
175
column_diffs_str = ""
173
176
dev_qualified_str = "." .join (diff_vars .dev_path )
174
177
prod_qualified_str = "." .join (diff_vars .prod_path )
@@ -189,18 +192,25 @@ def _local_diff(diff_vars: DiffVars) -> None:
189
192
rich .print (diff_output_str )
190
193
return
191
194
192
- mutual_set = set (table1_columns ) & set (table2_columns )
193
- table1_set_diff = list (set (table1_columns ) - set (table2_columns ))
194
- table2_set_diff = list (set (table2_columns ) - set (table1_columns ))
195
+ column_set = set (table1_columns ).intersection (table2_columns )
196
+ table1_diff = set (table1_columns ).difference (table2_columns )
197
+ table2_diff = set (table2_columns ).difference (table1_columns )
198
+
199
+ if table1_diff :
200
+ column_diffs_str += f"Column(s) added: { table1_diff } \n "
201
+
202
+ if table2_diff :
203
+ column_diffs_str += f"Column(s) removed: { table2_diff } \n "
204
+
205
+ column_set = column_set - set (diff_vars .primary_keys )
195
206
196
- if table1_set_diff :
197
- column_diffs_str += "Column(s) added: " + str ( table1_set_diff ) + " \n "
207
+ if diff_vars . include_columns :
208
+ column_set = { x for x in column_set if x . upper () in [ y . upper () for y in diff_vars . include_columns ]}
198
209
199
- if table2_set_diff :
200
- column_diffs_str += "Column(s) removed: " + str ( table2_set_diff ) + " \n "
210
+ if diff_vars . exclude_columns :
211
+ column_set = { x for x in column_set if x . upper () not in [ y . upper () for y in diff_vars . exclude_columns ]}
201
212
202
- mutual_set = mutual_set - set (diff_vars .primary_keys )
203
- extra_columns = tuple (mutual_set )
213
+ extra_columns = tuple (column_set )
204
214
205
215
diff = diff_tables (
206
216
table1 ,
@@ -250,7 +260,7 @@ def _initialize_api() -> Optional[DatafoldAPI]:
250
260
return DatafoldAPI (api_key = api_key , host = datafold_host )
251
261
252
262
253
- def _cloud_diff (diff_vars : DiffVars , datasource_id : int , api : DatafoldAPI ) -> None :
263
+ def _cloud_diff (diff_vars : TDiffVars , datasource_id : int , api : DatafoldAPI ) -> None :
254
264
diff_output_str = _diff_output_base ("." .join (diff_vars .dev_path ), "." .join (diff_vars .prod_path ))
255
265
payload = TCloudApiDataDiff (
256
266
data_source1_id = datasource_id ,
@@ -260,6 +270,8 @@ def _cloud_diff(diff_vars: DiffVars, datasource_id: int, api: DatafoldAPI) -> No
260
270
pk_columns = diff_vars .primary_keys ,
261
271
filter1 = diff_vars .where_filter ,
262
272
filter2 = diff_vars .where_filter ,
273
+ include_columns = diff_vars .include_columns ,
274
+ exclude_columns = diff_vars .exclude_columns ,
263
275
)
264
276
265
277
if is_tracking_enabled ():
0 commit comments