Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 5b4dc4d

Browse files
committed
Support for varying alphanums, with special characters
Re-implementation of alphanums, segmented without the use of intermediary ints
1 parent b54e925 commit 5b4dc4d

File tree

7 files changed

+185
-77
lines changed

7 files changed

+185
-77
lines changed

data_diff/databases/base.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
Float,
1717
ColType_UUID,
1818
Native_UUID,
19-
String_Alphanum,
2019
String_UUID,
20+
String_Alphanum,
21+
String_FixedAlphanum,
22+
String_VaryingAlphanum,
2123
TemporalType,
2224
UnknownColType,
2325
Text,
@@ -79,6 +81,7 @@ class Database(AbstractDatabase):
7981

8082
TYPE_CLASSES: Dict[str, type] = {}
8183
default_schema: str = None
84+
SUPPORTS_ALPHANUMS = True
8285

8386
@property
8487
def name(self):
@@ -229,23 +232,22 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType], whe
229232
col_dict[col_name] = String_UUID()
230233
continue
231234

232-
alphanum_samples = [s for s in samples if s and String_Alphanum.test_value(s)]
233-
if alphanum_samples:
234-
if len(alphanum_samples) != len(samples):
235-
logger.warning(
236-
f"Mixed Alphanum/Non-Alphanum values detected in column {'.'.join(table_path)}.{col_name}, disabling Alphanum support."
237-
)
238-
else:
239-
assert col_name in col_dict
240-
lens = set(map(len, alphanum_samples))
241-
if len(lens) > 1:
235+
if self.SUPPORTS_ALPHANUMS: # Anything but MySQL (so far)
236+
alphanum_samples = [s for s in samples if String_Alphanum.test_value(s)]
237+
if alphanum_samples:
238+
if len(alphanum_samples) != len(samples):
242239
logger.warning(
243-
f"Mixed Alphanum lengths detected in column {'.'.join(table_path)}.{col_name}, disabling Alphanum support."
240+
f"Mixed Alphanum/Non-Alphanum values detected in column {'.'.join(table_path)}.{col_name}. It cannot be used as a key."
244241
)
245242
else:
246-
(length,) = lens
247-
col_dict[col_name] = String_Alphanum(length=length)
248-
continue
243+
assert col_name in col_dict
244+
lens = set(map(len, alphanum_samples))
245+
if len(lens) > 1:
246+
col_dict[col_name] = String_VaryingAlphanum()
247+
else:
248+
(length,) = lens
249+
col_dict[col_name] = String_FixedAlphanum(length=length)
250+
continue
249251

250252
# @lru_cache()
251253
# def get_table_schema(self, path: DbPath) -> Dict[str, ColType]:

data_diff/databases/database_types.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,7 @@ class String_UUID(StringType, ColType_UUID):
9292
pass
9393

9494

95-
@dataclass
9695
class String_Alphanum(StringType, ColType_Alphanum):
97-
length: int
98-
9996
@staticmethod
10097
def test_value(value: str) -> bool:
10198
try:
@@ -104,6 +101,18 @@ def test_value(value: str) -> bool:
104101
except ValueError:
105102
return False
106103

104+
def make_value(self, value):
105+
return self.python_type(value)
106+
107+
108+
class String_VaryingAlphanum(String_Alphanum):
109+
pass
110+
111+
112+
@dataclass
113+
class String_FixedAlphanum(String_Alphanum):
114+
length: int
115+
107116
def make_value(self, value):
108117
if len(value) != self.length:
109118
raise ValueError(f"Expected alphanumeric value of length {self.length}, but got '{value}'.")

data_diff/databases/mysql.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class MySQL(ThreadedDatabase):
2828
"binary": Text,
2929
}
3030
ROUNDS_ON_PREC_LOSS = True
31+
SUPPORTS_ALPHANUMS = False
3132

3233
def __init__(self, *, thread_count, **kw):
3334
self._args = kw

data_diff/diff_tables.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
121121
logger.info(
122122
f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. "
123123
f"key-range: {table1.min_key}..{table2.max_key}, "
124-
f"size: {table1.approximate_size()}"
124+
f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}"
125125
)
126126

127127
# Bisect (split) the table into segments, and diff them recursively.
@@ -218,12 +218,12 @@ def _validate_and_adjust_columns(self, table1, table2):
218218
"If encoding/formatting differs between databases, it may result in false positives."
219219
)
220220

221-
def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None):
221+
def _bisect_and_diff_tables(self, table1: TableSegment, table2: TableSegment, level=0, max_rows=None):
222222
assert table1.is_bounded and table2.is_bounded
223223

