From 40b80537ed76f668686c621bebdc129d0f301764 Mon Sep 17 00:00:00 2001 From: Lisa Cawley Date: Mon, 18 Oct 2021 09:23:41 -0700 Subject: [PATCH] Revert "[ML] Add queue_capacity setting to start deployment API (#79369)" This reverts commit 637a2993095da2d2fa65adee971c552cc8dc5c14. --- .../StartTrainedModelDeploymentAction.java | 47 +++---------------- ...inedModelAllocationActionRequestTests.java | 9 +++- ...artTrainedModelDeploymentRequestTests.java | 43 ----------------- ...TrainedModelDeploymentTaskParamsTests.java | 3 +- .../TrainedModelAllocationTests.java | 9 ++-- ...portStartTrainedModelDeploymentAction.java | 5 +- .../deployment/DeploymentManager.java | 7 +-- ...RestStartTrainedModelDeploymentAction.java | 2 - ...nedModelAllocationClusterServiceTests.java | 2 +- .../TrainedModelAllocationMetadataTests.java | 3 +- ...rainedModelAllocationNodeServiceTests.java | 2 +- .../xpack/ml/job/NodeLoadDetectorTests.java | 3 +- 12 files changed, 27 insertions(+), 108 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java index 09af5c073f9a5..0aa300346c58d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -60,7 +60,6 @@ public static class Request extends MasterNodeRequest implements ToXCon public static final ParseField WAIT_FOR = new ParseField("wait_for"); public static final ParseField INFERENCE_THREADS = TaskParams.INFERENCE_THREADS; public static final ParseField MODEL_THREADS = TaskParams.MODEL_THREADS; - public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY; public static final ObjectParser PARSER = new ObjectParser<>(NAME, Request::new); @@ -70,7 +69,6 @@ public static class Request extends MasterNodeRequest implements ToXCon PARSER.declareString((request, waitFor) -> request.setWaitForState(AllocationStatus.State.fromString(waitFor)), WAIT_FOR); PARSER.declareInt(Request::setInferenceThreads, INFERENCE_THREADS); PARSER.declareInt(Request::setModelThreads, MODEL_THREADS); - PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY); } public static Request parseRequest(String modelId, XContentParser parser) { @@ -89,7 +87,6 @@ public static Request parseRequest(String modelId, XContentParser parser) { private AllocationStatus.State waitForState = AllocationStatus.State.STARTED; private int modelThreads = 1; private int inferenceThreads = 1; - private int queueCapacity = 1024; private Request() {} @@ -104,7 +101,6 @@ public Request(StreamInput in) throws IOException { waitForState = in.readEnum(AllocationStatus.State.class); modelThreads = in.readVInt(); inferenceThreads = in.readVInt(); - queueCapacity = in.readVInt(); } public final void setModelId(String modelId) { @@ -148,14 +144,6 @@ public void setInferenceThreads(int inferenceThreads) { this.inferenceThreads = inferenceThreads; } - public int getQueueCapacity() { - return queueCapacity; - } - - public void setQueueCapacity(int queueCapacity) { - this.queueCapacity = queueCapacity; - } - @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -164,7 +152,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeEnum(waitForState); out.writeVInt(modelThreads); out.writeVInt(inferenceThreads); - out.writeVInt(queueCapacity); } @Override @@ -175,7 +162,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(WAIT_FOR.getPreferredName(), waitForState); builder.field(MODEL_THREADS.getPreferredName(), modelThreads); builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads); - builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity); builder.endObject(); return builder; } @@ -197,15 +183,12 @@ public ActionRequestValidationException validate() { if (inferenceThreads < 1) { validationException.addValidationError("[" + INFERENCE_THREADS + "] must be a positive integer"); } - if (queueCapacity < 1 || queueCapacity > 10000) { - validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be in [1, 10000]"); - } return validationException.validationErrors().isEmpty() ? null : validationException; } @Override public int hashCode() { - return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads, queueCapacity); + return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads); } @Override @@ -221,8 +204,7 @@ public boolean equals(Object obj) { && Objects.equals(timeout, other.timeout) && Objects.equals(waitForState, other.waitForState) && modelThreads == other.modelThreads - && inferenceThreads == other.inferenceThreads - && queueCapacity == other.queueCapacity; + && inferenceThreads == other.inferenceThreads; } @Override @@ -244,20 +226,16 @@ public static boolean mayAllocateToNode(DiscoveryNode node) { private static final ParseField MODEL_BYTES = new ParseField("model_bytes"); public static final ParseField MODEL_THREADS = new ParseField("model_threads"); public static final ParseField INFERENCE_THREADS = new ParseField("inference_threads"); - public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity"); - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "trained_model_deployment_params", true, - a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3], (int) a[4]) + a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3]) ); - static { PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID); PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES); PARSER.declareInt(ConstructingObjectParser.constructorArg(), INFERENCE_THREADS); PARSER.declareInt(ConstructingObjectParser.constructorArg(), MODEL_THREADS); - PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY); } public static TaskParams fromXContent(XContentParser parser) { @@ -275,9 +253,8 @@ public static TaskParams fromXContent(XContentParser parser) { private final long modelBytes; private final int inferenceThreads; private final int modelThreads; - private final int queueCapacity; - public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads, int queueCapacity) { + public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads) { this.modelId = Objects.requireNonNull(modelId); this.modelBytes = modelBytes; if (modelBytes < 0) { @@ -291,10 +268,6 @@ public TaskParams(String modelId, long modelBytes, int inferenceThreads, int mod if (modelThreads < 1) { throw new IllegalArgumentException(MODEL_THREADS + " must be positive"); } - this.queueCapacity = queueCapacity; - if (queueCapacity < 1 || queueCapacity > 10000) { - throw new IllegalArgumentException(QUEUE_CAPACITY + " must be in [1, 10000]"); - } } public TaskParams(StreamInput in) throws IOException { @@ -302,7 +275,6 @@ public TaskParams(StreamInput in) throws IOException { this.modelBytes = in.readVLong(); this.inferenceThreads = in.readVInt(); this.modelThreads = in.readVInt(); - this.queueCapacity = in.readVInt(); } public String getModelId() { @@ -324,7 +296,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeVLong(modelBytes); out.writeVInt(inferenceThreads); out.writeVInt(modelThreads); - out.writeVInt(queueCapacity); } @Override @@ -334,14 +305,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(MODEL_BYTES.getPreferredName(), modelBytes); builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads); builder.field(MODEL_THREADS.getPreferredName(), modelThreads); - builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity); builder.endObject(); return builder; } @Override public int hashCode() { - return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads, queueCapacity); + return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads); } @Override @@ -353,8 +323,7 @@ public boolean equals(Object o) { return Objects.equals(modelId, other.modelId) && modelBytes == other.modelBytes && inferenceThreads == other.inferenceThreads - && modelThreads == other.modelThreads - && queueCapacity == other.queueCapacity; + && modelThreads == other.modelThreads; } @Override @@ -373,10 +342,6 @@ public int getInferenceThreads() { public int getModelThreads() { return modelThreads; } - - public int getQueueCapacity() { - return queueCapacity; - } } public interface TaskMatcher { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java index 978139c44e142..6a6fa0453ff7e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java @@ -14,7 +14,14 @@ public class CreateTrainedModelAllocationActionRequestTests extends AbstractWire @Override protected Request createTestInstance() { - return new Request(StartTrainedModelDeploymentTaskParamsTests.createRandom()); + return new Request( + new StartTrainedModelDeploymentAction.TaskParams( + randomAlphaOfLength(10), + randomNonNegativeLong(), + randomIntBetween(1, 8), + randomIntBetween(1, 8) + ) + ); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java index 0365ea45c2b17..6bd27634dcf69 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java @@ -18,7 +18,6 @@ import java.io.IOException; import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; @@ -54,9 +53,6 @@ public static Request createRandom() { if (randomBoolean()) { request.setModelThreads(randomIntBetween(1, 8)); } - if (randomBoolean()) { - request.setQueueCapacity(randomIntBetween(1, 10000)); - } return request; } @@ -99,43 +95,4 @@ public void testValidate_GivenModelThreadsIsNegative() { assertThat(e, is(not(nullValue()))); assertThat(e.getMessage(), containsString("[model_threads] must be a positive integer")); } - - public void testValidate_GivenQueueCapacityIsZero() { - Request request = createRandom(); - request.setQueueCapacity(0); - - ActionRequestValidationException e = request.validate(); - - assertThat(e, is(not(nullValue()))); - assertThat(e.getMessage(), containsString("[queue_capacity] must be in [1, 10000]")); - } - - public void testValidate_GivenQueueCapacityIsNegative() { - Request request = createRandom(); - request.setQueueCapacity(randomIntBetween(Integer.MIN_VALUE, -1)); - - ActionRequestValidationException e = request.validate(); - - assertThat(e, is(not(nullValue()))); - assertThat(e.getMessage(), containsString("[queue_capacity] must be in [1, 10000]")); - } - - public void testValidate_GivenQueueCapacityIsGreaterThan10000() { - Request request = createRandom(); - request.setQueueCapacity(randomIntBetween(10001, Integer.MAX_VALUE)); - - ActionRequestValidationException e = request.validate(); - - assertThat(e, is(not(nullValue()))); - assertThat(e.getMessage(), containsString("[queue_capacity] must be in [1, 10000]")); - } - - public void testDefaults() { - Request request = new Request(randomAlphaOfLength(10)); - assertThat(request.getTimeout(), equalTo(TimeValue.timeValueSeconds(20))); - assertThat(request.getWaitForState(), equalTo(AllocationStatus.State.STARTED)); - assertThat(request.getInferenceThreads(), equalTo(1)); - assertThat(request.getModelThreads(), equalTo(1)); - assertThat(request.getQueueCapacity(), equalTo(1024)); - } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java index c5160f96663a3..95a529d3ccc1e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java @@ -36,8 +36,7 @@ public static StartTrainedModelDeploymentAction.TaskParams createRandom() { randomAlphaOfLength(10), randomNonNegativeLong(), randomIntBetween(1, 8), - randomIntBetween(1, 8), - randomIntBetween(1, 10000) + randomIntBetween(1, 8) ); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java index 82ca307f0e024..473730901cac7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java @@ -13,10 +13,9 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentTaskParamsTests; import java.io.IOException; import java.util.List; @@ -32,7 +31,9 @@ public class TrainedModelAllocationTests extends AbstractSerializingTestCase { public static TrainedModelAllocation randomInstance() { - TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(randomParams()); + TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty( + new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1) + ); List nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList()); for (String node : nodes) { if (randomBoolean()) { @@ -248,7 +249,7 @@ private static DiscoveryNode buildNode() { } private static StartTrainedModelDeploymentAction.TaskParams randomParams() { - return StartTrainedModelDeploymentTaskParamsTests.createRandom(); + return new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1); } private static void assertUnchanged( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java index f6d260c61fd2f..682fab432fbf6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java @@ -26,6 +26,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.core.TimeValue; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; @@ -34,7 +35,6 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction; @@ -161,8 +161,7 @@ protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Requ trainedModelConfig.getModelId(), modelBytes, request.getInferenceThreads(), - request.getModelThreads(), - request.getQueueCapacity() + request.getModelThreads() ); PersistentTasksCustomMetadata persistentTasks = clusterService.state().getMetadata().custom( PersistentTasksCustomMetadata.TYPE); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 21ac5dec07bec..8ae0aeb3f4dde 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -307,7 +307,6 @@ public void onFailure(Exception e) { @Override protected void doRun() throws Exception { - logger.info("Request [{}] running", requestId); final String requestIdStr = String.valueOf(requestId); try { // The request builder expect a list of inputs which are then batched. @@ -393,11 +392,7 @@ class ProcessContext { this.task = Objects.requireNonNull(task); resultProcessor = new PyTorchResultProcessor(task.getModelId()); this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry); - this.executorService = new ProcessWorkerExecutorService( - threadPool.getThreadContext(), - "pytorch_inference", - task.getParams().getQueueCapacity() - ); + this.executorService = new ProcessWorkerExecutorService(threadPool.getThreadContext(), "pytorch_inference", 1024); } PyTorchResultProcessor getResultProcessor() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java index f3cf06bf70684..6fa6405b461b0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java @@ -23,7 +23,6 @@ import static org.elasticsearch.rest.RestRequest.Method.POST; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.INFERENCE_THREADS; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.MODEL_THREADS; -import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.TIMEOUT; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.WAIT_FOR; @@ -60,7 +59,6 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient )); request.setInferenceThreads(restRequest.paramAsInt(INFERENCE_THREADS.getPreferredName(), request.getInferenceThreads())); request.setModelThreads(restRequest.paramAsInt(MODEL_THREADS.getPreferredName(), request.getModelThreads())); - request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity())); } return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java index e6b82316b9663..0c974d17fbce1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java @@ -940,7 +940,7 @@ private static DiscoveryNode buildOldNode(String name, boolean isML, long native } private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId, long modelSize) { - return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, 1, 1, 1024); + return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, 1, 1); } private static void assertNodeState(TrainedModelAllocationMetadata metadata, String modelId, String nodeId, RoutingState routingState) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java index ccb9b27be591e..1de97cc5991f3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java @@ -99,8 +99,7 @@ private static StartTrainedModelDeploymentAction.TaskParams randomParams(String modelId, randomNonNegativeLong(), randomIntBetween(1, 8), - randomIntBetween(1, 8), - randomIntBetween(1, 10000) + randomIntBetween(1, 8) ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java index 0b8bffdf302aa..88cd412c30147 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java @@ -497,7 +497,7 @@ private void withSearchingLoadFailure(String modelId) { } private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId) { - return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong(), 1, 1, 1024); + return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong(), 1, 1); } private TrainedModelAllocationNodeService createService() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java index f8cd99164af07..8ef497c4aa956 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java @@ -91,8 +91,7 @@ public void testNodeLoadDetection() { .addNewAllocation( "model1", TrainedModelAllocation.Builder - .empty(new StartTrainedModelDeploymentAction.TaskParams( - "model1", MODEL_MEMORY_REQUIREMENT, 1, 1, 1024)) + .empty(new StartTrainedModelDeploymentAction.TaskParams("model1", MODEL_MEMORY_REQUIREMENT, 1, 1)) .addNewRoutingEntry("_node_id4") .addNewFailedRoutingEntry("_node_id2", "test") .addNewRoutingEntry("_node_id1")