Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,22 +116,22 @@ $ data-diff \

## Supported Databases

| Database | Connection string | Status |
|---------------|------------------------------------------------------------------------------------------------------------------------------------|--------|
| PostgreSQL | `postgresql://<user>:<password>@<hostname>:5432/<database>` | 💚 |
| MySQL | `mysql://<user>:<password>@<hostname>:5432/<database>` | 💚 |
| Snowflake | `"snowflake://<user>[:<password>]@<account>/<database>/<SCHEMA>?warehouse=<WAREHOUSE>&role=<role>[&authenticator=externalbrowser]"`| 💚 |
| Oracle | `oracle://<username>:<password>@<hostname>/database` | 💛 |
| BigQuery | `bigquery://<project>/<dataset>` | 💛 |
| Redshift | `redshift://<username>:<password>@<hostname>:5439/<database>` | 💛 |
| Presto | `presto://<username>:<password>@<hostname>:8080/<database>` | 💛 |
| ElasticSearch | | 📝 |
| Databricks | | 📝 |
| Planetscale | | 📝 |
| Clickhouse | | 📝 |
| Pinot | | 📝 |
| Druid | | 📝 |
| Kafka | | 📝 |
| Database | Connection string | Status |
|---------------|-------------------------------------------------------------------------------------------------------------------------------------|--------|
| PostgreSQL | `postgresql://<user>:<password>@<hostname>:5432/<database>` | 💚 |
| MySQL | `mysql://<user>:<password>@<hostname>:5432/<database>` | 💚 |
| Snowflake | `"snowflake://<user>[:<password>]@<account>/<database>/<SCHEMA>?warehouse=<WAREHOUSE>&role=<role>[&authenticator=externalbrowser]"` | 💚 |
| Oracle | `oracle://<username>:<password>@<hostname>/database` | 💛 |
| BigQuery | `bigquery://<project>/<dataset>` | 💛 |
| Redshift | `redshift://<username>:<password>@<hostname>:5439/<database>` | 💛 |
| Presto | `presto://<username>:<password>@<hostname>:8080/<database>` | 💛 |
| Databricks | `databricks://<http_path>:<access_token>@<server_hostname>/<catalog>/<schema>` | 💛 |
| ElasticSearch | | 📝 | | 📝 |
| Planetscale | | 📝 |
| Clickhouse | | 📝 |
| Pinot | | 📝 |
| Druid | | 📝 |
| Kafka | | 📝 |

* 💚: Implemented and thoroughly tested.
* 💛: Implemented, but not thoroughly tested yet.
Expand Down
1 change: 1 addition & 0 deletions data_diff/databases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
from .bigquery import BigQuery
from .redshift import Redshift
from .presto import Presto
from .databricks import Databricks

from .connect import connect_to_uri
43 changes: 29 additions & 14 deletions data_diff/databases/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .bigquery import BigQuery
from .redshift import Redshift
from .presto import Presto
from .databricks import Databricks


@dataclass
Expand Down Expand Up @@ -77,6 +78,9 @@ def match_path(self, dsn):
),
"presto": MatchUriPath(Presto, ["catalog", "schema"], help_str="presto://<user>@<host>/<catalog>/<schema>"),
"bigquery": MatchUriPath(BigQuery, ["dataset"], help_str="bigquery://<project>/<dataset>"),
"databricks": MatchUriPath(
Databricks, ["catalog", "schema"], help_str="databricks://:access_token@server_name/http_path",
)
}


