Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 14193b9

Browse files
authored
Merge pull request #543 from datafold/followup-sqeleton
Follow-up the sqeleton-to-datadiff embedding
2 parents 33bf6a9 + 53f01c7 commit 14193b9

File tree

7 files changed

+88
-3
lines changed

7 files changed

+88
-3
lines changed

data_diff/sqeleton/abcs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010
PrecisionType,
1111
StringType,
1212
Boolean,
13+
JSONType,
1314
)
1415
from .compiler import AbstractCompiler, Compilable

data_diff/sqeleton/abcs/database_types.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,22 @@ class Text(StringType):
134134
supported = False
135135

136136

137+
class JSONType(ColType):
138+
pass
139+
140+
141+
class RedShiftSuper(JSONType):
142+
pass
143+
144+
145+
class PostgresqlJSON(JSONType):
146+
pass
147+
148+
149+
class PostgresqlJSONB(JSONType):
150+
pass
151+
152+
137153
@dataclass
138154
class Integer(NumericType, IKey):
139155
precision: int = 0

data_diff/sqeleton/abcs/mixins.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from .database_types import TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID
2+
from .database_types import TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID, JSONType
33
from .compiler import Compilable
44

55

@@ -49,6 +49,10 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
4949
return f"TRIM({value})"
5050
return self.to_string(value)
5151

52+
def normalize_json(self, value: str, _coltype: JSONType) -> str:
53+
"""Creates an SQL expression, that converts 'value' to its minified json string representation."""
54+
raise NotImplementedError()
55+
5256
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
5357
"""Creates an SQL expression, that converts 'value' to a normalized representation.
5458
@@ -73,6 +77,8 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
7377
return self.normalize_uuid(value, coltype)
7478
elif isinstance(coltype, Boolean):
7579
return self.normalize_boolean(value, coltype)
80+
elif isinstance(coltype, JSONType):
81+
return self.normalize_json(value, coltype)
7682
return self.to_string(value)
7783

7884

data_diff/sqeleton/databases/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
DbTime,
3636
DbPath,
3737
Boolean,
38+
JSONType
3839
)
3940
from ..abcs.mixins import Compilable
4041
from ..abcs.mixins import (
@@ -259,6 +260,9 @@ def parse_type(
259260
elif issubclass(cls, (Text, Native_UUID)):
260261
return cls()
261262

263+
elif issubclass(cls, JSONType):
264+
return cls()
265+
262266
raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")
263267

264268
def _convert_db_precision_to_digits(self, p: int) -> int:

data_diff/sqeleton/databases/oracle.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints):
9292
"NCHAR": Text,
9393
"NVARCHAR2": Text,
9494
"VARCHAR2": Text,
95+
"DATE": Timestamp,
9596
}
9697
ROUNDS_ON_PREC_LOSS = True
9798
PLACEHOLDER_TABLE = "DUAL"

data_diff/sqeleton/databases/postgresql.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
FractionalType,
1212
Boolean,
1313
Date,
14+
PostgresqlJSON,
15+
PostgresqlJSONB
1416
)
1517
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
1618
from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, Mixin_Schema
@@ -49,6 +51,9 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
4951
def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
5052
return self.to_string(f"{value}::int")
5153

54+
def normalize_json(self, value: str, _coltype: PostgresqlJSON) -> str:
55+
return f"{value}::text"
56+
5257

5358
class PostgresqlDialect(BaseDialect, Mixin_Schema):
5459
name = "PostgreSQL"
@@ -76,6 +81,9 @@ class PostgresqlDialect(BaseDialect, Mixin_Schema):
7681
"character varying": Text,
7782
"varchar": Text,
7883
"text": Text,
84+
# JSON
85+
"json": PostgresqlJSON,
86+
"jsonb": PostgresqlJSONB,
7987
# UUID
8088
"uuid": Native_UUID,
8189
# Boolean

data_diff/sqeleton/databases/redshift.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
from typing import List, Dict
2-
from ..abcs.database_types import Float, TemporalType, FractionalType, DbPath, TimestampTZ
2+
from ..abcs.database_types import (
3+
Float,
4+
TemporalType,
5+
FractionalType,
6+
DbPath,
7+
TimestampTZ,
8+
RedShiftSuper
9+
)
310
from ..abcs.mixins import AbstractMixin_MD5
411
from .postgresql import (
512
PostgreSQL,
@@ -40,13 +47,18 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
4047
def normalize_number(self, value: str, coltype: FractionalType) -> str:
4148
return self.to_string(f"{value}::decimal(38,{coltype.precision})")
4249

50+
def normalize_json(self, value: str, _coltype: RedShiftSuper) -> str:
51+
return f'nvl2({value}, json_serialize({value}), NULL)'
52+
4353

4454
class Dialect(PostgresqlDialect):
4555
name = "Redshift"
4656
TYPE_CLASSES = {
4757
**PostgresqlDialect.TYPE_CLASSES,
4858
"double": Float,
4959
"real": Float,
60+
# JSON
61+
"super": RedShiftSuper
5062
}
5163
SUPPORTS_INDEXES = False
5264

@@ -109,11 +121,48 @@ def query_external_table_schema(self, path: DbPath) -> Dict[str, tuple]:
109121
assert len(d) == len(rows)
110122
return d
111123

124+
def select_view_columns(self, path: DbPath) -> str:
125+
_, schema, table = self._normalize_table_path(path)
126+
127+
return (
128+
"""select * from pg_get_cols('{}.{}')
129+
cols(view_schema name, view_name name, col_name name, col_type varchar, col_num int)
130+
""".format(schema, table)
131+
)
132+
133+
def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]:
134+
rows = self.query(self.select_view_columns(path), list)
135+
136+
if not rows:
137+
raise RuntimeError(f"{self.name}: View '{'.'.join(path)}' does not exist, or has no columns")
138+
139+
output = {}
140+
for r in rows:
141+
col_name = r[2]
142+
type_info = r[3].split('(')
143+
base_type = type_info[0]
144+
precision = None
145+
scale = None
146+
147+
if len(type_info) > 1:
148+
if base_type == 'numeric':
149+
precision, scale = type_info[1][:-1].split(',')
150+
precision = int(precision)
151+
scale = int(scale)
152+
153+
out = [col_name, base_type, None, precision, scale]
154+
output[col_name] = tuple(out)
155+
156+
return output
157+
112158
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
113159
try:
114160
return super().query_table_schema(path)
115161
except RuntimeError:
116-
return self.query_external_table_schema(path)
162+
try:
163+
return self.query_external_table_schema(path)
164+
except RuntimeError:
165+
return self.query_pg_get_cols()
117166

118167
def _normalize_table_path(self, path: DbPath) -> DbPath:
119168
if len(path) == 1:

0 commit comments

Comments
 (0)