Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
import json
import logging
from itertools import islice
from typing import Optional
from typing import Dict, Optional

import rich
from rich.logging import RichHandler
import click

from data_diff.sqeleton.schema import create_schema
from data_diff.sqeleton.queries.api import current_timestamp

from .dbt import dbt_diff
from .utils import eval_name_template, remove_password_from_url, safezip, match_like
from .utils import eval_name_template, remove_password_from_url, safezip, match_like, LogStatusHandler
from .diff_tables import Algorithm
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer
Expand All @@ -27,9 +28,6 @@
from .version import __version__


LOG_FORMAT = "[%(asctime)s] %(levelname)s - %(message)s"
DATE_FORMAT = "%H:%M:%S"

COLOR_SCHEME = {
"+": "green",
"-": "red",
Expand All @@ -38,6 +36,28 @@
set_entrypoint_name("CLI")


def _get_log_handlers(is_dbt: Optional[bool] = False) -> Dict[str, logging.Handler]:
handlers = {}
date_format = "%H:%M:%S"
log_format_rich = "%(message)s"

# limits to 100 characters arbitrarily
log_format_status = "%(message).100s"
rich_handler = RichHandler(rich_tracebacks=True)
rich_handler.setFormatter(logging.Formatter(log_format_rich, datefmt=date_format))
rich_handler.setLevel(logging.WARN)
handlers["rich_handler"] = rich_handler

# only use log_status_handler in a terminal
if rich_handler.console.is_terminal and is_dbt:
log_status_handler = LogStatusHandler()
log_status_handler.setFormatter(logging.Formatter(log_format_status, datefmt=date_format))
log_status_handler.setLevel(logging.DEBUG)
handlers["log_status_handler"] = log_status_handler

return handlers


def _remove_passwords_in_dict(d: dict):
for k, v in d.items():
if k == "password":
Expand Down Expand Up @@ -244,6 +264,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
help="Specify manifest to utilize for 'prod' comparison paths instead of using configuration.",
)
def main(conf, run, **kw):
log_handlers = _get_log_handlers(kw["dbt"])
if kw["table2"] is None and kw["database2"]:
# Use the "database table table" form
kw["table2"] = kw["database2"]
Expand All @@ -263,15 +284,18 @@ def main(conf, run, **kw):
kw["debug"] = True

if kw["debug"]:
logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT)
log_handlers["rich_handler"].setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG, handlers=list(log_handlers.values()))
if kw.get("__conf__"):
kw["__conf__"] = deepcopy(kw["__conf__"])
_remove_passwords_in_dict(kw["__conf__"])
logging.debug(f"Applied run configuration: {kw['__conf__']}")
elif kw.get("verbose"):
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT, datefmt=DATE_FORMAT)
log_handlers["rich_handler"].setLevel(logging.INFO)
logging.basicConfig(level=logging.DEBUG, handlers=list(log_handlers.values()))
else:
logging.basicConfig(level=logging.WARNING, format=LOG_FORMAT, datefmt=DATE_FORMAT)
log_handlers["rich_handler"].setLevel(logging.WARNING)
logging.basicConfig(level=logging.DEBUG, handlers=list(log_handlers.values()))

try:
state = kw.pop("state", None)
Expand All @@ -285,6 +309,7 @@ def main(conf, run, **kw):
project_dir_override = os.path.expanduser(project_dir_override)
if kw["dbt"]:
dbt_diff(
log_status_handler=log_handlers.get("log_status_handler"),
profiles_dir_override=profiles_dir_override,
project_dir_override=project_dir_override,
is_cloud=kw["cloud"],
Expand Down
2 changes: 1 addition & 1 deletion data_diff/cloud/datafold_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def poll_data_diff_results(self, diff_id: int) -> TCloudApiDataDiffSummaryResult

diff_url = f"{self.host}/datadiffs/{diff_id}/overview"
while not summary_results:
logger.debug(f"Polling: {diff_url}")
logger.debug("Polling Datafold for results...")
response = self.make_get_request(url=f"api/v1/datadiffs/{diff_id}/summary_results")
response_json = response.json()
if response_json["status"] == "success":
Expand Down
104 changes: 61 additions & 43 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import nullcontext
import json
import os
import re
Expand Down Expand Up @@ -42,6 +43,7 @@
run_as_daemon,
truncate_error,
print_version_info,
LogStatusHandler,
)