Expand All @@ -100,6 +104,7 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database:
- bigquery
- redshift
- presto
- databricks
"""

dsn = dsnparse.parse(db_uri)
Expand All @@ -113,23 +118,33 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database:
raise NotImplementedError(f"Scheme {scheme} currently not supported")

cls = matcher.database_cls
kw = matcher.match_path(dsn)

if scheme == "bigquery":
kw["project"] = dsn.host
return cls(**kw)

if scheme == "snowflake":
kw["account"] = dsn.host
assert not dsn.port
kw["user"] = dsn.user
kw["password"] = dsn.password
if scheme == "databricks":
assert not dsn.user
kw = {}
kw['access_token'] = dsn.password
kw['http_path'] = dsn.path
kw['server_hostname'] = dsn.host
kw.update(dsn.query)
else:
kw["host"] = dsn.host
kw["port"] = dsn.port
kw["user"] = dsn.user
if dsn.password:
kw = matcher.match_path(dsn)

if scheme == "bigquery":
kw["project"] = dsn.host
return cls(**kw)

if scheme == "snowflake":
kw["account"] = dsn.host
assert not dsn.port
kw["user"] = dsn.user
kw["password"] = dsn.password
else:
kw["host"] = dsn.host
kw["port"] = dsn.port
kw["user"] = dsn.user
if dsn.password:
kw["password"] = dsn.password

kw = {k: v for k, v in kw.items() if v is not None}

if issubclass(cls, ThreadedDatabase):
Expand Down
135 changes: 135 additions & 0 deletions data_diff/databases/databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import logging
import math

from .database_types import *
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, Database, import_helper, _query_conn, parse_table_name


@import_helper("databricks")
def import_databricks():
import databricks.sql

return databricks


class Databricks(Database):
TYPE_CLASSES = {
# Numbers
"INT": Integer,
"SMALLINT": Integer,
"TINYINT": Integer,
"BIGINT": Integer,
"FLOAT": Float,
"DOUBLE": Float,
"DECIMAL": Decimal,
# Timestamps
"TIMESTAMP": Timestamp,
# Text
"STRING": Text,
}

ROUNDS_ON_PREC_LOSS = True

def __init__(
self,
http_path: str,
access_token: str,
server_hostname: str,
catalog: str = "hive_metastore",
schema: str = "default",
**kwargs,
):
databricks = import_databricks()

self._conn = databricks.sql.connect(
server_hostname=server_hostname, http_path=http_path, access_token=access_token
)

logging.getLogger("databricks.sql").setLevel(logging.WARNING)

self.catalog = catalog
self.default_schema = schema
self.kwargs = kwargs

def _query(self, sql_code: str) -> list:
"Uses the standard SQL cursor interface"
return _query_conn(self._conn, sql_code)

def quote(self, s: str):
return f"`{s}`"

def md5_to_int(self, s: str) -> str:
return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))"

def to_string(self, s: str) -> str:
return f"cast({s} as string)"

def _convert_db_precision_to_digits(self, p: int) -> int:
# Subtracting 1 due to wierd precision issues
return max(super()._convert_db_precision_to_digits(p) - 1, 0)

def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]:
# Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL.
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html
# So, to obtain information about schema, we should use another approach.

schema, table = self._normalize_table_path(path)
with self._conn.cursor() as cursor:
cursor.columns(catalog_name=self.catalog, schema_name=schema, table_name=table)
rows = cursor.fetchall()
if not rows:
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")

if filter_columns is not None:
accept = {i.lower() for i in filter_columns}
rows = [r for r in rows if r.COLUMN_NAME.lower() in accept]

resulted_rows = []
for row in rows:
row_type = "DECIMAL" if row.DATA_TYPE == 3 else row.TYPE_NAME
type_cls = self.TYPE_CLASSES.get(row_type, UnknownColType)

if issubclass(type_cls, Integer):
row = (row.COLUMN_NAME, row_type, None, None, 0)

elif issubclass(type_cls, Float):
numeric_precision = self._convert_db_precision_to_digits(row.DECIMAL_DIGITS)
row = (row.COLUMN_NAME, row_type, None, numeric_precision, None)

elif issubclass(type_cls, Decimal):
# TYPE_NAME has a format DECIMAL(x,y)
items = row.TYPE_NAME[8:].rstrip(")").split(",")
numeric_precision, numeric_scale = int(items[0]), int(items[1])
row = (row.COLUMN_NAME, row_type, None, numeric_precision, numeric_scale)

elif issubclass(type_cls, Timestamp):
row = (row.COLUMN_NAME, row_type, row.DECIMAL_DIGITS, None, None)

else:
row = (row.COLUMN_NAME, row_type, None, None, None)

resulted_rows.append(row)
col_dict: Dict[str, ColType] = {row[0]: self._parse_type(path, *row) for row in resulted_rows}

self._refine_coltypes(path, col_dict)
return col_dict

def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
"""Databricks timestamp contains no more than 6 digits in precision"""

if coltype.rounds:
timestamp = f"cast(round(unix_micros({value}) / 1000000, {coltype.precision}) * 1000000 as bigint)"
return f"date_format(timestamp_micros({timestamp}), 'yyyy-MM-dd HH:mm:ss.SSSSSS')"
else:
precision_format = "S" * coltype.precision + "0" * (6 - coltype.precision)
return f"date_format({value}, 'yyyy-MM-dd HH:mm:ss.{precision_format}')"

def normalize_number(self, value: str, coltype: NumericType) -> str:
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")

def parse_table_name(self, name: str) -> DbPath:
path = parse_table_name(name)
return self._normalize_table_path(path)

def close(self):
self._conn.close()
Loading