Skip to content

Commit 36bcfaa

Browse files
authored
fix: remove scipy, read csr matrix manually (#117)
* fix: remove scipy, read csr matrix manually
1 parent 218c775 commit 36bcfaa

File tree

3 files changed

+68
-46
lines changed

3 files changed

+68
-46
lines changed

dataset_reader/base_reader.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from dataclasses import dataclass
22
from typing import Iterator, List, Optional
33

4-
import numpy as np
5-
64

75
@dataclass
86
class SparseVector:
9-
indices: np.array
10-
values: np.array
7+
indices: List[int]
8+
values: List[float]
119

1210

1311
@dataclass

dataset_reader/sparse_reader.py

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,62 @@
11
import os
2-
from typing import Iterator
2+
from pathlib import Path
3+
from typing import Iterator, List, Tuple, Union
34

45
import numpy as np
5-
from scipy.sparse import csr_matrix
66

77
from dataset_reader.base_reader import BaseReader, Query, Record, SparseVector
88

9-
# credit: code extracted from neuIPS 2023 benchmarks
109

10+
def read_sparse_matrix_fields(
11+
filename: Union[Path, str]
12+
) -> Tuple[np.array, np.array, np.array]:
13+
"""Read the fields of a CSR matrix without instantiating it"""
1114

12-
def read_sparse_matrix_fields(fname):
13-
"""read the fields of a CSR matrix without instantiating it"""
14-
with open(fname, "rb") as f:
15+
with open(filename, "rb") as f:
1516
sizes = np.fromfile(f, dtype="int64", count=3)
16-
nrow, ncol, nnz = sizes
17-
indptr = np.fromfile(f, dtype="int64", count=nrow + 1)
18-
assert nnz == indptr[-1]
19-
indices = np.fromfile(f, dtype="int32", count=nnz)
20-
assert np.all(indices >= 0) and np.all(indices < ncol)
21-
data = np.fromfile(f, dtype="float32", count=nnz)
22-
return data, indices, indptr, ncol
23-
24-
25-
def read_sparse_matrix(fname) -> Iterator[SparseVector]:
26-
"""read a CSR matrix in spmat format"""
27-
data, indices, indptr, ncol = read_sparse_matrix_fields(fname)
28-
# Need scipy csr_matrix to parse spmat format and easily take out rows
29-
csr_mat = csr_matrix((data, indices, indptr), shape=(len(indptr) - 1, ncol))
30-
num_vectors = csr_mat.shape[0]
31-
32-
for i in range(num_vectors):
33-
indices = csr_mat[i].indices.tolist()
34-
values = csr_mat[i].data.tolist()
35-
yield SparseVector(indices=indices, values=values)
36-
37-
38-
def knn_result_read(fname):
39-
n, d = map(int, np.fromfile(fname, dtype="uint32", count=2))
40-
assert os.stat(fname).st_size == 8 + n * d * (4 + 4)
41-
f = open(fname, "rb")
42-
f.seek(4 + 4)
43-
ids = np.fromfile(f, dtype="int32", count=n * d).reshape(n, d)
44-
scores = np.fromfile(f, dtype="float32", count=n * d).reshape(n, d)
45-
f.close()
17+
n_row, n_col, n_non_zero = sizes
18+
index_pointer = np.fromfile(f, dtype="int64", count=n_row + 1)
19+
assert n_non_zero == index_pointer[-1]
20+
columns = np.fromfile(f, dtype="int32", count=n_non_zero)
21+
assert np.all(columns >= 0) and np.all(columns < n_col)
22+
values = np.fromfile(f, dtype="float32", count=n_non_zero)
23+
return values, columns, index_pointer
24+
25+
26+
def csr_to_sparse_vectors(
27+
values: List[float], columns: List[int], index_pointer: List[int]
28+
) -> Iterator[SparseVector]:
29+
num_rows = len(index_pointer) - 1
30+
31+
for i in range(num_rows):
32+
start = index_pointer[i]
33+
end = index_pointer[i + 1]
34+
row_values, row_indices = [], []
35+
for j in range(start, end):
36+
row_values.append(values[j])
37+
row_indices.append(columns[j])
38+
yield SparseVector(indices=row_indices, values=row_values)
39+
40+
41+
def read_csr_matrix(filename: Union[Path, str]) -> Iterator[SparseVector]:
42+
"""Read a CSR matrix in spmat format"""
43+
values, columns, index_pointer = read_sparse_matrix_fields(filename)
44+
values = values.tolist()
45+
columns = columns.tolist()
46+
index_pointer = index_pointer.tolist()
47+
48+
yield from csr_to_sparse_vectors(values, columns, index_pointer)
49+
50+
51+
def knn_result_read(
52+
filename: Union[Path, str]
53+
) -> Tuple[List[List[int]], List[List[float]]]:
54+
n, d = map(int, np.fromfile(filename, dtype="uint32", count=2))
55+
assert os.stat(filename).st_size == 8 + n * d * (4 + 4)
56+
with open(filename, "rb") as f:
57+
f.seek(4 + 4)
58+
ids = np.fromfile(f, dtype="int32", count=n * d).reshape(n, d).tolist()
59+
scores = np.fromfile(f, dtype="float32", count=n * d).reshape(n, d).tolist()
4660
return ids, scores
4761

4862

@@ -53,7 +67,7 @@ def __init__(self, path, normalize=False):
5367

5468
def read_queries(self) -> Iterator[Query]:
5569
queries_path = self.path / "queries.csr"
56-
X = read_sparse_matrix(queries_path)
70+
X = read_csr_matrix(queries_path)
5771

5872
gt_path = self.path / "results.gt"
5973
gt_indices, _ = knn_result_read(gt_path)
@@ -63,12 +77,24 @@ def read_queries(self) -> Iterator[Query]:
6377
vector=None,
6478
sparse_vector=sparse_vector,
6579
meta_conditions=None,
66-
expected_result=gt_indices[i].tolist(),
80+
expected_result=gt_indices[i],
6781
)
6882

