Skip to content

Commit d62c3ec

Browse files
committed
Add table append method
1 parent dc76c8c commit d62c3ec

File tree

10 files changed

+267
-270
lines changed

10 files changed

+267
-270
lines changed

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def handle_item(fieldarg, content):
311311
# TODO these have been triaged here to make the docs compile, but we should
312312
# sort them out properly. https://github.com/tskit-dev/tskit/issues/336
313313
("py:class", "array_like"),
314+
("py:class", "row-like"),
314315
("py:class", "array-like"),
315316
("py:class", "dtype=np.uint32"),
316317
("py:class", "dtype=np.uint32."),

python/CHANGELOG.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
**Features**
1111

12+
- Add `Table.append` method for adding rows from classes such as `SiteTableRow` and
13+
`Site` (:user:`benjeffery`, :issue:`1111`, :pr:`1254`).
14+
1215
- SVG visualization of a single tree allows all mutations on an edge to be plotted
1316
via the ``all_edge_mutations`` param (:user:`hyanwong`,:issue:`1253`, :pr:`1258`).
1417

python/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def ts_fixture():
128128
for name, table in tables.name_map.items():
129129
if name != "provenances":
130130
table.metadata_schema = tskit.MetadataSchema({"codec": "json"})
131-
metadatas = [f"n_{name}_{u}" for u in range(len(table))]
131+
metadatas = [f'{{"foo":"n_{name}_{u}"}}' for u in range(len(table))]
132132
metadata, metadata_offset = tskit.pack_strings(metadatas)
133133
table.set_columns(
134134
**{

python/tests/simplify.py

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Python implementation of the simplify algorithm.
2525
"""
2626
import sys
27+
from dataclasses import replace
2728

2829
import numpy as np
2930
import portion
@@ -158,12 +159,11 @@ def record_node(self, input_id, is_sample=False):
158159
flags &= ~tskit.NODE_IS_SAMPLE
159160
if is_sample:
160161
flags |= tskit.NODE_IS_SAMPLE
161-
output_id = self.tables.nodes.add_row(
162-
flags=flags,
163-
time=node.time,
164-
population=node.population,
165-
metadata=node.metadata,
166-
individual=node.individual,
162+
output_id = self.tables.nodes.append(
163+
replace(
164+
node,
165+
flags=flags,
166+
)
167167
)
168168
self.node_id_map[input_id] = output_id
169169
return output_id
@@ -186,9 +186,7 @@ def flush_edges(self):
186186
num_edges = 0
187187
for child in sorted(self.edge_buffer.keys()):
188188
for edge in self.edge_buffer[child]:
189-
self.tables.edges.add_row(
190-
edge.left, edge.right, edge.parent, edge.child
191-
)
189+
self.tables.edges.append(edge)
192190
num_edges += 1
193191
self.edge_buffer.clear()
194192
return num_edges
@@ -413,19 +411,15 @@ def finalise_sites(self):
413411
mapped_parent = -1
414412
if mut.parent != -1:
415413
mapped_parent = mutation_id_map[mut.parent]
416-
self.tables.mutations.add_row(
417-
site=len(self.tables.sites),
418-
node=self.mutation_node_map[mut.id],
419-
time=mut.time,
420-
parent=mapped_parent,
421-
derived_state=mut.derived_state,
422-
metadata=mut.metadata,
414+
self.tables.mutations.append(
415+
replace(
416+
mut,
417+
site=len(self.tables.sites),
418+
node=self.mutation_node_map[mut.id],
419+
parent=mapped_parent,
420+
)
423421
)
424-
self.tables.sites.add_row(
425-
position=site.position,
426-
ancestral_state=site.ancestral_state,
427-
metadata=site.metadata,
428-
)
422+
self.tables.sites.append(site)
429423

430424
def finalise_references(self):
431425
input_populations = self.ts.tables.populations
@@ -455,17 +449,12 @@ def finalise_references(self):
455449
for input_id, count in enumerate(population_ref_count):
456450
if count > 0:
457451
row = input_populations[input_id]
458-
output_id = self.tables.populations.add_row(metadata=row.metadata)
452+
output_id = self.tables.populations.append(row)
459453
population_id_map[input_id] = output_id
460454
for input_id, count in enumerate(individual_ref_count):
461455
if count > 0:
462456
row = input_individuals[input_id]
463-
output_id = self.tables.individuals.add_row(
464-
flags=row.flags,
465-
location=row.location,
466-
parents=row.parents,
467-
metadata=row.metadata,
468-
)
457+
output_id = self.tables.individuals.append(row)
469458
individual_id_map[input_id] = output_id
470459

471460
# Remap the population ID references for nodes.
@@ -489,11 +478,11 @@ def finalise_references(self):
489478
mapped_parents.append(-1)
490479
else:
491480
mapped_parents.append(individual_id_map[p])
492-
self.tables.individuals.add_row(
493-
flags=row.flags,
494-
location=row.location,
495-
parents=mapped_parents,
496-
metadata=row.metadata,
481+
self.tables.individuals.append(
482+
replace(
483+
row,
484+
parents=mapped_parents,
485+
)
497486
)
498487

499488
# We don't support migrations for now. We'll need to remap these as well.
@@ -710,7 +699,7 @@ def flush_edges(self):
710699
num_edges = 0
711700
for child in sorted(self.edge_buffer.keys()):
712701
for edge in self.edge_buffer[child]:
713-
self.table.add_row(edge.left, edge.right, edge.parent, edge.child)
702+
self.table.append(edge)
714703
num_edges += 1
715704
self.edge_buffer.clear()
716705
return num_edges

python/tests/test_highlevel.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import unittest
4141
import uuid as _uuid
4242
import warnings
43+
from dataclasses import replace
4344

4445
import kastore
4546
import msprime
@@ -2854,15 +2855,14 @@ def verify_random_permutation(self, ts):
28542855
inv_node_map = {v: k for k, v in node_map.items()}
28552856
for j in range(ts.num_nodes):
28562857
node = ts.node(inv_node_map[j])
2857-
other_tables.nodes.add_row(
2858-
flags=node.flags, time=node.time, population=node.population
2859-
)
2858+
other_tables.nodes.append(node)
28602859
for e in ts.edges():
2861-
other_tables.edges.add_row(
2862-
left=e.left,
2863-
right=e.right,
2864-
parent=node_map[e.parent],
2865-
child=node_map[e.child],
2860+
other_tables.edges.append(
2861+
replace(
2862+
e,
2863+
parent=node_map[e.parent],
2864+
child=node_map[e.child],
2865+
)
28662866
)
28672867
for _ in range(ts.num_populations):
28682868
other_tables.populations.add_row()
@@ -3168,6 +3168,30 @@ def get_instances(self, n):
31683168
]
31693169

31703170

3171+
class TestContainersAppend:
3172+
def test_containers_append(self, ts_fixture):
3173+
"""
3174+
Test that the containers work with `Table.append`
3175+
"""
3176+
tables = ts_fixture.dump_tables()
3177+
tables.clear(clear_provenance=True)
3178+
for table_name in [
3179+
"individuals",
3180+
"nodes",
3181+
"edges",
3182+
"migrations",
3183+
"sites",
3184+
"mutations",
3185+
"populations",
3186+
"provenances",
3187+
]:
3188+
table = getattr(tables, table_name)
3189+
for i in range(len(getattr(ts_fixture.tables, table_name))):
3190+
table.append(getattr(ts_fixture, table_name[:-1])(i))
3191+
print(ts_fixture.tables, tables)
3192+
assert ts_fixture.tables == tables
3193+
3194+
31713195
class TestTskitConversionOutput(unittest.TestCase):
31723196
"""
31733197
Tests conversion output to ensure it is correct.

python/tests/test_parsimony.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"""
2525
import io
2626
import itertools
27+
from dataclasses import replace
2728

2829
import attr
2930
import Bio.Phylo.TreeConstruction
@@ -549,18 +550,20 @@ def verify(self, ts):
549550
ancestral_state, mutations = self.do_map_mutations(
550551
tree, G[site.id], alleles[site.id]
551552
)
552-
site_id = tables.sites.add_row(site.position, ancestral_state)
553+
site_id = tables.sites.append(
554+
replace(site, ancestral_state=ancestral_state)
555+
)
553556
parent_offset = len(tables.mutations)
554557
for mutation in mutations:
555558
parent = mutation.parent
556559
if parent != tskit.NULL:
557560
parent += parent_offset
558-
tables.mutations.add_row(
559-
site_id,
560-
node=mutation.node,
561-
time=mutation.time,
562-
parent=parent,
563-
derived_state=mutation.derived_state,
561+
tables.mutations.append(
562+
replace(
563+
mutation,
564+
site=site_id,
565+
parent=parent,
566+
)
564567
)
565568
other_ts = tables.tree_sequence()
566569
for h1, h2 in zip(
@@ -715,18 +718,20 @@ def verify(self, ts):
715718
ancestral_state, mutations = self.do_map_mutations(
716719
tree, G[site.id], alleles[site.id]
717720
)
718-
site_id = tables.sites.add_row(site.position, ancestral_state)
721+
site_id = tables.sites.append(
722+
replace(site, ancestral_state=ancestral_state)
723+
)
719724
parent_offset = len(tables.mutations)
720725
for m in mutations:
721726
parent = m.parent
722727
if m.parent != tskit.NULL:
723728
parent = m.parent + parent_offset
724-
tables.mutations.add_row(
725-
site_id,
726-
node=m.node,
727-
time=m.time,
728-
parent=parent,
729-
derived_state=m.derived_state,
729+
tables.mutations.append(
730+
replace(
731+
m,
732+
site=site_id,
733+
parent=parent,
734+
)
730735
)
731736
other_ts = tables.tree_sequence()
732737
assert ts.num_samples == other_ts.num_samples
@@ -1206,12 +1211,12 @@ def verify(self, ts, k):
12061211
parent = mutation.parent
12071212
if parent != tskit.NULL:
12081213
parent += parent_offset
1209-
tables.mutations.add_row(
1210-
j,
1211-
node=mutation.node,
1212-
time=mutation.time,
1213-
parent=parent,
1214-
derived_state=mutation.derived_state,
1214+
tables.mutations.append(
1215+
replace(
1216+
mutation,
1217+
site=j,
1218+
parent=parent,
1219+
)
12151220
)
12161221

12171222
ts2 = tables.tree_sequence()

python/tests/test_tables.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,42 @@ def test_add_row_round_trip(self):
386386
t2.add_row(**dataclasses.asdict(row))
387387
assert t1 == t2
388388

389+
def test_append_row(self):
390+
for num_rows in [0, 10, 100]:
391+
table = self.table_class()
392+
for j, row in enumerate(self.make_transposed_input_data(num_rows)):
393+
k = table.append(table.row_class(**row))
394+
assert k == j
395+
for colname, input_array in self.make_input_data(num_rows).items():
396+
output_array = getattr(table, colname)
397+
assert input_array.shape == output_array.shape
398+
assert np.all(input_array == output_array)
399+
table.clear()
400+
assert table.num_rows == 0
401+
assert len(table) == 0
402+
403+
def test_append_duck_type(self):
404+
class Duck:
405+
pass
406+
407+
table = self.table_class()
408+
for j, row in enumerate(self.make_transposed_input_data(20)):
409+
duck = Duck()
410+
for k, v in row.items():
411+
setattr(duck, k, v)
412+
k = table.append(duck)
413+
assert k == j
414+
for colname, input_array in self.make_input_data(20).items():
415+
output_array = getattr(table, colname)
416+
assert np.array_equal(input_array, output_array)
417+
418+
def test_append_error(self):
419+
class NotADuck:
420+
pass
421+
422+
with pytest.raises(AttributeError, match="'NotADuck' object has no attribute"):
423+
self.table_class().append(NotADuck())
424+
389425
def test_set_columns_data(self):
390426
for num_rows in [0, 10, 100, 1000]:
391427
input_data = {col.name: col.get_input(num_rows) for col in self.columns}
@@ -1627,7 +1663,7 @@ def verify_edge_sort_offset(self, ts):
16271663
all_edges = keep + reversed_edges
16281664
tables.edges.clear()
16291665
for e in all_edges:
1630-
tables.edges.add_row(e.left, e.right, e.parent, e.child)
1666+
tables.edges.append(e)
16311667
# Verify that import fails for randomised edges
16321668
with pytest.raises(_tskit.LibraryError):
16331669
tables.tree_sequence()
@@ -1638,7 +1674,7 @@ def verify_edge_sort_offset(self, ts):
16381674
# Sorting from the correct index should give us back the original table.
16391675
tables.edges.clear()
16401676
for e in all_edges:
1641-
tables.edges.add_row(e.left, e.right, e.parent, e.child)
1677+
tables.edges.append(e)
16421678
tables.sort(edge_start=start)
16431679
# Verify the new and old edges are equal.
16441680
assert edges == tables.edges

0 commit comments

Comments
 (0)