|
| 1 | +import os |
| 2 | +from pathlib import Path |
| 3 | +from typing import Iterator, List, Tuple, Union |
| 4 | + |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +from dataset_reader.base_reader import BaseReader, Query, Record, SparseVector |
| 8 | + |
| 9 | + |
| 10 | +def read_sparse_matrix_fields( |
| 11 | + filename: Union[Path, str] |
| 12 | +) -> Tuple[np.array, np.array, np.array]: |
| 13 | + """Read the fields of a CSR matrix without instantiating it""" |
| 14 | + |
| 15 | + with open(filename, "rb") as f: |
| 16 | + sizes = np.fromfile(f, dtype="int64", count=3) |
| 17 | + n_row, n_col, n_non_zero = sizes |
| 18 | + index_pointer = np.fromfile(f, dtype="int64", count=n_row + 1) |
| 19 | + assert n_non_zero == index_pointer[-1] |
| 20 | + columns = np.fromfile(f, dtype="int32", count=n_non_zero) |
| 21 | + assert np.all(columns >= 0) and np.all(columns < n_col) |
| 22 | + values = np.fromfile(f, dtype="float32", count=n_non_zero) |
| 23 | + return values, columns, index_pointer |
| 24 | + |
| 25 | + |
| 26 | +def csr_to_sparse_vectors( |
| 27 | + values: List[float], columns: List[int], index_pointer: List[int] |
| 28 | +) -> Iterator[SparseVector]: |
| 29 | + num_rows = len(index_pointer) - 1 |
| 30 | + |
| 31 | + for i in range(num_rows): |
| 32 | + start = index_pointer[i] |
| 33 | + end = index_pointer[i + 1] |
| 34 | + row_values, row_indices = [], [] |
| 35 | + for j in range(start, end): |
| 36 | + row_values.append(values[j]) |
| 37 | + row_indices.append(columns[j]) |
| 38 | + yield SparseVector(indices=row_indices, values=row_values) |
| 39 | + |
| 40 | + |
| 41 | +def read_csr_matrix(filename: Union[Path, str]) -> Iterator[SparseVector]: |
| 42 | + """Read a CSR matrix in spmat format""" |
| 43 | + values, columns, index_pointer = read_sparse_matrix_fields(filename) |
| 44 | + values = values.tolist() |
| 45 | + columns = columns.tolist() |
| 46 | + index_pointer = index_pointer.tolist() |
| 47 | + |
| 48 | + yield from csr_to_sparse_vectors(values, columns, index_pointer) |
| 49 | + |
| 50 | + |
| 51 | +def knn_result_read( |
| 52 | + filename: Union[Path, str] |
| 53 | +) -> Tuple[List[List[int]], List[List[float]]]: |
| 54 | + n, d = map(int, np.fromfile(filename, dtype="uint32", count=2)) |
| 55 | + assert os.stat(filename).st_size == 8 + n * d * (4 + 4) |
| 56 | + with open(filename, "rb") as f: |
| 57 | + f.seek(4 + 4) |
| 58 | + ids = np.fromfile(f, dtype="int32", count=n * d).reshape(n, d).tolist() |
| 59 | + scores = np.fromfile(f, dtype="float32", count=n * d).reshape(n, d).tolist() |
| 60 | + return ids, scores |
| 61 | + |
| 62 | + |
| 63 | +class SparseReader(BaseReader): |
| 64 | + def __init__(self, path, normalize=False): |
| 65 | + self.path = path |
| 66 | + self.normalize = normalize |
| 67 | + |
| 68 | + def read_queries(self) -> Iterator[Query]: |
| 69 | + queries_path = self.path / "queries.csr" |
| 70 | + X = read_csr_matrix(queries_path) |
| 71 | + |
| 72 | + gt_path = self.path / "results.gt" |
| 73 | + gt_indices, _ = knn_result_read(gt_path) |
| 74 | + |
| 75 | + for i, sparse_vector in enumerate(X): |
| 76 | + yield Query( |
| 77 | + vector=None, |
| 78 | + sparse_vector=sparse_vector, |
| 79 | + meta_conditions=None, |
| 80 | + expected_result=gt_indices[i], |
| 81 | + ) |
| 82 | + |
| 83 | + def read_data(self) -> Iterator[Record]: |
| 84 | + data_path = self.path / "data.csr" |
| 85 | + X = read_csr_matrix(data_path) |
| 86 | + |
| 87 | + for i, sparse_vector in enumerate(X): |
| 88 | + yield Record(id=i, vector=None, sparse_vector=sparse_vector, metadata=None) |
| 89 | + |
| 90 | + |
| 91 | +if __name__ == "__main__": |
| 92 | + vals = [1, 3, 2, 3, 6, 4, 5] |
| 93 | + cols = [0, 2, 2, 1, 3, 0, 2] |
| 94 | + pointers = [0, 2, 3, 5, 7] |
| 95 | + vecs = [vec for vec in csr_to_sparse_vectors(vals, cols, pointers)] |
| 96 | + |
| 97 | + assert vecs[0] == SparseVector(indices=[0, 2], values=[1, 3]) |
| 98 | + assert vecs[1] == SparseVector(indices=[2], values=[2]) |
| 99 | + assert vecs[2] == SparseVector(indices=[1, 3], values=[3, 6]) |
| 100 | + assert vecs[3] == SparseVector(indices=[0, 2], values=[4, 5]) |
0 commit comments