224224
if max_rows is None:
225225
# We can be sure that row_count <= max_rows
226-
max_rows = table1.max_key - table1.min_key
226+
max_rows = max(table1.approximate_size(), table2.approximate_size())
227227

228228
# If count is below the threshold, just download and compare the columns locally
229229
# This saves time, as bisection speed is limited by ping and query performance.
@@ -254,37 +254,38 @@ def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None):
254254

255255
# Recursively compare each pair of corresponding segments between table1 and table2
256256
diff_iters = [
257-
self._diff_tables(t1, t2, level + 1, i + 1, len(segmented1))
257+
self._diff_tables(t1, t2, max_rows, level + 1, i + 1, len(segmented1))
258258
for i, (t1, t2) in enumerate(safezip(segmented1, segmented2))
259259
]
260260

261261
for res in self._thread_map(list, diff_iters):
262262
yield from res
263263

264-
def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_count=None):
264+
def _diff_tables(
265+
self, table1: TableSegment, table2: TableSegment, max_rows: int, level=0, segment_index=None, segment_count=None
266+
):
265267
logger.info(
266268
". " * level + f"Diffing segment {segment_index}/{segment_count}, "
267269
f"key-range: {table1.min_key}..{table2.max_key}, "
268-
f"size: {table2.max_key-table1.min_key}"
270+
f"size <= {max_rows}"
269271
)
270272

271273
# When benchmarking, we want the ability to skip checksumming. This
272274
# allows us to download all rows for comparison in performance. By
273275
# default, data-diff will checksum the section first (when it's below
274276
# the threshold) and _then_ download it.
275277
if BENCHMARK:
276-
max_rows_from_keys = max(table1.max_key - table1.min_key, table2.max_key - table2.min_key)
277-
if max_rows_from_keys < self.bisection_threshold:
278-
yield from self._bisect_and_diff_tables(table1, table2, level=level, max_rows=max_rows_from_keys)
278+
if max_rows < self.bisection_threshold:
279+
yield from self._bisect_and_diff_tables(table1, table2, level=level, max_rows=max_rows)
279280
return
280281

281282
(count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2])
282283

283284
if count1 == 0 and count2 == 0:
284-
logger.warning(
285-
"Uneven distribution of keys detected. (big gaps in the key column). "
286-
"For better performance, we recommend to increase the bisection-threshold."
287-
)
285+
# logger.warning(
286+
# f"Uneven distribution of keys detected in segment {table1.min_key}..{table2.max_key}. (big gaps in the key column). "
287+
# "For better performance, we recommend to increase the bisection-threshold."
288+
# )
288289
assert checksum1 is None and checksum2 is None
289290
return
290291

data_diff/table_segment.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from runtype import dataclass
66

7-
from .utils import ArithString, split_space
7+
from .utils import ArithString, split_space, ArithAlphanumeric
88

99
from .databases.base import Database
1010
from .databases.database_types import DbPath, DbKey, DbTime, Native_UUID, Schema, create_schema
@@ -149,8 +149,9 @@ def choose_checkpoints(self, count: int) -> List[DbKey]:
149149
assert self.is_bounded
150150
if isinstance(self.min_key, ArithString):
151151
assert type(self.min_key) is type(self.max_key)
152-
checkpoints = split_space(self.min_key.int, self.max_key.int, count)
153-
return [self.min_key.new(int=i) for i in checkpoints]
152+
checkpoints = self.min_key.range(self.max_key, count)
153+
assert all(self.min_key <= x <= self.max_key for x in checkpoints)
154+
return checkpoints
154155

155156
return split_space(self.min_key, self.max_key, count)
156157

data_diff/utils.py

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import string
1010
import threading
1111

12-
alphanums = string.digits + string.ascii_lowercase
12+
alphanums = " -" + string.digits + string.ascii_uppercase + "_" + string.ascii_lowercase
1313

1414

1515
def safezip(*args):
@@ -29,6 +29,11 @@ class ArithString:
2929
def new(cls, *args, **kw):
3030
return cls(*args, **kw)
3131

32+
def range(self, other: "ArithString", count: int):
33+
assert isinstance(other, ArithString)
34+
checkpoints = split_space(self.int, other.int, count)
35+
return [self.new(int=i) for i in checkpoints]
36+
3237

3338
class ArithUUID(UUID, ArithString):
3439
"A UUID that supports basic arithmetic (add, sub)"
@@ -49,70 +54,96 @@ def __sub__(self, other: Union[UUID, int]):
4954
return NotImplemented
5055

5156

