Skip to content

ts-PCA performance is slow compared scikit-allel #1743

@brieuclehmann

Description

@brieuclehmann

Building on #898 and using the 'matrix multiplication' in WIP #1246 (i.e. genetic_relatedness_weighted), we're trying to implement PCA for tskit. This appears to be working 🎉 but is rather slow compared to scikit-allel. See the following code for a small reprex, where scikit-allel is approximately 20 times faster than our current tskit implementation.

import itertools
import time

import allel
import msprime
import numpy as np
import stdpopsim
import tskit
from scipy.sparse.linalg import eigsh
from scipy.sparse.linalg import LinearOperator

# Get realistic simulation parameters
species = stdpopsim.get_species("HomSap")
chrom = species.genome.get_chromosome("chr20")
gm = None
model = species.get_demographic_model("OutOfAfrica_3G09")
engine = stdpopsim.get_engine("msprime")
n_pop = model.num_populations

# Set parameters
seed = 1
n_ind = 64
mult = 0.05
mut_rate = 1e-8
recomb_rate=chrom.recombination_rate
    
n_tot = n_pop * n_ind
n_hap = 2 * n_ind
samples = model.get_samples(n_hap, n_hap, n_hap)

recomb_map = msprime.RecombinationMap.uniform_map(
    chrom.length * mult, recomb_rate
)

contig = stdpopsim.Contig(
    recombination_map=recomb_map, mutation_rate=mut_rate, genetic_map=gm
)

# Simulate tree sequence
ts = engine.simulate(model, contig, samples)

sample_sets = [(2 * i, (2 * i) + 1) for i in range(n_tot)]
W_samples = np.array(
    [[float(u in A) for A in sample_sets] for u in ts.samples()]
)
indexes = [(i, n_tot) for i in range(len(sample_sets))]

def mat_mul_stat(a):
    W = np.c_[W_samples, W_samples @ a]
    return ts.genetic_relatedness_weighted(
        W, indexes=indexes, mode="site", span_normalise=False
    )

# Linear operator
start_time = time.time()
A = LinearOperator((n_tot, n_tot), matvec=mat_mul_stat)
eigval_linop, eigvec_linop = eigsh(A)
linop_pc = (A @ eigvec_linop[:, ::-1]) / np.sqrt(eigval_linop[::-1])
end_time = time.time()
time_linop = end_time - start_time

# Direct genotype matrix
haps = allel.HaplotypeArray(ts.genotype_matrix())
gns = haps.to_genotypes(ploidy=2).to_n_ref()
start_time = time.time()
allel_pc = allel.pca(gns, n_components=6, scaler=None)[0]
end_time = time.time()
time_allel = end_time - start_time

print(time_linop)
# 5.768216848373413
print(time_allel)
# 0.2641010284423828

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions