|
61 | 61 | import java.util.concurrent.atomic.AtomicLong; |
62 | 62 | import java.util.function.Consumer; |
63 | 63 |
|
| 64 | +import static org.elasticsearch.core.Strings.format; |
64 | 65 | import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; |
65 | 66 | import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; |
66 | 67 |
|
@@ -149,34 +150,46 @@ private void doStartDeployment(TrainedModelDeploymentTask task, ActionListener<T |
149 | 150 | TrainedModelConfig modelConfig = getModelResponse.getResources().results().get(0); |
150 | 151 | processContext.modelInput.set(modelConfig.getInput()); |
151 | 152 |
|
152 | | - assert modelConfig.getInferenceConfig() instanceof NlpConfig; |
153 | | - NlpConfig nlpConfig = (NlpConfig) modelConfig.getInferenceConfig(); |
154 | | - task.init(nlpConfig); |
155 | | - |
156 | | - SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId()); |
157 | | - executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchVocabResponse -> { |
158 | | - if (searchVocabResponse.getHits().getHits().length == 0) { |
159 | | - listener.onFailure( |
160 | | - new ResourceNotFoundException( |
161 | | - Messages.getMessage( |
162 | | - Messages.VOCABULARY_NOT_FOUND, |
163 | | - task.getModelId(), |
164 | | - VocabularyConfig.docId(modelConfig.getModelId()) |
| 153 | + if (modelConfig.getInferenceConfig()instanceof NlpConfig nlpConfig) { |
| 154 | + task.init(nlpConfig); |
| 155 | + |
| 156 | + SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId()); |
| 157 | + executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchVocabResponse -> { |
| 158 | + if (searchVocabResponse.getHits().getHits().length == 0) { |
| 159 | + listener.onFailure( |
| 160 | + new ResourceNotFoundException( |
| 161 | + Messages.getMessage( |
| 162 | + Messages.VOCABULARY_NOT_FOUND, |
| 163 | + task.getModelId(), |
| 164 | + VocabularyConfig.docId(modelConfig.getModelId()) |
| 165 | + ) |
165 | 166 | ) |
166 | | - ) |
| 167 | + ); |
| 168 | + return; |
| 169 | + } |
| 170 | + |
| 171 | + Vocabulary vocabulary = parseVocabularyDocLeniently(searchVocabResponse.getHits().getAt(0)); |
| 172 | + NlpTask nlpTask = new NlpTask(nlpConfig, vocabulary); |
| 173 | + NlpTask.Processor processor = nlpTask.createProcessor(); |
| 174 | + processContext.nlpTaskProcessor.set(processor); |
| 175 | + // here, we are being called back on the searching thread, which MAY be a network thread |
| 176 | + // `startAndLoad` creates named pipes, blocking the calling thread, better to execute that in our utility |
| 177 | + // executor. |
| 178 | + executorServiceForDeployment.execute( |
| 179 | + () -> startAndLoad(processContext, modelConfig.getLocation(), modelLoadedListener) |
167 | 180 | ); |
168 | | - return; |
169 | | - } |
170 | | - |
171 | | - Vocabulary vocabulary = parseVocabularyDocLeniently(searchVocabResponse.getHits().getAt(0)); |
172 | | - NlpTask nlpTask = new NlpTask(nlpConfig, vocabulary); |
173 | | - NlpTask.Processor processor = nlpTask.createProcessor(); |
174 | | - processContext.nlpTaskProcessor.set(processor); |
175 | | - // here, we are being called back on the searching thread, which MAY be a network thread |
176 | | - // `startAndLoad` creates named pipes, blocking the calling thread, better to execute that in our utility |
177 | | - // executor. |
178 | | - executorServiceForDeployment.execute(() -> startAndLoad(processContext, modelConfig.getLocation(), modelLoadedListener)); |
179 | | - }, listener::onFailure)); |
| 181 | + }, listener::onFailure)); |
| 182 | + } else { |
| 183 | + listener.onFailure( |
| 184 | + new IllegalArgumentException( |
| 185 | + format( |
| 186 | + "[%s] must be a pytorch model; found inference config of kind [%s]", |
| 187 | + modelConfig.getModelId(), |
| 188 | + modelConfig.getInferenceConfig().getWriteableName() |
| 189 | + ) |
| 190 | + ) |
| 191 | + ); |
| 192 | + } |
180 | 193 | }, listener::onFailure); |
181 | 194 |
|
182 | 195 | executeAsyncWithOrigin( |
@@ -404,10 +417,12 @@ private Consumer<String> onProcessCrash() { |
404 | 417 | } |
405 | 418 |
|
406 | 419 | void loadModel(TrainedModelLocation modelLocation, ActionListener<Boolean> listener) { |
407 | | - if (modelLocation instanceof IndexLocation) { |
408 | | - process.get().loadModel(task.getModelId(), ((IndexLocation) modelLocation).getIndexName(), stateStreamer, listener); |
| 420 | + if (modelLocation instanceof IndexLocation indexLocation) { |
| 421 | + process.get().loadModel(task.getModelId(), indexLocation.getIndexName(), stateStreamer, listener); |
409 | 422 | } else { |
410 | | - throw new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]"); |
| 423 | + listener.onFailure( |
| 424 | + new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]") |
| 425 | + ); |
411 | 426 | } |
412 | 427 | } |
413 | 428 |
|
|
0 commit comments