Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import java.io.IOException;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE;

/**
Expand Down Expand Up @@ -540,6 +541,11 @@ public AllReader reader(LeafReaderContext context) throws IOException {
case FLOAT -> {
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName);
if (floatVectorValues != null) {
if (fieldType.isNormalized()) {
NumericDocValues magnitudeDocValues = context.reader()
.getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX);
return new FloatDenseVectorNormalizedValuesBlockReader(floatVectorValues, dimensions, magnitudeDocValues);
}
return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions);
}
}
Expand Down Expand Up @@ -584,6 +590,9 @@ public void read(int docId, BlockLoader.StoredFields storedFields, Builder build
}

private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException {
assert vectorValues.dimension() == dimensions
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + vectorValues.dimension();

if (iterator.docID() > doc) {
builder.appendNull();
} else if (iterator.docID() == doc || iterator.advance(doc) == doc) {
Expand Down Expand Up @@ -611,8 +620,6 @@ private static class FloatDenseVectorValuesBlockReader extends DenseVectorValues

protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
float[] floats = vectorValues.vectorValue(iterator.index());
assert floats.length == dimensions
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length;
for (float aFloat : floats) {
builder.appendFloat(aFloat);
}
Expand All @@ -624,15 +631,45 @@ public String toString() {
}
}

private static class FloatDenseVectorNormalizedValuesBlockReader extends DenseVectorValuesBlockReader<FloatVectorValues> {
private final NumericDocValues magnitudeDocValues;

FloatDenseVectorNormalizedValuesBlockReader(
FloatVectorValues floatVectorValues,
int dimensions,
NumericDocValues magnitudeDocValues
) {
super(floatVectorValues, dimensions);
this.magnitudeDocValues = magnitudeDocValues;
}

@Override
protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
float magnitude = 1.0f;
// If all vectors are normalized, no doc values will be present. The vector may be normalized already, so we may not have a
// stored magnitude for all docs
if ((magnitudeDocValues != null) && magnitudeDocValues.advanceExact(iterator.docID())) {
magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue());
}
float[] floats = vectorValues.vectorValue(iterator.index());
for (float aFloat : floats) {
builder.appendFloat(aFloat * magnitude);
}
}

@Override
public String toString() {
return "BlockDocValuesReader.FloatDenseVectorNormalizedValuesBlockReader";
}
}

private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader<ByteVectorValues> {
ByteDenseVectorValuesBlockReader(ByteVectorValues floatVectorValues, int dimensions) {
super(floatVectorValues, dimensions);
}

protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
byte[] bytes = vectorValues.vectorValue(iterator.index());
assert bytes.length == dimensions
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + bytes.length;
for (byte aFloat : bytes) {
builder.appendFloat(aFloat);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,30 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static org.elasticsearch.index.IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING;
import static org.elasticsearch.index.mapper.SourceFieldMapper.Mode.SYNTHETIC;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;

public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {

public static final Set<String> ALL_DENSE_VECTOR_INDEX_TYPES = Set.of(
"int8_hnsw",
"hnsw",
"int4_hnsw",
"bbq_hnsw",
"int8_flat",
"int4_flat",
"bbq_flat",
"flat"
);
public static final Set<String> NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Set.of("hnsw", "flat");
public static final Set<String> ALL_DENSE_VECTOR_INDEX_TYPES = Arrays.stream(DenseVectorFieldMapper.VectorIndexType.values())
.filter(DenseVectorFieldMapper.VectorIndexType::isEnabled)
.map(v -> v.getName().toLowerCase(Locale.ROOT))
.collect(Collectors.toSet());

public static final Set<String> NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Arrays.stream(DenseVectorFieldMapper.VectorIndexType.values())
.filter(t -> t.isEnabled() && t.isQuantized() == false)
.map(v -> v.getName().toLowerCase(Locale.ROOT))
.collect(Collectors.toSet());

public static final float DELTA = 1e-7F;

private final ElementType elementType;
Expand All @@ -57,15 +58,10 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {
@ParametersFactory
public static Iterable<Object[]> parameters() throws Exception {
List<Object[]> params = new ArrayList<>();
List<DenseVectorFieldMapper.VectorSimilarity> similarities = List.of(
DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT,
DenseVectorFieldMapper.VectorSimilarity.L2_NORM,
DenseVectorFieldMapper.VectorSimilarity.MAX_INNER_PRODUCT
);

for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT)) {
// Test all similarities for element types
for (DenseVectorFieldMapper.VectorSimilarity similarity : similarities) {
// Test all similarities
for (DenseVectorFieldMapper.VectorSimilarity similarity : DenseVectorFieldMapper.VectorSimilarity.values()) {
params.add(new Object[] { elementType, similarity, true, false });
}

Expand All @@ -74,6 +70,7 @@ public static Iterable<Object[]> parameters() throws Exception {
// No indexing, synthetic source
params.add(new Object[] { elementType, null, false, true });
}

return params;
}

Expand Down