Skip to content

Add function for covariance matrix #898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions c/tests/test_stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,36 @@ test_paper_ex_divergence(void)
tsk_treeseq_free(&ts);
}

static void
test_paper_ex_relatedness(void)
{
tsk_treeseq_t ts;
tsk_id_t samples[] = { 0, 1, 2, 3 };
tsk_size_t sample_set_sizes[] = { 2, 2 };
tsk_id_t set_indexes[] = { 0, 0 };
double result;
int ret;

tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites,
paper_ex_mutations, paper_ex_individuals, NULL, 0);

ret = tsk_treeseq_relatedness(&ts, 2, sample_set_sizes, samples, 1, set_indexes, 0,
NULL, &result, TSK_STAT_SITE);
CU_ASSERT_EQUAL_FATAL(ret, 0);
tsk_treeseq_free(&ts);
}

static void
test_paper_ex_relatedness_errors(void)
{
tsk_treeseq_t ts;

tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites,
paper_ex_mutations, paper_ex_individuals, NULL, 0);
verify_two_way_stat_func_errors(&ts, tsk_treeseq_relatedness);
tsk_treeseq_free(&ts);
}

static void
test_paper_ex_Y2_errors(void)
{
Expand Down Expand Up @@ -1679,6 +1709,8 @@ main(int argc, char **argv)
{ "test_paper_ex_Y1", test_paper_ex_Y1 },
{ "test_paper_ex_divergence_errors", test_paper_ex_divergence_errors },
{ "test_paper_ex_divergence", test_paper_ex_divergence },
{ "test_paper_ex_relatedness_errors", test_paper_ex_relatedness_errors },
{ "test_paper_ex_relatedness", test_paper_ex_relatedness },
{ "test_paper_ex_Y2_errors", test_paper_ex_Y2_errors },
{ "test_paper_ex_Y2", test_paper_ex_Y2 },
{ "test_paper_ex_f2_errors", test_paper_ex_f2_errors },
Expand Down
42 changes: 42 additions & 0 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -2752,6 +2752,48 @@ tsk_treeseq_divergence(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
return ret;
}

static int
relatedness_summary_func(size_t state_dim, const double *state, size_t result_dim,
double *result, void *params)
{
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
const double *x = state;
tsk_id_t i, j;
size_t k;
double sumx = 0;
double meanx;

for (k = 0; k < state_dim; k++) {
sumx += x[k];
}

meanx = sumx / (double) state_dim;
for (k = 0; k < result_dim; k++) {
i = args.set_indexes[2 * k];
j = args.set_indexes[2 * k + 1];
result[k] = (x[i] - meanx) * (x[j] - meanx);
}
return 0;
}

int
tsk_treeseq_relatedness(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows,
const double *windows, double *result, tsk_flags_t options)
{
int ret = 0;
ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples);
if (ret != 0) {
goto out;
}
ret = tsk_treeseq_sample_count_stat(self, num_sample_sets, sample_set_sizes,
sample_sets, num_index_tuples, index_tuples, relatedness_summary_func,
num_windows, windows, result, options);
out:
return ret;
}

static int
Y2_summary_func(size_t TSK_UNUSED(state_dim), const double *state, size_t result_dim,
double *result, void *params)
Expand Down
6 changes: 4 additions & 2 deletions c/tskit/trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,6 @@ int tsk_treeseq_allele_frequency_spectrum(const tsk_treeseq_t *self,
const tsk_id_t *sample_sets, tsk_size_t num_windows, const double *windows,
double *result, tsk_flags_t options);

/* Two way sample set stats */

typedef int general_sample_stat_method(const tsk_treeseq_t *self,
tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes,
const tsk_id_t *sample_sets, tsk_size_t num_indexes, const tsk_id_t *indexes,
Expand All @@ -357,6 +355,10 @@ int tsk_treeseq_f2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows,
const double *windows, double *result, tsk_flags_t options);
int tsk_treeseq_relatedness(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows,
const double *windows, double *result, tsk_flags_t options);

/* Three way sample set stats */
int tsk_treeseq_Y3(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
Expand Down
22 changes: 18 additions & 4 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -6946,7 +6946,7 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd
{
PyObject *ret = NULL;
static char *kwlist[] = { "sample_set_sizes", "sample_sets", "indexes", "windows",
"mode", "span_normalise", NULL };
"mode", "span_normalise", "polarised", NULL };
PyObject *sample_set_sizes = NULL;
PyObject *sample_sets = NULL;
PyObject *indexes = NULL;
Expand All @@ -6960,14 +6960,15 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd
npy_intp *shape;
tsk_flags_t options = 0;
char *mode = NULL;
int span_normalise = 1;
int span_normalise = true;
int polarised = false;
int err;

if (TreeSequence_check_tree_sequence(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOO|si", kwlist, &sample_set_sizes,
&sample_sets, &indexes, &windows, &mode, &span_normalise)) {
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOO|sii", kwlist, &sample_set_sizes,
&sample_sets, &indexes, &windows, &mode, &span_normalise, &polarised)) {
goto out;
}
if (parse_stats_mode(mode, &options) != 0) {
Expand All @@ -6976,6 +6977,9 @@ TreeSequence_k_way_stat_method(TreeSequence *self, PyObject *args, PyObject *kwd
if (span_normalise) {
options |= TSK_STAT_SPAN_NORMALISE;
}
if (polarised) {
options |= TSK_STAT_POLARISED;
}
if (parse_sample_sets(sample_set_sizes, &sample_set_sizes_array, sample_sets,
&sample_sets_array, &num_sample_sets)
!= 0) {
Expand Down Expand Up @@ -7028,6 +7032,12 @@ TreeSequence_divergence(TreeSequence *self, PyObject *args, PyObject *kwds)
return TreeSequence_k_way_stat_method(self, args, kwds, 2, tsk_treeseq_divergence);
}

static PyObject *
TreeSequence_relatedness(TreeSequence *self, PyObject *args, PyObject *kwds)
{
return TreeSequence_k_way_stat_method(self, args, kwds, 2, tsk_treeseq_relatedness);
}

static PyObject *
TreeSequence_Y2(TreeSequence *self, PyObject *args, PyObject *kwds)
{
Expand Down Expand Up @@ -7345,6 +7355,10 @@ static PyMethodDef TreeSequence_methods[] = {
.ml_meth = (PyCFunction) TreeSequence_divergence,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Computes diveregence between sample sets." },
{ .ml_name = "relatedness",
.ml_meth = (PyCFunction) TreeSequence_relatedness,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Computes genetic relatedness between sample sets." },
{ .ml_name = "Y1",
.ml_meth = (PyCFunction) TreeSequence_Y1,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
Expand Down
Loading