Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ public void onNodeStarted() {

@Override
public Collection<RestHeaderDefinition> getRestHeaders() {
return Set.of(new RestHeaderDefinition(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, false));
return Set.of(new RestHeaderDefinition(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, true));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@

package org.elasticsearch.xpack.inference.services.elastic;

import org.elasticsearch.inference.InputType;

import java.util.Locale;

import static org.elasticsearch.inference.InputType.INTERNAL_INGEST;

/**
* Specifies the usage context for a request to the Elastic Inference Service.
* This helps to determine the type of resources that are allocated in the Elastic Inference Service for the particular request.
Expand All @@ -24,4 +28,25 @@ public String toString() {
return name().toLowerCase(Locale.ROOT);
}

public static ElasticInferenceServiceUsageContext fromInputType(InputType inputType) {
switch (inputType) {
case SEARCH, INTERNAL_SEARCH -> {
return ElasticInferenceServiceUsageContext.SEARCH;
}
case INGEST, INTERNAL_INGEST -> {
return ElasticInferenceServiceUsageContext.INGEST;
}
default -> {
return ElasticInferenceServiceUsageContext.UNSPECIFIED;
}
}
}

public String productUseCaseHeaderValue() {
return switch (this) {
case SEARCH -> "internal_search";
case INGEST -> "internal_ingest";
case UNSPECIFIED -> "unspecified";
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;
Expand All @@ -25,7 +26,7 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceSparseEmbeddingsRequest.inputTypeToUsageContext;
import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER;

public class ElasticInferenceServiceDenseTextEmbeddingsRequest extends ElasticInferenceServiceRequest {

Expand Down Expand Up @@ -53,7 +54,7 @@ public ElasticInferenceServiceDenseTextEmbeddingsRequest(
@Override
public HttpRequestBase createHttpRequestBase() {
var httpPost = new HttpPost(uri);
var usageContext = inputTypeToUsageContext(inputType);
var usageContext = ElasticInferenceServiceUsageContext.fromInputType(inputType);
var requestEntity = Strings.toString(
new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(inputs, model.getServiceSettings().modelId(), usageContext)
);
Expand All @@ -63,6 +64,7 @@ public HttpRequestBase createHttpRequestBase() {

traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
httpPost.setHeader(new BasicHeader(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, usageContext.productUseCaseHeaderValue()));

return httpPost;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public final HttpRequest createHttpRequest() {
}

if (Objects.nonNull(productUseCase) && productUseCase.isEmpty() == false) {
request.setHeader(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, metadata.productUseCase());
request.addHeader(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, metadata.productUseCase());
}

return new HttpRequest(request, getInferenceEntityId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import java.nio.charset.StandardCharsets;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER;

public class ElasticInferenceServiceSparseEmbeddingsRequest extends ElasticInferenceServiceRequest {

private final URI uri;
Expand Down Expand Up @@ -55,7 +57,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest(
@Override
public HttpRequestBase createHttpRequestBase() {
var httpPost = new HttpPost(uri);
var usageContext = inputTypeToUsageContext(inputType);
var usageContext = ElasticInferenceServiceUsageContext.fromInputType(inputType);
var requestEntity = Strings.toString(
new ElasticInferenceServiceSparseEmbeddingsRequestEntity(
truncationResult.input(),
Expand All @@ -69,6 +71,7 @@ public HttpRequestBase createHttpRequestBase() {

traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
httpPost.setHeader(new BasicHeader(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, usageContext.productUseCaseHeaderValue()));

return httpPost;
}
Expand Down Expand Up @@ -104,19 +107,4 @@ public Request truncate() {
public boolean[] getTruncationInfo() {
return truncationResult.truncated().clone();
}

// visible for testing
static ElasticInferenceServiceUsageContext inputTypeToUsageContext(InputType inputType) {
switch (inputType) {
case SEARCH, INTERNAL_SEARCH -> {
return ElasticInferenceServiceUsageContext.SEARCH;
}
case INGEST, INTERNAL_INGEST -> {
return ElasticInferenceServiceUsageContext.INGEST;
}
default -> {
return ElasticInferenceServiceUsageContext.UNSPECIFIED;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.isA;
Expand Down Expand Up @@ -703,7 +704,8 @@ public void testInfer_PropagatesProductUseCaseHeader() throws IOException {
assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType()));

// Check that the product use case header was set correctly
assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase));
var productUseCaseHeaders = request.getHeaders().get(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER);
assertThat(productUseCaseHeaders, contains("internal_search", productUseCase));

// Verify request body
var requestMap = entityAsMap(request.getBody());
Expand Down Expand Up @@ -832,7 +834,8 @@ public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException
assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));

// Check that the product use case header was set correctly
assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase));
var productUseCaseHeaders = request.getHeaders().get(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER);
assertThat(productUseCaseHeaders, contains("internal_ingest", productUseCase));

} finally {
// Clean up the thread context
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.elastic;

import org.elasticsearch.inference.InputType;
import org.elasticsearch.test.ESTestCase;
import org.hamcrest.Matchers;

import static org.hamcrest.Matchers.equalTo;

public class ElasticInferenceServiceUsageContextTests extends ESTestCase {

public void testInputTypeToUsageContext_Search() {
assertThat(
ElasticInferenceServiceUsageContext.fromInputType(InputType.SEARCH),
equalTo(ElasticInferenceServiceUsageContext.SEARCH)
);
}

public void testInputTypeToUsageContext_Ingest() {
assertThat(
ElasticInferenceServiceUsageContext.fromInputType(InputType.INGEST),
equalTo(ElasticInferenceServiceUsageContext.INGEST)
);
}

public void testInputTypeToUsageContext_Unspecified() {
assertThat(
ElasticInferenceServiceUsageContext.fromInputType(InputType.UNSPECIFIED),
equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED)
);
}

public void testInputTypeToUsageContext_Unknown_DefaultToUnspecified() {
assertThat(
ElasticInferenceServiceUsageContext.fromInputType(InputType.CLASSIFICATION),
equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED)
);
assertThat(
ElasticInferenceServiceUsageContext.fromInputType(InputType.CLUSTERING),
equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED)
);
}

public void testProductUseCase() {
assertThat(ElasticInferenceServiceUsageContext.SEARCH.productUseCaseHeaderValue(), Matchers.is("internal_search"));
assertThat(ElasticInferenceServiceUsageContext.INGEST.productUseCaseHeaderValue(), Matchers.is("internal_ingest"));
assertThat(ElasticInferenceServiceUsageContext.UNSPECIFIED.productUseCaseHeaderValue(), Matchers.is("unspecified"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.io.IOException;
import java.util.List;

import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata;
import static org.hamcrest.Matchers.aMapWithSize;
Expand Down Expand Up @@ -146,6 +147,32 @@ public void testGetTruncationInfo_ReturnsNull() {
assertThat(request.getTruncationInfo(), is(nullValue()));
}

public void testDecorate_HttpRequest_WithProductUseCase() {
var input = "elastic";
var modelId = "my-model-id";
var url = "http://eis-gateway.com";

for (var inputType : List.of(InputType.INTERNAL_SEARCH, InputType.INTERNAL_INGEST, InputType.UNSPECIFIED)) {
var request = new ElasticInferenceServiceDenseTextEmbeddingsRequest(
ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(url, modelId),
List.of(input),
new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)),
new ElasticInferenceServiceRequestMetadata("my-product-origin", "my-product-use-case-from-metadata"),
inputType
);

var httpRequest = request.createHttpRequest();

assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();

var headers = httpPost.getHeaders(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER);
assertThat(headers.length, is(2));
assertThat(headers[0].getValue(), is(inputType.toString()));
assertThat(headers[1].getValue(), is("my-product-use-case-from-metadata"));
}
}

private ElasticInferenceServiceDenseTextEmbeddingsRequest createRequest(
String url,
String modelId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.common.TruncatorTests;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata;
import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceSparseEmbeddingsRequest.inputTypeToUsageContext;
import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -99,21 +98,31 @@ public void testIsTruncated_ReturnsTrue() {
assertTrue(truncatedRequest.getTruncationInfo()[0]);
}

public void testInputTypeToUsageContext_Search() {
assertThat(inputTypeToUsageContext(InputType.SEARCH), equalTo(ElasticInferenceServiceUsageContext.SEARCH));
}

public void testInputTypeToUsageContext_Ingest() {
assertThat(inputTypeToUsageContext(InputType.INGEST), equalTo(ElasticInferenceServiceUsageContext.INGEST));
}

public void testInputTypeToUsageContext_Unspecified() {
assertThat(inputTypeToUsageContext(InputType.UNSPECIFIED), equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED));
}
public void testDecorate_HttpRequest_WithProductUseCase() {
var input = "elastic";
var modelId = "my-model-id";
var url = "http://eis-gateway.com";

public void testInputTypeToUsageContext_Unknown_DefaultToUnspecified() {
assertThat(inputTypeToUsageContext(InputType.CLASSIFICATION), equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED));
assertThat(inputTypeToUsageContext(InputType.CLUSTERING), equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED));
for (var inputType : List.of(InputType.INTERNAL_SEARCH, InputType.INTERNAL_INGEST, InputType.UNSPECIFIED)) {
var request = new ElasticInferenceServiceSparseEmbeddingsRequest(
TruncatorTests.createTruncator(),
new Truncator.TruncationResult(List.of(input), new boolean[] { false }),
ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(url, modelId),
new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)),
new ElasticInferenceServiceRequestMetadata("my-product-origin", "my-product-use-case-from-metadata"),
inputType
);

var httpRequest = request.createHttpRequest();

assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();

var headers = httpPost.getHeaders(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER);
assertThat(headers.length, is(2));
assertThat(headers[0].getValue(), is(inputType.toString()));
assertThat(headers[1].getValue(), is("my-product-use-case-from-metadata"));
}
}

public ElasticInferenceServiceSparseEmbeddingsRequest createRequest(String url, String modelId, String input, InputType inputType) {
Expand Down