From 763fe6392701e5612784c98307febc75974eea43 Mon Sep 17 00:00:00 2001 From: cdelgado Date: Tue, 12 Aug 2025 14:00:04 +0200 Subject: [PATCH 01/10] Take into account normalization for dense vector support --- .../index/mapper/BlockDocValuesReader.java | 100 ++++++++++++++- .../vectors/DenseVectorFieldMapper.java | 62 +++++----- .../xpack/esql/DenseVectorFieldTypeIT.java | 100 ++++++++++----- .../xpack/esql/plugin/KnnFunctionIT.java | 117 ++++++++++++------ 4 files changed, 270 insertions(+), 109 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index 809bad5145fe6..3ad251af5ef47 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -29,11 +29,15 @@ import org.elasticsearch.index.mapper.BlockLoader.DoubleBuilder; import org.elasticsearch.index.mapper.BlockLoader.IntBuilder; import org.elasticsearch.index.mapper.BlockLoader.LongBuilder; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder; import org.elasticsearch.search.fetch.StoredFieldsSpec; import java.io.IOException; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX; + /** * A reader that supports reading doc-values from a Lucene segment in Block fashion. */ @@ -511,10 +515,12 @@ public String toString() { public static class DenseVectorBlockLoader extends DocValuesBlockLoader { private final String fieldName; private final int dimensions; + private final DenseVectorFieldMapper.DenseVectorFieldType fieldType; - public DenseVectorBlockLoader(String fieldName, int dimensions) { + public DenseVectorBlockLoader(String fieldName, int dimensions, DenseVectorFieldMapper.DenseVectorFieldType fieldType) { this.fieldName = fieldName; this.dimensions = dimensions; + this.fieldType = fieldType; } @Override @@ -524,9 +530,26 @@ public Builder builder(BlockFactory factory, int expectedCount) { @Override public AllReader reader(LeafReaderContext context) throws IOException { - FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); - if (floatVectorValues != null) { - return new DenseVectorValuesBlockReader(floatVectorValues, dimensions); + switch (fieldType.getElementType()) { + case FLOAT -> { + FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); + if (floatVectorValues != null) { + if (fieldType.isNormalized()) { + return new FloatDenseVectorNormalizedValuesBlockReader( + floatVectorValues, + dimensions, + context.reader().getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX) + ); + } + return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); + } + } + case BYTE -> { + ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(fieldName); + if (byteVectorValues != null) { + return new ByteDenseVectorValuesBlockReader(byteVectorValues, dimensions); + } + } } return new ConstantNullsReader(); } @@ -580,10 +603,77 @@ private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException public int docId() { return iterator.docID(); } + } + + private static class FloatDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader { + + FloatDenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) { + super(floatVectorValues, dimensions); + } + + protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { + float[] floats = vectorValues.vectorValue(iterator.index()); + assert floats.length == dimensions + : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; + for (float aFloat : floats) { + builder.appendFloat(aFloat); + } + } + + @Override + public String toString() { + return "BlockDocValuesReader.FloatDenseVectorValuesBlockReader"; + } + } + + private static class FloatDenseVectorNormalizedValuesBlockReader extends DenseVectorValuesBlockReader { + private final NumericDocValues magnitudeDocValues; + + FloatDenseVectorNormalizedValuesBlockReader( + FloatVectorValues floatVectorValues, + int dimensions, + NumericDocValues magnitudeDocValues + ) { + super(floatVectorValues, dimensions); + this.magnitudeDocValues = magnitudeDocValues; + } + + @Override + protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { + float[] floats = vectorValues.vectorValue(iterator.index()); + assert floats.length == dimensions + : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; + + assert magnitudeDocValues.advanceExact(iterator.docID()); + float magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue()); + for (float aFloat : floats) { + builder.appendFloat(aFloat * magnitude); + } + } + + @Override + public String toString() { + return "BlockDocValuesReader.FloatDenseVectorNormalizedValuesBlockReader"; + } + } + + private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader { + ByteDenseVectorValuesBlockReader(ByteVectorValues floatVectorValues, int dimensions) { + super(floatVectorValues, dimensions); + } + + protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { + byte[] bytes = vectorValues.vectorValue(iterator.index()); + assert bytes.length == dimensions + : "unexpected dimensions for vector value; expected " + dimensions + " but got " + bytes.length; + for (byte aFloat : bytes) { + builder.appendFloat(aFloat); + } + } @Override public String toString() { - return "BlockDocValuesReader.FloatVectorValuesBlockReader"; + return "BlockDocValuesReader.ByteDenseVectorValuesBlockReader"; } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index c9c14d027ebfd..2a3655a1100dd 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -734,31 +734,29 @@ IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldTyp this, denseVectorFieldType.dims, denseVectorFieldType.indexed, - denseVectorFieldType.indexVersionCreated.onOrAfter(NORMALIZE_COSINE) - && denseVectorFieldType.indexed - && denseVectorFieldType.similarity.equals(VectorSimilarity.COSINE) ? r -> new FilterLeafReader(r) { - @Override - public CacheHelper getCoreCacheHelper() { - return r.getCoreCacheHelper(); - } + denseVectorFieldType.isNormalized() && denseVectorFieldType.indexed ? r -> new FilterLeafReader(r) { + @Override + public CacheHelper getCoreCacheHelper() { + return r.getCoreCacheHelper(); + } - @Override - public CacheHelper getReaderCacheHelper() { - return r.getReaderCacheHelper(); - } + @Override + public CacheHelper getReaderCacheHelper() { + return r.getReaderCacheHelper(); + } - @Override - public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException { - FloatVectorValues values = in.getFloatVectorValues(fieldName); - if (values == null) { - return null; - } - return new DenormalizedCosineFloatVectorValues( - values, - in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX) - ); + @Override + public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException { + FloatVectorValues values = in.getFloatVectorValues(fieldName); + if (values == null) { + return null; } - } : r -> r + return new DenormalizedCosineFloatVectorValues( + values, + in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX) + ); + } + } : r -> r ); } @@ -820,9 +818,7 @@ public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFie fieldMapper.checkDimensionMatches(index, context); checkVectorBounds(vector); checkVectorMagnitude(fieldMapper.fieldType().similarity, errorFloatElementsAppender(vector), squaredMagnitude); - if (fieldMapper.indexCreatedVersion.onOrAfter(NORMALIZE_COSINE) - && fieldMapper.fieldType().similarity.equals(VectorSimilarity.COSINE) - && isNotUnitVector(squaredMagnitude)) { + if (fieldMapper.fieldType().isNormalized() && isNotUnitVector(squaredMagnitude)) { float length = (float) Math.sqrt(squaredMagnitude); for (int i = 0; i < vector.length; i++) { vector[i] /= length; @@ -2491,6 +2487,10 @@ public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity) return knnQuery; } + public boolean isNormalized() { + return indexVersionCreated.onOrAfter(NORMALIZE_COSINE) && VectorSimilarity.COSINE.equals(similarity); + } + private Query createExactKnnBitQuery(byte[] queryVector) { elementType.checkDimensions(dims, queryVector.length); return new DenseVectorQuery.Bytes(queryVector, name()); @@ -2511,9 +2511,7 @@ private Query createExactKnnFloatQuery(float[] queryVector) { if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude); - if (similarity == VectorSimilarity.COSINE - && indexVersionCreated.onOrAfter(NORMALIZE_COSINE) - && isNotUnitVector(squaredMagnitude)) { + if (isNormalized() && isNotUnitVector(squaredMagnitude)) { float length = (float) Math.sqrt(squaredMagnitude); queryVector = Arrays.copyOf(queryVector, queryVector.length); for (int i = 0; i < queryVector.length; i++) { @@ -2703,9 +2701,7 @@ private Query createKnnFloatQuery( if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude); - if (similarity == VectorSimilarity.COSINE - && indexVersionCreated.onOrAfter(NORMALIZE_COSINE) - && isNotUnitVector(squaredMagnitude)) { + if (isNormalized() && isNotUnitVector(squaredMagnitude)) { float length = (float) Math.sqrt(squaredMagnitude); queryVector = Arrays.copyOf(queryVector, queryVector.length); for (int i = 0; i < queryVector.length; i++) { @@ -2795,7 +2791,7 @@ int getVectorDimensions() { return dims; } - ElementType getElementType() { + public ElementType getElementType() { return elementType; } @@ -2816,7 +2812,7 @@ public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) { } if (indexed) { - return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims); + return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims, this); } if (hasDocValues() && (blContext.fieldExtractPreference() != FieldExtractPreference.STORED || isSyntheticSource)) { diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 0673856cbcc3b..87d2583e2285b 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -13,6 +13,9 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import org.elasticsearch.script.field.vectors.DenseVector; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; @@ -23,8 +26,11 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; +import java.util.function.Function; +import java.util.function.Supplier; import static org.elasticsearch.index.IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING; import static org.elasticsearch.index.mapper.SourceFieldMapper.Mode.SYNTHETIC; @@ -32,7 +38,7 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { - private static final Set DENSE_VECTOR_INDEX_TYPES = Set.of( + public static final Set ALL_DENSE_VECTOR_INDEX_TYPES = Set.of( "int8_hnsw", "hnsw", "int4_hnsw", @@ -43,31 +49,46 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { "flat" ); + public static final float DELTA = 1e-7F; + + private final ElementType elementType; + private final DenseVectorFieldMapper.VectorSimilarity similarity; + private final boolean synthetic; private final String indexType; private final boolean index; - private final boolean synthetic; @ParametersFactory public static Iterable parameters() throws Exception { List params = new ArrayList<>(); // Indexed field types - for (String indexType : DENSE_VECTOR_INDEX_TYPES) { - params.add(new Object[] { indexType, true, false }); - } + Supplier elementTypeProvider = () -> ElementType.FLOAT; + Function indexTypeProvider = e -> randomFrom(ALL_DENSE_VECTOR_INDEX_TYPES); + Supplier vectorSimilarityProvider = () -> randomFrom( + DenseVectorFieldMapper.VectorSimilarity.values() + ); + params.add(new Object[] { elementTypeProvider, indexTypeProvider, vectorSimilarityProvider, true, false }); // No indexing - params.add(new Object[] { null, false, false }); + params.add(new Object[] { elementTypeProvider, null, null, false, false }); // No indexing, synthetic source - params.add(new Object[] { null, false, true }); + params.add(new Object[] { elementTypeProvider, null, null, false, true }); return params; } - public DenseVectorFieldTypeIT(@Name("indexType") String indexType, @Name("index") boolean index, @Name("synthetic") boolean synthetic) { - this.indexType = indexType; + public DenseVectorFieldTypeIT( + @Name("elementType") Supplier elementTypeProvider, + @Name("indexType") Function indexTypeProvider, + @Name("similarity") Supplier similarityProvider, + @Name("index") boolean index, + @Name("synthetic") boolean synthetic + ) { + this.elementType = elementTypeProvider.get(); + this.indexType = indexTypeProvider == null ? null : indexTypeProvider.apply(this.elementType); + this.similarity = similarityProvider == null ? null : similarityProvider.get(); this.index = index; this.synthetic = synthetic; } - private final Map> indexedVectors = new HashMap<>(); + private final Map> indexedVectors = new HashMap<>(); public void testRetrieveFieldType() { var query = """ @@ -90,17 +111,17 @@ public void testRetrieveTopNDenseVectorFieldData() { try (var resp = run(query)) { List> valuesList = EsqlTestUtils.getValuesList(resp); - indexedVectors.forEach((id, vector) -> { + indexedVectors.forEach((id, expectedVector) -> { var values = valuesList.get(id); assertEquals(id, values.get(0)); - List vectors = (List) values.get(1); - if (vector == null) { - assertNull(vectors); + List actualVector = (List) values.get(1); + if (expectedVector == null) { + assertNull(actualVector); } else { - assertNotNull(vectors); - assertEquals(vector.size(), vectors.size()); - for (int i = 0; i < vector.size(); i++) { - assertEquals(vector.get(i), vectors.get(i), 0F); + assertNotNull(actualVector); + assertEquals(expectedVector.size(), actualVector.size()); + for (int i = 0; i < expectedVector.size(); i++) { + assertEquals(expectedVector.get(i).floatValue(), actualVector.get(i).floatValue(), DELTA); } } }); @@ -114,24 +135,31 @@ public void testRetrieveDenseVectorFieldData() { | KEEP id, vector """; + indexedVectors.forEach((i, v) -> { + System.out.println("ID: " + i + ", Vector: " + v); + }); + try (var resp = run(query)) { List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(valuesList.size(), indexedVectors.size()); + // print all values for debugging valuesList.forEach(value -> { - ; assertEquals(2, value.size()); Integer id = (Integer) value.get(0); - List expectedVector = indexedVectors.get(id); - List vector = (List) value.get(1); + List expectedVector = indexedVectors.get(id); + List actualVector = (List) value.get(1); if (expectedVector == null) { - assertNull(vector); + assertNull(actualVector); } else { - assertNotNull(vector); - assertEquals(expectedVector.size(), vector.size()); - assertNotNull(vector); - assertNotNull(expectedVector); - for (int i = 0; i < vector.size(); i++) { - assertEquals(expectedVector.get(i), vector.get(i), 0F); + assertNotNull(actualVector); + assertEquals(expectedVector.size(), actualVector.size()); + for (int i = 0; i < actualVector.size(); i++) { + assertEquals( + "Actual: " + actualVector + "; expected: " + expectedVector, + expectedVector.get(i).floatValue(), + actualVector.get(i).floatValue(), + DELTA + ); } } }); @@ -177,13 +205,19 @@ public void setup() throws IOException { int numDocs = randomIntBetween(10, 100); IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; for (int i = 0; i < numDocs; i++) { - List vector = new ArrayList<>(numDims); + List vector = new ArrayList<>(numDims); if (rarely()) { docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i)); indexedVectors.put(i, null); } else { for (int j = 0; j < numDims; j++) { - vector.add(randomFloat()); + vector.add(randomFloatBetween(0F, 1F, true)); + } + if (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT) { + // Normalize the vector + float magnitude = DenseVector.getMagnitude(vector); + vector.replaceAll(number -> number.floatValue() / magnitude); + } } docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector); indexedVectors.put(i, vector); @@ -203,9 +237,13 @@ private void createIndexWithDenseVector(String indexName) throws IOException { .endObject() .startObject("vector") .field("type", "dense_vector") + .field("element_type", elementType.toString().toLowerCase(Locale.ROOT)) .field("index", index); if (index) { - mapping.field("similarity", "l2_norm"); + mapping.field( + "similarity", + similarity.name().toLowerCase(Locale.ROOT) + ); } if (indexType != null) { mapping.startObject("index_options").field("type", indexType).endObject(); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index 9ae1c980337f1..c2f8662ac502b 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -7,11 +7,15 @@ package org.elasticsearch.xpack.esql.plugin; +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.client.internal.IndicesAdminClient; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.esql.EsqlTestUtils; @@ -30,42 +34,73 @@ import static org.elasticsearch.index.IndexMode.LOOKUP; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.ALL_DENSE_VECTOR_INDEX_TYPES; +import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES; import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.Matchers.lessThanOrEqualTo; public class KnnFunctionIT extends AbstractEsqlIntegTestCase { - private final Map> indexedVectors = new HashMap<>(); + private final Map> indexedVectors = new HashMap<>(); private int numDocs; private int numDims; + private final DenseVectorFieldMapper.ElementType elementType; + private final String indexType; + + @ParametersFactory + public static Iterable parameters() throws Exception { + List params = new ArrayList<>(); + for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) { + params.add(new Object[] { DenseVectorFieldMapper.ElementType.FLOAT, indexType }); + } + for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) { + params.add(new Object[] { DenseVectorFieldMapper.ElementType.BYTE, indexType }); + } + + // Remove flat index types, as knn does not do a top k for flat + params.removeIf(param -> param[1] != null && ((String) param[1]).contains("flat")); + return params; + } + + public KnnFunctionIT(@Name("elementType") DenseVectorFieldMapper.ElementType elementType, @Name("indexType") String indexType) { + this.elementType = elementType; + this.indexType = indexType; + } + public void testKnnDefaults() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); var query = String.format(Locale.ROOT, """ FROM test METADATA _score | WHERE knn(vector, %s, 10) - | KEEP id, floats, _score, vector + | KEEP id, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(Math.min(indexedVectors.size(), 10), valuesList.size()); - for (int i = 0; i < valuesList.size(); i++) { - List row = valuesList.get(i); - // Vectors should be in order of ID, as they're less similar than the query vector as the ID increases - assertEquals(i, row.getFirst()); + double previousScore = Float.MAX_VALUE; + for (List row : valuesList) { + // Vectors should be in score order + double currentScore = (Double) row.get(1); + assertThat(currentScore, lessThanOrEqualTo(previousScore)); + previousScore = currentScore; @SuppressWarnings("unchecked") // Vectors should be the same - List floats = (List) row.get(1); - for (int j = 0; j < floats.size(); j++) { - assertEquals(floats.get(j).floatValue(), indexedVectors.get(i).get(j), 0f); + List actualVector = (List) row.get(2); + List expectedVector = indexedVectors.get(row.get(0)); + for (int j = 0; j < actualVector.size(); j++) { + float expected = expectedVector.get(j).floatValue(); + float actual = actualVector.get(j).floatValue(); + assertEquals(expected, actual, 0f); } - var score = (Double) row.get(2); + var score = (Double) row.get(1); assertNotNull(score); assertTrue(score > 0.0); } @@ -74,18 +109,18 @@ public void testKnnDefaults() { public void testKnnOptions() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); var query = String.format(Locale.ROOT, """ FROM test METADATA _score | WHERE knn(vector, %s, 5) - | KEEP id, floats, _score, vector + | KEEP id, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(5, valuesList.size()); @@ -94,42 +129,41 @@ public void testKnnOptions() { public void testKnnNonPushedDown() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); // TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, 5) OR id > 10 - | KEEP id, floats, _score, vector + | WHERE knn(vector, %s, 5) OR id > 100 + | KEEP id, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); - // K = 5, 1 more for every id > 10 - assertEquals(5 + Math.max(0, numDocs - 10 - 1), valuesList.size()); + assertEquals(5, valuesList.size()); } } public void testKnnWithPrefilters() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); // We retrieve 5 from knn, but must be prefiltered with id > 5 or no result will be returned as it would be post-filtered var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, 5) AND id > 5 - | KEEP id, floats, _score, vector + | WHERE knn(vector, %s, 5) AND id > 5 AND id <= 10 + | KEEP id, _score, vector | SORT _score DESC | LIMIT 5 """, Arrays.toString(queryVector)); try (var resp = run(query)) { - assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); - assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + assertColumnNames(resp.columns(), List.of("id", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "dense_vector")); List> valuesList = EsqlTestUtils.getValuesList(resp); // K = 5, 1 more for every id > 10 @@ -139,12 +173,12 @@ public void testKnnWithPrefilters() { public void testKnnWithLookupJoin() { float[] queryVector = new float[numDims]; - Arrays.fill(queryVector, 1.0f); + Arrays.fill(queryVector, 0.0f); var query = String.format(Locale.ROOT, """ FROM test | LOOKUP JOIN test_lookup ON id - | WHERE KNN(lookup_vector, %s, 5) OR id > 10 + | WHERE KNN(lookup_vector, %s, 5) OR id > 100 """, Arrays.toString(queryVector)); var error = expectThrows(VerificationException.class, () -> run(query)); @@ -171,10 +205,14 @@ public void setup() throws IOException { .endObject() .startObject("vector") .field("type", "dense_vector") - .field("similarity", "l2_norm") + .field( + "similarity", + // Let's not use others to avoid vector normalization + randomFrom("l2_norm", "max_inner_product") + ) + .startObject("index_options") + .field("type", indexType) .endObject() - .startObject("floats") - .field("type", "float") .endObject() .endObject() .endObject(); @@ -186,16 +224,15 @@ public void setup() throws IOException { var createRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build()); assertAcked(createRequest); - numDocs = randomIntBetween(15, 25); - numDims = randomIntBetween(3, 10); + numDocs = randomIntBetween(20, 35); + numDims = 64 + randomIntBetween(1, 10) * 2; IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; - float value = 0.0f; for (int i = 0; i < numDocs; i++) { - List vector = new ArrayList<>(numDims); + List vector = new ArrayList<>(numDims); for (int j = 0; j < numDims; j++) { - vector.add(value++); + vector.add(randomFloatBetween(0F, 1F, true)); } - docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "floats", vector, "vector", vector); + docs[i] = prepareIndex("test").setId(String.valueOf(i)).setSource("id", String.valueOf(i), "vector", vector); indexedVectors.put(i, vector); } From 80b48cf03b5c11f005c6f19134782a9546c67a9b Mon Sep 17 00:00:00 2001 From: cdelgado Date: Tue, 12 Aug 2025 14:20:02 +0200 Subject: [PATCH 02/10] Fix cherry pick --- .../index/mapper/BlockDocValuesReader.java | 74 ++++++------------- .../xpack/esql/DenseVectorFieldTypeIT.java | 3 +- .../xpack/esql/plugin/KnnFunctionIT.java | 4 - 3 files changed, 22 insertions(+), 59 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index 3ad251af5ef47..5ae2162192e28 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -30,7 +30,6 @@ import org.elasticsearch.index.mapper.BlockLoader.IntBuilder; import org.elasticsearch.index.mapper.BlockLoader.LongBuilder; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder; import org.elasticsearch.search.fetch.StoredFieldsSpec; @@ -530,39 +529,31 @@ public Builder builder(BlockFactory factory, int expectedCount) { @Override public AllReader reader(LeafReaderContext context) throws IOException { - switch (fieldType.getElementType()) { - case FLOAT -> { - FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); - if (floatVectorValues != null) { - if (fieldType.isNormalized()) { - return new FloatDenseVectorNormalizedValuesBlockReader( - floatVectorValues, - dimensions, - context.reader().getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX) - ); - } - return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); - } - } - case BYTE -> { - ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(fieldName); - if (byteVectorValues != null) { - return new ByteDenseVectorValuesBlockReader(byteVectorValues, dimensions); - } + FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); + if (floatVectorValues != null) { + if (fieldType.isNormalized()) { + return new FloatDenseVectorNormalizedValuesBlockReader( + floatVectorValues, + dimensions, + context.reader().getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX) + ); } + return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); } + return new ConstantNullsReader(); } } - private static class DenseVectorValuesBlockReader extends BlockDocValuesReader { - private final FloatVectorValues floatVectorValues; - private final KnnVectorValues.DocIndexIterator iterator; - private final int dimensions; + private abstract static class DenseVectorValuesBlockReader extends BlockDocValuesReader { - DenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) { - this.floatVectorValues = floatVectorValues; - iterator = floatVectorValues.iterator(); + protected final T vectorValues; + protected final KnnVectorValues.DocIndexIterator iterator; + protected final int dimensions; + + DenseVectorValuesBlockReader(T vectorValues, int dimensions) { + this.vectorValues = vectorValues; + iterator = vectorValues.iterator(); this.dimensions = dimensions; } @@ -587,18 +578,15 @@ private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException builder.appendNull(); } else if (iterator.docID() == doc || iterator.advance(doc) == doc) { builder.beginPositionEntry(); - float[] floats = floatVectorValues.vectorValue(iterator.index()); - assert floats.length == dimensions - : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; - for (float aFloat : floats) { - builder.appendFloat(aFloat); - } + appendDoc(builder); builder.endPositionEntry(); } else { builder.appendNull(); } } + protected abstract void appendDoc(BlockLoader.FloatBuilder builder) throws IOException; + @Override public int docId() { return iterator.docID(); @@ -657,26 +645,6 @@ public String toString() { } } - private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader { - ByteDenseVectorValuesBlockReader(ByteVectorValues floatVectorValues, int dimensions) { - super(floatVectorValues, dimensions); - } - - protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { - byte[] bytes = vectorValues.vectorValue(iterator.index()); - assert bytes.length == dimensions - : "unexpected dimensions for vector value; expected " + dimensions + " but got " + bytes.length; - for (byte aFloat : bytes) { - builder.appendFloat(aFloat); - } - } - - @Override - public String toString() { - return "BlockDocValuesReader.ByteDenseVectorValuesBlockReader"; - } - } - public static class BytesRefsFromOrdsBlockLoader extends DocValuesBlockLoader { private final String fieldName; diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 87d2583e2285b..97addbecd4564 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -216,8 +216,7 @@ public void setup() throws IOException { if (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT) { // Normalize the vector float magnitude = DenseVector.getMagnitude(vector); - vector.replaceAll(number -> number.floatValue() / magnitude); - } + vector.replaceAll(number -> number.floatValue() / magnitude); } docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector); indexedVectors.put(i, vector); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index c2f8662ac502b..4d630d95b264b 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -35,7 +35,6 @@ import static org.elasticsearch.index.IndexMode.LOOKUP; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.ALL_DENSE_VECTOR_INDEX_TYPES; -import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.Matchers.lessThanOrEqualTo; @@ -54,9 +53,6 @@ public static Iterable parameters() throws Exception { for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) { params.add(new Object[] { DenseVectorFieldMapper.ElementType.FLOAT, indexType }); } - for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) { - params.add(new Object[] { DenseVectorFieldMapper.ElementType.BYTE, indexType }); - } // Remove flat index types, as knn does not do a top k for flat params.removeIf(param -> param[1] != null && ((String) param[1]).contains("flat")); From 40edca3f893b232dfb6170093775ed0438e6cdc4 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 12 Aug 2025 12:35:32 +0000 Subject: [PATCH 03/10] [CI] Auto commit changes from spotless --- .../elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 97addbecd4564..00f890b816aea 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -135,9 +135,7 @@ public void testRetrieveDenseVectorFieldData() { | KEEP id, vector """; - indexedVectors.forEach((i, v) -> { - System.out.println("ID: " + i + ", Vector: " + v); - }); + indexedVectors.forEach((i, v) -> { System.out.println("ID: " + i + ", Vector: " + v); }); try (var resp = run(query)) { List> valuesList = EsqlTestUtils.getValuesList(resp); @@ -239,10 +237,7 @@ private void createIndexWithDenseVector(String indexName) throws IOException { .field("element_type", elementType.toString().toLowerCase(Locale.ROOT)) .field("index", index); if (index) { - mapping.field( - "similarity", - similarity.name().toLowerCase(Locale.ROOT) - ); + mapping.field("similarity", similarity.name().toLowerCase(Locale.ROOT)); } if (indexType != null) { mapping.startObject("index_options").field("type", indexType).endObject(); From 8bd7f7914e2f01b2d3dc56a9221f070cd229fe9b Mon Sep 17 00:00:00 2001 From: cdelgado Date: Tue, 12 Aug 2025 15:21:55 +0200 Subject: [PATCH 04/10] Remove debugging code --- .../org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 97addbecd4564..cd6a00d613de1 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -135,10 +135,6 @@ public void testRetrieveDenseVectorFieldData() { | KEEP id, vector """; - indexedVectors.forEach((i, v) -> { - System.out.println("ID: " + i + ", Vector: " + v); - }); - try (var resp = run(query)) { List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(valuesList.size(), indexedVectors.size()); From 7d2625cee55b5b05257fb61eaa4543742dd6eb14 Mon Sep 17 00:00:00 2001 From: cdelgado Date: Tue, 12 Aug 2025 16:52:38 +0200 Subject: [PATCH 05/10] Check that we may not have magnitudes at all, or for normalized vectors --- .../index/mapper/BlockDocValuesReader.java | 16 +++++++++------- .../xpack/esql/DenseVectorFieldTypeIT.java | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index 5ae2162192e28..6d869c4d394f8 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -532,11 +532,9 @@ public AllReader reader(LeafReaderContext context) throws IOException { FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName); if (floatVectorValues != null) { if (fieldType.isNormalized()) { - return new FloatDenseVectorNormalizedValuesBlockReader( - floatVectorValues, - dimensions, - context.reader().getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX) - ); + NumericDocValues magnitudeDocValues = context.reader() + .getNumericDocValues(fieldType.name() + COSINE_MAGNITUDE_FIELD_SUFFIX); + return new FloatDenseVectorNormalizedValuesBlockReader(floatVectorValues, dimensions, magnitudeDocValues); } return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions); } @@ -632,8 +630,12 @@ protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { assert floats.length == dimensions : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; - assert magnitudeDocValues.advanceExact(iterator.docID()); - float magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue()); + float magnitude = 1.0f; + // If all vectors are normalized, no doc values will be present. The vector may be normalized already, so we may not have a + // stored magnitude for all docs + if ((magnitudeDocValues != null) && magnitudeDocValues.advanceExact(iterator.docID())) { + magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue()); + } for (float aFloat : floats) { builder.appendFloat(aFloat * magnitude); } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 566f130dd9cc5..346fe51daebb6 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -209,7 +209,7 @@ public void setup() throws IOException { for (int j = 0; j < numDims; j++) { vector.add(randomFloatBetween(0F, 1F, true)); } - if (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT) { + if ((similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT) || rarely()) { // Normalize the vector float magnitude = DenseVector.getMagnitude(vector); vector.replaceAll(number -> number.floatValue() / magnitude); From 4371b69ed2398292dcdd81bd243e78d8161a634a Mon Sep 17 00:00:00 2001 From: cdelgado Date: Wed, 13 Aug 2025 11:39:25 +0200 Subject: [PATCH 06/10] Better parameterized test --- .../xpack/esql/DenseVectorFieldTypeIT.java | 40 +++++++++---------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 346fe51daebb6..3291f5303ba2c 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -29,8 +29,6 @@ import java.util.Locale; import java.util.Map; import java.util.Set; -import java.util.function.Function; -import java.util.function.Supplier; import static org.elasticsearch.index.IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING; import static org.elasticsearch.index.mapper.SourceFieldMapper.Mode.SYNTHETIC; @@ -54,36 +52,33 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { private final ElementType elementType; private final DenseVectorFieldMapper.VectorSimilarity similarity; private final boolean synthetic; - private final String indexType; private final boolean index; @ParametersFactory public static Iterable parameters() throws Exception { List params = new ArrayList<>(); - // Indexed field types - Supplier elementTypeProvider = () -> ElementType.FLOAT; - Function indexTypeProvider = e -> randomFrom(ALL_DENSE_VECTOR_INDEX_TYPES); - Supplier vectorSimilarityProvider = () -> randomFrom( - DenseVectorFieldMapper.VectorSimilarity.values() - ); - params.add(new Object[] { elementTypeProvider, indexTypeProvider, vectorSimilarityProvider, true, false }); + + // Test all similarities + for (DenseVectorFieldMapper.VectorSimilarity similarity : DenseVectorFieldMapper.VectorSimilarity.values()) { + params.add(new Object[] { ElementType.FLOAT, similarity, true, false }); + } + // No indexing - params.add(new Object[] { elementTypeProvider, null, null, false, false }); + params.add(new Object[] { ElementType.FLOAT, null, false, false }); // No indexing, synthetic source - params.add(new Object[] { elementTypeProvider, null, null, false, true }); + params.add(new Object[] { ElementType.FLOAT, null, false, true }); + return params; } public DenseVectorFieldTypeIT( - @Name("elementType") Supplier elementTypeProvider, - @Name("indexType") Function indexTypeProvider, - @Name("similarity") Supplier similarityProvider, + @Name("elementType") ElementType elementType, + @Name("similarity") DenseVectorFieldMapper.VectorSimilarity similarity, @Name("index") boolean index, @Name("synthetic") boolean synthetic ) { - this.elementType = elementTypeProvider.get(); - this.indexType = indexTypeProvider == null ? null : indexTypeProvider.apply(this.elementType); - this.similarity = similarityProvider == null ? null : similarityProvider.get(); + this.elementType = elementType; + this.similarity = similarity; this.index = index; this.synthetic = synthetic; } @@ -207,7 +202,11 @@ public void setup() throws IOException { indexedVectors.put(i, null); } else { for (int j = 0; j < numDims; j++) { - vector.add(randomFloatBetween(0F, 1F, true)); + switch (elementType) { + case FLOAT -> vector.add(randomFloatBetween(0F, 1F, true)); + case BYTE -> vector.add((byte) (randomFloatBetween(0F, 1F, true) * 127.0f)); + default -> throw new IllegalArgumentException("Unexpected element type: " + elementType); + } } if ((similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT) || rarely()) { // Normalize the vector @@ -236,8 +235,7 @@ private void createIndexWithDenseVector(String indexName) throws IOException { .field("index", index); if (index) { mapping.field("similarity", similarity.name().toLowerCase(Locale.ROOT)); - } - if (indexType != null) { + String indexType = randomFrom(ALL_DENSE_VECTOR_INDEX_TYPES); mapping.startObject("index_options").field("type", indexType).endObject(); } mapping.endObject().endObject().endObject(); From e98bbcb784c273e7d51d4264d8b6bed74b58e909 Mon Sep 17 00:00:00 2001 From: cdelgado Date: Wed, 13 Aug 2025 16:51:15 +0200 Subject: [PATCH 07/10] Refactor dimension check --- .../index/mapper/BlockDocValuesReader.java | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java index 473f9a368070f..4a17ccc46ddac 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java @@ -577,6 +577,9 @@ public void read(int docId, BlockLoader.StoredFields storedFields, Builder build } private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException { + assert vectorValues.dimension() == dimensions + : "unexpected dimensions for vector value; expected " + dimensions + " but got " + vectorValues.dimension(); + if (iterator.docID() > doc) { builder.appendNull(); } else if (iterator.docID() == doc || iterator.advance(doc) == doc) { @@ -604,8 +607,6 @@ private static class FloatDenseVectorValuesBlockReader extends DenseVectorValues protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { float[] floats = vectorValues.vectorValue(iterator.index()); - assert floats.length == dimensions - : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; for (float aFloat : floats) { builder.appendFloat(aFloat); } @@ -631,16 +632,13 @@ private static class FloatDenseVectorNormalizedValuesBlockReader extends DenseVe @Override protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException { - float[] floats = vectorValues.vectorValue(iterator.index()); - assert floats.length == dimensions - : "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length; - float magnitude = 1.0f; // If all vectors are normalized, no doc values will be present. The vector may be normalized already, so we may not have a // stored magnitude for all docs if ((magnitudeDocValues != null) && magnitudeDocValues.advanceExact(iterator.docID())) { magnitude = Float.intBitsToFloat((int) magnitudeDocValues.longValue()); } + float[] floats = vectorValues.vectorValue(iterator.index()); for (float aFloat : floats) { builder.appendFloat(aFloat * magnitude); } From 0cb587a0f817340de93654a3a718a08824e33e1e Mon Sep 17 00:00:00 2001 From: cdelgado Date: Wed, 13 Aug 2025 16:57:33 +0200 Subject: [PATCH 08/10] Refactor index names --- .../xpack/esql/DenseVectorFieldTypeIT.java | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 3291f5303ba2c..a5ecb9335525d 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -24,11 +24,13 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import static org.elasticsearch.index.IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING; import static org.elasticsearch.index.mapper.SourceFieldMapper.Mode.SYNTHETIC; @@ -36,16 +38,10 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { - public static final Set ALL_DENSE_VECTOR_INDEX_TYPES = Set.of( - "int8_hnsw", - "hnsw", - "int4_hnsw", - "bbq_hnsw", - "int8_flat", - "int4_flat", - "bbq_flat", - "flat" - ); + public static final Set ALL_DENSE_VECTOR_INDEX_TYPES = Arrays.stream(DenseVectorFieldMapper.VectorIndexType.values()) + .filter(DenseVectorFieldMapper.VectorIndexType::isEnabled) + .map(v -> v.getName().toLowerCase(Locale.ROOT)) + .collect(Collectors.toSet()); public static final float DELTA = 1e-7F; From e16c7cfecb1fe3aa84bfaf29b6eef9f52c993a34 Mon Sep 17 00:00:00 2001 From: cdelgado Date: Wed, 13 Aug 2025 18:00:24 +0200 Subject: [PATCH 09/10] Remove comment --- .../org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index a5ecb9335525d..94af0da65f74e 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -129,7 +129,6 @@ public void testRetrieveDenseVectorFieldData() { try (var resp = run(query)) { List> valuesList = EsqlTestUtils.getValuesList(resp); assertEquals(valuesList.size(), indexedVectors.size()); - // print all values for debugging valuesList.forEach(value -> { assertEquals(2, value.size()); Integer id = (Integer) value.get(0); From 8936d3328fe87997080e8b0cae2cc6ea6cc9ee19 Mon Sep 17 00:00:00 2001 From: cdelgado Date: Thu, 14 Aug 2025 13:29:49 +0200 Subject: [PATCH 10/10] Fix test --- .../xpack/esql/DenseVectorFieldTypeIT.java | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 743f96dc83dcb..ad482bfa1b60c 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -43,8 +43,10 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { .map(v -> v.getName().toLowerCase(Locale.ROOT)) .collect(Collectors.toSet()); - public static final Set NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Set.of("hnsw", "flat"); - + public static final Set NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES = Arrays.stream(DenseVectorFieldMapper.VectorIndexType.values()) + .filter(t -> t.isEnabled() && t.isQuantized() == false) + .map(v -> v.getName().toLowerCase(Locale.ROOT)) + .collect(Collectors.toSet()); public static final float DELTA = 1e-7F; @@ -58,15 +60,15 @@ public static Iterable parameters() throws Exception { List params = new ArrayList<>(); for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT)) { - // Test all similarities + // Test all similarities for (DenseVectorFieldMapper.VectorSimilarity similarity : DenseVectorFieldMapper.VectorSimilarity.values()) { - params.add(new Object[] { ElementType.FLOAT, similarity, true, false }); + params.add(new Object[] { elementType, similarity, true, false }); } // No indexing - params.add(new Object[] { ElementType.FLOAT, null, false, false }); + params.add(new Object[] { elementType, null, false, false }); // No indexing, synthetic source - params.add(new Object[] { ElementType.FLOAT, null, false, true }); + params.add(new Object[] { elementType, null, false, true }); } return params;