diff --git a/c/CHANGELOG.rst b/c/CHANGELOG.rst index 39310f5c1e..9e316de8b7 100644 --- a/c/CHANGELOG.rst +++ b/c/CHANGELOG.rst @@ -44,6 +44,12 @@ In development. **New features** +- New methods to perform set operations on table collections. + ``tsk_table_collection_subset`` subsets and reorders table collections by nodes + (:user:`mufernando`, :user:`petrelharp`, :pr:`663`, :pr:`690`). + ``tsk_table_collection_union`` forms the node-wise union of two table collections + (:user:`mufernando`, :user:`petrelharp`, :issue:`381`, :pr:`623`). + - Mutations now have an optional double-precision floating-point ``time`` column. If not specified, this defaults to a particular NaN value (``TSK_UNKNOWN_TIME``) indicating that the time is unknown. For a tree sequence to be considered valid diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 55129e0bc4..226af068d3 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -3986,10 +3986,10 @@ test_table_collection_subset_with_options(tsk_flags_t options) // four nodes from two diploids; the first is from pop 0 ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, 0, 0, NULL, 0); CU_ASSERT_FATAL(ret >= 0); - ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, 0, 0, NULL, 0); + ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 1.0, 0, 0, NULL, 0); CU_ASSERT_FATAL(ret >= 0); ret = tsk_node_table_add_row( - &tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, 1, NULL, 0); + &tables.nodes, TSK_NODE_IS_SAMPLE, 2.0, TSK_NULL, 1, NULL, 0); CU_ASSERT_FATAL(ret >= 0); ret = tsk_node_table_add_row( &tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, 1, NULL, 0); @@ -4009,13 +4009,16 @@ test_table_collection_subset_with_options(tsk_flags_t options) ret = tsk_site_table_add_row(&tables.sites, 0.4, "A", 1, NULL, 0); CU_ASSERT_FATAL(ret >= 0); ret = tsk_mutation_table_add_row( - &tables.mutations, 0, 0, TSK_NULL, NAN, NULL, 0, NULL, 0); + &tables.mutations, 0, 0, TSK_NULL, TSK_UNKNOWN_TIME, NULL, 0, NULL, 0); CU_ASSERT_FATAL(ret >= 0); - ret = tsk_mutation_table_add_row(&tables.mutations, 0, 0, 0, NAN, NULL, 0, NULL, 0); + ret = tsk_mutation_table_add_row( + &tables.mutations, 0, 0, 0, TSK_UNKNOWN_TIME, NULL, 0, NULL, 0); CU_ASSERT_FATAL(ret >= 0); ret = tsk_mutation_table_add_row( - &tables.mutations, 1, 1, TSK_NULL, NAN, NULL, 0, NULL, 0); + &tables.mutations, 1, 1, TSK_NULL, TSK_UNKNOWN_TIME, NULL, 0, NULL, 0); CU_ASSERT_FATAL(ret >= 0); + ret = tsk_table_collection_build_index(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); // empty nodes should get empty tables ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT | options); @@ -4069,16 +4072,17 @@ test_table_collection_subset_errors(void) ret = tsk_table_collection_init(&tables, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1; ret = tsk_table_collection_init(&tables_copy, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); // four nodes from two diploids; the first is from pop 0 ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, 0, 0, NULL, 0); CU_ASSERT_FATAL(ret >= 0); - ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, 0, 0, NULL, 0); + ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 1.0, 0, 0, NULL, 0); CU_ASSERT_FATAL(ret >= 0); ret = tsk_node_table_add_row( - &tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, 1, NULL, 0); + &tables.nodes, TSK_NODE_IS_SAMPLE, 2.0, TSK_NULL, 1, NULL, 0); CU_ASSERT_FATAL(ret >= 0); ret = tsk_node_table_add_row( &tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, 1, NULL, 0); @@ -4091,6 +4095,8 @@ test_table_collection_subset_errors(void) CU_ASSERT_FATAL(ret >= 0); ret = tsk_edge_table_add_row(&tables.edges, 0.0, 1.0, 1, 0, NULL, 0); CU_ASSERT_FATAL(ret >= 0); + ret = tsk_table_collection_build_index(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); /* Migrations are not supported */ ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); @@ -4101,15 +4107,248 @@ test_table_collection_subset_errors(void) CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MIGRATIONS_NOT_SUPPORTED); // test out of bounds nodes + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); nodes[0] = -1; - ret = tsk_table_collection_subset(&tables, nodes, 4); + ret = tsk_table_collection_subset(&tables_copy, nodes, 4); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); nodes[0] = 6; - ret = tsk_table_collection_subset(&tables, nodes, 4); + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_subset(&tables_copy, nodes, 4); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + // check integrity + nodes[0] = 0; + nodes[1] = 1; + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_node_table_truncate(&tables_copy.nodes, 3); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_node_table_add_row( + &tables_copy.nodes, TSK_NODE_IS_SAMPLE, 0.0, -2, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_table_collection_subset(&tables_copy, nodes, 4); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); + + tsk_table_collection_free(&tables); + tsk_table_collection_free(&tables_copy); +} + +static void +test_table_collection_union(void) +{ + int ret; + tsk_table_collection_t tables; + tsk_table_collection_t tables_empty; + tsk_table_collection_t tables_copy; + tsk_id_t node_mapping[3]; + + memset(node_mapping, 0xff, sizeof(node_mapping)); + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1; + ret = tsk_table_collection_init(&tables_empty, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables_empty.sequence_length = 1; + ret = tsk_table_collection_init(&tables_copy, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // does not error on empty tables + ret = tsk_table_collection_union( + &tables, &tables_empty, node_mapping, TSK_UNION_NO_CHECK_SHARED); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // three nodes, two pop, three ind, two edge, two site, two mut + ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, 0, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, 1, 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.5, 1, 2, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_individual_table_add_row(&tables.individuals, 0, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_individual_table_add_row(&tables.individuals, 0, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_individual_table_add_row(&tables.individuals, 0, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_edge_table_add_row(&tables.edges, 0.0, 1.0, 2, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_edge_table_add_row(&tables.edges, 0.0, 1.0, 2, 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_site_table_add_row(&tables.sites, 0.4, "T", 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_site_table_add_row(&tables.sites, 0.2, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_mutation_table_add_row( + &tables.mutations, 0, 0, TSK_NULL, TSK_UNKNOWN_TIME, NULL, 0, NULL, 0); + ret = tsk_mutation_table_add_row( + &tables.mutations, 1, 1, TSK_NULL, TSK_UNKNOWN_TIME, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_table_collection_build_index(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_sort(&tables, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // union with empty should not change + // other is empty + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_union( + &tables_copy, &tables_empty, node_mapping, TSK_UNION_NO_CHECK_SHARED); + CU_ASSERT_FATAL(tsk_table_collection_equals(&tables, &tables_copy)); + // self is empty + ret = tsk_table_collection_clear(&tables_copy); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_union( + &tables_copy, &tables, node_mapping, TSK_UNION_NO_CHECK_SHARED); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_FATAL(tsk_table_collection_equals(&tables, &tables_copy)); + + // union all shared nodes + subset original nodes = original table + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_union( + &tables_copy, &tables, node_mapping, TSK_UNION_NO_CHECK_SHARED); + CU_ASSERT_EQUAL_FATAL(ret, 0); + node_mapping[0] = 0; + node_mapping[1] = 1; + node_mapping[2] = 2; + ret = tsk_table_collection_subset(&tables_copy, node_mapping, 3); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_FATAL(tsk_table_collection_equals(&tables, &tables_copy)); + + // union with one shared node + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + node_mapping[0] = TSK_NULL; + node_mapping[1] = TSK_NULL; + node_mapping[2] = 2; + ret = tsk_table_collection_union(&tables_copy, &tables, node_mapping, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL( + tables_copy.populations.num_rows, tables.populations.num_rows + 2); + CU_ASSERT_EQUAL_FATAL( + tables_copy.individuals.num_rows, tables.individuals.num_rows + 2); + CU_ASSERT_EQUAL_FATAL(tables_copy.nodes.num_rows, tables.nodes.num_rows + 2); + CU_ASSERT_EQUAL_FATAL(tables_copy.edges.num_rows, tables.edges.num_rows + 2); + CU_ASSERT_EQUAL_FATAL(tables_copy.sites.num_rows, tables.sites.num_rows); + CU_ASSERT_EQUAL_FATAL(tables_copy.mutations.num_rows, tables.mutations.num_rows + 2); + + // union with one shared node, but no add pop + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + node_mapping[0] = TSK_NULL; + node_mapping[1] = TSK_NULL; + node_mapping[2] = 2; + ret = tsk_table_collection_union( + &tables_copy, &tables, node_mapping, TSK_UNION_NO_ADD_POP); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(tables_copy.populations.num_rows, tables.populations.num_rows); + CU_ASSERT_EQUAL_FATAL( + tables_copy.individuals.num_rows, tables.individuals.num_rows + 2); + CU_ASSERT_EQUAL_FATAL(tables_copy.nodes.num_rows, tables.nodes.num_rows + 2); + CU_ASSERT_EQUAL_FATAL(tables_copy.edges.num_rows, tables.edges.num_rows + 2); + CU_ASSERT_EQUAL_FATAL(tables_copy.sites.num_rows, tables.sites.num_rows); + CU_ASSERT_EQUAL_FATAL(tables_copy.mutations.num_rows, tables.mutations.num_rows + 2); + + tsk_table_collection_free(&tables_copy); + tsk_table_collection_free(&tables_empty); tsk_table_collection_free(&tables); +} + +static void +test_table_collection_union_errors(void) +{ + int ret; + tsk_table_collection_t tables; + tsk_table_collection_t tables_copy; + tsk_id_t node_mapping[] = { 0, 1 }; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1; + ret = tsk_table_collection_init(&tables_copy, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // two nodes, two pop, two ind, one edge, one site, one mut + ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, 0, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.5, 1, 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_individual_table_add_row(&tables.individuals, 0, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_individual_table_add_row(&tables.individuals, 0, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_edge_table_add_row(&tables.edges, 0.0, 1.0, 1, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_site_table_add_row(&tables.sites, 0.2, "A", 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_mutation_table_add_row( + &tables.mutations, 0, 0, TSK_NULL, TSK_UNKNOWN_TIME, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + + // trigger diff histories error + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_mutation_table_add_row( + &tables_copy.mutations, 0, 1, TSK_NULL, TSK_UNKNOWN_TIME, NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_table_collection_union(&tables_copy, &tables, node_mapping, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNION_DIFF_HISTORIES); + + // Migrations are not supported + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_migration_table_add_row(&tables_copy.migrations, 0, 1, 0, 0, 0, 0, NULL, 0); + CU_ASSERT_EQUAL_FATAL(tables_copy.migrations.num_rows, 1); + ret = tsk_table_collection_union( + &tables_copy, &tables, node_mapping, TSK_UNION_NO_CHECK_SHARED); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MIGRATIONS_NOT_SUPPORTED); + + // unsuported union - child shared parent not shared + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + node_mapping[0] = 0; + node_mapping[1] = TSK_NULL; + ret = tsk_table_collection_union( + &tables_copy, &tables, node_mapping, TSK_UNION_NO_ADD_POP); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNION_NOT_SUPPORTED); + + // test out of bounds node_mapping + node_mapping[0] = -4; + node_mapping[1] = 6; + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_union(&tables_copy, &tables, node_mapping, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNION_BAD_MAP); + + // check integrity + node_mapping[0] = 0; + node_mapping[1] = 1; + ret = tsk_node_table_add_row( + &tables_copy.nodes, TSK_NODE_IS_SAMPLE, 0.0, -2, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_table_collection_union(&tables_copy, &tables, node_mapping, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); + ret = tsk_table_collection_copy(&tables, &tables_copy, TSK_NO_INIT); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, -2, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_table_collection_union(&tables, &tables_copy, node_mapping, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); + tsk_table_collection_free(&tables_copy); + tsk_table_collection_free(&tables); } int @@ -4168,6 +4407,8 @@ main(int argc, char **argv) test_table_collection_check_integrity_no_populations }, { "test_table_collection_subset", test_table_collection_subset }, { "test_table_collection_subset_errors", test_table_collection_subset_errors }, + { "test_table_collection_union", test_table_collection_union }, + { "test_table_collection_union_errors", test_table_collection_union_errors }, { NULL, NULL }, }; diff --git a/c/tskit/core.c b/c/tskit/core.c index 3e58f7bc8a..7458f0deb8 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -352,6 +352,10 @@ tsk_strerror_internal(int err) case TSK_ERR_NONBINARY_MUTATIONS_UNSUPPORTED: ret = "Only binary mutations are supported for this operation"; break; + case TSK_ERR_UNION_NOT_SUPPORTED: + ret = "Union is not supported for cases where there is non-shared" + "history older than the shared history of the two Table Collections"; + break; /* Stats errors */ case TSK_ERR_BAD_NUM_WINDOWS: @@ -441,6 +445,16 @@ tsk_strerror_internal(int err) case TSK_ERR_TOO_MANY_VALUES: ret = "Too many values to compress"; break; + + /* Union errors */ + case TSK_ERR_UNION_BAD_MAP: + ret = "Node map contains an entry of a node not present in this table " + "collection."; + break; + case TSK_ERR_UNION_DIFF_HISTORIES: + // histories could be equivalent, because subset does not reorder + // edges (if not sorted) or mutations. + ret = "Shared portions of the tree sequences are not equal."; } return ret; } diff --git a/c/tskit/core.h b/c/tskit/core.h index 0e963e1225..78fb3add34 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -267,6 +267,7 @@ not found in the file. #define TSK_ERR_SORT_OFFSET_NOT_SUPPORTED -803 #define TSK_ERR_NONBINARY_MUTATIONS_UNSUPPORTED -804 #define TSK_ERR_MIGRATIONS_NOT_SUPPORTED -805 +#define TSK_ERR_UNION_NOT_SUPPORTED -806 /* Stats errors */ #define TSK_ERR_BAD_NUM_WINDOWS -900 @@ -303,6 +304,11 @@ not found in the file. #define TSK_ERR_MATCH_IMPOSSIBLE -1301 #define TSK_ERR_BAD_COMPRESSED_MATRIX_NODE -1302 #define TSK_ERR_TOO_MANY_VALUES -1303 + +/* Union errors */ +#define TSK_ERR_UNION_BAD_MAP -1400 +#define TSK_ERR_UNION_DIFF_HISTORIES -1401 + // clang-format on /* This bit is 0 for any errors originating from kastore */ diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 8e52e5ca70..1c6dc8ff2c 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -8209,15 +8209,68 @@ tsk_table_collection_clear(tsk_table_collection_t *self) return tsk_table_collection_truncate(self, &start); } -int TSK_WARN_UNUSED -tsk_table_collection_subset( - tsk_table_collection_t *self, tsk_id_t *nodes, tsk_size_t num_nodes) +static int +tsk_table_collection_add_and_remap_node(tsk_table_collection_t *self, + tsk_table_collection_t *other, tsk_id_t node_id, tsk_id_t *individual_map, + tsk_id_t *population_map, tsk_id_t *node_map, bool add_populations) { int ret = 0; - tsk_id_t k, i, new_ind, new_pop, new_parent, new_child, new_node; + tsk_id_t new_ind, new_pop; tsk_node_t node; tsk_individual_t ind; tsk_population_t pop; + + ret = tsk_node_table_get_row(&other->nodes, node_id, &node); + if (ret < 0) { + goto out; + } + new_ind = TSK_NULL; + if (node.individual != TSK_NULL) { + if (individual_map[node.individual] == TSK_NULL) { + tsk_individual_table_get_row(&other->individuals, node.individual, &ind); + ret = tsk_individual_table_add_row(&self->individuals, ind.flags, + ind.location, ind.location_length, ind.metadata, ind.metadata_length); + if (ret < 0) { + goto out; + } + individual_map[node.individual] = ret; + } + new_ind = individual_map[node.individual]; + } + new_pop = TSK_NULL; + if (node.population != TSK_NULL) { + // keep same pops if add_populations is False + if (!add_populations) { + population_map[node.population] = node.population; + } + if (population_map[node.population] == TSK_NULL) { + tsk_population_table_get_row(&other->populations, node.population, &pop); + ret = tsk_population_table_add_row( + &self->populations, pop.metadata, pop.metadata_length); + if (ret < 0) { + goto out; + } + population_map[node.population] = ret; + } + new_pop = population_map[node.population]; + } + ret = tsk_node_table_add_row(&self->nodes, node.flags, node.time, new_pop, new_ind, + node.metadata, node.metadata_length); + if (ret < 0) { + goto out; + } + node_map[node.id] = ret; + +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_table_collection_subset( + tsk_table_collection_t *self, tsk_id_t *nodes, tsk_size_t num_nodes) +{ + int ret = 0; + tsk_id_t k, i, new_parent, new_child, new_node; tsk_edge_t edge; tsk_mutation_t mut; tsk_site_t site; @@ -8232,6 +8285,10 @@ tsk_table_collection_subset( if (ret != 0) { goto out; } + ret = tsk_table_collection_check_integrity(self, 0); + if (ret != 0) { + goto out; + } ret = tsk_table_collection_clear(self); if (ret != 0) { goto out; @@ -8255,43 +8312,11 @@ tsk_table_collection_subset( // nodes, individuals, populations for (k = 0; k < (tsk_id_t) num_nodes; k++) { - ret = tsk_node_table_get_row(&tables.nodes, nodes[k], &node); + ret = tsk_table_collection_add_and_remap_node( + self, &tables, nodes[k], individual_map, population_map, node_map, true); if (ret < 0) { goto out; } - new_ind = TSK_NULL; - if (node.individual != TSK_NULL) { - if (individual_map[node.individual] == TSK_NULL) { - tsk_individual_table_get_row(&tables.individuals, node.individual, &ind); - ret = tsk_individual_table_add_row(&self->individuals, ind.flags, - ind.location, ind.location_length, ind.metadata, - ind.metadata_length); - if (ret < 0) { - goto out; - } - individual_map[node.individual] = ret; - } - new_ind = individual_map[node.individual]; - } - new_pop = TSK_NULL; - if (node.population != TSK_NULL) { - if (population_map[node.population] == TSK_NULL) { - tsk_population_table_get_row(&tables.populations, node.population, &pop); - ret = tsk_population_table_add_row( - &self->populations, pop.metadata, pop.metadata_length); - if (ret < 0) { - goto out; - } - population_map[node.population] = ret; - } - new_pop = population_map[node.population]; - } - ret = tsk_node_table_add_row(&self->nodes, node.flags, node.time, new_pop, - new_ind, node.metadata, node.metadata_length); - if (ret < 0) { - goto out; - } - node_map[node.id] = ret; } // edges @@ -8367,6 +8392,238 @@ tsk_table_collection_subset( return ret; } +static int +tsk_check_subset_equality(tsk_table_collection_t *self, tsk_table_collection_t *other, + tsk_id_t *other_node_mapping, tsk_size_t num_shared_nodes) +{ + int ret = 0; + tsk_id_t k, i; + tsk_id_t *self_nodes = NULL; + tsk_id_t *other_nodes = NULL; + tsk_table_collection_t self_copy; + tsk_table_collection_t other_copy; + + memset(&self_copy, 0, sizeof(self_copy)); + memset(&other_copy, 0, sizeof(other_copy)); + self_nodes = malloc(num_shared_nodes * sizeof(*self_nodes)); + other_nodes = malloc(num_shared_nodes * sizeof(*other_nodes)); + if (self_nodes == NULL || other_nodes == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + i = 0; + for (k = 0; k < (tsk_id_t) other->nodes.num_rows; k++) { + if (other_node_mapping[k] != TSK_NULL) { + self_nodes[i] = other_node_mapping[k]; + other_nodes[i] = k; + i++; + } + } + + // TODO: strict sort before checking equality + ret = tsk_table_collection_copy(self, &self_copy, 0); + if (ret != 0) { + goto out; + } + ret = tsk_table_collection_copy(other, &other_copy, 0); + if (ret != 0) { + goto out; + } + ret = tsk_provenance_table_clear(&other_copy.provenances); + if (ret != 0) { + goto out; + } + ret = tsk_provenance_table_clear(&self_copy.provenances); + if (ret != 0) { + goto out; + } + ret = tsk_table_collection_subset(&self_copy, self_nodes, num_shared_nodes); + if (ret != 0) { + goto out; + } + ret = tsk_table_collection_subset(&other_copy, other_nodes, num_shared_nodes); + if (ret != 0) { + goto out; + } + if (!tsk_table_collection_equals(&self_copy, &other_copy)) { + ret = TSK_ERR_UNION_DIFF_HISTORIES; + goto out; + } + +out: + tsk_table_collection_free(&self_copy); + tsk_table_collection_free(&other_copy); + tsk_safe_free(other_nodes); + tsk_safe_free(self_nodes); + return ret; +} + +int TSK_WARN_UNUSED +tsk_table_collection_union(tsk_table_collection_t *self, tsk_table_collection_t *other, + tsk_id_t *other_node_mapping, tsk_flags_t options) +{ + int ret = 0; + tsk_id_t k, i, new_parent, new_child; + tsk_size_t num_shared_nodes = 0; + tsk_edge_t edge; + tsk_mutation_t mut; + tsk_site_t site; + tsk_id_t *node_map = NULL; + tsk_id_t *individual_map = NULL; + tsk_id_t *population_map = NULL; + tsk_id_t *site_map = NULL; + bool add_populations = !(options & TSK_UNION_NO_ADD_POP); + bool check_shared_portion = !(options & TSK_UNION_NO_CHECK_SHARED); + + ret = tsk_table_collection_check_integrity(self, 0); + if (ret != 0) { + goto out; + } + ret = tsk_table_collection_check_integrity(other, 0); + if (ret != 0) { + goto out; + } + for (k = 0; k < (tsk_id_t) other->nodes.num_rows; k++) { + if (other_node_mapping[k] >= (tsk_id_t) self->nodes.num_rows + || other_node_mapping[k] < TSK_NULL) { + ret = TSK_ERR_UNION_BAD_MAP; + goto out; + } + if (other_node_mapping[k] != TSK_NULL) { + num_shared_nodes++; + } + } + + if (check_shared_portion) { + ret = tsk_check_subset_equality( + self, other, other_node_mapping, num_shared_nodes); + if (ret != 0) { + goto out; + } + } + + // Maps relating the IDs in other to the new IDs in self. + node_map = malloc(other->nodes.num_rows * sizeof(*node_map)); + individual_map = malloc(other->individuals.num_rows * sizeof(*individual_map)); + population_map = malloc(other->populations.num_rows * sizeof(*population_map)); + site_map = malloc(other->sites.num_rows * sizeof(*site_map)); + if (node_map == NULL || individual_map == NULL || population_map == NULL + || site_map == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + memset(node_map, 0xff, other->nodes.num_rows * sizeof(*node_map)); + memset(individual_map, 0xff, other->individuals.num_rows * sizeof(*individual_map)); + memset(population_map, 0xff, other->populations.num_rows * sizeof(*population_map)); + memset(site_map, 0xff, other->sites.num_rows * sizeof(*site_map)); + + // nodes, individuals, populations + for (k = 0; k < (tsk_id_t) other->nodes.num_rows; k++) { + if (other_node_mapping[k] != TSK_NULL) { + node_map[k] = other_node_mapping[k]; + } else { + ret = tsk_table_collection_add_and_remap_node(self, other, k, individual_map, + population_map, node_map, add_populations); + if (ret < 0) { + goto out; + } + } + } + + // edges + for (k = 0; k < (tsk_id_t) other->edges.num_rows; k++) { + tsk_edge_table_get_row(&other->edges, k, &edge); + if ((other_node_mapping[edge.parent] == TSK_NULL) + || (other_node_mapping[edge.child] == TSK_NULL)) { + /* TODO: union does not support case where non-shared bits of + * other are above the shared bits of self and other. This will be + * resolved when the Mutation Table has a time attribute and + * the Mutation Table is sorted on time. */ + if (other_node_mapping[edge.parent] == TSK_NULL + && other_node_mapping[edge.child] != TSK_NULL) { + ret = TSK_ERR_UNION_NOT_SUPPORTED; + goto out; + } + new_parent = node_map[edge.parent]; + new_child = node_map[edge.child]; + ret = tsk_edge_table_add_row(&self->edges, edge.left, edge.right, new_parent, + new_child, edge.metadata, edge.metadata_length); + if (ret < 0) { + goto out; + } + } + } + + // mutations and sites + i = 0; + for (k = 0; k < (tsk_id_t) other->sites.num_rows; k++) { + tsk_site_table_get_row(&other->sites, k, &site); + while ((i < (tsk_id_t) other->mutations.num_rows) + && (other->mutations.site[i] == site.id)) { + tsk_mutation_table_get_row(&other->mutations, i, &mut); + if (other_node_mapping[mut.node] == TSK_NULL) { + if (site_map[site.id] == TSK_NULL) { + ret = tsk_site_table_add_row(&self->sites, site.position, + site.ancestral_state, site.ancestral_state_length, site.metadata, + site.metadata_length); + if (ret < 0) { + goto out; + } + site_map[site.id] = ret; + } + // the parents will be recomputed later + new_parent = TSK_NULL; + ret = tsk_mutation_table_add_row(&self->mutations, site_map[site.id], + node_map[mut.node], new_parent, mut.time, mut.derived_state, + mut.derived_state_length, mut.metadata, mut.metadata_length); + if (ret < 0) { + goto out; + } + } + i++; + } + } + + /* TODO: Union of the Migrations Table. The only hindrance to performing the + * union operation on Migrations Tables is that tsk_table_collection_sort + * does not sort migrations by time, and instead throws an error. */ + if (self->migrations.num_rows != 0 || other->migrations.num_rows != 0) { + ret = TSK_ERR_MIGRATIONS_NOT_SUPPORTED; + goto out; + } + + // provenance (new record is added in python) + + // deduplicating, sorting, and computing parents + ret = tsk_table_collection_sort(self, 0, 0); + if (ret < 0) { + goto out; + } + + ret = tsk_table_collection_deduplicate_sites(self, 0); + if (ret < 0) { + goto out; + } + + ret = tsk_table_collection_build_index(self, 0); + if (ret < 0) { + goto out; + } + + ret = tsk_table_collection_compute_mutation_parents(self, 0); + if (ret < 0) { + goto out; + } + +out: + tsk_safe_free(node_map); + tsk_safe_free(individual_map); + tsk_safe_free(population_map); + tsk_safe_free(site_map); + return ret; +} + static int cmp_edge_cl(const void *a, const void *b) { diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 83cdb26d68..3cadf6bdc4 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -683,6 +683,10 @@ typedef struct _tsk_table_sorter_t { /* Flags for table init. */ #define TSK_NO_METADATA (1 << 0) +/* Flags for union() */ +#define TSK_UNION_NO_CHECK_SHARED (1 << 0) +#define TSK_UNION_NO_ADD_POP (1 << 1) + /****************************************************************************/ /* Function signatures */ /****************************************************************************/ @@ -2515,6 +2519,61 @@ nodes (and individuals and populations) reordered. int tsk_table_collection_subset( tsk_table_collection_t *self, tsk_id_t *nodes, tsk_size_t num_nodes); +/** +@brief Forms the node-wise union of two table collections. + +@rst +Expands this table collection by adding the non-shared portions of another table +collection to itself. The ``other_node_mapping`` encodes which nodes in ``other`` are +equivalent to a node in ``self``. The positions in the ``other_node_mapping`` array +correspond to node ids in ``other``, and the elements encode the equivalent +node id in ``self`` or TSK_NULL if the node is exclusive to ``other``. Nodes +that are exclusive ``other`` are added to ``self``, along with: + +1. Individuals which are new to ``self``. +2. Edges whose parent or child are new to ``self``. +3. Sites which were not present in ``self``. +4. Mutations whose nodes are new to ``self``. + +By default, populations of newly added nodes are assumed to be new populations, +and added to the population table as well. + +This operation will also sort the resulting tables, so the tables may change +even if nothing new is added, if the original tables were not sorted. + +**Options**: + +Options can be specified by providing one or more of the following bitwise +flags: + +TSK_UNION_NO_CHECK_SHARED + By default, union checks that the portion of shared history between + ``self`` and ``other``, as implied by ``other_node_mapping``, are indeed + equivalent. It does so by subsetting both ``self`` and ``other`` on the + equivalent nodes specified in ``other_node_mapping``, and then checking for + equality of the subsets. +TSK_UNION_NO_ADD_POP + By default, all nodes new to ``self`` are assigned new populations. If this + option is specified, nodes that are added to ``self`` will retain the + population IDs they have in ``other``. + +.. note:: Migrations are currently not supported by union, and an error will + be raised if we attempt call union on a table collection with migrations. +@endrst + +@param self A pointer to a tsk_table_collection_t object. +@param other A pointer to a tsk_table_collection_t object. +@param other_node_mapping An array of node IDs that relate nodes in other to nodes in +self: the k-th element of other_node_mapping should be the index of the equivalent +node in self, or TSK_NULL if the node is not present in self (in which case it +will be added to self). +@param options Union options; see above for the available bitwise flags. + For the default behaviour, a value of 0 should be provided. +@return Return 0 on success or a negative value on failure. +*/ +int tsk_table_collection_union(tsk_table_collection_t *self, + tsk_table_collection_t *other, tsk_id_t *other_node_mapping, tsk_flags_t options); + /** @brief Set the metadata @rst diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 2ba14642f3..f1b1ce75ee 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -21,6 +21,12 @@ In development **New features** +- New methods to perform set operations on TableCollections and TreeSequences. + ``TableCollection.subset`` subsets and reorders table collections by nodes + (:user:`mufernando`, :user:`petrelharp`, :pr:`663`, :pr:`690`). + ``TableCollection.union`` forms the node-wise union of two table collections + (:user:`mufernando`, :user:`petrelharp`, :issue:`381` :pr:`623`). + - Mutations now have an optional double-precision floating-point ``time`` column. If not specified, this defaults to a particular NaN value (``tskit.UNKNOWN_TIME``) indicating that the time is unknown. For a tree sequence to be considered valid diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index d0deab7de4..488994008a 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -6575,6 +6575,59 @@ TableCollection_subset(TableCollection *self, PyObject *args) return ret; } +/* Forward declaration */ +static PyTypeObject TableCollectionType; + +static PyObject * +TableCollection_union(TableCollection *self, PyObject *args, PyObject *kwds) { + int err; + TableCollection *other = NULL; + PyObject *ret = NULL; + PyObject *other_node_mapping = NULL; + PyArrayObject *nmap_array = NULL; + npy_intp *shape; + tsk_flags_t options = 0; + int check_shared = true; + int add_populations = true; + static char *kwlist[] = {"other", "other_node_mapping", "check_shared_equality", + "add_populations", NULL}; + + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "O!O|ii", kwlist, &TableCollectionType, &other, + &other_node_mapping, &check_shared, &add_populations)) { + goto out; + } + nmap_array = (PyArrayObject *)PyArray_FROMANY(other_node_mapping, NPY_INT32, + 1, 1, NPY_ARRAY_IN_ARRAY); + if (nmap_array == NULL) { + goto out; + } + shape = PyArray_DIMS(nmap_array); + if (other->tables->nodes.num_rows != (tsk_size_t) shape[0]) { + PyErr_SetString( + PyExc_ValueError, + "The length of the node mapping array should be equal to the" + " number of nodes in the other tree sequence."); + goto out; + } + if (!check_shared) { + options |= TSK_UNION_NO_CHECK_SHARED; + } + if (!add_populations) { + options |= TSK_UNION_NO_ADD_POP; + } + err = tsk_table_collection_union(self->tables, other->tables, + PyArray_DATA(nmap_array), options); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + Py_XDECREF(nmap_array); + return ret; +} + static PyObject * TableCollection_sort(TableCollection *self, PyObject *args, PyObject *kwds) { @@ -6687,9 +6740,6 @@ TableCollection_has_index(TableCollection *self) return Py_BuildValue("i", (int) has_index); } -/* Forward declaration */ -static PyTypeObject TableCollectionType; - static PyObject * TableCollection_equals(TableCollection *self, PyObject *args) { @@ -6734,7 +6784,9 @@ static PyMethodDef TableCollection_methods[] = { METH_VARARGS|METH_KEYWORDS, "Returns an edge table linking samples to a set of specified ancestors." }, {"subset", (PyCFunction) TableCollection_subset, METH_VARARGS, - "Subsets the tree sequence to a set of nodes." }, + "Subsets the table collection to a set of nodes." }, + {"union", (PyCFunction) TableCollection_union, METH_VARARGS|METH_KEYWORDS, + "Adds to this table collection the portions of another table collection that are not shared with this one." }, {"sort", (PyCFunction) TableCollection_sort, METH_VARARGS|METH_KEYWORDS, "Sorts the tables to satisfy tree sequence requirements." }, {"equals", (PyCFunction) TableCollection_equals, METH_VARARGS, diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index c39ee9d4a7..0564f92d6b 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -277,6 +277,28 @@ def test_subset_bad_args(self): with self.assertRaises(_tskit.LibraryError): tc.subset(np.array([100, 200], dtype="int32")) + def test_union_bad_args(self): + ts = msprime.simulate(10, random_seed=1) + tc = ts.tables.ll_tables + tc2 = tc + with self.assertRaises(TypeError): + tc.union(tc2, np.array(["a"])) + with self.assertRaises(ValueError): + tc.union(tc2, np.array([0], dtype="int32")) + with self.assertRaises(TypeError): + tc.union(tc2) + with self.assertRaises(TypeError): + tc.union() + node_mapping = np.arange(ts.num_nodes, dtype="int32") + node_mapping[0] = 1200 + with self.assertRaises(_tskit.LibraryError): + tc.union(tc2, node_mapping) + node_mapping = np.array( + [node_mapping.tolist(), node_mapping.tolist()], dtype="int32" + ) + with self.assertRaises(ValueError): + tc.union(tc2, np.array([[1], [2]], dtype="int32")) + class TestTreeSequence(LowLevelTestCase, MetadataTestMixin): """ diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 1469505701..9ae67e386a 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -26,6 +26,7 @@ """ import io import itertools +import json import math import pickle import random @@ -2655,3 +2656,142 @@ def test_empty_nodes(self): self.assertEqual(subset.sites.num_rows, 0) self.assertEqual(subset.mutations.num_rows, 0) self.assertEqual(subset.provenances, tables.provenances) + + +class TestUnion(unittest.TestCase): + """ + Tests for the TableCollection.union method + """ + + def get_msprime_example(self, sample_size=3, T=5, seed=1239): + # we assume after the split the ts are completely independent + M = [[0, 0], [0, 0]] + population_configurations = [ + msprime.PopulationConfiguration(sample_size=sample_size), + msprime.PopulationConfiguration(sample_size=sample_size), + ] + demographic_events = [ + msprime.CensusEvent(time=T), + msprime.MassMigration(T, source=1, dest=0, proportion=1), + ] + ts = msprime.simulate( + population_configurations=population_configurations, + demographic_events=demographic_events, + migration_matrix=M, + length=2e5, + recombination_rate=1e-8, + mutation_rate=1e-7, + record_migrations=False, + random_seed=seed, + ) + ts = tsutil.add_random_metadata(ts, seed) + ts = tsutil.insert_random_ploidy_individuals(ts, max_ploidy=1) + return ts + + def get_wf_example(self, N=5, T=5, seed=1249): + twopop_tables = wf.wf_sim(N, T, num_pops=2, seed=seed, deep_history=True) + twopop_tables.sort() + ts = twopop_tables.tree_sequence() + ts = ts.simplify() + # adding muts + ts = tsutil.jukes_cantor(ts, 1, 10, seed=seed) + ts = tsutil.add_random_metadata(ts, seed) + ts = tsutil.insert_random_ploidy_individuals(ts, max_ploidy=2) + return ts + + def split_example(self, ts, T): + # splitting two pop ts into disjoint ts + shared_nodes = [n.id for n in ts.nodes() if n.time >= T] + pop1 = list(ts.samples(population=0)) + pop2 = list(ts.samples(population=1)) + tables1 = ts.simplify(shared_nodes + pop1, record_provenance=False).tables + tables2 = ts.simplify(shared_nodes + pop2, record_provenance=False).tables + node_mapping = [ + i if i < len(shared_nodes) else tskit.NULL + for i in range(tables2.nodes.num_rows) + ] + return tables1, tables2, node_mapping + + def verify_union_equality(self, tables, other, node_mapping, add_populations=True): + # verifying against py impl + uni1 = tables.copy() + uni2 = tables.copy() + uni1.union( + other, + node_mapping, + record_provenance=False, + add_populations=add_populations, + ) + tsutil.py_union( + uni2, + other, + node_mapping, + record_provenance=False, + add_populations=add_populations, + ) + self.assertEqual(uni1, uni2) + # verifying that subsetting to original nodes return the same table + orig_nodes = [j for i, j in enumerate(node_mapping) if j != tskit.NULL] + uni1.subset(orig_nodes) + # subsetting tables just to make sure order is the same + tables.subset(orig_nodes) + uni1.provenances.clear() + tables.provenances.clear() + self.assertEqual(uni1, tables) + + def test_noshared_example(self): + ts1 = self.get_msprime_example(sample_size=3, T=2, seed=9328) + ts2 = self.get_msprime_example(sample_size=3, T=2, seed=2125) + node_mapping = np.full(ts2.num_nodes, tskit.NULL, dtype="int32") + uni1 = ts1.union(ts2, node_mapping, record_provenance=False) + uni2_tables = ts1.dump_tables() + tsutil.py_union(uni2_tables, ts2.tables, node_mapping, record_provenance=False) + self.assertEqual(uni1.tables, uni2_tables) + + def test_all_shared_example(self): + tables = self.get_wf_example(N=5, T=5, seed=11349).dump_tables() + uni = tables.copy() + node_mapping = np.arange(tables.nodes.num_rows) + uni.union(tables, node_mapping, record_provenance=False) + self.assertEqual(uni, tables) + + def test_no_add_pop(self): + self.verify_union_equality( + *self.split_example(self.get_msprime_example(10, 10), 10), + add_populations=False, + ) + self.verify_union_equality( + *self.split_example(self.get_wf_example(10, 10), 10), add_populations=False + ) + + def test_provenance(self): + tables, other, node_mapping = self.split_example( + self.get_msprime_example(5, 2, seed=928), 2 + ) + tables_copy = tables.copy() + tables.union(other, node_mapping) + uni_other_dict = json.loads(tables.provenances[-1].record)["parameters"][ + "other" + ] + recovered_prov_table = tskit.ProvenanceTable() + self.assertEqual( + len(uni_other_dict["timestamp"]), len(uni_other_dict["record"]) + ) + for timestamp, record in zip( + uni_other_dict["timestamp"], uni_other_dict["record"] + ): + recovered_prov_table.add_row(record, timestamp) + self.assertEqual(recovered_prov_table, other.provenances) + tables.provenances.truncate(tables.provenances.num_rows - 1) + self.assertEqual(tables.provenances, tables_copy.provenances) + + def test_examples(self): + for N in [2, 4, 5]: + for T in [2, 5, 20]: + with self.subTest(N=N, T=T): + self.verify_union_equality( + *self.split_example(self.get_msprime_example(N, T), T) + ) + self.verify_union_equality( + *self.split_example(self.get_wf_example(N, T), T) + ) diff --git a/python/tests/test_wright_fisher.py b/python/tests/test_wright_fisher.py index cbabf2859b..6e30c833dd 100644 --- a/python/tests/test_wright_fisher.py +++ b/python/tests/test_wright_fisher.py @@ -38,12 +38,16 @@ class WrightFisherSimulator: """ - SIMPLE simulation of a bisexual, haploid Wright-Fisher population of size N - for ngens generations, in which each individual survives with probability - survival and only those who die are replaced. If num_loci is None, - the chromosome is 1.0 Morgans long, and the mutation rate is in units of - mutations/Morgan/generation. If num_loci not None, a discrete recombination - model is used where breakpoints are chosen uniformly from 1 to num_loci - 1. + SIMPLE simulation of `num_pops` bisexual, haploid Wright-Fisher populations + of size `N` for `ngens` generations, in which each individual survives with + probability `survival` and only those who die are replaced. If `num_pops` is + greater than 1, the individual to be replaced has a chance `mig_rate` of + being the offspring of nodes from a different and randomly chosen + population. If `num_loci` is None, the chromosome is 1.0 Morgans long. If + `num_loci` not None, a discrete recombination model is used where + breakpoints are chosen uniformly from 1 to `num_loci` - 1. If + `deep_history` is True, a history to coalescence of just one population of + `self.N` samples is added at the beginning. """ def __init__( @@ -55,10 +59,16 @@ def __init__( debug=False, initial_generation_samples=False, num_loci=None, + num_pops=1, + mig_rate=0.0, + record_migrations=False, ): self.N = N + self.num_pops = num_pops self.num_loci = num_loci self.survival = survival + self.mig_rate = mig_rate + self.record_migrations = record_migrations self.deep_history = deep_history self.debug = debug self.initial_generation_samples = initial_generation_samples @@ -76,11 +86,18 @@ def run(self, ngens): if self.num_loci is not None: L = self.num_loci tables = tskit.TableCollection(sequence_length=L) - tables.populations.add_row() + for _ in range(self.num_pops): + tables.populations.add_row() if self.deep_history: # initial population + population_configurations = [ + msprime.PopulationConfiguration(sample_size=self.N) + ] init_ts = msprime.simulate( - self.N, recombination_rate=1.0, length=L, random_seed=self.seed + population_configurations=population_configurations, + recombination_rate=1.0, + length=L, + random_seed=self.seed, ) init_tables = init_ts.dump_tables() flags = init_tables.nodes.flags @@ -97,48 +114,74 @@ def run(self, ngens): flags = 0 if self.initial_generation_samples: flags = tskit.NODE_IS_SAMPLE - for _ in range(self.N): - tables.nodes.add_row(flags=flags, time=ngens, population=0) - - pop = list(range(self.N)) + for p in range(self.num_pops): + for _ in range(self.N): + tables.nodes.add_row(flags=flags, time=ngens, population=p) + + pops = [ + list(range(p * self.N, (p * self.N) + self.N)) for p in range(self.num_pops) + ] + pop_ids = list(range(self.num_pops)) for t in range(ngens - 1, -1, -1): if self.debug: print("t:", t) - print("pop:", pop) - - dead = [self.rng.random() > self.survival for k in pop] + print("pops:", pops) + dead = [[self.rng.random() > self.survival for _ in pop] for pop in pops] # sample these first so that all parents are from the previous gen - new_parents = [ - (self.rng.choice(pop), self.rng.choice(pop)) for k in range(sum(dead)) - ] - k = 0 + parent_pop = [] + new_parents = [] + for p in pop_ids: + w = [ + 1 - self.mig_rate if i == p else self.mig_rate / (self.num_pops - 1) + for i in pop_ids + ] + parent_pop.append(self.rng.choices(pop_ids, w, k=sum(dead[p]))) + new_parents.append( + [ + self.rng.choices(pops[parent_pop[p][k]], k=2) + for k in range(sum(dead[p])) + ] + ) + if self.debug: - print("Replacing", sum(dead), "individuals.") - for j in range(self.N): - if dead[j]: - # this is: offspring ID, lparent, rparent, breakpoint - offspring = len(tables.nodes) - tables.nodes.add_row(time=t, population=0) - lparent, rparent = new_parents[k] - k += 1 - bp = self.random_breakpoint() - if self.debug: - print("--->", offspring, lparent, rparent, bp) - pop[j] = offspring - if bp > 0.0: - tables.edges.add_row( - left=0.0, right=bp, parent=lparent, child=offspring - ) - if bp < L: - tables.edges.add_row( - left=bp, right=L, parent=rparent, child=offspring - ) + for p in pop_ids: + print("Replacing", sum(dead[p]), "individuals from pop", p) + for p in pop_ids: + k = 0 + for j in range(self.N): + if dead[p][j]: + # this is: offspring ID, lparent, rparent, breakpoint + offspring = tables.nodes.add_row(time=t, population=p) + if parent_pop[p][k] != p and self.record_migrations: + tables.migrations.add_row( + left=0.0, + right=L, + node=offspring, + source=parent_pop[p][k], + dest=p, + time=t, + ) + lparent, rparent = new_parents[p][k] + k += 1 + bp = self.random_breakpoint() + if self.debug: + print("--->", offspring, lparent, rparent, bp) + pops[p][j] = offspring + if bp > 0.0: + tables.edges.add_row( + left=0.0, right=bp, parent=lparent, child=offspring + ) + if bp < L: + tables.edges.add_row( + left=bp, right=L, parent=rparent, child=offspring + ) if self.debug: print("Done! Final pop:") - print(pop) + print(pops) flags = tables.nodes.flags - flags[pop] = tskit.NODE_IS_SAMPLE + flattened = [n for pop in pops for n in pop] + flags[flattened] = tskit.NODE_IS_SAMPLE tables.nodes.set_columns( flags=flags, time=tables.nodes.time, population=tables.nodes.population ) @@ -154,6 +197,9 @@ def wf_sim( seed=None, initial_generation_samples=False, num_loci=None, + num_pops=1, + mig_rate=0.0, + record_migrations=False, ): sim = WrightFisherSimulator( N, @@ -163,6 +209,9 @@ def wf_sim( seed=seed, initial_generation_samples=initial_generation_samples, num_loci=num_loci, + num_pops=num_pops, + mig_rate=mig_rate, + record_migrations=record_migrations, ) return sim.run(ngens) @@ -174,6 +223,75 @@ class TestSimulation(unittest.TestCase): random_seed = 5678 + def test_one_gen_multipop_mig_no_deep(self): + tables = wf_sim( + N=5, + ngens=1, + num_pops=4, + mig_rate=1.0, + deep_history=False, + seed=self.random_seed, + record_migrations=True, + ) + self.assertEqual(tables.nodes.num_rows, 5 * 4 * (1 + 1)) + self.assertGreater(tables.edges.num_rows, 0) + self.assertEqual(tables.migrations.num_rows, 5 * 4) + + def test_multipop_mig_deep(self): + N = 10 + ngens = 20 + num_pops = 3 + tables = wf_sim( + N=N, + ngens=ngens, + num_pops=num_pops, + mig_rate=1.0, + seed=self.random_seed, + record_migrations=True, + ) + self.assertGreater(tables.nodes.num_rows, (num_pops * N * ngens) + N) + self.assertGreater(tables.edges.num_rows, 0) + self.assertEqual(tables.sites.num_rows, 0) + self.assertEqual(tables.mutations.num_rows, 0) + self.assertGreaterEqual(tables.migrations.num_rows, N * num_pops * ngens) + self.assertEqual(tables.populations.num_rows, num_pops) + # sort does not support mig + tables.migrations.clear() + # making sure trees are valid + tables.sort() + tables.simplify() + ts = tables.tree_sequence() + sample_pops = tables.nodes.population[ts.samples()] + self.assertEqual(np.unique(sample_pops).size, num_pops) + + def test_multipop_mig_no_deep(self): + N = 5 + ngens = 5 + num_pops = 2 + tables = wf_sim( + N=N, + ngens=ngens, + num_pops=num_pops, + mig_rate=1.0, + deep_history=False, + seed=self.random_seed, + record_migrations=True, + ) + self.assertEqual(tables.nodes.num_rows, num_pops * N * (ngens + 1)) + self.assertGreater(tables.edges.num_rows, 0) + self.assertEqual(tables.sites.num_rows, 0) + self.assertEqual(tables.mutations.num_rows, 0) + self.assertEqual(tables.migrations.num_rows, N * num_pops * ngens) + self.assertEqual(tables.populations.num_rows, num_pops) + # sort does not support mig + tables.migrations.clear() + # making sure trees are valid + tables.sort() + tables.simplify() + ts = tables.tree_sequence() + sample_pops = tables.nodes.population[ts.samples()] + self.assertEqual(np.unique(sample_pops).size, num_pops) + def test_non_overlapping_generations(self): tables = wf_sim(N=10, ngens=10, survival=0.0, seed=self.random_seed) self.assertGreater(tables.nodes.num_rows, 0) diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index b76acdad9c..30b04c286f 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -638,6 +638,84 @@ def py_subset(tables, nodes, record_provenance=True): mutation_map[i] = new_mut +def py_union(tables, other, nodes, record_provenance=True, add_populations=True): + """ + Python implementation of TableCollection.union(). + """ + # mappings of id in other to new id in tables + # the +1 is to take care of mapping tskit.NULL(-1) to tskit.NULL + pop_map = [tskit.NULL for _ in range(other.populations.num_rows + 1)] + ind_map = [tskit.NULL for _ in range(other.individuals.num_rows + 1)] + node_map = [tskit.NULL for _ in range(other.nodes.num_rows + 1)] + site_map = [tskit.NULL for _ in range(other.sites.num_rows + 1)] + mut_map = [tskit.NULL for _ in range(other.mutations.num_rows + 1)] + for other_id, node in enumerate(other.nodes): + if nodes[other_id] != tskit.NULL: + node_map[other_id] = nodes[other_id] + else: + if ind_map[node.individual] == tskit.NULL and node.individual != tskit.NULL: + ind = other.individuals[node.individual] + ind_id = tables.individuals.add_row( + flags=ind.flags, location=ind.location, metadata=ind.metadata + ) + ind_map[node.individual] = ind_id + if pop_map[node.population] == tskit.NULL and node.population != tskit.NULL: + if not add_populations: + pop_map[node.population] = node.population + else: + pop = other.populations[node.population] + pop_id = tables.populations.add_row(metadata=pop.metadata) + pop_map[node.population] = pop_id + node_id = tables.nodes.add_row( + time=node.time, + population=pop_map[node.population], + individual=ind_map[node.individual], + metadata=node.metadata, + flags=node.flags, + ) + node_map[other_id] = node_id + for edge in other.edges: + if (nodes[edge.parent] == tskit.NULL) or (nodes[edge.child] == tskit.NULL): + # can't do this right not because of sorting of mutations + if (nodes[edge.parent] == tskit.NULL) and (nodes[edge.child] != tskit.NULL): + raise ValueError("Cannot graft nodes above existing nodes.") + tables.edges.add_row( + left=edge.left, + right=edge.right, + parent=node_map[edge.parent], + child=node_map[edge.child], + metadata=edge.metadata, + ) + for other_id, mut in enumerate(other.mutations): + if nodes[mut.node] == tskit.NULL: + # add site: may already be in tables, but we deduplicate + if site_map[mut.site] == tskit.NULL: + site = other.sites[mut.site] + site_id = tables.sites.add_row( + position=site.position, + ancestral_state=site.ancestral_state, + metadata=site.metadata, + ) + site_map[mut.site] = site_id + mut_id = tables.mutations.add_row( + site=site_map[mut.site], + node=node_map[mut.node], + derived_state=mut.derived_state, + parent=tskit.NULL, + metadata=mut.metadata, + ) + mut_map[other_id] = mut_id + # migration table + # grafting provenance table + if record_provenance: + pass + # sorting, deduplicating sites, and re-computing mutation parents + tables.sort() + tables.deduplicate_sites() + tables.build_index() + tables.compute_mutation_parents() + + def compute_mutation_times(ts): """ Compute the `time` column of a MutationTable in a TableCollection. diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 547f84e279..00b6b22d5a 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -2602,3 +2602,51 @@ def subset(self, nodes, record_provenance=True): self.provenances.add_row( record=json.dumps(provenance.get_provenance_dict(parameters)) ) + + def union( + self, + other, + node_mapping, + check_shared_equality=True, + add_populations=True, + record_provenance=True, + ): + """ + Modifies the table collection in place by adding the non-shared + portions of ``other`` to itself. To perform the node-wise union, + the method relies on a ``node_mapping`` array, that maps nodes in + ``other`` to its equivalent node in ``self`` or ``tskit.NULL`` if + the node is exclusive to ``other``. See :meth:`TreeSequence.union` for a more + detailed description. + + :param TableCollection other: Another table collection. + :param list node_mapping: An array of node IDs that relate nodes in + ``other`` to nodes in ``self``: the k-th element of ``node_mapping`` + should be the index of the equivalent node in ``self``, or + ``tskit.NULL`` if the node is not present in ``self`` (in which case it + will be added to self). + :param bool check_shared_equality: If True, the shared portions of the + table collections will be checked for equality. + :param bool add_populations: If True, nodes new to ``self`` will be + assigned new population IDs. + :param bool record_provenance: Whether to record a provenance entry + in the provenance table for this operation. + """ + node_mapping = util.safe_np_int_cast(node_mapping, np.int32) + self.ll_tables.union( + other.ll_tables, + node_mapping, + check_shared_equality=check_shared_equality, + add_populations=add_populations, + ) + if record_provenance: + other_records = [prov.record for prov in other.provenances] + other_timestamps = [prov.timestamp for prov in other.provenances] + parameters = { + "command": "union", + "other": {"timestamp": other_timestamps, "record": other_records}, + "node_mapping": node_mapping.tolist(), + } + self.provenances.add_row( + record=json.dumps(provenance.get_provenance_dict(parameters)) + ) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index aa0e8c8d19..e0a0617f87 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -4549,6 +4549,62 @@ def subset(self, nodes, record_provenance=True): tables.subset(nodes, record_provenance) return tables.tree_sequence() + def union( + self, + other, + node_mapping, + check_shared_equality=True, + add_populations=True, + record_provenance=True, + ): + """ + Returns an expanded tree sequence which contains the node-wise union of + ``self`` and ``other``, obtained by adding the non-shared portions of + ``other`` onto ``self``. The "shared" portions are specified using a + map that specifies which nodes in ``other`` are equivalent to those in + ``self``: the ``node_mapping`` argument should be an array of length + equal to the number of nodes in ``other`` and whose entries are the ID + of the matching node in ``self``, or ``tskit.NULL`` if there is no + matching node. Those nodes in ``other`` that map to ``tskit.NULL`` will + be added to ``self``, along with: + + 1. Individuals whose nodes are new to ``self``. + 2. Edges whose parent or child are new to ``self``. + 3. Mutations whose nodes are new to ``self``. + 4. Sites which were not present in ``self``, if the site contains a newly + added mutation. + + By default, populations of newly added nodes are assumed to be new + populations, and added to the population table as well. + + Note that this operation also sorts the resulting tables, so the + resulting tree sequence may not be equal to ``self`` even if nothing + new was added (although it would differ only in ordering of the tables). + + :param TableCollection other: Another table collection. + :param list node_mapping: An array of node IDs that relate nodes in + ``other`` to nodes in ``self``. + :param bool check_shared_equality: If True, the shared portions of the + tree sequences will be checked for equality. It does so by + subsetting both ``self`` and ``other`` on the equivalent nodes + specified in ``node_mapping``, and then checking for equality of + the subsets. + :param bool add_populations: If True, nodes new to ``self`` will be + assigned new population IDs. + :param bool record_provenance: Whether to record a provenance entry + in the provenance table for this operation. + """ + tables = self.dump_tables() + other_tables = other.dump_tables() + tables.union( + other_tables, + node_mapping, + check_shared_equality=check_shared_equality, + add_populations=add_populations, + record_provenance=record_provenance, + ) + return tables.tree_sequence() + def draw_svg( self, path=None,