Skip to content

"Append row" for table classes #1254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def handle_item(fieldarg, content):
# TODO these have been triaged here to make the docs compile, but we should
# sort them out properly. https://github.com/tskit-dev/tskit/issues/336
("py:class", "array_like"),
("py:class", "row-like"),
("py:class", "array-like"),
("py:class", "dtype=np.uint32"),
("py:class", "dtype=np.uint32."),
Expand Down
3 changes: 3 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

**Features**

- Add `Table.append` method for adding rows from classes such as `SiteTableRow` and
`Site` (:user:`benjeffery`, :issue:`1111`, :pr:`1254`).

- SVG visualization of a single tree allows all mutations on an edge to be plotted
via the ``all_edge_mutations`` param (:user:`hyanwong`,:issue:`1253`, :pr:`1258`).

Expand Down
2 changes: 1 addition & 1 deletion python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def ts_fixture():
for name, table in tables.name_map.items():
if name != "provenances":
table.metadata_schema = tskit.MetadataSchema({"codec": "json"})
metadatas = [f"n_{name}_{u}" for u in range(len(table))]
metadatas = [f'{{"foo":"n_{name}_{u}"}}' for u in range(len(table))]
metadata, metadata_offset = tskit.pack_strings(metadatas)
table.set_columns(
**{
Expand Down
49 changes: 13 additions & 36 deletions python/tests/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,7 @@ def record_node(self, input_id, is_sample=False):
flags &= ~tskit.NODE_IS_SAMPLE
if is_sample:
flags |= tskit.NODE_IS_SAMPLE
output_id = self.tables.nodes.add_row(
flags=flags,
time=node.time,
population=node.population,
metadata=node.metadata,
individual=node.individual,
)
output_id = self.tables.nodes.append(node.replace(flags=flags))
self.node_id_map[input_id] = output_id
return output_id

Expand All @@ -186,9 +180,7 @@ def flush_edges(self):
num_edges = 0
for child in sorted(self.edge_buffer.keys()):
for edge in self.edge_buffer[child]:
self.tables.edges.add_row(
edge.left, edge.right, edge.parent, edge.child
)
self.tables.edges.append(edge)
num_edges += 1
self.edge_buffer.clear()
return num_edges
Expand Down Expand Up @@ -413,19 +405,14 @@ def finalise_sites(self):
mapped_parent = -1
if mut.parent != -1:
mapped_parent = mutation_id_map[mut.parent]
self.tables.mutations.add_row(
site=len(self.tables.sites),
node=self.mutation_node_map[mut.id],
time=mut.time,
parent=mapped_parent,
derived_state=mut.derived_state,
metadata=mut.metadata,
self.tables.mutations.append(
mut.replace(
site=len(self.tables.sites),
node=self.mutation_node_map[mut.id],
parent=mapped_parent,
)
)
self.tables.sites.add_row(
position=site.position,
ancestral_state=site.ancestral_state,
metadata=site.metadata,
)
self.tables.sites.append(site)

def finalise_references(self):
input_populations = self.ts.tables.populations
Expand Down Expand Up @@ -455,17 +442,12 @@ def finalise_references(self):
for input_id, count in enumerate(population_ref_count):
if count > 0:
row = input_populations[input_id]
output_id = self.tables.populations.add_row(metadata=row.metadata)
output_id = self.tables.populations.append(row)
population_id_map[input_id] = output_id
for input_id, count in enumerate(individual_ref_count):
if count > 0:
row = input_individuals[input_id]
output_id = self.tables.individuals.add_row(
flags=row.flags,
location=row.location,
parents=row.parents,
metadata=row.metadata,
)
output_id = self.tables.individuals.append(row)
individual_id_map[input_id] = output_id

# Remap the population ID references for nodes.
Expand All @@ -489,12 +471,7 @@ def finalise_references(self):
mapped_parents.append(-1)
else:
mapped_parents.append(individual_id_map[p])
self.tables.individuals.add_row(
flags=row.flags,
location=row.location,
parents=mapped_parents,
metadata=row.metadata,
)
self.tables.individuals.append(row.replace(parents=mapped_parents))

# We don't support migrations for now. We'll need to remap these as well.
assert self.ts.num_migrations == 0
Expand Down Expand Up @@ -710,7 +687,7 @@ def flush_edges(self):
num_edges = 0
for child in sorted(self.edge_buffer.keys()):
for edge in self.edge_buffer[child]:
self.table.add_row(edge.left, edge.right, edge.parent, edge.child)
self.table.append(edge)
num_edges += 1
self.edge_buffer.clear()
return num_edges
Expand Down
36 changes: 26 additions & 10 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2854,15 +2854,10 @@ def verify_random_permutation(self, ts):
inv_node_map = {v: k for k, v in node_map.items()}
for j in range(ts.num_nodes):
node = ts.node(inv_node_map[j])
other_tables.nodes.add_row(
flags=node.flags, time=node.time, population=node.population
)
other_tables.nodes.append(node)
for e in ts.edges():
other_tables.edges.add_row(
left=e.left,
right=e.right,
parent=node_map[e.parent],
child=node_map[e.child],
other_tables.edges.append(
e.replace(parent=node_map[e.parent], child=node_map[e.child])
)
for _ in range(ts.num_populations):
other_tables.populations.add_row()
Expand Down Expand Up @@ -2971,8 +2966,6 @@ def test_metadata(self):
(inst,) = self.get_instances(1)
(inst2,) = self.get_instances(1)
assert inst == inst2
inst.metadata
assert inst == inst2
inst._metadata = "different"
assert inst != inst2

Expand Down Expand Up @@ -3170,6 +3163,29 @@ def get_instances(self, n):
]


class TestContainersAppend:
def test_containers_append(self, ts_fixture):
"""
Test that the containers work with `Table.append`
"""
tables = ts_fixture.dump_tables()
tables.clear(clear_provenance=True)
for table_name in [
"individuals",
"nodes",
"edges",
"migrations",
"sites",
"mutations",
"populations",
"provenances",
]:
table = getattr(tables, table_name)
for i in range(len(getattr(ts_fixture.tables, table_name))):
table.append(getattr(ts_fixture, table_name[:-1])(i))
assert ts_fixture.tables == tables


class TestTskitConversionOutput(unittest.TestCase):
"""
Tests conversion output to ensure it is correct.
Expand Down
32 changes: 10 additions & 22 deletions python/tests/test_parsimony.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,18 +549,16 @@ def verify(self, ts):
ancestral_state, mutations = self.do_map_mutations(
tree, G[site.id], alleles[site.id]
)
site_id = tables.sites.add_row(site.position, ancestral_state)
site_id = tables.sites.append(
site.replace(ancestral_state=ancestral_state)
)
parent_offset = len(tables.mutations)
for mutation in mutations:
parent = mutation.parent
if parent != tskit.NULL:
parent += parent_offset
tables.mutations.add_row(
site_id,
node=mutation.node,
time=mutation.time,
parent=parent,
derived_state=mutation.derived_state,
tables.mutations.append(
mutation.replace(site=site_id, parent=parent)
)
other_ts = tables.tree_sequence()
for h1, h2 in zip(
Expand Down Expand Up @@ -715,19 +713,15 @@ def verify(self, ts):
ancestral_state, mutations = self.do_map_mutations(
tree, G[site.id], alleles[site.id]
)
site_id = tables.sites.add_row(site.position, ancestral_state)
site_id = tables.sites.append(
site.replace(ancestral_state=ancestral_state)
)
parent_offset = len(tables.mutations)
for m in mutations:
parent = m.parent
if m.parent != tskit.NULL:
parent = m.parent + parent_offset
tables.mutations.add_row(
site_id,
node=m.node,
time=m.time,
parent=parent,
derived_state=m.derived_state,
)
tables.mutations.append(m.replace(site=site_id, parent=parent))
other_ts = tables.tree_sequence()
assert ts.num_samples == other_ts.num_samples
H1 = list(ts.haplotypes(isolated_as_missing=False))
Expand Down Expand Up @@ -1206,13 +1200,7 @@ def verify(self, ts, k):
parent = mutation.parent
if parent != tskit.NULL:
parent += parent_offset
tables.mutations.add_row(
j,
node=mutation.node,
time=mutation.time,
parent=parent,
derived_state=mutation.derived_state,
)
tables.mutations.append(mutation.replace(site=j, parent=parent))

ts2 = tables.tree_sequence()
G2 = np.zeros((m, n), dtype=np.int8)
Expand Down
Loading