Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit d4589c7

Browse files
authored
Merge pull request #814 from datafold/collations-alignment
Retrieve collations from the schema (and refactor the column info structures)
2 parents 06579be + 9e83b7d commit d4589c7

17 files changed

+273
-142
lines changed

data_diff/__main__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from rich.logging import RichHandler
1313
import click
1414

15-
from data_diff import Database
16-
from data_diff.schema import create_schema
15+
from data_diff import Database, DbPath
16+
from data_diff.schema import RawColumnInfo, create_schema
1717
from data_diff.queries.api import current_timestamp
1818

1919
from data_diff.dbt import dbt_diff
@@ -72,7 +72,7 @@ def _remove_passwords_in_dict(d: dict) -> None:
7272
d[k] = remove_password_from_url(v)
7373

7474

75-
def _get_schema(pair):
75+
def _get_schema(pair: Tuple[Database, DbPath]) -> Dict[str, RawColumnInfo]:
7676
db, table_path = pair
7777
return db.query_table_schema(table_path)
7878

data_diff/abcs/database_types.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import decimal
22
from abc import ABC, abstractmethod
3-
from typing import List, Optional, Tuple, Type, TypeVar, Union
3+
from typing import Collection, List, Optional, Tuple, Type, TypeVar, Union
44
from datetime import datetime
55

66
import attrs
@@ -15,6 +15,91 @@
1515
N = TypeVar("N")
1616

1717

18+
@attrs.frozen(kw_only=True, eq=False, order=False, unsafe_hash=True)
19+
class Collation:
20+
"""
21+
A pre-parsed or pre-known record about db collation, per column.
22+
23+
The "greater" collation should be used as a target collation for textual PKs
24+
on both sides of the diff — by coverting the "lesser" collation to self.
25+
26+
Snowflake easily absorbs the performance losses, so it has a boost to always
27+
be greater than any other collation in non-Snowflake databases.
28+
Other databases need to negotiate which side absorbs the performance impact.
29+
"""
30+
31+
# A boost for special databases that are known to absorb the performance dmaage well.
32+
absorbs_damage: bool = False
33+
34+
# Ordinal soring by ASCII/UTF8 (True), or alphabetic as per locale/country/etc (False).
35+
ordinal: Optional[bool] = None
36+
37+
# Lowercase first (aAbBcC or abcABC). Otherwise, uppercase first (AaBbCc or ABCabc).
38+
lower_first: Optional[bool] = None
39+
40+
# 2-letter lower-case locale and upper-case country codes, e.g. en_US. Ignored for ordinals.
41+
language: Optional[str] = None
42+
country: Optional[str] = None
43+
44+
# There are also space-, punctuation-, width-, kana-(in)sensitivity, so on.
45+
# Ignore everything not related to xdb alignment. Only case- & accent-sensitivity are common.
46+
case_sensitive: Optional[bool] = None
47+
accent_sensitive: Optional[bool] = None
48+
49+
# Purely informational, for debugging:
50+
_source: Union[None, str, Collection[str]] = None
51+
52+
def __eq__(self, other: object) -> bool:
53+
if not isinstance(other, Collation):
54+
return NotImplemented
55+
if self.ordinal and other.ordinal:
56+
# TODO: does it depend on language? what does Albanic_BIN mean in MS SQL?
57+
return True
58+
return (
59+
self.language == other.language
60+
and (self.country is None or other.country is None or self.country == other.country)
61+
and self.case_sensitive == other.case_sensitive
62+
and self.accent_sensitive == other.accent_sensitive
63+
and self.lower_first == other.lower_first
64+
)
65+
66+
def __ne__(self, other: object) -> bool:
67+
if not isinstance(other, Collation):
68+
return NotImplemented
69+
return not self.__eq__(other)
70+
71+
def __gt__(self, other: object) -> bool:
72+
if not isinstance(other, Collation):
73+
return NotImplemented
74+
if self == other:
75+
return False
76+
if self.absorbs_damage and not other.absorbs_damage:
77+
return False
78+
if other.absorbs_damage and not self.absorbs_damage:
79+
return True # this one is preferred if it cannot absorb damage as its counterpart can
80+
if self.ordinal and not other.ordinal:
81+
return True
82+
if other.ordinal and not self.ordinal:
83+
return False
84+
# TODO: try to align the languages & countries?
85+
return False
86+
87+
def __ge__(self, other: object) -> bool:
88+
if not isinstance(other, Collation):
89+
return NotImplemented
90+
return self == other or self.__gt__(other)
91+
92+
def __lt__(self, other: object) -> bool:
93+
if not isinstance(other, Collation):
94+
return NotImplemented
95+
return self != other and not self.__gt__(other)
96+
97+
def __le__(self, other: object) -> bool:
98+
if not isinstance(other, Collation):
99+
return NotImplemented
100+
return self == other or not self.__gt__(other)
101+
102+
18103
@attrs.define(frozen=True, kw_only=True)
19104
class ColType:
20105
# Arbitrary metadata added and fetched at runtime.
@@ -112,6 +197,7 @@ def python_type(self) -> type:
112197
@attrs.define(frozen=True)
113198
class StringType(ColType):
114199
python_type = str
200+
collation: Optional[Collation] = attrs.field(default=None, kw_only=True)
115201

