Skip to content

Commit daa0cf7

Browse files
committed
load_text only creates populations if none given
Fixes #1909
1 parent 1fcb3f6 commit daa0cf7

File tree

3 files changed

+50
-20
lines changed

3 files changed

+50
-20
lines changed

python/CHANGELOG.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@
112112
(:user:`jeetsukumaran`, :user:`jeromekelleher`, :issue:`1785`, :pr:`1835`,
113113
:pr:`1836`, :pr:`1838`)
114114

115+
- `load_text` created additional populations even if the population table was specified
116+
(:user:`hyanwong`, :issue:`1909`, :pr:`1910`)
117+
115118

116119
--------------------
117120
[0.3.7] - 2021-07-08

python/tests/test_highlevel.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2736,6 +2736,7 @@ def verify_approximate_equality(self, ts1, ts2):
27362736
assert ts1.num_edges == ts2.num_edges
27372737
assert ts1.num_sites == ts2.num_sites
27382738
assert ts1.num_mutations == ts2.num_mutations
2739+
assert ts1.num_populations == ts2.num_populations
27392740

27402741
checked = 0
27412742
for n1, n2 in zip(ts1.nodes(), ts2.nodes()):
@@ -2848,6 +2849,28 @@ def test_empty_files_sequence_length(self):
28482849
assert ts.num_sites == 0
28492850
assert ts.num_edges == 0
28502851

2852+
def test_load_text_no_populations(self):
2853+
nodes_file = io.StringIO("is_sample\ttime\tpopulation\n1\t0\t2\n")
2854+
edges_file = io.StringIO("left\tright\tparent\tchild\n")
2855+
ts = tskit.load_text(nodes_file, edges_file, sequence_length=100)
2856+
assert ts.num_nodes == 1
2857+
assert ts.num_populations == 3
2858+
2859+
def test_load_text_populations(self):
2860+
nodes_file = io.StringIO("is_sample\ttime\tpopulation\n")
2861+
edges_file = io.StringIO("left\tright\tparent\tchild\n")
2862+
populations_file = io.StringIO("metadata\nmetadata_1\nmetadata_2\n")
2863+
ts = tskit.load_text(
2864+
nodes_file,
2865+
edges_file,
2866+
populations=populations_file,
2867+
sequence_length=100,
2868+
base64_metadata=False,
2869+
)
2870+
assert ts.num_populations == 2
2871+
assert ts.tables.populations[0].metadata == b"metadata_1"
2872+
assert ts.tables.populations[1].metadata == b"metadata_2"
2873+
28512874

