Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Compare JSON, ARRAY, STRUCT types in BigQuery (simplistically) #533

Merged
merged 1 commit into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions data_diff/sqeleton/abcs/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DbTime = datetime


@dataclass
class ColType:
supported = True

Expand Down Expand Up @@ -140,6 +141,21 @@ class JSON(ColType):
pass


@dataclass
class Array(ColType):
item_type: ColType


# Unlike JSON, structs are not free-form and have a very specific set of fields and their types.
# We do not parse & use those fields now, but we can do this later.
# For example, in BigQuery:
# - https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#struct_type
# - https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#struct_literals
@dataclass
class Struct(ColType):
pass


@dataclass
class Integer(NumericType, IKey):
precision: int = 0
Expand Down Expand Up @@ -227,6 +243,10 @@ def parse_type(
) -> ColType:
"Parse type info as returned by the database"

@abstractmethod
def to_comparable(self, value: str, coltype: ColType) -> str:
"""Ensure that the expression is comparable in ``IS DISTINCT FROM``."""


from typing import TypeVar, Generic

Expand Down
21 changes: 19 additions & 2 deletions data_diff/sqeleton/abcs/mixins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from .database_types import TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID, JSON
from .database_types import Array, TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID, JSON, Struct
from .compiler import Compilable


Expand All @@ -8,6 +8,11 @@ class AbstractMixin(ABC):


class AbstractMixin_NormalizeValue(AbstractMixin):

@abstractmethod
def to_comparable(self, value: str, coltype: ColType) -> str:
"""Ensure that the expression is comparable in ``IS DISTINCT FROM``."""

@abstractmethod
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized timestamp.
Expand Down Expand Up @@ -51,7 +56,15 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:

def normalize_json(self, value: str, _coltype: JSON) -> str:
"""Creates an SQL expression, that converts 'value' to its minified json string representation."""
raise NotImplementedError()
return self.to_string(value)

def normalize_array(self, value: str, _coltype: Array) -> str:
"""Creates an SQL expression, that serialized an array into a JSON string."""
return self.to_string(value)

def normalize_struct(self, value: str, _coltype: Struct) -> str:
"""Creates an SQL expression, that serialized a typed struct into a JSON string."""
return self.to_string(value)

def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized representation.
Expand Down Expand Up @@ -79,6 +92,10 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
return self.normalize_boolean(value, coltype)
elif isinstance(coltype, JSON):
return self.normalize_json(value, coltype)
elif isinstance(coltype, Array):
return self.normalize_array(value, coltype)
elif isinstance(coltype, Struct):
return self.normalize_struct(value, coltype)
return self.to_string(value)


