diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 32b00b8f04..0d9be1df7d 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -112,6 +112,9 @@ (:user:`jeetsukumaran`, :user:`jeromekelleher`, :issue:`1785`, :pr:`1835`, :pr:`1836`, :pr:`1838`) +- `load_text` created additional populations even if the population table was specified, + and didn't strip newlines from input text (:user:`hyanwong`, :issue:`1909`, :pr:`1910`) + -------------------- [0.3.7] - 2021-07-08 diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 9da13b2a25..025aae3528 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -2736,6 +2736,7 @@ def verify_approximate_equality(self, ts1, ts2): assert ts1.num_edges == ts2.num_edges assert ts1.num_sites == ts2.num_sites assert ts1.num_mutations == ts2.num_mutations + assert ts1.num_populations == ts2.num_populations checked = 0 for n1, n2 in zip(ts1.nodes(), ts2.nodes()): @@ -2848,6 +2849,28 @@ def test_empty_files_sequence_length(self): assert ts.num_sites == 0 assert ts.num_edges == 0 + def test_load_text_no_populations(self): + nodes_file = io.StringIO("is_sample\ttime\tpopulation\n1\t0\t2\n") + edges_file = io.StringIO("left\tright\tparent\tchild\n") + ts = tskit.load_text(nodes_file, edges_file, sequence_length=100) + assert ts.num_nodes == 1 + assert ts.num_populations == 3 + + def test_load_text_populations(self): + nodes_file = io.StringIO("is_sample\ttime\tpopulation\n") + edges_file = io.StringIO("left\tright\tparent\tchild\n") + populations_file = io.StringIO("metadata\nmetadata_1\nmetadata_2\n") + ts = tskit.load_text( + nodes_file, + edges_file, + populations=populations_file, + sequence_length=100, + base64_metadata=False, + ) + assert ts.num_populations == 2 + assert ts.tables.populations[0].metadata == b"metadata_1" + assert ts.tables.populations[1].metadata == b"metadata_2" + class TestTree(HighLevelTestCase): """ diff --git a/python/tskit/trees.py b/python/tskit/trees.py index a0d38b1e7d..c279c19858 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -2978,7 +2978,7 @@ def parse_individuals( if table is None: table = tables.IndividualTable() # Read the header and find the indexes of the required fields. - header = source.readline().strip("\n").split(sep) + header = source.readline().rstrip("\n").split(sep) flags_index = header.index("flags") location_index = None parents_index = None @@ -2996,7 +2996,7 @@ def parse_individuals( except ValueError: pass for line in source: - tokens = line.split(sep) + tokens = line.rstrip("\n").split(sep) if len(tokens) >= 1: flags = int(tokens[flags_index]) location = () @@ -3047,7 +3047,7 @@ def parse_nodes(source, strict=True, encoding="utf8", base64_metadata=True, tabl if table is None: table = tables.NodeTable() # Read the header and find the indexes of the required fields. - header = source.readline().strip("\n").split(sep) + header = source.readline().rstrip("\n").split(sep) is_sample_index = header.index("is_sample") time_index = header.index("time") population_index = None @@ -3066,7 +3066,7 @@ def parse_nodes(source, strict=True, encoding="utf8", base64_metadata=True, tabl except ValueError: pass for line in source: - tokens = line.split(sep) + tokens = line.rstrip("\n").split(sep) if len(tokens) >= 2: is_sample = int(tokens[is_sample_index]) time = float(tokens[time_index]) @@ -3116,13 +3116,13 @@ def parse_edges(source, strict=True, table=None): sep = "\t" if table is None: table = tables.EdgeTable() - header = source.readline().strip("\n").split(sep) + header = source.readline().rstrip("\n").split(sep) left_index = header.index("left") right_index = header.index("right") parent_index = header.index("parent") children_index = header.index("child") for line in source: - tokens = line.split(sep) + tokens = line.rstrip("\n").split(sep) if len(tokens) >= 4: left = float(tokens[left_index]) right = float(tokens[right_index]) @@ -3159,7 +3159,7 @@ def parse_sites(source, strict=True, encoding="utf8", base64_metadata=True, tabl sep = "\t" if table is None: table = tables.SiteTable() - header = source.readline().strip("\n").split(sep) + header = source.readline().rstrip("\n").split(sep) position_index = header.index("position") ancestral_state_index = header.index("ancestral_state") metadata_index = None @@ -3168,7 +3168,7 @@ def parse_sites(source, strict=True, encoding="utf8", base64_metadata=True, tabl except ValueError: pass for line in source: - tokens = line.split(sep) + tokens = line.rstrip("\n").split(sep) if len(tokens) >= 2: position = float(tokens[position_index]) ancestral_state = tokens[ancestral_state_index] @@ -3212,7 +3212,7 @@ def parse_mutations( sep = "\t" if table is None: table = tables.MutationTable() - header = source.readline().strip("\n").split(sep) + header = source.readline().rstrip("\n").split(sep) site_index = header.index("site") node_index = header.index("node") try: @@ -3232,7 +3232,7 @@ def parse_mutations( except ValueError: pass for line in source: - tokens = line.split(sep) + tokens = line.rstrip("\n").split(sep) if len(tokens) >= 3: site = int(tokens[site_index]) node = int(tokens[node_index]) @@ -3289,10 +3289,10 @@ def parse_populations( if table is None: table = tables.PopulationTable() # Read the header and find the indexes of the required fields. - header = source.readline().strip("\n").split(sep) + header = source.readline().rstrip("\n").split(sep) metadata_index = header.index("metadata") for line in source: - tokens = line.split(sep) + tokens = line.rstrip("\n").split(sep) if len(tokens) >= 1: metadata = tokens[metadata_index].encode(encoding) if base64_metadata: @@ -3329,7 +3329,10 @@ def load_text( :func:`parse_nodes` and :func:`parse_edges`, respectively. ``sites``, ``mutations``, ``individuals`` and ``populations`` are optional, and must be parsable by :func:`parse_sites`, :func:`parse_individuals`, - :func:`parse_populations`, and :func:`parse_mutations`, respectively. + :func:`parse_populations`, and :func:`parse_mutations`, respectively. For + convenience, if the node table refers to populations, but the ``populations`` + parameter is not provided, a minimal set of rows are added to the + population table, so that a valid tree sequence can be returned. The ``sequence_length`` parameter determines the :attr:`TreeSequence.sequence_length` of the returned tree sequence. If it @@ -3394,12 +3397,6 @@ def load_text( base64_metadata=base64_metadata, table=tc.nodes, ) - # We need to add populations any referenced in the node table. - if len(tc.nodes) > 0: - max_population = tc.nodes.population.max() - if max_population != NULL: - for _ in range(max_population + 1): - tc.populations.add_row() if sites is not None: parse_sites( sites, @@ -3424,7 +3421,14 @@ def load_text( base64_metadata=base64_metadata, table=tc.individuals, ) - if populations is not None: + if populations is None: + # As a convenience we add any populations referenced in the node table. + if len(tc.nodes) > 0: + max_population = tc.nodes.population.max() + if max_population != NULL: + for _ in range(max_population + 1): + tc.populations.add_row() + else: parse_populations( populations, strict=strict,