Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
columns_added_template,
columns_removed_template,
no_differences_template,
columns_type_changed_template,
)
from pathlib import Path

import keyring

Expand Down Expand Up @@ -191,26 +191,36 @@ def _local_diff(diff_vars: TDiffVars) -> None:
diff_vars.connection, prod_qualified_str, tuple(diff_vars.primary_keys), diff_vars.threads
)

table1_columns = list(table1.get_schema())
table1_columns = table1.get_schema()
try:
table2_columns = list(table2.get_schema())
table2_columns = table2.get_schema()
# Not ideal, but we don't have more specific exceptions yet
except Exception as ex:
logger.debug(ex)
diff_output_str += "[red]New model or no access to prod table.[/] \n"
rich.print(diff_output_str)
return

column_set = set(table1_columns).intersection(table2_columns)
columns_added = set(table1_columns).difference(table2_columns)
columns_removed = set(table2_columns).difference(table1_columns)
table1_column_names = set(table1_columns.keys())
table2_column_names = set(table2_columns.keys())
column_set = table1_column_names.intersection(table2_column_names)
columns_added = table1_column_names.difference(table2_column_names)
columns_removed = table2_column_names.difference(table1_column_names)
# col type is i = 1 in tuple
columns_type_changed = {
k for k, v in table1_columns.items() if k in table2_columns and v[1] != table2_columns[k][1]
}

if columns_added:
diff_output_str += columns_added_template(columns_added)

if columns_removed:
diff_output_str += columns_removed_template(columns_removed)

if columns_type_changed:
diff_output_str += columns_type_changed_template(columns_type_changed)
column_set = column_set.difference(columns_type_changed)

column_set = column_set - set(diff_vars.primary_keys)

if diff_vars.include_columns:
Expand Down Expand Up @@ -321,7 +331,7 @@ def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI, org_
diff_output_str += columns_removed_template(columns_removed)

if column_type_changes:
diff_output_str += "Type change: " + str(column_type_changes) + "\n"
diff_output_str += columns_type_changed_template(column_type_changes)

if any([rows_added_count, rows_removed_count, rows_updated]):
diff_output = dbt_diff_string_template(
Expand Down
17 changes: 11 additions & 6 deletions data_diff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,19 @@ def diffs_are_equiv_jsons(diff: list, json_cols: dict):
return match, overriden_diff_cols


def columns_removed_template(table2_set_diff) -> str:
columns_removed = "Column(s) removed: " + str(table2_set_diff) + "\n"
return columns_removed
def columns_removed_template(columns_removed) -> str:
columns_removed_str = f"Column(s) removed: {columns_removed}\n"
return columns_removed_str


def columns_added_template(table1_set_diff) -> str:
columns_added = "Column(s) added: " + str(table1_set_diff) + "\n"
return columns_added
def columns_added_template(columns_added) -> str:
columns_added_str = f"Column(s) added: {columns_added}\n"
return columns_added_str


def columns_type_changed_template(columns_type_changed) -> str:
columns_type_changed_str = f"Type change: {columns_type_changed}\n"
return columns_type_changed_str


def no_differences_template() -> str:
Expand Down
69 changes: 48 additions & 21 deletions tests/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,21 +442,6 @@ def test_get_connection_no_type(self, mock_open, mock_profile_renderer, mock_yam
_, _ = DbtParser.get_connection_creds(mock_self)


EXAMPLE_DIFF_RESULTS = {
"pks": {"exclusives": [5, 3]},
"values": {
"rows_with_differences": 2,
"total_rows": 10,
"columns_diff_stats": [
{"column_name": "name", "match": 80.0},
{"column_name": "age", "match": 100.0},
{"column_name": "city", "match": 0.0},
{"column_name": "country", "match": 100.0},
],
},
}


class TestDbtDiffer(unittest.TestCase):
# Set DATA_DIFF_DBT_PROJ to use your own dbt project, otherwise uses the duckdb project in tests/dbt_artifacts
def test_integration_basic_dbt(self):
Expand Down Expand Up @@ -488,10 +473,10 @@ def test_integration_cloud_dbt(self):
def test_local_diff(self, mock_diff_tables):
connection = {}
mock_table1 = Mock()
column_set = {"col1", "col2"}
mock_table1.get_schema.return_value = column_set
column_dictionary = {"col1": ("col1", "type"), "col2": ("col2", "type")}
mock_table1.get_schema.return_value = column_dictionary
mock_table2 = Mock()
mock_table2.get_schema.return_value = column_set
mock_table2.get_schema.return_value = column_dictionary
mock_diff = MagicMock()
mock_diff_tables.return_value = mock_diff
mock_diff.__iter__.return_value = [1, 2, 3]
Expand Down Expand Up @@ -527,14 +512,56 @@ def test_local_diff(self, mock_diff_tables):
mock_connect.assert_any_call(connection, ".".join(prod_qualified_list), tuple(expected_primary_keys), threads)
mock_diff.get_stats_string.assert_called_once()

@patch("data_diff.dbt.diff_tables")
def test_local_diff_types_differ(self, mock_diff_tables):
connection = {}
mock_table1 = Mock()
mock_table2 = Mock()
table1_column_dictionary = {"col1": ("col1", "type"), "col2": ("col2", "type")}
table2_column_dictionary = {"col1": ("col1", "type"), "col2": ("col2", "differing_type")}
mock_table1.get_schema.return_value = table1_column_dictionary
mock_table2.get_schema.return_value = table2_column_dictionary
mock_diff = MagicMock()
mock_diff_tables.return_value = mock_diff
mock_diff.__iter__.return_value = [1, 2, 3]
threads = None
where = "a_string"
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
expected_primary_keys = ["key"]
diff_vars = TDiffVars(
dev_path=dev_qualified_list,
prod_path=prod_qualified_list,
primary_keys=expected_primary_keys,
connection=connection,
threads=threads,
where_filter=where,
include_columns=[],
exclude_columns=[],
)
with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect:
_local_diff(diff_vars)

mock_diff_tables.assert_called_once_with(
mock_table1,
mock_table2,
threaded=True,
algorithm=Algorithm.JOINDIFF,
extra_columns=ANY,
where=where,
)
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 1)
self.assertEqual(mock_connect.call_count, 2)
mock_diff.get_stats_string.assert_called_once()

@patch("data_diff.dbt.diff_tables")
def test_local_diff_no_diffs(self, mock_diff_tables):
connection = {}
column_set = {"col1", "col2"}
column_dictionary = {"col1": ("col1", "type"), "col2": ("col2", "type")}
mock_table1 = Mock()
mock_table1.get_schema.return_value = column_set
mock_table1.get_schema.return_value = column_dictionary
mock_table2 = Mock()
mock_table2.get_schema.return_value = column_set
mock_table2.get_schema.return_value = column_dictionary
mock_diff = MagicMock()
mock_diff_tables.return_value = mock_diff
mock_diff.__iter__.return_value = []
Expand Down