From 6cb8215e54ba10c2596dab7bde0b812bc329baf9 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 5 Feb 2019 13:45:39 +0200 Subject: [PATCH 1/2] [FEATURE][ML] Allow parsing different types of results from analytics process --- .../process/AnalyticsProcessManager.java | 8 +- .../ml/dataframe/process/AnalyticsResult.java | 34 ++--- .../process/AnalyticsResultProcessor.java | 86 ++---------- .../process/DataFrameRowsJoiner.java | 124 +++++++++++++++++ .../dataframe/process/results/RowResults.java | 73 ++++++++++ .../AnalyticsResultProcessorTests.java | 107 +++----------- .../process/AnalyticsResultTests.java | 16 +-- .../process/DataFrameRowsJoinerTests.java | 131 ++++++++++++++++++ .../process/results/RowResultsTests.java | 42 ++++++ 9 files changed, 421 insertions(+), 200 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResults.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResultsTests.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 168662e5cf3c0..5868cefe3a30e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -16,10 +16,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; -import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysesUtils; import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; import java.io.IOException; import java.util.List; @@ -51,7 +51,9 @@ public void runJob(DataFrameAnalyticsConfig config, DataFrameDataExtractorFactor DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor(false); AnalyticsProcess process = createProcess(config.getId(), createProcessConfig(config, dataExtractor)); ExecutorService executorService = threadPool.executor(MachineLearning.AUTODETECT_THREAD_POOL_NAME); - AnalyticsResultProcessor resultProcessor = new AnalyticsResultProcessor(client, dataExtractorFactory.newExtractor(true)); + DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client, + dataExtractorFactory.newExtractor(true)); + AnalyticsResultProcessor resultProcessor = new AnalyticsResultProcessor(dataFrameRowsJoiner); executorService.execute(() -> resultProcessor.process(process)); executorService.execute(() -> processData(config.getId(), dataExtractor, process, resultProcessor, finishHandler)); }); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java index 3e9c1b8b9cd57..4d15bf89b29e9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java @@ -9,46 +9,38 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import java.io.IOException; -import java.util.Map; import java.util.Objects; public class AnalyticsResult implements ToXContentObject { public static final ParseField TYPE = new ParseField("analytics_result"); - public static final ParseField CHECKSUM = new ParseField("checksum"); - public static final ParseField RESULTS = new ParseField("results"); static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(), - a -> new AnalyticsResult((Integer) a[0], (Map) a[1])); + a -> new AnalyticsResult((RowResults) a[0])); static { - PARSER.declareInt(ConstructingObjectParser.constructorArg(), CHECKSUM); - PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, context) -> p.map(), RESULTS); + PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE); } - private final int checksum; - private final Map results; + private final RowResults rowResults; - public AnalyticsResult(int checksum, Map results) { - this.checksum = Objects.requireNonNull(checksum); - this.results = Objects.requireNonNull(results); + public AnalyticsResult(RowResults rowResults) { + this.rowResults = rowResults; } - public int getChecksum() { - return checksum; - } - - public Map getResults() { - return results; + public RowResults getRowResults() { + return rowResults; } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CHECKSUM.getPreferredName(), checksum); - builder.field(RESULTS.getPreferredName(), results); + if (rowResults != null) { + builder.field(RowResults.TYPE.getPreferredName(), rowResults); + } builder.endObject(); return builder; } @@ -63,11 +55,11 @@ public boolean equals(Object other) { } AnalyticsResult that = (AnalyticsResult) other; - return checksum == that.checksum && Objects.equals(results, that.results); + return Objects.equals(rowResults, that.rowResults); } @Override public int hashCode() { - return Objects.hash(checksum, results); + return Objects.hash(rowResults); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index cd6864f049740..1b70d68598df6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -7,22 +7,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.DocWriteRequest; -import org.elasticsearch.action.bulk.BulkAction; -import org.elasticsearch.action.bulk.BulkRequest; -import org.elasticsearch.action.bulk.BulkResponse; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.client.Client; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; -import java.util.ArrayList; import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; import java.util.Objects; -import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -30,15 +18,11 @@ public class AnalyticsResultProcessor { private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class); - private final Client client; - private final DataFrameDataExtractor dataExtractor; - private List currentDataFrameRows; - private List currentResults; + private final DataFrameRowsJoiner dataFrameRowsJoiner; private final CountDownLatch completionLatch = new CountDownLatch(1); - public AnalyticsResultProcessor(Client client, DataFrameDataExtractor dataExtractor) { - this.client = Objects.requireNonNull(client); - this.dataExtractor = Objects.requireNonNull(dataExtractor); + public AnalyticsResultProcessor(DataFrameRowsJoiner dataFrameRowsJoiner) { + this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); } public void awaitForCompletion() { @@ -57,28 +41,8 @@ public void process(AnalyticsProcess process) { try { Iterator iterator = process.readAnalyticsResults(); while (iterator.hasNext()) { - try { - AnalyticsResult result = iterator.next(); - if (dataExtractor.hasNext() == false) { - return; - } - if (currentDataFrameRows == null) { - Optional> nextBatch = dataExtractor.next(); - if (nextBatch.isPresent() == false) { - return; - } - currentDataFrameRows = nextBatch.get(); - currentResults = new ArrayList<>(currentDataFrameRows.size()); - } - currentResults.add(result); - if (currentResults.size() == currentDataFrameRows.size()) { - joinCurrentResults(); - currentDataFrameRows = null; - } - } catch (Exception e) { - LOGGER.warn("Error processing data frame analytics result", e); - } - + AnalyticsResult result = iterator.next(); + processResult(result); } } catch (Exception e) { LOGGER.error("Error parsing data frame analytics output", e); @@ -88,40 +52,10 @@ public void process(AnalyticsProcess process) { } } - private void joinCurrentResults() { - BulkRequest bulkRequest = new BulkRequest(); - for (int i = 0; i < currentDataFrameRows.size(); i++) { - DataFrameDataExtractor.Row row = currentDataFrameRows.get(i); - if (row.shouldSkip()) { - continue; - } - AnalyticsResult result = currentResults.get(i); - checkChecksumsMatch(row, result); - - SearchHit hit = row.getHit(); - Map source = new LinkedHashMap(hit.getSourceAsMap()); - source.putAll(result.getResults()); - IndexRequest indexRequest = new IndexRequest(hit.getIndex(), hit.getType(), hit.getId()); - indexRequest.source(source); - indexRequest.opType(DocWriteRequest.OpType.INDEX); - bulkRequest.add(indexRequest); - } - if (bulkRequest.numberOfActions() > 0) { - BulkResponse bulkResponse = client.execute(BulkAction.INSTANCE, bulkRequest).actionGet(); - if (bulkResponse.hasFailures()) { - LOGGER.error("Failures while writing data frame"); - // TODO Better error handling - } - } - } - - private void checkChecksumsMatch(DataFrameDataExtractor.Row row, AnalyticsResult result) { - if (row.getChecksum() != result.getChecksum()) { - String msg = "Detected checksum mismatch for document with id [" + row.getHit().getId() + "]; "; - msg += "expected [" + row.getChecksum() + "] but result had [" + result.getChecksum() + "]; "; - msg += "this implies the data frame index [" + row.getHit().getIndex() + "] was modified while the analysis was running. "; - msg += "We rely on this index being immutable during a running analysis and so the results will be unreliable."; - throw new IllegalStateException(msg); + private void processResult(AnalyticsResult result) { + RowResults rowResults = result.getRowResults(); + if (rowResults != null) { + dataFrameRowsJoiner.processRowResults(rowResults); } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java new file mode 100644 index 0000000000000..6dbebe5ab6e32 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.BulkAction; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.client.Client; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +public class DataFrameRowsJoiner { + + private static final Logger LOGGER = LogManager.getLogger(DataFrameRowsJoiner.class); + + private final String analyticsId; + private final Client client; + private final DataFrameDataExtractor dataExtractor; + private List currentDataFrameRows; + private List currentResults; + private boolean failed; + + public DataFrameRowsJoiner(String analyticsId, Client client, DataFrameDataExtractor dataExtractor) { + this.analyticsId = Objects.requireNonNull(analyticsId); + this.client = Objects.requireNonNull(client); + this.dataExtractor = Objects.requireNonNull(dataExtractor); + } + + public void processRowResults(RowResults rowResults) { + if (failed) { + // If we are in failed state we drop the results but we let the processor + // parse the output + return; + } + + try { + addResultAndJoinIfEndOfBatch(rowResults); + } catch (Exception e) { + LOGGER.error(new ParameterizedMessage("[{}] Failed to join results", analyticsId), e); + failed = true; + } + } + + private void addResultAndJoinIfEndOfBatch(RowResults rowResults) { + if (currentDataFrameRows == null) { + Optional> nextBatch = getNextBatch(); + if (nextBatch.isPresent() == false) { + return; + } + currentDataFrameRows = nextBatch.get(); + currentResults = new ArrayList<>(currentDataFrameRows.size()); + } + currentResults.add(rowResults); + if (currentResults.size() == currentDataFrameRows.size()) { + joinCurrentResults(); + currentDataFrameRows = null; + } + } + + private Optional> getNextBatch() { + try { + return dataExtractor.next(); + } catch (IOException e) { + // TODO Implement recovery strategy or better error reporting + LOGGER.error("Error reading next batch of data frame rows", e); + return Optional.empty(); + } + } + + private void joinCurrentResults() { + BulkRequest bulkRequest = new BulkRequest(); + for (int i = 0; i < currentDataFrameRows.size(); i++) { + DataFrameDataExtractor.Row row = currentDataFrameRows.get(i); + if (row.shouldSkip()) { + continue; + } + RowResults result = currentResults.get(i); + checkChecksumsMatch(row, result); + + SearchHit hit = row.getHit(); + Map source = new LinkedHashMap(hit.getSourceAsMap()); + source.putAll(result.getResults()); + IndexRequest indexRequest = new IndexRequest(hit.getIndex(), hit.getType(), hit.getId()); + indexRequest.source(source); + indexRequest.opType(DocWriteRequest.OpType.INDEX); + bulkRequest.add(indexRequest); + } + if (bulkRequest.numberOfActions() > 0) { + BulkResponse bulkResponse = client.execute(BulkAction.INSTANCE, bulkRequest).actionGet(); + if (bulkResponse.hasFailures()) { + LOGGER.error("Failures while writing data frame"); + // TODO Better error handling + } + } + } + + private void checkChecksumsMatch(DataFrameDataExtractor.Row row, RowResults result) { + if (row.getChecksum() != result.getChecksum()) { + String msg = "Detected checksum mismatch for document with id [" + row.getHit().getId() + "]; "; + msg += "expected [" + row.getChecksum() + "] but result had [" + result.getChecksum() + "]; "; + msg += "this implies the data frame index [" + row.getHit().getIndex() + "] was modified while the analysis was running. "; + msg += "We rely on this index being immutable during a running analysis and so the results will be unreliable."; + throw new RuntimeException(msg); + // TODO Communicate this error to the user as effectively the analytics have failed (e.g. FAILED state, audit error, etc.) + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResults.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResults.java new file mode 100644 index 0000000000000..ba4aebededa2e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResults.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.process.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class RowResults implements ToXContentObject { + + public static final ParseField TYPE = new ParseField("row_results"); + public static final ParseField CHECKSUM = new ParseField("checksum"); + public static final ParseField RESULTS = new ParseField("results"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(), + a -> new RowResults((Integer) a[0], (Map) a[1])); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), CHECKSUM); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, context) -> p.map(), RESULTS); + } + + private final int checksum; + private final Map results; + + public RowResults(int checksum, Map results) { + this.checksum = Objects.requireNonNull(checksum); + this.results = Objects.requireNonNull(results); + } + + public int getChecksum() { + return checksum; + } + + public Map getResults() { + return results; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CHECKSUM.getPreferredName(), checksum); + builder.field(RESULTS.getPreferredName(), results); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + + RowResults that = (RowResults) other; + return checksum == that.checksum && Objects.equals(results, that.results); + } + + @Override + public int hashCode() { + return Objects.hash(checksum, results); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 15d2e012080e1..cded955767344 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -5,47 +5,29 @@ */ package org.elasticsearch.xpack.ml.dataframe.process; -import org.elasticsearch.action.ActionFuture; -import org.elasticsearch.action.bulk.BulkAction; -import org.elasticsearch.action.bulk.BulkItemResponse; -import org.elasticsearch.action.bulk.BulkRequest; -import org.elasticsearch.action.bulk.BulkResponse; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.client.Client; -import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.text.Text; -import org.elasticsearch.search.SearchHit; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.junit.Before; -import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; +import org.mockito.Mockito; -import java.io.IOException; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.Optional; -import static org.hamcrest.Matchers.equalTo; -import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; public class AnalyticsResultProcessorTests extends ESTestCase { - private Client client; private AnalyticsProcess process; - private DataFrameDataExtractor dataExtractor; - private ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); + private DataFrameRowsJoiner dataFrameRowsJoiner; @Before public void setUpMocks() { - client = mock(Client.class); process = mock(AnalyticsProcess.class); - dataExtractor = mock(DataFrameDataExtractor.class); + dataFrameRowsJoiner = mock(DataFrameRowsJoiner.class); } public void testProcess_GivenNoResults() { @@ -55,93 +37,38 @@ public void testProcess_GivenNoResults() { resultProcessor.process(process); resultProcessor.awaitForCompletion(); - verifyNoMoreInteractions(client); + verifyNoMoreInteractions(dataFrameRowsJoiner); } - public void testProcess_GivenSingleRowAndResult() throws IOException { - givenClientHasNoFailures(); - - String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; - String[] dataValues = {"42.0"}; - DataFrameDataExtractor.Row row = newRow(newHit(dataDoc), dataValues, 1); - givenSingleDataFrameBatch(Arrays.asList(row)); - - Map resultFields = new HashMap<>(); - resultFields.put("a", "1"); - resultFields.put("b", "2"); - AnalyticsResult result = new AnalyticsResult(1, resultFields); - givenProcessResults(Arrays.asList(result)); - + public void testProcess_GivenEmptyResults() { + givenProcessResults(Arrays.asList(new AnalyticsResult(null), new AnalyticsResult(null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); resultProcessor.awaitForCompletion(); - List capturedBulkRequests = bulkRequestCaptor.getAllValues(); - assertThat(capturedBulkRequests.size(), equalTo(1)); - BulkRequest capturedBulkRequest = capturedBulkRequests.get(0); - assertThat(capturedBulkRequest.numberOfActions(), equalTo(1)); - IndexRequest indexRequest = (IndexRequest) capturedBulkRequest.requests().get(0); - Map indexedDocSource = indexRequest.sourceAsMap(); - assertThat(indexedDocSource.size(), equalTo(4)); - assertThat(indexedDocSource.get("f_1"), equalTo("foo")); - assertThat(indexedDocSource.get("f_2"), equalTo(42.0)); - assertThat(indexedDocSource.get("a"), equalTo("1")); - assertThat(indexedDocSource.get("b"), equalTo("2")); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner); } - public void testProcess_GivenSingleRowAndResultWithMismatchingIdHash() throws IOException { - givenClientHasNoFailures(); - - String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; - String[] dataValues = {"42.0"}; - DataFrameDataExtractor.Row row = newRow(newHit(dataDoc), dataValues, 1); - givenSingleDataFrameBatch(Arrays.asList(row)); - - Map resultFields = new HashMap<>(); - resultFields.put("a", "1"); - resultFields.put("b", "2"); - AnalyticsResult result = new AnalyticsResult(2, resultFields); - givenProcessResults(Arrays.asList(result)); - + public void testProcess_GivenRowResults() { + RowResults rowResults1 = mock(RowResults.class); + RowResults rowResults2 = mock(RowResults.class); + givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1), new AnalyticsResult(rowResults2))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); resultProcessor.awaitForCompletion(); - verifyNoMoreInteractions(client); + InOrder inOrder = Mockito.inOrder(dataFrameRowsJoiner); + inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults1); + inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults2); } private void givenProcessResults(List results) { when(process.readAnalyticsResults()).thenReturn(results.iterator()); } - private void givenSingleDataFrameBatch(List batch) throws IOException { - when(dataExtractor.hasNext()).thenReturn(true).thenReturn(true).thenReturn(false); - when(dataExtractor.next()).thenReturn(Optional.of(batch)).thenReturn(Optional.empty()); - } - - private static SearchHit newHit(String json) { - SearchHit hit = new SearchHit(randomInt(), randomAlphaOfLength(10), new Text("doc"), Collections.emptyMap()); - hit.sourceRef(new BytesArray(json)); - return hit; - } - - private static DataFrameDataExtractor.Row newRow(SearchHit hit, String[] values, int checksum) { - DataFrameDataExtractor.Row row = mock(DataFrameDataExtractor.Row.class); - when(row.getHit()).thenReturn(hit); - when(row.getValues()).thenReturn(values); - when(row.getChecksum()).thenReturn(checksum); - return row; - } - - private void givenClientHasNoFailures() { - ActionFuture responseFuture = mock(ActionFuture.class); - when(responseFuture.actionGet()).thenReturn(new BulkResponse(new BulkItemResponse[0], 0)); - when(client.execute(same(BulkAction.INSTANCE), bulkRequestCaptor.capture())).thenReturn(responseFuture); - } - private AnalyticsResultProcessor createResultProcessor() { - return new AnalyticsResultProcessor(client, dataExtractor); + return new AnalyticsResultProcessor(dataFrameRowsJoiner); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java index fc46b4e984d26..d0d3b4ee5f99c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java @@ -7,24 +7,20 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.elasticsearch.xpack.ml.dataframe.process.results.RowResultsTests; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; public class AnalyticsResultTests extends AbstractXContentTestCase { @Override protected AnalyticsResult createTestInstance() { - int checksum = randomInt(); - Map results = new HashMap<>(); - int resultsSize = randomIntBetween(1, 10); - for (int i = 0; i < resultsSize; i++) { - String resultField = randomAlphaOfLength(20); - Object resultValue = randomBoolean() ? randomAlphaOfLength(20) : randomDouble(); - results.put(resultField, resultValue); + RowResults rowResults = null; + if (randomBoolean()) { + rowResults = RowResultsTests.createRandom(); } - return new AnalyticsResult(checksum, results); + return new AnalyticsResult(rowResults); } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java new file mode 100644 index 0000000000000..4c6d9e78a9300 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.process; + +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.bulk.BulkAction; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.text.Text; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Matchers.same; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class DataFrameRowsJoinerTests extends ESTestCase { + + private static final String ANALYTICS_ID = "my_analytics"; + + private Client client; + private DataFrameDataExtractor dataExtractor; + private DataFrameRowsJoiner dataFrameRowsJoiner; + private ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); + + @Before + public void setUpMocks() { + client = mock(Client.class); + dataExtractor = mock(DataFrameDataExtractor.class); + dataFrameRowsJoiner = new DataFrameRowsJoiner(ANALYTICS_ID, client, dataExtractor); + } + + public void testProcess_GivenNoResults() { + givenProcessResults(Collections.emptyList()); + verifyNoMoreInteractions(client); + } + + public void testProcess_GivenSingleRowAndResult() throws IOException { + givenClientHasNoFailures(); + + String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; + String[] dataValues = {"42.0"}; + DataFrameDataExtractor.Row row = newRow(newHit(dataDoc), dataValues, 1); + givenSingleDataFrameBatch(Arrays.asList(row)); + + Map resultFields = new HashMap<>(); + resultFields.put("a", "1"); + resultFields.put("b", "2"); + RowResults result = new RowResults(1, resultFields); + givenProcessResults(Arrays.asList(result)); + + List capturedBulkRequests = bulkRequestCaptor.getAllValues(); + assertThat(capturedBulkRequests.size(), equalTo(1)); + BulkRequest capturedBulkRequest = capturedBulkRequests.get(0); + assertThat(capturedBulkRequest.numberOfActions(), equalTo(1)); + IndexRequest indexRequest = (IndexRequest) capturedBulkRequest.requests().get(0); + Map indexedDocSource = indexRequest.sourceAsMap(); + assertThat(indexedDocSource.size(), equalTo(4)); + assertThat(indexedDocSource.get("f_1"), equalTo("foo")); + assertThat(indexedDocSource.get("f_2"), equalTo(42.0)); + assertThat(indexedDocSource.get("a"), equalTo("1")); + assertThat(indexedDocSource.get("b"), equalTo("2")); + } + + public void testProcess_GivenSingleRowAndResultWithMismatchingIdHash() throws IOException { + givenClientHasNoFailures(); + + String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; + String[] dataValues = {"42.0"}; + DataFrameDataExtractor.Row row = newRow(newHit(dataDoc), dataValues, 1); + givenSingleDataFrameBatch(Arrays.asList(row)); + + Map resultFields = new HashMap<>(); + resultFields.put("a", "1"); + resultFields.put("b", "2"); + RowResults result = new RowResults(2, resultFields); + givenProcessResults(Arrays.asList(result)); + + verifyNoMoreInteractions(client); + } + + private void givenProcessResults(List results) { + results.forEach(dataFrameRowsJoiner::processRowResults); + } + + private void givenSingleDataFrameBatch(List batch) throws IOException { + when(dataExtractor.hasNext()).thenReturn(true).thenReturn(true).thenReturn(false); + when(dataExtractor.next()).thenReturn(Optional.of(batch)).thenReturn(Optional.empty()); + } + + private static SearchHit newHit(String json) { + SearchHit hit = new SearchHit(randomInt(), randomAlphaOfLength(10), new Text("doc"), Collections.emptyMap()); + hit.sourceRef(new BytesArray(json)); + return hit; + } + + private static DataFrameDataExtractor.Row newRow(SearchHit hit, String[] values, int checksum) { + DataFrameDataExtractor.Row row = mock(DataFrameDataExtractor.Row.class); + when(row.getHit()).thenReturn(hit); + when(row.getValues()).thenReturn(values); + when(row.getChecksum()).thenReturn(checksum); + return row; + } + + private void givenClientHasNoFailures() { + ActionFuture responseFuture = mock(ActionFuture.class); + when(responseFuture.actionGet()).thenReturn(new BulkResponse(new BulkItemResponse[0], 0)); + when(client.execute(same(BulkAction.INSTANCE), bulkRequestCaptor.capture())).thenReturn(responseFuture); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResultsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResultsTests.java new file mode 100644 index 0000000000000..5fdeee90329ae --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResultsTests.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.process.results; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.util.HashMap; +import java.util.Map; + +public class RowResultsTests extends AbstractXContentTestCase { + + @Override + protected RowResults createTestInstance() { + return createRandom(); + } + + public static RowResults createRandom() { + int checksum = randomInt(); + Map results = new HashMap<>(); + int resultsSize = randomIntBetween(1, 10); + for (int i = 0; i < resultsSize; i++) { + String resultField = randomAlphaOfLength(20); + Object resultValue = randomBoolean() ? randomAlphaOfLength(20) : randomDouble(); + results.put(resultField, resultValue); + } + return new RowResults(checksum, results); + } + + @Override + protected RowResults doParseInstance(XContentParser parser) { + return RowResults.PARSER.apply(parser, null); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } +} From e59fa6eeb8022075e87abd6e0ae5f327628cfed6 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 5 Feb 2019 18:16:10 +0200 Subject: [PATCH 2/2] Remove use of type in index request --- .../ml/dataframe/process/DataFrameRowsJoiner.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java index 6dbebe5ab6e32..76ebe166a39ad 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java @@ -26,7 +26,7 @@ import java.util.Objects; import java.util.Optional; -public class DataFrameRowsJoiner { +class DataFrameRowsJoiner { private static final Logger LOGGER = LogManager.getLogger(DataFrameRowsJoiner.class); @@ -37,13 +37,13 @@ public class DataFrameRowsJoiner { private List currentResults; private boolean failed; - public DataFrameRowsJoiner(String analyticsId, Client client, DataFrameDataExtractor dataExtractor) { + DataFrameRowsJoiner(String analyticsId, Client client, DataFrameDataExtractor dataExtractor) { this.analyticsId = Objects.requireNonNull(analyticsId); this.client = Objects.requireNonNull(client); this.dataExtractor = Objects.requireNonNull(dataExtractor); } - public void processRowResults(RowResults rowResults) { + void processRowResults(RowResults rowResults) { if (failed) { // If we are in failed state we drop the results but we let the processor // parse the output @@ -97,7 +97,9 @@ private void joinCurrentResults() { SearchHit hit = row.getHit(); Map source = new LinkedHashMap(hit.getSourceAsMap()); source.putAll(result.getResults()); - IndexRequest indexRequest = new IndexRequest(hit.getIndex(), hit.getType(), hit.getId()); + new IndexRequest(hit.getIndex()); + IndexRequest indexRequest = new IndexRequest(hit.getIndex()); + indexRequest.id(hit.getId()); indexRequest.source(source); indexRequest.opType(DocWriteRequest.OpType.INDEX); bulkRequest.add(indexRequest);