diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 63ac3427cf..b6cd6d770c 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -1333,6 +1333,36 @@ test_paper_ex_divergence(void) tsk_treeseq_free(&ts); } +static void +test_paper_ex_relatedness(void) +{ + tsk_treeseq_t ts; + tsk_id_t samples[] = { 0, 1, 2, 3 }; + tsk_size_t sample_set_sizes[] = { 2, 2 }; + tsk_id_t set_indexes[] = { 0, 0 }; + double result; + int ret; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + ret = tsk_treeseq_relatedness(&ts, 2, sample_set_sizes, samples, 1, set_indexes, 0, + NULL, &result, TSK_STAT_SITE); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts); +} + +static void +test_paper_ex_relatedness_errors(void) +{ + tsk_treeseq_t ts; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + verify_two_way_stat_func_errors(&ts, tsk_treeseq_relatedness); + tsk_treeseq_free(&ts); +} + static void test_paper_ex_Y2_errors(void) { @@ -1679,6 +1709,8 @@ main(int argc, char **argv) { "test_paper_ex_Y1", test_paper_ex_Y1 }, { "test_paper_ex_divergence_errors", test_paper_ex_divergence_errors }, { "test_paper_ex_divergence", test_paper_ex_divergence }, + { "test_paper_ex_relatedness_errors", test_paper_ex_relatedness_errors }, + { "test_paper_ex_relatedness", test_paper_ex_relatedness }, { "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors }, { "test_paper_ex_Y2", test_paper_ex_Y2 }, { "test_paper_ex_f2_errors", test_paper_ex_f2_errors }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 11d2890f97..4a130c0d47 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2752,6 +2752,48 @@ tsk_treeseq_divergence(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, return ret; } +static int +relatedness_summary_func(size_t state_dim, const double *state, size_t result_dim, + double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *x = state; + tsk_id_t i, j; + size_t k; + double sumx = 0; + double meanx; + + for (k = 0; k < state_dim; k++) { + sumx += x[k]; + } + + meanx = sumx / (double) state_dim; + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + result[k] = (x[i] - meanx) * (x[j] - meanx); + } + return 0; +} + +int +tsk_treeseq_relatedness(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, + const double *windows, double *result, tsk_flags_t options) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, relatedness_summary_func, + num_windows, windows, result, options); +out: + return ret; +} + static int Y2_summary_func(size_t TSK_UNUSED(state_dim), const double *state, size_t result_dim, double *result, void *params) diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 8eb8427478..762d8afcda 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -338,8 +338,6 @@ int tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self, const tsk_id_t *sample_sets, tsk_size_t num_windows, const double *windows, double *result, tsk_flags_t options); -/* Two way sample set stats */ - typedef int general_sample_stat_method(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_indexes, const tsk_id_t *indexes, @@ -357,6 +355,10 @@ int tsk_treeseq_f2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, double *result, tsk_flags_t options); +int tsk_treeseq_relatedness(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, + const double *windows, double *result, tsk_flags_t options); /* Three way sample set stats */ int tsk_treeseq_Y3(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 12754d7ea5..bfc695ac65 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -6946,7 +6946,7 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd { PyObject *ret = NULL; static char *kwlist[] = { "sample_set_sizes", "sample_sets", "indexes", "windows", - "mode", "span_normalise", NULL }; + "mode", "span_normalise", "polarised", NULL }; PyObject *sample_set_sizes = NULL; PyObject *sample_sets = NULL; PyObject *indexes = NULL; @@ -6960,14 +6960,15 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd npy_intp *shape; tsk_flags_t options = 0; char *mode = NULL; - int span_normalise = 1; + int span_normalise = true; + int polarised = false; int err; if (TreeSequence_check_tree_sequence(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOO|si", kwlist, &sample_set_sizes, - &sample_sets, &indexes, &windows, &mode, &span_normalise)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOO|sii", kwlist, &sample_set_sizes, + &sample_sets, &indexes, &windows, &mode, &span_normalise, &polarised)) { goto out; } if (parse_stats_mode(mode, &options) != 0) { @@ -6976,6 +6977,9 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd if (span_normalise) { options |= TSK_STAT_SPAN_NORMALISE; } + if (polarised) { + options |= TSK_STAT_POLARISED; + } if (parse_sample_sets(sample_set_sizes, &sample_set_sizes_array, sample_sets, &sample_sets_array, &num_sample_sets) != 0) { @@ -7028,6 +7032,12 @@ TreeSequence_divergence(TreeSequence *self, PyObject *args, PyObject *kwds) return TreeSequence_k_way_stat_method(self, args, kwds, 2, tsk_treeseq_divergence); } +static PyObject * +TreeSequence_relatedness(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_stat_method(self, args, kwds, 2, tsk_treeseq_relatedness); +} + static PyObject * TreeSequence_Y2(TreeSequence *self, PyObject *args, PyObject *kwds) { @@ -7345,6 +7355,10 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_divergence, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes diveregence between sample sets." }, + { .ml_name = "relatedness", + .ml_meth = (PyCFunction) TreeSequence_relatedness, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes genetic relatedness between sample sets." }, { .ml_name = "Y1", .ml_meth = (PyCFunction) TreeSequence_Y1, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tests/test_covariance.py b/python/tests/test_covariance.py new file mode 100644 index 0000000000..11891cc29e --- /dev/null +++ b/python/tests/test_covariance.py @@ -0,0 +1,242 @@ +# MIT License +# +# Copyright (c) 2018-2020 Tskit Developers +# Copyright (c) 2016-2017 University of Oxford +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for covariance computation. +""" +import io +import itertools +import unittest + +import msprime +import numpy as np + +import tskit + + +def naive_genotype_covariance(ts, proportion=False): + G = ts.genotype_matrix() + denominator = ts.sequence_length + if proportion: + all_samples = ts.samples() + num = ts.segregating_sites(all_samples) + denominator = denominator * num + G = G.T - np.mean(G, axis=1) + return G @ G.T / denominator + + +def genotype_relatedness(ts, polarised=False, proportion=False): + n = ts.num_samples + sample_sets = [[u] for u in ts.samples()] + + def f(x): + return np.array( + [ + (x[i] - sum(x) / n) * (x[j] - sum(x) / n) + for i in range(n) + for j in range(n) + ] + ) + + denominator = 2 - polarised + if proportion: + all_samples = list({u for s in sample_sets for u in s}) + num = ts.segregating_sites(all_samples) + denominator = denominator * num + return ( + ts.sample_count_stat( + sample_sets, + f, + output_dim=n * n, + mode="site", + span_normalise=True, + polarised=polarised, + ).reshape((n, n)) + / denominator + ) + + +def c_genotype_relatedness(ts, sample_sets, indexes, polarised=False, proportion=False): + m = len(indexes) + state_dim = len(sample_sets) + + def f(x): + sumx = 0 + for k in range(state_dim): + sumx += x[k] + meanx = sumx / state_dim + result = np.zeros(m) + for k in range(m): + i = indexes[k][0] + j = indexes[k][1] + result[k] = (x[i] - meanx) * (x[j] - meanx) + return result + + denominator = 2 - polarised + if proportion: + all_samples = list({u for s in sample_sets for u in s}) + num = ts.segregating_sites(all_samples) + denominator = denominator * num + return ( + ts.sample_count_stat( + sample_sets, + f, + output_dim=m, + mode="site", + span_normalise=True, + polarised=False, + strict=False, + ) + / denominator + ) + + +class TestCovariance(unittest.TestCase): + """ + Tests on covariance matrix computation + """ + + def verify(self, ts): + cov1 = naive_genotype_covariance(ts) + cov2 = genotype_relatedness(ts) + sample_sets = [[u] for u in ts.samples()] + n = len(sample_sets) + indexes = [ + (n1, n2) for n1, n2 in itertools.combinations_with_replacement(range(n), 2) + ] + cov3 = np.zeros((n, n)) + cov4 = np.zeros((n, n)) + i_upper = np.triu_indices(n) + cov3[i_upper] = c_genotype_relatedness(ts, sample_sets, indexes) + cov3 = cov3 + cov3.T - np.diag(cov3.diagonal()) + cov4[i_upper] = ts.genetic_relatedness( + sample_sets, indexes, mode="site", span_normalise=True + ) + cov4 = cov4 + cov4.T - np.diag(cov4.diagonal()) + assert np.allclose(cov1, cov2) + assert np.allclose(cov1, cov3) + assert np.allclose(cov1, cov4) + + def test_single_coalescent_tree(self): + ts = msprime.simulate(10, random_seed=1, length=10, mutation_rate=1) + self.verify(ts) + + def test_coalescent_trees(self): + ts = msprime.simulate( + 8, recombination_rate=5, random_seed=1, length=2, mutation_rate=1 + ) + assert ts.num_trees > 2 + self.verify(ts) + + def test_internal_samples(self): + nodes = io.StringIO( + """\ + id is_sample time + 0 0 0 + 1 1 0.1 + 2 1 0.1 + 3 1 0.2 + 4 0 0.4 + 5 1 0.5 + 6 0 0.7 + 7 0 1.0 + 8 0 0.8 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0.0 0.2 4 2,3 + 0.2 0.8 4 0,2 + 0.8 1.0 4 2,3 + 0.0 1.0 5 1,4 + 0.8 1.0 6 0,5 + 0.2 0.8 8 3,5 + 0.0 0.2 7 0,5 + """ + ) + sites = io.StringIO( + """\ + position ancestral_state + 0.1 0 + 0.5 0 + 0.9 0 + """ + ) + mutations = io.StringIO( + """\ + site node derived_state + 0 1 1 + 1 3 1 + 2 5 1 + """ + ) + ts = tskit.load_text( + nodes=nodes, edges=edges, sites=sites, mutations=mutations, strict=False + ) + self.verify(ts) + + def validate_trees(self, n): + for seed in range(1, 10): + ts = msprime.simulate( + n, random_seed=seed, recombination_rate=1, mutation_rate=1 + ) + self.verify(ts) + + def test_sample_5(self): + self.validate_trees(5) + + def test_sample_10(self): + self.validate_trees(10) + + def test_sample_20(self): + self.validate_trees(20) + + def validate_nonbinary_trees(self, n): + demographic_events = [ + msprime.SimpleBottleneck(0.02, 0, proportion=0.25), + msprime.SimpleBottleneck(0.2, 0, proportion=1), + ] + + for seed in range(1, 10): + ts = msprime.simulate( + n, + random_seed=seed, + demographic_events=demographic_events, + recombination_rate=1, + mutation_rate=5, + ) + # Check if this is really nonbinary + found = False + for edgeset in ts.edgesets(): + if len(edgeset.children) > 2: + found = True + break + assert found + + self.verify(ts) + + def test_non_binary_sample_10(self): + self.validate_nonbinary_trees(10) + + def test_non_binary_sample_20(self): + self.validate_nonbinary_trees(20) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 3e597148ce..741597ff35 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -5308,6 +5308,7 @@ def __k_way_sample_set_stat( windows=None, mode=None, span_normalise=True, + polarised=False, ): sample_set_sizes = np.array( [len(sample_set) for sample_set in sample_sets], dtype=np.uint32 @@ -5340,6 +5341,7 @@ def __k_way_sample_set_stat( indexes, mode=mode, span_normalise=span_normalise, + polarised=polarised, ) if drop_dimension: stat = stat.reshape(stat.shape[:-1]) @@ -5496,6 +5498,57 @@ def divergence( # k += 1 # return A + def genetic_relatedness( + self, + sample_sets, + indexes=None, + windows=None, + mode="site", + span_normalise=True, + polarised=False, + proportion=True, + ): + """ + Computes genetic relatedness between (and within) pairs of + sets of nodes from ``sample_sets``. + Operates on ``k = 2`` sample sets at a time; please see the + :ref:`multi-way statistics ` + section for details on how the ``sample_sets`` and ``indexes`` arguments are + interpreted and how they interact with the dimensions of the output array. + See the :ref:`statistics interface ` section for details on + :ref:`windows `, + :ref:`mode `, + :ref:`span normalise `, + :ref:`polarised `, + and :ref:`return value `. + + :param list sample_sets: A list of lists of Node IDs, specifying the + groups of nodes to compute the statistic with. + :param list indexes: A list of 2-tuples, or None. + :param list windows: An increasing list of breakpoints between the windows + to compute the statistic in. + :param str mode: A string giving the "type" of the statistic to be computed + (defaults to "site"). + :param bool span_normalise: Whether to divide the result by the span of the + window (defaults to True). + :param bool proportion: Whether to divide the result by the number of + segregating sites (defaults to True). + :return: A ndarray with shape equal to (num windows, num statistics). + """ + return ( + self.__k_way_sample_set_stat( + self._ll_tree_sequence.relatedness, + 2, + sample_sets, + indexes=indexes, + windows=windows, + mode=mode, + span_normalise=span_normalise, + polarised=polarised, + ) + / 2 + ) + def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): """ Computes the mean squared covariances between each of the columns of ``W``