29
29
30
30
import msprime
31
31
import numpy as np
32
- import pytest
33
32
34
- import tests .tsutil as tsutil
35
33
import tskit
36
34
37
35
38
- def check_cov_tree_inputs (tree ):
39
- if not len (tree .roots ) == 1 :
40
- raise ValueError ("Trees must have one root" )
41
- for u in tree .nodes ():
42
- if tree .num_children (u ) == 1 :
43
- raise ValueError ("Unary nodes are not supported" )
44
-
45
-
46
- def naive_tree_covariance (tree ):
47
- """
48
- Returns the (branch) covariance matrix for the sample nodes in a tree. The
49
- covariance between a pair of nodes is the distance from the root to their
50
- most recent common ancestor.
51
- """
52
- samples = tree .tree_sequence .samples ()
53
- check_cov_tree_inputs (tree )
54
- n = samples .shape [0 ]
55
- cov = np .zeros ((n , n ))
56
- for n1 , n2 in itertools .combinations_with_replacement (range (n ), 2 ):
57
- mrca = tree .mrca (samples [n1 ], samples [n2 ])
58
- cov [n1 , n2 ] = tree .time (tree .root ) - tree .time (mrca )
59
- cov [n2 , n1 ] = cov [n1 , n2 ]
60
- return cov
61
-
62
-
63
- def naive_ts_covariance (ts ):
64
- """
65
- Returns the (branch) covariance matrix for the sample nodes in a tree
66
- sequence. The covariance between a pair of nodes is the weighted sum of the
67
- tree covariance, with the weights given by the spans of the trees.
68
- """
69
- samples = ts .samples ()
70
- n = samples .shape [0 ]
71
- cov = np .zeros ((n , n ))
72
- for tree in ts .trees ():
73
- cov += naive_tree_covariance (tree ) * tree .span
74
- return cov / ts .sequence_length
75
-
76
-
77
- def naive_genotype_covariance (ts ):
36
+ def naive_genotype_covariance (ts , proportion = False ):
78
37
G = ts .genotype_matrix ()
79
- # p = G.shape[0]
38
+ denominator = ts .sequence_length
39
+ if proportion :
40
+ all_samples = ts .samples ()
41
+ num = ts .segregating_sites (all_samples )
42
+ denominator = denominator * num
80
43
G = G .T - np .mean (G , axis = 1 )
81
- return G @ G .T # / p
82
-
83
-
84
- def genetic_relatedness (ts , mode = "site" , polarised = True ):
85
- # NOTE: I'm outputting a matrix here just for convenience; the proposal
86
- # is that the tskit method *not* output a matrix, and use the indices argument
87
- sample_sets = [[u ] for u in ts .samples ()]
88
- # sample_sets = [[0], [1]]
89
- n = len (sample_sets )
90
- num_samples = sum (map (len , sample_sets ))
91
-
92
- def f (x ):
93
- # x[i] gives the number of descendants in sample set i below the branch
94
- return np .array (
95
- [x [i ] * x [j ] * (sum (x ) != num_samples ) for i in range (n ) for j in range (n )]
96
- )
44
+ return G @ G .T / denominator
97
45
98
- return ts .sample_count_stat (
99
- sample_sets ,
100
- f ,
101
- output_dim = n * n ,
102
- mode = mode ,
103
- span_normalise = True ,
104
- polarised = polarised ,
105
- ).reshape ((n , n ))
106
46
107
-
108
- def genotype_relatedness (ts , polarised = False ):
47
+ def genotype_relatedness (ts , polarised = False , proportion = False ):
109
48
n = ts .num_samples
110
49
sample_sets = [[u ] for u in ts .samples ()]
111
50
@@ -118,20 +57,25 @@ def f(x):
118
57
]
119
58
)
120
59
60
+ denominator = 2 - polarised
61
+ if proportion :
62
+ all_samples = list ({u for s in sample_sets for u in s })
63
+ num = ts .segregating_sites (all_samples )
64
+ denominator = denominator * num
121
65
return (
122
66
ts .sample_count_stat (
123
67
sample_sets ,
124
68
f ,
125
69
output_dim = n * n ,
126
70
mode = "site" ,
127
- span_normalise = False ,
71
+ span_normalise = True ,
128
72
polarised = polarised ,
129
73
).reshape ((n , n ))
130
- / 2
74
+ / denominator
131
75
)
132
76
133
77
134
- def c_genotype_relatedness (ts , sample_sets , indexes ):
78
+ def c_genotype_relatedness (ts , sample_sets , indexes , polarised = False , proportion = False ):
135
79
m = len (indexes )
136
80
state_dim = len (sample_sets )
137
81
@@ -144,17 +88,25 @@ def f(x):
144
88
for k in range (m ):
145
89
i = indexes [k ][0 ]
146
90
j = indexes [k ][1 ]
147
- result [k ] = (x [i ] - meanx ) * (x [j ] - meanx ) / 2
91
+ result [k ] = (x [i ] - meanx ) * (x [j ] - meanx )
148
92
return result
149
93
150
- return ts .sample_count_stat (
151
- sample_sets ,
152
- f ,
153
- output_dim = m ,
154
- mode = "site" ,
155
- span_normalise = False ,
156
- polarised = False ,
157
- strict = False ,
94
+ denominator = 2 - polarised
95
+ if proportion :
96
+ all_samples = list ({u for s in sample_sets for u in s })
97
+ num = ts .segregating_sites (all_samples )
98
+ denominator = denominator * num
99
+ return (
100
+ ts .sample_count_stat (
101
+ sample_sets ,
102
+ f ,
103
+ output_dim = m ,
104
+ mode = "site" ,
105
+ span_normalise = True ,
106
+ polarised = False ,
107
+ strict = False ,
108
+ )
109
+ / denominator
158
110
)
159
111
160
112
@@ -164,9 +116,7 @@ class TestCovariance(unittest.TestCase):
164
116
"""
165
117
166
118
def verify (self , ts ):
167
- # cov1 = naive_ts_covariance(ts)
168
119
cov1 = naive_genotype_covariance (ts )
169
- # cov2 = genetic_relatedness(ts)
170
120
cov2 = genotype_relatedness (ts )
171
121
sample_sets = [[u ] for u in ts .samples ()]
172
122
n = len (sample_sets )
@@ -176,28 +126,15 @@ def verify(self, ts):
176
126
cov3 = np .zeros ((n , n ))
177
127
cov4 = np .zeros ((n , n ))
178
128
i_upper = np .triu_indices (n )
179
- cov3 [i_upper ] = (
180
- ts .genetic_relatedness (
181
- sample_sets , indexes , mode = "site" , span_normalise = False
182
- )
183
- / 2
184
- ) # NOTE: divided by 2 to reflect unpolarised
129
+ cov3 [i_upper ] = c_genotype_relatedness (ts , sample_sets , indexes )
185
130
cov3 = cov3 + cov3 .T - np .diag (cov3 .diagonal ())
186
- cov4 [i_upper ] = c_genotype_relatedness (ts , sample_sets , indexes )
131
+ cov4 [i_upper ] = ts .genetic_relatedness (
132
+ sample_sets , indexes , mode = "site" , span_normalise = True
133
+ )
187
134
cov4 = cov4 + cov4 .T - np .diag (cov4 .diagonal ())
188
- # assert np.allclose(cov2, cov3)
189
135
assert np .allclose (cov1 , cov2 )
190
- assert np .allclose (cov1 , cov4 )
191
136
assert np .allclose (cov1 , cov3 )
192
-
193
- def verify_errors (self , ts ):
194
- with pytest .raises (ValueError ):
195
- naive_ts_covariance (ts )
196
-
197
- def test_errors_multiroot_tree (self ):
198
- ts = msprime .simulate (15 , random_seed = 10 , mutation_rate = 1 )
199
- ts = tsutil .decapitate (ts , ts .num_edges // 2 )
200
- self .verify_errors (ts )
137
+ assert np .allclose (cov1 , cov4 )
201
138
202
139
def test_single_coalescent_tree (self ):
203
140
ts = msprime .simulate (10 , random_seed = 1 , length = 10 , mutation_rate = 1 )
0 commit comments