Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/132689.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 132689
summary: Add support for dimensions in google vertex ai request
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public GoogleVertexAiEmbeddingsModel(GoogleVertexAiEmbeddingsModel model, Google
}

// Should only be used directly for testing
GoogleVertexAiEmbeddingsModel(
public GoogleVertexAiEmbeddingsModel(
String inferenceEntityId,
TaskType taskType,
String service,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,14 @@ public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(model.nonStreamingUri());

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new GoogleVertexAiEmbeddingsRequestEntity(truncationResult.input(), inputType, model.getTaskSettings()))
.getBytes(StandardCharsets.UTF_8)
Strings.toString(
new GoogleVertexAiEmbeddingsRequestEntity(
truncationResult.input(),
inputType,
model.getTaskSettings(),
model.getServiceSettings()
)
).getBytes(StandardCharsets.UTF_8)
);

httpPost.setEntity(byteEntity);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;

import java.io.IOException;
Expand All @@ -21,13 +22,15 @@
public record GoogleVertexAiEmbeddingsRequestEntity(
List<String> inputs,
InputType inputType,
GoogleVertexAiEmbeddingsTaskSettings taskSettings
GoogleVertexAiEmbeddingsTaskSettings taskSettings,
GoogleVertexAiEmbeddingsServiceSettings serviceSettings
) implements ToXContentObject {

private static final String INSTANCES_FIELD = "instances";
private static final String CONTENT_FIELD = "content";
private static final String PARAMETERS_FIELD = "parameters";
private static final String AUTO_TRUNCATE_FIELD = "autoTruncate";
private static final String OUTPUT_DIMENSIONALITY_FIELD = "outputDimensionality";
private static final String TASK_TYPE_FIELD = "task_type";

private static final String CLASSIFICATION_TASK_TYPE = "CLASSIFICATION";
Expand All @@ -38,6 +41,7 @@ public record GoogleVertexAiEmbeddingsRequestEntity(
public GoogleVertexAiEmbeddingsRequestEntity {
Objects.requireNonNull(inputs);
Objects.requireNonNull(taskSettings);
Objects.requireNonNull(serviceSettings);
}

@Override
Expand All @@ -62,15 +66,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

builder.endArray();

if (taskSettings.autoTruncate() != null) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change will allow an empty object for parameters to be sent. I verified that sending

"parameters": {}

succeeds.

builder.startObject(PARAMETERS_FIELD);
{
builder.startObject(PARAMETERS_FIELD);
{
if (taskSettings.autoTruncate() != null) {
builder.field(AUTO_TRUNCATE_FIELD, taskSettings.autoTruncate());
}
builder.endObject();
if (serviceSettings.dimensionsSetByUser()) {
builder.field(OUTPUT_DIMENSIONALITY_FIELD, serviceSettings.dimensions());
}
}
builder.endObject();

builder.endObject();

return builder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;

import java.io.IOException;
Expand All @@ -26,7 +27,8 @@ public void testToXContent_SingleEmbeddingRequest_WritesAllFields() throws IOExc
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
List.of("abc"),
null,
new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING)
new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING),
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", true, null, 10, null, null)
);

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
Expand All @@ -42,17 +44,19 @@ public void testToXContent_SingleEmbeddingRequest_WritesAllFields() throws IOExc
}
],
"parameters": {
"autoTruncate": true
"autoTruncate": true,
"outputDimensionality": 10
}
}
"""));
}

public void testToXContent_SingleEmbeddingRequest_DoesNotWriteAutoTruncationIfNotDefined() throws IOException {
public void testToXContent_SingleEmbeddingRequest_DoesNotWriteUndefinedFields() throws IOException {
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
List.of("abc"),
InputType.INTERNAL_INGEST,
new GoogleVertexAiEmbeddingsTaskSettings(null, null)
new GoogleVertexAiEmbeddingsTaskSettings(null, null),
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
);

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
Expand All @@ -66,13 +70,45 @@ public void testToXContent_SingleEmbeddingRequest_DoesNotWriteAutoTruncationIfNo
"content": "abc",
"task_type": "RETRIEVAL_DOCUMENT"
}
]
],
"parameters": {
}
}
"""));
}

public void testToXContent_SingleEmbeddingRequest_DoesNotWriteUndefinedFields_DimensionsSetByUserFalse() throws IOException {
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
List.of("abc"),
InputType.INTERNAL_INGEST,
new GoogleVertexAiEmbeddingsTaskSettings(null, null),
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, 10, null, null)
);

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);

assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"instances": [
{
"content": "abc",
"task_type": "RETRIEVAL_DOCUMENT"
}
],
"parameters": {}
}
"""));
}

public void testToXContent_SingleEmbeddingRequest_DoesNotWriteInputTypeIfNotDefined() throws IOException {
var entity = new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc"), null, new GoogleVertexAiEmbeddingsTaskSettings(false, null));
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
List.of("abc"),
null,
new GoogleVertexAiEmbeddingsTaskSettings(false, null),
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
);

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
Expand All @@ -96,7 +132,8 @@ public void testToXContent_MultipleEmbeddingsRequest_WritesAllFields() throws IO
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
List.of("abc", "def"),
InputType.INTERNAL_SEARCH,
new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING)
new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING),
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", true, null, 10, null, null)
);

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
Expand All @@ -116,7 +153,8 @@ public void testToXContent_MultipleEmbeddingsRequest_WritesAllFields() throws IO
}
],
"parameters": {
"autoTruncate": true
"autoTruncate": true,
"outputDimensionality": 10
}
}
"""));
Expand All @@ -126,7 +164,8 @@ public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteInputTypeIfNotD
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
List.of("abc", "def"),
null,
new GoogleVertexAiEmbeddingsTaskSettings(true, null)
new GoogleVertexAiEmbeddingsTaskSettings(true, null),
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
);

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
Expand Down Expand Up @@ -154,7 +193,8 @@ public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteAutoTruncationI
var entity = new GoogleVertexAiEmbeddingsRequestEntity(
List.of("abc", "def"),
null,
new GoogleVertexAiEmbeddingsTaskSettings(null, InputType.CLASSIFICATION)
new GoogleVertexAiEmbeddingsTaskSettings(null, InputType.CLASSIFICATION),
new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
);

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
Expand All @@ -172,12 +212,14 @@ public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteAutoTruncationI
"content": "def",
"task_type": "CLASSIFICATION"
}
]
],
"parameters": {
}
}
"""));
}

public void testToXContent_ThrowsIfTaskSettingsIsNull() {
expectThrows(NullPointerException.class, () -> new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), null, null));
expectThrows(NullPointerException.class, () -> new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), null, null, null));
}
}
Loading