Skip to content

load_text only creates populations if none given #1910

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@
(:user:`jeetsukumaran`, :user:`jeromekelleher`, :issue:`1785`, :pr:`1835`,
:pr:`1836`, :pr:`1838`)

- `load_text` created additional populations even if the population table was specified,
and didn't strip newlines from input text (:user:`hyanwong`, :issue:`1909`, :pr:`1910`)


--------------------
[0.3.7] - 2021-07-08
Expand Down
23 changes: 23 additions & 0 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2736,6 +2736,7 @@ def verify_approximate_equality(self, ts1, ts2):
assert ts1.num_edges == ts2.num_edges
assert ts1.num_sites == ts2.num_sites
assert ts1.num_mutations == ts2.num_mutations
assert ts1.num_populations == ts2.num_populations

checked = 0
for n1, n2 in zip(ts1.nodes(), ts2.nodes()):
Expand Down Expand Up @@ -2848,6 +2849,28 @@ def test_empty_files_sequence_length(self):
assert ts.num_sites == 0
assert ts.num_edges == 0

def test_load_text_no_populations(self):
nodes_file = io.StringIO("is_sample\ttime\tpopulation\n1\t0\t2\n")
edges_file = io.StringIO("left\tright\tparent\tchild\n")
ts = tskit.load_text(nodes_file, edges_file, sequence_length=100)
assert ts.num_nodes == 1
assert ts.num_populations == 3

def test_load_text_populations(self):
nodes_file = io.StringIO("is_sample\ttime\tpopulation\n")
edges_file = io.StringIO("left\tright\tparent\tchild\n")
populations_file = io.StringIO("metadata\nmetadata_1\nmetadata_2\n")
ts = tskit.load_text(
nodes_file,
edges_file,
populations=populations_file,
sequence_length=100,
base64_metadata=False,
)
assert ts.num_populations == 2
assert ts.tables.populations[0].metadata == b"metadata_1"
assert ts.tables.populations[1].metadata == b"metadata_2"


class TestTree(HighLevelTestCase):
"""
Expand Down
44 changes: 24 additions & 20 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -2978,7 +2978,7 @@ def parse_individuals(
if table is None:
table = tables.IndividualTable()
# Read the header and find the indexes of the required fields.
header = source.readline().strip("\n").split(sep)
header = source.readline().rstrip("\n").split(sep)
flags_index = header.index("flags")
location_index = None
parents_index = None
Expand All @@ -2996,7 +2996,7 @@ def parse_individuals(
except ValueError:
pass
for line in source:
tokens = line.split(sep)
tokens = line.rstrip("\n").split(sep)
if len(tokens) >= 1:
flags = int(tokens[flags_index])
location = ()
Expand Down Expand Up @@ -3047,7 +3047,7 @@ def parse_nodes(source, strict=True, encoding="utf8", base64_metadata=True, tabl
if table is None:
table = tables.NodeTable()
# Read the header and find the indexes of the required fields.
header = source.readline().strip("\n").split(sep)
header = source.readline().rstrip("\n").split(sep)
is_sample_index = header.index("is_sample")
time_index = header.index("time")
population_index = None
Expand All @@ -3066,7 +3066,7 @@ def parse_nodes(source, strict=True, encoding="utf8", base64_metadata=True, tabl
except ValueError:
pass
for line in source:
tokens = line.split(sep)
tokens = line.rstrip("\n").split(sep)
if len(tokens) >= 2:
is_sample = int(tokens[is_sample_index])
time = float(tokens[time_index])
Expand Down Expand Up @@ -3116,13 +3116,13 @@ def parse_edges(source, strict=True, table=None):
sep = "\t"
if table is None:
table = tables.EdgeTable()
header = source.readline().strip("\n").split(sep)
header = source.readline().rstrip("\n").split(sep)
left_index = header.index("left")
right_index = header.index("right")
parent_index = header.index("parent")
children_index = header.index("child")
for line in source:
tokens = line.split(sep)
tokens = line.rstrip("\n").split(sep)
if len(tokens) >= 4:
left = float(tokens[left_index])
right = float(tokens[right_index])
Expand Down Expand Up @@ -3159,7 +3159,7 @@ def parse_sites(source, strict=True, encoding="utf8", base64_metadata=True, tabl
sep = "\t"
if table is None:
table = tables.SiteTable()
header = source.readline().strip("\n").split(sep)
header = source.readline().rstrip("\n").split(sep)
position_index = header.index("position")
ancestral_state_index = header.index("ancestral_state")
metadata_index = None
Expand All @@ -3168,7 +3168,7 @@ def parse_sites(source, strict=True, encoding="utf8", base64_metadata=True, tabl
except ValueError:
pass
for line in source:
tokens = line.split(sep)
tokens = line.rstrip("\n").split(sep)
if len(tokens) >= 2:
position = float(tokens[position_index])
ancestral_state = tokens[ancestral_state_index]
Expand Down Expand Up @@ -3212,7 +3212,7 @@ def parse_mutations(
sep = "\t"
if table is None:
table = tables.MutationTable()
header = source.readline().strip("\n").split(sep)
header = source.readline().rstrip("\n").split(sep)
site_index = header.index("site")
node_index = header.index("node")
try:
Expand All @@ -3232,7 +3232,7 @@ def parse_mutations(
except ValueError:
pass
for line in source:
tokens = line.split(sep)
tokens = line.rstrip("\n").split(sep)
if len(tokens) >= 3:
site = int(tokens[site_index])
node = int(tokens[node_index])
Expand Down Expand Up @@ -3289,10 +3289,10 @@ def parse_populations(
if table is None:
table = tables.PopulationTable()
# Read the header and find the indexes of the required fields.
header = source.readline().strip("\n").split(sep)
header = source.readline().rstrip("\n").split(sep)
metadata_index = header.index("metadata")
for line in source:
tokens = line.split(sep)
tokens = line.rstrip("\n").split(sep)
if len(tokens) >= 1:
metadata = tokens[metadata_index].encode(encoding)
if base64_metadata:
Expand Down Expand Up @@ -3329,7 +3329,10 @@ def load_text(
:func:`parse_nodes` and :func:`parse_edges`, respectively. ``sites``,
``mutations``, ``individuals`` and ``populations`` are optional, and must
be parsable by :func:`parse_sites`, :func:`parse_individuals`,
:func:`parse_populations`, and :func:`parse_mutations`, respectively.
:func:`parse_populations`, and :func:`parse_mutations`, respectively. For
convenience, if the node table refers to populations, but the ``populations``
parameter is not provided, a minimal set of rows are added to the
population table, so that a valid tree sequence can be returned.

The ``sequence_length`` parameter determines the
:attr:`TreeSequence.sequence_length` of the returned tree sequence. If it
Expand Down Expand Up @@ -3394,12 +3397,6 @@ def load_text(
base64_metadata=base64_metadata,
table=tc.nodes,
)
# We need to add populations any referenced in the node table.
if len(tc.nodes) > 0:
max_population = tc.nodes.population.max()
if max_population != NULL:
for _ in range(max_population + 1):
tc.populations.add_row()
if sites is not None:
parse_sites(
sites,
Expand All @@ -3424,7 +3421,14 @@ def load_text(
base64_metadata=base64_metadata,
table=tc.individuals,
)
if populations is not None:
if populations is None:
# As a convenience we add any populations referenced in the node table.
if len(tc.nodes) > 0:
max_population = tc.nodes.population.max()
if max_population != NULL:
for _ in range(max_population + 1):
tc.populations.add_row()
else:
parse_populations(
populations,
strict=strict,
Expand Down