6983
def read_data(self) -> Iterator[Record]:
7084
data_path = self.path / "data.csr"
71-
X = read_sparse_matrix(data_path)
85+
X = read_csr_matrix(data_path)
7286

7387
for i, sparse_vector in enumerate(X):
7488
yield Record(id=i, vector=None, sparse_vector=sparse_vector, metadata=None)
89+
90+
91+
if __name__ == "__main__":
92+
vals = [1, 3, 2, 3, 6, 4, 5]
93+
cols = [0, 2, 2, 1, 3, 0, 2]
94+
pointers = [0, 2, 3, 5, 7]
95+
vecs = [vec for vec in csr_to_sparse_vectors(vals, cols, pointers)]
96+
97+
assert vecs[0] == SparseVector(indices=[0, 2], values=[1, 3])
98+
assert vecs[1] == SparseVector(indices=[2], values=[2])
99+
assert vecs[2] == SparseVector(indices=[1, 3], values=[3, 6])
100+
assert vecs[3] == SparseVector(indices=[0, 2], values=[4, 5])

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "vector-db-benchmark"
33
version = "0.1.0"
44
description = ""
5-
authors = ["Kacper Łukawski <kacper.lukawski@qdrant.com>"]
5+
authors = ["Qdrant Team <info@qdrant.tech>"]
66

77
[tool.poetry.dependencies]
88
python = ">=3.9,<3.12"
@@ -20,8 +20,6 @@ opensearch-py = "^2.3.2"
2020
tqdm = "^4.66.1"
2121
psycopg = {extras = ["binary"], version = "^3.1.17"}
2222
pgvector = "^0.2.4"
23-
scipy = "^1.12.0"
24-
2523

2624
[tool.poetry.dev-dependencies]
2725
pre-commit = "^2.20.0"

0 commit comments

Comments
 (0)