28522875
class TestTree(HighLevelTestCase):
28532876
"""

python/tskit/trees.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2978,7 +2978,7 @@ def parse_individuals(
29782978
if table is None:
29792979
table = tables.IndividualTable()
29802980
# Read the header and find the indexes of the required fields.
2981-
header = source.readline().strip("\n").split(sep)
2981+
header = source.readline().rstrip("\n").split(sep)
29822982
flags_index = header.index("flags")
29832983
location_index = None
29842984
parents_index = None
@@ -2996,7 +2996,7 @@ def parse_individuals(
29962996
except ValueError:
29972997
pass
29982998
for line in source:
2999-
tokens = line.split(sep)
2999+
tokens = line.rstrip("\n").split(sep)
30003000
if len(tokens) >= 1:
30013001
flags = int(tokens[flags_index])
30023002
location = ()
@@ -3047,7 +3047,7 @@ def parse_nodes(source, strict=True, encoding="utf8", base64_metadata=True, tabl
30473047
if table is None:
30483048
table = tables.NodeTable()
30493049
# Read the header and find the indexes of the required fields.
3050-
header = source.readline().strip("\n").split(sep)
3050+
header = source.readline().rstrip("\n").split(sep)
30513051
is_sample_index = header.index("is_sample")
30523052
time_index = header.index("time")
30533053
population_index = None
@@ -3066,7 +3066,7 @@ def parse_nodes(source, strict=True, encoding="utf8", base64_metadata=True, tabl
30663066
except ValueError:
30673067
pass
30683068
for line in source:
3069-
tokens = line.split(sep)
3069+
tokens = line.rstrip("\n").split(sep)
30703070
if len(tokens) >= 2:
30713071
is_sample = int(tokens[is_sample_index])
30723072
time = float(tokens[time_index])
@@ -3116,13 +3116,13 @@ def parse_edges(source, strict=True, table=None):
31163116
sep = "\t"
31173117
if table is None:
31183118
table = tables.EdgeTable()
3119-
header = source.readline().strip("\n").split(sep)
3119+
header = source.readline().rstrip("\n").split(sep)
31203120
left_index = header.index("left")
31213121
right_index = header.index("right")
31223122
parent_index = header.index("parent")
31233123
children_index = header.index("child")
31243124
for line in source:
3125-
tokens = line.split(sep)
3125+
tokens = line.rstrip("\n").split(sep)
31263126
if len(tokens) >= 4:
31273127
left = float(tokens[left_index])
31283128
right = float(tokens[right_index])
@@ -3159,7 +3159,7 @@ def parse_sites(source, strict=True, encoding="utf8", base64_metadata=True, tabl
31593159
sep = "\t"
31603160
if table is None:
31613161
table = tables.SiteTable()
3162-
header = source.readline().strip("\n").split(sep)
3162+
header = source.readline().rstrip("\n").split(sep)
31633163
position_index = header.index("position")
31643164
ancestral_state_index = header.index("ancestral_state")
31653165
metadata_index = None
@@ -3168,7 +3168,7 @@ def parse_sites(source, strict=True, encoding="utf8", base64_metadata=True, tabl
31683168
except ValueError:
31693169
pass
31703170
for line in source:
3171-
tokens = line.split(sep)
3171+
tokens = line.rstrip("\n").split(sep)
31723172
if len(tokens) >= 2:
31733173
position = float(tokens[position_index])
31743174
ancestral_state = tokens[ancestral_state_index]
@@ -3212,7 +3212,7 @@ def parse_mutations(
32123212
sep = "\t"
32133213
if table is None:
32143214
table = tables.MutationTable()
3215-
header = source.readline().strip("\n").split(sep)
3215+
header = source.readline().rstrip("\n").split(sep)
32163216
site_index = header.index("site")
32173217
node_index = header.index("node")
32183218
try:
@@ -3232,7 +3232,7 @@ def parse_mutations(
32323232
except ValueError:
32333233
pass
32343234
for line in source:
3235-
tokens = line.split(sep)
3235+
tokens = line.rstrip("\n").split(sep)
32363236
if len(tokens) >= 3:
32373237
site = int(tokens[site_index])
32383238
node = int(tokens[node_index])
@@ -3289,10 +3289,10 @@ def parse_populations(
32893289
if table is None:
32903290
table = tables.PopulationTable()
32913291
# Read the header and find the indexes of the required fields.
3292-
header = source.readline().strip("\n").split(sep)
3292+
header = source.readline().rstrip("\n").split(sep)
32933293
metadata_index = header.index("metadata")
32943294
for line in source:
3295-
tokens = line.split(sep)
3295+
tokens = line.rstrip("\n").split(sep)
32963296
if len(tokens) >= 1:
32973297
metadata = tokens[metadata_index].encode(encoding)
32983298
if base64_metadata:
@@ -3329,7 +3329,10 @@ def load_text(
33293329
:func:`parse_nodes` and :func:`parse_edges`, respectively. ``sites``,
33303330
``mutations``, ``individuals`` and ``populations`` are optional, and must
33313331
be parsable by :func:`parse_sites`, :func:`parse_individuals`,
3332-
:func:`parse_populations`, and :func:`parse_mutations`, respectively.
3332+
:func:`parse_populations`, and :func:`parse_mutations`, respectively. For
3333+
convenience, if the node table refers to populations, but the ``populations``
3334+
parameter is not provided, a minimal set of rows are added to the
3335+
population table, so that a valid tree sequence can be returned.
33333336
33343337
The ``sequence_length`` parameter determines the
33353338
:attr:`TreeSequence.sequence_length` of the returned tree sequence. If it
@@ -3394,12 +3397,6 @@ def load_text(
33943397
base64_metadata=base64_metadata,
33953398
table=tc.nodes,
33963399
)
3397-
# We need to add populations any referenced in the node table.
3398-
if len(tc.nodes) > 0:
3399-
max_population = tc.nodes.population.max()
3400-
if max_population != NULL:
3401-
for _ in range(max_population + 1):
3402-
tc.populations.add_row()
34033400
if sites is not None:
34043401
parse_sites(
34053402
sites,
@@ -3424,7 +3421,14 @@ def load_text(
34243421
base64_metadata=base64_metadata,
34253422
table=tc.individuals,
34263423
)
3427-
if populations is not None:
3424+
if populations is None:
3425+
# As a convenience we add any populations referenced in the node table.
3426+
if len(tc.nodes) > 0:
3427+
max_population = tc.nodes.population.max()
3428+
if max_population != NULL:
3429+
for _ in range(max_population + 1):
3430+
tc.populations.add_row()
3431+
else:
34283432
parse_populations(
34293433
populations,
34303434
strict=strict,

0 commit comments

Comments
 (0)