Expand Down
14 changes: 8 additions & 6 deletions data_diff/sqeleton/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from ..queries.ast_classes import Random
from ..abcs.database_types import (
AbstractDatabase,
T_Dialect,
Array,
Struct,
AbstractDialect,
AbstractTable,
ColType,
Expand Down Expand Up @@ -165,6 +166,10 @@ def concat(self, items: List[str]) -> str:
joined_exprs = ", ".join(items)
return f"concat({joined_exprs})"

def to_comparable(self, value: str, coltype: ColType) -> str:
"""Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
return value

def is_distinct_from(self, a: str, b: str) -> str:
return f"{a} is distinct from {b}"

Expand Down Expand Up @@ -229,7 +234,7 @@ def parse_type(
""" """

cls = self._parse_type_repr(type_repr)
if not cls:
if cls is None:
return UnknownColType(type_repr)

if issubclass(cls, TemporalType):
Expand Down Expand Up @@ -257,10 +262,7 @@ def parse_type(
)
)

elif issubclass(cls, (Text, Native_UUID)):
return cls()

elif issubclass(cls, JSON):
elif issubclass(cls, (JSON, Array, Struct, Text, Native_UUID)):
return cls()

raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")
Expand Down
69 changes: 66 additions & 3 deletions data_diff/sqeleton/databases/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from typing import List, Union
import re
from typing import Any, List, Union
from ..abcs.database_types import (
ColType,
Array,
JSON,
Struct,
Timestamp,
Datetime,
Integer,
Expand All @@ -10,6 +15,7 @@
FractionalType,
TemporalType,
Boolean,
UnknownColType,
)
from ..abcs.mixins import (
AbstractMixin_MD5,
Expand All @@ -36,6 +42,7 @@ def md5_as_int(self, s: str) -> str:


class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):

def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
Expand All @@ -57,6 +64,27 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
return self.to_string(f"cast({value} as int)")

def normalize_json(self, value: str, _coltype: JSON) -> str:
# BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.:
# Got error: 400 Grouping is not defined for arguments of type ARRAY<INT64> at …
# So we do the best effort and compare it as strings, hoping that the JSON forms
# match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc.
return f"to_json_string({value})"

def normalize_array(self, value: str, _coltype: Array) -> str:
# BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.:
# Got error: 400 Grouping is not defined for arguments of type ARRAY<INT64> at …
# So we do the best effort and compare it as strings, hoping that the JSON forms
# match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc.
return f"to_json_string({value})"

def normalize_struct(self, value: str, _coltype: Struct) -> str:
# BigQuery is unable to compare arrays & structs with ==/!=/distinct from, e.g.:
# Got error: 400 Grouping is not defined for arguments of type ARRAY<INT64> at …
# So we do the best effort and compare it as strings, hoping that the JSON forms
# match on both sides: i.e. have properly ordered keys, same spacing, same quotes, etc.
return f"to_json_string({value})"


class Mixin_Schema(AbstractMixin_Schema):
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
Expand Down Expand Up @@ -112,11 +140,12 @@ class Dialect(BaseDialect, Mixin_Schema):
"BIGNUMERIC": Decimal,
"FLOAT64": Float,
"FLOAT32": Float,
# Text
"STRING": Text,
# Boolean
"BOOL": Boolean,
"JSON": JSON,
}
TYPE_ARRAY_RE = re.compile(r'ARRAY<(.+)>')
TYPE_STRUCT_RE = re.compile(r'STRUCT<(.+)>')
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_TimeTravel, Mixin_RandomSample}

def random(self) -> str:
Expand All @@ -134,6 +163,40 @@ def type_repr(self, t) -> str:
except KeyError:
return super().type_repr(t)

def parse_type(
self,
table_path: DbPath,
col_name: str,
type_repr: str,
*args: Any, # pass-through args
**kwargs: Any, # pass-through args
) -> ColType:
col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs)
if isinstance(col_type, UnknownColType):

m = self.TYPE_ARRAY_RE.fullmatch(type_repr)
if m:
item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs)
col_type = Array(item_type=item_type)

# We currently ignore structs' structure, but later can parse it too. Examples:
# - STRUCT<INT64, STRING(10)> (unnamed)
# - STRUCT<foo INT64, bar STRING(10)> (named)
# - STRUCT<foo INT64, bar ARRAY<INT64>> (with complex fields)
# - STRUCT<foo INT64, bar STRUCT<a INT64, b INT64>> (nested)
m = self.TYPE_STRUCT_RE.fullmatch(type_repr)
if m:
col_type = Struct()

return col_type

def to_comparable(self, value: str, coltype: ColType) -> str:
"""Ensure that the expression is comparable in ``IS DISTINCT FROM``."""
if isinstance(coltype, (JSON, Array, Struct)):
return self.normalize_value_by_type(value, coltype)
else:
return super().to_comparable(value, coltype)

def set_timezone_to_utc(self) -> str:
raise NotImplementedError()

Expand Down
4 changes: 3 additions & 1 deletion data_diff/sqeleton/queries/ast_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,9 @@ class IsDistinctFrom(ExprNode, LazyOps):
type = bool

def compile(self, c: Compiler) -> str:
return c.dialect.is_distinct_from(c.compile(self.a), c.compile(self.b))
a = c.dialect.to_comparable(c.compile(self.a), self.a.type)
b = c.dialect.to_comparable(c.compile(self.b), self.b.type)
return c.dialect.is_distinct_from(a, b)


@dataclass(eq=False, order=False)
Expand Down
3 changes: 3 additions & 0 deletions tests/sqeleton/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def concat(self, l: List[str]) -> str:
s = ", ".join(l)
return f"concat({s})"

def to_comparable(self, s: str) -> str:
return s

def to_string(self, s: str) -> str:
return f"cast({s} as varchar)"

Expand Down