Skip to content

Commit 188d56a

Browse files
authored
[ML] address potential bug where trained models get stuck in starting after being allocated to node (#88945) (#88992)
When a model is starting, it has been rarely observed that it will lock up while trying to restore the model objects to the native process. This would manifest as a trained model being stuck in "starting" while also being assigned to a node. So, there is a native process started and task available on the assigned nodes, but the model state never gets out of "starting".
1 parent 3b9c452 commit 188d56a

File tree

4 files changed

+71
-42
lines changed

4 files changed

+71
-42
lines changed

docs/changelog/88945.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 88945
2+
summary: Address potential bug where trained models get stuck in starting after being
3+
allocated to node
4+
area: Machine Learning
5+
type: bug
6+
issues: []

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,13 @@ void loadQueuedModels() {
225225
} catch (Exception ex) {
226226
logger.warn(() -> "[" + modelId + "] Start deployment failed", ex);
227227
if (ExceptionsHelper.unwrapCause(ex) instanceof ResourceNotFoundException) {
228-
logger.warn(() -> "[" + modelId + "] Start deployment failed", ex);
228+
logger.debug(() -> "[" + modelId + "] Start deployment failed as model was not found", ex);
229229
handleLoadFailure(loadingTask, ExceptionsHelper.missingTrainedModel(modelId, ex));
230230
} else if (ExceptionsHelper.unwrapCause(ex) instanceof SearchPhaseExecutionException) {
231-
logger.trace(() -> "[" + modelId + "] Start deployment failed, will retry", ex);
231+
logger.debug(() -> "[" + modelId + "] Start deployment failed, will retry", ex);
232232
// A search phase execution failure should be retried, push task back to the queue
233233
loadingToRetry.add(loadingTask);
234234
} else {
235-
logger.warn(() -> "[" + modelId + "] Start deployment failed", ex);
236235
handleLoadFailure(loadingTask, ex);
237236
}
238237
}
@@ -413,7 +412,7 @@ private void updateNumberOfAllocations(TrainedModelAssignmentMetadata assignment
413412
for (TrainedModelAssignment assignment : modelsToUpdate) {
414413
TrainedModelDeploymentTask task = modelIdToTask.get(assignment.getModelId());
415414
if (task == null) {
416-
logger.debug(() -> format("[%s] task was removed whilst updating number of allocations", task.getModelId()));
415+
logger.debug(() -> format("[%s] task was removed whilst updating number of allocations", assignment.getModelId()));
417416
continue;
418417
}
419418
RoutingInfo routingInfo = assignment.getNodeRoutingTable().get(nodeId);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import java.util.concurrent.atomic.AtomicLong;
6262
import java.util.function.Consumer;
6363

64+
import static org.elasticsearch.core.Strings.format;
6465
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
6566
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
6667

@@ -149,34 +150,46 @@ private void doStartDeployment(TrainedModelDeploymentTask task, ActionListener<T
149150
TrainedModelConfig modelConfig = getModelResponse.getResources().results().get(0);
150151
processContext.modelInput.set(modelConfig.getInput());
151152

152-
assert modelConfig.getInferenceConfig() instanceof NlpConfig;
153-
NlpConfig nlpConfig = (NlpConfig) modelConfig.getInferenceConfig();
154-
task.init(nlpConfig);
155-
156-
SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
157-
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchVocabResponse -> {
158-
if (searchVocabResponse.getHits().getHits().length == 0) {
159-
listener.onFailure(
160-
new ResourceNotFoundException(
161-
Messages.getMessage(
162-
Messages.VOCABULARY_NOT_FOUND,
163-
task.getModelId(),
164-
VocabularyConfig.docId(modelConfig.getModelId())
153+
if (modelConfig.getInferenceConfig()instanceof NlpConfig nlpConfig) {
154+
task.init(nlpConfig);
155+
156+
SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
157+
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchVocabResponse -> {
158+
if (searchVocabResponse.getHits().getHits().length == 0) {
159+
listener.onFailure(
160+
new ResourceNotFoundException(
161+
Messages.getMessage(
162+
Messages.VOCABULARY_NOT_FOUND,
163+
task.getModelId(),
164+
VocabularyConfig.docId(modelConfig.getModelId())
165+
)
165166
)
166-
)
167+
);
168+
return;
169+
}
170+
171+
Vocabulary vocabulary = parseVocabularyDocLeniently(searchVocabResponse.getHits().getAt(0));
172+
NlpTask nlpTask = new NlpTask(nlpConfig, vocabulary);
173+
NlpTask.Processor processor = nlpTask.createProcessor();
174+
processContext.nlpTaskProcessor.set(processor);
175+
// here, we are being called back on the searching thread, which MAY be a network thread
176+
// `startAndLoad` creates named pipes, blocking the calling thread, better to execute that in our utility
177+
// executor.
178+
executorServiceForDeployment.execute(
179+
() -> startAndLoad(processContext, modelConfig.getLocation(), modelLoadedListener)
167180
);
168-
return;
169-
}
170-
171-
Vocabulary vocabulary = parseVocabularyDocLeniently(searchVocabResponse.getHits().getAt(0));
172-
NlpTask nlpTask = new NlpTask(nlpConfig, vocabulary);
173-
NlpTask.Processor processor = nlpTask.createProcessor();
174-
processContext.nlpTaskProcessor.set(processor);
175-
// here, we are being called back on the searching thread, which MAY be a network thread
176-
// `startAndLoad` creates named pipes, blocking the calling thread, better to execute that in our utility
177-
// executor.
178-
executorServiceForDeployment.execute(() -> startAndLoad(processContext, modelConfig.getLocation(), modelLoadedListener));
179-
}, listener::onFailure));
181+
}, listener::onFailure));
182+
} else {
183+
listener.onFailure(
184+
new IllegalArgumentException(
185+
format(
186+
"[%s] must be a pytorch model; found inference config of kind [%s]",
187+
modelConfig.getModelId(),
188+
modelConfig.getInferenceConfig().getWriteableName()
189+
)
190+
)
191+
);
192+
}
180193
}, listener::onFailure);
181194

