diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 8d0b8486ba437..60592c5dd1dbd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -663,7 +663,7 @@ public void onNodeStarted() { @Override public Collection getRestHeaders() { - return Set.of(new RestHeaderDefinition(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, false)); + return Set.of(new RestHeaderDefinition(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, true)); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUsageContext.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUsageContext.java index 7303f0c6e4436..25653a7594b3b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUsageContext.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUsageContext.java @@ -7,8 +7,12 @@ package org.elasticsearch.xpack.inference.services.elastic; +import org.elasticsearch.inference.InputType; + import java.util.Locale; +import static org.elasticsearch.inference.InputType.INTERNAL_INGEST; + /** * Specifies the usage context for a request to the Elastic Inference Service. * This helps to determine the type of resources that are allocated in the Elastic Inference Service for the particular request. @@ -24,4 +28,25 @@ public String toString() { return name().toLowerCase(Locale.ROOT); } + public static ElasticInferenceServiceUsageContext fromInputType(InputType inputType) { + switch (inputType) { + case SEARCH, INTERNAL_SEARCH -> { + return ElasticInferenceServiceUsageContext.SEARCH; + } + case INGEST, INTERNAL_INGEST -> { + return ElasticInferenceServiceUsageContext.INGEST; + } + default -> { + return ElasticInferenceServiceUsageContext.UNSPECIFIED; + } + } + } + + public String productUseCaseHeaderValue() { + return switch (this) { + case SEARCH -> "internal_search"; + case INGEST -> "internal_ingest"; + case UNSPECIFIED -> "unspecified"; + }; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequest.java index 8a873504ee128..95408f4d04054 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequest.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; @@ -25,7 +26,7 @@ import java.util.List; import java.util.Objects; -import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceSparseEmbeddingsRequest.inputTypeToUsageContext; +import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER; public class ElasticInferenceServiceDenseTextEmbeddingsRequest extends ElasticInferenceServiceRequest { @@ -53,7 +54,7 @@ public ElasticInferenceServiceDenseTextEmbeddingsRequest( @Override public HttpRequestBase createHttpRequestBase() { var httpPost = new HttpPost(uri); - var usageContext = inputTypeToUsageContext(inputType); + var usageContext = ElasticInferenceServiceUsageContext.fromInputType(inputType); var requestEntity = Strings.toString( new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(inputs, model.getServiceSettings().modelId(), usageContext) ); @@ -63,6 +64,7 @@ public HttpRequestBase createHttpRequestBase() { traceContextHandler.propagateTraceContext(httpPost); httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); + httpPost.setHeader(new BasicHeader(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, usageContext.productUseCaseHeaderValue())); return httpPost; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRequest.java index 8d2db92b953fe..d1e60220126e2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRequest.java @@ -42,7 +42,7 @@ public final HttpRequest createHttpRequest() { } if (Objects.nonNull(productUseCase) && productUseCase.isEmpty() == false) { - request.setHeader(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, metadata.productUseCase()); + request.addHeader(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, metadata.productUseCase()); } return new HttpRequest(request, getInferenceEntityId()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequest.java index ae52955c1d98f..fa6e9ef5ae935 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequest.java @@ -26,6 +26,8 @@ import java.nio.charset.StandardCharsets; import java.util.Objects; +import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER; + public class ElasticInferenceServiceSparseEmbeddingsRequest extends ElasticInferenceServiceRequest { private final URI uri; @@ -55,7 +57,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest( @Override public HttpRequestBase createHttpRequestBase() { var httpPost = new HttpPost(uri); - var usageContext = inputTypeToUsageContext(inputType); + var usageContext = ElasticInferenceServiceUsageContext.fromInputType(inputType); var requestEntity = Strings.toString( new ElasticInferenceServiceSparseEmbeddingsRequestEntity( truncationResult.input(), @@ -69,6 +71,7 @@ public HttpRequestBase createHttpRequestBase() { traceContextHandler.propagateTraceContext(httpPost); httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); + httpPost.setHeader(new BasicHeader(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, usageContext.productUseCaseHeaderValue())); return httpPost; } @@ -104,19 +107,4 @@ public Request truncate() { public boolean[] getTruncationInfo() { return truncationResult.truncated().clone(); } - - // visible for testing - static ElasticInferenceServiceUsageContext inputTypeToUsageContext(InputType inputType) { - switch (inputType) { - case SEARCH, INTERNAL_SEARCH -> { - return ElasticInferenceServiceUsageContext.SEARCH; - } - case INGEST, INTERNAL_INGEST -> { - return ElasticInferenceServiceUsageContext.INGEST; - } - default -> { - return ElasticInferenceServiceUsageContext.UNSPECIFIED; - } - } - } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index d861d5b2bb47b..a44a692e9b912 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -93,6 +93,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.isA; @@ -703,7 +704,8 @@ public void testInfer_PropagatesProductUseCaseHeader() throws IOException { assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType())); // Check that the product use case header was set correctly - assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase)); + var productUseCaseHeaders = request.getHeaders().get(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER); + assertThat(productUseCaseHeaders, contains("internal_search", productUseCase)); // Verify request body var requestMap = entityAsMap(request.getBody()); @@ -832,7 +834,8 @@ public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); // Check that the product use case header was set correctly - assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase)); + var productUseCaseHeaders = request.getHeaders().get(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER); + assertThat(productUseCaseHeaders, contains("internal_ingest", productUseCase)); } finally { // Clean up the thread context diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUsageContextTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUsageContextTests.java new file mode 100644 index 0000000000000..619a0438d4b88 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUsageContextTests.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; + +import static org.hamcrest.Matchers.equalTo; + +public class ElasticInferenceServiceUsageContextTests extends ESTestCase { + + public void testInputTypeToUsageContext_Search() { + assertThat( + ElasticInferenceServiceUsageContext.fromInputType(InputType.SEARCH), + equalTo(ElasticInferenceServiceUsageContext.SEARCH) + ); + } + + public void testInputTypeToUsageContext_Ingest() { + assertThat( + ElasticInferenceServiceUsageContext.fromInputType(InputType.INGEST), + equalTo(ElasticInferenceServiceUsageContext.INGEST) + ); + } + + public void testInputTypeToUsageContext_Unspecified() { + assertThat( + ElasticInferenceServiceUsageContext.fromInputType(InputType.UNSPECIFIED), + equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED) + ); + } + + public void testInputTypeToUsageContext_Unknown_DefaultToUnspecified() { + assertThat( + ElasticInferenceServiceUsageContext.fromInputType(InputType.CLASSIFICATION), + equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED) + ); + assertThat( + ElasticInferenceServiceUsageContext.fromInputType(InputType.CLUSTERING), + equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED) + ); + } + + public void testProductUseCase() { + assertThat(ElasticInferenceServiceUsageContext.SEARCH.productUseCaseHeaderValue(), Matchers.is("internal_search")); + assertThat(ElasticInferenceServiceUsageContext.INGEST.productUseCaseHeaderValue(), Matchers.is("internal_ingest")); + assertThat(ElasticInferenceServiceUsageContext.UNSPECIFIED.productUseCaseHeaderValue(), Matchers.is("unspecified")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java index 86687980acdf6..6705ff744debc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.List; +import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; import static org.hamcrest.Matchers.aMapWithSize; @@ -146,6 +147,32 @@ public void testGetTruncationInfo_ReturnsNull() { assertThat(request.getTruncationInfo(), is(nullValue())); } + public void testDecorate_HttpRequest_WithProductUseCase() { + var input = "elastic"; + var modelId = "my-model-id"; + var url = "http://eis-gateway.com"; + + for (var inputType : List.of(InputType.INTERNAL_SEARCH, InputType.INTERNAL_INGEST, InputType.UNSPECIFIED)) { + var request = new ElasticInferenceServiceDenseTextEmbeddingsRequest( + ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(url, modelId), + List.of(input), + new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), + new ElasticInferenceServiceRequestMetadata("my-product-origin", "my-product-use-case-from-metadata"), + inputType + ); + + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var headers = httpPost.getHeaders(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER); + assertThat(headers.length, is(2)); + assertThat(headers[0].getValue(), is(inputType.toString())); + assertThat(headers[1].getValue(), is("my-product-use-case-from-metadata")); + } + } + private ElasticInferenceServiceDenseTextEmbeddingsRequest createRequest( String url, String modelId, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequestTests.java index 100f327225293..f739f5ca8b379 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceSparseEmbeddingsRequestTests.java @@ -16,15 +16,14 @@ import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.io.IOException; import java.util.List; +import static org.elasticsearch.xpack.inference.InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; -import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceSparseEmbeddingsRequest.inputTypeToUsageContext; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -99,21 +98,31 @@ public void testIsTruncated_ReturnsTrue() { assertTrue(truncatedRequest.getTruncationInfo()[0]); } - public void testInputTypeToUsageContext_Search() { - assertThat(inputTypeToUsageContext(InputType.SEARCH), equalTo(ElasticInferenceServiceUsageContext.SEARCH)); - } - - public void testInputTypeToUsageContext_Ingest() { - assertThat(inputTypeToUsageContext(InputType.INGEST), equalTo(ElasticInferenceServiceUsageContext.INGEST)); - } - - public void testInputTypeToUsageContext_Unspecified() { - assertThat(inputTypeToUsageContext(InputType.UNSPECIFIED), equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED)); - } + public void testDecorate_HttpRequest_WithProductUseCase() { + var input = "elastic"; + var modelId = "my-model-id"; + var url = "http://eis-gateway.com"; - public void testInputTypeToUsageContext_Unknown_DefaultToUnspecified() { - assertThat(inputTypeToUsageContext(InputType.CLASSIFICATION), equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED)); - assertThat(inputTypeToUsageContext(InputType.CLUSTERING), equalTo(ElasticInferenceServiceUsageContext.UNSPECIFIED)); + for (var inputType : List.of(InputType.INTERNAL_SEARCH, InputType.INTERNAL_INGEST, InputType.UNSPECIFIED)) { + var request = new ElasticInferenceServiceSparseEmbeddingsRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of(input), new boolean[] { false }), + ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(url, modelId), + new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), + new ElasticInferenceServiceRequestMetadata("my-product-origin", "my-product-use-case-from-metadata"), + inputType + ); + + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var headers = httpPost.getHeaders(X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER); + assertThat(headers.length, is(2)); + assertThat(headers[0].getValue(), is(inputType.toString())); + assertThat(headers[1].getValue(), is("my-product-use-case-from-metadata")); + } } public ElasticInferenceServiceSparseEmbeddingsRequest createRequest(String url, String modelId, String input, InputType inputType) {