Skip to content

Commit 9e47224

Browse files
author
brieuc-mac
committed
Adjust tests and change default arguments
1 parent dffe79b commit 9e47224

File tree

3 files changed

+50
-132
lines changed

3 files changed

+50
-132
lines changed

c/tskit/trees.c

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2760,27 +2760,17 @@ relatedness_summary_func(size_t state_dim, const double *state,
27602760
const double *x = state;
27612761
tsk_id_t i, j;
27622762
size_t k;
2763-
int c = 0;
27642763
double sumx = 0;
27652764
double meanx;
2766-
double num = 0;
27672765

27682766
for (k = 0; k < state_dim; k++) {
27692767
sumx += x[k];
27702768
}
27712769

2772-
for (k = 0; k < state_dim; k++) {
2773-
num += args.sample_set_sizes[k];
2774-
}
2775-
2776-
if (num != sumx) {
2777-
c = 1;
2778-
}
27792770
meanx = sumx / (double) state_dim;
27802771
for (k = 0; k < result_dim; k++) {
27812772
i = args.set_indexes[2 * k];
27822773
j = args.set_indexes[2 * k + 1];
2783-
// result[k] = x[i] * x[j] * c;
27842774
result[k] = (x[i] - meanx) * (x[j] - meanx);
27852775
}
27862776
return 0;

python/tests/test_covariance.py

Lines changed: 38 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -29,83 +29,22 @@
2929

3030
import msprime
3131
import numpy as np
32-
import pytest
3332

34-
import tests.tsutil as tsutil
3533
import tskit
3634

3735

38-
def check_cov_tree_inputs(tree):
39-
if not len(tree.roots) == 1:
40-
raise ValueError("Trees must have one root")
41-
for u in tree.nodes():
42-
if tree.num_children(u) == 1:
43-
raise ValueError("Unary nodes are not supported")
44-
45-
46-
def naive_tree_covariance(tree):
47-
"""
48-
Returns the (branch) covariance matrix for the sample nodes in a tree. The
49-
covariance between a pair of nodes is the distance from the root to their
50-
most recent common ancestor.
51-
"""
52-
samples = tree.tree_sequence.samples()
53-
check_cov_tree_inputs(tree)
54-
n = samples.shape[0]
55-
cov = np.zeros((n, n))
56-
for n1, n2 in itertools.combinations_with_replacement(range(n), 2):
57-
mrca = tree.mrca(samples[n1], samples[n2])
58-
cov[n1, n2] = tree.time(tree.root) - tree.time(mrca)
59-
cov[n2, n1] = cov[n1, n2]
60-
return cov
61-
62-
63-
def naive_ts_covariance(ts):
64-
"""
65-
Returns the (branch) covariance matrix for the sample nodes in a tree
66-
sequence. The covariance between a pair of nodes is the weighted sum of the
67-
tree covariance, with the weights given by the spans of the trees.
68-
"""
69-
samples = ts.samples()
70-
n = samples.shape[0]
71-
cov = np.zeros((n, n))
72-
for tree in ts.trees():
73-
cov += naive_tree_covariance(tree) * tree.span
74-
return cov / ts.sequence_length
75-
76-
77-
def naive_genotype_covariance(ts):
36+
def naive_genotype_covariance(ts, proportion=False):
7837
G = ts.genotype_matrix()
79-
# p = G.shape[0]
38+
denominator = ts.sequence_length
39+
if proportion:
40+
all_samples = ts.samples()
41+
num = ts.segregating_sites(all_samples)
42+
denominator = denominator * num
8043
G = G.T - np.mean(G, axis=1)
81-
return G @ G.T # / p
82-
83-
84-
def genetic_relatedness(ts, mode="site", polarised=True):
85-
# NOTE: I'm outputting a matrix here just for convenience; the proposal
86-
# is that the tskit method *not* output a matrix, and use the indices argument
87-
sample_sets = [[u] for u in ts.samples()]
88-
# sample_sets = [[0], [1]]
89-
n = len(sample_sets)
90-
num_samples = sum(map(len, sample_sets))
91-
92-
def f(x):
93-
# x[i] gives the number of descendants in sample set i below the branch
94-
return np.array(
95-
[x[i] * x[j] * (sum(x) != num_samples) for i in range(n) for j in range(n)]
96-
)
44+
return G @ G.T / denominator
9745

98-
return ts.sample_count_stat(
99-
sample_sets,
100-
f,
101-
output_dim=n * n,
102-
mode=mode,
103-
span_normalise=True,
104-
polarised=polarised,
105-
).reshape((n, n))
10646

107-
108-
def genotype_relatedness(ts, polarised=False):
47+
def genotype_relatedness(ts, polarised=False, proportion=False):
10948
n = ts.num_samples
11049
sample_sets = [[u] for u in ts.samples()]
11150

@@ -118,20 +57,25 @@ def f(x):
11857
]
11958
)
12059

