Skip to content

Commit e5dc0f6

Browse files
authored
Merge pull request #1799 from jeromekelleher/use-virtual-root
Use virtual root
2 parents e622480 + 02d4771 commit e5dc0f6

File tree

8 files changed

+211
-130
lines changed

8 files changed

+211
-130
lines changed

c/tests/test_trees.c

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4329,6 +4329,44 @@ test_single_tree_is_descendant(void)
43294329
tsk_treeseq_free(&ts);
43304330
}
43314331

4332+
static void
4333+
test_single_tree_total_branch_length(void)
4334+
{
4335+
int ret;
4336+
tsk_treeseq_t ts;
4337+
tsk_tree_t tree;
4338+
double length;
4339+
4340+
tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL,
4341+
NULL, NULL, NULL, 0);
4342+
ret = tsk_tree_init(&tree, &ts, 0);
4343+
CU_ASSERT_EQUAL_FATAL(ret, 0);
4344+
ret = tsk_tree_first(&tree);
4345+
CU_ASSERT_EQUAL_FATAL(ret, 1);
4346+
4347+
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, TSK_NULL, &length), 0);
4348+
CU_ASSERT_EQUAL_FATAL(length, 9);
4349+
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 7, &length), 0);
4350+
CU_ASSERT_EQUAL_FATAL(length, 9);
4351+
CU_ASSERT_EQUAL_FATAL(
4352+
tsk_tree_get_total_branch_length(&tree, tree.virtual_root, &length), 0);
4353+
CU_ASSERT_EQUAL_FATAL(length, 9);
4354+
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 4, &length), 0);
4355+
CU_ASSERT_EQUAL_FATAL(length, 2);
4356+
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 0, &length), 0);
4357+
CU_ASSERT_EQUAL_FATAL(length, 0);
4358+
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, 5, &length), 0);
4359+
CU_ASSERT_EQUAL_FATAL(length, 4);
4360+
4361+
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_total_branch_length(&tree, -2, &length),
4362+
TSK_ERR_NODE_OUT_OF_BOUNDS);
4363+
CU_ASSERT_EQUAL_FATAL(
4364+
tsk_tree_get_total_branch_length(&tree, 8, &length), TSK_ERR_NODE_OUT_OF_BOUNDS);
4365+
4366+
tsk_tree_free(&tree);
4367+
tsk_treeseq_free(&ts);
4368+
}
4369+
43324370
static void
43334371
test_single_tree_map_mutations(void)
43344372
{
@@ -6605,6 +6643,7 @@ main(int argc, char **argv)
66056643
{ "test_single_tree_compute_mutation_times",
66066644
test_single_tree_compute_mutation_times },
66076645
{ "test_single_tree_is_descendant", test_single_tree_is_descendant },
6646+
{ "test_single_tree_total_branch_length", test_single_tree_total_branch_length },
66086647
{ "test_single_tree_map_mutations", test_single_tree_map_mutations },
66096648
{ "test_single_tree_map_mutations_internal_samples",
66106649
test_single_tree_map_mutations_internal_samples },

c/tskit/convert.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,7 @@ tsk_newick_converter_init(tsk_newick_converter_t *self, const tsk_tree_t *tree,
154154
self->options = options;
155155
self->tree = tree;
156156
self->traversal_stack
157-
= tsk_malloc(tsk_treeseq_get_num_nodes(self->tree->tree_sequence)
158-
* sizeof(*self->traversal_stack));
157+
= tsk_malloc(tsk_tree_get_size_bound(tree) * sizeof(*self->traversal_stack));
159158
if (self->traversal_stack == NULL) {
160159
ret = TSK_ERR_NO_MEMORY;
161160
goto out;

c/tskit/trees.c

Lines changed: 91 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3509,34 +3509,29 @@ tsk_tree_get_num_samples_by_traversal(
35093509
const tsk_tree_t *self, tsk_id_t u, tsk_size_t *num_samples)
35103510
{
35113511
int ret = 0;
3512-
tsk_id_t *stack = NULL;
3513-
tsk_id_t v;
3512+
tsk_size_t num_nodes, j;
35143513
tsk_size_t count = 0;
3515-
int stack_top = 0;
3514+
const tsk_flags_t *restrict flags = self->tree_sequence->tables->nodes.flags;
3515+
tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes));
3516+
tsk_id_t v;
35163517

3517-
stack = tsk_malloc(self->num_nodes * sizeof(*stack));
3518-
if (stack == NULL) {
3518+
if (nodes == NULL) {
35193519
ret = TSK_ERR_NO_MEMORY;
35203520
goto out;
35213521
}
3522-
3523-
stack[0] = u;
3524-
while (stack_top >= 0) {
3525-
v = stack[stack_top];
3526-
stack_top--;
3527-
if (tsk_treeseq_is_sample(self->tree_sequence, v)) {
3522+
ret = tsk_tree_preorder(self, u, nodes, &num_nodes);
3523+
if (ret != 0) {
3524+
goto out;
3525+
}
3526+
for (j = 0; j < num_nodes; j++) {
3527+
v = nodes[j];
3528+
if (flags[v] & TSK_NODE_IS_SAMPLE) {
35283529
count++;
35293530
}
3530-
v = self->left_child[v];
3531-
while (v != TSK_NULL) {
3532-
stack_top++;
3533-
stack[stack_top] = v;
3534-
v = self->right_sib[v];
3535-
}
35363531
}
35373532
*num_samples = count;
35383533
out:
3539-
tsk_safe_free(stack);
3534+
tsk_safe_free(nodes);
35403535
return ret;
35413536
}
35423537

@@ -3636,6 +3631,40 @@ tsk_tree_get_time(const tsk_tree_t *self, tsk_id_t u, double *t)
36363631
return ret;
36373632
}
36383633

3634+
int
3635+
tsk_tree_get_total_branch_length(const tsk_tree_t *self, tsk_id_t node, double *ret_tbl)
3636+
{
3637+
int ret = 0;
3638+
tsk_size_t j, num_nodes;
3639+
tsk_id_t u, v;
3640+
const tsk_id_t *restrict parent = self->parent;
3641+
const double *restrict time = self->tree_sequence->tables->nodes.time;
3642+
tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes));
3643+
double sum = 0;
3644+
3645+
if (nodes == NULL) {
3646+
ret = TSK_ERR_NO_MEMORY;
3647+
goto out;
3648+
}
3649+
ret = tsk_tree_preorder(self, node, nodes, &num_nodes);
3650+
if (ret != 0) {
3651+
goto out;
3652+
}
3653+
/* We always skip the first node because we don't return the branch length
3654+
* over the input node. */
3655+
for (j = 1; j < num_nodes; j++) {
3656+
u = nodes[j];
3657+
v = parent[u];
3658+
if (v != TSK_NULL) {
3659+
sum += time[v] - time[u];
3660+
}
3661+
}
3662+
*ret_tbl = sum;
3663+
out:
3664+
tsk_safe_free(nodes);
3665+
return ret;
3666+
}
3667+
36393668
int TSK_WARN_UNUSED
36403669
tsk_tree_get_sites(
36413670
const tsk_tree_t *self, const tsk_site_t **sites, tsk_size_t *sites_length)
@@ -3649,14 +3678,14 @@ tsk_tree_get_sites(
36493678
static int
36503679
tsk_tree_get_depth_unsafe(const tsk_tree_t *self, tsk_id_t u)
36513680
{
3652-
36533681
tsk_id_t v;
3682+
const tsk_id_t *restrict parent = self->parent;
36543683
int depth = 0;
36553684

36563685
if (u == self->virtual_root) {
36573686
return -1;
36583687
}
3659-
for (v = self->parent[u]; v != TSK_NULL; v = self->parent[v]) {
3688+
for (v = parent[u]; v != TSK_NULL; v = parent[v]) {
36603689
depth++;
36613690
}
36623691
return depth;
@@ -4443,6 +4472,9 @@ get_smallest_set_bit(uint64_t v)
44434472
* use a general cost matrix, in which case we'll use the Sankoff algorithm. For
44444473
* now this is unused.
44454474
*
4475+
* We should also vectorise the function so that several sites can be processed
4476+
* at once.
4477+
*
44464478
* The algorithm used here is Hartigan parsimony, "Minimum Mutation Fits to a
44474479
* Given Tree", Biometrics 1973.
44484480
*/
@@ -4458,30 +4490,34 @@ tsk_tree_map_mutations(tsk_tree_t *self, int8_t *genotypes,
44584490
int8_t state;
44594491
};
44604492
const tsk_size_t num_samples = self->tree_sequence->num_samples;
4461-
const tsk_size_t num_nodes = self->num_nodes;
44624493
const tsk_id_t *restrict left_child = self->left_child;
44634494
const tsk_id_t *restrict right_sib = self->right_sib;
4464-
const tsk_id_t *restrict parent = self->parent;
4495+
const tsk_size_t N = tsk_treeseq_get_num_nodes(self->tree_sequence);
44654496
const tsk_flags_t *restrict node_flags = self->tree_sequence->tables->nodes.flags;
4466-
uint64_t optimal_root_set;
4467-
uint64_t *restrict optimal_set = tsk_calloc(num_nodes, sizeof(*optimal_set));
4468-
tsk_id_t *restrict postorder_stack
4469-
= tsk_malloc(num_nodes * sizeof(*postorder_stack));
4497+
tsk_id_t *nodes = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*nodes));
4498+
/* Note: to use less memory here and to improve cache performance we should
4499+
* probably change to allocating exactly the number of nodes returned by
4500+
* a preorder traversal, and then lay the memory out in this order. So, we'd
4501+
* need a map from node ID to its index in the preorder traversal, but this
4502+
* is trivial to compute. Probably doesn't matter so much at the moment
4503+
* when we're doing a single site, but it would make a big difference if
4504+
* we were vectorising over lots of sites. */
4505+
uint64_t *restrict optimal_set = tsk_calloc(N + 1, sizeof(*optimal_set));
44704506
struct stack_elem *restrict preorder_stack
4471-
= tsk_malloc(num_nodes * sizeof(*preorder_stack));
4472-
tsk_id_t postorder_parent, root, u, v;
4507+
= tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*preorder_stack));
4508+
tsk_id_t root, u, v;
44734509
/* The largest possible number of transitions is one over every sample */
44744510
tsk_state_transition_t *transitions = tsk_malloc(num_samples * sizeof(*transitions));
44754511
int8_t allele, ancestral_state;
44764512
int stack_top;
44774513
struct stack_elem s;
4478-
tsk_size_t j, num_transitions, max_allele_count;
4514+
tsk_size_t j, num_transitions, max_allele_count, num_nodes;
44794515
tsk_size_t allele_count[HARTIGAN_MAX_ALLELES];
44804516
tsk_size_t non_missing = 0;
44814517
int8_t num_alleles = 0;
44824518

