11import os
2- from typing import Iterator
2+ from pathlib import Path
3+ from typing import Iterator , List , Tuple , Union
34
45import numpy as np
5- from scipy .sparse import csr_matrix
66
77from dataset_reader .base_reader import BaseReader , Query , Record , SparseVector
88
9- # credit: code extracted from neuIPS 2023 benchmarks
109
10+ def read_sparse_matrix_fields (
11+ filename : Union [Path , str ]
12+ ) -> Tuple [np .array , np .array , np .array ]:
13+ """Read the fields of a CSR matrix without instantiating it"""
1114
12- def read_sparse_matrix_fields (fname ):
13- """read the fields of a CSR matrix without instantiating it"""
14- with open (fname , "rb" ) as f :
15+ with open (filename , "rb" ) as f :
1516 sizes = np .fromfile (f , dtype = "int64" , count = 3 )
16- nrow , ncol , nnz = sizes
17- indptr = np .fromfile (f , dtype = "int64" , count = nrow + 1 )
18- assert nnz == indptr [- 1 ]
19- indices = np .fromfile (f , dtype = "int32" , count = nnz )
20- assert np .all (indices >= 0 ) and np .all (indices < ncol )
21- data = np .fromfile (f , dtype = "float32" , count = nnz )
22- return data , indices , indptr , ncol
23-
24-
25- def read_sparse_matrix (fname ) -> Iterator [SparseVector ]:
26- """read a CSR matrix in spmat format"""
27- data , indices , indptr , ncol = read_sparse_matrix_fields (fname )
28- # Need scipy csr_matrix to parse spmat format and easily take out rows
29- csr_mat = csr_matrix ((data , indices , indptr ), shape = (len (indptr ) - 1 , ncol ))
30- num_vectors = csr_mat .shape [0 ]
31-
32- for i in range (num_vectors ):
33- indices = csr_mat [i ].indices .tolist ()
34- values = csr_mat [i ].data .tolist ()
35- yield SparseVector (indices = indices , values = values )
36-
37-
38- def knn_result_read (fname ):
39- n , d = map (int , np .fromfile (fname , dtype = "uint32" , count = 2 ))
40- assert os .stat (fname ).st_size == 8 + n * d * (4 + 4 )
41- f = open (fname , "rb" )
42- f .seek (4 + 4 )
43- ids = np .fromfile (f , dtype = "int32" , count = n * d ).reshape (n , d )
44- scores = np .fromfile (f , dtype = "float32" , count = n * d ).reshape (n , d )
45- f .close ()
17+ n_row , n_col , n_non_zero = sizes
18+ index_pointer = np .fromfile (f , dtype = "int64" , count = n_row + 1 )
19+ assert n_non_zero == index_pointer [- 1 ]
20+ columns = np .fromfile (f , dtype = "int32" , count = n_non_zero )
21+ assert np .all (columns >= 0 ) and np .all (columns < n_col )
22+ values = np .fromfile (f , dtype = "float32" , count = n_non_zero )
23+ return values , columns , index_pointer
24+
25+
26+ def csr_to_sparse_vectors (
27+ values : List [float ], columns : List [int ], index_pointer : List [int ]
28+ ) -> Iterator [SparseVector ]:
29+ num_rows = len (index_pointer ) - 1
30+
31+ for i in range (num_rows ):
32+ start = index_pointer [i ]
33+ end = index_pointer [i + 1 ]
34+ row_values , row_indices = [], []
35+ for j in range (start , end ):
36+ row_values .append (values [j ])
37+ row_indices .append (columns [j ])
38+ yield SparseVector (indices = row_indices , values = row_values )
39+
40+
41+ def read_csr_matrix (filename : Union [Path , str ]) -> Iterator [SparseVector ]:
42+ """Read a CSR matrix in spmat format"""
43+ values , columns , index_pointer = read_sparse_matrix_fields (filename )
44+ values = values .tolist ()
45+ columns = columns .tolist ()
46+ index_pointer = index_pointer .tolist ()
47+
48+ yield from csr_to_sparse_vectors (values , columns , index_pointer )
49+
50+
51+ def knn_result_read (
52+ filename : Union [Path , str ]
53+ ) -> Tuple [List [List [int ]], List [List [float ]]]:
54+ n , d = map (int , np .fromfile (filename , dtype = "uint32" , count = 2 ))
55+ assert os .stat (filename ).st_size == 8 + n * d * (4 + 4 )
56+ with open (filename , "rb" ) as f :
57+ f .seek (4 + 4 )
58+ ids = np .fromfile (f , dtype = "int32" , count = n * d ).reshape (n , d ).tolist ()
59+ scores = np .fromfile (f , dtype = "float32" , count = n * d ).reshape (n , d ).tolist ()
4660 return ids , scores
4761
4862
@@ -53,7 +67,7 @@ def __init__(self, path, normalize=False):
5367
5468 def read_queries (self ) -> Iterator [Query ]:
5569 queries_path = self .path / "queries.csr"
56- X = read_sparse_matrix (queries_path )
70+ X = read_csr_matrix (queries_path )
5771
5872 gt_path = self .path / "results.gt"
5973 gt_indices , _ = knn_result_read (gt_path )
@@ -63,12 +77,24 @@ def read_queries(self) -> Iterator[Query]:
6377 vector = None ,
6478 sparse_vector = sparse_vector ,
6579 meta_conditions = None ,
66- expected_result = gt_indices [i ]. tolist () ,
80+ expected_result = gt_indices [i ],
6781 )
6882
6983 def read_data (self ) -> Iterator [Record ]:
7084 data_path = self .path / "data.csr"
71- X = read_sparse_matrix (data_path )
85+ X = read_csr_matrix (data_path )
7286
7387 for i , sparse_vector in enumerate (X ):
7488 yield Record (id = i , vector = None , sparse_vector = sparse_vector , metadata = None )
89+
90+
91+ if __name__ == "__main__" :
92+ vals = [1 , 3 , 2 , 3 , 6 , 4 , 5 ]
93+ cols = [0 , 2 , 2 , 1 , 3 , 0 , 2 ]
94+ pointers = [0 , 2 , 3 , 5 , 7 ]
95+ vecs = [vec for vec in csr_to_sparse_vectors (vals , cols , pointers )]
96+
97+ assert vecs [0 ] == SparseVector (indices = [0 , 2 ], values = [1 , 3 ])
98+ assert vecs [1 ] == SparseVector (indices = [2 ], values = [2 ])
99+ assert vecs [2 ] == SparseVector (indices = [1 , 3 ], values = [3 , 6 ])
100+ assert vecs [3 ] == SparseVector (indices = [0 , 2 ], values = [4 , 5 ])
0 commit comments