116202

117203
@attrs.define(frozen=True)

data_diff/databases/base.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from data_diff.abcs.compiler import AbstractCompiler, Compilable
2121
from data_diff.queries.extras import ApplyFuncAndNormalizeAsString, Checksum, NormalizeAsString
22+
from data_diff.schema import RawColumnInfo
2223
from data_diff.utils import ArithString, is_uuid, join_iter, safezip
2324
from data_diff.queries.api import Expr, table, Select, SKIP, Explain, Code, this
2425
from data_diff.queries.ast_classes import (
@@ -707,27 +708,18 @@ def type_repr(self, t) -> str:
707708
datetime: "TIMESTAMP",
708709
}[t]
709710

710-
def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
711-
return self.TYPE_CLASSES.get(type_repr)
712-
713-
def parse_type(
714-
self,
715-
table_path: DbPath,
716-
col_name: str,
717-
type_repr: str,
718-
datetime_precision: int = None,
719-
numeric_precision: int = None,
720-
numeric_scale: int = None,
721-
) -> ColType:
711+
def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
722712
"Parse type info as returned by the database"
723713

724-
cls = self._parse_type_repr(type_repr)
714+
cls = self.TYPE_CLASSES.get(info.data_type)
725715
if cls is None:
726-
return UnknownColType(type_repr)
716+
return UnknownColType(info.data_type)
727717

728718
if issubclass(cls, TemporalType):
729719
return cls(
730-
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
720+
precision=info.datetime_precision
721+
if info.datetime_precision is not None
722+
else DEFAULT_DATETIME_PRECISION,
731723
rounds=self.ROUNDS_ON_PREC_LOSS,
732724
)
733725

@@ -738,22 +730,22 @@ def parse_type(
738730
return cls()
739731

740732
elif issubclass(cls, Decimal):
741-
if numeric_scale is None:
742-
numeric_scale = 0 # Needed for Oracle.
743-
return cls(precision=numeric_scale)
733+
if info.numeric_scale is None:
734+
return cls(precision=0) # Needed for Oracle.
735+
return cls(precision=info.numeric_scale)
744736

745737
elif issubclass(cls, Float):
746738
# assert numeric_scale is None
747739
return cls(
748740
precision=self._convert_db_precision_to_digits(
749-
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
741+
info.numeric_precision if info.numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
750742
)
751743
)
752744

753745
elif issubclass(cls, (JSON, Array, Struct, Text, Native_UUID)):
754746
return cls()
755747

756-
raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")
748+
raise TypeError(f"Parsing {info.data_type} returned an unknown type {cls!r}.")
757749

758750
def _convert_db_precision_to_digits(self, p: int) -> int:
759751
"""Convert from binary precision, used by floats, to decimal precision."""
@@ -1018,7 +1010,7 @@ def select_table_schema(self, path: DbPath) -> str:
10181010
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
10191011
)
10201012

1021-
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
1013+
def query_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]:
10221014
"""Query the table for its schema for table in 'path', and return {column: tuple}
10231015
where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?)
10241016
@@ -1029,7 +1021,17 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
10291021
if not rows:
10301022
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
10311023

1032-
d = {r[0]: r for r in rows}
1024+
d = {
1025+
r[0]: RawColumnInfo(
1026+
column_name=r[0],
1027+
data_type=r[1],
1028+
datetime_precision=r[2],
1029+
numeric_precision=r[3],
1030+
numeric_scale=r[4],
1031+
collation_name=r[5] if len(r) > 5 else None,
1032+
)
1033+
for r in rows
1034+
}
10331035
assert len(d) == len(rows)
10341036
return d
10351037

@@ -1051,7 +1053,11 @@ def query_table_unique_columns(self, path: DbPath) -> List[str]:
10511053
return list(res)
10521054

10531055
def _process_table_schema(
1054-
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str] = None, where: str = None
1056+
self,
1057+
path: DbPath,
1058+
raw_schema: Dict[str, RawColumnInfo],
1059+
filter_columns: Sequence[str] = None,
1060+
where: str = None,
10551061
):
10561062
"""Process the result of query_table_schema().
10571063
@@ -1067,7 +1073,7 @@ def _process_table_schema(
10671073
accept = {i.lower() for i in filter_columns}
10681074
filtered_schema = {name: row for name, row in raw_schema.items() if name.lower() in accept}
10691075

1070-
col_dict = {row[0]: self.dialect.parse_type(path, *row) for _name, row in filtered_schema.items()}
1076+
col_dict = {info.column_name: self.dialect.parse_type(path, info) for info in filtered_schema.values()}
10711077

10721078
self._refine_coltypes(path, col_dict, where)
10731079

@@ -1076,15 +1082,15 @@ def _process_table_schema(
10761082

10771083
def _refine_coltypes(
10781084
self, table_path: DbPath, col_dict: Dict[str, ColType], where: Optional[str] = None, sample_size=64
1079-
):
1085+
) -> Dict[str, ColType]:
10801086
"""Refine the types in the column dict, by querying the database for a sample of their values
10811087
10821088
'where' restricts the rows to be sampled.
10831089
"""
10841090

10851091
text_columns = [k for k, v in col_dict.items() if isinstance(v, Text)]
10861092
if not text_columns:
1087-
return
1093+
return col_dict
10881094

10891095
fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns]
10901096