60+
denominator = 2 - polarised
61+
if proportion:
62+
all_samples = list({u for s in sample_sets for u in s})
63+
num = ts.segregating_sites(all_samples)
64+
denominator = denominator * num
12165
return (
12266
ts.sample_count_stat(
12367
sample_sets,
12468
f,
12569
output_dim=n * n,
12670
mode="site",
127-
span_normalise=False,
71+
span_normalise=True,
12872
polarised=polarised,
12973
).reshape((n, n))
130-
/ 2
74+
/ denominator
13175
)
13276

13377

134-
def c_genotype_relatedness(ts, sample_sets, indexes):
78+
def c_genotype_relatedness(ts, sample_sets, indexes, polarised=False, proportion=False):
13579
m = len(indexes)
13680
state_dim = len(sample_sets)
13781

@@ -144,17 +88,25 @@ def f(x):
14488
for k in range(m):
14589
i = indexes[k][0]
14690
j = indexes[k][1]
147-
result[k] = (x[i] - meanx) * (x[j] - meanx) / 2
91+
result[k] = (x[i] - meanx) * (x[j] - meanx)
14892
return result
14993

150-
return ts.sample_count_stat(
151-
sample_sets,
152-
f,
153-
output_dim=m,
154-
mode="site",
155-
span_normalise=False,
156-
polarised=False,
157-
strict=False,
94+
denominator = 2 - polarised
95+
if proportion:
96+
all_samples = list({u for s in sample_sets for u in s})
97+
num = ts.segregating_sites(all_samples)
98+
denominator = denominator * num
99+
return (
100+
ts.sample_count_stat(
101+
sample_sets,
102+
f,
103+
output_dim=m,
104+
mode="site",
105+
span_normalise=True,
106+
polarised=False,
107+
strict=False,
108+
)
109+
/ denominator
158110
)
159111

160112

@@ -164,9 +116,7 @@ class TestCovariance(unittest.TestCase):
164116
"""
165117

166118
def verify(self, ts):
167-
# cov1 = naive_ts_covariance(ts)
168119
cov1 = naive_genotype_covariance(ts)
169-
# cov2 = genetic_relatedness(ts)
170120
cov2 = genotype_relatedness(ts)
171121
sample_sets = [[u] for u in ts.samples()]
172122
n = len(sample_sets)
@@ -176,28 +126,15 @@ def verify(self, ts):
176126
cov3 = np.zeros((n, n))
177127
cov4 = np.zeros((n, n))
178128
i_upper = np.triu_indices(n)
179-
cov3[i_upper] = (
180-
ts.genetic_relatedness(
181-
sample_sets, indexes, mode="site", span_normalise=False
182-
)
183-
/ 2
184-
) # NOTE: divided by 2 to reflect unpolarised
129+
cov3[i_upper] = c_genotype_relatedness(ts, sample_sets, indexes)
185130
cov3 = cov3 + cov3.T - np.diag(cov3.diagonal())
186-
cov4[i_upper] = c_genotype_relatedness(ts, sample_sets, indexes)
131+
cov4[i_upper] = ts.genetic_relatedness(
132+
sample_sets, indexes, mode="site", span_normalise=True
133+
)
187134
cov4 = cov4 + cov4.T - np.diag(cov4.diagonal())
188-
# assert np.allclose(cov2, cov3)
189135
assert np.allclose(cov1, cov2)
190-
assert np.allclose(cov1, cov4)
191136
assert np.allclose(cov1, cov3)
192-
193-
def verify_errors(self, ts):
194-
with pytest.raises(ValueError):
195-
naive_ts_covariance(ts)
196-
197-
def test_errors_multiroot_tree(self):
198-
ts = msprime.simulate(15, random_seed=10, mutation_rate=1)
199-
ts = tsutil.decapitate(ts, ts.num_edges // 2)
200-
self.verify_errors(ts)
137+
assert np.allclose(cov1, cov4)
201138

202139
def test_single_coalescent_tree(self):
203140
ts = msprime.simulate(10, random_seed=1, length=10, mutation_rate=1)

python/tskit/trees.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5499,7 +5499,14 @@ def divergence(
54995499
# return A
55005500

55015501
def genetic_relatedness(
5502-
self, sample_sets, indexes=None, windows=None, mode="site", span_normalise=True
5502+
self,
5503+
sample_sets,
5504+
indexes=None,
5505+
windows=None,
5506+
mode="site",
5507+
span_normalise=True,
5508+
polarised=False,
5509+
proportion=True,
55035510
):
55045511
"""
55055512
Computes genetic relatedness between (and within) pairs of
@@ -5512,27 +5519,9 @@ def genetic_relatedness(
55125519
:ref:`windows <sec_stats_windows>`,
55135520
:ref:`mode <sec_stats_mode>`,
55145521
:ref:`span normalise <sec_stats_span_normalise>`,
5522+
:ref:`polarised <sec_stats_polarisation>`,
55155523
and :ref:`return value <sec_stats_output_format>`.
55165524
5517-
What is computed depends on ``mode``:
5518-
5519-
"site"
5520-
Mean pairwise genetic divergence: the average across distinct,
5521-
randomly chosen pairs of chromosomes (one from each sample set), of
5522-
the density of sites at which the two carry different alleles, per
5523-
unit of chromosome length.
5524-
5525-
"branch"
5526-
Mean distance in the tree: the average across distinct, randomly
5527-
chosen pairs of chromsomes (one from each sample set) and locations
5528-
in the window, of the mean distance in the tree between the two
5529-
samples (in units of time).
5530-
5531-
"node"
5532-
For each node, the proportion of genome on which the node is an ancestor to
5533-
only one of a random pair (one from each sample set), averaged over
5534-
choices of pair.
5535-
55365525
:param list sample_sets: A list of lists of Node IDs, specifying the
55375526
groups of nodes to compute the statistic with.
55385527
:param list indexes: A list of 2-tuples, or None.
@@ -5542,6 +5531,8 @@ def genetic_relatedness(
55425531
(defaults to "site").
55435532
:param bool span_normalise: Whether to divide the result by the span of the
55445533
window (defaults to True).
5534+
:param bool proportion: Whether to divide the result by the number of
5535+
segregating sites (defaults to True).
55455536
:return: A ndarray with shape equal to (num windows, num statistics).
55465537
"""
55475538
return self.__k_way_sample_set_stat(
@@ -5552,7 +5543,7 @@ def genetic_relatedness(
55525543
windows=windows,
55535544
mode=mode,
55545545
span_normalise=span_normalise,
5555-
polarised=False,
5546+
polarised=polarised,
55565547
)
55575548

55585549
def trait_covariance(self, W, windows=None, mode="site", span_normalise=True):

0 commit comments

Comments
 (0)