Skip to content

Commit 74a8f16

Browse files
committed
Test ragged arrays with ragged data
1 parent acf29ef commit 74a8f16

File tree

2 files changed

+46
-23
lines changed

2 files changed

+46
-23
lines changed

python/tests/test_tables.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def get_input(self, n):
7272

7373
class CharColumn(Column):
7474
def get_input(self, n):
75-
return np.zeros(n, dtype=np.int8)
75+
rng = np.random.RandomState(42)
76+
return rng.randint(low=0, high=127, size=n, dtype=np.int8)
7677

7778

7879
class DoubleColumn(Column):
@@ -87,13 +88,32 @@ class CommonTestsMixin:
8788
"""
8889

8990
def make_input_data(self, num_rows):
91+
rng = np.random.RandomState(42)
9092
input_data = {col.name: col.get_input(num_rows) for col in self.columns}
9193
for list_col, offset_col in self.ragged_list_columns:
92-
value = list_col.get_input(num_rows)
93-
input_data[list_col.name] = value
94-
input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32)
94+
lengths = rng.randint(low=0, high=10, size=num_rows)
95+
input_data[list_col.name] = list_col.get_input(sum(lengths))
96+
input_data[offset_col.name] = np.zeros(num_rows + 1, dtype=np.uint32)
97+
input_data[offset_col.name][1:] = np.cumsum(lengths, dtype=np.uint32)
9598
return input_data
9699

100+
def make_transposed_input_data(self, num_rows):
101+
cols = self.make_input_data(num_rows)
102+
return [
103+
{
104+
col: data[j]
105+
if len(data) == num_rows
106+
else (
107+
bytes(data[cols[f"{col}_offset"][j] : cols[f"{col}_offset"][j + 1]])
108+
if "metadata" in col
109+
else data[cols[f"{col}_offset"][j] : cols[f"{col}_offset"][j + 1]]
110+
)
111+
for col, data in cols.items()
112+
if "offset" not in col
113+
}
114+
for j in range(num_rows)
115+
]
116+
97117
def test_max_rows_increment(self):
98118
for bad_value in [-1, -(2 ** 10)]:
99119
with pytest.raises(ValueError):
@@ -279,15 +299,17 @@ def test_set_column_attributes_data(self):
279299
setattr(table, list_col.name, list_data)
280300
assert np.array_equal(getattr(table, list_col.name), list_data)
281301
list_value = getattr(table[0], list_col.name)
282-
assert len(list_value) == 1
302+
assert len(list_value) == input_data[offset_col.name][1]
283303

284304
# Reset the offsets so that all the full array is associated with the
285305
# first element.
286-
offset_data = np.zeros(num_rows + 1, dtype=np.uint32) + num_rows
306+
offset_data = np.zeros(num_rows + 1, dtype=np.uint32) + len(
307+
input_data[list_col.name]
308+
)
287309
offset_data[0] = 0
288310
setattr(table, offset_col.name, offset_data)
289311
list_value = getattr(table[0], list_col.name)
290-
assert len(list_value) == num_rows
312+
assert len(list_value) == len(input_data[list_col.name])
291313

292314
del input_data[list_col.name]
293315
del input_data[offset_col.name]
@@ -338,17 +360,11 @@ def test_defaults(self):
338360

339361
def test_add_row_data(self):
340362
for num_rows in [0, 10, 100]:
341-
input_data = {col.name: col.get_input(num_rows) for col in self.columns}
342363
table = self.table_class()
343-
for j in range(num_rows):
344-
kwargs = {col: data[j] for col, data in input_data.items()}
345-
for col in self.string_colnames:
346-
kwargs[col] = "x"
347-
for col in self.binary_colnames:
348-
kwargs[col] = b"x"
349-
k = table.add_row(**kwargs)
364+
for j, row in enumerate(self.make_transposed_input_data(num_rows)):
365+
k = table.add_row(**row)
350366
assert k == j
351-
for colname, input_array in input_data.items():
367+
for colname, input_array in self.make_input_data(num_rows).items():
352368
output_array = getattr(table, colname)
353369
assert input_array.shape == output_array.shape
354370
assert np.all(input_array == output_array)
@@ -573,6 +589,9 @@ def test_equality(self):
573589
value = list_col.get_input(num_rows)
574590
input_data_copy = dict(input_data)
575591
input_data_copy[list_col.name] = value + 1
592+
input_data_copy[offset_col.name] = np.arange(
593+
num_rows + 1, dtype=np.uint32
594+
)
576595
t2.set_columns(**input_data_copy)
577596
assert t1 != t2
578597
assert t1[0] != t2[0]
@@ -607,35 +626,36 @@ def test_bad_offsets(self):
607626
t.set_columns(**input_data)
608627

609628
for _list_col, offset_col in self.ragged_list_columns:
629+
original_offset = np.copy(input_data[offset_col.name])
610630
input_data[offset_col.name][0] = -1
611631
with pytest.raises(ValueError):
612632
t.set_columns(**input_data)
613-
input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32)
633+
input_data[offset_col.name] = np.copy(original_offset)
614634
t.set_columns(**input_data)
615635
input_data[offset_col.name][-1] = 0
616636
with pytest.raises(ValueError):
617637
t.set_columns(**input_data)
618-
input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32)
638+
input_data[offset_col.name] = np.copy(original_offset)
619639
t.set_columns(**input_data)
620640
input_data[offset_col.name][num_rows // 2] = 2 ** 31
621641
with pytest.raises(ValueError):
622642
t.set_columns(**input_data)
623-
input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32)
643+
input_data[offset_col.name] = np.copy(original_offset)
624644

625645
input_data[offset_col.name][0] = -1
626646
with pytest.raises(ValueError):
627647
t.append_columns(**input_data)
628-
input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32)
648+
input_data[offset_col.name] = np.copy(original_offset)
629649
t.append_columns(**input_data)
630650
input_data[offset_col.name][-1] = 0
631651
with pytest.raises(ValueError):
632652
t.append_columns(**input_data)
633-
input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32)
653+
input_data[offset_col.name] = np.copy(original_offset)
634654
t.append_columns(**input_data)
635655
input_data[offset_col.name][num_rows // 2] = 2 ** 31
636656
with pytest.raises(ValueError):
637657
t.append_columns(**input_data)
638-
input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32)
658+
input_data[offset_col.name] = np.copy(original_offset)
639659

640660

641661
class MetadataTestsMixin:

python/tskit/tables.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def __setattr__(self, name, value):
268268

269269
def __getitem__(self, index):
270270
"""
271-
Return the specifed row of this table, decoding metadata if it is present.
271+
Return the specified row of this table, decoding metadata if it is present.
272272
Supports negative indexing, e.g. ``table[-5]``.
273273
274274
:param int index: the zero-index of the desired row
@@ -285,6 +285,9 @@ def __getitem__(self, index):
285285
pass
286286
return self.row_class(*row)
287287

288+
def append(self, row):
289+
return self.ll_table.add_row(**row)
290+
288291
def clear(self):
289292
"""
290293
Deletes all rows in this table.

0 commit comments

Comments
 (0)