-
Notifications
You must be signed in to change notification settings - Fork 77
Closed
Description
Building on #898 and using the 'matrix multiplication' in WIP #1246 (i.e. genetic_relatedness_weighted), we're trying to implement PCA for tskit. This appears to be working 🎉 but is rather slow compared to scikit-allel. See the following code for a small reprex, where scikit-allel is approximately 20 times faster than our current tskit implementation.
import itertools
import time
import allel
import msprime
import numpy as np
import stdpopsim
import tskit
from scipy.sparse.linalg import eigsh
from scipy.sparse.linalg import LinearOperator
# Get realistic simulation parameters
species = stdpopsim.get_species("HomSap")
chrom = species.genome.get_chromosome("chr20")
gm = None
model = species.get_demographic_model("OutOfAfrica_3G09")
engine = stdpopsim.get_engine("msprime")
n_pop = model.num_populations
# Set parameters
seed = 1
n_ind = 64
mult = 0.05
mut_rate = 1e-8
recomb_rate=chrom.recombination_rate
n_tot = n_pop * n_ind
n_hap = 2 * n_ind
samples = model.get_samples(n_hap, n_hap, n_hap)
recomb_map = msprime.RecombinationMap.uniform_map(
chrom.length * mult, recomb_rate
)
contig = stdpopsim.Contig(
recombination_map=recomb_map, mutation_rate=mut_rate, genetic_map=gm
)
# Simulate tree sequence
ts = engine.simulate(model, contig, samples)
sample_sets = [(2 * i, (2 * i) + 1) for i in range(n_tot)]
W_samples = np.array(
[[float(u in A) for A in sample_sets] for u in ts.samples()]
)
indexes = [(i, n_tot) for i in range(len(sample_sets))]
def mat_mul_stat(a):
W = np.c_[W_samples, W_samples @ a]
return ts.genetic_relatedness_weighted(
W, indexes=indexes, mode="site", span_normalise=False
)
# Linear operator
start_time = time.time()
A = LinearOperator((n_tot, n_tot), matvec=mat_mul_stat)
eigval_linop, eigvec_linop = eigsh(A)
linop_pc = (A @ eigvec_linop[:, ::-1]) / np.sqrt(eigval_linop[::-1])
end_time = time.time()
time_linop = end_time - start_time
# Direct genotype matrix
haps = allel.HaplotypeArray(ts.genotype_matrix())
gns = haps.to_genotypes(ploidy=2).to_n_ref()
start_time = time.time()
allel_pc = allel.pca(gns, n_components=6, scaler=None)[0]
end_time = time.time()
time_allel = end_time - start_time
print(time_linop)
# 5.768216848373413
print(time_allel)
# 0.2641010284423828
Metadata
Metadata
Assignees
Labels
No labels
Type
Projects
Status
Done