diff --git a/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java b/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java new file mode 100644 index 0000000..72f77e9 --- /dev/null +++ b/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java @@ -0,0 +1,303 @@ +package com.redis.vl.query; + +import com.redis.vl.utils.ArrayUtils; +import java.util.*; +import lombok.Getter; + +/** + * MultiVectorQuery allows for search over multiple vector fields in a document simultaneously. + * + *
The final score will be a weighted combination of the individual vector similarity scores + * following the formula: + * + *
score = (w_1 * score_1 + w_2 * score_2 + w_3 * score_3 + ... ) + * + *
Vectors may be of different size and datatype, but must be indexed using the 'cosine' + * distance_metric. + * + *
Ported from Python: redisvl/query/aggregate.py:257-400 (MultiVectorQuery class) + * + *
Python equivalent: + * + *
+ * from redisvl.query import MultiVectorQuery, Vector + * + * vector_1 = Vector(vector=[0.1, 0.2, 0.3], field_name="text_vector", dtype="float32", weight=0.7) + * vector_2 = Vector(vector=[0.5, 0.5], field_name="image_vector", dtype="bfloat16", weight=0.2) + * + * query = MultiVectorQuery( + * vectors=[vector_1, vector_2], + * filter_expression=None, + * num_results=10, + * return_fields=["field1", "field2"], + * dialect=2 + * ) + *+ * + * Java equivalent: + * + *
+ * Vector vector1 = Vector.builder()
+ * .vector(new float[]{0.1f, 0.2f, 0.3f})
+ * .fieldName("text_vector")
+ * .dtype("float32")
+ * .weight(0.7)
+ * .build();
+ *
+ * Vector vector2 = Vector.builder()
+ * .vector(new float[]{0.5f, 0.5f})
+ * .fieldName("image_vector")
+ * .dtype("bfloat16")
+ * .weight(0.2)
+ * .build();
+ *
+ * MultiVectorQuery query = MultiVectorQuery.builder()
+ * .vectors(Arrays.asList(vector1, vector2))
+ * .numResults(10)
+ * .returnFields(Arrays.asList("field1", "field2"))
+ * .build();
+ *
+ */
+@Getter
+public final class MultiVectorQuery {
+
+ /** Distance threshold for VECTOR_RANGE (hardcoded at 2.0 to include all eligible documents) */
+ private static final double DISTANCE_THRESHOLD = 2.0;
+
+ private final ListFormat: {@code @field1:[VECTOR_RANGE 2.0 $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} |
+ * @field2:[VECTOR_RANGE 2.0 $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}}
+ *
+ * @return Query string
+ */
+ public String toQueryString() {
+ List Returns map with vector_0, vector_1, etc. as byte arrays
+ *
+ * @return Parameters map
+ */
+ public Map Formula: w_1 * score_1 + w_2 * score_2 + ...
+ *
+ * Where score_i = (2 - distance_i) / 2
+ *
+ * @return Scoring formula string
+ */
+ public String getScoringFormula() {
+ List Returns map of score_0=(2-distance_0)/2, score_1=(2-distance_1)/2, etc.
+ *
+ * @return Map of score names to calculation formulas
+ */
+ public Map Ported from Python: redisvl/query/aggregate.py:16-36 (Vector class)
+ *
+ * Python equivalent:
+ *
+ * Supported types: BFLOAT16, FLOAT16, FLOAT32, FLOAT64, INT8, UINT8 (case-insensitive)
+ *
+ * @param dtype Vector data type
+ * @return This builder
+ */
+ public Builder dtype(String dtype) {
+ this.dtype = dtype;
+ return this;
+ }
+
+ /**
+ * Set the weight for this vector in multi-vector scoring.
+ *
+ * The final score will be a weighted combination: w_1 * score_1 + w_2 * score_2 + ...
+ *
+ * @param weight Weight value (must be positive)
+ * @return This builder
+ */
+ public Builder weight(double weight) {
+ this.weight = weight;
+ return this;
+ }
+
+ /**
+ * Build the Vector instance.
+ *
+ * @return Configured Vector
+ * @throws IllegalArgumentException if vector or fieldName is null/empty, dtype is invalid, or
+ * weight is non-positive
+ */
+ public Vector build() {
+ return new Vector(this);
+ }
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Vector vector1 = (Vector) o;
+ return Double.compare(vector1.weight, weight) == 0
+ && Arrays.equals(vector, vector1.vector)
+ && fieldName.equals(vector1.fieldName)
+ && dtype.equals(vector1.dtype);
+ }
+
+ @Override
+ public int hashCode() {
+ int result = Arrays.hashCode(vector);
+ result = 31 * result + fieldName.hashCode();
+ result = 31 * result + dtype.hashCode();
+ result = 31 * result + Double.hashCode(weight);
+ return result;
+ }
+
+ @Override
+ public String toString() {
+ return String.format(
+ "Vector[fieldName=%s, dtype=%s, weight=%.2f, dimensions=%d]",
+ fieldName, dtype, weight, vector.length);
+ }
+}
diff --git a/core/src/test/java/com/redis/vl/query/MultiVectorQueryTest.java b/core/src/test/java/com/redis/vl/query/MultiVectorQueryTest.java
new file mode 100644
index 0000000..9cae573
--- /dev/null
+++ b/core/src/test/java/com/redis/vl/query/MultiVectorQueryTest.java
@@ -0,0 +1,372 @@
+package com.redis.vl.query;
+
+import static org.assertj.core.api.Assertions.*;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+/**
+ * Unit tests for Vector and MultiVectorQuery classes (issue #402).
+ *
+ * Ported from Python: tests/unit/test_aggregation_types.py
+ *
+ * Python reference: PR #402 - Multi-vector query support
+ */
+@DisplayName("Multi-Vector Query Tests")
+class MultiVectorQueryTest {
+
+ private static final float[] SAMPLE_VECTOR = {0.1f, 0.2f, 0.3f};
+ private static final float[] SAMPLE_VECTOR_2 = {0.4f, 0.5f};
+ private static final float[] SAMPLE_VECTOR_3 = {0.6f, 0.7f, 0.8f, 0.9f};
+ private static final float[] SAMPLE_VECTOR_4 = {0.2f, 0.3f};
+
+ // ========== Vector Class Tests ==========
+
+ @Test
+ @DisplayName("Vector: Should create with required fields")
+ void testVectorCreation() {
+ Vector vector = Vector.builder().vector(SAMPLE_VECTOR).fieldName("text_embedding").build();
+
+ assertThat(vector.getVector()).isEqualTo(SAMPLE_VECTOR);
+ assertThat(vector.getFieldName()).isEqualTo("text_embedding");
+ assertThat(vector.getDtype()).isEqualTo("float32"); // Default
+ assertThat(vector.getWeight()).isEqualTo(1.0); // Default
+ }
+
+ @Test
+ @DisplayName("Vector: Should create with all fields")
+ void testVectorCreationAllFields() {
+ Vector vector =
+ Vector.builder()
+ .vector(SAMPLE_VECTOR)
+ .fieldName("text_embedding")
+ .dtype("float64")
+ .weight(0.7)
+ .build();
+
+ assertThat(vector.getVector()).isEqualTo(SAMPLE_VECTOR);
+ assertThat(vector.getFieldName()).isEqualTo("text_embedding");
+ assertThat(vector.getDtype()).isEqualTo("float64");
+ assertThat(vector.getWeight()).isEqualTo(0.7);
+ }
+
+ @Test
+ @DisplayName("Vector: Should validate dtype")
+ void testVectorDtypeValidation() {
+ // Valid dtypes (case-insensitive)
+ assertThatCode(
+ () ->
+ Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").dtype("float32").build())
+ .doesNotThrowAnyException();
+
+ assertThatCode(
+ () ->
+ Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").dtype("FLOAT64").build())
+ .doesNotThrowAnyException();
+
+ assertThatCode(
+ () ->
+ Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").dtype("bfloat16").build())
+ .doesNotThrowAnyException();
+
+ // Invalid dtype
+ assertThatThrownBy(
+ () -> Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").dtype("float").build())
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Invalid data type");
+ }
+
+ @Test
+ @DisplayName("Vector: Should require non-null vector")
+ void testVectorRequiresVector() {
+ assertThatThrownBy(() -> Vector.builder().fieldName("field").build())
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Vector cannot be null or empty");
+ }
+
+ @Test
+ @DisplayName("Vector: Should require non-empty vector")
+ void testVectorRequiresNonEmptyVector() {
+ assertThatThrownBy(() -> Vector.builder().vector(new float[] {}).fieldName("field").build())
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Vector cannot be null or empty");
+ }
+
+ @Test
+ @DisplayName("Vector: Should require non-null field name")
+ void testVectorRequiresFieldName() {
+ assertThatThrownBy(() -> Vector.builder().vector(SAMPLE_VECTOR).build())
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Field name cannot be null or empty");
+ }
+
+ @Test
+ @DisplayName("Vector: Should require positive weight")
+ void testVectorRequiresPositiveWeight() {
+ assertThatThrownBy(
+ () -> Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").weight(0.0).build())
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Weight must be positive");
+
+ assertThatThrownBy(
+ () -> Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").weight(-0.5).build())
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Weight must be positive");
+ }
+
+ @Test
+ @DisplayName("Vector: Should be immutable (defensive copy)")
+ void testVectorImmutability() {
+ float[] original = SAMPLE_VECTOR.clone();
+ Vector vector = Vector.builder().vector(original).fieldName("field").build();
+
+ // Modify original
+ original[0] = 999.0f;
+
+ // Vector should not be affected
+ assertThat(vector.getVector()[0]).isEqualTo(SAMPLE_VECTOR[0]);
+ }
+
+ // ========== MultiVectorQuery Class Tests ==========
+
+ @Test
+ @DisplayName("MultiVectorQuery: Should require at least one vector")
+ void testMultiVectorQueryRequiresVectors() {
+ assertThatThrownBy(() -> MultiVectorQuery.builder().build())
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("At least one Vector is required");
+ }
+
+ @Test
+ @DisplayName("MultiVectorQuery: Should reject null vectors in list")
+ void testMultiVectorQueryRejectsNullInList() {
+ Vector vector1 = Vector.builder().vector(SAMPLE_VECTOR).fieldName("field1").build();
+
+ assertThatThrownBy(
+ () -> MultiVectorQuery.builder().vectors(Arrays.asList(vector1, null)).build())
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("cannot contain null values");
+ }
+
+ @Test
+ @DisplayName("MultiVectorQuery: Should accept single vector")
+ void testMultiVectorQuerySingleVector() {
+ Vector vector = Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").build();
+
+ MultiVectorQuery query = MultiVectorQuery.builder().vector(vector).build();
+
+ assertThat(query.getVectors()).hasSize(1);
+ assertThat(query.getVectors().get(0)).isEqualTo(vector);
+ assertThat(query.getNumResults()).isEqualTo(10); // Default
+ assertThat(query.getDialect()).isEqualTo(2); // Default
+ assertThat(query.getFilterExpression()).isNull();
+ assertThat(query.getReturnFields()).isEmpty();
+ }
+
+ @Test
+ @DisplayName("MultiVectorQuery: Should accept multiple vectors")
+ void testMultiVectorQueryMultipleVectors() {
+ Vector vector1 =
+ Vector.builder()
+ .vector(SAMPLE_VECTOR)
+ .fieldName("field_1")
+ .weight(0.2)
+ .dtype("float32")
+ .build();
+
+ Vector vector2 =
+ Vector.builder()
+ .vector(SAMPLE_VECTOR_2)
+ .fieldName("field_2")
+ .weight(0.5)
+ .dtype("float32")
+ .build();
+
+ Vector vector3 =
+ Vector.builder()
+ .vector(SAMPLE_VECTOR_3)
+ .fieldName("field_3")
+ .weight(0.6)
+ .dtype("float32")
+ .build();
+
+ Vector vector4 =
+ Vector.builder()
+ .vector(SAMPLE_VECTOR_4)
+ .fieldName("field_4")
+ .weight(0.1)
+ .dtype("float32")
+ .build();
+
+ List
+ * from redisvl.query import Vector
+ *
+ * vector = Vector(
+ * vector=[0.1, 0.2, 0.3],
+ * field_name="text_embedding",
+ * dtype="float32",
+ * weight=0.7
+ * )
+ *
+ *
+ * Java equivalent:
+ *
+ *
+ * Vector vector = Vector.builder()
+ * .vector(new float[]{0.1f, 0.2f, 0.3f})
+ * .fieldName("text_embedding")
+ * .dtype("float32")
+ * .weight(0.7)
+ * .build();
+ *
+ */
+@Getter
+public final class Vector {
+
+ private static final Set