Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 4c68554

Browse files
authored
Merge pull request #535 from dlawin/issue_518_2
support include/exclude meta config
2 parents 0b2dcec + 818ab62 commit 4c68554

File tree

4 files changed

+325
-148
lines changed

4 files changed

+325
-148
lines changed

data_diff/cloud/datafold_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ class TCloudApiDataDiff(pydantic.BaseModel):
103103
pk_columns: List[str]
104104
filter1: Optional[str] = None
105105
filter2: Optional[str] = None
106+
include_columns: Optional[List[str]]
107+
exclude_columns: Optional[List[str]]
106108

107109

108110
class TSummaryResultPrimaryKeyStats(pydantic.BaseModel):

data_diff/dbt.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
22
import time
33
import webbrowser
4+
import pydantic
45
import rich
56
from rich.prompt import Confirm
67

7-
from dataclasses import dataclass
88
from typing import List, Optional, Dict
99
from .utils import dbt_diff_string_template, getLogger
1010
from pathlib import Path
@@ -32,14 +32,15 @@
3232
from . import connect_to_table, diff_tables, Algorithm
3333

3434

35-
@dataclass
36-
class DiffVars:
35+
class TDiffVars(pydantic.BaseModel):
3736
dev_path: List[str]
3837
prod_path: List[str]
3938
primary_keys: List[str]
40-
connection: Dict[str, str]
41-
threads: Optional[int]
39+
connection: Dict[str, Optional[str]]
40+
threads: Optional[int] = None
4241
where_filter: Optional[str] = None
42+
include_columns: List[str]
43+
exclude_columns: List[str]
4344

4445

4546
def dbt_diff(
@@ -122,7 +123,7 @@ def _get_diff_vars(
122123
config_prod_schema: Optional[str],
123124
config_prod_custom_schema: Optional[str],
124125
model,
125-
) -> DiffVars:
126+
) -> TDiffVars:
126127
dev_database = model.database
127128
dev_schema = model.schema_
128129