182195
executeAsyncWithOrigin(
@@ -404,10 +417,12 @@ private Consumer<String> onProcessCrash() {
404417
}
405418

406419
void loadModel(TrainedModelLocation modelLocation, ActionListener<Boolean> listener) {
407-
if (modelLocation instanceof IndexLocation) {
408-
process.get().loadModel(task.getModelId(), ((IndexLocation) modelLocation).getIndexName(), stateStreamer, listener);
420+
if (modelLocation instanceof IndexLocation indexLocation) {
421+
process.get().loadModel(task.getModelId(), indexLocation.getIndexName(), stateStreamer, listener);
409422
} else {
410-
throw new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]");
423+
listener.onFailure(
424+
new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]")
425+
);
411426
}
412427
}
413428

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/ChunkedTrainedModelRestorer.java

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
1212
import org.elasticsearch.ResourceNotFoundException;
13-
import org.elasticsearch.action.ActionListener;
14-
import org.elasticsearch.action.search.SearchAction;
1513
import org.elasticsearch.action.search.SearchRequest;
1614
import org.elasticsearch.action.search.SearchRequestBuilder;
15+
import org.elasticsearch.action.search.SearchResponse;
1716
import org.elasticsearch.client.internal.Client;
17+
import org.elasticsearch.client.internal.OriginSettingClient;
1818
import org.elasticsearch.common.bytes.BytesReference;
1919
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
2020
import org.elasticsearch.core.CheckedFunction;
@@ -38,8 +38,10 @@
3838
import java.util.concurrent.ExecutorService;
3939
import java.util.function.Consumer;
4040

41+
import static org.elasticsearch.core.Strings.format;
4142
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
42-
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
43+
import static org.elasticsearch.xpack.ml.MachineLearning.NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME;
44+
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
4345

4446
/**
4547
* Searches for and emits {@link TrainedModelDefinitionDoc}s in
@@ -71,7 +73,7 @@ public ChunkedTrainedModelRestorer(
7173
ExecutorService executorService,
7274
NamedXContentRegistry xContentRegistry
7375
) {
74-
this.client = client;
76+
this.client = new OriginSettingClient(client, ML_ORIGIN);
7577
this.executorService = executorService;
7678
this.xContentRegistry = xContentRegistry;
7779
this.modelId = modelId;
@@ -122,7 +124,6 @@ public void restoreModelDefinition(
122124

123125
logger.debug("[{}] restoring model", modelId);
124126
SearchRequest searchRequest = buildSearch(client, modelId, index, searchSize, null);
125-
126127
executorService.execute(() -> doSearch(searchRequest, modelConsumer, successConsumer, errorConsumer));
127128
}
128129

@@ -132,8 +133,16 @@ private void doSearch(
132133
Consumer<Boolean> successConsumer,
133134
Consumer<Exception> errorConsumer
134135
) {
135-
136-
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchResponse -> {
136+
try {
137+
assert Thread.currentThread().getName().contains(NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME)
138+
|| Thread.currentThread().getName().contains(UTILITY_THREAD_POOL_NAME)
139+
: format(
140+
"Must execute from [%s] or [%s] but thread is [%s]",
141+
NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME,
142+
UTILITY_THREAD_POOL_NAME,
143+
Thread.currentThread().getName()
144+
);
145+
SearchResponse searchResponse = client.search(searchRequest).actionGet();
137146
if (searchResponse.getHits().getHits().length == 0) {
138147
errorConsumer.accept(new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
139148
return;
@@ -182,13 +191,13 @@ private void doSearch(
182191
searchRequestBuilder.searchAfter(new Object[] { lastHit.getIndex(), lastNum });
183192
executorService.execute(() -> doSearch(searchRequestBuilder.request(), modelConsumer, successConsumer, errorConsumer));
184193
}
185-
}, e -> {
194+
} catch (Exception e) {
186195
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
187196
errorConsumer.accept(new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
188197
} else {
189198
errorConsumer.accept(e);
190199
}
191-
}));
200+
}
192201
}
193202

194203
private static SearchRequestBuilder buildSearchBuilder(Client client, String modelId, String index, int searchSize) {

0 commit comments

Comments
 (0)