logger = getLogger(__name__)
Expand All @@ -67,6 +69,7 @@ def dbt_diff(
dbt_selection: Optional[str] = None,
json_output: bool = False,
state: Optional[str] = None,
log_status_handler: Optional[LogStatusHandler] = None,
where_flag: Optional[str] = None,
) -> None:
print_version_info()
Expand All @@ -89,7 +92,6 @@ def dbt_diff(
if not api:
return
org_meta = api.get_org_meta()

if config.datasource_id is None:
rich.print("[red]Data source ID not found in dbt_project.yml")
raise DataDiffNoDatasourceIdError(
Expand All @@ -103,48 +105,54 @@ def dbt_diff(
else:
dbt_parser.set_connection()

for model in models:
diff_vars = _get_diff_vars(dbt_parser, config, model, where_flag)

# we won't always have a prod path when using state
# when the model DNE in prod manifest, skip the model diff
if (
state and len(diff_vars.prod_path) < 2
): # < 2 because some providers like databricks can legitimately have *only* 2
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
diff_output_str += "[green]New model: nothing to diff![/] \n"
rich.print(diff_output_str)
continue

if diff_vars.primary_keys:
if is_cloud:
diff_thread = run_as_daemon(_cloud_diff, diff_vars, config.datasource_id, api, org_meta)
diff_threads.append(diff_thread)
else:
_local_diff(diff_vars, json_output)
else:
if json_output:
print(
json.dumps(
jsonify_error(
table1=diff_vars.prod_path,
table2=diff_vars.dev_path,
dbt_model=diff_vars.dbt_model,
error="No primary key found. Add uniqueness tests, meta, or tags.",
)
),
flush=True,
)
with log_status_handler.status if log_status_handler else nullcontext():
for model in models:
if log_status_handler:
log_status_handler.set_prefix(f"Diffing {model.alias} \n")

diff_vars = _get_diff_vars(dbt_parser, config, model, where_flag)

# we won't always have a prod path when using state
# when the model DNE in prod manifest, skip the model diff
if (
state and len(diff_vars.prod_path) < 2
): # < 2 because some providers like databricks can legitimately have *only* 2
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
diff_output_str += "[green]New model: nothing to diff![/] \n"
rich.print(diff_output_str)
continue

if diff_vars.primary_keys:
if is_cloud:
diff_thread = run_as_daemon(
_cloud_diff, diff_vars, config.datasource_id, api, org_meta, log_status_handler
)
diff_threads.append(diff_thread)
else:
_local_diff(diff_vars, json_output)
else:
rich.print(
_diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
+ "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n"
)

# wait for all threads
if diff_threads:
for thread in diff_threads:
thread.join()
if json_output:
print(
json.dumps(
jsonify_error(
table1=diff_vars.prod_path,
table2=diff_vars.dev_path,
dbt_model=diff_vars.dbt_model,
error="No primary key found. Add uniqueness tests, meta, or tags.",
)
),
flush=True,
)
else:
rich.print(
_diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
+ "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n"
)

# wait for all threads
if diff_threads:
for thread in diff_threads:
thread.join()


def _get_diff_vars(
Expand Down Expand Up @@ -348,7 +356,15 @@ def _initialize_api() -> Optional[DatafoldAPI]:
return DatafoldAPI(api_key=api_key, host=datafold_host)


def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI, org_meta: TCloudApiOrgMeta) -> None:
def _cloud_diff(
diff_vars: TDiffVars,
datasource_id: int,
api: DatafoldAPI,
org_meta: TCloudApiOrgMeta,
log_status_handler: Optional[LogStatusHandler] = None,
) -> None:
if log_status_handler:
log_status_handler.cloud_diff_started(diff_vars.dev_path[-1])
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
payload = TCloudApiDataDiff(
data_source1_id=datasource_id,
Expand Down Expand Up @@ -417,6 +433,8 @@ def _cloud_diff(diff_vars: TDiffVars, datasource_id: int, api: DatafoldAPI, org_
diff_output_str += f"\n{diff_url}\n{no_differences_template()}\n"
rich.print(diff_output_str)

if log_status_handler:
log_status_handler.cloud_diff_finished(diff_vars.dev_path[-1])
except BaseException as ex: # Catch KeyboardInterrupt too
error = ex
finally:
Expand Down
37 changes: 37 additions & 0 deletions data_diff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import requests
from tabulate import tabulate
from .version import __version__
from rich.status import Status


def safezip(*args):
Expand Down Expand Up @@ -211,3 +212,39 @@ def print_version_info() -> None:
print(f"{base_version_string} (Update {latest_version} is available!)")
else:
print(base_version_string)


class LogStatusHandler(logging.Handler):
"""
This log handler can be used to update a rich.status every time a log is emitted.
"""

def __init__(self):
super().__init__()
self.status = Status("")
self.prefix = ""
self.cloud_diff_status = {}

def emit(self, record):
log_entry = self.format(record)
if self.cloud_diff_status:
self._update_cloud_status(log_entry)
else:
self.status.update(self.prefix + log_entry)

def set_prefix(self, prefix_string):
self.prefix = prefix_string

def cloud_diff_started(self, model_name):
self.cloud_diff_status[model_name] = "[yellow]In Progress[/]"
self._update_cloud_status()

def cloud_diff_finished(self, model_name):
self.cloud_diff_status[model_name] = "[green]Finished [/]"
self._update_cloud_status()

def _update_cloud_status(self, log=None):
cloud_status_string = "\n"
for model_name, status in self.cloud_diff_status.items():
cloud_status_string += f"{status} {model_name}\n"
self.status.update(f"{cloud_status_string}{log or ''}")
2 changes: 1 addition & 1 deletion tests/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def test_diff_is_cloud(

mock_initialize_api.assert_called_once()
mock_api.get_data_source.assert_called_once_with(1)
mock_cloud_diff.assert_called_once_with(diff_vars, 1, mock_api, org_meta)
mock_cloud_diff.assert_called_once_with(diff_vars, 1, mock_api, org_meta, None)
mock_local_diff.assert_not_called()
mock_print.assert_called_once()

Expand Down