@@ -1116,7 +1122,9 @@ def _refine_coltypes(
11161122
)
11171123
else:
11181124
assert col_name in col_dict
1119-
col_dict[col_name] = String_VaryingAlphanum()
1125+
col_dict[col_name] = String_VaryingAlphanum(collation=col_dict[col_name].collation)
1126+
1127+
return col_dict
11201128

11211129
def _normalize_table_path(self, path: DbPath) -> DbPath:
11221130
if len(path) == 1:

data_diff/databases/bigquery.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
MD5_HEXDIGITS,
3434
)
3535
from data_diff.databases.base import TIMESTAMP_PRECISION_POS, ThreadLocalInterpreter
36+
from data_diff.schema import RawColumnInfo
3637

3738

3839
@import_helper(text="Please install BigQuery and configure your google-cloud access.")
@@ -91,27 +92,21 @@ def type_repr(self, t) -> str:
9192
except KeyError:
9293
return super().type_repr(t)
9394

94-
def parse_type(
95-
self,
96-
table_path: DbPath,
97-
col_name: str,
98-
type_repr: str,
99-
*args: Any, # pass-through args
100-
**kwargs: Any, # pass-through args
101-
) -> ColType:
102-
col_type = super().parse_type(table_path, col_name, type_repr, *args, **kwargs)
95+
def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
96+
col_type = super().parse_type(table_path, info)
10397
if isinstance(col_type, UnknownColType):
104-
m = self.TYPE_ARRAY_RE.fullmatch(type_repr)
98+
m = self.TYPE_ARRAY_RE.fullmatch(info.data_type)
10599
if m:
106-
item_type = self.parse_type(table_path, col_name, m.group(1), *args, **kwargs)
100+
item_info = attrs.evolve(info, data_type=m.group(1))
101+
item_type = self.parse_type(table_path, item_info)
107102
col_type = Array(item_type=item_type)
108103

109104
# We currently ignore structs' structure, but later can parse it too. Examples:
110105
# - STRUCT<INT64, STRING(10)> (unnamed)
111106
# - STRUCT<foo INT64, bar STRING(10)> (named)
112107
# - STRUCT<foo INT64, bar ARRAY<INT64>> (with complex fields)
113108
# - STRUCT<foo INT64, bar STRUCT<a INT64, b INT64>> (nested)
114-
m = self.TYPE_STRUCT_RE.fullmatch(type_repr)
109+
m = self.TYPE_STRUCT_RE.fullmatch(info.data_type)
115110
if m:
116111
col_type = Struct()
117112

data_diff/databases/clickhouse.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from data_diff.abcs.database_types import (
1616
ColType,
17+
DbPath,
1718
Decimal,
1819
Float,
1920
Integer,
@@ -24,6 +25,7 @@
2425
Timestamp,
2526
Boolean,
2627
)
28+
from data_diff.schema import RawColumnInfo
2729

2830
# https://clickhouse.com/docs/en/operations/server-configuration-parameters/settings/#default-database
2931
DEFAULT_DATABASE = "default"
@@ -75,19 +77,19 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
7577
# because it does not help for float with a big integer part.
7678
return super()._convert_db_precision_to_digits(p) - 2
7779

78-
def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
80+
def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
7981
nullable_prefix = "Nullable("
80-
if type_repr.startswith(nullable_prefix):
81-
type_repr = type_repr[len(nullable_prefix) :].rstrip(")")
82+
if info.data_type.startswith(nullable_prefix):
83+
info = attrs.evolve(info, data_type=info.data_type[len(nullable_prefix) :].rstrip(")"))
8284

83-
if type_repr.startswith("Decimal"):
84-
type_repr = "Decimal"
85-
elif type_repr.startswith("FixedString"):
86-
type_repr = "FixedString"
87-
elif type_repr.startswith("DateTime64"):
88-
type_repr = "DateTime64"
85+
if info.data_type.startswith("Decimal"):
86+
info = attrs.evolve(info, data_type="Decimal")
87+
elif info.data_type.startswith("FixedString"):
88+
info = attrs.evolve(info, data_type="FixedString")
89+
elif info.data_type.startswith("DateTime64"):
90+
info = attrs.evolve(info, data_type="DateTime64")
8991

90-
return self.TYPE_CLASSES.get(type_repr)
92+
return super().parse_type(table_path, info)
9193

9294
# def timestamp_value(self, t: DbTime) -> str:
9395
# # return f"'{t}'"

0 commit comments

Comments
 (0)