Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,13 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
metavar="PATH",
help="Which directory to look in for the dbt_project.yml file. Default is the current working directory and its parents.",
)
@click.option(
"--select",
"-s",
default=None,
metavar="PATH",
help="select dbt resources to compare using dbt selection syntax",
)
def main(conf, run, **kw):
if kw["table2"] is None and kw["database2"]:
# Use the "database table table" form
Expand Down Expand Up @@ -264,6 +271,7 @@ def main(conf, run, **kw):
profiles_dir_override=kw["dbt_profiles_dir"],
project_dir_override=kw["dbt_project_dir"],
is_cloud=kw["cloud"],
dbt_selection=kw["select"],
)
else:
return _data_diff(**kw)
Expand Down Expand Up @@ -306,6 +314,7 @@ def _data_diff(
cloud,
dbt_profiles_dir,
dbt_project_dir,
select,
threads1=None,
threads2=None,
__conf__=None,
Expand Down
18 changes: 5 additions & 13 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,6 @@
logger = getLogger(__name__)


def import_dbt():
try:
from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
from dbt.config.renderer import ProfileRenderer
import yaml
except ImportError:
raise RuntimeError("Could not import 'dbt' package. You can install it using: pip install 'data-diff[dbt]'.")

return parse_run_results, parse_manifest, ProfileRenderer, yaml


from .tracking import (
set_entrypoint_name,
set_dbt_user_id,
Expand All @@ -54,12 +43,15 @@ class DiffVars:


def dbt_diff(
profiles_dir_override: Optional[str] = None, project_dir_override: Optional[str] = None, is_cloud: bool = False
profiles_dir_override: Optional[str] = None,
project_dir_override: Optional[str] = None,
is_cloud: bool = False,
dbt_selection: Optional[str] = None,
) -> None:
diff_threads = []
set_entrypoint_name("CLI-dbt")
dbt_parser = DbtParser(profiles_dir_override, project_dir_override)
models = dbt_parser.get_models()
models = dbt_parser.get_models(dbt_selection)
datadiff_variables = dbt_parser.get_datadiff_variables()
config_prod_database = datadiff_variables.get("prod_database")
config_prod_schema = datadiff_variables.get("prod_schema")
Expand Down
91 changes: 81 additions & 10 deletions data_diff/dbt_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections import defaultdict
import json
import os
from pathlib import Path
from typing import List, Dict, Tuple, Set
from typing import List, Dict, Tuple, Set, Optional

from packaging.version import parse as parse_version

Expand All @@ -12,23 +13,34 @@
logger = getLogger(__name__)


def import_dbt():
def import_dbt_dependencies():
try:
from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
from dbt.config.renderer import ProfileRenderer
import yaml
except ImportError:
raise RuntimeError("Could not import 'dbt' package. You can install it using: pip install 'data-diff[dbt]'.")

return parse_run_results, parse_manifest, ProfileRenderer, yaml
# dbt 1.5+ specific stuff to power selection of models
try:
from dbt.cli.main import dbtRunner
except ImportError:
dbtRunner = None

if dbtRunner is not None:
dbt_runner = dbtRunner()
else:
dbt_runner = None

return parse_run_results, parse_manifest, ProfileRenderer, yaml, dbt_runner


RUN_RESULTS_PATH = "target/run_results.json"
MANIFEST_PATH = "target/manifest.json"
PROJECT_FILE = "dbt_project.yml"
PROFILES_FILE = "profiles.yml"
LOWER_DBT_V = "1.0.0"
UPPER_DBT_V = "1.4.7"
UPPER_DBT_V = "1.6.0"


# https://github.com/dbt-labs/dbt-core/blob/c952d44ec5c2506995fbad75320acbae49125d3d/core/dbt/cli/resolvers.py#L6
Expand All @@ -49,7 +61,13 @@ def legacy_profiles_dir() -> Path:

class DbtParser:
def __init__(self, profiles_dir_override: str, project_dir_override: str) -> None:
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
(
self.parse_run_results,
self.parse_manifest,
self.ProfileRenderer,
self.yaml,
self.dbt_runner,
) = import_dbt_dependencies()
self.profiles_dir = Path(profiles_dir_override or default_profiles_dir())
self.project_dir = Path(project_dir_override or default_project_dir())
self.connection = None
Expand All @@ -68,7 +86,60 @@ def get_datadiff_variables(self) -> dict:
vars = get_from_dict_with_raise(self.project_dict, "vars", error_message)
return get_from_dict_with_raise(vars, "data_diff", error_message)

def get_models(self):
def get_models(self, dbt_selection: Optional[str] = None):
dbt_version = parse_version(self.dbt_version)
if dbt_selection:
if (dbt_version.major, dbt_version.minor) >= (1, 5):
if self.dbt_runner:
return self.get_dbt_selection_models(dbt_selection)
# edge case if running data-diff from a separate env than dbt (likely local development)
else:
raise Exception(
"data-diff is using a dbt-core version < 1.5, update the environment's dbt-core version via pip install 'dbt-core>=1.5' in order to use `--select`"
)
else:
raise Exception(
f"Use of the `--select` feature requires dbt >= 1.5. Found dbt manifest: v{dbt_version}"
)
else:
return self.get_run_results_models()

def get_dbt_selection_models(self, dbt_selection: str) -> List[str]:
# log level and format settings needed to prevent dbt from printing to stdout
# ls command is used to get the list of model unique_ids
results = self.dbt_runner.invoke(
[
"--log-format",
"json",
"--log-level",
"none",
"ls",
"--select",
dbt_selection,
"--resource-type",
"model",
"--output",
"json",
"--output-keys",
"unique_id",
"--project-dir",
self.project_dir,
]
)
if results.success and results.result:
model_list = [json.loads(model)["unique_id"] for model in results.result]
models = [self.manifest_obj.nodes.get(x) for x in model_list]
return models
elif not results.result:
raise Exception(f"No dbt models found for `--select {dbt_selection}`")
else:
if results.exception:
raise results.exception
else:
logger.debug(str(results))
raise Exception("Encountered an error while finding `--select` models")

def get_run_results_models(self):
with open(self.project_dir / RUN_RESULTS_PATH) as run_results:
logger.info(f"Parsing file {RUN_RESULTS_PATH}")
run_results_dict = json.load(run_results)
Expand All @@ -80,11 +151,11 @@ def get_models(self):
self.profiles_dir = legacy_profiles_dir()

if dbt_version < parse_version(LOWER_DBT_V):
raise Exception(
f"Found dbt: v{dbt_version} Expected the dbt project's version to be >= {LOWER_DBT_V}"
)
raise Exception(f"Found dbt: v{dbt_version} Expected the dbt project's version to be >= {LOWER_DBT_V}")
elif dbt_version >= parse_version(UPPER_DBT_V):
logger.warning(f"{dbt_version} is a recent version of dbt and may not be fully tested with data-diff! \nPlease report any issues to https://github.com/datafold/data-diff/issues")
logger.warning(
f"{dbt_version} is a recent version of dbt and may not be fully tested with data-diff! \nPlease report any issues to https://github.com/datafold/data-diff/issues"
)

success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"]
models = [self.manifest_obj.nodes.get(x) for x in success_models]
Expand Down
Loading