diff --git a/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java b/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java new file mode 100644 index 0000000..72f77e9 --- /dev/null +++ b/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java @@ -0,0 +1,303 @@ +package com.redis.vl.query; + +import com.redis.vl.utils.ArrayUtils; +import java.util.*; +import lombok.Getter; + +/** + * MultiVectorQuery allows for search over multiple vector fields in a document simultaneously. + * + *

The final score will be a weighted combination of the individual vector similarity scores + * following the formula: + * + *

score = (w_1 * score_1 + w_2 * score_2 + w_3 * score_3 + ... ) + * + *

Vectors may be of different size and datatype, but must be indexed using the 'cosine' + * distance_metric. + * + *

Ported from Python: redisvl/query/aggregate.py:257-400 (MultiVectorQuery class) + * + *

Python equivalent: + * + *

+ * from redisvl.query import MultiVectorQuery, Vector
+ *
+ * vector_1 = Vector(vector=[0.1, 0.2, 0.3], field_name="text_vector", dtype="float32", weight=0.7)
+ * vector_2 = Vector(vector=[0.5, 0.5], field_name="image_vector", dtype="bfloat16", weight=0.2)
+ *
+ * query = MultiVectorQuery(
+ *     vectors=[vector_1, vector_2],
+ *     filter_expression=None,
+ *     num_results=10,
+ *     return_fields=["field1", "field2"],
+ *     dialect=2
+ * )
+ * 
+ * + * Java equivalent: + * + *
+ * Vector vector1 = Vector.builder()
+ *     .vector(new float[]{0.1f, 0.2f, 0.3f})
+ *     .fieldName("text_vector")
+ *     .dtype("float32")
+ *     .weight(0.7)
+ *     .build();
+ *
+ * Vector vector2 = Vector.builder()
+ *     .vector(new float[]{0.5f, 0.5f})
+ *     .fieldName("image_vector")
+ *     .dtype("bfloat16")
+ *     .weight(0.2)
+ *     .build();
+ *
+ * MultiVectorQuery query = MultiVectorQuery.builder()
+ *     .vectors(Arrays.asList(vector1, vector2))
+ *     .numResults(10)
+ *     .returnFields(Arrays.asList("field1", "field2"))
+ *     .build();
+ * 
+ */ +@Getter +public final class MultiVectorQuery { + + /** Distance threshold for VECTOR_RANGE (hardcoded at 2.0 to include all eligible documents) */ + private static final double DISTANCE_THRESHOLD = 2.0; + + private final List vectors; + private final Filter filterExpression; + private final List returnFields; + private final int numResults; + private final int dialect; + + private MultiVectorQuery(Builder builder) { + // Validate before modifying state + if (builder.vectors == null || builder.vectors.isEmpty()) { + throw new IllegalArgumentException("At least one Vector is required"); + } + + // Validate all elements are Vector objects + for (Vector v : builder.vectors) { + if (v == null) { + throw new IllegalArgumentException("Vector list cannot contain null values"); + } + } + + this.vectors = List.copyOf(builder.vectors); + this.filterExpression = builder.filterExpression; + this.returnFields = + builder.returnFields != null ? List.copyOf(builder.returnFields) : List.of(); + this.numResults = builder.numResults; + this.dialect = builder.dialect; + } + + /** + * Create a new Builder for MultiVectorQuery. + * + * @return A new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Build the Redis query string for multi-vector search. + * + *

Format: {@code @field1:[VECTOR_RANGE 2.0 $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} | + * @field2:[VECTOR_RANGE 2.0 $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}} + * + * @return Query string + */ + public String toQueryString() { + List rangeQueries = new ArrayList<>(); + + for (int i = 0; i < vectors.size(); i++) { + Vector v = vectors.get(i); + String rangeQuery = + String.format( + "@%s:[VECTOR_RANGE %.1f $vector_%d]=>{$YIELD_DISTANCE_AS: distance_%d}", + v.getFieldName(), DISTANCE_THRESHOLD, i, i); + rangeQueries.add(rangeQuery); + } + + String baseQuery = String.join(" | ", rangeQueries); + + // Add filter expression if present + if (filterExpression != null) { + String filterStr = filterExpression.build(); + return String.format("(%s) AND (%s)", baseQuery, filterStr); + } + + return baseQuery; + } + + /** + * Convert to parameter map for query execution. + * + *

Returns map with vector_0, vector_1, etc. as byte arrays + * + * @return Parameters map + */ + public Map toParams() { + Map params = new HashMap<>(); + + for (int i = 0; i < vectors.size(); i++) { + Vector v = vectors.get(i); + byte[] vectorBytes = ArrayUtils.floatArrayToBytes(v.getVector()); + params.put(String.format("vector_%d", i), vectorBytes); + } + + return params; + } + + /** + * Get the scoring formula for combining vector similarities. + * + *

Formula: w_1 * score_1 + w_2 * score_2 + ... + * + *

Where score_i = (2 - distance_i) / 2 + * + * @return Scoring formula string + */ + public String getScoringFormula() { + List scoreTerms = new ArrayList<>(); + + for (int i = 0; i < vectors.size(); i++) { + Vector v = vectors.get(i); + scoreTerms.add(String.format("%.2f * score_%d", v.getWeight(), i)); + } + + return String.join(" + ", scoreTerms); + } + + /** + * Get individual score calculations. + * + *

Returns map of score_0=(2-distance_0)/2, score_1=(2-distance_1)/2, etc. + * + * @return Map of score names to calculation formulas + */ + public Map getScoreCalculations() { + Map calculations = new LinkedHashMap<>(); + + for (int i = 0; i < vectors.size(); i++) { + calculations.put(String.format("score_%d", i), String.format("(2 - distance_%d)/2", i)); + } + + return calculations; + } + + @Override + public String toString() { + return toQueryString(); + } + + /** Builder for creating MultiVectorQuery instances. */ + public static class Builder { + private List vectors; + private Filter filterExpression; + private List returnFields; + private int numResults = 10; // Default from Python + private int dialect = 2; // Default from Python + + Builder() {} + + /** + * Set the vectors to search (accepts a single Vector). + * + * @param vector Single Vector for search + * @return This builder + */ + public Builder vector(Vector vector) { + this.vectors = vector != null ? List.of(vector) : null; + return this; + } + + /** + * Set the vectors to search (accepts multiple Vectors as varargs). + * + * @param vectors Vectors for multi-vector search + * @return This builder + */ + public Builder vectors(Vector... vectors) { + this.vectors = vectors != null ? Arrays.asList(vectors) : null; + return this; + } + + /** + * Set the vectors to search (accepts a List of Vectors). + * + * @param vectors List of Vectors for multi-vector search + * @return This builder + */ + public Builder vectors(List vectors) { + this.vectors = vectors != null ? new ArrayList<>(vectors) : null; + return this; + } + + /** + * Set the filter expression. + * + * @param filterExpression Filter to apply + * @return This builder + */ + public Builder filterExpression(Filter filterExpression) { + this.filterExpression = filterExpression; + return this; + } + + /** + * Set the fields to return in results (varargs). + * + * @param fields Field names to return + * @return This builder + */ + public Builder returnFields(String... fields) { + this.returnFields = Arrays.asList(fields); + return this; + } + + /** + * Set the fields to return in results (list). + * + * @param fields List of field names to return + * @return This builder + */ + public Builder returnFields(List fields) { + this.returnFields = fields != null ? new ArrayList<>(fields) : null; + return this; + } + + /** + * Set the maximum number of results to return. + * + * @param numResults Maximum number of results + * @return This builder + */ + public Builder numResults(int numResults) { + this.numResults = numResults; + return this; + } + + /** + * Set the query dialect. + * + * @param dialect RediSearch dialect version + * @return This builder + */ + public Builder dialect(int dialect) { + this.dialect = dialect; + return this; + } + + /** + * Build the MultiVectorQuery instance. + * + * @return Configured MultiVectorQuery + * @throws IllegalArgumentException if vectors is null/empty or contains null values + */ + public MultiVectorQuery build() { + return new MultiVectorQuery(this); + } + } +} diff --git a/core/src/main/java/com/redis/vl/query/Vector.java b/core/src/main/java/com/redis/vl/query/Vector.java new file mode 100644 index 0000000..599201e --- /dev/null +++ b/core/src/main/java/com/redis/vl/query/Vector.java @@ -0,0 +1,197 @@ +package com.redis.vl.query; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import lombok.Getter; + +/** + * Simple object containing the necessary arguments to perform a multi-vector query. + * + *

Ported from Python: redisvl/query/aggregate.py:16-36 (Vector class) + * + *

Python equivalent: + * + *

+ * from redisvl.query import Vector
+ *
+ * vector = Vector(
+ *     vector=[0.1, 0.2, 0.3],
+ *     field_name="text_embedding",
+ *     dtype="float32",
+ *     weight=0.7
+ * )
+ * 
+ * + * Java equivalent: + * + *
+ * Vector vector = Vector.builder()
+ *     .vector(new float[]{0.1f, 0.2f, 0.3f})
+ *     .fieldName("text_embedding")
+ *     .dtype("float32")
+ *     .weight(0.7)
+ *     .build();
+ * 
+ */ +@Getter +public final class Vector { + + private static final Set VALID_DTYPES = + new HashSet<>( + Arrays.asList( + "BFLOAT16", + "bfloat16", + "FLOAT16", + "float16", + "FLOAT32", + "float32", + "FLOAT64", + "float64", + "INT8", + "int8", + "UINT8", + "uint8")); + + private final float[] vector; + private final String fieldName; + private final String dtype; + private final double weight; + + private Vector(Builder builder) { + // Validate before modifying state + if (builder.vector == null || builder.vector.length == 0) { + throw new IllegalArgumentException("Vector cannot be null or empty"); + } + if (builder.fieldName == null || builder.fieldName.trim().isEmpty()) { + throw new IllegalArgumentException("Field name cannot be null or empty"); + } + if (!VALID_DTYPES.contains(builder.dtype)) { + throw new IllegalArgumentException( + String.format( + "Invalid data type: %s. Supported types are: %s", builder.dtype, VALID_DTYPES)); + } + if (builder.weight <= 0) { + throw new IllegalArgumentException("Weight must be positive, got " + builder.weight); + } + + this.vector = Arrays.copyOf(builder.vector, builder.vector.length); + this.fieldName = builder.fieldName.trim(); + this.dtype = builder.dtype; + this.weight = builder.weight; + } + + /** + * Create a new Builder for Vector. + * + * @return A new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Get a copy of the vector. + * + * @return Copy of the vector + */ + public float[] getVector() { + return Arrays.copyOf(vector, vector.length); + } + + /** Builder for creating Vector instances. */ + public static class Builder { + private float[] vector; + private String fieldName; + private String dtype = "float32"; // Default from Python + private double weight = 1.0; // Default from Python + + Builder() {} + + /** + * Set the query vector. + * + * @param vector Query vector for similarity search + * @return This builder + */ + public Builder vector(float[] vector) { + this.vector = vector != null ? vector.clone() : null; + return this; + } + + /** + * Set the vector field name to search. + * + * @param fieldName Name of the vector field + * @return This builder + */ + public Builder fieldName(String fieldName) { + this.fieldName = fieldName; + return this; + } + + /** + * Set the vector data type. + * + *

Supported types: BFLOAT16, FLOAT16, FLOAT32, FLOAT64, INT8, UINT8 (case-insensitive) + * + * @param dtype Vector data type + * @return This builder + */ + public Builder dtype(String dtype) { + this.dtype = dtype; + return this; + } + + /** + * Set the weight for this vector in multi-vector scoring. + * + *

The final score will be a weighted combination: w_1 * score_1 + w_2 * score_2 + ... + * + * @param weight Weight value (must be positive) + * @return This builder + */ + public Builder weight(double weight) { + this.weight = weight; + return this; + } + + /** + * Build the Vector instance. + * + * @return Configured Vector + * @throws IllegalArgumentException if vector or fieldName is null/empty, dtype is invalid, or + * weight is non-positive + */ + public Vector build() { + return new Vector(this); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Vector vector1 = (Vector) o; + return Double.compare(vector1.weight, weight) == 0 + && Arrays.equals(vector, vector1.vector) + && fieldName.equals(vector1.fieldName) + && dtype.equals(vector1.dtype); + } + + @Override + public int hashCode() { + int result = Arrays.hashCode(vector); + result = 31 * result + fieldName.hashCode(); + result = 31 * result + dtype.hashCode(); + result = 31 * result + Double.hashCode(weight); + return result; + } + + @Override + public String toString() { + return String.format( + "Vector[fieldName=%s, dtype=%s, weight=%.2f, dimensions=%d]", + fieldName, dtype, weight, vector.length); + } +} diff --git a/core/src/test/java/com/redis/vl/query/MultiVectorQueryTest.java b/core/src/test/java/com/redis/vl/query/MultiVectorQueryTest.java new file mode 100644 index 0000000..9cae573 --- /dev/null +++ b/core/src/test/java/com/redis/vl/query/MultiVectorQueryTest.java @@ -0,0 +1,372 @@ +package com.redis.vl.query; + +import static org.assertj.core.api.Assertions.*; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for Vector and MultiVectorQuery classes (issue #402). + * + *

Ported from Python: tests/unit/test_aggregation_types.py + * + *

Python reference: PR #402 - Multi-vector query support + */ +@DisplayName("Multi-Vector Query Tests") +class MultiVectorQueryTest { + + private static final float[] SAMPLE_VECTOR = {0.1f, 0.2f, 0.3f}; + private static final float[] SAMPLE_VECTOR_2 = {0.4f, 0.5f}; + private static final float[] SAMPLE_VECTOR_3 = {0.6f, 0.7f, 0.8f, 0.9f}; + private static final float[] SAMPLE_VECTOR_4 = {0.2f, 0.3f}; + + // ========== Vector Class Tests ========== + + @Test + @DisplayName("Vector: Should create with required fields") + void testVectorCreation() { + Vector vector = Vector.builder().vector(SAMPLE_VECTOR).fieldName("text_embedding").build(); + + assertThat(vector.getVector()).isEqualTo(SAMPLE_VECTOR); + assertThat(vector.getFieldName()).isEqualTo("text_embedding"); + assertThat(vector.getDtype()).isEqualTo("float32"); // Default + assertThat(vector.getWeight()).isEqualTo(1.0); // Default + } + + @Test + @DisplayName("Vector: Should create with all fields") + void testVectorCreationAllFields() { + Vector vector = + Vector.builder() + .vector(SAMPLE_VECTOR) + .fieldName("text_embedding") + .dtype("float64") + .weight(0.7) + .build(); + + assertThat(vector.getVector()).isEqualTo(SAMPLE_VECTOR); + assertThat(vector.getFieldName()).isEqualTo("text_embedding"); + assertThat(vector.getDtype()).isEqualTo("float64"); + assertThat(vector.getWeight()).isEqualTo(0.7); + } + + @Test + @DisplayName("Vector: Should validate dtype") + void testVectorDtypeValidation() { + // Valid dtypes (case-insensitive) + assertThatCode( + () -> + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").dtype("float32").build()) + .doesNotThrowAnyException(); + + assertThatCode( + () -> + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").dtype("FLOAT64").build()) + .doesNotThrowAnyException(); + + assertThatCode( + () -> + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").dtype("bfloat16").build()) + .doesNotThrowAnyException(); + + // Invalid dtype + assertThatThrownBy( + () -> Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").dtype("float").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Invalid data type"); + } + + @Test + @DisplayName("Vector: Should require non-null vector") + void testVectorRequiresVector() { + assertThatThrownBy(() -> Vector.builder().fieldName("field").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Vector cannot be null or empty"); + } + + @Test + @DisplayName("Vector: Should require non-empty vector") + void testVectorRequiresNonEmptyVector() { + assertThatThrownBy(() -> Vector.builder().vector(new float[] {}).fieldName("field").build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Vector cannot be null or empty"); + } + + @Test + @DisplayName("Vector: Should require non-null field name") + void testVectorRequiresFieldName() { + assertThatThrownBy(() -> Vector.builder().vector(SAMPLE_VECTOR).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Field name cannot be null or empty"); + } + + @Test + @DisplayName("Vector: Should require positive weight") + void testVectorRequiresPositiveWeight() { + assertThatThrownBy( + () -> Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").weight(0.0).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Weight must be positive"); + + assertThatThrownBy( + () -> Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").weight(-0.5).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Weight must be positive"); + } + + @Test + @DisplayName("Vector: Should be immutable (defensive copy)") + void testVectorImmutability() { + float[] original = SAMPLE_VECTOR.clone(); + Vector vector = Vector.builder().vector(original).fieldName("field").build(); + + // Modify original + original[0] = 999.0f; + + // Vector should not be affected + assertThat(vector.getVector()[0]).isEqualTo(SAMPLE_VECTOR[0]); + } + + // ========== MultiVectorQuery Class Tests ========== + + @Test + @DisplayName("MultiVectorQuery: Should require at least one vector") + void testMultiVectorQueryRequiresVectors() { + assertThatThrownBy(() -> MultiVectorQuery.builder().build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("At least one Vector is required"); + } + + @Test + @DisplayName("MultiVectorQuery: Should reject null vectors in list") + void testMultiVectorQueryRejectsNullInList() { + Vector vector1 = Vector.builder().vector(SAMPLE_VECTOR).fieldName("field1").build(); + + assertThatThrownBy( + () -> MultiVectorQuery.builder().vectors(Arrays.asList(vector1, null)).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("cannot contain null values"); + } + + @Test + @DisplayName("MultiVectorQuery: Should accept single vector") + void testMultiVectorQuerySingleVector() { + Vector vector = Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vector(vector).build(); + + assertThat(query.getVectors()).hasSize(1); + assertThat(query.getVectors().get(0)).isEqualTo(vector); + assertThat(query.getNumResults()).isEqualTo(10); // Default + assertThat(query.getDialect()).isEqualTo(2); // Default + assertThat(query.getFilterExpression()).isNull(); + assertThat(query.getReturnFields()).isEmpty(); + } + + @Test + @DisplayName("MultiVectorQuery: Should accept multiple vectors") + void testMultiVectorQueryMultipleVectors() { + Vector vector1 = + Vector.builder() + .vector(SAMPLE_VECTOR) + .fieldName("field_1") + .weight(0.2) + .dtype("float32") + .build(); + + Vector vector2 = + Vector.builder() + .vector(SAMPLE_VECTOR_2) + .fieldName("field_2") + .weight(0.5) + .dtype("float32") + .build(); + + Vector vector3 = + Vector.builder() + .vector(SAMPLE_VECTOR_3) + .fieldName("field_3") + .weight(0.6) + .dtype("float32") + .build(); + + Vector vector4 = + Vector.builder() + .vector(SAMPLE_VECTOR_4) + .fieldName("field_4") + .weight(0.1) + .dtype("float32") + .build(); + + List vectors = Arrays.asList(vector1, vector2, vector3, vector4); + MultiVectorQuery query = MultiVectorQuery.builder().vectors(vectors).build(); + + assertThat(query.getVectors()).hasSize(4); + assertThat(query.getVectors()).isEqualTo(vectors); + } + + @Test + @DisplayName("MultiVectorQuery: Should override defaults") + void testMultiVectorQueryOverrideDefaults() { + Vector vector = Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").build(); + + Filter filter = Filter.tag("user_group", "group_A", "group_C"); + + MultiVectorQuery query = + MultiVectorQuery.builder() + .vector(vector) + .filterExpression(filter) + .numResults(5) + .returnFields("field_1", "user_name", "address") + .dialect(3) + .build(); + + assertThat(query.getFilterExpression()).isEqualTo(filter); + assertThat(query.getNumResults()).isEqualTo(5); + assertThat(query.getReturnFields()).containsExactly("field_1", "user_name", "address"); + assertThat(query.getDialect()).isEqualTo(3); + } + + @Test + @DisplayName("MultiVectorQuery: Should build correct query string") + void testMultiVectorQueryString() { + // Python test: test_multi_vector_query_string + String field1 = "text embedding"; + String field2 = "image embedding"; + double weight1 = 0.2; + double weight2 = 0.7; + + Vector vector1 = + Vector.builder().vector(SAMPLE_VECTOR_2).fieldName(field1).weight(weight1).build(); + + Vector vector2 = + Vector.builder().vector(SAMPLE_VECTOR_3).fieldName(field2).weight(weight2).build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vectors(vector1, vector2).build(); + + String queryString = query.toQueryString(); + + // Expected format: + // @field1:[VECTOR_RANGE 2.0 $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} | + // @field2:[VECTOR_RANGE 2.0 $vector_1]=>{$YIELD_DISTANCE_AS: distance_1} + assertThat(queryString) + .contains(String.format("@%s:[VECTOR_RANGE 2.0 $vector_0]", field1)) + .contains("{$YIELD_DISTANCE_AS: distance_0}") + .contains(String.format("@%s:[VECTOR_RANGE 2.0 $vector_1]", field2)) + .contains("{$YIELD_DISTANCE_AS: distance_1}") + .contains(" | "); + } + + @Test + @DisplayName("MultiVectorQuery: Should build params map") + void testMultiVectorQueryParams() { + Vector vector1 = Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").build(); + + Vector vector2 = Vector.builder().vector(SAMPLE_VECTOR_2).fieldName("field_2").build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vectors(vector1, vector2).build(); + + Map params = query.toParams(); + + assertThat(params).containsKeys("vector_0", "vector_1"); + assertThat(params.get("vector_0")).isInstanceOf(byte[].class); + assertThat(params.get("vector_1")).isInstanceOf(byte[].class); + } + + @Test + @DisplayName("MultiVectorQuery: Should generate scoring formula") + void testMultiVectorQueryScoringFormula() { + Vector vector1 = + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").weight(0.2).build(); + + Vector vector2 = + Vector.builder().vector(SAMPLE_VECTOR_2).fieldName("field_2").weight(0.7).build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vectors(vector1, vector2).build(); + + String formula = query.getScoringFormula(); + + // Expected: "0.20 * score_0 + 0.70 * score_1" + assertThat(formula).contains("0.20 * score_0").contains("0.70 * score_1").contains(" + "); + } + + @Test + @DisplayName("MultiVectorQuery: Should generate score calculations") + void testMultiVectorQueryScoreCalculations() { + Vector vector1 = Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").build(); + + Vector vector2 = Vector.builder().vector(SAMPLE_VECTOR_2).fieldName("field_2").build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vectors(vector1, vector2).build(); + + Map calculations = query.getScoreCalculations(); + + assertThat(calculations).hasSize(2); + assertThat(calculations.get("score_0")).isEqualTo("(2 - distance_0)/2"); + assertThat(calculations.get("score_1")).isEqualTo("(2 - distance_1)/2"); + } + + @Test + @DisplayName("MultiVectorQuery: Should include filter in query string") + void testMultiVectorQueryWithFilter() { + Vector vector = Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").build(); + + Filter filter = Filter.tag("category", "electronics"); + + MultiVectorQuery query = + MultiVectorQuery.builder().vector(vector).filterExpression(filter).build(); + + String queryString = query.toQueryString(); + + // Should wrap base query and filter with AND + assertThat(queryString).contains(" AND "); + assertThat(queryString).contains(filter.build()); + } + + @Test + @DisplayName("MultiVectorQuery: Should support varargs vectors") + void testMultiVectorQueryVarargs() { + Vector vector1 = Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").build(); + + Vector vector2 = Vector.builder().vector(SAMPLE_VECTOR_2).fieldName("field_2").build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vectors(vector1, vector2).build(); + + assertThat(query.getVectors()).hasSize(2); + } + + @Test + @DisplayName("Vector: Should implement equals and hashCode") + void testVectorEqualsAndHashCode() { + Vector vector1 = + Vector.builder() + .vector(SAMPLE_VECTOR) + .fieldName("field") + .dtype("float32") + .weight(0.7) + .build(); + + Vector vector2 = + Vector.builder() + .vector(SAMPLE_VECTOR) + .fieldName("field") + .dtype("float32") + .weight(0.7) + .build(); + + Vector vector3 = + Vector.builder() + .vector(SAMPLE_VECTOR_2) + .fieldName("field") + .dtype("float32") + .weight(0.7) + .build(); + + assertThat(vector1).isEqualTo(vector2); + assertThat(vector1.hashCode()).isEqualTo(vector2.hashCode()); + assertThat(vector1).isNotEqualTo(vector3); + } +}