@@ -156,19 +157,21 @@ def _get_diff_vars(
156157
dev_qualified_list = [x for x in [dev_database, dev_schema, model.alias] if x]
157158
prod_qualified_list = [x for x in [prod_database, prod_schema, model.alias] if x]
158159

159-
where_filter = None
160-
if model.meta:
161-
try:
162-
where_filter = model.meta["datafold"]["datadiff"]["filter"]
163-
except KeyError:
164-
pass
165-
166-
return DiffVars(
167-
dev_qualified_list, prod_qualified_list, primary_keys, dbt_parser.connection, dbt_parser.threads, where_filter
160+
datadiff_model_config = dbt_parser.get_datadiff_model_config(model.meta)
161+
162+
return TDiffVars(
163+
dev_path=dev_qualified_list,
164+
prod_path=prod_qualified_list,
165+
primary_keys=primary_keys,
166+
connection=dbt_parser.connection,
167+
threads=dbt_parser.threads,
168+
where_filter=datadiff_model_config.where_filter,
169+
include_columns=datadiff_model_config.include_columns,
170+
exclude_columns=datadiff_model_config.exclude_columns,
168171
)
169172

170173

171-
def _local_diff(diff_vars: DiffVars) -> None:
174+
def _local_diff(diff_vars: TDiffVars) -> None:
172175
column_diffs_str = ""
173176
dev_qualified_str = ".".join(diff_vars.dev_path)
174177
prod_qualified_str = ".".join(diff_vars.prod_path)
@@ -189,18 +192,25 @@ def _local_diff(diff_vars: DiffVars) -> None:
189192
rich.print(diff_output_str)
190193
return
191194

192-
mutual_set = set(table1_columns) & set(table2_columns)
193-
table1_set_diff = list(set(table1_columns) - set(table2_columns))
194-
table2_set_diff = list(set(table2_columns) - set(table1_columns))
195+
column_set = set(table1_columns).intersection(table2_columns)
196+
table1_diff = set(table1_columns).difference(table2_columns)
197+
table2_diff = set(table2_columns).difference(table1_columns)
198+
199+
if table1_diff:
200+
column_diffs_str += f"Column(s) added: {table1_diff}\n"
201+
202+
if table2_diff:
203+
column_diffs_str += f"Column(s) removed: {table2_diff}\n"
204+
205+
column_set = column_set - set(diff_vars.primary_keys)
195206

196-
if table1_set_diff:
197-
column_diffs_str += "Column(s) added: " + str(table1_set_diff) + "\n"
207+
if diff_vars.include_columns:
208+
column_set = {x for x in column_set if x.upper() in [y.upper() for y in diff_vars.include_columns]}
198209

199-
if table2_set_diff:
200-
column_diffs_str += "Column(s) removed: " + str(table2_set_diff) + "\n"
210+
if diff_vars.exclude_columns:
211+
column_set = {x for x in column_set if x.upper() not in [y.upper() for y in diff_vars.exclude_columns]}
201212

202-
mutual_set = mutual_set - set(diff_vars.primary_keys)
203-
extra_columns = tuple(mutual_set)
213+
extra_columns = tuple(column_set)
204214

205215
diff = diff_tables(
206216
table1,
@@ -250,7 +260,7 @@ def _initialize_api() -> Optional[DatafoldAPI]:
250260
return DatafoldAPI(api_key=api_key, host=datafold_host)
251261

252262

253-
def _cloud_diff(diff_vars: DiffVars, datasource_id: int, api: DatafoldAPI) -> None:
263+
def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI) -> None:
254264
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
255265
payload = TCloudApiDataDiff(
256266
data_source1_id=datasource_id,
@@ -260,6 +270,8 @@ def _cloud_diff(diff_vars: DiffVars, datasource_id: int, api: DatafoldAPI) -> No
260270
pk_columns=diff_vars.primary_keys,
261271
filter1=diff_vars.where_filter,
262272
filter2=diff_vars.where_filter,
273+
include_columns=diff_vars.include_columns,
274+
exclude_columns=diff_vars.exclude_columns,
263275
)
264276

265277
if is_tracking_enabled():

data_diff/dbt_parser.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import List, Dict, Tuple, Set, Optional
77

88
from packaging.version import parse as parse_version
9+
import pydantic
910

1011
from .utils import getLogger, get_from_dict_with_raise
1112
from .version import __version__
@@ -68,6 +69,12 @@ def legacy_profiles_dir() -> Path:
6869
return Path.home() / ".dbt"
6970

7071

72+
class TDatadiffModelConfig(pydantic.BaseModel):
73+
where_filter: Optional[str] = None
74+
include_columns: List[str] = []
75+
exclude_columns: List[str] = []
76+
77+
7178
class DbtParser:
7279
def __init__(self, profiles_dir_override: str, project_dir_override: str) -> None:
7380
(
@@ -79,7 +86,7 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str) -> Non
7986
) = import_dbt_dependencies()
8087
self.profiles_dir = Path(profiles_dir_override or default_profiles_dir())
8188
self.project_dir = Path(project_dir_override or default_project_dir())
82-
self.connection = None
89+
self.connection = {}
8390
self.project_dict = self.get_project_dict()
8491
self.manifest_obj = self.get_manifest_obj()
8592
self.dbt_user_id = self.manifest_obj.metadata.user_id
@@ -95,6 +102,21 @@ def get_datadiff_variables(self) -> dict:
95102
vars = get_from_dict_with_raise(self.project_dict, "vars", error_message)
96103
return get_from_dict_with_raise(vars, "data_diff", error_message)
97104

105+
def get_datadiff_model_config(self, model_meta: dict) -> TDatadiffModelConfig:
106+
where_filter = None
107+
include_columns = []
108+
exclude_columns = []
109+
110+
if "datafold" in model_meta and "datadiff" in model_meta["datafold"]:
111+
config = model_meta["datafold"]["datadiff"]
112+
where_filter = config.get("filter")
113+
include_columns = config.get("include_columns") or []
114+
exclude_columns = config.get("exclude_columns") or []
115+
116+
return TDatadiffModelConfig(
117+
where_filter=where_filter, include_columns=include_columns, exclude_columns=exclude_columns
118+
)
119+
98120
def get_models(self, dbt_selection: Optional[str] = None):
99121
dbt_version = parse_version(self.dbt_version)
100122
if dbt_selection:

0 commit comments

Comments
 (0)