52-
def numberToBase(num, base):
57+
def numberToAlphanum(num: int, base: str = alphanums) -> str:
5358
digits = []
5459
while num > 0:
55-
num, remainder = divmod(num, base)
60+
num, remainder = divmod(num, len(base))
5661
digits.append(remainder)
57-
return "".join(alphanums[i] for i in digits[::-1])
62+
return "".join(base[i] for i in digits[::-1])
5863

5964

60-
class ArithAlphanumeric(ArithString):
61-
def __init__(self, str: str = None, int: int = None, max_len=None):
62-
if str is None:
63-
str = numberToBase(int, len(alphanums))
64-
else:
65-
assert int is None
65+
def alphanumToNumber(alphanum: str, base: str) -> int:
66+
num = 0
67+
for c in alphanum:
68+
num = num * len(base) + base.index(c)
69+
return num
70+
71+
72+
def justify_alphanums(s1: str, s2: str):
73+
max_len = max(len(s1), len(s2))
74+
s1 = s1.ljust(max_len)
75+
s2 = s2.ljust(max_len)
76+
return s1, s2
6677

67-
if max_len and len(str) > max_len:
78+
79+
def alphanums_to_numbers(s1: str, s2: str):
80+
s1, s2 = justify_alphanums(s1, s2)
81+
n1 = alphanumToNumber(s1, alphanums)
82+
n2 = alphanumToNumber(s2, alphanums)
83+
return n1, n2
84+
85+
86+
class ArithAlphanumeric(ArithString):
87+
def __init__(self, s: str, max_len=None):
88+
if s is None:
89+
raise ValueError("Alphanum string cannot be None")
90+
if max_len and len(s) > max_len:
6891
raise ValueError(f"Length of alphanum value '{str}' is longer than the expected {max_len}")
6992

70-
self._str = str
93+
for ch in s:
94+
if ch not in alphanums:
95+
raise ValueError(f"Unexpected character {ch} in alphanum string")
96+
97+
self._str = s
7198
self._max_len = max_len
7299

73-
@property
74-
def int(self):
75-
return int(self._str, len(alphanums))
100+
# @property
101+
# def int(self):
102+
# return alphanumToNumber(self._str, alphanums)
76103

77104
def __str__(self):
78105
s = self._str
79106
if self._max_len:
80-
s = s.rjust(self._max_len, "0")
107+
s = s.rjust(self._max_len, alphanums[0])
81108
return s
82109

83110
def __len__(self):
84111
return len(self._str)
85112

86-
def __int__(self):
87-
return self.int
88-
89113
def __repr__(self):
90114
return f'alphanum"{self._str}"'
91115

92-
def __add__(self, other: "Union[ArithAlphanumeric, int]"):
116+
def __add__(self, other: "Union[ArithAlphanumeric, int]") -> "ArithAlphanumeric":
93117
if isinstance(other, int):
94-
res = self.new(int=self.int + other)
95-
if len(str(res)) != len(self):
96-
raise ValueError("Overflow error when adding to alphanumeric")
97-
return res
118+
if other != 1:
119+
raise NotImplementedError("not implemented for arbitrary numbers")
120+
lastchar = self._str[-1] if self._str else alphanums[0]
121+
s = self._str[:-1] + alphanums[alphanums.index(lastchar) + other]
122+
return self.new(s)
98123
return NotImplemented
99124

100-
def __sub__(self, other: "Union[ArithAlphanumeric, int]"):
101-
if isinstance(other, int):
102-
return type(self)(int=self.int - other)
103-
elif isinstance(other, ArithAlphanumeric):
104-
return self.int - other.int
125+
def range(self, other: "ArithAlphanumeric", count: int):
126+
assert isinstance(other, ArithAlphanumeric)
127+
n1, n2 = alphanums_to_numbers(self._str, other._str)
128+
split = split_space(n1, n2, count)
129+
return [self.new(numberToAlphanum(s)) for s in split]
130+
131+
def __sub__(self, other: "Union[ArithAlphanumeric, int]") -> float:
132+
if isinstance(other, ArithAlphanumeric):
133+
n1, n2 = alphanums_to_numbers(self._str, other._str)
134+
return n1 - n2
135+
105136
return NotImplemented
106137

107138
def __ge__(self, other):
108139
if not isinstance(other, type(self)):
109140
return NotImplemented
110-
return self.int >= other.int
141+
return self._str >= other._str
111142

112143
def __lt__(self, other):
113144
if not isinstance(other, type(self)):
114145
return NotImplemented
115-
return self.int < other.int
146+
return self._str < other._str
116147

117148
def new(self, *args, **kw):
118149
return type(self)(*args, **kw, max_len=self._max_len)

0 commit comments

Comments
 (0)