4483-
if (optimal_set == NULL || preorder_stack == NULL || postorder_stack == NULL
4484-
|| transitions == NULL) {
4519+
if (optimal_set == NULL || preorder_stack == NULL || transitions == NULL
4520+
|| nodes == NULL) {
44854521
ret = TSK_ERR_NO_MEMORY;
44864522
goto out;
44874523
}
@@ -4518,68 +4554,33 @@ tsk_tree_map_mutations(tsk_tree_t *self, int8_t *genotypes,
45184554
}
45194555
}
45204556

4521-
for (root = self->left_root; root != TSK_NULL; root = self->right_sib[root]) {
4522-
/* Do a post order traversal */
4523-
postorder_stack[0] = root;
4524-
stack_top = 0;
4525-
postorder_parent = TSK_NULL;
4526-
while (stack_top >= 0) {
4527-
u = postorder_stack[stack_top];
4528-
if (left_child[u] != TSK_NULL && u != postorder_parent) {
4529-
for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) {
4530-
stack_top++;
4531-
postorder_stack[stack_top] = v;
4532-
}
4533-
} else {
4534-
stack_top--;
4535-
postorder_parent = parent[u];
4536-
4537-
/* Visit u */
4538-
tsk_memset(
4539-
allele_count, 0, ((size_t) num_alleles) * sizeof(*allele_count));
4540-
for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) {
4541-
for (allele = 0; allele < num_alleles; allele++) {
4542-
allele_count[allele] += bit_is_set(optimal_set[v], allele);
4543-
}
4544-
}
4545-
if (!(node_flags[u] & TSK_NODE_IS_SAMPLE)) {
4546-
max_allele_count = 0;
4547-
for (allele = 0; allele < num_alleles; allele++) {
4548-
max_allele_count
4549-
= TSK_MAX(max_allele_count, allele_count[allele]);
4550-
}
4551-
for (allele = 0; allele < num_alleles; allele++) {
4552-
if (allele_count[allele] == max_allele_count) {
4553-
optimal_set[u] = set_bit(optimal_set[u], allele);
4554-
}
4555-
}
4556-
}
4557-
}
4558-
}
4557+
ret = tsk_tree_postorder(self, self->virtual_root, nodes, &num_nodes);
4558+
if (ret != 0) {
4559+
goto out;
45594560
}
4560-
4561-
if (!(options & TSK_MM_FIXED_ANCESTRAL_STATE)) {
4562-
optimal_root_set = 0;
4563-
/* TODO it's annoying that this is essentially the same as the
4564-
* visit function above. It would be nice if we had an extra
4565-
* node that was the parent of all roots, then the algorithm
4566-
* would work as-is */
4561+
for (j = 0; j < num_nodes; j++) {
4562+
u = nodes[j];
45674563
tsk_memset(allele_count, 0, ((size_t) num_alleles) * sizeof(*allele_count));
4568-
for (root = self->left_root; root != TSK_NULL; root = right_sib[root]) {
4564+
for (v = left_child[u]; v != TSK_NULL; v = right_sib[v]) {
45694565
for (allele = 0; allele < num_alleles; allele++) {
4570-
allele_count[allele] += bit_is_set(optimal_set[root], allele);
4566+
allele_count[allele] += bit_is_set(optimal_set[v], allele);
45714567
}
45724568
}
4573-
max_allele_count = 0;
4574-
for (allele = 0; allele < num_alleles; allele++) {
4575-
max_allele_count = TSK_MAX(max_allele_count, allele_count[allele]);
4576-
}
4577-
for (allele = 0; allele < num_alleles; allele++) {
4578-
if (allele_count[allele] == max_allele_count) {
4579-
optimal_root_set = set_bit(optimal_root_set, allele);
4569+
/* the virtual root has no flags defined */
4570+
if (u == (tsk_id_t) N || !(node_flags[u] & TSK_NODE_IS_SAMPLE)) {
4571+
max_allele_count = 0;
4572+
for (allele = 0; allele < num_alleles; allele++) {
4573+
max_allele_count = TSK_MAX(max_allele_count, allele_count[allele]);
4574+
}
4575+
for (allele = 0; allele < num_alleles; allele++) {
4576+
if (allele_count[allele] == max_allele_count) {
4577+
optimal_set[u] = set_bit(optimal_set[u], allele);
4578+
}
45804579
}
45814580
}
4582-
ancestral_state = get_smallest_set_bit(optimal_root_set);
4581+
}
4582+
if (!(options & TSK_MM_FIXED_ANCESTRAL_STATE)) {
4583+
ancestral_state = get_smallest_set_bit(optimal_set[N]);
45834584
}
45844585

45854586
num_transitions = 0;
@@ -4622,8 +4623,8 @@ tsk_tree_map_mutations(tsk_tree_t *self, int8_t *genotypes,
46224623
if (preorder_stack != NULL) {
46234624
free(preorder_stack);
46244625
}
4625-
if (postorder_stack != NULL) {
4626-
free(postorder_stack);
4626+
if (nodes != NULL) {
4627+
free(nodes);
46274628
}
46284629
return ret;
46294630
}
@@ -4888,7 +4889,7 @@ fill_kc_vectors(const tsk_tree_t *t, kc_vectors *kc_vecs)
48884889
int ret = 0;
48894890
const tsk_treeseq_t *ts = t->tree_sequence;
48904891

4891-
stack = tsk_malloc(t->num_nodes * sizeof(*stack));
4892+
stack = tsk_malloc(tsk_tree_get_size_bound(t) * sizeof(*stack));
48924893
if (stack == NULL) {
48934894
ret = TSK_ERR_NO_MEMORY;
48944895
goto out;
@@ -5094,7 +5095,7 @@ update_kc_subtree_state(
50945095
tsk_id_t *stack = NULL;
50955096
int ret = 0;
50965097

5097-
stack = tsk_malloc(t->num_nodes * sizeof(*stack));
5098+
stack = tsk_malloc(tsk_tree_get_size_bound(t) * sizeof(*stack));
50985099
if (stack == NULL) {
50995100
ret = TSK_ERR_NO_MEMORY;
51005101
goto out;

c/tskit/trees.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,31 @@ be greater than or equal to ``num_nodes``.
436436
*/
437437
tsk_size_t tsk_tree_get_size_bound(const tsk_tree_t *self);
438438

439+
/**
440+
@brief Returns the sum of the lengths of all branches reachable from
441+
the specified node, or from all roots if node=TSK_NULL.
442+
443+
@rst
444+
Return the total branch length in a particular subtree or of the
445+
entire tree. If the specified node is TSK_NULL (or the virtual
446+
root) the sum of the lengths of all branches reachable from roots
447+
is returned. Branch length is defined as difference between the time
448+
of a node and its parent. The branch length of a root is zero.
449+
450+
Note that if the specified node is internal its branch length is
451+
*not* included, so that, e.g., the total branch length of a
452+
leaf node is zero.
453+
@endrst
454+
455+
@param self A pointer to a tsk_tree_t object.
456+
@param node The tree node to compute branch length or TSK_NULL to return the
457+
total branch length of the tree.
458+
@param ret_tbl A double pointer to store the returned total branch length.
459+
@return 0 on success or a negative value on failure.
460+
*/
461+
int tsk_tree_get_total_branch_length(
462+
const tsk_tree_t *self, tsk_id_t node, double *ret_tbl);
463+
439464
/** @} */
440465

441466
int tsk_tree_set_root_threshold(tsk_tree_t *self, tsk_size_t root_threshold);

python/CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
Roughly a 10X performance increase for "preorder", "postorder", "timeasc"
6161
and "timedesc" (:user:`jeromekelleher`, :pr:`1704`).
6262

63+
- Substantial performance improvement for ``Tree.total_branch_length``
64+
(:user:`jeromekelleher`, :issue:`1794` :pr:`1799`)
6365

6466
--------------------
6567
[0.3.7] - 2021-07-08

0 commit comments

Comments
 (0)