From 2bd0891b7c6f3a22a8ad7f38178374bee4052320 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 25 Jun 2019 10:48:27 +0300 Subject: [PATCH] [7.x][ML] Machine learning data frame analytics (#43544) This merges the initial work that adds a framework for performing machine learning analytics on data frames. The feature is currently experimental and requires a platinum license. Note that the original commits can be found in the `feature-ml-data-frame-analytics` branch. A new set of APIs is added which allows the creation of data frame analytics jobs. Configuration allows specifying different types of analysis to be performed on a data frame. At first there is support for outlier detection. The APIs are: - PUT _ml/data_frame/analysis/{id} - GET _ml/data_frame/analysis/{id} - GET _ml/data_frame/analysis/{id}/_stats - POST _ml/data_frame/analysis/{id}/_start - POST _ml/data_frame/analysis/{id}/_stop - DELETE _ml/data_frame/analysis/{id} When a data frame analytics job is started a persistent task is created and started. The main steps of the task are: 1. reindex the source index into the dest index 2. analyze the data through the data_frame_analyzer c++ process 3. merge the results of the process back into the destination index In addition, an evaluation API is added which packages commonly used metrics that provide evaluation of various analysis: - POST _ml/data_frame/_evaluate --- .../client/MLRequestConverters.java | 116 +++ .../client/MachineLearningClient.java | 294 ++++++ .../DataFrameNamedXContentProvider.java | 41 + .../transforms/DataFrameTransformConfig.java | 55 +- .../dataframe/transforms/SyncConfig.java | 30 + .../dataframe/transforms/TimeSyncConfig.java | 108 +++ .../ml/DeleteDataFrameAnalyticsRequest.java | 64 ++ .../client/ml/EvaluateDataFrameRequest.java | 136 +++ .../client/ml/EvaluateDataFrameResponse.java | 119 +++ .../ml/GetDataFrameAnalyticsRequest.java | 104 +++ .../ml/GetDataFrameAnalyticsResponse.java | 74 ++ .../ml/GetDataFrameAnalyticsStatsRequest.java | 99 ++ .../GetDataFrameAnalyticsStatsResponse.java | 102 +++ .../client/ml/NodeAttributes.java | 6 + .../ml/PutDataFrameAnalyticsRequest.java | 70 ++ .../ml/PutDataFrameAnalyticsResponse.java | 57 ++ .../ml/StartDataFrameAnalyticsRequest.java | 74 ++ .../ml/StopDataFrameAnalyticsRequest.java | 88 ++ .../ml/StopDataFrameAnalyticsResponse.java | 87 ++ .../ml/dataframe/DataFrameAnalysis.java | 27 + .../dataframe/DataFrameAnalyticsConfig.java | 208 +++++ .../ml/dataframe/DataFrameAnalyticsDest.java | 123 +++ .../dataframe/DataFrameAnalyticsSource.java | 121 +++ .../ml/dataframe/DataFrameAnalyticsState.java | 34 + .../ml/dataframe/DataFrameAnalyticsStats.java | 133 +++ ...ataFrameAnalysisNamedXContentProvider.java | 37 + .../client/ml/dataframe/OutlierDetection.java | 176 ++++ .../client/ml/dataframe/QueryConfig.java | 82 ++ .../ml/dataframe/evaluation/Evaluation.java | 32 + .../evaluation/EvaluationMetric.java | 43 + .../MlEvaluationNamedXContentProvider.java | 57 ++ .../AbstractConfusionMatrixMetric.java | 47 + .../softclassification/AucRocMetric.java | 241 +++++ .../BinarySoftClassification.java | 129 +++ .../ConfusionMatrixMetric.java | 206 +++++ .../softclassification/PrecisionMetric.java | 123 +++ .../softclassification/RecallMetric.java | 123 +++ ...icsearch.plugins.spi.NamedXContentProvider | 5 +- .../DataFrameRequestConvertersTests.java | 6 +- .../client/MLRequestConvertersTests.java | 133 ++- .../client/MachineLearningIT.java | 513 +++++++++-- .../client/MlTestStateCleaner.java | 13 + .../client/RestHighLevelClientTests.java | 24 +- .../GetDataFrameTransformResponseTests.java | 6 +- ...PreviewDataFrameTransformRequestTests.java | 6 +- .../PutDataFrameTransformRequestTests.java | 6 +- .../DataFrameTransformConfigTests.java | 12 +- .../transforms/TimeSyncConfigTests.java | 49 + .../transforms/hlrc/TimeSyncConfigTests.java | 59 ++ .../DataFrameTransformDocumentationIT.java | 1 + .../MlClientDocumentationIT.java | 569 +++++++++++- .../ml/AucRocMetricAucRocPointTests.java | 47 + .../client/ml/AucRocMetricResultTests.java | 63 ++ ...usionMatrixMetricConfusionMatrixTests.java | 47 + .../ml/ConfusionMatrixMetricResultTests.java | 62 ++ .../DeleteDataFrameAnalyticsRequestTests.java | 39 + .../ml/EvaluateDataFrameResponseTests.java | 76 ++ .../ml/GetDataFrameAnalyticsRequestTests.java | 39 + ...etDataFrameAnalyticsStatsRequestTests.java | 39 + .../client/ml/PrecisionMetricResultTests.java | 60 ++ .../ml/PutDataFrameAnalyticsRequestTests.java | 74 ++ .../client/ml/RecallMetricResultTests.java | 60 ++ .../StartDataFrameAnalyticsRequestTests.java | 43 + .../StopDataFrameAnalyticsRequestTests.java | 43 + .../StopDataFrameAnalyticsResponseTests.java | 42 + .../DataFrameAnalyticsConfigTests.java | 88 ++ .../DataFrameAnalyticsDestTests.java | 50 + .../DataFrameAnalyticsSourceTests.java | 70 ++ .../DataFrameAnalyticsStatsTests.java | 66 ++ .../ml/dataframe/OutlierDetectionTests.java | 73 ++ .../client/ml/dataframe/QueryConfigTests.java | 62 ++ .../ml/delete-data-frame-analytics.asciidoc | 28 + .../ml/evaluate-data-frame.asciidoc | 45 + .../get-data-frame-analytics-stats.asciidoc | 34 + .../ml/get-data-frame-analytics.asciidoc | 34 + .../ml/put-data-frame-analytics.asciidoc | 115 +++ .../ml/start-data-frame-analytics.asciidoc | 28 + .../ml/stop-data-frame-analytics.asciidoc | 28 + .../high-level/supported-apis.asciidoc | 14 + .../xpack/core/XPackClientPlugin.java | 52 ++ .../xpack/core/dataframe/DataFrameField.java | 4 + .../core/dataframe/DataFrameMessages.java | 4 +- .../DataFrameNamedXContentProvider.java | 26 + .../transforms/DataFrameTransformConfig.java | 59 +- .../core/dataframe/transforms/SyncConfig.java | 25 + .../dataframe/transforms/TimeSyncConfig.java | 148 +++ .../pivot/DateHistogramGroupSource.java | 13 + .../pivot/HistogramGroupSource.java | 13 + .../transforms/pivot/PivotConfig.java | 6 +- .../transforms/pivot/SingleGroupSource.java | 6 + .../transforms/pivot/TermsGroupSource.java | 13 + .../xpack/core/ml/MachineLearningField.java | 2 +- .../elasticsearch/xpack/core/ml/MlTasks.java | 31 + .../DeleteDataFrameAnalyticsAction.java | 100 ++ .../ml/action/EvaluateDataFrameAction.java | 215 +++++ .../action/GetDataFrameAnalyticsAction.java | 80 ++ .../GetDataFrameAnalyticsStatsAction.java | 321 +++++++ .../action/PutDataFrameAnalyticsAction.java | 153 ++++ .../action/StartDataFrameAnalyticsAction.java | 223 +++++ .../action/StopDataFrameAnalyticsAction.java | 223 +++++ .../core/ml/datafeed/DatafeedConfig.java | 3 +- .../core/ml/datafeed/DatafeedUpdate.java | 5 +- .../dataframe/DataFrameAnalyticsConfig.java | 312 +++++++ .../ml/dataframe/DataFrameAnalyticsDest.java | 106 +++ .../dataframe/DataFrameAnalyticsSource.java | 144 +++ .../ml/dataframe/DataFrameAnalyticsState.java | 36 + .../DataFrameAnalyticsTaskState.java | 105 +++ .../dataframe/analyses/DataFrameAnalysis.java | 16 + ...ataFrameAnalysisNamedXContentProvider.java | 37 + .../dataframe/analyses/OutlierDetection.java | 169 ++++ .../ml/dataframe/evaluation/Evaluation.java | 37 + .../evaluation/EvaluationMetricResult.java | 20 + .../MlEvaluationNamedXContentProvider.java | 69 ++ .../AbstractConfusionMatrixMetric.java | 102 +++ .../evaluation/softclassification/AucRoc.java | 350 +++++++ .../BinarySoftClassification.java | 212 +++++ .../softclassification/ConfusionMatrix.java | 163 ++++ .../softclassification/Precision.java | 91 ++ .../evaluation/softclassification/Recall.java | 91 ++ .../ScoreByThresholdResult.java | 63 ++ .../SoftClassificationMetric.java | 60 ++ .../xpack/core/ml/job/messages/Messages.java | 4 + .../persistence/ElasticsearchMappings.java | 51 ++ .../ml/job/results/ReservedFieldNames.java | 18 + .../core/ml/process/writer/RecordWriter.java | 2 +- .../xpack/core/ml/utils/ExceptionsHelper.java | 13 + .../ml/{datafeed => utils}/QueryProvider.java | 23 +- .../AbstractSerializingDataFrameTestCase.java | 8 + ...tractWireSerializingDataFrameTestCase.java | 8 + ...wDataFrameTransformActionRequestTests.java | 10 +- .../AbstractSerializingDataFrameTestCase.java | 4 + .../DataFrameTransformConfigTests.java | 22 +- .../transforms/TimeSyncConfigTests.java | 38 + .../xpack/core/ml/MlTasksTests.java | 7 + .../EvaluateDataFrameActionRequestTests.java | 58 ++ ...DataFrameAnalyticsActionResponseTests.java | 55 ++ .../GetDataFrameAnalyticsRequestTests.java | 27 + ...rameAnalyticsStatsActionResponseTests.java | 37 + ...etDataFrameAnalyticsStatsRequestTests.java | 26 + ...tDataFrameAnalyticsActionRequestTests.java | 67 ++ ...DataFrameAnalyticsActionResponseTests.java | 48 + .../StartDataFrameAnalyticsRequestTests.java | 28 + ...DataFrameAnalyticsActionResponseTests.java | 23 + .../StopDataFrameAnalyticsRequestTests.java | 40 + .../core/ml/datafeed/DatafeedConfigTests.java | 3 +- .../core/ml/datafeed/DatafeedUpdateTests.java | 3 +- .../DataFrameAnalyticsConfigTests.java | 251 ++++++ .../DataFrameAnalyticsDestTests.java | 55 ++ .../DataFrameAnalyticsSourceTests.java | 64 ++ .../analyses/OutlierDetectionTests.java | 59 ++ .../softclassification/AucRocTests.java | 127 +++ .../BinarySoftClassificationTests.java | 79 ++ .../ConfusionMatrixTests.java | 79 ++ .../softclassification/PrecisionTests.java | 93 ++ .../softclassification/RecallTests.java | 93 ++ .../integration/MlRestTestStateCleaner.java | 17 + .../QueryProviderTests.java | 8 +- .../DataFrameTransformProgressIT.java | 3 + .../xpack/dataframe/DataFrame.java | 7 + .../DataFrameTransformsCheckpointService.java | 15 +- .../transforms/DataFrameIndexer.java | 142 ++- ...FrameTransformPersistentTasksExecutor.java | 28 +- .../transforms/DataFrameTransformTask.java | 72 +- .../dataframe/transforms/pivot/Pivot.java | 87 +- .../transforms/DataFrameIndexerTests.java | 14 +- .../ml/qa/ml-with-security/build.gradle | 50 + .../plugin/ml/qa/ml-with-security/roles.yml | 4 +- .../smoketest/MlWithSecurityUserRoleIT.java | 33 +- ...NativeDataFrameAnalyticsIntegTestCase.java | 118 +++ .../ml/integration/MlNativeIntegTestCase.java | 6 + .../integration/RunDataFrameAnalyticsIT.java | 281 ++++++ .../xpack/ml/MachineLearning.java | 82 +- ...ansportDeleteDataFrameAnalyticsAction.java | 112 +++ .../ml/action/TransportDeleteJobAction.java | 2 +- .../TransportEvaluateDataFrameAction.java | 61 ++ .../TransportGetDataFrameAnalyticsAction.java | 82 ++ ...sportGetDataFrameAnalyticsStatsAction.java | 190 ++++ .../ml/action/TransportOpenJobAction.java | 283 +----- .../TransportPutDataFrameAnalyticsAction.java | 160 ++++ ...ransportStartDataFrameAnalyticsAction.java | 452 ++++++++++ ...TransportStopDataFrameAnalyticsAction.java | 247 +++++ .../extractor/fields/ExtractedField.java | 30 + .../extractor/fields/ExtractedFields.java | 9 + .../dataframe/DataFrameAnalyticsFields.java | 13 + .../dataframe/DataFrameAnalyticsManager.java | 257 ++++++ .../ml/dataframe/SourceDestValidator.java | 65 ++ .../extractor/DataFrameDataExtractor.java | 276 ++++++ .../DataFrameDataExtractorContext.java | 35 + .../DataFrameDataExtractorFactory.java | 168 ++++ .../extractor/ExtractedFieldsDetector.java | 162 ++++ .../DataFrameAnalyticsConfigProvider.java | 122 +++ .../dataframe/process/AnalyticsBuilder.java | 74 ++ .../AnalyticsControlMessageWriter.java | 38 + .../dataframe/process/AnalyticsProcess.java | 34 + .../process/AnalyticsProcessConfig.java | 76 ++ .../process/AnalyticsProcessFactory.java | 21 + .../process/AnalyticsProcessManager.java | 239 +++++ .../ml/dataframe/process/AnalyticsResult.java | 77 ++ .../process/AnalyticsResultProcessor.java | 79 ++ .../process/DataFrameRowsJoiner.java | 184 ++++ .../process/NativeAnalyticsProcess.java | 50 + .../NativeAnalyticsProcessFactory.java | 84 ++ .../dataframe/process/results/RowResults.java | 73 ++ .../xpack/ml/job/JobNodeSelector.java | 328 +++++++ .../autodetect/NativeAutodetectProcess.java | 19 +- .../NativeAutodetectProcessFactory.java | 9 +- .../writer/AbstractDataToProcessWriter.java | 4 +- .../ml/process/AbstractNativeProcess.java | 13 + .../xpack/ml/process/MlMemoryTracker.java | 191 +++- .../ProcessResultsParser.java} | 34 +- .../RestDeleteDataFrameAnalyticsAction.java | 39 + .../RestEvaluateDataFrameAction.java | 36 + .../RestGetDataFrameAnalyticsAction.java | 51 ++ .../RestGetDataFrameAnalyticsStatsAction.java | 52 ++ .../RestPutDataFrameAnalyticsAction.java | 43 + .../RestStartDataFrameAnalyticsAction.java | 50 + .../RestStopDataFrameAnalyticsAction.java | 54 ++ .../plugin-security-test.policy | 5 + .../action/TransportOpenJobActionTests.java | 394 +------- .../dataframe/SourceDestValidatorTests.java | 176 ++++ .../DataFrameDataExtractorTests.java | 392 ++++++++ .../ExtractedFieldsDetectorTests.java | 319 +++++++ .../AnalyticsControlMessageWriterTests.java | 50 + .../AnalyticsResultProcessorTests.java | 86 ++ .../process/AnalyticsResultTests.java | 39 + .../process/DataFrameRowsJoinerTests.java | 280 ++++++ .../process/results/RowResultsTests.java | 42 + .../xpack/ml/job/JobNodeSelectorTests.java | 575 ++++++++++++ .../NativeAutodetectProcessTests.java | 13 +- .../output/AutodetectResultsParserTests.java | 422 --------- .../ml/process/MlMemoryTrackerTests.java | 62 +- .../ml/process/ProcessResultsParserTests.java | 113 +++ .../xpack/ml/support/BaseMlIntegTestCase.java | 21 + .../api/ml.delete_data_frame_analytics.json | 18 + .../api/ml.evaluate_data_frame.json | 15 + .../api/ml.get_data_frame_analytics.json | 38 + .../ml.get_data_frame_analytics_stats.json | 38 + .../api/ml.put_data_frame_analytics.json | 21 + .../api/ml.start_data_frame_analytics.json | 27 + .../api/ml.stop_data_frame_analytics.json | 32 + .../test/ml/data_frame_analytics_crud.yml | 851 ++++++++++++++++++ .../test/ml/evaluate_data_frame.yml | 520 +++++++++++ .../test/ml/start_data_frame_analytics.yml | 74 ++ .../test/ml/stop_data_frame_analytics.yml | 70 ++ 244 files changed, 20932 insertions(+), 1374 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/DataFrameNamedXContentProvider.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/SyncConfig.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfig.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsResponse.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsResponse.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsResponse.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponse.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalysis.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSource.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsState.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/QueryConfig.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/Evaluation.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/EvaluationMetric.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetric.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetric.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/PrecisionMetric.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/RecallMetric.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfigTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/hlrc/TimeSyncConfigTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricAucRocPointTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricResultTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricResultTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequestTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequestTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequestTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PrecisionMetricResultTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequestTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/RecallMetricResultTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequestTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequestTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponseTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfigTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDestTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSourceTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/QueryConfigTests.java create mode 100644 docs/java-rest/high-level/ml/delete-data-frame-analytics.asciidoc create mode 100644 docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc create mode 100644 docs/java-rest/high-level/ml/get-data-frame-analytics-stats.asciidoc create mode 100644 docs/java-rest/high-level/ml/get-data-frame-analytics.asciidoc create mode 100644 docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc create mode 100644 docs/java-rest/high-level/ml/start-data-frame-analytics.asciidoc create mode 100644 docs/java-rest/high-level/ml/stop-data-frame-analytics.asciidoc create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameNamedXContentProvider.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/SyncConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsState.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/{datafeed => utils}/QueryProvider.java (86%) create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfigTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsActionResponseTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSourceTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRocTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrixTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/PrecisionTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/RecallTests.java rename x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/{datafeed => utils}/QueryProviderTests.java (96%) create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsFields.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/SourceDestValidator.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsBuilder.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsControlMessageWriter.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessFactory.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResults.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{job/process/autodetect/output/AutodetectResultsParser.java => process/ProcessResultsParser.java} (72%) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestDeleteDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestEvaluateDataFrameAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsStatsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestPutDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestStartDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestStopDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/plugin-metadata/plugin-security-test.policy create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/SourceDestValidatorTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsControlMessageWriterTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResultsTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java delete mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutodetectResultsParserTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/ProcessResultsParserTests.java create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_data_frame_analytics.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.evaluate_data_frame.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics_stats.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_data_frame_analytics.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.start_data_frame_analytics.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.stop_data_frame_analytics.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/stop_data_frame_analytics.yml diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index c11e577ef3639..e5a98b4632432 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -32,12 +32,14 @@ import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteExpiredDataRequest; import org.elasticsearch.client.ml.DeleteFilterRequest; import org.elasticsearch.client.ml.DeleteForecastRequest; import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequest; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FlushJobRequest; import org.elasticsearch.client.ml.ForecastJobRequest; @@ -45,6 +47,8 @@ import org.elasticsearch.client.ml.GetCalendarEventsRequest; import org.elasticsearch.client.ml.GetCalendarsRequest; import org.elasticsearch.client.ml.GetCategoriesRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsRequest; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; import org.elasticsearch.client.ml.GetFiltersRequest; @@ -61,12 +65,15 @@ import org.elasticsearch.client.ml.PreviewDatafeedRequest; import org.elasticsearch.client.ml.PutCalendarJobRequest; import org.elasticsearch.client.ml.PutCalendarRequest; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutFilterRequest; import org.elasticsearch.client.ml.PutJobRequest; import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.SetUpgradeModeRequest; +import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.UpdateDatafeedRequest; import org.elasticsearch.client.ml.UpdateFilterRequest; @@ -581,6 +588,115 @@ static Request deleteCalendarEvent(DeleteCalendarEventRequest deleteCalendarEven return new Request(HttpDelete.METHOD_NAME, endpoint); } + static Request putDataFrameAnalytics(PutDataFrameAnalyticsRequest putRequest) throws IOException { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "analytics") + .addPathPart(putRequest.getConfig().getId()) + .build(); + Request request = new Request(HttpPut.METHOD_NAME, endpoint); + request.setEntity(createEntity(putRequest, REQUEST_BODY_CONTENT_TYPE)); + return request; + } + + static Request getDataFrameAnalytics(GetDataFrameAnalyticsRequest getRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "analytics") + .addPathPart(Strings.collectionToCommaDelimitedString(getRequest.getIds())) + .build(); + Request request = new Request(HttpGet.METHOD_NAME, endpoint); + RequestConverters.Params params = new RequestConverters.Params(); + if (getRequest.getPageParams() != null) { + PageParams pageParams = getRequest.getPageParams(); + if (pageParams.getFrom() != null) { + params.putParam(PageParams.FROM.getPreferredName(), pageParams.getFrom().toString()); + } + if (pageParams.getSize() != null) { + params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString()); + } + } + if (getRequest.getAllowNoMatch() != null) { + params.putParam(GetDataFrameAnalyticsRequest.ALLOW_NO_MATCH.getPreferredName(), Boolean.toString(getRequest.getAllowNoMatch())); + } + request.addParameters(params.asMap()); + return request; + } + + static Request getDataFrameAnalyticsStats(GetDataFrameAnalyticsStatsRequest getStatsRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "analytics") + .addPathPart(Strings.collectionToCommaDelimitedString(getStatsRequest.getIds())) + .addPathPartAsIs("_stats") + .build(); + Request request = new Request(HttpGet.METHOD_NAME, endpoint); + RequestConverters.Params params = new RequestConverters.Params(); + if (getStatsRequest.getPageParams() != null) { + PageParams pageParams = getStatsRequest.getPageParams(); + if (pageParams.getFrom() != null) { + params.putParam(PageParams.FROM.getPreferredName(), pageParams.getFrom().toString()); + } + if (pageParams.getSize() != null) { + params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString()); + } + } + if (getStatsRequest.getAllowNoMatch() != null) { + params.putParam(GetDataFrameAnalyticsStatsRequest.ALLOW_NO_MATCH.getPreferredName(), + Boolean.toString(getStatsRequest.getAllowNoMatch())); + } + request.addParameters(params.asMap()); + return request; + } + + static Request startDataFrameAnalytics(StartDataFrameAnalyticsRequest startRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "analytics") + .addPathPart(startRequest.getId()) + .addPathPartAsIs("_start") + .build(); + Request request = new Request(HttpPost.METHOD_NAME, endpoint); + RequestConverters.Params params = new RequestConverters.Params(); + if (startRequest.getTimeout() != null) { + params.withTimeout(startRequest.getTimeout()); + } + request.addParameters(params.asMap()); + return request; + } + + static Request stopDataFrameAnalytics(StopDataFrameAnalyticsRequest stopRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "analytics") + .addPathPart(stopRequest.getId()) + .addPathPartAsIs("_stop") + .build(); + Request request = new Request(HttpPost.METHOD_NAME, endpoint); + RequestConverters.Params params = new RequestConverters.Params(); + if (stopRequest.getTimeout() != null) { + params.withTimeout(stopRequest.getTimeout()); + } + if (stopRequest.getAllowNoMatch() != null) { + params.putParam( + StopDataFrameAnalyticsRequest.ALLOW_NO_MATCH.getPreferredName(), Boolean.toString(stopRequest.getAllowNoMatch())); + } + request.addParameters(params.asMap()); + return request; + } + + static Request deleteDataFrameAnalytics(DeleteDataFrameAnalyticsRequest deleteRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "analytics") + .addPathPart(deleteRequest.getId()) + .build(); + return new Request(HttpDelete.METHOD_NAME, endpoint); + } + + static Request evaluateDataFrame(EvaluateDataFrameRequest evaluateRequest) throws IOException { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "_evaluate") + .build(); + Request request = new Request(HttpPost.METHOD_NAME, endpoint); + request.setEntity(createEntity(evaluateRequest, REQUEST_BODY_CONTENT_TYPE)); + return request; + } + static Request putFilter(PutFilterRequest putFilterRequest) throws IOException { String endpoint = new EndpointBuilder() .addPathPartAsIs("_ml") diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java index 2e359931c1025..ea72c355a02e7 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java @@ -25,6 +25,7 @@ import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteExpiredDataRequest; import org.elasticsearch.client.ml.DeleteExpiredDataResponse; @@ -33,6 +34,8 @@ import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteJobResponse; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameResponse; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FindFileStructureResponse; import org.elasticsearch.client.ml.FlushJobRequest; @@ -47,6 +50,10 @@ import org.elasticsearch.client.ml.GetCalendarsResponse; import org.elasticsearch.client.ml.GetCategoriesRequest; import org.elasticsearch.client.ml.GetCategoriesResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsResponse; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedResponse; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; @@ -78,6 +85,8 @@ import org.elasticsearch.client.ml.PutCalendarJobRequest; import org.elasticsearch.client.ml.PutCalendarRequest; import org.elasticsearch.client.ml.PutCalendarResponse; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutDatafeedResponse; import org.elasticsearch.client.ml.PutFilterRequest; @@ -87,8 +96,11 @@ import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.RevertModelSnapshotResponse; import org.elasticsearch.client.ml.SetUpgradeModeRequest; +import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StartDatafeedResponse; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.StopDatafeedResponse; import org.elasticsearch.client.ml.UpdateDatafeedRequest; @@ -1877,4 +1889,286 @@ public void setUpgradeModeAsync(SetUpgradeModeRequest request, RequestOptions op listener, Collections.emptySet()); } + + /** + * Creates a new Data Frame Analytics config + *

+ * For additional info + * see PUT Data Frame Analytics documentation + * + * @param request The {@link PutDataFrameAnalyticsRequest} containing the + * {@link org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return The {@link PutDataFrameAnalyticsResponse} containing the created + * {@link org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig} + * @throws IOException when there is a serialization issue sending the request or receiving the response + */ + public PutDataFrameAnalyticsResponse putDataFrameAnalytics(PutDataFrameAnalyticsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::putDataFrameAnalytics, + options, + PutDataFrameAnalyticsResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Creates a new Data Frame Analytics config asynchronously and notifies listener upon completion + *

+ * For additional info + * see PUT Data Frame Analytics documentation + * + * @param request The {@link PutDataFrameAnalyticsRequest} containing the + * {@link org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void putDataFrameAnalyticsAsync(PutDataFrameAnalyticsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::putDataFrameAnalytics, + options, + PutDataFrameAnalyticsResponse::fromXContent, + listener, + Collections.emptySet()); + } + + /** + * Gets a single or multiple Data Frame Analytics configs + *

+ * For additional info + * see GET Data Frame Analytics documentation + * + * @param request The {@link GetDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return {@link GetDataFrameAnalyticsResponse} response object containing the + * {@link org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig} objects + */ + public GetDataFrameAnalyticsResponse getDataFrameAnalytics(GetDataFrameAnalyticsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::getDataFrameAnalytics, + options, + GetDataFrameAnalyticsResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Gets a single or multiple Data Frame Analytics configs asynchronously and notifies listener upon completion + *

+ * For additional info + * see GET Data Frame Analytics documentation + * + * @param request The {@link GetDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void getDataFrameAnalyticsAsync(GetDataFrameAnalyticsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::getDataFrameAnalytics, + options, + GetDataFrameAnalyticsResponse::fromXContent, + listener, + Collections.emptySet()); + } + + /** + * Gets the running statistics of a Data Frame Analytics + *

+ * For additional info + * see GET Data Frame Analytics Stats documentation + * + * @param request The {@link GetDataFrameAnalyticsStatsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return {@link GetDataFrameAnalyticsStatsResponse} response object + */ + public GetDataFrameAnalyticsStatsResponse getDataFrameAnalyticsStats(GetDataFrameAnalyticsStatsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::getDataFrameAnalyticsStats, + options, + GetDataFrameAnalyticsStatsResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Gets the running statistics of a Data Frame Analytics asynchronously and notifies listener upon completion + *

+ * For additional info + * see GET Data Frame Analytics Stats documentation + * + * @param request The {@link GetDataFrameAnalyticsStatsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void getDataFrameAnalyticsStatsAsync(GetDataFrameAnalyticsStatsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::getDataFrameAnalyticsStats, + options, + GetDataFrameAnalyticsStatsResponse::fromXContent, + listener, + Collections.emptySet()); + } + + /** + * Starts Data Frame Analytics + *

+ * For additional info + * see Start Data Frame Analytics documentation + * + * @param request The {@link StartDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return action acknowledgement + * @throws IOException when there is a serialization issue sending the request or receiving the response + */ + public AcknowledgedResponse startDataFrameAnalytics(StartDataFrameAnalyticsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::startDataFrameAnalytics, + options, + AcknowledgedResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Starts Data Frame Analytics asynchronously and notifies listener upon completion + *

+ * For additional info + * see Start Data Frame Analytics documentation + * + * @param request The {@link StartDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void startDataFrameAnalyticsAsync(StartDataFrameAnalyticsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::startDataFrameAnalytics, + options, + AcknowledgedResponse::fromXContent, + listener, + Collections.emptySet()); + } + + /** + * Stops Data Frame Analytics + *

+ * For additional info + * see Stop Data Frame Analytics documentation + * + * @param request The {@link StopDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return {@link StopDataFrameAnalyticsResponse} + * @throws IOException when there is a serialization issue sending the request or receiving the response + */ + public StopDataFrameAnalyticsResponse stopDataFrameAnalytics(StopDataFrameAnalyticsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::stopDataFrameAnalytics, + options, + StopDataFrameAnalyticsResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Stops Data Frame Analytics asynchronously and notifies listener upon completion + *

+ * For additional info + * see Stop Data Frame Analytics documentation + * + * @param request The {@link StopDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void stopDataFrameAnalyticsAsync(StopDataFrameAnalyticsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::stopDataFrameAnalytics, + options, + StopDataFrameAnalyticsResponse::fromXContent, + listener, + Collections.emptySet()); + } + + /** + * Deletes the given Data Frame Analytics config + *

+ * For additional info + * see DELETE Data Frame Analytics documentation + * + * @param request The {@link DeleteDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return action acknowledgement + * @throws IOException when there is a serialization issue sending the request or receiving the response + */ + public AcknowledgedResponse deleteDataFrameAnalytics(DeleteDataFrameAnalyticsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::deleteDataFrameAnalytics, + options, + AcknowledgedResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Deletes the given Data Frame Analytics config asynchronously and notifies listener upon completion + *

+ * For additional info + * see DELETE Data Frame Analytics documentation + * + * @param request The {@link DeleteDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void deleteDataFrameAnalyticsAsync(DeleteDataFrameAnalyticsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::deleteDataFrameAnalytics, + options, + AcknowledgedResponse::fromXContent, + listener, + Collections.emptySet()); + } + + /** + * Evaluates the given Data Frame + *

+ * For additional info + * see Evaluate Data Frame documentation + * + * @param request The {@link EvaluateDataFrameRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return {@link EvaluateDataFrameResponse} response object + * @throws IOException when there is a serialization issue sending the request or receiving the response + */ + public EvaluateDataFrameResponse evaluateDataFrame(EvaluateDataFrameRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::evaluateDataFrame, + options, + EvaluateDataFrameResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Evaluates the given Data Frame asynchronously and notifies listener upon completion + *

+ * For additional info + * see Evaluate Data Frame documentation + * + * @param request The {@link EvaluateDataFrameRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void evaluateDataFrameAsync(EvaluateDataFrameRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::evaluateDataFrame, + options, + EvaluateDataFrameResponse::fromXContent, + listener, + Collections.emptySet()); + } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/DataFrameNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/DataFrameNamedXContentProvider.java new file mode 100644 index 0000000000000..940b136c93daa --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/DataFrameNamedXContentProvider.java @@ -0,0 +1,41 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.dataframe; + +import org.elasticsearch.client.dataframe.transforms.SyncConfig; +import org.elasticsearch.client.dataframe.transforms.TimeSyncConfig; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; + +import java.util.Arrays; +import java.util.List; + +public class DataFrameNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + return Arrays.asList( + new NamedXContentRegistry.Entry(SyncConfig.class, + new ParseField(TimeSyncConfig.NAME), + TimeSyncConfig::fromXContent)); + } + +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfig.java index 34bcb595c206e..355e3ad9bbc0f 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfig.java @@ -30,6 +30,7 @@ import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; import java.io.IOException; import java.time.Instant; @@ -44,6 +45,7 @@ public class DataFrameTransformConfig implements ToXContentObject { public static final ParseField SOURCE = new ParseField("source"); public static final ParseField DEST = new ParseField("dest"); public static final ParseField DESCRIPTION = new ParseField("description"); + public static final ParseField SYNC = new ParseField("sync"); public static final ParseField VERSION = new ParseField("version"); public static final ParseField CREATE_TIME = new ParseField("create_time"); // types of transforms @@ -52,6 +54,7 @@ public class DataFrameTransformConfig implements ToXContentObject { private final String id; private final SourceConfig source; private final DestConfig dest; + private final SyncConfig syncConfig; private final PivotConfig pivotConfig; private final String description; private final Version transformVersion; @@ -63,17 +66,26 @@ public class DataFrameTransformConfig implements ToXContentObject { String id = (String) args[0]; SourceConfig source = (SourceConfig) args[1]; DestConfig dest = (DestConfig) args[2]; - PivotConfig pivotConfig = (PivotConfig) args[3]; - String description = (String)args[4]; - Instant createTime = (Instant)args[5]; - String transformVersion = (String)args[6]; - return new DataFrameTransformConfig(id, source, dest, pivotConfig, description, createTime, transformVersion); + SyncConfig syncConfig = (SyncConfig) args[3]; + PivotConfig pivotConfig = (PivotConfig) args[4]; + String description = (String)args[5]; + Instant createTime = (Instant)args[6]; + String transformVersion = (String)args[7]; + return new DataFrameTransformConfig(id, + source, + dest, + syncConfig, + pivotConfig, + description, + createTime, + transformVersion); }); static { PARSER.declareString(constructorArg(), ID); PARSER.declareObject(constructorArg(), (p, c) -> SourceConfig.PARSER.apply(p, null), SOURCE); PARSER.declareObject(constructorArg(), (p, c) -> DestConfig.PARSER.apply(p, null), DEST); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> parseSyncConfig(p), SYNC); PARSER.declareObject(optionalConstructorArg(), (p, c) -> PivotConfig.fromXContent(p), PIVOT_TRANSFORM); PARSER.declareString(optionalConstructorArg(), DESCRIPTION); PARSER.declareField(optionalConstructorArg(), @@ -81,6 +93,15 @@ public class DataFrameTransformConfig implements ToXContentObject { PARSER.declareString(optionalConstructorArg(), VERSION); } + private static SyncConfig parseSyncConfig(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + SyncConfig syncConfig = parser.namedObject(SyncConfig.class, parser.currentName(), true); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return syncConfig; + } + + public static DataFrameTransformConfig fromXContent(final XContentParser parser) { return PARSER.apply(parser, null); } @@ -97,12 +118,13 @@ public static DataFrameTransformConfig fromXContent(final XContentParser parser) * @return A DataFrameTransformConfig to preview, NOTE it will have a {@code null} id, destination and index. */ public static DataFrameTransformConfig forPreview(final SourceConfig source, final PivotConfig pivotConfig) { - return new DataFrameTransformConfig(null, source, null, pivotConfig, null, null, null); + return new DataFrameTransformConfig(null, source, null, null, pivotConfig, null, null, null); } DataFrameTransformConfig(final String id, final SourceConfig source, final DestConfig dest, + final SyncConfig syncConfig, final PivotConfig pivotConfig, final String description, final Instant createTime, @@ -110,6 +132,7 @@ public static DataFrameTransformConfig forPreview(final SourceConfig source, fin this.id = id; this.source = source; this.dest = dest; + this.syncConfig = syncConfig; this.pivotConfig = pivotConfig; this.description = description; this.createTime = createTime == null ? null : Instant.ofEpochMilli(createTime.toEpochMilli()); @@ -128,6 +151,10 @@ public DestConfig getDestination() { return dest; } + public SyncConfig getSyncConfig() { + return syncConfig; + } + public PivotConfig getPivotConfig() { return pivotConfig; } @@ -157,6 +184,11 @@ public XContentBuilder toXContent(final XContentBuilder builder, final Params pa if (dest != null) { builder.field(DEST.getPreferredName(), dest); } + if (syncConfig != null) { + builder.startObject(SYNC.getPreferredName()); + builder.field(syncConfig.getName(), syncConfig); + builder.endObject(); + } if (pivotConfig != null) { builder.field(PIVOT_TRANSFORM.getPreferredName(), pivotConfig); } @@ -189,6 +221,7 @@ public boolean equals(Object other) { && Objects.equals(this.source, that.source) && Objects.equals(this.dest, that.dest) && Objects.equals(this.description, that.description) + && Objects.equals(this.syncConfig, that.syncConfig) && Objects.equals(this.transformVersion, that.transformVersion) && Objects.equals(this.createTime, that.createTime) && Objects.equals(this.pivotConfig, that.pivotConfig); @@ -196,7 +229,7 @@ public boolean equals(Object other) { @Override public int hashCode() { - return Objects.hash(id, source, dest, pivotConfig, description, createTime, transformVersion); + return Objects.hash(id, source, dest, syncConfig, pivotConfig, description); } @Override @@ -213,6 +246,7 @@ public static class Builder { private String id; private SourceConfig source; private DestConfig dest; + private SyncConfig syncConfig; private PivotConfig pivotConfig; private String description; @@ -231,6 +265,11 @@ public Builder setDest(DestConfig dest) { return this; } + public Builder setSyncConfig(SyncConfig syncConfig) { + this.syncConfig = syncConfig; + return this; + } + public Builder setPivotConfig(PivotConfig pivotConfig) { this.pivotConfig = pivotConfig; return this; @@ -242,7 +281,7 @@ public Builder setDescription(String description) { } public DataFrameTransformConfig build() { - return new DataFrameTransformConfig(id, source, dest, pivotConfig, description, null, null); + return new DataFrameTransformConfig(id, source, dest, syncConfig, pivotConfig, description, null, null); } } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/SyncConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/SyncConfig.java new file mode 100644 index 0000000000000..3ead35d0a491a --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/SyncConfig.java @@ -0,0 +1,30 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.dataframe.transforms; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +public interface SyncConfig extends ToXContentObject { + + /** + * Returns the name of the writeable object + */ + String getName(); +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfig.java new file mode 100644 index 0000000000000..797ca3f896138 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfig.java @@ -0,0 +1,108 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.dataframe.transforms; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class TimeSyncConfig implements SyncConfig { + + public static final String NAME = "time"; + + private static final ParseField FIELD = new ParseField("field"); + private static final ParseField DELAY = new ParseField("delay"); + + private final String field; + private final TimeValue delay; + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("time_sync_config", true, + args -> new TimeSyncConfig((String) args[0], args[1] != null ? (TimeValue) args[1] : TimeValue.ZERO)); + + static { + PARSER.declareString(constructorArg(), FIELD); + PARSER.declareField(optionalConstructorArg(), (p, c) -> TimeValue.parseTimeValue(p.textOrNull(), DELAY.getPreferredName()), DELAY, + ObjectParser.ValueType.STRING_OR_NULL); + } + + public static TimeSyncConfig fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public TimeSyncConfig(String field, TimeValue delay) { + this.field = field; + this.delay = delay; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD.getPreferredName(), field); + if (delay.duration() > 0) { + builder.field(DELAY.getPreferredName(), delay.getStringRep()); + } + builder.endObject(); + return builder; + } + + public String getField() { + return field; + } + + public TimeValue getDelay() { + return delay; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + + final TimeSyncConfig that = (TimeSyncConfig) other; + + return Objects.equals(this.field, that.field) + && Objects.equals(this.delay, that.delay); + } + + @Override + public int hashCode() { + return Objects.hash(field, delay); + } + + @Override + public String getName() { + return NAME; + } + +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequest.java new file mode 100644 index 0000000000000..f03466632304d --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequest.java @@ -0,0 +1,64 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; + +import java.util.Objects; +import java.util.Optional; + +/** + * Request to delete a data frame analytics config + */ +public class DeleteDataFrameAnalyticsRequest implements Validatable { + + private final String id; + + public DeleteDataFrameAnalyticsRequest(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + @Override + public Optional validate() { + if (id == null) { + return Optional.of(ValidationException.withError("data frame analytics id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + DeleteDataFrameAnalyticsRequest other = (DeleteDataFrameAnalyticsRequest) o; + return Objects.equals(id, other.id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java new file mode 100644 index 0000000000000..2e3bbb170509c --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java @@ -0,0 +1,136 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; + +public class EvaluateDataFrameRequest implements ToXContentObject, Validatable { + + private static final ParseField INDEX = new ParseField("index"); + private static final ParseField EVALUATION = new ParseField("evaluation"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "evaluate_data_frame_request", true, args -> new EvaluateDataFrameRequest((List) args[0], (Evaluation) args[1])); + + static { + PARSER.declareStringArray(constructorArg(), INDEX); + PARSER.declareObject(constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION); + } + + private static Evaluation parseEvaluation(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + Evaluation evaluation = parser.namedObject(Evaluation.class, parser.currentName(), null); + ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return evaluation; + } + + public static EvaluateDataFrameRequest fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private List indices; + private Evaluation evaluation; + + public EvaluateDataFrameRequest(String index, Evaluation evaluation) { + this(Arrays.asList(index), evaluation); + } + + public EvaluateDataFrameRequest(List indices, Evaluation evaluation) { + setIndices(indices); + setEvaluation(evaluation); + } + + public List getIndices() { + return Collections.unmodifiableList(indices); + } + + public final void setIndices(List indices) { + Objects.requireNonNull(indices); + this.indices = new ArrayList<>(indices); + } + + public Evaluation getEvaluation() { + return evaluation; + } + + public final void setEvaluation(Evaluation evaluation) { + this.evaluation = evaluation; + } + + @Override + public Optional validate() { + List errors = new ArrayList<>(); + if (indices.isEmpty()) { + errors.add("At least one index must be specified"); + } + if (evaluation == null) { + errors.add("evaluation must not be null"); + } + return errors.isEmpty() + ? Optional.empty() + : Optional.of(ValidationException.withErrors(errors)); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .array(INDEX.getPreferredName(), indices.toArray()) + .startObject(EVALUATION.getPreferredName()) + .field(evaluation.getName(), evaluation) + .endObject() + .endObject(); + } + + @Override + public int hashCode() { + return Objects.hash(indices, evaluation); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + EvaluateDataFrameRequest that = (EvaluateDataFrameRequest) o; + return Objects.equals(indices, that.indices) + && Objects.equals(evaluation, that.evaluation); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java new file mode 100644 index 0000000000000..d70bd713bd60a --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java @@ -0,0 +1,119 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.NamedObjectNotFoundException; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; + +public class EvaluateDataFrameResponse implements ToXContentObject { + + public static EvaluateDataFrameResponse fromXContent(XContentParser parser) throws IOException { + if (parser.currentToken() == null) { + parser.nextToken(); + } + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + String evaluationName = parser.currentName(); + parser.nextToken(); + Map metrics = parser.map(LinkedHashMap::new, EvaluateDataFrameResponse::parseMetric); + List knownMetrics = + metrics.values().stream() + .filter(Objects::nonNull) // Filter out null values returned by {@link EvaluateDataFrameResponse::parseMetric}. + .collect(Collectors.toList()); + ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return new EvaluateDataFrameResponse(evaluationName, knownMetrics); + } + + private static EvaluationMetric.Result parseMetric(XContentParser parser) throws IOException { + String metricName = parser.currentName(); + try { + return parser.namedObject(EvaluationMetric.Result.class, metricName, null); + } catch (NamedObjectNotFoundException e) { + parser.skipChildren(); + // Metric name not recognized. Return {@code null} value here and filter it out later. + return null; + } + } + + private final String evaluationName; + private final Map metrics; + + public EvaluateDataFrameResponse(String evaluationName, List metrics) { + this.evaluationName = Objects.requireNonNull(evaluationName); + this.metrics = Collections.unmodifiableMap(Objects.requireNonNull(metrics) + .stream().collect(Collectors.toMap(m -> m.getMetricName(), m -> m))); + } + + public String getEvaluationName() { + return evaluationName; + } + + public List getMetrics() { + return metrics.values().stream().collect(Collectors.toList()); + } + + @SuppressWarnings("unchecked") + public T getMetricByName(String metricName) { + Objects.requireNonNull(metricName); + return (T) metrics.get(metricName); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder + .startObject() + .field(evaluationName, metrics) + .endObject(); + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + EvaluateDataFrameResponse that = (EvaluateDataFrameResponse) o; + return Objects.equals(evaluationName, that.evaluationName) + && Objects.equals(metrics, that.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(evaluationName, metrics); + } + + @Override + public final String toString() { + return Strings.toString(this); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java new file mode 100644 index 0000000000000..40698c4b528fa --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java @@ -0,0 +1,104 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.core.PageParams; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +public class GetDataFrameAnalyticsRequest implements Validatable { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + private final List ids; + private Boolean allowNoMatch; + private PageParams pageParams; + + /** + * Helper method to create a request that will get ALL Data Frame Analytics + * @return new {@link GetDataFrameAnalyticsRequest} object for the id "_all" + */ + public static GetDataFrameAnalyticsRequest getAllDataFrameAnalyticsRequest() { + return new GetDataFrameAnalyticsRequest("_all"); + } + + public GetDataFrameAnalyticsRequest(String... ids) { + this.ids = Arrays.asList(ids); + } + + public List getIds() { + return ids; + } + + public Boolean getAllowNoMatch() { + return allowNoMatch; + } + + /** + * Whether to ignore if a wildcard expression matches no data frame analytics. + * + * @param allowNoMatch If this is {@code false}, then an error is returned when a wildcard (or {@code _all}) + * does not match any data frame analytics + */ + public GetDataFrameAnalyticsRequest setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + return this; + } + + public PageParams getPageParams() { + return pageParams; + } + + public GetDataFrameAnalyticsRequest setPageParams(@Nullable PageParams pageParams) { + this.pageParams = pageParams; + return this; + } + + @Override + public Optional validate() { + if (ids == null || ids.isEmpty()) { + return Optional.of(ValidationException.withError("data frame analytics id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetDataFrameAnalyticsRequest other = (GetDataFrameAnalyticsRequest) o; + return Objects.equals(ids, other.ids) + && Objects.equals(allowNoMatch, other.allowNoMatch) + && Objects.equals(pageParams, other.pageParams); + } + + @Override + public int hashCode() { + return Objects.hash(ids, allowNoMatch, pageParams); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsResponse.java new file mode 100644 index 0000000000000..76996e9d4d0b6 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsResponse.java @@ -0,0 +1,74 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class GetDataFrameAnalyticsResponse { + + public static final ParseField DATA_FRAME_ANALYTICS = new ParseField("data_frame_analytics"); + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "get_data_frame_analytics", + true, + args -> new GetDataFrameAnalyticsResponse((List) args[0])); + + static { + PARSER.declareObjectArray(constructorArg(), (p, c) -> DataFrameAnalyticsConfig.fromXContent(p), DATA_FRAME_ANALYTICS); + } + + public static GetDataFrameAnalyticsResponse fromXContent(final XContentParser parser) { + return PARSER.apply(parser, null); + } + + private List analytics; + + public GetDataFrameAnalyticsResponse(List analytics) { + this.analytics = analytics; + } + + public List getAnalytics() { + return analytics; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetDataFrameAnalyticsResponse other = (GetDataFrameAnalyticsResponse) o; + return Objects.equals(this.analytics, other.analytics); + } + + @Override + public int hashCode() { + return Objects.hash(analytics); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java new file mode 100644 index 0000000000000..f1e4a35fb661b --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java @@ -0,0 +1,99 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.core.PageParams; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +/** + * Request to get data frame analytics stats + */ +public class GetDataFrameAnalyticsStatsRequest implements Validatable { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + private final List ids; + private Boolean allowNoMatch; + private PageParams pageParams; + + public GetDataFrameAnalyticsStatsRequest(String... ids) { + this.ids = Arrays.asList(ids); + } + + public List getIds() { + return ids; + } + + public Boolean getAllowNoMatch() { + return allowNoMatch; + } + + /** + * Whether to ignore if a wildcard expression matches no data frame analytics. + * + * @param allowNoMatch If this is {@code false}, then an error is returned when a wildcard (or {@code _all}) + * does not match any data frame analytics + */ + public GetDataFrameAnalyticsStatsRequest setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + return this; + } + + public PageParams getPageParams() { + return pageParams; + } + + public GetDataFrameAnalyticsStatsRequest setPageParams(@Nullable PageParams pageParams) { + this.pageParams = pageParams; + return this; + } + + @Override + public Optional validate() { + if (ids == null || ids.isEmpty()) { + return Optional.of(ValidationException.withError("data frame analytics id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetDataFrameAnalyticsStatsRequest other = (GetDataFrameAnalyticsStatsRequest) o; + return Objects.equals(ids, other.ids) + && Objects.equals(allowNoMatch, other.allowNoMatch) + && Objects.equals(pageParams, other.pageParams); + } + + @Override + public int hashCode() { + return Objects.hash(ids, allowNoMatch, pageParams); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsResponse.java new file mode 100644 index 0000000000000..5391a576e98b0 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsResponse.java @@ -0,0 +1,102 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.TaskOperationFailure; +import org.elasticsearch.client.dataframe.AcknowledgedTasksResponse; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class GetDataFrameAnalyticsStatsResponse { + + public static GetDataFrameAnalyticsStatsResponse fromXContent(XContentParser parser) { + return GetDataFrameAnalyticsStatsResponse.PARSER.apply(parser, null); + } + + private static final ParseField DATA_FRAME_ANALYTICS = new ParseField("data_frame_analytics"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "get_data_frame_analytics_stats_response", true, + args -> new GetDataFrameAnalyticsStatsResponse( + (List) args[0], + (List) args[1], + (List) args[2])); + + static { + PARSER.declareObjectArray(constructorArg(), (p, c) -> DataFrameAnalyticsStats.fromXContent(p), DATA_FRAME_ANALYTICS); + PARSER.declareObjectArray( + optionalConstructorArg(), (p, c) -> TaskOperationFailure.fromXContent(p), AcknowledgedTasksResponse.TASK_FAILURES); + PARSER.declareObjectArray( + optionalConstructorArg(), (p, c) -> ElasticsearchException.fromXContent(p), AcknowledgedTasksResponse.NODE_FAILURES); + } + + private final List analyticsStats; + private final List taskFailures; + private final List nodeFailures; + + public GetDataFrameAnalyticsStatsResponse(List analyticsStats, + @Nullable List taskFailures, + @Nullable List nodeFailures) { + this.analyticsStats = analyticsStats; + this.taskFailures = taskFailures == null ? Collections.emptyList() : Collections.unmodifiableList(taskFailures); + this.nodeFailures = nodeFailures == null ? Collections.emptyList() : Collections.unmodifiableList(nodeFailures); + } + + public List getAnalyticsStats() { + return analyticsStats; + } + + public List getNodeFailures() { + return nodeFailures; + } + + public List getTaskFailures() { + return taskFailures; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetDataFrameAnalyticsStatsResponse other = (GetDataFrameAnalyticsStatsResponse) o; + return Objects.equals(analyticsStats, other.analyticsStats) + && Objects.equals(nodeFailures, other.nodeFailures) + && Objects.equals(taskFailures, other.taskFailures); + } + + @Override + public int hashCode() { + return Objects.hash(analyticsStats, nodeFailures, taskFailures); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/NodeAttributes.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/NodeAttributes.java index 892df340abd6b..a0f0d25f2ca01 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/NodeAttributes.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/NodeAttributes.java @@ -19,6 +19,7 @@ package org.elasticsearch.client.ml; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; @@ -147,4 +148,9 @@ public boolean equals(Object other) { Objects.equals(transportAddress, that.transportAddress) && Objects.equals(attributes, that.attributes); } + + @Override + public String toString() { + return Strings.toString(this); + } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequest.java new file mode 100644 index 0000000000000..14950a74c9187 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequest.java @@ -0,0 +1,70 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; +import java.util.Optional; + +public class PutDataFrameAnalyticsRequest implements ToXContentObject, Validatable { + + private final DataFrameAnalyticsConfig config; + + public PutDataFrameAnalyticsRequest(DataFrameAnalyticsConfig config) { + this.config = config; + } + + public DataFrameAnalyticsConfig getConfig() { + return config; + } + + @Override + public Optional validate() { + if (config == null) { + return Optional.of(ValidationException.withError("put requires a non-null data frame analytics config")); + } + return Optional.empty(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return config.toXContent(builder, params); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + PutDataFrameAnalyticsRequest other = (PutDataFrameAnalyticsRequest) o; + return Objects.equals(config, other.config); + } + + @Override + public int hashCode() { + return Objects.hash(config); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsResponse.java new file mode 100644 index 0000000000000..e6c4be15987d4 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsResponse.java @@ -0,0 +1,57 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class PutDataFrameAnalyticsResponse { + + public static PutDataFrameAnalyticsResponse fromXContent(XContentParser parser) throws IOException { + return new PutDataFrameAnalyticsResponse(DataFrameAnalyticsConfig.fromXContent(parser)); + } + + private final DataFrameAnalyticsConfig config; + + public PutDataFrameAnalyticsResponse(DataFrameAnalyticsConfig config) { + this.config = config; + } + + public DataFrameAnalyticsConfig getConfig() { + return config; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + PutDataFrameAnalyticsResponse other = (PutDataFrameAnalyticsResponse) o; + return Objects.equals(config, other.config); + } + + @Override + public int hashCode() { + return Objects.hash(config); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequest.java new file mode 100644 index 0000000000000..68a925d15019a --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequest.java @@ -0,0 +1,74 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.unit.TimeValue; + +import java.util.Objects; +import java.util.Optional; + +public class StartDataFrameAnalyticsRequest implements Validatable { + + private final String id; + private TimeValue timeout; + + public StartDataFrameAnalyticsRequest(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + public TimeValue getTimeout() { + return timeout; + } + + public StartDataFrameAnalyticsRequest setTimeout(@Nullable TimeValue timeout) { + this.timeout = timeout; + return this; + } + + @Override + public Optional validate() { + if (id == null) { + return Optional.of(ValidationException.withError("data frame analytics id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StartDataFrameAnalyticsRequest other = (StartDataFrameAnalyticsRequest) o; + return Objects.equals(id, other.id) + && Objects.equals(timeout, other.timeout); + } + + @Override + public int hashCode() { + return Objects.hash(id, timeout); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequest.java new file mode 100644 index 0000000000000..9608d40fc7d16 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequest.java @@ -0,0 +1,88 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.unit.TimeValue; + +import java.util.Objects; +import java.util.Optional; + +public class StopDataFrameAnalyticsRequest implements Validatable { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + private final String id; + private TimeValue timeout; + private Boolean allowNoMatch; + + public StopDataFrameAnalyticsRequest(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + public TimeValue getTimeout() { + return timeout; + } + + public StopDataFrameAnalyticsRequest setTimeout(@Nullable TimeValue timeout) { + this.timeout = timeout; + return this; + } + + public Boolean getAllowNoMatch() { + return allowNoMatch; + } + + public StopDataFrameAnalyticsRequest setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + return this; + } + + @Override + public Optional validate() { + if (id == null) { + return Optional.of(ValidationException.withError("data frame analytics id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StopDataFrameAnalyticsRequest other = (StopDataFrameAnalyticsRequest) o; + return Objects.equals(id, other.id) + && Objects.equals(timeout, other.timeout) + && Objects.equals(allowNoMatch, other.allowNoMatch); + } + + @Override + public int hashCode() { + return Objects.hash(id, timeout, allowNoMatch); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponse.java new file mode 100644 index 0000000000000..5f45c6f9ea51f --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponse.java @@ -0,0 +1,87 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +/** + * Response indicating if the Machine Learning Data Frame Analytics is now stopped or not + */ +public class StopDataFrameAnalyticsResponse implements ToXContentObject { + + private static final ParseField STOPPED = new ParseField("stopped"); + + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "stop_data_frame_analytics_response", + true, + args -> new StopDataFrameAnalyticsResponse((Boolean) args[0])); + + static { + PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), STOPPED); + } + + public static StopDataFrameAnalyticsResponse fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + private final boolean stopped; + + public StopDataFrameAnalyticsResponse(boolean stopped) { + this.stopped = stopped; + } + + /** + * Has the Data Frame Analytics stopped or not + * + * @return boolean value indicating the Data Frame Analytics stopped status + */ + public boolean isStopped() { + return stopped; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StopDataFrameAnalyticsResponse other = (StopDataFrameAnalyticsResponse) o; + return stopped == other.stopped; + } + + @Override + public int hashCode() { + return Objects.hash(stopped); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(STOPPED.getPreferredName(), stopped) + .endObject(); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalysis.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalysis.java new file mode 100644 index 0000000000000..81b19eefce573 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalysis.java @@ -0,0 +1,27 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +public interface DataFrameAnalysis extends ToXContentObject { + + String getName(); +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java new file mode 100644 index 0000000000000..b1309e66afcd4 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java @@ -0,0 +1,208 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ObjectParser.ValueType.OBJECT_ARRAY_BOOLEAN_OR_STRING; +import static org.elasticsearch.common.xcontent.ObjectParser.ValueType.VALUE; + +public class DataFrameAnalyticsConfig implements ToXContentObject { + + public static DataFrameAnalyticsConfig fromXContent(XContentParser parser) { + return PARSER.apply(parser, null).build(); + } + + public static Builder builder(String id) { + return new Builder().setId(id); + } + + private static final ParseField ID = new ParseField("id"); + private static final ParseField SOURCE = new ParseField("source"); + private static final ParseField DEST = new ParseField("dest"); + private static final ParseField ANALYSIS = new ParseField("analysis"); + private static final ParseField ANALYZED_FIELDS = new ParseField("analyzed_fields"); + private static final ParseField MODEL_MEMORY_LIMIT = new ParseField("model_memory_limit"); + + private static ObjectParser PARSER = new ObjectParser<>("data_frame_analytics_config", true, Builder::new); + + static { + PARSER.declareString(Builder::setId, ID); + PARSER.declareObject(Builder::setSource, (p, c) -> DataFrameAnalyticsSource.fromXContent(p), SOURCE); + PARSER.declareObject(Builder::setDest, (p, c) -> DataFrameAnalyticsDest.fromXContent(p), DEST); + PARSER.declareObject(Builder::setAnalysis, (p, c) -> parseAnalysis(p), ANALYSIS); + PARSER.declareField(Builder::setAnalyzedFields, + (p, c) -> FetchSourceContext.fromXContent(p), + ANALYZED_FIELDS, + OBJECT_ARRAY_BOOLEAN_OR_STRING); + PARSER.declareField(Builder::setModelMemoryLimit, + (p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), MODEL_MEMORY_LIMIT.getPreferredName()), MODEL_MEMORY_LIMIT, VALUE); + } + + private static DataFrameAnalysis parseAnalysis(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + DataFrameAnalysis analysis = parser.namedObject(DataFrameAnalysis.class, parser.currentName(), true); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return analysis; + } + + private final String id; + private final DataFrameAnalyticsSource source; + private final DataFrameAnalyticsDest dest; + private final DataFrameAnalysis analysis; + private final FetchSourceContext analyzedFields; + private final ByteSizeValue modelMemoryLimit; + + private DataFrameAnalyticsConfig(String id, DataFrameAnalyticsSource source, DataFrameAnalyticsDest dest, DataFrameAnalysis analysis, + @Nullable FetchSourceContext analyzedFields, @Nullable ByteSizeValue modelMemoryLimit) { + this.id = Objects.requireNonNull(id); + this.source = Objects.requireNonNull(source); + this.dest = Objects.requireNonNull(dest); + this.analysis = Objects.requireNonNull(analysis); + this.analyzedFields = analyzedFields; + this.modelMemoryLimit = modelMemoryLimit; + } + + public String getId() { + return id; + } + + public DataFrameAnalyticsSource getSource() { + return source; + } + + public DataFrameAnalyticsDest getDest() { + return dest; + } + + public DataFrameAnalysis getAnalysis() { + return analysis; + } + + public FetchSourceContext getAnalyzedFields() { + return analyzedFields; + } + + public ByteSizeValue getModelMemoryLimit() { + return modelMemoryLimit; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ID.getPreferredName(), id); + builder.field(SOURCE.getPreferredName(), source); + builder.field(DEST.getPreferredName(), dest); + builder.startObject(ANALYSIS.getPreferredName()); + builder.field(analysis.getName(), analysis); + builder.endObject(); + if (analyzedFields != null) { + builder.field(ANALYZED_FIELDS.getPreferredName(), analyzedFields); + } + if (modelMemoryLimit != null) { + builder.field(MODEL_MEMORY_LIMIT.getPreferredName(), modelMemoryLimit.getStringRep()); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsConfig other = (DataFrameAnalyticsConfig) o; + return Objects.equals(id, other.id) + && Objects.equals(source, other.source) + && Objects.equals(dest, other.dest) + && Objects.equals(analysis, other.analysis) + && Objects.equals(analyzedFields, other.analyzedFields) + && Objects.equals(modelMemoryLimit, other.modelMemoryLimit); + } + + @Override + public int hashCode() { + return Objects.hash(id, source, dest, analysis, analyzedFields, getModelMemoryLimit()); + } + + @Override + public String toString() { + return Strings.toString(this); + } + + public static class Builder { + + private String id; + private DataFrameAnalyticsSource source; + private DataFrameAnalyticsDest dest; + private DataFrameAnalysis analysis; + private FetchSourceContext analyzedFields; + private ByteSizeValue modelMemoryLimit; + + private Builder() {} + + public Builder setId(String id) { + this.id = Objects.requireNonNull(id); + return this; + } + + public Builder setSource(DataFrameAnalyticsSource source) { + this.source = Objects.requireNonNull(source); + return this; + } + + public Builder setDest(DataFrameAnalyticsDest dest) { + this.dest = Objects.requireNonNull(dest); + return this; + } + + public Builder setAnalysis(DataFrameAnalysis analysis) { + this.analysis = Objects.requireNonNull(analysis); + return this; + } + + public Builder setAnalyzedFields(FetchSourceContext fields) { + this.analyzedFields = fields; + return this; + } + + public Builder setModelMemoryLimit(ByteSizeValue modelMemoryLimit) { + this.modelMemoryLimit = modelMemoryLimit; + return this; + } + + public DataFrameAnalyticsConfig build() { + return new DataFrameAnalyticsConfig(id, source, dest, analysis, analyzedFields, modelMemoryLimit); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDest.java new file mode 100644 index 0000000000000..4123f85ee2f43 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDest.java @@ -0,0 +1,123 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class DataFrameAnalyticsDest implements ToXContentObject { + + public static DataFrameAnalyticsDest fromXContent(XContentParser parser) { + return PARSER.apply(parser, null).build(); + } + + public static Builder builder() { + return new Builder(); + } + + private static final ParseField INDEX = new ParseField("index"); + private static final ParseField RESULTS_FIELD = new ParseField("results_field"); + + private static ObjectParser PARSER = new ObjectParser<>("data_frame_analytics_dest", true, Builder::new); + + static { + PARSER.declareString(Builder::setIndex, INDEX); + PARSER.declareString(Builder::setResultsField, RESULTS_FIELD); + } + + private final String index; + private final String resultsField; + + private DataFrameAnalyticsDest(String index, @Nullable String resultsField) { + this.index = requireNonNull(index); + this.resultsField = resultsField; + } + + public String getIndex() { + return index; + } + + public String getResultsField() { + return resultsField; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INDEX.getPreferredName(), index); + if (resultsField != null) { + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsDest other = (DataFrameAnalyticsDest) o; + return Objects.equals(index, other.index) + && Objects.equals(resultsField, other.resultsField); + } + + @Override + public int hashCode() { + return Objects.hash(index, resultsField); + } + + @Override + public String toString() { + return Strings.toString(this); + } + + public static class Builder { + + private String index; + private String resultsField; + + private Builder() {} + + public Builder setIndex(String index) { + this.index = index; + return this; + } + + public Builder setResultsField(String resultsField) { + this.resultsField = resultsField; + return this; + } + + public DataFrameAnalyticsDest build() { + return new DataFrameAnalyticsDest(index, resultsField); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSource.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSource.java new file mode 100644 index 0000000000000..c36799cd3b4a7 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSource.java @@ -0,0 +1,121 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class DataFrameAnalyticsSource implements ToXContentObject { + + public static DataFrameAnalyticsSource fromXContent(XContentParser parser) { + return PARSER.apply(parser, null).build(); + } + + public static Builder builder() { + return new Builder(); + } + + private static final ParseField INDEX = new ParseField("index"); + private static final ParseField QUERY = new ParseField("query"); + + private static ObjectParser PARSER = new ObjectParser<>("data_frame_analytics_source", true, Builder::new); + + static { + PARSER.declareString(Builder::setIndex, INDEX); + PARSER.declareObject(Builder::setQueryConfig, (p, c) -> QueryConfig.fromXContent(p), QUERY); + } + + private final String index; + private final QueryConfig queryConfig; + + private DataFrameAnalyticsSource(String index, @Nullable QueryConfig queryConfig) { + this.index = Objects.requireNonNull(index); + this.queryConfig = queryConfig; + } + + public String getIndex() { + return index; + } + + public QueryConfig getQueryConfig() { + return queryConfig; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INDEX.getPreferredName(), index); + if (queryConfig != null) { + builder.field(QUERY.getPreferredName(), queryConfig.getQuery()); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsSource other = (DataFrameAnalyticsSource) o; + return Objects.equals(index, other.index) + && Objects.equals(queryConfig, other.queryConfig); + } + + @Override + public int hashCode() { + return Objects.hash(index, queryConfig); + } + + @Override + public String toString() { + return Strings.toString(this); + } + + public static class Builder { + + private String index; + private QueryConfig queryConfig; + + private Builder() {} + + public Builder setIndex(String index) { + this.index = index; + return this; + } + + public Builder setQueryConfig(QueryConfig queryConfig) { + this.queryConfig = queryConfig; + return this; + } + + public DataFrameAnalyticsSource build() { + return new DataFrameAnalyticsSource(index, queryConfig); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsState.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsState.java new file mode 100644 index 0000000000000..6ee349b8e8d38 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsState.java @@ -0,0 +1,34 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import java.util.Locale; + +public enum DataFrameAnalyticsState { + STARTED, REINDEXING, ANALYZING, STOPPING, STOPPED; + + public static DataFrameAnalyticsState fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public String value() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java new file mode 100644 index 0000000000000..5c652f33edb2e --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java @@ -0,0 +1,133 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.client.ml.NodeAttributes; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.internal.ToStringBuilder; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class DataFrameAnalyticsStats { + + public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + static final ParseField ID = new ParseField("id"); + static final ParseField STATE = new ParseField("state"); + static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent"); + static final ParseField NODE = new ParseField("node"); + static final ParseField ASSIGNMENT_EXPLANATION = new ParseField("assignment_explanation"); + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("data_frame_analytics_stats", true, + args -> new DataFrameAnalyticsStats( + (String) args[0], + (DataFrameAnalyticsState) args[1], + (Integer) args[2], + (NodeAttributes) args[3], + (String) args[4])); + + static { + PARSER.declareString(constructorArg(), ID); + PARSER.declareField(constructorArg(), p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return DataFrameAnalyticsState.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, STATE, ObjectParser.ValueType.STRING); + PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT); + PARSER.declareObject(optionalConstructorArg(), NodeAttributes.PARSER, NODE); + PARSER.declareString(optionalConstructorArg(), ASSIGNMENT_EXPLANATION); + } + + private final String id; + private final DataFrameAnalyticsState state; + private final Integer progressPercent; + private final NodeAttributes node; + private final String assignmentExplanation; + + public DataFrameAnalyticsStats(String id, DataFrameAnalyticsState state, @Nullable Integer progressPercent, + @Nullable NodeAttributes node, @Nullable String assignmentExplanation) { + this.id = id; + this.state = state; + this.progressPercent = progressPercent; + this.node = node; + this.assignmentExplanation = assignmentExplanation; + } + + public String getId() { + return id; + } + + public DataFrameAnalyticsState getState() { + return state; + } + + public Integer getProgressPercent() { + return progressPercent; + } + + public NodeAttributes getNode() { + return node; + } + + public String getAssignmentExplanation() { + return assignmentExplanation; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsStats other = (DataFrameAnalyticsStats) o; + return Objects.equals(id, other.id) + && Objects.equals(state, other.state) + && Objects.equals(progressPercent, other.progressPercent) + && Objects.equals(node, other.node) + && Objects.equals(assignmentExplanation, other.assignmentExplanation); + } + + @Override + public int hashCode() { + return Objects.hash(id, state, progressPercent, node, assignmentExplanation); + } + + @Override + public String toString() { + return new ToStringBuilder(getClass()) + .add("id", id) + .add("state", state) + .add("progressPercent", progressPercent) + .add("node", node) + .add("assignmentExplanation", assignmentExplanation) + .toString(); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java new file mode 100644 index 0000000000000..3b78c60be91fd --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java @@ -0,0 +1,37 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; + +import java.util.Arrays; +import java.util.List; + +public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + return Arrays.asList( + new NamedXContentRegistry.Entry( + DataFrameAnalysis.class, + OutlierDetection.NAME, + (p, c) -> OutlierDetection.fromXContent(p))); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java new file mode 100644 index 0000000000000..946c01ac5c835 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java @@ -0,0 +1,176 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Locale; +import java.util.Objects; + +public class OutlierDetection implements DataFrameAnalysis { + + public static OutlierDetection fromXContent(XContentParser parser) { + return PARSER.apply(parser, null).build(); + } + + public static OutlierDetection createDefault() { + return builder().build(); + } + + public static Builder builder() { + return new Builder(); + } + + public static final ParseField NAME = new ParseField("outlier_detection"); + static final ParseField N_NEIGHBORS = new ParseField("n_neighbors"); + static final ParseField METHOD = new ParseField("method"); + public static final ParseField MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE = + new ParseField("minimum_score_to_write_feature_influence"); + + private static ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Builder::new); + + static { + PARSER.declareInt(Builder::setNNeighbors, N_NEIGHBORS); + PARSER.declareField(Builder::setMethod, p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return Method.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, METHOD, ObjectParser.ValueType.STRING); + PARSER.declareDouble(Builder::setMinScoreToWriteFeatureInfluence, MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE); + } + + private final Integer nNeighbors; + private final Method method; + private final Double minScoreToWriteFeatureInfluence; + + /** + * Constructs the outlier detection configuration + * @param nNeighbors The number of neighbors. Leave unspecified for dynamic detection. + * @param method The method. Leave unspecified for a dynamic mixture of methods. + * @param minScoreToWriteFeatureInfluence The min outlier score required to calculate feature influence. Defaults to 0.1. + */ + private OutlierDetection(@Nullable Integer nNeighbors, @Nullable Method method, @Nullable Double minScoreToWriteFeatureInfluence) { + this.nNeighbors = nNeighbors; + this.method = method; + this.minScoreToWriteFeatureInfluence = minScoreToWriteFeatureInfluence; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + public Integer getNNeighbors() { + return nNeighbors; + } + + public Method getMethod() { + return method; + } + + public Double getMinScoreToWriteFeatureInfluence() { + return minScoreToWriteFeatureInfluence; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (nNeighbors != null) { + builder.field(N_NEIGHBORS.getPreferredName(), nNeighbors); + } + if (method != null) { + builder.field(METHOD.getPreferredName(), method); + } + if (minScoreToWriteFeatureInfluence != null) { + builder.field(MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName(), minScoreToWriteFeatureInfluence); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + OutlierDetection other = (OutlierDetection) o; + return Objects.equals(nNeighbors, other.nNeighbors) + && Objects.equals(method, other.method) + && Objects.equals(minScoreToWriteFeatureInfluence, other.minScoreToWriteFeatureInfluence); + } + + @Override + public int hashCode() { + return Objects.hash(nNeighbors, method, minScoreToWriteFeatureInfluence); + } + + @Override + public String toString() { + return Strings.toString(this); + } + + public enum Method { + LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN; + + public static Method fromString(String value) { + return Method.valueOf(value.toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + } + + public static class Builder { + + private Integer nNeighbors; + private Method method; + private Double minScoreToWriteFeatureInfluence; + + private Builder() {} + + public Builder setNNeighbors(Integer nNeighbors) { + this.nNeighbors = nNeighbors; + return this; + } + + public Builder setMethod(Method method) { + this.method = method; + return this; + } + + public Builder setMinScoreToWriteFeatureInfluence(Double minScoreToWriteFeatureInfluence) { + this.minScoreToWriteFeatureInfluence = minScoreToWriteFeatureInfluence; + return this; + } + + public OutlierDetection build() { + return new OutlierDetection(nNeighbors, method, minScoreToWriteFeatureInfluence); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/QueryConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/QueryConfig.java new file mode 100644 index 0000000000000..ae704db9f800e --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/QueryConfig.java @@ -0,0 +1,82 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; + +import java.io.IOException; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +/** + * Object for encapsulating the desired Query for a DataFrameAnalysis + */ +public class QueryConfig implements ToXContentObject { + + public static QueryConfig fromXContent(XContentParser parser) throws IOException { + QueryBuilder query = AbstractQueryBuilder.parseInnerQueryBuilder(parser); + return new QueryConfig(query); + } + + private final QueryBuilder query; + + public QueryConfig(QueryBuilder query) { + this.query = requireNonNull(query); + } + + public QueryConfig(QueryConfig queryConfig) { + this(requireNonNull(queryConfig).query); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + query.toXContent(builder, params); + return builder; + } + + public QueryBuilder getQuery() { + return query; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + QueryConfig other = (QueryConfig) o; + return Objects.equals(query, other.query); + } + + @Override + public int hashCode() { + return Objects.hash(query); + } + + @Override + public String toString() { + return Strings.toString(this); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/Evaluation.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/Evaluation.java new file mode 100644 index 0000000000000..78578597e195b --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/Evaluation.java @@ -0,0 +1,32 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +/** + * Defines an evaluation + */ +public interface Evaluation extends ToXContentObject { + + /** + * Returns the evaluation name + */ + String getName(); +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/EvaluationMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/EvaluationMetric.java new file mode 100644 index 0000000000000..a0f77838f1fd0 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/EvaluationMetric.java @@ -0,0 +1,43 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +/** + * Defines an evaluation metric + */ +public interface EvaluationMetric extends ToXContentObject { + + /** + * Returns the name of the metric + */ + String getName(); + + /** + * The result of an evaluation metric + */ + interface Result extends ToXContentObject { + + /** + * Returns the name of the metric + */ + String getMetricName(); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java new file mode 100644 index 0000000000000..764ff41de86e0 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -0,0 +1,57 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; + +import java.util.Arrays; +import java.util.List; + +public class MlEvaluationNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + return Arrays.asList( + // Evaluations + new NamedXContentRegistry.Entry( + Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent), + // Evaluation metrics + new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent), + new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent), + new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent), + // Evaluation metrics results + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent)); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java new file mode 100644 index 0000000000000..f41c13f248ab9 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java @@ -0,0 +1,47 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +abstract class AbstractConfusionMatrixMetric implements EvaluationMetric { + + protected static final ParseField AT = new ParseField("at"); + + protected final double[] thresholds; + + protected AbstractConfusionMatrixMetric(List at) { + this.thresholds = Objects.requireNonNull(at).stream().mapToDouble(Double::doubleValue).toArray(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder + .startObject() + .field(AT.getPreferredName(), thresholds) + .endObject(); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetric.java new file mode 100644 index 0000000000000..78c713c592581 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetric.java @@ -0,0 +1,241 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * Area under the curve (AUC) of the receiver operating characteristic (ROC). + * The ROC curve is a plot of the TPR (true positive rate) against + * the FPR (false positive rate) over a varying threshold. + */ +public class AucRocMetric implements EvaluationMetric { + + public static final String NAME = "auc_roc"; + + public static final ParseField INCLUDE_CURVE = new ParseField("include_curve"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, args -> new AucRocMetric((Boolean) args[0])); + + static { + PARSER.declareBoolean(optionalConstructorArg(), INCLUDE_CURVE); + } + + public static AucRocMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static AucRocMetric withCurve() { + return new AucRocMetric(true); + } + + private final boolean includeCurve; + + public AucRocMetric(Boolean includeCurve) { + this.includeCurve = includeCurve == null ? false : includeCurve; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder + .startObject() + .field(INCLUDE_CURVE.getPreferredName(), includeCurve) + .endObject(); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AucRocMetric that = (AucRocMetric) o; + return Objects.equals(includeCurve, that.includeCurve); + } + + @Override + public int hashCode() { + return Objects.hash(includeCurve); + } + + public static class Result implements EvaluationMetric.Result { + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ParseField SCORE = new ParseField("score"); + private static final ParseField CURVE = new ParseField("curve"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("auc_roc_result", true, args -> new Result((double) args[0], (List) args[1])); + + static { + PARSER.declareDouble(constructorArg(), SCORE); + PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE); + } + + private final double score; + private final List curve; + + public Result(double score, @Nullable List curve) { + this.score = score; + this.curve = curve; + } + + @Override + public String getMetricName() { + return NAME; + } + + public double getScore() { + return score; + } + + public List getCurve() { + return curve == null ? null : Collections.unmodifiableList(curve); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(SCORE.getPreferredName(), score); + if (curve != null && curve.isEmpty() == false) { + builder.field(CURVE.getPreferredName(), curve); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(score, that.score) + && Objects.equals(curve, that.curve); + } + + @Override + public int hashCode() { + return Objects.hash(score, curve); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + public static final class AucRocPoint implements ToXContentObject { + + public static AucRocPoint fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ParseField TPR = new ParseField("tpr"); + private static final ParseField FPR = new ParseField("fpr"); + private static final ParseField THRESHOLD = new ParseField("threshold"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "auc_roc_point", + true, + args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2])); + + static { + PARSER.declareDouble(constructorArg(), TPR); + PARSER.declareDouble(constructorArg(), FPR); + PARSER.declareDouble(constructorArg(), THRESHOLD); + } + + private final double tpr; + private final double fpr; + private final double threshold; + + public AucRocPoint(double tpr, double fpr, double threshold) { + this.tpr = tpr; + this.fpr = fpr; + this.threshold = threshold; + } + + public double getTruePositiveRate() { + return tpr; + } + + public double getFalsePositiveRate() { + return fpr; + } + + public double getThreshold() { + return threshold; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(TPR.getPreferredName(), tpr) + .field(FPR.getPreferredName(), fpr) + .field(THRESHOLD.getPreferredName(), threshold) + .endObject(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AucRocPoint that = (AucRocPoint) o; + return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold; + } + + @Override + public int hashCode() { + return Objects.hash(tpr, fpr, threshold); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java new file mode 100644 index 0000000000000..6d5fa04da38e5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java @@ -0,0 +1,129 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * Evaluation of binary soft classification methods, e.g. outlier detection. + * This is useful to evaluate problems where a model outputs a probability of whether + * a data frame row belongs to one of two groups. + */ +public class BinarySoftClassification implements Evaluation { + + public static final String NAME = "binary_soft_classification"; + + private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); + private static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field"); + private static final ParseField METRICS = new ParseField("metrics"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + NAME, + args -> new BinarySoftClassification((String) args[0], (String) args[1], (List) args[2])); + + static { + PARSER.declareString(constructorArg(), ACTUAL_FIELD); + PARSER.declareString(constructorArg(), PREDICTED_PROBABILITY_FIELD); + PARSER.declareNamedObjects(optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, n, null), METRICS); + } + + public static BinarySoftClassification fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** + * The field where the actual class is marked up. + * The value of this field is assumed to either be 1 or 0, or true or false. + */ + private final String actualField; + + /** + * The field of the predicted probability in [0.0, 1.0]. + */ + private final String predictedProbabilityField; + + /** + * The list of metrics to calculate + */ + private final List metrics; + + public BinarySoftClassification(String actualField, String predictedProbabilityField, EvaluationMetric... metric) { + this(actualField, predictedProbabilityField, Arrays.asList(metric)); + } + + public BinarySoftClassification(String actualField, String predictedProbabilityField, + @Nullable List metrics) { + this.actualField = Objects.requireNonNull(actualField); + this.predictedProbabilityField = Objects.requireNonNull(predictedProbabilityField); + this.metrics = Objects.requireNonNull(metrics); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_FIELD.getPreferredName(), actualField); + builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField); + + builder.startObject(METRICS.getPreferredName()); + for (EvaluationMetric metric : metrics) { + builder.field(metric.getName(), metric); + } + builder.endObject(); + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BinarySoftClassification that = (BinarySoftClassification) o; + return Objects.equals(actualField, that.actualField) + && Objects.equals(predictedProbabilityField, that.predictedProbabilityField) + && Objects.equals(metrics, that.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(actualField, predictedProbabilityField, metrics); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetric.java new file mode 100644 index 0000000000000..d5e4307c9cc74 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetric.java @@ -0,0 +1,206 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class ConfusionMatrixMetric extends AbstractConfusionMatrixMetric { + + public static final String NAME = "confusion_matrix"; + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, args -> new ConfusionMatrixMetric((List) args[0])); + + static { + PARSER.declareDoubleArray(constructorArg(), AT); + } + + public static ConfusionMatrixMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static ConfusionMatrixMetric at(Double... at) { + return new ConfusionMatrixMetric(Arrays.asList(at)); + } + + public ConfusionMatrixMetric(List at) { + super(at); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConfusionMatrixMetric that = (ConfusionMatrixMetric) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + public static class Result implements EvaluationMetric.Result { + + public static Result fromXContent(XContentParser parser) throws IOException { + return new Result(parser.map(LinkedHashMap::new, ConfusionMatrix::fromXContent)); + } + + private final Map results; + + public Result(Map results) { + this.results = Objects.requireNonNull(results); + } + + @Override + public String getMetricName() { + return NAME; + } + + public ConfusionMatrix getScoreByThreshold(String threshold) { + return results.get(threshold); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder.map(results); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(results, that.results); + } + + @Override + public int hashCode() { + return Objects.hash(results); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + public static final class ConfusionMatrix implements ToXContentObject { + + public static ConfusionMatrix fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ParseField TP = new ParseField("tp"); + private static final ParseField FP = new ParseField("fp"); + private static final ParseField TN = new ParseField("tn"); + private static final ParseField FN = new ParseField("fn"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "confusion_matrix", true, args -> new ConfusionMatrix((long) args[0], (long) args[1], (long) args[2], (long) args[3])); + + static { + PARSER.declareLong(constructorArg(), TP); + PARSER.declareLong(constructorArg(), FP); + PARSER.declareLong(constructorArg(), TN); + PARSER.declareLong(constructorArg(), FN); + } + + private final long tp; + private final long fp; + private final long tn; + private final long fn; + + public ConfusionMatrix(long tp, long fp, long tn, long fn) { + this.tp = tp; + this.fp = fp; + this.tn = tn; + this.fn = fn; + } + + public long getTruePositives() { + return tp; + } + + public long getFalsePositives() { + return fp; + } + + public long getTrueNegatives() { + return tn; + } + + public long getFalseNegatives() { + return fn; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(TP.getPreferredName(), tp) + .field(FP.getPreferredName(), fp) + .field(TN.getPreferredName(), tn) + .field(FN.getPreferredName(), fn) + .endObject(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConfusionMatrix that = (ConfusionMatrix) o; + return tp == that.tp && fp == that.fp && tn == that.tn && fn == that.fn; + } + + @Override + public int hashCode() { + return Objects.hash(tp, fp, tn, fn); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/PrecisionMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/PrecisionMetric.java new file mode 100644 index 0000000000000..2a0f1499461d6 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/PrecisionMetric.java @@ -0,0 +1,123 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class PrecisionMetric extends AbstractConfusionMatrixMetric { + + public static final String NAME = "precision"; + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, args -> new PrecisionMetric((List) args[0])); + + static { + PARSER.declareDoubleArray(constructorArg(), AT); + } + + public static PrecisionMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static PrecisionMetric at(Double... at) { + return new PrecisionMetric(Arrays.asList(at)); + } + + public PrecisionMetric(List at) { + super(at); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PrecisionMetric that = (PrecisionMetric) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + public static class Result implements EvaluationMetric.Result { + + public static Result fromXContent(XContentParser parser) throws IOException { + return new Result(parser.map(LinkedHashMap::new, p -> p.doubleValue())); + } + + private final Map results; + + public Result(Map results) { + this.results = Objects.requireNonNull(results); + } + + @Override + public String getMetricName() { + return NAME; + } + + public Double getScoreByThreshold(String threshold) { + return results.get(threshold); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder.map(results); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(results, that.results); + } + + @Override + public int hashCode() { + return Objects.hash(results); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/RecallMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/RecallMetric.java new file mode 100644 index 0000000000000..505ff1b34d7c5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/RecallMetric.java @@ -0,0 +1,123 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class RecallMetric extends AbstractConfusionMatrixMetric { + + public static final String NAME = "recall"; + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, args -> new RecallMetric((List) args[0])); + + static { + PARSER.declareDoubleArray(constructorArg(), AT); + } + + public static RecallMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static RecallMetric at(Double... at) { + return new RecallMetric(Arrays.asList(at)); + } + + public RecallMetric(List at) { + super(at); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RecallMetric that = (RecallMetric) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + public static class Result implements EvaluationMetric.Result { + + public static Result fromXContent(XContentParser parser) throws IOException { + return new Result(parser.map(LinkedHashMap::new, p -> p.doubleValue())); + } + + private final Map results; + + public Result(Map results) { + this.results = Objects.requireNonNull(results); + } + + @Override + public String getMetricName() { + return NAME; + } + + public Double getScoreByThreshold(String threshold) { + return results.get(threshold); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder.map(results); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(results, that.results); + } + + @Override + public int hashCode() { + return Objects.hash(results); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } +} diff --git a/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider b/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider index 4204a868246a5..dde81e43867d8 100644 --- a/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider +++ b/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider @@ -1 +1,4 @@ -org.elasticsearch.client.indexlifecycle.IndexLifecycleNamedXContentProvider \ No newline at end of file +org.elasticsearch.client.dataframe.DataFrameNamedXContentProvider +org.elasticsearch.client.indexlifecycle.IndexLifecycleNamedXContentProvider +org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider +org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider \ No newline at end of file diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/DataFrameRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/DataFrameRequestConvertersTests.java index 7a1e5e2389316..153d9a98d9da1 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/DataFrameRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/DataFrameRequestConvertersTests.java @@ -24,6 +24,7 @@ import org.apache.http.client.methods.HttpPost; import org.apache.http.client.methods.HttpPut; import org.elasticsearch.client.core.PageParams; +import org.elasticsearch.client.dataframe.DataFrameNamedXContentProvider; import org.elasticsearch.client.dataframe.DeleteDataFrameTransformRequest; import org.elasticsearch.client.dataframe.GetDataFrameTransformRequest; import org.elasticsearch.client.dataframe.GetDataFrameTransformStatsRequest; @@ -43,6 +44,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.equalTo; @@ -53,7 +55,9 @@ public class DataFrameRequestConvertersTests extends ESTestCase { @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContents); } public void testPutDataFrameTransform() throws IOException { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java index fd867a12204d0..36d71df5f91bb 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java @@ -28,12 +28,14 @@ import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteExpiredDataRequest; import org.elasticsearch.client.ml.DeleteFilterRequest; import org.elasticsearch.client.ml.DeleteForecastRequest; import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequest; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FindFileStructureRequestTests; import org.elasticsearch.client.ml.FlushJobRequest; @@ -42,6 +44,8 @@ import org.elasticsearch.client.ml.GetCalendarEventsRequest; import org.elasticsearch.client.ml.GetCalendarsRequest; import org.elasticsearch.client.ml.GetCategoriesRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsRequest; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; import org.elasticsearch.client.ml.GetFiltersRequest; @@ -58,13 +62,16 @@ import org.elasticsearch.client.ml.PreviewDatafeedRequest; import org.elasticsearch.client.ml.PutCalendarJobRequest; import org.elasticsearch.client.ml.PutCalendarRequest; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutFilterRequest; import org.elasticsearch.client.ml.PutJobRequest; import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.SetUpgradeModeRequest; +import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StartDatafeedRequestTests; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.UpdateFilterRequest; import org.elasticsearch.client.ml.UpdateJobRequest; @@ -75,6 +82,12 @@ import org.elasticsearch.client.ml.calendars.ScheduledEventTests; import org.elasticsearch.client.ml.datafeed.DatafeedConfig; import org.elasticsearch.client.ml.datafeed.DatafeedConfigTests; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.Detector; @@ -84,23 +97,30 @@ import org.elasticsearch.client.ml.job.config.MlFilter; import org.elasticsearch.client.ml.job.config.MlFilterTests; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.ESTestCase; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import static org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfigTests.randomDataFrameAnalyticsConfig; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.is; import static org.hamcrest.core.IsNull.nullValue; @@ -154,7 +174,6 @@ public void testGetJobStats() { assertEquals(Boolean.toString(true), request.getParameters().get("allow_no_jobs")); } - public void testOpenJob() throws Exception { String jobId = "some-job-id"; OpenJobRequest openJobRequest = new OpenJobRequest(jobId); @@ -669,6 +688,109 @@ public void testDeleteCalendarEvent() { assertEquals("/_ml/calendars/" + calendarId + "/events/" + eventId, request.getEndpoint()); } + public void testPutDataFrameAnalytics() throws IOException { + PutDataFrameAnalyticsRequest putRequest = new PutDataFrameAnalyticsRequest(randomDataFrameAnalyticsConfig()); + Request request = MLRequestConverters.putDataFrameAnalytics(putRequest); + assertEquals(HttpPut.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + putRequest.getConfig().getId(), request.getEndpoint()); + try (XContentParser parser = createParser(JsonXContent.jsonXContent, request.getEntity().getContent())) { + DataFrameAnalyticsConfig parsedConfig = DataFrameAnalyticsConfig.fromXContent(parser); + assertThat(parsedConfig, equalTo(putRequest.getConfig())); + } + } + + public void testGetDataFrameAnalytics() { + String configId1 = randomAlphaOfLength(10); + String configId2 = randomAlphaOfLength(10); + String configId3 = randomAlphaOfLength(10); + GetDataFrameAnalyticsRequest getRequest = new GetDataFrameAnalyticsRequest(configId1, configId2, configId3) + .setAllowNoMatch(false) + .setPageParams(new PageParams(100, 300)); + + Request request = MLRequestConverters.getDataFrameAnalytics(getRequest); + assertEquals(HttpGet.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + configId1 + "," + configId2 + "," + configId3, request.getEndpoint()); + assertThat(request.getParameters(), allOf(hasEntry("from", "100"), hasEntry("size", "300"), hasEntry("allow_no_match", "false"))); + assertNull(request.getEntity()); + } + + public void testGetDataFrameAnalyticsStats() { + String configId1 = randomAlphaOfLength(10); + String configId2 = randomAlphaOfLength(10); + String configId3 = randomAlphaOfLength(10); + GetDataFrameAnalyticsStatsRequest getStatsRequest = new GetDataFrameAnalyticsStatsRequest(configId1, configId2, configId3) + .setAllowNoMatch(false) + .setPageParams(new PageParams(100, 300)); + + Request request = MLRequestConverters.getDataFrameAnalyticsStats(getStatsRequest); + assertEquals(HttpGet.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + configId1 + "," + configId2 + "," + configId3 + "/_stats", request.getEndpoint()); + assertThat(request.getParameters(), allOf(hasEntry("from", "100"), hasEntry("size", "300"), hasEntry("allow_no_match", "false"))); + assertNull(request.getEntity()); + } + + public void testStartDataFrameAnalytics() { + StartDataFrameAnalyticsRequest startRequest = new StartDataFrameAnalyticsRequest(randomAlphaOfLength(10)); + Request request = MLRequestConverters.startDataFrameAnalytics(startRequest); + assertEquals(HttpPost.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + startRequest.getId() + "/_start", request.getEndpoint()); + assertNull(request.getEntity()); + } + + public void testStartDataFrameAnalytics_WithTimeout() { + StartDataFrameAnalyticsRequest startRequest = new StartDataFrameAnalyticsRequest(randomAlphaOfLength(10)) + .setTimeout(TimeValue.timeValueMinutes(1)); + Request request = MLRequestConverters.startDataFrameAnalytics(startRequest); + assertEquals(HttpPost.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + startRequest.getId() + "/_start", request.getEndpoint()); + assertThat(request.getParameters(), hasEntry("timeout", "1m")); + assertNull(request.getEntity()); + } + + public void testStopDataFrameAnalytics() { + StopDataFrameAnalyticsRequest stopRequest = new StopDataFrameAnalyticsRequest(randomAlphaOfLength(10)); + Request request = MLRequestConverters.stopDataFrameAnalytics(stopRequest); + assertEquals(HttpPost.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + stopRequest.getId() + "/_stop", request.getEndpoint()); + assertNull(request.getEntity()); + } + + public void testStopDataFrameAnalytics_WithParams() { + StopDataFrameAnalyticsRequest stopRequest = new StopDataFrameAnalyticsRequest(randomAlphaOfLength(10)) + .setTimeout(TimeValue.timeValueMinutes(1)) + .setAllowNoMatch(false); + Request request = MLRequestConverters.stopDataFrameAnalytics(stopRequest); + assertEquals(HttpPost.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + stopRequest.getId() + "/_stop", request.getEndpoint()); + assertThat(request.getParameters(), allOf(hasEntry("timeout", "1m"), hasEntry("allow_no_match", "false"))); + assertNull(request.getEntity()); + } + + public void testDeleteDataFrameAnalytics() { + DeleteDataFrameAnalyticsRequest deleteRequest = new DeleteDataFrameAnalyticsRequest(randomAlphaOfLength(10)); + Request request = MLRequestConverters.deleteDataFrameAnalytics(deleteRequest); + assertEquals(HttpDelete.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + deleteRequest.getId(), request.getEndpoint()); + assertNull(request.getEntity()); + } + + public void testEvaluateDataFrame() throws IOException { + EvaluateDataFrameRequest evaluateRequest = + new EvaluateDataFrameRequest( + Arrays.asList(generateRandomStringArray(1, 10, false, false)), + new BinarySoftClassification( + randomAlphaOfLengthBetween(1, 10), + randomAlphaOfLengthBetween(1, 10), + PrecisionMetric.at(0.5), RecallMetric.at(0.6, 0.7))); + Request request = MLRequestConverters.evaluateDataFrame(evaluateRequest); + assertEquals(HttpPost.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/_evaluate", request.getEndpoint()); + try (XContentParser parser = createParser(JsonXContent.jsonXContent, request.getEntity().getContent())) { + EvaluateDataFrameRequest parsedRequest = EvaluateDataFrameRequest.fromXContent(parser); + assertThat(parsedRequest, equalTo(evaluateRequest)); + } + } + public void testPutFilter() throws IOException { MlFilter filter = MlFilterTests.createRandomBuilder("foo").build(); PutFilterRequest putFilterRequest = new PutFilterRequest(filter); @@ -835,6 +957,15 @@ public void testSetUpgradeMode() { assertThat(request.getParameters().get(SetUpgradeModeRequest.TIMEOUT.getPreferredName()), is("1h")); } + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContent); + } + private static Job createValidJob(String jobId) { AnalysisConfig.Builder analysisConfig = AnalysisConfig.builder(Collections.singletonList( Detector.builder().setFunction("count").build())); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 8ef28733f2e12..77efe43b2e174 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -29,11 +29,13 @@ import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.client.core.PageParams; import org.elasticsearch.client.indices.CreateIndexRequest; +import org.elasticsearch.client.indices.GetIndexRequest; import org.elasticsearch.client.ml.CloseJobRequest; import org.elasticsearch.client.ml.CloseJobResponse; import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteExpiredDataRequest; import org.elasticsearch.client.ml.DeleteExpiredDataResponse; @@ -42,6 +44,8 @@ import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteJobResponse; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameResponse; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FindFileStructureResponse; import org.elasticsearch.client.ml.FlushJobRequest; @@ -52,6 +56,10 @@ import org.elasticsearch.client.ml.GetCalendarEventsResponse; import org.elasticsearch.client.ml.GetCalendarsRequest; import org.elasticsearch.client.ml.GetCalendarsResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsResponse; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedResponse; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; @@ -77,6 +85,8 @@ import org.elasticsearch.client.ml.PutCalendarJobRequest; import org.elasticsearch.client.ml.PutCalendarRequest; import org.elasticsearch.client.ml.PutCalendarResponse; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutDatafeedResponse; import org.elasticsearch.client.ml.PutFilterRequest; @@ -86,8 +96,11 @@ import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.RevertModelSnapshotResponse; import org.elasticsearch.client.ml.SetUpgradeModeRequest; +import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StartDatafeedResponse; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.StopDatafeedResponse; import org.elasticsearch.client.ml.UpdateDatafeedRequest; @@ -103,6 +116,18 @@ import org.elasticsearch.client.ml.datafeed.DatafeedState; import org.elasticsearch.client.ml.datafeed.DatafeedStats; import org.elasticsearch.client.ml.datafeed.DatafeedUpdate; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; +import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.QueryConfig; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.DataDescription; @@ -113,9 +138,12 @@ import org.elasticsearch.client.ml.job.config.MlFilter; import org.elasticsearch.client.ml.job.process.ModelSnapshot; import org.elasticsearch.client.ml.job.stats.JobStats; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.junit.After; @@ -136,6 +164,7 @@ import static org.hamcrest.CoreMatchers.hasItem; import static org.hamcrest.CoreMatchers.hasItems; import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -528,18 +557,7 @@ public void testStartDatafeed() throws Exception { String indexName = "start_data_1"; // Set up the index and docs - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName, defaultMappingForTest()); BulkRequest bulk = new BulkRequest(); bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); long now = (System.currentTimeMillis()/1000)*1000; @@ -611,18 +629,7 @@ public void testStopDatafeed() throws Exception { String indexName = "stop_data_1"; // Set up the index - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName, defaultMappingForTest()); // create the job and the datafeed Job job1 = buildJob(jobId1); @@ -684,18 +691,7 @@ public void testGetDatafeedStats() throws Exception { String indexName = "datafeed_stats_data_1"; // Set up the index - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName, defaultMappingForTest()); // create the job and the datafeed Job job1 = buildJob(jobId1); @@ -762,18 +758,7 @@ public void testPreviewDatafeed() throws Exception { String indexName = "preview_data_1"; // Set up the index and docs - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName, defaultMappingForTest()); BulkRequest bulk = new BulkRequest(); bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); long now = (System.currentTimeMillis()/1000)*1000; @@ -826,21 +811,9 @@ public void testDeleteExpiredDataGivenNothingToDelete() throws Exception { } private String createExpiredData(String jobId) throws Exception { - String indexId = jobId + "-data"; + String indexName = jobId + "-data"; // Set up the index and docs - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexId); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .field("format", "epoch_millis") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName, defaultMappingForTest()); BulkRequest bulk = new BulkRequest(); bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); @@ -853,7 +826,7 @@ private String createExpiredData(String jobId) throws Exception { long timestamp = nowMillis - TimeValue.timeValueHours(totalBuckets - bucket).getMillis(); int bucketRate = bucket == anomalousBucket ? anomalousRate : normalRate; for (int point = 0; point < bucketRate; point++) { - IndexRequest indexRequest = new IndexRequest(indexId); + IndexRequest indexRequest = new IndexRequest(indexName); indexRequest.source(XContentType.JSON, "timestamp", timestamp, "total", randomInt(1000)); bulk.add(indexRequest); } @@ -872,7 +845,7 @@ private String createExpiredData(String jobId) throws Exception { Job job = buildJobForExpiredDataTests(jobId); putJob(job); openJob(job); - String datafeedId = createAndPutDatafeed(jobId, indexId); + String datafeedId = createAndPutDatafeed(jobId, indexName); startDatafeed(datafeedId, String.valueOf(0), String.valueOf(nowMillis - TimeValue.timeValueHours(24).getMillis())); @@ -1230,6 +1203,418 @@ public void testDeleteCalendarEvent() throws IOException { assertThat(remainingIds, not(hasItem(deletedEvent))); } + public void testPutDataFrameAnalyticsConfig() throws Exception { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "put-test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("put-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("put-test-dest-index") + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); + + createIndex("put-test-source-index", defaultMappingForTest()); + + PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig(); + assertThat(createdConfig.getId(), equalTo(config.getId())); + assertThat(createdConfig.getSource().getIndex(), equalTo(config.getSource().getIndex())); + assertThat(createdConfig.getSource().getQueryConfig(), equalTo(new QueryConfig(new MatchAllQueryBuilder()))); // default value + assertThat(createdConfig.getDest().getIndex(), equalTo(config.getDest().getIndex())); + assertThat(createdConfig.getDest().getResultsField(), equalTo("ml")); // default value + assertThat(createdConfig.getAnalysis(), equalTo(config.getAnalysis())); + assertThat(createdConfig.getAnalyzedFields(), equalTo(config.getAnalyzedFields())); + assertThat(createdConfig.getModelMemoryLimit(), equalTo(ByteSizeValue.parseBytesSizeValue("1gb", ""))); // default value + } + + public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "get-test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("get-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("get-test-dest-index") + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); + + createIndex("get-test-source-index", defaultMappingForTest()); + + PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig(); + + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configId), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(1)); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), contains(createdConfig)); + } + + public void testGetDataFrameAnalyticsConfig_MultipleConfigs() throws Exception { + createIndex("get-test-source-index", defaultMappingForTest()); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configIdPrefix = "get-test-config-"; + int numberOfConfigs = 10; + List createdConfigs = new ArrayList<>(); + for (int i = 0; i < numberOfConfigs; ++i) { + String configId = configIdPrefix + i; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("get-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("get-test-dest-index") + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); + + PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig(); + createdConfigs.add(createdConfig); + } + + { + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + GetDataFrameAnalyticsRequest.getAllDataFrameAnalyticsRequest(), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(numberOfConfigs)); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), containsInAnyOrder(createdConfigs.toArray())); + } + { + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configIdPrefix + "*"), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(numberOfConfigs)); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), containsInAnyOrder(createdConfigs.toArray())); + } + { + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configIdPrefix + "9", configIdPrefix + "1", configIdPrefix + "4"), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(3)); + assertThat( + getDataFrameAnalyticsResponse.getAnalytics(), + containsInAnyOrder(createdConfigs.get(1), createdConfigs.get(4), createdConfigs.get(9))); + } + { + GetDataFrameAnalyticsRequest getDataFrameAnalyticsRequest = new GetDataFrameAnalyticsRequest(configIdPrefix + "*"); + getDataFrameAnalyticsRequest.setPageParams(new PageParams(3, 4)); + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + getDataFrameAnalyticsRequest, + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(4)); + assertThat( + getDataFrameAnalyticsResponse.getAnalytics(), + containsInAnyOrder(createdConfigs.get(3), createdConfigs.get(4), createdConfigs.get(5), createdConfigs.get(6))); + } + } + + public void testGetDataFrameAnalyticsConfig_ConfigNotFound() { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + GetDataFrameAnalyticsRequest request = new GetDataFrameAnalyticsRequest("config_that_does_not_exist"); + ElasticsearchStatusException exception = expectThrows(ElasticsearchStatusException.class, + () -> execute(request, machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync)); + assertThat(exception.status().getStatus(), equalTo(404)); + } + + public void testGetDataFrameAnalyticsStats() throws Exception { + String sourceIndex = "get-stats-test-source-index"; + String destIndex = "get-stats-test-dest-index"; + createIndex(sourceIndex, defaultMappingForTest()); + highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "get-stats-test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex(sourceIndex) + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex(destIndex) + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); + + execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + + GetDataFrameAnalyticsStatsResponse statsResponse = execute( + new GetDataFrameAnalyticsStatsRequest(configId), + machineLearningClient::getDataFrameAnalyticsStats, machineLearningClient::getDataFrameAnalyticsStatsAsync); + + assertThat(statsResponse.getAnalyticsStats(), hasSize(1)); + DataFrameAnalyticsStats stats = statsResponse.getAnalyticsStats().get(0); + assertThat(stats.getId(), equalTo(configId)); + assertThat(stats.getState(), equalTo(DataFrameAnalyticsState.STOPPED)); + assertNull(stats.getProgressPercent()); + assertNull(stats.getNode()); + assertNull(stats.getAssignmentExplanation()); + assertThat(statsResponse.getNodeFailures(), hasSize(0)); + assertThat(statsResponse.getTaskFailures(), hasSize(0)); + } + + public void testStartDataFrameAnalyticsConfig() throws Exception { + String sourceIndex = "start-test-source-index"; + String destIndex = "start-test-dest-index"; + createIndex(sourceIndex, defaultMappingForTest()); + highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT); + + // Verify that the destination index does not exist. Otherwise, analytics' reindexing step would fail. + assertFalse(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT)); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "start-test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex(sourceIndex) + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex(destIndex) + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); + + execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)); + + AcknowledgedResponse startDataFrameAnalyticsResponse = execute( + new StartDataFrameAnalyticsRequest(configId), + machineLearningClient::startDataFrameAnalytics, machineLearningClient::startDataFrameAnalyticsAsync); + assertTrue(startDataFrameAnalyticsResponse.isAcknowledged()); + + // Wait for the analytics to stop. + assertBusy(() -> assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)), 30, TimeUnit.SECONDS); + + // Verify that the destination index got created. + assertTrue(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT)); + } + + public void testStopDataFrameAnalyticsConfig() throws Exception { + String sourceIndex = "stop-test-source-index"; + String destIndex = "stop-test-dest-index"; + createIndex(sourceIndex, mappingForClassification()); + highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT); + + // Verify that the destination index does not exist. Otherwise, analytics' reindexing step would fail. + assertFalse(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT)); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "stop-test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex(sourceIndex) + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex(destIndex) + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); + + execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)); + + AcknowledgedResponse startDataFrameAnalyticsResponse = execute( + new StartDataFrameAnalyticsRequest(configId), + machineLearningClient::startDataFrameAnalytics, machineLearningClient::startDataFrameAnalyticsAsync); + assertTrue(startDataFrameAnalyticsResponse.isAcknowledged()); + assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STARTED)); + + StopDataFrameAnalyticsResponse stopDataFrameAnalyticsResponse = execute( + new StopDataFrameAnalyticsRequest(configId), + machineLearningClient::stopDataFrameAnalytics, machineLearningClient::stopDataFrameAnalyticsAsync); + assertTrue(stopDataFrameAnalyticsResponse.isStopped()); + assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)); + } + + private DataFrameAnalyticsState getAnalyticsState(String configId) throws IOException { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + GetDataFrameAnalyticsStatsResponse statsResponse = + machineLearningClient.getDataFrameAnalyticsStats(new GetDataFrameAnalyticsStatsRequest(configId), RequestOptions.DEFAULT); + assertThat(statsResponse.getAnalyticsStats(), hasSize(1)); + DataFrameAnalyticsStats stats = statsResponse.getAnalyticsStats().get(0); + return stats.getState(); + } + + public void testDeleteDataFrameAnalyticsConfig() throws Exception { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "delete-test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("delete-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("delete-test-dest-index") + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); + + createIndex("delete-test-source-index", defaultMappingForTest()); + + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configId + "*"), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(0)); + + execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + + getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configId + "*"), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(1)); + + AcknowledgedResponse deleteDataFrameAnalyticsResponse = execute( + new DeleteDataFrameAnalyticsRequest(configId), + machineLearningClient::deleteDataFrameAnalytics, machineLearningClient::deleteDataFrameAnalyticsAsync); + assertTrue(deleteDataFrameAnalyticsResponse.isAcknowledged()); + + getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configId + "*"), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(0)); + } + + public void testDeleteDataFrameAnalyticsConfig_ConfigNotFound() { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + DeleteDataFrameAnalyticsRequest request = new DeleteDataFrameAnalyticsRequest("config_that_does_not_exist"); + ElasticsearchStatusException exception = expectThrows(ElasticsearchStatusException.class, + () -> execute( + request, machineLearningClient::deleteDataFrameAnalytics, machineLearningClient::deleteDataFrameAnalyticsAsync)); + assertThat(exception.status().getStatus(), equalTo(404)); + } + + public void testEvaluateDataFrame() throws IOException { + String indexName = "evaluate-test-index"; + createIndex(indexName, mappingForClassification()); + BulkRequest bulk = new BulkRequest() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(docForClassification(indexName, false, 0.1)) // #0 + .add(docForClassification(indexName, false, 0.2)) // #1 + .add(docForClassification(indexName, false, 0.3)) // #2 + .add(docForClassification(indexName, false, 0.4)) // #3 + .add(docForClassification(indexName, false, 0.7)) // #4 + .add(docForClassification(indexName, true, 0.2)) // #5 + .add(docForClassification(indexName, true, 0.3)) // #6 + .add(docForClassification(indexName, true, 0.4)) // #7 + .add(docForClassification(indexName, true, 0.8)) // #8 + .add(docForClassification(indexName, true, 0.9)); // #9 + highLevelClient().bulk(bulk, RequestOptions.DEFAULT); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + EvaluateDataFrameRequest evaluateDataFrameRequest = + new EvaluateDataFrameRequest( + indexName, + new BinarySoftClassification( + actualField, + probabilityField, + PrecisionMetric.at(0.4, 0.5, 0.6), RecallMetric.at(0.5, 0.7), ConfusionMatrixMetric.at(0.5), AucRocMetric.withCurve())); + + EvaluateDataFrameResponse evaluateDataFrameResponse = + execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(BinarySoftClassification.NAME)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4)); + + PrecisionMetric.Result precisionResult = evaluateDataFrameResponse.getMetricByName(PrecisionMetric.NAME); + assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME)); + // Precision is 3/5=0.6 as there were 3 true examples (#7, #8, #9) among the 5 positive examples (#3, #4, #7, #8, #9) + assertThat(precisionResult.getScoreByThreshold("0.4"), closeTo(0.6, 1e-9)); + // Precision is 2/3=0.(6) as there were 2 true examples (#8, #9) among the 3 positive examples (#4, #8, #9) + assertThat(precisionResult.getScoreByThreshold("0.5"), closeTo(0.666666666, 1e-9)); + // Precision is 2/3=0.(6) as there were 2 true examples (#8, #9) among the 3 positive examples (#4, #8, #9) + assertThat(precisionResult.getScoreByThreshold("0.6"), closeTo(0.666666666, 1e-9)); + assertNull(precisionResult.getScoreByThreshold("0.1")); + + RecallMetric.Result recallResult = evaluateDataFrameResponse.getMetricByName(RecallMetric.NAME); + assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME)); + // Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9) + assertThat(recallResult.getScoreByThreshold("0.5"), closeTo(0.4, 1e-9)); + // Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9) + assertThat(recallResult.getScoreByThreshold("0.7"), closeTo(0.4, 1e-9)); + assertNull(recallResult.getScoreByThreshold("0.1")); + + ConfusionMatrixMetric.Result confusionMatrixResult = evaluateDataFrameResponse.getMetricByName(ConfusionMatrixMetric.NAME); + assertThat(confusionMatrixResult.getMetricName(), equalTo(ConfusionMatrixMetric.NAME)); + ConfusionMatrixMetric.ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); + assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9 + assertThat(confusionMatrix.getFalsePositives(), equalTo(1L)); // doc #4 + assertThat(confusionMatrix.getTrueNegatives(), equalTo(4L)); // docs #0, #1, #2 and #3 + assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L)); // docs #5, #6 and #7 + assertNull(confusionMatrixResult.getScoreByThreshold("0.1")); + + AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME); + assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); + assertThat(aucRocResult.getScore(), closeTo(0.70025, 1e-9)); + assertNotNull(aucRocResult.getCurve()); + List curve = aucRocResult.getCurve(); + AucRocMetric.AucRocPoint curvePointAtThreshold0 = curve.stream().filter(p -> p.getThreshold() == 0.0).findFirst().get(); + assertThat(curvePointAtThreshold0.getTruePositiveRate(), equalTo(1.0)); + assertThat(curvePointAtThreshold0.getFalsePositiveRate(), equalTo(1.0)); + assertThat(curvePointAtThreshold0.getThreshold(), equalTo(0.0)); + AucRocMetric.AucRocPoint curvePointAtThreshold1 = curve.stream().filter(p -> p.getThreshold() == 1.0).findFirst().get(); + assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0)); + assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0)); + assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0)); + } + + private static XContentBuilder defaultMappingForTest() throws IOException { + return XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject("timestamp") + .field("type", "date") + .endObject() + .startObject("total") + .field("type", "long") + .endObject() + .endObject() + .endObject(); + } + + private static final String actualField = "label"; + private static final String probabilityField = "p"; + + private static XContentBuilder mappingForClassification() throws IOException { + return XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject(actualField) + .field("type", "keyword") + .endObject() + .startObject(probabilityField) + .field("type", "double") + .endObject() + .endObject() + .endObject(); + } + + private static IndexRequest docForClassification(String indexName, boolean isTrue, double p) { + return new IndexRequest() + .index(indexName) + .source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p); + } + + private void createIndex(String indexName, XContentBuilder mapping) throws IOException { + highLevelClient().indices().create(new CreateIndexRequest(indexName).mapping(mapping), RequestOptions.DEFAULT); + } + public void testPutFilter() throws Exception { String filterId = "filter-job-test"; MlFilter mlFilter = MlFilter.builder(filterId) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java index c565af7c37202..f5776e99fd0eb 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java @@ -20,14 +20,18 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.client.ml.CloseJobRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteJobRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedResponse; import org.elasticsearch.client.ml.GetJobRequest; import org.elasticsearch.client.ml.GetJobResponse; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.datafeed.DatafeedConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.client.ml.job.config.Job; import java.io.IOException; @@ -48,6 +52,7 @@ public MlTestStateCleaner(Logger logger, MachineLearningClient mlClient) { public void clearMlMetadata() throws IOException { deleteAllDatafeeds(); deleteAllJobs(); + deleteAllDataFrameAnalytics(); } private void deleteAllDatafeeds() throws IOException { @@ -99,4 +104,12 @@ private void closeAllJobs() { throw new RuntimeException("Had to resort to force-closing jobs, something went wrong?", e1); } } + + private void deleteAllDataFrameAnalytics() throws IOException { + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = + mlClient.getDataFrameAnalytics(GetDataFrameAnalyticsRequest.getAllDataFrameAnalyticsRequest(), RequestOptions.DEFAULT); + for (DataFrameAnalyticsConfig config : getDataFrameAnalyticsResponse.getAnalytics()) { + mlClient.deleteDataFrameAnalytics(new DeleteDataFrameAnalyticsRequest(config.getId()), RequestOptions.DEFAULT); + } + } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 183bce91f83ed..26e5842019675 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -46,6 +46,8 @@ import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.client.core.MainRequest; import org.elasticsearch.client.core.MainResponse; +import org.elasticsearch.client.dataframe.transforms.SyncConfig; +import org.elasticsearch.client.dataframe.transforms.TimeSyncConfig; import org.elasticsearch.client.indexlifecycle.AllocateAction; import org.elasticsearch.client.indexlifecycle.DeleteAction; import org.elasticsearch.client.indexlifecycle.ForceMergeAction; @@ -56,6 +58,13 @@ import org.elasticsearch.client.indexlifecycle.SetPriorityAction; import org.elasticsearch.client.indexlifecycle.ShrinkAction; import org.elasticsearch.client.indexlifecycle.UnfollowAction; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; +import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; @@ -109,6 +118,7 @@ import static org.hamcrest.CoreMatchers.endsWith; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.Matchers.hasItems; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -664,7 +674,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(20, namedXContents.size()); + assertEquals(31, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -674,7 +684,7 @@ public void testProvidedNamedXContents() { categories.put(namedXContent.categoryClass, counter + 1); } } - assertEquals("Had: " + categories, 4, categories.size()); + assertEquals("Had: " + categories, 9, categories.size()); assertEquals(Integer.valueOf(3), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); @@ -698,6 +708,16 @@ public void testProvidedNamedXContents() { assertTrue(names.contains(ShrinkAction.NAME)); assertTrue(names.contains(FreezeAction.NAME)); assertTrue(names.contains(SetPriorityAction.NAME)); + assertEquals(Integer.valueOf(1), categories.get(DataFrameAnalysis.class)); + assertTrue(names.contains(OutlierDetection.NAME.getPreferredName())); + assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class)); + assertTrue(names.contains(TimeSyncConfig.NAME)); + assertEquals(Integer.valueOf(1), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); + assertThat(names, hasItems(BinarySoftClassification.NAME)); + assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); + assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME)); + assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); + assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME)); } public void testApiNamingConventions() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/GetDataFrameTransformResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/GetDataFrameTransformResponseTests.java index f7386e936301b..2bedb7d095fe0 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/GetDataFrameTransformResponseTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/GetDataFrameTransformResponseTests.java @@ -35,7 +35,6 @@ import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester; - public class GetDataFrameTransformResponseTests extends ESTestCase { public void testXContentParser() throws IOException { @@ -79,6 +78,9 @@ private static void toXContent(GetDataFrameTransformResponse response, XContentB @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); + + return new NamedXContentRegistry(namedXContents); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PreviewDataFrameTransformRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PreviewDataFrameTransformRequestTests.java index c91e1cbb1dd91..45d5d879d47f9 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PreviewDataFrameTransformRequestTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PreviewDataFrameTransformRequestTests.java @@ -31,6 +31,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Optional; import static org.elasticsearch.client.dataframe.transforms.SourceConfigTests.randomSourceConfig; @@ -55,7 +56,10 @@ protected boolean supportsUnknownFields() { @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); + + return new NamedXContentRegistry(namedXContents); } public void testValidate() { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PutDataFrameTransformRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PutDataFrameTransformRequestTests.java index 28fd92dcf913f..7c7cd3fa151fe 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PutDataFrameTransformRequestTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PutDataFrameTransformRequestTests.java @@ -31,6 +31,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Optional; import static org.hamcrest.Matchers.containsString; @@ -71,6 +72,9 @@ protected boolean supportsUnknownFields() { @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); + + return new NamedXContentRegistry(namedXContents); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfigTests.java index 84782a8a97062..212ff64555ecc 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfigTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.client.dataframe.transforms; +import org.elasticsearch.client.dataframe.DataFrameNamedXContentProvider; import org.elasticsearch.Version; import org.elasticsearch.client.dataframe.transforms.pivot.PivotConfigTests; import org.elasticsearch.common.settings.Settings; @@ -30,6 +31,7 @@ import java.io.IOException; import java.time.Instant; import java.util.Collections; +import java.util.List; import java.util.function.Predicate; import static org.elasticsearch.client.dataframe.transforms.DestConfigTests.randomDestConfig; @@ -41,12 +43,17 @@ public static DataFrameTransformConfig randomDataFrameTransformConfig() { return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomSourceConfig(), randomDestConfig(), + randomBoolean() ? null : randomSyncConfig(), PivotConfigTests.randomPivotConfig(), randomBoolean() ? null : randomAlphaOfLengthBetween(1, 100), randomBoolean() ? null : Instant.now(), randomBoolean() ? null : Version.CURRENT.toString()); } + public static SyncConfig randomSyncConfig() { + return TimeSyncConfigTests.randomTimeSyncConfig(); + } + @Override protected DataFrameTransformConfig createTestInstance() { return randomDataFrameTransformConfig(); @@ -71,6 +78,9 @@ protected Predicate getRandomFieldsExcludeFilter() { @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); + + return new NamedXContentRegistry(namedXContents); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfigTests.java new file mode 100644 index 0000000000000..dd2a17eb0260d --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfigTests.java @@ -0,0 +1,49 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.dataframe.transforms; + +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class TimeSyncConfigTests extends AbstractXContentTestCase { + + public static TimeSyncConfig randomTimeSyncConfig() { + return new TimeSyncConfig(randomAlphaOfLengthBetween(1, 10), new TimeValue(randomNonNegativeLong())); + } + + @Override + protected TimeSyncConfig createTestInstance() { + return randomTimeSyncConfig(); + } + + @Override + protected TimeSyncConfig doParseInstance(XContentParser parser) throws IOException { + return TimeSyncConfig.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/hlrc/TimeSyncConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/hlrc/TimeSyncConfigTests.java new file mode 100644 index 0000000000000..0c6a0350882a4 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/hlrc/TimeSyncConfigTests.java @@ -0,0 +1,59 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.dataframe.transforms.hlrc; + +import org.elasticsearch.client.AbstractResponseTestCase; +import org.elasticsearch.client.dataframe.transforms.TimeSyncConfig; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; + +public class TimeSyncConfigTests + extends AbstractResponseTestCase { + + public static org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig randomTimeSyncConfig() { + return new org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig(randomAlphaOfLengthBetween(1, 10), + new TimeValue(randomNonNegativeLong())); + } + + public static void assertHlrcEquals(org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig serverTestInstance, + TimeSyncConfig clientInstance) { + assertEquals(serverTestInstance.getField(), clientInstance.getField()); + assertEquals(serverTestInstance.getDelay(), clientInstance.getDelay()); + } + + @Override + protected org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig createServerTestInstance() { + return randomTimeSyncConfig(); + } + + @Override + protected TimeSyncConfig doParseToClientInstance(XContentParser parser) throws IOException { + return TimeSyncConfig.fromXContent(parser); + } + + @Override + protected void assertInstances(org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig serverTestInstance, + TimeSyncConfig clientInstance) { + assertHlrcEquals(serverTestInstance, clientInstance); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/DataFrameTransformDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/DataFrameTransformDocumentationIT.java index 4f94db604f147..b3fa85880b465 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/DataFrameTransformDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/DataFrameTransformDocumentationIT.java @@ -433,6 +433,7 @@ public void testPreview() throws IOException, InterruptedException { .setQueryConfig(queryConfig) .build(), // <1> pivotConfig); // <2> + PreviewDataFrameTransformRequest request = new PreviewDataFrameTransformRequest(transformConfig); // <3> // end::preview-data-frame-transform-request diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index fe7d04a4e0a8d..526e31a5da1ae 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -39,6 +39,7 @@ import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteExpiredDataRequest; import org.elasticsearch.client.ml.DeleteExpiredDataResponse; @@ -47,6 +48,8 @@ import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteJobResponse; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameResponse; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FindFileStructureResponse; import org.elasticsearch.client.ml.FlushJobRequest; @@ -61,8 +64,10 @@ import org.elasticsearch.client.ml.GetCalendarsResponse; import org.elasticsearch.client.ml.GetCategoriesRequest; import org.elasticsearch.client.ml.GetCategoriesResponse; -import org.elasticsearch.client.ml.GetModelSnapshotsRequest; -import org.elasticsearch.client.ml.GetModelSnapshotsResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsResponse; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedResponse; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; @@ -75,6 +80,8 @@ import org.elasticsearch.client.ml.GetJobResponse; import org.elasticsearch.client.ml.GetJobStatsRequest; import org.elasticsearch.client.ml.GetJobStatsResponse; +import org.elasticsearch.client.ml.GetModelSnapshotsRequest; +import org.elasticsearch.client.ml.GetModelSnapshotsResponse; import org.elasticsearch.client.ml.GetOverallBucketsRequest; import org.elasticsearch.client.ml.GetOverallBucketsResponse; import org.elasticsearch.client.ml.GetRecordsRequest; @@ -92,6 +99,8 @@ import org.elasticsearch.client.ml.PutCalendarJobRequest; import org.elasticsearch.client.ml.PutCalendarRequest; import org.elasticsearch.client.ml.PutCalendarResponse; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutDatafeedResponse; import org.elasticsearch.client.ml.PutFilterRequest; @@ -101,8 +110,11 @@ import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.RevertModelSnapshotResponse; import org.elasticsearch.client.ml.SetUpgradeModeRequest; +import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StartDatafeedResponse; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.StopDatafeedResponse; import org.elasticsearch.client.ml.UpdateDatafeedRequest; @@ -118,6 +130,21 @@ import org.elasticsearch.client.ml.datafeed.DatafeedStats; import org.elasticsearch.client.ml.datafeed.DatafeedUpdate; import org.elasticsearch.client.ml.datafeed.DelayedDataCheckConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; +import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.QueryConfig; +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric.ConfusionMatrix; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.AnalysisLimits; @@ -139,13 +166,18 @@ import org.elasticsearch.client.ml.job.results.OverallBucket; import org.elasticsearch.client.ml.job.stats.JobStats; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.tasks.TaskId; +import org.hamcrest.CoreMatchers; import org.junit.After; import java.io.IOException; @@ -870,18 +902,7 @@ public void testPreviewDatafeed() throws Exception { client.machineLearning().putJob(new PutJobRequest(job), RequestOptions.DEFAULT); String datafeedId = job.getId() + "-feed"; String indexName = "preview_data_2"; - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName); DatafeedConfig datafeed = DatafeedConfig.builder(datafeedId, job.getId()) .setIndices(indexName) .build(); @@ -938,18 +959,7 @@ public void testStartDatafeed() throws Exception { client.machineLearning().putJob(new PutJobRequest(job), RequestOptions.DEFAULT); String datafeedId = job.getId() + "-feed"; String indexName = "start_data_2"; - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName); DatafeedConfig datafeed = DatafeedConfig.builder(datafeedId, job.getId()) .setIndices(indexName) .build(); @@ -1067,18 +1077,7 @@ public void testGetDatafeedStats() throws Exception { client.machineLearning().putJob(new PutJobRequest(secondJob), RequestOptions.DEFAULT); String datafeedId1 = job.getId() + "-feed"; String indexName = "datafeed_stats_data_2"; - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName); DatafeedConfig datafeed = DatafeedConfig.builder(datafeedId1, job.getId()) .setIndices(indexName) .build(); @@ -2802,6 +2801,465 @@ public void onFailure(Exception e) { } } + public void testGetDataFrameAnalytics() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + + RestHighLevelClient client = highLevelClient(); + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + // tag::get-data-frame-analytics-request + GetDataFrameAnalyticsRequest request = new GetDataFrameAnalyticsRequest("my-analytics-config"); // <1> + // end::get-data-frame-analytics-request + + // tag::get-data-frame-analytics-execute + GetDataFrameAnalyticsResponse response = client.machineLearning().getDataFrameAnalytics(request, RequestOptions.DEFAULT); + // end::get-data-frame-analytics-execute + + // tag::get-data-frame-analytics-response + List configs = response.getAnalytics(); + // end::get-data-frame-analytics-response + + assertThat(configs.size(), equalTo(1)); + } + { + GetDataFrameAnalyticsRequest request = new GetDataFrameAnalyticsRequest("my-analytics-config"); + + // tag::get-data-frame-analytics-execute-listener + ActionListener listener = new ActionListener() { + @Override + public void onResponse(GetDataFrameAnalyticsResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::get-data-frame-analytics-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::get-data-frame-analytics-execute-async + client.machineLearning().getDataFrameAnalyticsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::get-data-frame-analytics-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } + + public void testGetDataFrameAnalyticsStats() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + + RestHighLevelClient client = highLevelClient(); + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + // tag::get-data-frame-analytics-stats-request + GetDataFrameAnalyticsStatsRequest request = new GetDataFrameAnalyticsStatsRequest("my-analytics-config"); // <1> + // end::get-data-frame-analytics-stats-request + + // tag::get-data-frame-analytics-stats-execute + GetDataFrameAnalyticsStatsResponse response = + client.machineLearning().getDataFrameAnalyticsStats(request, RequestOptions.DEFAULT); + // end::get-data-frame-analytics-stats-execute + + // tag::get-data-frame-analytics-stats-response + List stats = response.getAnalyticsStats(); + // end::get-data-frame-analytics-stats-response + + assertThat(stats.size(), equalTo(1)); + } + { + GetDataFrameAnalyticsStatsRequest request = new GetDataFrameAnalyticsStatsRequest("my-analytics-config"); + + // tag::get-data-frame-analytics-stats-execute-listener + ActionListener listener = new ActionListener() { + @Override + public void onResponse(GetDataFrameAnalyticsStatsResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::get-data-frame-analytics-stats-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::get-data-frame-analytics-stats-execute-async + client.machineLearning().getDataFrameAnalyticsStatsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::get-data-frame-analytics-stats-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } + + public void testPutDataFrameAnalytics() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + + RestHighLevelClient client = highLevelClient(); + { + // tag::put-data-frame-analytics-query-config + QueryConfig queryConfig = new QueryConfig(new MatchAllQueryBuilder()); + // end::put-data-frame-analytics-query-config + + // tag::put-data-frame-analytics-source-config + DataFrameAnalyticsSource sourceConfig = DataFrameAnalyticsSource.builder() // <1> + .setIndex("put-test-source-index") // <2> + .setQueryConfig(queryConfig) // <3> + .build(); + // end::put-data-frame-analytics-source-config + + // tag::put-data-frame-analytics-dest-config + DataFrameAnalyticsDest destConfig = DataFrameAnalyticsDest.builder() // <1> + .setIndex("put-test-dest-index") // <2> + .build(); + // end::put-data-frame-analytics-dest-config + + // tag::put-data-frame-analytics-analysis-default + DataFrameAnalysis outlierDetection = OutlierDetection.createDefault(); // <1> + // end::put-data-frame-analytics-analysis-default + + // tag::put-data-frame-analytics-analysis-customized + DataFrameAnalysis outlierDetectionCustomized = OutlierDetection.builder() // <1> + .setMethod(OutlierDetection.Method.DISTANCE_KNN) // <2> + .setNNeighbors(5) // <3> + .build(); + // end::put-data-frame-analytics-analysis-customized + + // tag::put-data-frame-analytics-analyzed-fields + FetchSourceContext analyzedFields = + new FetchSourceContext( + true, + new String[] { "included_field_1", "included_field_2" }, + new String[] { "excluded_field" }); + // end::put-data-frame-analytics-analyzed-fields + + // tag::put-data-frame-analytics-config + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder("my-analytics-config") // <1> + .setSource(sourceConfig) // <2> + .setDest(destConfig) // <3> + .setAnalysis(outlierDetection) // <4> + .setAnalyzedFields(analyzedFields) // <5> + .setModelMemoryLimit(new ByteSizeValue(5, ByteSizeUnit.MB)) // <6> + .build(); + // end::put-data-frame-analytics-config + + // tag::put-data-frame-analytics-request + PutDataFrameAnalyticsRequest request = new PutDataFrameAnalyticsRequest(config); // <1> + // end::put-data-frame-analytics-request + + // tag::put-data-frame-analytics-execute + PutDataFrameAnalyticsResponse response = client.machineLearning().putDataFrameAnalytics(request, RequestOptions.DEFAULT); + // end::put-data-frame-analytics-execute + + // tag::put-data-frame-analytics-response + DataFrameAnalyticsConfig createdConfig = response.getConfig(); + // end::put-data-frame-analytics-response + + assertThat(createdConfig.getId(), equalTo("my-analytics-config")); + } + { + PutDataFrameAnalyticsRequest request = new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG); + // tag::put-data-frame-analytics-execute-listener + ActionListener listener = new ActionListener() { + @Override + public void onResponse(PutDataFrameAnalyticsResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::put-data-frame-analytics-execute-listener + + // Replace the empty listener by a blocking listener in test + final CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::put-data-frame-analytics-execute-async + client.machineLearning().putDataFrameAnalyticsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::put-data-frame-analytics-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } + + public void testDeleteDataFrameAnalytics() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + + RestHighLevelClient client = highLevelClient(); + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + // tag::delete-data-frame-analytics-request + DeleteDataFrameAnalyticsRequest request = new DeleteDataFrameAnalyticsRequest("my-analytics-config"); // <1> + // end::delete-data-frame-analytics-request + + // tag::delete-data-frame-analytics-execute + AcknowledgedResponse response = client.machineLearning().deleteDataFrameAnalytics(request, RequestOptions.DEFAULT); + // end::delete-data-frame-analytics-execute + + // tag::delete-data-frame-analytics-response + boolean acknowledged = response.isAcknowledged(); + // end::delete-data-frame-analytics-response + + assertThat(acknowledged, is(true)); + } + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + DeleteDataFrameAnalyticsRequest request = new DeleteDataFrameAnalyticsRequest("my-analytics-config"); + + // tag::delete-data-frame-analytics-execute-listener + ActionListener listener = new ActionListener() { + @Override + public void onResponse(AcknowledgedResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::delete-data-frame-analytics-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::delete-data-frame-analytics-execute-async + client.machineLearning().deleteDataFrameAnalyticsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::delete-data-frame-analytics-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } + + public void testStartDataFrameAnalytics() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + highLevelClient().index( + new IndexRequest(DF_ANALYTICS_CONFIG.getSource().getIndex()).source(XContentType.JSON, "total", 10000) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT); + RestHighLevelClient client = highLevelClient(); + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + // tag::start-data-frame-analytics-request + StartDataFrameAnalyticsRequest request = new StartDataFrameAnalyticsRequest("my-analytics-config"); // <1> + // end::start-data-frame-analytics-request + + // tag::start-data-frame-analytics-execute + AcknowledgedResponse response = client.machineLearning().startDataFrameAnalytics(request, RequestOptions.DEFAULT); + // end::start-data-frame-analytics-execute + + // tag::start-data-frame-analytics-response + boolean acknowledged = response.isAcknowledged(); + // end::start-data-frame-analytics-response + + assertThat(acknowledged, is(true)); + } + assertBusy( + () -> assertThat(getAnalyticsState(DF_ANALYTICS_CONFIG.getId()), equalTo(DataFrameAnalyticsState.STOPPED)), + 30, TimeUnit.SECONDS); + { + StartDataFrameAnalyticsRequest request = new StartDataFrameAnalyticsRequest("my-analytics-config"); + + // tag::start-data-frame-analytics-execute-listener + ActionListener listener = new ActionListener() { + @Override + public void onResponse(AcknowledgedResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::start-data-frame-analytics-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::start-data-frame-analytics-execute-async + client.machineLearning().startDataFrameAnalyticsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::start-data-frame-analytics-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + assertBusy( + () -> assertThat(getAnalyticsState(DF_ANALYTICS_CONFIG.getId()), equalTo(DataFrameAnalyticsState.STOPPED)), + 30, TimeUnit.SECONDS); + } + + public void testStopDataFrameAnalytics() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + highLevelClient().index( + new IndexRequest(DF_ANALYTICS_CONFIG.getSource().getIndex()).source(XContentType.JSON, "total", 10000) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT); + RestHighLevelClient client = highLevelClient(); + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + // tag::stop-data-frame-analytics-request + StopDataFrameAnalyticsRequest request = new StopDataFrameAnalyticsRequest("my-analytics-config"); // <1> + // end::stop-data-frame-analytics-request + + // tag::stop-data-frame-analytics-execute + StopDataFrameAnalyticsResponse response = client.machineLearning().stopDataFrameAnalytics(request, RequestOptions.DEFAULT); + // end::stop-data-frame-analytics-execute + + // tag::stop-data-frame-analytics-response + boolean acknowledged = response.isStopped(); + // end::stop-data-frame-analytics-response + + assertThat(acknowledged, is(true)); + } + assertBusy( + () -> assertThat(getAnalyticsState(DF_ANALYTICS_CONFIG.getId()), equalTo(DataFrameAnalyticsState.STOPPED)), + 30, TimeUnit.SECONDS); + { + StopDataFrameAnalyticsRequest request = new StopDataFrameAnalyticsRequest("my-analytics-config"); + + // tag::stop-data-frame-analytics-execute-listener + ActionListener listener = new ActionListener() { + @Override + public void onResponse(StopDataFrameAnalyticsResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::stop-data-frame-analytics-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::stop-data-frame-analytics-execute-async + client.machineLearning().stopDataFrameAnalyticsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::stop-data-frame-analytics-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + assertBusy( + () -> assertThat(getAnalyticsState(DF_ANALYTICS_CONFIG.getId()), equalTo(DataFrameAnalyticsState.STOPPED)), + 30, TimeUnit.SECONDS); + } + + public void testEvaluateDataFrame() throws Exception { + String indexName = "evaluate-test-index"; + CreateIndexRequest createIndexRequest = + new CreateIndexRequest(indexName) + .mapping(XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject("label") + .field("type", "keyword") + .endObject() + .startObject("p") + .field("type", "double") + .endObject() + .endObject() + .endObject()); + BulkRequest bulkRequest = + new BulkRequest(indexName) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.1)) // #0 + .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.2)) // #1 + .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.3)) // #2 + .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.4)) // #3 + .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.7)) // #4 + .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.2)) // #5 + .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.3)) // #6 + .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.4)) // #7 + .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.8)) // #8 + .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.9)); // #9 + RestHighLevelClient client = highLevelClient(); + client.indices().create(createIndexRequest, RequestOptions.DEFAULT); + client.bulk(bulkRequest, RequestOptions.DEFAULT); + { + // tag::evaluate-data-frame-request + EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( // <1> + indexName, // <2> + new BinarySoftClassification( // <3> + "label", // <4> + "p", // <5> + // Evaluation metrics // <6> + PrecisionMetric.at(0.4, 0.5, 0.6), // <7> + RecallMetric.at(0.5, 0.7), // <8> + ConfusionMatrixMetric.at(0.5), // <9> + AucRocMetric.withCurve())); // <10> + // end::evaluate-data-frame-request + + // tag::evaluate-data-frame-execute + EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT); + // end::evaluate-data-frame-execute + + // tag::evaluate-data-frame-response + List metrics = response.getMetrics(); // <1> + + PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <2> + double precision = precisionResult.getScoreByThreshold("0.4"); // <3> + + ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <4> + ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); // <5> + // end::evaluate-data-frame-response + + assertThat( + metrics.stream().map(m -> m.getMetricName()).collect(Collectors.toList()), + containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME)); + assertThat(precision, closeTo(0.6, 1e-9)); + assertThat(confusionMatrix.getTruePositives(), CoreMatchers.equalTo(2L)); // docs #8 and #9 + assertThat(confusionMatrix.getFalsePositives(), CoreMatchers.equalTo(1L)); // doc #4 + assertThat(confusionMatrix.getTrueNegatives(), CoreMatchers.equalTo(4L)); // docs #0, #1, #2 and #3 + assertThat(confusionMatrix.getFalseNegatives(), CoreMatchers.equalTo(3L)); // docs #5, #6 and #7 + } + { + EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( + indexName, + new BinarySoftClassification( + "label", + "p", + PrecisionMetric.at(0.4, 0.5, 0.6), + RecallMetric.at(0.5, 0.7), + ConfusionMatrixMetric.at(0.5), + AucRocMetric.withCurve())); + + // tag::evaluate-data-frame-execute-listener + ActionListener listener = new ActionListener() { + @Override + public void onResponse(EvaluateDataFrameResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::evaluate-data-frame-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::evaluate-data-frame-execute-async + client.machineLearning().evaluateDataFrameAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::evaluate-data-frame-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } public void testCreateFilter() throws Exception { RestHighLevelClient client = highLevelClient(); @@ -3140,4 +3598,39 @@ private String createFilter(RestHighLevelClient client) throws IOException { assertThat(createdFilter.getId(), equalTo("my_safe_domains")); return createdFilter.getId(); } + + private void createIndex(String indexName) throws IOException { + CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); + createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject("timestamp") + .field("type", "date") + .endObject() + .startObject("total") + .field("type", "long") + .endObject() + .endObject() + .endObject()); + highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + } + + private DataFrameAnalyticsState getAnalyticsState(String configId) throws IOException { + GetDataFrameAnalyticsStatsResponse statsResponse = + highLevelClient().machineLearning().getDataFrameAnalyticsStats( + new GetDataFrameAnalyticsStatsRequest(configId), RequestOptions.DEFAULT); + assertThat(statsResponse.getAnalyticsStats(), hasSize(1)); + DataFrameAnalyticsStats stats = statsResponse.getAnalyticsStats().get(0); + return stats.getState(); + } + + private static final DataFrameAnalyticsConfig DF_ANALYTICS_CONFIG = + DataFrameAnalyticsConfig.builder("my-analytics-config") + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("put-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("put-test-dest-index") + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricAucRocPointTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricAucRocPointTests.java new file mode 100644 index 0000000000000..825adcd2060f8 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricAucRocPointTests.java @@ -0,0 +1,47 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class AucRocMetricAucRocPointTests extends AbstractXContentTestCase { + + static AucRocMetric.AucRocPoint randomPoint() { + return new AucRocMetric.AucRocPoint(randomDouble(), randomDouble(), randomDouble()); + } + + @Override + protected AucRocMetric.AucRocPoint createTestInstance() { + return randomPoint(); + } + + @Override + protected AucRocMetric.AucRocPoint doParseInstance(XContentParser parser) throws IOException { + return AucRocMetric.AucRocPoint.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricResultTests.java new file mode 100644 index 0000000000000..9ea7689d60f32 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricResultTests.java @@ -0,0 +1,63 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.elasticsearch.client.ml.AucRocMetricAucRocPointTests.randomPoint; + +public class AucRocMetricResultTests extends AbstractXContentTestCase { + + static AucRocMetric.Result randomResult() { + return new AucRocMetric.Result( + randomDouble(), + Stream + .generate(() -> randomPoint()) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toList())); + } + + @Override + protected AucRocMetric.Result createTestInstance() { + return randomResult(); + } + + @Override + protected AucRocMetric.Result doParseInstance(XContentParser parser) throws IOException { + return AucRocMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only + return field -> !field.isEmpty(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java new file mode 100644 index 0000000000000..28eb221b318c6 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java @@ -0,0 +1,47 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class ConfusionMatrixMetricConfusionMatrixTests extends AbstractXContentTestCase { + + static ConfusionMatrixMetric.ConfusionMatrix randomConfusionMatrix() { + return new ConfusionMatrixMetric.ConfusionMatrix(randomInt(), randomInt(), randomInt(), randomInt()); + } + + @Override + protected ConfusionMatrixMetric.ConfusionMatrix createTestInstance() { + return randomConfusionMatrix(); + } + + @Override + protected ConfusionMatrixMetric.ConfusionMatrix doParseInstance(XContentParser parser) throws IOException { + return ConfusionMatrixMetric.ConfusionMatrix.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricResultTests.java new file mode 100644 index 0000000000000..c4b299a96b536 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricResultTests.java @@ -0,0 +1,62 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.elasticsearch.client.ml.ConfusionMatrixMetricConfusionMatrixTests.randomConfusionMatrix; + +public class ConfusionMatrixMetricResultTests extends AbstractXContentTestCase { + + static ConfusionMatrixMetric.Result randomResult() { + return new ConfusionMatrixMetric.Result( + Stream + .generate(() -> randomConfusionMatrix()) + .limit(randomIntBetween(1, 5)) + .collect(Collectors.toMap(v -> String.valueOf(randomDouble()), v -> v))); + } + + @Override + protected ConfusionMatrixMetric.Result createTestInstance() { + return randomResult(); + } + + @Override + protected ConfusionMatrixMetric.Result doParseInstance(XContentParser parser) throws IOException { + return ConfusionMatrixMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // disallow unknown fields in the root of the object as field names must be parsable as numbers + return field -> field.isEmpty(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..bc2ca2d954e76 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequestTests.java @@ -0,0 +1,39 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class DeleteDataFrameAnalyticsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new DeleteDataFrameAnalyticsRequest("valid-id").validate()); + assertEquals(Optional.empty(), new DeleteDataFrameAnalyticsRequest("").validate()); + } + + public void testValidate_Failure() { + assertThat(new DeleteDataFrameAnalyticsRequest(null).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java new file mode 100644 index 0000000000000..b41d113686ccf --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java @@ -0,0 +1,76 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Predicate; + +public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase { + + public static EvaluateDataFrameResponse randomResponse() { + List metrics = new ArrayList<>(); + if (randomBoolean()) { + metrics.add(AucRocMetricResultTests.randomResult()); + } + if (randomBoolean()) { + metrics.add(PrecisionMetricResultTests.randomResult()); + } + if (randomBoolean()) { + metrics.add(RecallMetricResultTests.randomResult()); + } + if (randomBoolean()) { + metrics.add(ConfusionMatrixMetricResultTests.randomResult()); + } + return new EvaluateDataFrameResponse(randomAlphaOfLength(5), metrics); + } + + @Override + protected EvaluateDataFrameResponse createTestInstance() { + return randomResponse(); + } + + @Override + protected EvaluateDataFrameResponse doParseInstance(XContentParser parser) throws IOException { + return EvaluateDataFrameResponse.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the metrics map (i.e. alongside named metrics like "precision" or "recall") + return field -> field.isEmpty() || field.contains("."); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..56d87ea6bef49 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequestTests.java @@ -0,0 +1,39 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class GetDataFrameAnalyticsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new GetDataFrameAnalyticsRequest("valid-id").validate()); + assertEquals(Optional.empty(), new GetDataFrameAnalyticsRequest("").validate()); + } + + public void testValidate_Failure() { + assertThat(new GetDataFrameAnalyticsRequest(new String[0]).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequestTests.java new file mode 100644 index 0000000000000..4e08d99eaa659 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequestTests.java @@ -0,0 +1,39 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class GetDataFrameAnalyticsStatsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new GetDataFrameAnalyticsStatsRequest("valid-id").validate()); + assertEquals(Optional.empty(), new GetDataFrameAnalyticsStatsRequest("").validate()); + } + + public void testValidate_Failure() { + assertThat(new GetDataFrameAnalyticsStatsRequest(new String[0]).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PrecisionMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PrecisionMetricResultTests.java new file mode 100644 index 0000000000000..607adacebb827 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PrecisionMetricResultTests.java @@ -0,0 +1,60 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class PrecisionMetricResultTests extends AbstractXContentTestCase { + + static PrecisionMetric.Result randomResult() { + return new PrecisionMetric.Result( + Stream + .generate(() -> randomDouble()) + .limit(randomIntBetween(1, 5)) + .collect(Collectors.toMap(v -> String.valueOf(randomDouble()), v -> v))); + } + + @Override + protected PrecisionMetric.Result createTestInstance() { + return randomResult(); + } + + @Override + protected PrecisionMetric.Result doParseInstance(XContentParser parser) throws IOException { + return PrecisionMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // disallow unknown fields in the root of the object as field names must be parsable as numbers + return field -> field.isEmpty(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..19bc68fa36118 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequestTests.java @@ -0,0 +1,74 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class PutDataFrameAnalyticsRequestTests extends AbstractXContentTestCase { + + public void testValidate_Ok() { + assertFalse(createTestInstance().validate().isPresent()); + } + + public void testValidate_Failure() { + Optional exception = new PutDataFrameAnalyticsRequest(null).validate(); + assertTrue(exception.isPresent()); + assertThat(exception.get().getMessage(), containsString("put requires a non-null data frame analytics config")); + } + + @Override + protected PutDataFrameAnalyticsRequest createTestInstance() { + return new PutDataFrameAnalyticsRequest(DataFrameAnalyticsConfigTests.randomDataFrameAnalyticsConfig()); + } + + @Override + protected PutDataFrameAnalyticsRequest doParseInstance(XContentParser parser) throws IOException { + return new PutDataFrameAnalyticsRequest(DataFrameAnalyticsConfig.fromXContent(parser)); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContent); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/RecallMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/RecallMetricResultTests.java new file mode 100644 index 0000000000000..138875007e30d --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/RecallMetricResultTests.java @@ -0,0 +1,60 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class RecallMetricResultTests extends AbstractXContentTestCase { + + static RecallMetric.Result randomResult() { + return new RecallMetric.Result( + Stream + .generate(() -> randomDouble()) + .limit(randomIntBetween(1, 5)) + .collect(Collectors.toMap(v -> String.valueOf(randomDouble()), v -> v))); + } + + @Override + protected RecallMetric.Result createTestInstance() { + return randomResult(); + } + + @Override + protected RecallMetric.Result doParseInstance(XContentParser parser) throws IOException { + return RecallMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // disallow unknown fields in the root of the object as field names must be parsable as numbers + return field -> field.isEmpty(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..6e43b50bcd12b --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequestTests.java @@ -0,0 +1,43 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class StartDataFrameAnalyticsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new StartDataFrameAnalyticsRequest("foo").validate()); + assertEquals(Optional.empty(), new StartDataFrameAnalyticsRequest("foo").setTimeout(null).validate()); + assertEquals(Optional.empty(), new StartDataFrameAnalyticsRequest("foo").setTimeout(TimeValue.ZERO).validate()); + } + + public void testValidate_Failure() { + assertThat(new StartDataFrameAnalyticsRequest(null).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + assertThat(new StartDataFrameAnalyticsRequest(null).setTimeout(TimeValue.ZERO).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..57af2083743ae --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequestTests.java @@ -0,0 +1,43 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class StopDataFrameAnalyticsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new StopDataFrameAnalyticsRequest("foo").validate()); + assertEquals(Optional.empty(), new StopDataFrameAnalyticsRequest("foo").setTimeout(null).validate()); + assertEquals(Optional.empty(), new StopDataFrameAnalyticsRequest("foo").setTimeout(TimeValue.ZERO).validate()); + } + + public void testValidate_Failure() { + assertThat(new StopDataFrameAnalyticsRequest(null).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + assertThat(new StopDataFrameAnalyticsRequest(null).setTimeout(TimeValue.ZERO).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponseTests.java new file mode 100644 index 0000000000000..55ef1aed7534a --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponseTests.java @@ -0,0 +1,42 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class StopDataFrameAnalyticsResponseTests extends AbstractXContentTestCase { + + @Override + protected StopDataFrameAnalyticsResponse createTestInstance() { + return new StopDataFrameAnalyticsResponse(randomBoolean()); + } + + @Override + protected StopDataFrameAnalyticsResponse doParseInstance(XContentParser parser) throws IOException { + return StopDataFrameAnalyticsResponse.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfigTests.java new file mode 100644 index 0000000000000..4eba642401054 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -0,0 +1,88 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Predicate; + +import static org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsSourceTests.randomSourceConfig; +import static org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDestTests.randomDestConfig; +import static org.elasticsearch.client.ml.dataframe.OutlierDetectionTests.randomOutlierDetection; + +public class DataFrameAnalyticsConfigTests extends AbstractXContentTestCase { + + public static DataFrameAnalyticsConfig randomDataFrameAnalyticsConfig() { + DataFrameAnalyticsConfig.Builder builder = + DataFrameAnalyticsConfig.builder(randomAlphaOfLengthBetween(1, 10)) + .setSource(randomSourceConfig()) + .setDest(randomDestConfig()) + .setAnalysis(randomOutlierDetection()); + if (randomBoolean()) { + builder.setAnalyzedFields(new FetchSourceContext(true, + generateRandomStringArray(10, 10, false, false), + generateRandomStringArray(10, 10, false, false))); + } + if (randomBoolean()) { + builder.setModelMemoryLimit(new ByteSizeValue(randomIntBetween(1, 16), randomFrom(ByteSizeUnit.MB, ByteSizeUnit.GB))); + } + return builder.build(); + } + + @Override + protected DataFrameAnalyticsConfig createTestInstance() { + return randomDataFrameAnalyticsConfig(); + } + + @Override + protected DataFrameAnalyticsConfig doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsConfig.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only + return field -> !field.isEmpty(); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContent); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDestTests.java new file mode 100644 index 0000000000000..dce7ca5204d57 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDestTests.java @@ -0,0 +1,50 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class DataFrameAnalyticsDestTests extends AbstractXContentTestCase { + + public static DataFrameAnalyticsDest randomDestConfig() { + return DataFrameAnalyticsDest.builder() + .setIndex(randomAlphaOfLengthBetween(1, 10)) + .setResultsField(randomBoolean() ? null : randomAlphaOfLengthBetween(1, 10)) + .build(); + } + + @Override + protected DataFrameAnalyticsDest doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsDest.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected DataFrameAnalyticsDest createTestInstance() { + return randomDestConfig(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSourceTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSourceTests.java new file mode 100644 index 0000000000000..eb254fd23de09 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSourceTests.java @@ -0,0 +1,70 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; + +import static java.util.Collections.emptyList; +import static org.elasticsearch.client.ml.dataframe.QueryConfigTests.randomQueryConfig; + + +public class DataFrameAnalyticsSourceTests extends AbstractXContentTestCase { + + public static DataFrameAnalyticsSource randomSourceConfig() { + return DataFrameAnalyticsSource.builder() + .setIndex(randomAlphaOfLengthBetween(1, 10)) + .setQueryConfig(randomBoolean() ? null : randomQueryConfig()) + .build(); + } + + @Override + protected DataFrameAnalyticsSource doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsSource.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only as QueryConfig stores a Map + return field -> !field.isEmpty(); + } + + @Override + protected DataFrameAnalyticsSource createTestInstance() { + return randomSourceConfig(); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, emptyList()); + return new NamedXContentRegistry(searchModule.getNamedXContents()); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java new file mode 100644 index 0000000000000..ed6e24f754d19 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java @@ -0,0 +1,66 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.client.ml.NodeAttributesTests; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester; + +public class DataFrameAnalyticsStatsTests extends ESTestCase { + + public void testFromXContent() throws IOException { + xContentTester(this::createParser, + DataFrameAnalyticsStatsTests::randomDataFrameAnalyticsStats, + DataFrameAnalyticsStatsTests::toXContent, + DataFrameAnalyticsStats::fromXContent) + .supportsUnknownFields(true) + .randomFieldsExcludeFilter(field -> field.startsWith("node.attributes")) + .test(); + } + + public static DataFrameAnalyticsStats randomDataFrameAnalyticsStats() { + return new DataFrameAnalyticsStats( + randomAlphaOfLengthBetween(1, 10), + randomFrom(DataFrameAnalyticsState.values()), + randomBoolean() ? null : randomIntBetween(0, 100), + randomBoolean() ? null : NodeAttributesTests.createRandom(), + randomBoolean() ? null : randomAlphaOfLengthBetween(1, 20)); + } + + public static void toXContent(DataFrameAnalyticsStats stats, XContentBuilder builder) throws IOException { + builder.startObject(); + builder.field(DataFrameAnalyticsStats.ID.getPreferredName(), stats.getId()); + builder.field(DataFrameAnalyticsStats.STATE.getPreferredName(), stats.getState().value()); + if (stats.getProgressPercent() != null) { + builder.field(DataFrameAnalyticsStats.PROGRESS_PERCENT.getPreferredName(), stats.getProgressPercent()); + } + if (stats.getNode() != null) { + builder.field(DataFrameAnalyticsStats.NODE.getPreferredName(), stats.getNode()); + } + if (stats.getAssignmentExplanation() != null) { + builder.field(DataFrameAnalyticsStats.ASSIGNMENT_EXPLANATION.getPreferredName(), stats.getAssignmentExplanation()); + } + builder.endObject(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java new file mode 100644 index 0000000000000..de110d92fdee1 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java @@ -0,0 +1,73 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; + +public class OutlierDetectionTests extends AbstractXContentTestCase { + + public static OutlierDetection randomOutlierDetection() { + return OutlierDetection.builder() + .setNNeighbors(randomBoolean() ? null : randomIntBetween(1, 20)) + .setMethod(randomBoolean() ? null : randomFrom(OutlierDetection.Method.values())) + .setMinScoreToWriteFeatureInfluence(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, true)) + .build(); + } + + @Override + protected OutlierDetection doParseInstance(XContentParser parser) throws IOException { + return OutlierDetection.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected OutlierDetection createTestInstance() { + return randomOutlierDetection(); + } + + public void testGetParams_GivenDefaults() { + OutlierDetection outlierDetection = OutlierDetection.createDefault(); + assertNull(outlierDetection.getNNeighbors()); + assertNull(outlierDetection.getMethod()); + assertNull(outlierDetection.getMinScoreToWriteFeatureInfluence()); + } + + public void testGetParams_GivenExplicitValues() { + OutlierDetection outlierDetection = + OutlierDetection.builder() + .setNNeighbors(42) + .setMethod(OutlierDetection.Method.LDOF) + .setMinScoreToWriteFeatureInfluence(0.5) + .build(); + assertThat(outlierDetection.getNNeighbors(), equalTo(42)); + assertThat(outlierDetection.getMethod(), equalTo(OutlierDetection.Method.LDOF)); + assertThat(outlierDetection.getMinScoreToWriteFeatureInfluence(), closeTo(0.5, 1E-9)); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/QueryConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/QueryConfigTests.java new file mode 100644 index 0000000000000..1e66445100b3e --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/QueryConfigTests.java @@ -0,0 +1,62 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +import static java.util.Collections.emptyList; + +public class QueryConfigTests extends AbstractXContentTestCase { + + public static QueryConfig randomQueryConfig() { + QueryBuilder queryBuilder = randomBoolean() ? new MatchAllQueryBuilder() : new MatchNoneQueryBuilder(); + return new QueryConfig(queryBuilder); + } + + @Override + protected QueryConfig createTestInstance() { + return randomQueryConfig(); + } + + @Override + protected QueryConfig doParseInstance(XContentParser parser) throws IOException { + return QueryConfig.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, emptyList()); + return new NamedXContentRegistry(searchModule.getNamedXContents()); + } +} diff --git a/docs/java-rest/high-level/ml/delete-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/delete-data-frame-analytics.asciidoc new file mode 100644 index 0000000000000..2e5ade37107cf --- /dev/null +++ b/docs/java-rest/high-level/ml/delete-data-frame-analytics.asciidoc @@ -0,0 +1,28 @@ +-- +:api: delete-data-frame-analytics +:request: DeleteDataFrameAnalyticsRequest +:response: AcknowledgedResponse +-- +[id="{upid}-{api}"] +=== Delete Data Frame Analytics API + +The Delete Data Frame Analytics API is used to delete an existing {dataframe-analytics-config}. +The API accepts a +{request}+ object as a request and returns a +{response}+. + +[id="{upid}-{api}-request"] +==== Delete Data Frame Analytics Request + +A +{request}+ object requires a {dataframe-analytics-config} id. + +["source","java",subs="attributes,callouts,macros"] +--------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +--------------------------------------------------- +<1> Constructing a new request referencing an existing {dataframe-analytics-config} + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ object acknowledges the {dataframe-analytics-config} deletion. diff --git a/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc new file mode 100644 index 0000000000000..660603d2e38e7 --- /dev/null +++ b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc @@ -0,0 +1,45 @@ +-- +:api: evaluate-data-frame +:request: EvaluateDataFrameRequest +:response: EvaluateDataFrameResponse +-- +[id="{upid}-{api}"] +=== Evaluate Data Frame API + +The Evaluate Data Frame API is used to evaluate an ML algorithm that ran on a {dataframe}. +The API accepts an +{request}+ object and returns an +{response}+. + +[id="{upid}-{api}-request"] +==== Evaluate Data Frame Request + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +-------------------------------------------------- +<1> Constructing a new evaluation request +<2> Reference to an existing index +<3> Kind of evaluation to perform +<4> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false +<5> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive +<6> The remaining parameters are the metrics to be calculated based on the two fields described above. +<7> https://en.wikipedia.org/wiki/Precision_and_recall[Precision] calculated at thresholds: 0.4, 0.5 and 0.6 +<8> https://en.wikipedia.org/wiki/Precision_and_recall[Recall] calculated at thresholds: 0.5 and 0.7 +<9> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5 +<10> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ contains the requested evaluation metrics. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-response] +-------------------------------------------------- +<1> Fetching all the calculated metrics results +<2> Fetching precision metric by name +<3> Fetching precision at a given (0.4) threshold +<4> Fetching confusion matrix metric by name +<5> Fetching confusion matrix at a given (0.5) threshold \ No newline at end of file diff --git a/docs/java-rest/high-level/ml/get-data-frame-analytics-stats.asciidoc b/docs/java-rest/high-level/ml/get-data-frame-analytics-stats.asciidoc new file mode 100644 index 0000000000000..e1047e9b3e002 --- /dev/null +++ b/docs/java-rest/high-level/ml/get-data-frame-analytics-stats.asciidoc @@ -0,0 +1,34 @@ +-- +:api: get-data-frame-analytics-stats +:request: GetDataFrameAnalyticsStatsRequest +:response: GetDataFrameAnalyticsStatsResponse +-- +[id="{upid}-{api}"] +=== Get Data Frame Analytics Stats API + +The Get Data Frame Analytics Stats API is used to read the operational statistics of one or more {dataframe-analytics-config}s. +The API accepts a +{request}+ object and returns a +{response}+. + +[id="{upid}-{api}-request"] +==== Get Data Frame Analytics Stats Request + +A +{request}+ requires either a {dataframe-analytics-config} id, a comma separated list of ids or +the special wildcard `_all` to get the statistics for all {dataframe-analytics-config}s + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +-------------------------------------------------- +<1> Constructing a new GET Stats request referencing an existing {dataframe-analytics-config} + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ contains the requested {dataframe-analytics-config} statistics. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-response] +-------------------------------------------------- \ No newline at end of file diff --git a/docs/java-rest/high-level/ml/get-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/get-data-frame-analytics.asciidoc new file mode 100644 index 0000000000000..c6d368efbcae9 --- /dev/null +++ b/docs/java-rest/high-level/ml/get-data-frame-analytics.asciidoc @@ -0,0 +1,34 @@ +-- +:api: get-data-frame-analytics +:request: GetDataFrameAnalyticsRequest +:response: GetDataFrameAnalyticsResponse +-- +[id="{upid}-{api}"] +=== Get Data Frame Analytics API + +The Get Data Frame Analytics API is used to get one or more {dataframe-analytics-config}s. +The API accepts a +{request}+ object and returns a +{response}+. + +[id="{upid}-{api}-request"] +==== Get Data Frame Analytics Request + +A +{request}+ requires either a {dataframe-analytics-config} id, a comma separated list of ids or +the special wildcard `_all` to get all {dataframe-analytics-config}s. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +-------------------------------------------------- +<1> Constructing a new GET request referencing an existing {dataframe-analytics-config} + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ contains the requested {dataframe-analytics-config}s. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-response] +-------------------------------------------------- diff --git a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc new file mode 100644 index 0000000000000..05fbd5bc3922a --- /dev/null +++ b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc @@ -0,0 +1,115 @@ +-- +:api: put-data-frame-analytics +:request: PutDataFrameAnalyticsRequest +:response: PutDataFrameAnalyticsResponse +-- +[id="{upid}-{api}"] +=== Put Data Frame Analytics API + +The Put Data Frame Analytics API is used to create a new {dataframe-analytics-config}. +The API accepts a +{request}+ object as a request and returns a +{response}+. + +[id="{upid}-{api}-request"] +==== Put Data Frame Analytics Request + +A +{request}+ requires the following argument: + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +-------------------------------------------------- +<1> The configuration of the {dataframe-job} to create + +[id="{upid}-{api}-config"] +==== Data Frame Analytics Configuration + +The `DataFrameAnalyticsConfig` object contains all the details about the {dataframe-job} +configuration and contains the following arguments: + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-config] +-------------------------------------------------- +<1> The {dataframe-analytics-config} id +<2> The source index and query from which to gather data +<3> The destination index +<4> The analysis to be performed +<5> The fields to be included in / excluded from the analysis +<6> The memory limit for the model created as part of the analysis process + +[id="{upid}-{api}-query-config"] + +==== SourceConfig + +The index and the query from which to collect data. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-source-config] +-------------------------------------------------- +<1> Constructing a new DataFrameAnalyticsSource +<2> The source index +<3> The query from which to gather the data. If query is not set, a `match_all` query is used by default. + +===== QueryConfig + +The query with which to select data from the source. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-query-config] +-------------------------------------------------- + +==== DestinationConfig + +The index to which data should be written by the {dataframe-job}. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-dest-config] +-------------------------------------------------- +<1> Constructing a new DataFrameAnalyticsDest +<2> The destination index + +==== Analysis + +The analysis to be performed. +Currently, only one analysis is supported: +OutlierDetection+. + ++OutlierDetection+ analysis can be created in one of two ways: + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-analysis-default] +-------------------------------------------------- +<1> Constructing a new OutlierDetection object with default strategy to determine outliers + +or +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-analysis-customized] +-------------------------------------------------- +<1> Constructing a new OutlierDetection object +<2> The method used to perform the analysis +<3> Number of neighbors taken into account during analysis + +==== Analyzed fields + +FetchContext object containing fields to be included in / excluded from the analysis + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-analyzed-fields] +-------------------------------------------------- + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ contains the newly created {dataframe-analytics-config}. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-response] +-------------------------------------------------- \ No newline at end of file diff --git a/docs/java-rest/high-level/ml/start-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/start-data-frame-analytics.asciidoc new file mode 100644 index 0000000000000..610607daba1f8 --- /dev/null +++ b/docs/java-rest/high-level/ml/start-data-frame-analytics.asciidoc @@ -0,0 +1,28 @@ +-- +:api: start-data-frame-analytics +:request: StartDataFrameAnalyticsRequest +:response: AcknowledgedResponse +-- +[id="{upid}-{api}"] +=== Start Data Frame Analytics API + +The Start Data Frame Analytics API is used to start an existing {dataframe-analytics-config}. +It accepts a +{request}+ object and responds with a +{response}+ object. + +[id="{upid}-{api}-request"] +==== Start Data Frame Analytics Request + +A +{request}+ object requires a {dataframe-analytics-config} id. + +["source","java",subs="attributes,callouts,macros"] +--------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +--------------------------------------------------- +<1> Constructing a new start request referencing an existing {dataframe-analytics-config} + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ object acknowledges the {dataframe-job} has started. \ No newline at end of file diff --git a/docs/java-rest/high-level/ml/stop-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/stop-data-frame-analytics.asciidoc new file mode 100644 index 0000000000000..243c075e18b03 --- /dev/null +++ b/docs/java-rest/high-level/ml/stop-data-frame-analytics.asciidoc @@ -0,0 +1,28 @@ +-- +:api: stop-data-frame-analytics +:request: StopDataFrameAnalyticsRequest +:response: StopDataFrameAnalyticsResponse +-- +[id="{upid}-{api}"] +=== Stop Data Frame Analytics API + +The Stop Data Frame Analytics API is used to stop a running {dataframe-analytics-config}. +It accepts a +{request}+ object and responds with a +{response}+ object. + +[id="{upid}-{api}-request"] +==== Stop Data Frame Analytics Request + +A +{request}+ object requires a {dataframe-analytics-config} id. + +["source","java",subs="attributes,callouts,macros"] +--------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +--------------------------------------------------- +<1> Constructing a new stop request referencing an existing {dataframe-analytics-config} + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ object acknowledges the {dataframe-job} has stopped. \ No newline at end of file diff --git a/docs/java-rest/high-level/supported-apis.asciidoc b/docs/java-rest/high-level/supported-apis.asciidoc index 4e28efc2941db..21ebdfab65155 100644 --- a/docs/java-rest/high-level/supported-apis.asciidoc +++ b/docs/java-rest/high-level/supported-apis.asciidoc @@ -285,6 +285,13 @@ The Java High Level REST Client supports the following Machine Learning APIs: * <<{upid}-put-calendar-job>> * <<{upid}-delete-calendar-job>> * <<{upid}-delete-calendar>> +* <<{upid}-get-data-frame-analytics>> +* <<{upid}-get-data-frame-analytics-stats>> +* <<{upid}-put-data-frame-analytics>> +* <<{upid}-delete-data-frame-analytics>> +* <<{upid}-start-data-frame-analytics>> +* <<{upid}-stop-data-frame-analytics>> +* <<{upid}-evaluate-data-frame>> * <<{upid}-put-filter>> * <<{upid}-get-filters>> * <<{upid}-update-filter>> @@ -329,6 +336,13 @@ include::ml/delete-calendar-event.asciidoc[] include::ml/put-calendar-job.asciidoc[] include::ml/delete-calendar-job.asciidoc[] include::ml/delete-calendar.asciidoc[] +include::ml/get-data-frame-analytics.asciidoc[] +include::ml/get-data-frame-analytics-stats.asciidoc[] +include::ml/put-data-frame-analytics.asciidoc[] +include::ml/delete-data-frame-analytics.asciidoc[] +include::ml/start-data-frame-analytics.asciidoc[] +include::ml/stop-data-frame-analytics.asciidoc[] +include::ml/evaluate-data-frame.asciidoc[] include::ml/put-filter.asciidoc[] include::ml/get-filters.asciidoc[] include::ml/update-filter.asciidoc[] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 2f3eb27ee8e96..d4a98dbdb9c87 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -54,6 +54,8 @@ import org.elasticsearch.xpack.core.dataframe.action.StopDataFrameTransformAction; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransform; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformState; +import org.elasticsearch.xpack.core.dataframe.transforms.SyncConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; import org.elasticsearch.xpack.core.deprecation.DeprecationInfoAction; import org.elasticsearch.xpack.core.graph.GraphFeatureSetUsage; import org.elasticsearch.xpack.core.graph.action.GraphExploreAction; @@ -85,12 +87,14 @@ import org.elasticsearch.xpack.core.ml.action.CloseJobAction; import org.elasticsearch.xpack.core.ml.action.DeleteCalendarAction; import org.elasticsearch.xpack.core.ml.action.DeleteCalendarEventAction; +import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.DeleteDatafeedAction; import org.elasticsearch.xpack.core.ml.action.DeleteExpiredDataAction; import org.elasticsearch.xpack.core.ml.action.DeleteFilterAction; import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction; import org.elasticsearch.xpack.core.ml.action.DeleteJobAction; import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; +import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction; import org.elasticsearch.xpack.core.ml.action.FindFileStructureAction; import org.elasticsearch.xpack.core.ml.action.FlushJobAction; @@ -99,6 +103,8 @@ import org.elasticsearch.xpack.core.ml.action.GetCalendarEventsAction; import org.elasticsearch.xpack.core.ml.action.GetCalendarsAction; import org.elasticsearch.xpack.core.ml.action.GetCategoriesAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetDatafeedsAction; import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetFiltersAction; @@ -117,12 +123,15 @@ import org.elasticsearch.xpack.core.ml.action.PostDataAction; import org.elasticsearch.xpack.core.ml.action.PreviewDatafeedAction; import org.elasticsearch.xpack.core.ml.action.PutCalendarAction; +import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction; import org.elasticsearch.xpack.core.ml.action.PutFilterAction; import org.elasticsearch.xpack.core.ml.action.PutJobAction; import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; +import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; import org.elasticsearch.xpack.core.ml.action.UpdateCalendarJobAction; import org.elasticsearch.xpack.core.ml.action.UpdateDatafeedAction; @@ -133,6 +142,18 @@ import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction; import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; import org.elasticsearch.xpack.core.monitoring.MonitoringFeatureSetUsage; import org.elasticsearch.xpack.core.rollup.RollupFeatureSetUsage; @@ -310,6 +331,13 @@ public List> getClientActions() { PersistJobAction.INSTANCE, FindFileStructureAction.INSTANCE, SetUpgradeModeAction.INSTANCE, + PutDataFrameAnalyticsAction.INSTANCE, + GetDataFrameAnalyticsAction.INSTANCE, + GetDataFrameAnalyticsStatsAction.INSTANCE, + DeleteDataFrameAnalyticsAction.INSTANCE, + StartDataFrameAnalyticsAction.INSTANCE, + StopDataFrameAnalyticsAction.INSTANCE, + EvaluateDataFrameAction.INSTANCE, // security ClearRealmCacheAction.INSTANCE, ClearRolesCacheAction.INSTANCE, @@ -402,11 +430,30 @@ public List getNamedWriteables() { StartDatafeedAction.DatafeedParams::new), new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.JOB_TASK_NAME, OpenJobAction.JobParams::new), + new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, + StartDataFrameAnalyticsAction.TaskParams::new), // ML - Task states new NamedWriteableRegistry.Entry(PersistentTaskState.class, JobTaskState.NAME, JobTaskState::new), new NamedWriteableRegistry.Entry(PersistentTaskState.class, DatafeedState.NAME, DatafeedState::fromStream), + new NamedWriteableRegistry.Entry(PersistentTaskState.class, DataFrameAnalyticsTaskState.NAME, + DataFrameAnalyticsTaskState::new), new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MACHINE_LEARNING, MachineLearningFeatureSetUsage::new), + // ML - Data frame analytics + new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new), + // ML - Data frame evaluation + new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), + BinarySoftClassification::new), + new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), AucRoc::new), + new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(), Precision::new), + new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(), Recall::new), + new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(), + ConfusionMatrix::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(), AucRoc.Result::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME, ScoreByThresholdResult::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(), + ConfusionMatrix.Result::new), + // monitoring new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new), // security @@ -467,6 +514,7 @@ public List getNamedWriteables() { new NamedWriteableRegistry.Entry(PersistentTaskParams.class, DataFrameField.TASK_NAME, DataFrameTransform::new), new NamedWriteableRegistry.Entry(Task.Status.class, DataFrameField.TASK_NAME, DataFrameTransformState::new), new NamedWriteableRegistry.Entry(PersistentTaskState.class, DataFrameField.TASK_NAME, DataFrameTransformState::new), + new NamedWriteableRegistry.Entry(SyncConfig.class, DataFrameField.TIME_BASED_SYNC.getPreferredName(), TimeSyncConfig::new), // Vectors new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.VECTORS, VectorsFeatureSetUsage::new) ); @@ -483,9 +531,13 @@ public List getNamedXContent() { StartDatafeedAction.DatafeedParams::fromXContent), new NamedXContentRegistry.Entry(PersistentTaskParams.class, new ParseField(MlTasks.JOB_TASK_NAME), OpenJobAction.JobParams::fromXContent), + new NamedXContentRegistry.Entry(PersistentTaskParams.class, new ParseField(MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME), + StartDataFrameAnalyticsAction.TaskParams::fromXContent), // ML - Task states new NamedXContentRegistry.Entry(PersistentTaskState.class, new ParseField(DatafeedState.NAME), DatafeedState::fromXContent), new NamedXContentRegistry.Entry(PersistentTaskState.class, new ParseField(JobTaskState.NAME), JobTaskState::fromXContent), + new NamedXContentRegistry.Entry(PersistentTaskState.class, new ParseField(DataFrameAnalyticsTaskState.NAME), + DataFrameAnalyticsTaskState::fromXContent), // watcher new NamedXContentRegistry.Entry(MetaData.Custom.class, new ParseField(WatcherMetaData.TYPE), WatcherMetaData::fromXContent), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameField.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameField.java index c61ed2ddde8be..71878c4894d6a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameField.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameField.java @@ -28,6 +28,10 @@ public final class DataFrameField { public static final ParseField DESTINATION = new ParseField("dest"); public static final ParseField FORCE = new ParseField("force"); public static final ParseField MAX_PAGE_SEARCH_SIZE = new ParseField("max_page_search_size"); + public static final ParseField FIELD = new ParseField("field"); + public static final ParseField SYNC = new ParseField("sync"); + public static final ParseField TIME_BASED_SYNC = new ParseField("time"); + public static final ParseField DELAY = new ParseField("delay"); /** * Fields for checkpointing diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameMessages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameMessages.java index e6e6ac860e37c..7fe51feb2260a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameMessages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameMessages.java @@ -38,7 +38,9 @@ public class DataFrameMessages { public static final String FAILED_TO_PARSE_TRANSFORM_CONFIGURATION = "Failed to parse transform configuration for data frame transform [{0}]"; public static final String FAILED_TO_PARSE_TRANSFORM_STATISTICS_CONFIGURATION = - "Failed to parse transform statistics for data frame transform [{0}]"; + "Failed to parse transform statistics for data frame transform [{0}]"; + public static final String FAILED_TO_LOAD_TRANSFORM_CHECKPOINT = + "Failed to load data frame transform configuration for transform [{0}]"; public static final String DATA_FRAME_TRANSFORM_CONFIGURATION_NO_TRANSFORM = "Data frame transform configuration must specify exactly 1 function"; public static final String DATA_FRAME_TRANSFORM_CONFIGURATION_PIVOT_NO_GROUP_BY = diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameNamedXContentProvider.java new file mode 100644 index 0000000000000..9eacfc5ff1eae --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameNamedXContentProvider.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.dataframe; + +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.dataframe.transforms.SyncConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; + +import java.util.Arrays; +import java.util.List; + +public class DataFrameNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + return Arrays.asList( + new NamedXContentRegistry.Entry(SyncConfig.class, + DataFrameField.TIME_BASED_SYNC, + TimeSyncConfig::parse)); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfig.java index 19d4d6ab6eed1..2762e0507ef06 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfig.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.xpack.core.dataframe.DataFrameField; import org.elasticsearch.xpack.core.dataframe.DataFrameMessages; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig; @@ -55,6 +56,7 @@ public class DataFrameTransformConfig extends AbstractDiffable create SourceConfig source = (SourceConfig) args[1]; DestConfig dest = (DestConfig) args[2]; - // ignored, only for internal storage: String docType = (String) args[3]; + SyncConfig syncConfig = (SyncConfig) args[3]; + // ignored, only for internal storage: String docType = (String) args[4]; // on strict parsing do not allow injection of headers, transform version, or create time if (lenient == false) { - validateStrictParsingParams(args[4], HEADERS.getPreferredName()); - validateStrictParsingParams(args[7], CREATE_TIME.getPreferredName()); - validateStrictParsingParams(args[8], VERSION.getPreferredName()); + validateStrictParsingParams(args[5], HEADERS.getPreferredName()); + validateStrictParsingParams(args[8], CREATE_TIME.getPreferredName()); + validateStrictParsingParams(args[9], VERSION.getPreferredName()); } @SuppressWarnings("unchecked") - Map headers = (Map) args[4]; + Map headers = (Map) args[5]; - PivotConfig pivotConfig = (PivotConfig) args[5]; - String description = (String)args[6]; + PivotConfig pivotConfig = (PivotConfig) args[6]; + String description = (String)args[7]; return new DataFrameTransformConfig(id, source, dest, + syncConfig, headers, pivotConfig, description, - (Instant)args[7], - (String)args[8]); + (Instant)args[8], + (String)args[9]); }); parser.declareString(optionalConstructorArg(), DataFrameField.ID); parser.declareObject(constructorArg(), (p, c) -> SourceConfig.fromXContent(p, lenient), DataFrameField.SOURCE); parser.declareObject(constructorArg(), (p, c) -> DestConfig.fromXContent(p, lenient), DataFrameField.DESTINATION); + parser.declareObject(optionalConstructorArg(), (p, c) -> parseSyncConfig(p, lenient), DataFrameField.SYNC); + parser.declareString(optionalConstructorArg(), DataFrameField.INDEX_DOC_TYPE); + parser.declareObject(optionalConstructorArg(), (p, c) -> p.mapStrings(), HEADERS); parser.declareObject(optionalConstructorArg(), (p, c) -> PivotConfig.fromXContent(p, lenient), PIVOT_TRANSFORM); parser.declareString(optionalConstructorArg(), DESCRIPTION); @@ -124,6 +131,14 @@ private static ConstructingObjectParser create return parser; } + private static SyncConfig parseSyncConfig(XContentParser parser, boolean ignoreUnknownFields) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + SyncConfig syncConfig = parser.namedObject(SyncConfig.class, parser.currentName(), ignoreUnknownFields); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return syncConfig; + } + public static String documentId(String transformId) { return NAME + "-" + transformId; } @@ -131,6 +146,7 @@ public static String documentId(String transformId) { DataFrameTransformConfig(final String id, final SourceConfig source, final DestConfig dest, + final SyncConfig syncConfig, final Map headers, final PivotConfig pivotConfig, final String description, @@ -139,6 +155,7 @@ public static String documentId(String transformId) { this.id = ExceptionsHelper.requireNonNull(id, DataFrameField.ID.getPreferredName()); this.source = ExceptionsHelper.requireNonNull(source, DataFrameField.SOURCE.getPreferredName()); this.dest = ExceptionsHelper.requireNonNull(dest, DataFrameField.DESTINATION.getPreferredName()); + this.syncConfig = syncConfig; this.setHeaders(headers == null ? Collections.emptyMap() : headers); this.pivotConfig = pivotConfig; this.description = description; @@ -157,10 +174,11 @@ public static String documentId(String transformId) { public DataFrameTransformConfig(final String id, final SourceConfig source, final DestConfig dest, + final SyncConfig syncConfig, final Map headers, final PivotConfig pivotConfig, final String description) { - this(id, source, dest, headers, pivotConfig, description, null, null); + this(id, source, dest, syncConfig, headers, pivotConfig, description, null, null); } public DataFrameTransformConfig(final StreamInput in) throws IOException { @@ -171,9 +189,11 @@ public DataFrameTransformConfig(final StreamInput in) throws IOException { pivotConfig = in.readOptionalWriteable(PivotConfig::new); description = in.readOptionalString(); if (in.getVersion().onOrAfter(Version.V_7_3_0)) { + syncConfig = in.readOptionalNamedWriteable(SyncConfig.class); createTime = in.readOptionalInstant(); transformVersion = in.readBoolean() ? Version.readVersion(in) : null; } else { + syncConfig = null; createTime = null; transformVersion = null; } @@ -191,6 +211,10 @@ public DestConfig getDestination() { return dest; } + public SyncConfig getSyncConfig() { + return syncConfig; + } + public Map getHeaders() { return headers; } @@ -233,6 +257,10 @@ public boolean isValid() { return false; } + if (syncConfig != null && syncConfig.isValid() == false) { + return false; + } + return source.isValid() && dest.isValid(); } @@ -245,8 +273,9 @@ public void writeTo(final StreamOutput out) throws IOException { out.writeOptionalWriteable(pivotConfig); out.writeOptionalString(description); if (out.getVersion().onOrAfter(Version.V_7_3_0)) { + out.writeOptionalNamedWriteable(syncConfig); out.writeOptionalInstant(createTime); - if (transformVersion != null) { + if (transformVersion != null) { out.writeBoolean(true); Version.writeVersion(transformVersion, out); } else { @@ -261,6 +290,11 @@ public XContentBuilder toXContent(final XContentBuilder builder, final Params pa builder.field(DataFrameField.ID.getPreferredName(), id); builder.field(DataFrameField.SOURCE.getPreferredName(), source); builder.field(DataFrameField.DESTINATION.getPreferredName(), dest); + if (syncConfig != null) { + builder.startObject(DataFrameField.SYNC.getPreferredName()); + builder.field(syncConfig.getWriteableName(), syncConfig); + builder.endObject(); + } if (pivotConfig != null) { builder.field(PIVOT_TRANSFORM.getPreferredName(), pivotConfig); } @@ -298,6 +332,7 @@ public boolean equals(Object other) { return Objects.equals(this.id, that.id) && Objects.equals(this.source, that.source) && Objects.equals(this.dest, that.dest) + && Objects.equals(this.syncConfig, that.syncConfig) && Objects.equals(this.headers, that.headers) && Objects.equals(this.pivotConfig, that.pivotConfig) && Objects.equals(this.description, that.description) @@ -307,7 +342,7 @@ public boolean equals(Object other) { @Override public int hashCode(){ - return Objects.hash(id, source, dest, headers, pivotConfig, description, createTime, transformVersion); + return Objects.hash(id, source, dest, syncConfig, headers, pivotConfig, description, createTime, transformVersion); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/SyncConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/SyncConfig.java new file mode 100644 index 0000000000000..19ff79ea7e0ee --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/SyncConfig.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.dataframe.transforms; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.index.query.QueryBuilder; + +public interface SyncConfig extends ToXContentObject, NamedWriteable { + + /** + * Validate configuration + * + * @return true if valid + */ + boolean isValid(); + + QueryBuilder getRangeQuery(DataFrameTransformCheckpoint newCheckpoint); + + QueryBuilder getRangeQuery(DataFrameTransformCheckpoint oldCheckpoint, DataFrameTransformCheckpoint newCheckpoint); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfig.java new file mode 100644 index 0000000000000..0490394d90b26 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfig.java @@ -0,0 +1,148 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.dataframe.transforms; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.RangeQueryBuilder; +import org.elasticsearch.xpack.core.dataframe.DataFrameField; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class TimeSyncConfig implements SyncConfig { + + private static final String NAME = "data_frame_transform_pivot_sync_time"; + + private final String field; + private final TimeValue delay; + + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, lenient, + args -> { + String field = (String) args[0]; + TimeValue delay = args[1] != null ? (TimeValue) args[1] : TimeValue.ZERO; + + return new TimeSyncConfig(field, delay); + }); + + parser.declareString(constructorArg(), DataFrameField.FIELD); + parser.declareField(optionalConstructorArg(), + (p, c) -> TimeValue.parseTimeValue(p.textOrNull(), DataFrameField.DELAY.getPreferredName()), DataFrameField.DELAY, + ObjectParser.ValueType.STRING_OR_NULL); + + return parser; + } + + public TimeSyncConfig() { + this(null, null); + } + + public TimeSyncConfig(final String field, final TimeValue delay) { + this.field = ExceptionsHelper.requireNonNull(field, DataFrameField.FIELD.getPreferredName()); + this.delay = ExceptionsHelper.requireNonNull(delay, DataFrameField.DELAY.getPreferredName()); + } + + public TimeSyncConfig(StreamInput in) throws IOException { + this.field = in.readString(); + this.delay = in.readTimeValue(); + } + + public String getField() { + return field; + } + + public TimeValue getDelay() { + return delay; + } + + @Override + public boolean isValid() { + return true; + } + + @Override + public void writeTo(final StreamOutput out) throws IOException { + out.writeString(field); + out.writeTimeValue(delay); + } + + @Override + public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException { + builder.startObject(); + builder.field(DataFrameField.FIELD.getPreferredName(), field); + if (delay.duration() > 0) { + builder.field(DataFrameField.DELAY.getPreferredName(), delay.getStringRep()); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + + final TimeSyncConfig that = (TimeSyncConfig) other; + + return Objects.equals(this.field, that.field) + && Objects.equals(this.delay, that.delay); + } + + @Override + public int hashCode(){ + return Objects.hash(field, delay); + } + + @Override + public String toString() { + return Strings.toString(this, true, true); + } + + public static TimeSyncConfig parse(final XContentParser parser) { + return LENIENT_PARSER.apply(parser, null); + } + + public static TimeSyncConfig fromXContent(final XContentParser parser, boolean lenient) throws IOException { + return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); + } + + @Override + public String getWriteableName() { + return DataFrameField.TIME_BASED_SYNC.getPreferredName(); + } + + @Override + public QueryBuilder getRangeQuery(DataFrameTransformCheckpoint newCheckpoint) { + return new RangeQueryBuilder(field).lt(newCheckpoint.getTimeUpperBound()).format("epoch_millis"); + } + + @Override + public QueryBuilder getRangeQuery(DataFrameTransformCheckpoint oldCheckpoint, DataFrameTransformCheckpoint newCheckpoint) { + return new RangeQueryBuilder(field).gte(oldCheckpoint.getTimeUpperBound()).lt(newCheckpoint.getTimeUpperBound()) + .format("epoch_millis"); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/DateHistogramGroupSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/DateHistogramGroupSource.java index a3861ef65f648..e38915c0beac6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/DateHistogramGroupSource.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/DateHistogramGroupSource.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ToXContentFragment; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder; import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval; @@ -21,6 +22,7 @@ import java.time.ZoneId; import java.time.ZoneOffset; import java.util.Objects; +import java.util.Set; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -320,4 +322,15 @@ public boolean equals(Object other) { public int hashCode() { return Objects.hash(field, interval, timeZone, format); } + + @Override + public QueryBuilder getIncrementalBucketUpdateFilterQuery(Set changedBuckets) { + // no need for an extra range filter as this is already done by checkpoints + return null; + } + + @Override + public boolean supportsIncrementalBucketUpdate() { + return false; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/HistogramGroupSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/HistogramGroupSource.java index 737590a0cc197..372f4ad99b608 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/HistogramGroupSource.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/HistogramGroupSource.java @@ -11,9 +11,11 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; import java.io.IOException; import java.util.Objects; +import java.util.Set; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -99,4 +101,15 @@ public boolean equals(Object other) { public int hashCode() { return Objects.hash(field, interval); } + + @Override + public QueryBuilder getIncrementalBucketUpdateFilterQuery(Set changedBuckets) { + // histograms are simple and cheap, so we skip this optimization + return null; + } + + @Override + public boolean supportsIncrementalBucketUpdate() { + return false; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/PivotConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/PivotConfig.java index ab2f7d489ac9a..038299bfd8326 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/PivotConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/PivotConfig.java @@ -100,12 +100,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public void toCompositeAggXContent(XContentBuilder builder, Params params) throws IOException { + public void toCompositeAggXContent(XContentBuilder builder, boolean forChangeDetection) throws IOException { builder.startObject(); builder.field(CompositeAggregationBuilder.SOURCES_FIELD_NAME.getPreferredName()); builder.startArray(); for (Entry groupBy : groups.getGroups().entrySet()) { + // some group source do not implement change detection or not makes no sense, skip those + if (forChangeDetection && groupBy.getValue().supportsIncrementalBucketUpdate() == false) { + continue; + } builder.startObject(); builder.startObject(groupBy.getKey()); builder.field(groupBy.getValue().getType().value(), groupBy.getValue()); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/SingleGroupSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/SingleGroupSource.java index 0a4cf2579460e..ff1f9c3d54ac8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/SingleGroupSource.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/SingleGroupSource.java @@ -14,10 +14,12 @@ import org.elasticsearch.common.xcontent.AbstractObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.query.QueryBuilder; import java.io.IOException; import java.util.Locale; import java.util.Objects; +import java.util.Set; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -94,6 +96,10 @@ public void writeTo(StreamOutput out) throws IOException { public abstract Type getType(); + public abstract boolean supportsIncrementalBucketUpdate(); + + public abstract QueryBuilder getIncrementalBucketUpdateFilterQuery(Set changedBuckets); + public String getField() { return field; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/TermsGroupSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/TermsGroupSource.java index d4585a611b367..891b160da0762 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/TermsGroupSource.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/TermsGroupSource.java @@ -9,8 +9,11 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.TermsQueryBuilder; import java.io.IOException; +import java.util.Set; /* * A terms aggregation source for group_by @@ -47,4 +50,14 @@ public Type getType() { public static TermsGroupSource fromXContent(final XContentParser parser, boolean lenient) throws IOException { return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); } + + @Override + public QueryBuilder getIncrementalBucketUpdateFilterQuery(Set changedBuckets) { + return new TermsQueryBuilder(field, changedBuckets); + } + + @Override + public boolean supportsIncrementalBucketUpdate() { + return true; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java index 6b5ba086c6fe0..5c3da41df7349 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java @@ -13,7 +13,7 @@ public final class MachineLearningField { public static final Setting AUTODETECT_PROCESS = Setting.boolSetting("xpack.ml.autodetect_process", true, Setting.Property.NodeScope); public static final Setting MAX_MODEL_MEMORY_LIMIT = - Setting.memorySizeSetting("xpack.ml.max_model_memory_limit", new ByteSizeValue(0), + Setting.memorySizeSetting("xpack.ml.max_model_memory_limit", ByteSizeValue.ZERO, Setting.Property.Dynamic, Setting.Property.NodeScope); public static final TimeValue STATE_PERSIST_RESTORE_TIMEOUT = TimeValue.timeValueMinutes(30); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java index cd32505a48e3e..9ac63f026b089 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java @@ -11,6 +11,8 @@ import org.elasticsearch.persistent.PersistentTasksClusterService; import org.elasticsearch.persistent.PersistentTasksCustomMetaData; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.job.config.JobState; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; @@ -23,9 +25,11 @@ public final class MlTasks { public static final String JOB_TASK_NAME = "xpack/ml/job"; public static final String DATAFEED_TASK_NAME = "xpack/ml/datafeed"; + public static final String DATA_FRAME_ANALYTICS_TASK_NAME = "xpack/ml/data_frame/analytics"; public static final String JOB_TASK_ID_PREFIX = "job-"; public static final String DATAFEED_TASK_ID_PREFIX = "datafeed-"; + public static final String DATA_FRAME_ANALYTICS_TASK_ID_PREFIX = "data_frame_analytics-"; public static final PersistentTasksCustomMetaData.Assignment AWAITING_UPGRADE = new PersistentTasksCustomMetaData.Assignment(null, @@ -50,6 +54,17 @@ public static String datafeedTaskId(String datafeedId) { return DATAFEED_TASK_ID_PREFIX + datafeedId; } + /** + * Namespaces the task ids for data frame analytics. + */ + public static String dataFrameAnalyticsTaskId(String id) { + return DATA_FRAME_ANALYTICS_TASK_ID_PREFIX + id; + } + + public static String dataFrameAnalyticsIdFromTaskId(String taskId) { + return taskId.replaceFirst(DATA_FRAME_ANALYTICS_TASK_ID_PREFIX, ""); + } + @Nullable public static PersistentTasksCustomMetaData.PersistentTask getJobTask(String jobId, @Nullable PersistentTasksCustomMetaData tasks) { return tasks == null ? null : tasks.getTask(jobTaskId(jobId)); @@ -61,6 +76,12 @@ public static PersistentTasksCustomMetaData.PersistentTask getDatafeedTask(St return tasks == null ? null : tasks.getTask(datafeedTaskId(datafeedId)); } + @Nullable + public static PersistentTasksCustomMetaData.PersistentTask getDataFrameAnalyticsTask(String analyticsId, + @Nullable PersistentTasksCustomMetaData tasks) { + return tasks == null ? null : tasks.getTask(dataFrameAnalyticsTaskId(analyticsId)); + } + /** * Note that the return value of this method does NOT take node relocations into account. * Use {@link #getJobStateModifiedForReassignments} to return a value adjusted to the most @@ -120,6 +141,16 @@ public static DatafeedState getDatafeedState(String datafeedId, @Nullable Persis } } + public static DataFrameAnalyticsState getDataFrameAnalyticsState(String analyticsId, @Nullable PersistentTasksCustomMetaData tasks) { + PersistentTasksCustomMetaData.PersistentTask task = getDataFrameAnalyticsTask(analyticsId, tasks); + if (task != null && task.getState() != null) { + DataFrameAnalyticsTaskState taskState = (DataFrameAnalyticsTaskState) task.getState(); + return taskState.getState(); + } else { + return DataFrameAnalyticsState.STOPPED; + } + } + /** * The job Ids of anomaly detector job tasks. * All anomaly detector jobs are returned regardless of the status of the diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..9a777b23a4bb8 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteDataFrameAnalyticsAction.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.MasterNodeOperationRequestBuilder; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentFragment; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class DeleteDataFrameAnalyticsAction extends Action { + + public static final DeleteDataFrameAnalyticsAction INSTANCE = new DeleteDataFrameAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/delete"; + + private DeleteDataFrameAnalyticsAction() { + super(NAME); + } + + @Override + public AcknowledgedResponse newResponse() { + throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable"); + } + + @Override + public Writeable.Reader getResponseReader() { + return AcknowledgedResponse::new; + } + + public static class Request extends AcknowledgedRequest implements ToXContentFragment { + + private String id; + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + } + + public Request() {} + + public Request(String id) { + this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID); + } + + public String getId() { + return id; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DeleteDataFrameAnalyticsAction.Request request = (DeleteDataFrameAnalyticsAction.Request) o; + return Objects.equals(id, request.id); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + } + + public static class RequestBuilder extends MasterNodeOperationRequestBuilder { + + protected RequestBuilder(ElasticsearchClient client, DeleteDataFrameAnalyticsAction action) { + super(client, action, new Request()); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java new file mode 100644 index 0000000000000..eec58428d55cd --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java @@ -0,0 +1,215 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +public class EvaluateDataFrameAction extends Action { + + public static final EvaluateDataFrameAction INSTANCE = new EvaluateDataFrameAction(); + public static final String NAME = "cluster:monitor/xpack/ml/data_frame/evaluate"; + + private EvaluateDataFrameAction() { + super(NAME); + } + + @Override + public Response newResponse() { + return new Response(); + } + + public static class Request extends ActionRequest implements ToXContentObject { + + private static final ParseField INDEX = new ParseField("index"); + private static final ParseField EVALUATION = new ParseField("evaluation"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + a -> new Request((List) a[0], (Evaluation) a[1])); + + static { + PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), INDEX); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION); + } + + private static Evaluation parseEvaluation(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + Evaluation evaluation = parser.namedObject(Evaluation.class, parser.currentName(), null); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return evaluation; + } + + public static Request parseRequest(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private String[] indices; + private Evaluation evaluation; + + private Request(List indices, Evaluation evaluation) { + setIndices(indices); + setEvaluation(evaluation); + } + + public Request() { + } + + public String[] getIndices() { + return indices; + } + + public final void setIndices(List indices) { + ExceptionsHelper.requireNonNull(indices, INDEX); + if (indices.isEmpty()) { + throw ExceptionsHelper.badRequestException("At least one index must be specified"); + } + this.indices = indices.toArray(new String[indices.size()]); + } + + public Evaluation getEvaluation() { + return evaluation; + } + + public final void setEvaluation(Evaluation evaluation) { + this.evaluation = ExceptionsHelper.requireNonNull(evaluation, EVALUATION); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + indices = in.readStringArray(); + evaluation = in.readNamedWriteable(Evaluation.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeStringArray(indices); + out.writeNamedWriteable(evaluation); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.array(INDEX.getPreferredName(), indices); + builder.startObject(EVALUATION.getPreferredName()); + builder.field(evaluation.getName(), evaluation); + builder.endObject(); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(indices), evaluation); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request that = (Request) o; + return Arrays.equals(indices, that.indices) && Objects.equals(evaluation, that.evaluation); + } + } + + static class RequestBuilder extends ActionRequestBuilder { + + RequestBuilder(ElasticsearchClient client) { + super(client, INSTANCE, new Request()); + } + } + + public static class Response extends ActionResponse implements ToXContentObject { + + private String evaluationName; + private List metrics; + + public Response() { + } + + public Response(String evaluationName, List metrics) { + this.evaluationName = Objects.requireNonNull(evaluationName); + this.metrics = Objects.requireNonNull(metrics); + } + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + this.evaluationName = in.readString(); + this.metrics = in.readNamedWriteableList(EvaluationMetricResult.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(evaluationName); + out.writeList(metrics); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startObject(evaluationName); + for (EvaluationMetricResult metric : metrics) { + builder.field(metric.getName(), metric); + } + builder.endObject(); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(evaluationName, metrics); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + Response other = (Response) obj; + return Objects.equals(evaluationName, other.evaluationName) && Objects.equals(metrics, other.metrics); + } + + @Override + public final String toString() { + return Strings.toString(this); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..92233fbb27692 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; + +import java.io.IOException; +import java.util.Collections; + +public class GetDataFrameAnalyticsAction extends Action { + + public static final GetDataFrameAnalyticsAction INSTANCE = new GetDataFrameAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/get"; + + private GetDataFrameAnalyticsAction() { + super(NAME); + } + + @Override + public Response newResponse() { + return new Response(new QueryPage<>(Collections.emptyList(), 0, Response.RESULTS_FIELD)); + } + + public static class Request extends AbstractGetResourcesRequest { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + public Request() { + setAllowNoResources(true); + } + + public Request(String id) { + setResourceId(id); + setAllowNoResources(true); + } + + public Request(StreamInput in) throws IOException { + readFrom(in); + } + + @Override + public String getResourceIdField() { + return DataFrameAnalyticsConfig.ID.getPreferredName(); + } + } + + public static class Response extends AbstractGetResourcesResponse { + + public static final ParseField RESULTS_FIELD = new ParseField("data_frame_analytics"); + + public Response() {} + + public Response(QueryPage analytics) { + super(analytics); + } + + @Override + protected Reader getReader() { + return DataFrameAnalyticsConfig::new; + } + } + + public static class RequestBuilder extends ActionRequestBuilder { + + public RequestBuilder(ElasticsearchClient client) { + super(client, INSTANCE, new Request()); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java new file mode 100644 index 0000000000000..b14feaa8839f5 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java @@ -0,0 +1,321 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.TaskOperationFailure; +import org.elasticsearch.action.support.tasks.BaseTasksRequest; +import org.elasticsearch.action.support.tasks.BaseTasksResponse; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class GetDataFrameAnalyticsStatsAction extends Action { + + public static final GetDataFrameAnalyticsStatsAction INSTANCE = new GetDataFrameAnalyticsStatsAction(); + public static final String NAME = "cluster:monitor/xpack/ml/data_frame/analytics/stats/get"; + + private GetDataFrameAnalyticsStatsAction() { + super(NAME); + } + + @Override + public Response newResponse() { + throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable"); + } + + @Override + public Writeable.Reader getResponseReader() { + return Response::new; + } + + public static class Request extends BaseTasksRequest { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + private String id; + private boolean allowNoMatch = true; + private PageParams pageParams = PageParams.defaultParams(); + + // Used internally to store the expanded IDs + private List expandedIds = Collections.emptyList(); + + public Request(String id) { + this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID.getPreferredName()); + this.expandedIds = Collections.singletonList(id); + } + + public Request() {} + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + allowNoMatch = in.readBoolean(); + pageParams = in.readOptionalWriteable(PageParams::new); + expandedIds = in.readStringList(); + } + + public void setExpandedIds(List expandedIds) { + this.expandedIds = Objects.requireNonNull(expandedIds); + } + + public List getExpandedIds() { + return expandedIds; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + out.writeBoolean(allowNoMatch); + out.writeOptionalWriteable(pageParams); + out.writeStringCollection(expandedIds); + } + + public void setId(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + public boolean isAllowNoMatch() { + return allowNoMatch; + } + + public void setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + } + + public void setPageParams(PageParams pageParams) { + this.pageParams = pageParams; + } + + public PageParams getPageParams() { + return pageParams; + } + + @Override + public boolean match(Task task) { + return expandedIds.stream().anyMatch(expandedId -> StartDataFrameAnalyticsAction.TaskMatcher.match(task, expandedId)); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public int hashCode() { + return Objects.hash(id, allowNoMatch, pageParams); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + Request other = (Request) obj; + return Objects.equals(id, other.id) && allowNoMatch == other.allowNoMatch && Objects.equals(pageParams, other.pageParams); + } + } + + public static class RequestBuilder extends ActionRequestBuilder { + + public RequestBuilder(ElasticsearchClient client, GetDataFrameAnalyticsStatsAction action) { + super(client, action, new Request()); + } + } + + public static class Response extends BaseTasksResponse implements ToXContentObject { + + public static class Stats implements ToXContentObject, Writeable { + + private final String id; + private final DataFrameAnalyticsState state; + @Nullable + private final Integer progressPercentage; + @Nullable + private final DiscoveryNode node; + @Nullable + private final String assignmentExplanation; + + public Stats(String id, DataFrameAnalyticsState state, @Nullable Integer progressPercentage, + @Nullable DiscoveryNode node, @Nullable String assignmentExplanation) { + this.id = Objects.requireNonNull(id); + this.state = Objects.requireNonNull(state); + this.progressPercentage = progressPercentage; + this.node = node; + this.assignmentExplanation = assignmentExplanation; + } + + public Stats(StreamInput in) throws IOException { + id = in.readString(); + state = DataFrameAnalyticsState.fromStream(in); + progressPercentage = in.readOptionalInt(); + node = in.readOptionalWriteable(DiscoveryNode::new); + assignmentExplanation = in.readOptionalString(); + } + + public String getId() { + return id; + } + + public DataFrameAnalyticsState getState() { + return state; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // TODO: Have callers wrap the content with an object as they choose rather than forcing it upon them + builder.startObject(); + { + toUnwrappedXContent(builder); + } + return builder.endObject(); + } + + public XContentBuilder toUnwrappedXContent(XContentBuilder builder) throws IOException { + builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); + builder.field("state", state.toString()); + if (progressPercentage != null) { + builder.field("progress_percent", progressPercentage); + } + if (node != null) { + builder.startObject("node"); + builder.field("id", node.getId()); + builder.field("name", node.getName()); + builder.field("ephemeral_id", node.getEphemeralId()); + builder.field("transport_address", node.getAddress().toString()); + + builder.startObject("attributes"); + for (Map.Entry entry : node.getAttributes().entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + builder.endObject(); + } + if (assignmentExplanation != null) { + builder.field("assignment_explanation", assignmentExplanation); + } + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + state.writeTo(out); + out.writeOptionalInt(progressPercentage); + out.writeOptionalWriteable(node); + out.writeOptionalString(assignmentExplanation); + } + + @Override + public int hashCode() { + return Objects.hash(id, state, progressPercentage, node, assignmentExplanation); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + Stats other = (Stats) obj; + return Objects.equals(id, other.id) + && Objects.equals(this.state, other.state) + && Objects.equals(this.node, other.node) + && Objects.equals(this.assignmentExplanation, other.assignmentExplanation); + } + } + + private QueryPage stats; + + public Response(QueryPage stats) { + this(Collections.emptyList(), Collections.emptyList(), stats); + } + + public Response(List taskFailures, List nodeFailures, + QueryPage stats) { + super(taskFailures, nodeFailures); + this.stats = stats; + } + + public Response(StreamInput in) throws IOException { + super(in); + stats = new QueryPage<>(in, Stats::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + stats.writeTo(out); + } + + public QueryPage getResponse() { + return stats; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + stats.doXContentBody(builder, params); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(stats); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + Response other = (Response) obj; + return Objects.equals(stats, other.stats); + } + + @Override + public final String toString() { + return Strings.toString(this); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..e447aa70109e7 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsAction.java @@ -0,0 +1,153 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.MasterNodeOperationRequestBuilder; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; + +import java.io.IOException; +import java.util.Objects; + +public class PutDataFrameAnalyticsAction extends Action { + + public static final PutDataFrameAnalyticsAction INSTANCE = new PutDataFrameAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/put"; + + private PutDataFrameAnalyticsAction() { + super(NAME); + } + + @Override + public Response newResponse() { + return new Response(); + } + + public static class Request extends AcknowledgedRequest implements ToXContentObject { + + public static Request parseRequest(String id, XContentParser parser) { + DataFrameAnalyticsConfig.Builder config = DataFrameAnalyticsConfig.STRICT_PARSER.apply(parser, null); + if (config.getId() == null) { + config.setId(id); + } else if (!Strings.isNullOrEmpty(id) && !id.equals(config.getId())) { + // If we have both URI and body ID, they must be identical + throw new IllegalArgumentException(Messages.getMessage(Messages.INCONSISTENT_ID, DataFrameAnalyticsConfig.ID, + config.getId(), id)); + } + + return new PutDataFrameAnalyticsAction.Request(config.build()); + } + + private DataFrameAnalyticsConfig config; + + public Request() {} + + public Request(DataFrameAnalyticsConfig config) { + this.config = config; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + config = new DataFrameAnalyticsConfig(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + config.writeTo(out); + } + + public DataFrameAnalyticsConfig getConfig() { + return config; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + config.toXContent(builder, params); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PutDataFrameAnalyticsAction.Request request = (PutDataFrameAnalyticsAction.Request) o; + return Objects.equals(config, request.config); + } + + @Override + public int hashCode() { + return Objects.hash(config); + } + } + + public static class Response extends ActionResponse implements ToXContentObject { + + private DataFrameAnalyticsConfig config; + + public Response(DataFrameAnalyticsConfig config) { + this.config = config; + } + + Response() {} + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + config = new DataFrameAnalyticsConfig(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + config.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + config.toXContent(builder, params); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Response response = (Response) o; + return Objects.equals(config, response.config); + } + + @Override + public int hashCode() { + return Objects.hash(config); + } + } + + public static class RequestBuilder extends MasterNodeOperationRequestBuilder { + + protected RequestBuilder(ElasticsearchClient client, PutDataFrameAnalyticsAction action) { + super(client, action, new Request()); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..d722198bdfae6 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java @@ -0,0 +1,223 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.Version; +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.MasterNodeRequest; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.XPackPlugin; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class StartDataFrameAnalyticsAction extends Action { + + public static final StartDataFrameAnalyticsAction INSTANCE = new StartDataFrameAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/start"; + + private StartDataFrameAnalyticsAction() { + super(NAME); + } + + @Override + public AcknowledgedResponse newResponse() { + throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable"); + } + + @Override + public Writeable.Reader getResponseReader() { + return AcknowledgedResponse::new; + } + + public static class Request extends MasterNodeRequest implements ToXContentObject { + + public static final ParseField TIMEOUT = new ParseField("timeout"); + + private static final ObjectParser PARSER = new ObjectParser<>(NAME, Request::new); + + static { + PARSER.declareString((request, id) -> request.id = id, DataFrameAnalyticsConfig.ID); + PARSER.declareString((request, val) -> request.setTimeout(TimeValue.parseTimeValue(val, TIMEOUT.getPreferredName())), TIMEOUT); + } + + public static Request parseRequest(String id, XContentParser parser) { + Request request = PARSER.apply(parser, null); + if (request.getId() == null) { + request.setId(id); + } else if (!Strings.isNullOrEmpty(id) && !id.equals(request.getId())) { + throw new IllegalArgumentException(Messages.getMessage(Messages.INCONSISTENT_ID, DataFrameAnalyticsConfig.ID, + request.getId(), id)); + } + return request; + } + + private String id; + private TimeValue timeout = TimeValue.timeValueSeconds(20); + + public Request(String id) { + setId(id); + } + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + timeout = in.readTimeValue(); + } + + public Request() {} + + public final void setId(String id) { + this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID); + } + + public String getId() { + return id; + } + + public void setTimeout(TimeValue timeout) { + this.timeout = timeout; + } + + public TimeValue getTimeout() { + return timeout; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + out.writeTimeValue(timeout); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (id != null) { + builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); + } + builder.field(TIMEOUT.getPreferredName(), timeout.getStringRep()); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(id, timeout); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || obj.getClass() != getClass()) { + return false; + } + StartDataFrameAnalyticsAction.Request other = (StartDataFrameAnalyticsAction.Request) obj; + return Objects.equals(id, other.id) && Objects.equals(timeout, other.timeout); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + static class RequestBuilder extends ActionRequestBuilder { + + RequestBuilder(ElasticsearchClient client, StartDataFrameAnalyticsAction action) { + super(client, action, new Request()); + } + } + + public static class TaskParams implements XPackPlugin.XPackPersistentTaskParams { + + // TODO Update to first released version + public static final Version VERSION_INTRODUCED = Version.V_7_1_0; + + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, true, a -> new TaskParams((String) a[0])); + + public static TaskParams fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private String id; + + public TaskParams(String id) { + this.id = Objects.requireNonNull(id); + } + + public TaskParams(StreamInput in) throws IOException { + this.id = in.readString(); + } + + public String getId() { + return id; + } + + @Override + public String getWriteableName() { + return MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME; + } + + @Override + public Version getMinimalSupportedVersion() { + return VERSION_INTRODUCED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); + builder.endObject(); + return builder; + } + } + + public interface TaskMatcher { + + static boolean match(Task task, String expectedId) { + if (task instanceof TaskMatcher) { + if (MetaData.ALL.equals(expectedId)) { + return true; + } + String expectedDescription = MlTasks.DATA_FRAME_ANALYTICS_TASK_ID_PREFIX + expectedId; + return expectedDescription.equals(task.getDescription()); + } + return false; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..43d382147fd64 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java @@ -0,0 +1,223 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.support.tasks.BaseTasksRequest; +import org.elasticsearch.action.support.tasks.BaseTasksResponse; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; + +public class StopDataFrameAnalyticsAction extends Action { + + public static final StopDataFrameAnalyticsAction INSTANCE = new StopDataFrameAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/stop"; + + private StopDataFrameAnalyticsAction() { + super(NAME); + } + + @Override + public Response newResponse() { + throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable"); + } + + @Override + public Writeable.Reader getResponseReader() { + return Response::new; + } + + public static class Request extends BaseTasksRequest implements ToXContentObject { + + public static final ParseField TIMEOUT = new ParseField("timeout"); + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + private static final ObjectParser PARSER = new ObjectParser<>(NAME, Request::new); + + static { + PARSER.declareString((request, id) -> request.id = id, DataFrameAnalyticsConfig.ID); + PARSER.declareString((request, val) -> request.setTimeout(TimeValue.parseTimeValue(val, TIMEOUT.getPreferredName())), TIMEOUT); + } + + public static Request parseRequest(String id, XContentParser parser) { + Request request = PARSER.apply(parser, null); + if (request.getId() == null) { + request.setId(id); + } else if (!Strings.isNullOrEmpty(id) && !id.equals(request.getId())) { + throw new IllegalArgumentException(Messages.getMessage(Messages.INCONSISTENT_ID, DataFrameAnalyticsConfig.ID, + request.getId(), id)); + } + return request; + } + + private String id; + private Set expandedIds = Collections.emptySet(); + private boolean allowNoMatch = true; + + public Request(String id) { + setId(id); + } + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + expandedIds = new HashSet<>(Arrays.asList(in.readStringArray())); + allowNoMatch = in.readBoolean(); + } + + public Request() {} + + public final void setId(String id) { + this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID); + } + + public String getId() { + return id; + } + + @Nullable + public Set getExpandedIds() { + return expandedIds; + } + + public void setExpandedIds(Set expandedIds) { + this.expandedIds = Objects.requireNonNull(expandedIds); + } + + public boolean allowNoMatch() { + return allowNoMatch; + } + + public void setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + out.writeStringArray(expandedIds.toArray(new String[0])); + out.writeBoolean(allowNoMatch); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(DataFrameAnalyticsConfig.ID.getPreferredName(), id) + .field(ALLOW_NO_MATCH.getPreferredName(), allowNoMatch) + .endObject(); + } + + @Override + public int hashCode() { + return Objects.hash(id, getTimeout(), expandedIds, allowNoMatch); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || obj.getClass() != getClass()) { + return false; + } + StopDataFrameAnalyticsAction.Request other = (StopDataFrameAnalyticsAction.Request) obj; + return Objects.equals(id, other.id) + && Objects.equals(getTimeout(), other.getTimeout()) + && Objects.equals(expandedIds, other.expandedIds) + && allowNoMatch == other.allowNoMatch; + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + public static class Response extends BaseTasksResponse implements Writeable, ToXContentObject { + + private final boolean stopped; + + public Response(boolean stopped) { + super(null, null); + this.stopped = stopped; + } + + public Response(StreamInput in) throws IOException { + super(in); + stopped = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeBoolean(stopped); + } + + public boolean isStopped() { + return stopped; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentCommon(builder, params); + builder.field("stopped", stopped); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + Response response = (Response) o; + return stopped == response.stopped; + } + + @Override + public int hashCode() { + return Objects.hash(stopped); + } + } + + static class RequestBuilder extends ActionRequestBuilder { + + RequestBuilder(ElasticsearchClient client, StopDataFrameAnalyticsAction action) { + super(client, action, new Request()); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java index 810d97df34636..8c5e86b602cef 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java @@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlStrings; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; import org.elasticsearch.xpack.core.ml.utils.time.TimeUtils; @@ -123,7 +124,7 @@ private static ObjectParser createParser(boolean ignoreUnknownFie parser.declareString((builder, val) -> builder.setFrequency(TimeValue.parseTimeValue(val, FREQUENCY.getPreferredName())), FREQUENCY); parser.declareObject(Builder::setQueryProvider, - (p, c) -> QueryProvider.fromXContent(p, ignoreUnknownFields), + (p, c) -> QueryProvider.fromXContent(p, ignoreUnknownFields, Messages.DATAFEED_CONFIG_QUERY_BAD_FORMAT), QUERY); parser.declareObject(Builder::setAggregationsSafe, (p, c) -> AggProvider.fromXContent(p, ignoreUnknownFields), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java index ccbb516197217..1ae4159cddb9d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java @@ -22,7 +22,9 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.job.config.Job; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; import java.io.IOException; @@ -53,7 +55,8 @@ public class DatafeedUpdate implements Writeable, ToXContentObject { TimeValue.parseTimeValue(val, DatafeedConfig.QUERY_DELAY.getPreferredName())), DatafeedConfig.QUERY_DELAY); PARSER.declareString((builder, val) -> builder.setFrequency( TimeValue.parseTimeValue(val, DatafeedConfig.FREQUENCY.getPreferredName())), DatafeedConfig.FREQUENCY); - PARSER.declareObject(Builder::setQuery, (p, c) -> QueryProvider.fromXContent(p, false), DatafeedConfig.QUERY); + PARSER.declareObject(Builder::setQuery, (p, c) -> QueryProvider.fromXContent(p, false, Messages.DATAFEED_CONFIG_QUERY_BAD_FORMAT), + DatafeedConfig.QUERY); PARSER.declareObject(Builder::setAggregationsSafe, (p, c) -> AggProvider.fromXContent(p, false), DatafeedConfig.AGGREGATIONS); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java new file mode 100644 index 0000000000000..0e9acdd44a2fe --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java @@ -0,0 +1,312 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ObjectParser.ValueType.OBJECT_ARRAY_BOOLEAN_OR_STRING; +import static org.elasticsearch.common.xcontent.ObjectParser.ValueType.VALUE; + +public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable { + + public static final String TYPE = "data_frame_analytics_config"; + + public static final ByteSizeValue DEFAULT_MODEL_MEMORY_LIMIT = new ByteSizeValue(1, ByteSizeUnit.GB); + public static final ByteSizeValue MIN_MODEL_MEMORY_LIMIT = new ByteSizeValue(1, ByteSizeUnit.MB); + public static final ByteSizeValue PROCESS_MEMORY_OVERHEAD = new ByteSizeValue(20, ByteSizeUnit.MB); + + public static final ParseField ID = new ParseField("id"); + public static final ParseField SOURCE = new ParseField("source"); + public static final ParseField DEST = new ParseField("dest"); + public static final ParseField ANALYSIS = new ParseField("analysis"); + public static final ParseField CONFIG_TYPE = new ParseField("config_type"); + public static final ParseField ANALYZED_FIELDS = new ParseField("analyzed_fields"); + public static final ParseField MODEL_MEMORY_LIMIT = new ParseField("model_memory_limit"); + public static final ParseField HEADERS = new ParseField("headers"); + + public static final ObjectParser STRICT_PARSER = createParser(false); + public static final ObjectParser LENIENT_PARSER = createParser(true); + + public static ObjectParser createParser(boolean ignoreUnknownFields) { + ObjectParser parser = new ObjectParser<>(TYPE, ignoreUnknownFields, Builder::new); + + parser.declareString((c, s) -> {}, CONFIG_TYPE); + parser.declareString(Builder::setId, ID); + parser.declareObject(Builder::setSource, DataFrameAnalyticsSource.createParser(ignoreUnknownFields), SOURCE); + parser.declareObject(Builder::setDest, DataFrameAnalyticsDest.createParser(ignoreUnknownFields), DEST); + parser.declareObject(Builder::setAnalysis, (p, c) -> parseAnalysis(p, ignoreUnknownFields), ANALYSIS); + parser.declareField(Builder::setAnalyzedFields, + (p, c) -> FetchSourceContext.fromXContent(p), + ANALYZED_FIELDS, + OBJECT_ARRAY_BOOLEAN_OR_STRING); + parser.declareField(Builder::setModelMemoryLimit, + (p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), MODEL_MEMORY_LIMIT.getPreferredName()), MODEL_MEMORY_LIMIT, VALUE); + if (ignoreUnknownFields) { + // Headers are not parsed by the strict (config) parser, so headers supplied in the _body_ of a REST request will be rejected. + // (For config, headers are explicitly transferred from the auth headers by code in the put data frame actions.) + parser.declareObject(Builder::setHeaders, (p, c) -> p.mapStrings(), HEADERS); + } + return parser; + } + + private static DataFrameAnalysis parseAnalysis(XContentParser parser, boolean ignoreUnknownFields) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + DataFrameAnalysis analysis = parser.namedObject(DataFrameAnalysis.class, parser.currentName(), ignoreUnknownFields); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return analysis; + } + + private final String id; + private final DataFrameAnalyticsSource source; + private final DataFrameAnalyticsDest dest; + private final DataFrameAnalysis analysis; + private final FetchSourceContext analyzedFields; + /** + * This may be null up to the point of persistence, as the relationship with xpack.ml.max_model_memory_limit + * depends on whether the user explicitly set the value or if the default was requested. null indicates + * the default was requested, which in turn means a default higher than the maximum is silently capped. + * A non-null value higher than xpack.ml.max_model_memory_limit will cause a + * validation error even if it is equal to the default value. This behaviour matches what is done in + * {@link org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits}. + */ + private final ByteSizeValue modelMemoryLimit; + private final Map headers; + + public DataFrameAnalyticsConfig(String id, DataFrameAnalyticsSource source, DataFrameAnalyticsDest dest, + DataFrameAnalysis analysis, Map headers, ByteSizeValue modelMemoryLimit, + FetchSourceContext analyzedFields) { + this.id = ExceptionsHelper.requireNonNull(id, ID); + this.source = ExceptionsHelper.requireNonNull(source, SOURCE); + this.dest = ExceptionsHelper.requireNonNull(dest, DEST); + this.analysis = ExceptionsHelper.requireNonNull(analysis, ANALYSIS); + this.analyzedFields = analyzedFields; + this.modelMemoryLimit = modelMemoryLimit; + this.headers = Collections.unmodifiableMap(headers); + } + + public DataFrameAnalyticsConfig(StreamInput in) throws IOException { + id = in.readString(); + source = new DataFrameAnalyticsSource(in); + dest = new DataFrameAnalyticsDest(in); + analysis = in.readNamedWriteable(DataFrameAnalysis.class); + this.analyzedFields = in.readOptionalWriteable(FetchSourceContext::new); + this.modelMemoryLimit = in.readOptionalWriteable(ByteSizeValue::new); + this.headers = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString)); + } + + public String getId() { + return id; + } + + public DataFrameAnalyticsSource getSource() { + return source; + } + + public DataFrameAnalyticsDest getDest() { + return dest; + } + + public DataFrameAnalysis getAnalysis() { + return analysis; + } + + public FetchSourceContext getAnalyzedFields() { + return analyzedFields; + } + + public ByteSizeValue getModelMemoryLimit() { + return modelMemoryLimit != null ? modelMemoryLimit : DEFAULT_MODEL_MEMORY_LIMIT; + } + + public Map getHeaders() { + return headers; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ID.getPreferredName(), id); + builder.field(SOURCE.getPreferredName(), source); + builder.field(DEST.getPreferredName(), dest); + + builder.startObject(ANALYSIS.getPreferredName()); + builder.field(analysis.getWriteableName(), analysis); + builder.endObject(); + + if (params.paramAsBoolean(ToXContentParams.INCLUDE_TYPE, false)) { + builder.field(CONFIG_TYPE.getPreferredName(), TYPE); + } + if (analyzedFields != null) { + builder.field(ANALYZED_FIELDS.getPreferredName(), analyzedFields); + } + builder.field(MODEL_MEMORY_LIMIT.getPreferredName(), getModelMemoryLimit().getStringRep()); + if (headers.isEmpty() == false && params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { + builder.field(HEADERS.getPreferredName(), headers); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + source.writeTo(out); + dest.writeTo(out); + out.writeNamedWriteable(analysis); + out.writeOptionalWriteable(analyzedFields); + out.writeOptionalWriteable(modelMemoryLimit); + out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString); + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsConfig other = (DataFrameAnalyticsConfig) o; + return Objects.equals(id, other.id) + && Objects.equals(source, other.source) + && Objects.equals(dest, other.dest) + && Objects.equals(analysis, other.analysis) + && Objects.equals(headers, other.headers) + && Objects.equals(getModelMemoryLimit(), other.getModelMemoryLimit()) + && Objects.equals(analyzedFields, other.analyzedFields); + } + + @Override + public int hashCode() { + return Objects.hash(id, source, dest, analysis, headers, getModelMemoryLimit(), analyzedFields); + } + + public static String documentId(String id) { + return TYPE + "-" + id; + } + + public static class Builder { + + private String id; + private DataFrameAnalyticsSource source; + private DataFrameAnalyticsDest dest; + private DataFrameAnalysis analysis; + private FetchSourceContext analyzedFields; + private ByteSizeValue modelMemoryLimit; + private ByteSizeValue maxModelMemoryLimit; + private Map headers = Collections.emptyMap(); + + public Builder() {} + + public Builder(String id) { + setId(id); + } + + public Builder(ByteSizeValue maxModelMemoryLimit) { + this.maxModelMemoryLimit = maxModelMemoryLimit; + } + + public Builder(DataFrameAnalyticsConfig config) { + this(config, null); + } + + public Builder(DataFrameAnalyticsConfig config, ByteSizeValue maxModelMemoryLimit) { + this.id = config.id; + this.source = new DataFrameAnalyticsSource(config.source); + this.dest = new DataFrameAnalyticsDest(config.dest); + this.analysis = config.analysis; + this.headers = new HashMap<>(config.headers); + this.modelMemoryLimit = config.modelMemoryLimit; + this.maxModelMemoryLimit = maxModelMemoryLimit; + if (config.analyzedFields != null) { + this.analyzedFields = new FetchSourceContext(true, config.analyzedFields.includes(), config.analyzedFields.excludes()); + } + } + + public String getId() { + return id; + } + + public Builder setId(String id) { + this.id = ExceptionsHelper.requireNonNull(id, ID); + return this; + } + + public Builder setSource(DataFrameAnalyticsSource source) { + this.source = ExceptionsHelper.requireNonNull(source, SOURCE); + return this; + } + + public Builder setDest(DataFrameAnalyticsDest dest) { + this.dest = ExceptionsHelper.requireNonNull(dest, DEST); + return this; + } + + public Builder setAnalysis(DataFrameAnalysis analysis) { + this.analysis = ExceptionsHelper.requireNonNull(analysis, ANALYSIS); + return this; + } + + public Builder setAnalyzedFields(FetchSourceContext fields) { + this.analyzedFields = fields; + return this; + } + + public Builder setHeaders(Map headers) { + this.headers = headers; + return this; + } + + public Builder setModelMemoryLimit(ByteSizeValue modelMemoryLimit) { + if (modelMemoryLimit != null && modelMemoryLimit.compareTo(MIN_MODEL_MEMORY_LIMIT) < 0) { + throw new IllegalArgumentException("[" + MODEL_MEMORY_LIMIT.getPreferredName() + + "] must be at least [" + MIN_MODEL_MEMORY_LIMIT.getStringRep() + "]"); + } + this.modelMemoryLimit = modelMemoryLimit; + return this; + } + + private void applyMaxModelMemoryLimit() { + + boolean maxModelMemoryIsSet = maxModelMemoryLimit != null && maxModelMemoryLimit.getMb() > 0; + + if (modelMemoryLimit == null) { + // Default is silently capped if higher than limit + if (maxModelMemoryIsSet && DEFAULT_MODEL_MEMORY_LIMIT.compareTo(maxModelMemoryLimit) > 0) { + modelMemoryLimit = maxModelMemoryLimit; + } + } else if (maxModelMemoryIsSet && modelMemoryLimit.compareTo(maxModelMemoryLimit) > 0) { + // Explicit setting higher than limit is an error + throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.JOB_CONFIG_MODEL_MEMORY_LIMIT_GREATER_THAN_MAX, + modelMemoryLimit, maxModelMemoryLimit)); + } + } + + public DataFrameAnalyticsConfig build() { + applyMaxModelMemoryLimit(); + return new DataFrameAnalyticsConfig(id, source, dest, analysis, headers, modelMemoryLimit, analyzedFields); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java new file mode 100644 index 0000000000000..3bc435336f062 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java @@ -0,0 +1,106 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.indices.InvalidIndexNameException; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Locale; +import java.util.Objects; + +import static org.elasticsearch.cluster.metadata.MetaDataCreateIndexService.validateIndexOrAliasName; + +public class DataFrameAnalyticsDest implements Writeable, ToXContentObject { + + public static final ParseField INDEX = new ParseField("index"); + public static final ParseField RESULTS_FIELD = new ParseField("results_field"); + + private static final String DEFAULT_RESULTS_FIELD = "ml"; + + public static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>("data_frame_analytics_dest", + ignoreUnknownFields, a -> new DataFrameAnalyticsDest((String) a[0], (String) a[1])); + parser.declareString(ConstructingObjectParser.constructorArg(), INDEX); + parser.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD); + return parser; + } + + private final String index; + private final String resultsField; + + public DataFrameAnalyticsDest(String index, @Nullable String resultsField) { + this.index = ExceptionsHelper.requireNonNull(index, INDEX); + if (index.isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must be non-empty", INDEX); + } + this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField; + } + + public DataFrameAnalyticsDest(StreamInput in) throws IOException { + index = in.readString(); + resultsField = in.readString(); + } + + public DataFrameAnalyticsDest(DataFrameAnalyticsDest other) { + this.index = other.index; + this.resultsField = other.resultsField; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(index); + out.writeString(resultsField); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INDEX.getPreferredName(), index); + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsDest other = (DataFrameAnalyticsDest) o; + return Objects.equals(index, other.index) && Objects.equals(resultsField, other.resultsField); + } + + @Override + public int hashCode() { + return Objects.hash(index, resultsField); + } + + public String getIndex() { + return index; + } + + public String getResultsField() { + return resultsField; + } + + public void validate() { + if (index != null) { + validateIndexOrAliasName(index, InvalidIndexNameException::new); + if (index.toLowerCase(Locale.ROOT).equals(index) == false) { + throw new InvalidIndexNameException(index, "dest.index must be lowercase"); + } + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java new file mode 100644 index 0000000000000..a57de375f3989 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java @@ -0,0 +1,144 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; +import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class DataFrameAnalyticsSource implements Writeable, ToXContentObject { + + public static final ParseField INDEX = new ParseField("index"); + public static final ParseField QUERY = new ParseField("query"); + + public static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>("data_frame_analytics_source", + ignoreUnknownFields, a -> new DataFrameAnalyticsSource((String) a[0], (QueryProvider) a[1])); + parser.declareString(ConstructingObjectParser.constructorArg(), INDEX); + parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> QueryProvider.fromXContent(p, ignoreUnknownFields, Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT), QUERY); + return parser; + } + + private final String index; + private final QueryProvider queryProvider; + + public DataFrameAnalyticsSource(String index, @Nullable QueryProvider queryProvider) { + this.index = ExceptionsHelper.requireNonNull(index, INDEX); + if (index.isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must be non-empty", INDEX); + } + this.queryProvider = queryProvider == null ? QueryProvider.defaultQuery() : queryProvider; + } + + public DataFrameAnalyticsSource(StreamInput in) throws IOException { + index = in.readString(); + queryProvider = QueryProvider.fromStream(in); + } + + public DataFrameAnalyticsSource(DataFrameAnalyticsSource other) { + this.index = other.index; + this.queryProvider = new QueryProvider(other.queryProvider); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(index); + queryProvider.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INDEX.getPreferredName(), index); + builder.field(QUERY.getPreferredName(), queryProvider.getQuery()); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsSource other = (DataFrameAnalyticsSource) o; + return Objects.equals(index, other.index) + && Objects.equals(queryProvider, other.queryProvider); + } + + @Override + public int hashCode() { + return Objects.hash(index, queryProvider); + } + + public String getIndex() { + return index; + } + + /** + * Get the fully parsed query from the semi-parsed stored {@code Map} + * + * @return Fully parsed query + */ + public QueryBuilder getParsedQuery() { + Exception exception = queryProvider.getParsingException(); + if (exception != null) { + if (exception instanceof RuntimeException) { + throw (RuntimeException) exception; + } else { + throw new ElasticsearchException(queryProvider.getParsingException()); + } + } + return queryProvider.getParsedQuery(); + } + + Exception getQueryParsingException() { + return queryProvider.getParsingException(); + } + + /** + * Calls the parser and returns any gathered deprecations + * + * @param namedXContentRegistry XContent registry to transform the lazily parsed query + * @return The deprecations from parsing the query + */ + public List getQueryDeprecations(NamedXContentRegistry namedXContentRegistry) { + List deprecations = new ArrayList<>(); + try { + XContentObjectTransformer.queryBuilderTransformer(namedXContentRegistry).fromMap(queryProvider.getQuery(), + deprecations); + } catch (Exception exception) { + // Certain thrown exceptions wrap up the real Illegal argument making it hard to determine cause for the user + if (exception.getCause() instanceof IllegalArgumentException) { + exception = (Exception) exception.getCause(); + } + throw ExceptionsHelper.badRequestException(Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT, exception); + } + return deprecations; + } + + public Map getQuery() { + return queryProvider.getQuery(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsState.java new file mode 100644 index 0000000000000..d40df259eec57 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsState.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Locale; + +public enum DataFrameAnalyticsState implements Writeable { + + STARTED, REINDEXING, ANALYZING, STOPPING, STOPPED; + + public static DataFrameAnalyticsState fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static DataFrameAnalyticsState fromStream(StreamInput in) throws IOException { + return in.readEnum(DataFrameAnalyticsState.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(this); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java new file mode 100644 index 0000000000000..994faaaee6cc2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.persistent.PersistentTaskState; +import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.xpack.core.ml.MlTasks; + +import java.io.IOException; +import java.util.Objects; + +public class DataFrameAnalyticsTaskState implements PersistentTaskState { + + public static final String NAME = MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME; + + private static ParseField STATE = new ParseField("state"); + private static ParseField ALLOCATION_ID = new ParseField("allocation_id"); + + private final DataFrameAnalyticsState state; + private final long allocationId; + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, true, + a -> new DataFrameAnalyticsTaskState((DataFrameAnalyticsState) a[0], (long) a[1])); + + static { + PARSER.declareField(ConstructingObjectParser.constructorArg(), p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return DataFrameAnalyticsState.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, STATE, ObjectParser.ValueType.STRING); + PARSER.declareLong(ConstructingObjectParser.constructorArg(), ALLOCATION_ID); + } + + public static DataFrameAnalyticsTaskState fromXContent(XContentParser parser) { + try { + return PARSER.parse(parser, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public DataFrameAnalyticsTaskState(DataFrameAnalyticsState state, long allocationId) { + this.state = Objects.requireNonNull(state); + this.allocationId = allocationId; + } + + public DataFrameAnalyticsTaskState(StreamInput in) throws IOException { + this.state = DataFrameAnalyticsState.fromStream(in); + this.allocationId = in.readLong(); + } + + public DataFrameAnalyticsState getState() { + return state; + } + + public boolean isStatusStale(PersistentTasksCustomMetaData.PersistentTask task) { + return allocationId != task.getAllocationId(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + state.writeTo(out); + out.writeLong(allocationId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(STATE.getPreferredName(), state.toString()); + builder.field(ALLOCATION_ID.getPreferredName(), allocationId); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DataFrameAnalyticsTaskState that = (DataFrameAnalyticsTaskState) o; + return allocationId == that.allocationId && + state == that.state; + } + + @Override + public int hashCode() { + return Objects.hash(state, allocationId); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java new file mode 100644 index 0000000000000..f21533d917602 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; + +import java.util.Map; + +public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { + + Map getParams(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java new file mode 100644 index 0000000000000..a48a23e4a8393 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; + +import java.util.ArrayList; +import java.util.List; + +public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + List namedXContent = new ArrayList<>(); + + namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME, (p, c) -> { + boolean ignoreUnknownFields = (boolean) c; + return OutlierDetection.fromXContent(p, ignoreUnknownFields); + })); + + return namedXContent; + } + + public List getNamedWriteables() { + List namedWriteables = new ArrayList<>(); + + namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), + OutlierDetection::new)); + + return namedWriteables; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java new file mode 100644 index 0000000000000..91eb02b7bcdfe --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -0,0 +1,169 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +public class OutlierDetection implements DataFrameAnalysis { + + public static final ParseField NAME = new ParseField("outlier_detection"); + + public static final ParseField N_NEIGHBORS = new ParseField("n_neighbors"); + public static final ParseField METHOD = new ParseField("method"); + public static final ParseField MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE = + new ParseField("minimum_score_to_write_feature_influence"); + + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient, + a -> new OutlierDetection((Integer) a[0], (Method) a[1], (Double) a[2])); + parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), N_NEIGHBORS); + parser.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return Method.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, METHOD, ObjectParser.ValueType.STRING); + parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE); + return parser; + } + + public static OutlierDetection fromXContent(XContentParser parser, boolean ignoreUnknownFields) { + return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); + } + + private final Integer nNeighbors; + private final Method method; + private final Double minScoreToWriteFeatureInfluence; + + /** + * Constructs the outlier detection configuration + * @param nNeighbors The number of neighbors. Leave unspecified for dynamic detection. + * @param method The method. Leave unspecified for a dynamic mixture of methods. + * @param minScoreToWriteFeatureInfluence The min outlier score required to calculate feature influence. Defaults to 0.1. + */ + public OutlierDetection(@Nullable Integer nNeighbors, @Nullable Method method, @Nullable Double minScoreToWriteFeatureInfluence) { + if (nNeighbors != null && nNeighbors <= 0) { + throw ExceptionsHelper.badRequestException("[{}] must be a positive integer", N_NEIGHBORS.getPreferredName()); + } + + if (minScoreToWriteFeatureInfluence != null && (minScoreToWriteFeatureInfluence < 0.0 || minScoreToWriteFeatureInfluence > 1.0)) { + throw ExceptionsHelper.badRequestException("[{}] must be in [0, 1]", + MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName()); + } + + this.nNeighbors = nNeighbors; + this.method = method; + this.minScoreToWriteFeatureInfluence = minScoreToWriteFeatureInfluence; + } + + /** + * Constructs the default outlier detection configuration + */ + public OutlierDetection() { + this(null, null, null); + } + + public OutlierDetection(StreamInput in) throws IOException { + nNeighbors = in.readOptionalVInt(); + method = in.readBoolean() ? in.readEnum(Method.class) : null; + minScoreToWriteFeatureInfluence = in.readOptionalDouble(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalVInt(nNeighbors); + + if (method != null) { + out.writeBoolean(true); + out.writeEnum(method); + } else { + out.writeBoolean(false); + } + + out.writeOptionalDouble(minScoreToWriteFeatureInfluence); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (nNeighbors != null) { + builder.field(N_NEIGHBORS.getPreferredName(), nNeighbors); + } + if (method != null) { + builder.field(METHOD.getPreferredName(), method); + } + if (minScoreToWriteFeatureInfluence != null) { + builder.field(MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName(), minScoreToWriteFeatureInfluence); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OutlierDetection that = (OutlierDetection) o; + return Objects.equals(nNeighbors, that.nNeighbors) + && Objects.equals(method, that.method) + && Objects.equals(minScoreToWriteFeatureInfluence, that.minScoreToWriteFeatureInfluence); + } + + @Override + public int hashCode() { + return Objects.hash(nNeighbors, method, minScoreToWriteFeatureInfluence); + } + + @Override + public Map getParams() { + Map params = new HashMap<>(); + if (nNeighbors != null) { + params.put(N_NEIGHBORS.getPreferredName(), nNeighbors); + } + if (method != null) { + params.put(METHOD.getPreferredName(), method); + } + if (minScoreToWriteFeatureInfluence != null) { + params.put(MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName(), minScoreToWriteFeatureInfluence); + } + return params; + } + + public enum Method { + LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN; + + public static Method fromString(String value) { + return Method.valueOf(value.toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java new file mode 100644 index 0000000000000..c01c19e33e865 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.search.builder.SearchSourceBuilder; + +import java.util.List; + +/** + * Defines an evaluation + */ +public interface Evaluation extends ToXContentObject, NamedWriteable { + + /** + * Returns the evaluation name + */ + String getName(); + + /** + * Builds the search required to collect data to compute the evaluation result + */ + SearchSourceBuilder buildSearch(); + + /** + * Computes the evaluation result + * @param searchResponse The search response required to compute the result + * @param listener A listener of the results + */ + void evaluate(SearchResponse searchResponse, ActionListener> listener); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java new file mode 100644 index 0000000000000..36b8adf9d4ea3 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; + +/** + * The result of an evaluation metric + */ +public interface EvaluationMetricResult extends ToXContentObject, NamedWriteable { + + /** + * Returns the name of the metric + */ + String getName(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java new file mode 100644 index 0000000000000..f4a6dba88e3b1 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric; + +import java.util.ArrayList; +import java.util.List; + +public class MlEvaluationNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + List namedXContent = new ArrayList<>(); + + // Evaluations + namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME, + BinarySoftClassification::fromXContent)); + + // Soft classification metrics + namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME, AucRoc::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Precision.NAME, Precision::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Recall.NAME, Recall::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME, + ConfusionMatrix::fromXContent)); + + return namedXContent; + } + + public List getNamedWriteables() { + List namedWriteables = new ArrayList<>(); + + // Evaluations + namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), + BinarySoftClassification::new)); + + // Evaluation Metrics + namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), + AucRoc::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(), + Precision::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(), + Recall::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(), + ConfusionMatrix::new)); + + // Evaluation Metrics Results + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(), + AucRoc.Result::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME, + ScoreByThresholdResult::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(), + ConfusionMatrix.Result::new)); + + return namedWriteables; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java new file mode 100644 index 0000000000000..facdcceea194f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric { + + public static final ParseField AT = new ParseField("at"); + + protected final double[] thresholds; + + protected AbstractConfusionMatrixMetric(double[] thresholds) { + this.thresholds = ExceptionsHelper.requireNonNull(thresholds, AT); + if (thresholds.length == 0) { + throw ExceptionsHelper.badRequestException("[" + getMetricName() + "." + AT.getPreferredName() + + "] must have at least one value"); + } + for (double threshold : thresholds) { + if (threshold < 0 || threshold > 1.0) { + throw ExceptionsHelper.badRequestException("[" + getMetricName() + "." + AT.getPreferredName() + + "] values must be in [0.0, 1.0]"); + } + } + } + + protected AbstractConfusionMatrixMetric(StreamInput in) throws IOException { + this.thresholds = in.readDoubleArray(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDoubleArray(thresholds); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(AT.getPreferredName(), thresholds); + builder.endObject(); + return builder; + } + + @Override + public final List aggs(String actualField, List classInfos) { + List aggs = new ArrayList<>(); + for (double threshold : thresholds) { + aggs.addAll(aggsAt(actualField, classInfos, threshold)); + } + return aggs; + } + + protected abstract List aggsAt(String labelField, List classInfos, double threshold); + + protected enum Condition { + TP, FP, TN, FN; + } + + protected String aggName(ClassInfo classInfo, double threshold, Condition condition) { + return getMetricName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name(); + } + + protected AggregationBuilder buildAgg(ClassInfo classInfo, double threshold, Condition condition) { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + switch (condition) { + case TP: + boolQuery.must(classInfo.matchingQuery()); + boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold)); + break; + case FP: + boolQuery.mustNot(classInfo.matchingQuery()); + boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold)); + break; + case TN: + boolQuery.mustNot(classInfo.matchingQuery()); + boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold)); + break; + case FN: + boolQuery.must(classInfo.matchingQuery()); + boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold)); + break; + default: + throw new IllegalArgumentException("Unknown enum value: " + condition); + } + return AggregationBuilders.filter(aggName(classInfo, threshold, condition), boolQuery); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java new file mode 100644 index 0000000000000..228dac00bfb68 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java @@ -0,0 +1,350 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.search.aggregations.metrics.Percentiles; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; +import java.util.stream.IntStream; + +/** + * Area under the curve (AUC) of the receiver operating characteristic (ROC). + * The ROC curve is a plot of the TPR (true positive rate) against + * the FPR (false positive rate) over a varying threshold. + * + * This particular implementation is making use of ES aggregations + * to calculate the curve. It then uses the trapezoidal rule to calculate + * the AUC. + * + * In particular, in order to calculate the ROC, we get percentiles of TP + * and FP against the predicted probability. We call those Rate-Threshold + * curves. We then scan ROC points from each Rate-Threshold curve against the + * other using interpolation. This gives us an approximation of the ROC curve + * that has the advantage of being efficient and resilient to some edge cases. + * + * When this is used for multi-class classification, it will calculate the ROC + * curve of each class versus the rest. + */ +public class AucRoc implements SoftClassificationMetric { + + public static final ParseField NAME = new ParseField("auc_roc"); + + public static final ParseField INCLUDE_CURVE = new ParseField("include_curve"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), + a -> new AucRoc((Boolean) a[0])); + + static { + PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), INCLUDE_CURVE); + } + + private static final String PERCENTILES = "percentiles"; + + public static AucRoc fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final boolean includeCurve; + + public AucRoc(Boolean includeCurve) { + this.includeCurve = includeCurve == null ? false : includeCurve; + } + + public AucRoc(StreamInput in) throws IOException { + this.includeCurve = in.readBoolean(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(includeCurve); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve); + builder.endObject(); + return builder; + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AucRoc that = (AucRoc) o; + return Objects.equals(includeCurve, that.includeCurve); + } + + @Override + public int hashCode() { + return Objects.hash(includeCurve); + } + + @Override + public List aggs(String actualField, List classInfos) { + double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray(); + List aggs = new ArrayList<>(); + for (ClassInfo classInfo : classInfos) { + AggregationBuilder percentilesForClassValueAgg = AggregationBuilders + .filter(evaluatedLabelAggName(classInfo), classInfo.matchingQuery()) + .subAggregation( + AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(percentiles)); + AggregationBuilder percentilesForRestAgg = AggregationBuilders + .filter(restLabelsAggName(classInfo), QueryBuilders.boolQuery().mustNot(classInfo.matchingQuery())) + .subAggregation( + AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(percentiles)); + aggs.add(percentilesForClassValueAgg); + aggs.add(percentilesForRestAgg); + } + return aggs; + } + + private String evaluatedLabelAggName(ClassInfo classInfo) { + return getMetricName() + "_" + classInfo.getName(); + } + + private String restLabelsAggName(ClassInfo classInfo) { + return getMetricName() + "_non_" + classInfo.getName(); + } + + @Override + public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + Filter classAgg = aggs.get(evaluatedLabelAggName(classInfo)); + Filter restAgg = aggs.get(restLabelsAggName(classInfo)); + double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES), + "[" + getMetricName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]"); + double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES), + "[" + getMetricName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]"); + List aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles); + double aucRocScore = calculateAucScore(aucRocCurve); + return new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()); + } + + private static double[] percentilesArray(Percentiles percentiles, String errorIfUndefined) { + double[] result = new double[99]; + percentiles.forEach(percentile -> { + if (Double.isNaN(percentile.getValue())) { + throw ExceptionsHelper.badRequestException(errorIfUndefined); + } + result[((int) percentile.getPercent()) - 1] = percentile.getValue(); + }); + return result; + } + + /** + * Visible for testing + */ + static List buildAucRocCurve(double[] tpPercentiles, double[] fpPercentiles) { + assert tpPercentiles.length == fpPercentiles.length; + assert tpPercentiles.length == 99; + + List aucRocCurve = new ArrayList<>(); + aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0)); + aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0)); + RateThresholdCurve tpCurve = new RateThresholdCurve(tpPercentiles, true); + RateThresholdCurve fpCurve = new RateThresholdCurve(fpPercentiles, false); + aucRocCurve.addAll(tpCurve.scanPoints(fpCurve)); + aucRocCurve.addAll(fpCurve.scanPoints(tpCurve)); + Collections.sort(aucRocCurve); + return aucRocCurve; + } + + /** + * Visible for testing + */ + static double calculateAucScore(List rocCurve) { + // Calculates AUC based on the trapezoid rule + double aucRoc = 0.0; + for (int i = 1; i < rocCurve.size(); i++) { + AucRocPoint left = rocCurve.get(i - 1); + AucRocPoint right = rocCurve.get(i); + aucRoc += (right.fpr - left.fpr) * (right.tpr + left.tpr) / 2; + } + return aucRoc; + } + + private static class RateThresholdCurve { + + private final double[] percentiles; + private final boolean isTp; + + private RateThresholdCurve(double[] percentiles, boolean isTp) { + this.percentiles = percentiles; + this.isTp = isTp; + } + + private double getRate(int index) { + return 1 - 0.01 * (index + 1); + } + + private double getThreshold(int index) { + return percentiles[index]; + } + + private double interpolateRate(double threshold) { + int binarySearchResult = Arrays.binarySearch(percentiles, threshold); + if (binarySearchResult >= 0) { + return getRate(binarySearchResult); + } else { + int right = (binarySearchResult * -1) -1; + int left = right - 1; + if (right >= percentiles.length) { + return 0.0; + } else if (left < 0) { + return 1.0; + } else { + double rightRate = getRate(right); + double leftRate = getRate(left); + return interpolate(threshold, percentiles[left], leftRate, percentiles[right], rightRate); + } + } + } + + private List scanPoints(RateThresholdCurve againstCurve) { + List points = new ArrayList<>(); + for (int index = 0; index < percentiles.length; index++) { + double rate = getRate(index); + double scannedThreshold = getThreshold(index); + double againstRate = againstCurve.interpolateRate(scannedThreshold); + AucRocPoint point; + if (isTp) { + point = new AucRocPoint(rate, againstRate, scannedThreshold); + } else { + point = new AucRocPoint(againstRate, rate, scannedThreshold); + } + points.add(point); + } + return points; + } + } + + public static final class AucRocPoint implements Comparable, ToXContentObject, Writeable { + double tpr; + double fpr; + double threshold; + + private AucRocPoint(double tpr, double fpr, double threshold) { + this.tpr = tpr; + this.fpr = fpr; + this.threshold = threshold; + } + + private AucRocPoint(StreamInput in) throws IOException { + this.tpr = in.readDouble(); + this.fpr = in.readDouble(); + this.threshold = in.readDouble(); + } + + @Override + public int compareTo(AucRocPoint o) { + return Comparator.comparingDouble((AucRocPoint p) -> p.threshold).reversed() + .thenComparing(p -> p.fpr) + .thenComparing(p -> p.tpr) + .compare(this, o); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(tpr); + out.writeDouble(fpr); + out.writeDouble(threshold); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("tpr", tpr); + builder.field("fpr", fpr); + builder.field("threshold", threshold); + builder.endObject(); + return builder; + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + private static double interpolate(double x, double x1, double y1, double x2, double y2) { + return y1 + (x - x1) * (y2 - y1) / (x2 - x1); + } + + public static class Result implements EvaluationMetricResult { + + private final double score; + private final List curve; + + public Result(double score, List curve) { + this.score = score; + this.curve = Objects.requireNonNull(curve); + } + + public Result(StreamInput in) throws IOException { + this.score = in.readDouble(); + this.curve = in.readList(AucRocPoint::new); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(score); + out.writeList(curve); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("score", score); + if (curve.isEmpty() == false) { + builder.field("curve", curve); + } + builder.endObject(); + return builder; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java new file mode 100644 index 0000000000000..f594e7598fc20 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java @@ -0,0 +1,212 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** + * Evaluation of binary soft classification methods, e.g. outlier detection. + * This is useful to evaluate problems where a model outputs a probability of whether + * a data frame row belongs to one of two groups. + */ +public class BinarySoftClassification implements Evaluation { + + public static final ParseField NAME = new ParseField("binary_soft_classification"); + + private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); + private static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field"); + private static final ParseField METRICS = new ParseField("metrics"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME.getPreferredName(), a -> new BinarySoftClassification((String) a[0], (String) a[1], (List) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_PROBABILITY_FIELD); + PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), + (p, c, n) -> p.namedObject(SoftClassificationMetric.class, n, null), METRICS); + } + + public static BinarySoftClassification fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** + * The field where the actual class is marked up. + * The value of this field is assumed to either be 1 or 0, or true or false. + */ + private final String actualField; + + /** + * The field of the predicted probability in [0.0, 1.0]. + */ + private final String predictedProbabilityField; + + /** + * The list of metrics to calculate + */ + private final List metrics; + + public BinarySoftClassification(String actualField, String predictedProbabilityField, + @Nullable List metrics) { + this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); + this.predictedProbabilityField = ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD); + this.metrics = initMetrics(metrics); + } + + private static List initMetrics(@Nullable List parsedMetrics) { + List metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics; + if (metrics.isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName()); + } + Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getMetricName)); + return metrics; + } + + private static List defaultMetrics() { + List defaultMetrics = new ArrayList<>(4); + defaultMetrics.add(new AucRoc(false)); + defaultMetrics.add(new Precision(Arrays.asList(0.25, 0.5, 0.75))); + defaultMetrics.add(new Recall(Arrays.asList(0.25, 0.5, 0.75))); + defaultMetrics.add(new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75))); + return defaultMetrics; + } + + public BinarySoftClassification(StreamInput in) throws IOException { + this.actualField = in.readString(); + this.predictedProbabilityField = in.readString(); + this.metrics = in.readNamedWriteableList(SoftClassificationMetric.class); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(actualField); + out.writeString(predictedProbabilityField); + out.writeNamedWriteableList(metrics); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_FIELD.getPreferredName(), actualField); + builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField); + + builder.startObject(METRICS.getPreferredName()); + for (SoftClassificationMetric metric : metrics) { + builder.field(metric.getMetricName(), metric); + } + builder.endObject(); + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BinarySoftClassification that = (BinarySoftClassification) o; + return Objects.equals(actualField, that.actualField) + && Objects.equals(predictedProbabilityField, that.predictedProbabilityField) + && Objects.equals(metrics, that.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(actualField, predictedProbabilityField, metrics); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public SearchSourceBuilder buildSearch() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.size(0); + searchSourceBuilder.query(buildQuery()); + for (SoftClassificationMetric metric : metrics) { + List aggs = metric.aggs(actualField, Collections.singletonList(new BinaryClassInfo())); + aggs.forEach(searchSourceBuilder::aggregation); + } + return searchSourceBuilder; + } + + private QueryBuilder buildQuery() { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + boolQuery.filter(QueryBuilders.existsQuery(actualField)); + boolQuery.filter(QueryBuilders.existsQuery(predictedProbabilityField)); + return boolQuery; + } + + @Override + public void evaluate(SearchResponse searchResponse, ActionListener> listener) { + if (searchResponse.getHits().getTotalHits().value == 0) { + listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, + predictedProbabilityField)); + return; + } + + List results = new ArrayList<>(); + Aggregations aggs = searchResponse.getAggregations(); + BinaryClassInfo binaryClassInfo = new BinaryClassInfo(); + for (SoftClassificationMetric metric : metrics) { + results.add(metric.evaluate(binaryClassInfo, aggs)); + } + listener.onResponse(results); + } + + private class BinaryClassInfo implements SoftClassificationMetric.ClassInfo { + + private QueryBuilder matchingQuery = QueryBuilders.queryStringQuery(actualField + ": (1 OR true)"); + + @Override + public String getName() { + return String.valueOf(true); + } + + @Override + public QueryBuilder matchingQuery() { + return matchingQuery; + } + + @Override + public String getProbabilityField() { + return predictedProbabilityField; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java new file mode 100644 index 0000000000000..54f245962d515 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java @@ -0,0 +1,163 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class ConfusionMatrix extends AbstractConfusionMatrixMetric { + + public static final ParseField NAME = new ParseField("confusion_matrix"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), + a -> new ConfusionMatrix((List) a[0])); + + static { + PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), AT); + } + + public static ConfusionMatrix fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public ConfusionMatrix(List at) { + super(at.stream().mapToDouble(Double::doubleValue).toArray()); + } + + public ConfusionMatrix(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConfusionMatrix that = (ConfusionMatrix) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + @Override + protected List aggsAt(String labelField, List classInfos, double threshold) { + List aggs = new ArrayList<>(); + for (ClassInfo classInfo : classInfos) { + aggs.add(buildAgg(classInfo, threshold, Condition.TP)); + aggs.add(buildAgg(classInfo, threshold, Condition.FP)); + aggs.add(buildAgg(classInfo, threshold, Condition.TN)); + aggs.add(buildAgg(classInfo, threshold, Condition.FN)); + } + return aggs; + } + + @Override + public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + long[] tp = new long[thresholds.length]; + long[] fp = new long[thresholds.length]; + long[] tn = new long[thresholds.length]; + long[] fn = new long[thresholds.length]; + for (int i = 0; i < thresholds.length; i++) { + Filter tpAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.TP)); + Filter fpAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.FP)); + Filter tnAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.TN)); + Filter fnAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.FN)); + tp[i] = tpAgg.getDocCount(); + fp[i] = fpAgg.getDocCount(); + tn[i] = tnAgg.getDocCount(); + fn[i] = fnAgg.getDocCount(); + } + return new Result(thresholds, tp, fp, tn, fn); + } + + public static class Result implements EvaluationMetricResult { + + private final double[] thresholds; + private final long[] tp; + private final long[] fp; + private final long[] tn; + private final long[] fn; + + public Result(double[] thresholds, long[] tp, long[] fp, long[] tn, long[] fn) { + assert thresholds.length == tp.length; + assert thresholds.length == fp.length; + assert thresholds.length == tn.length; + assert thresholds.length == fn.length; + this.thresholds = thresholds; + this.tp = tp; + this.fp = fp; + this.tn = tn; + this.fn = fn; + } + + public Result(StreamInput in) throws IOException { + this.thresholds = in.readDoubleArray(); + this.tp = in.readLongArray(); + this.fp = in.readLongArray(); + this.tn = in.readLongArray(); + this.fn = in.readLongArray(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDoubleArray(thresholds); + out.writeLongArray(tp); + out.writeLongArray(fp); + out.writeLongArray(tn); + out.writeLongArray(fn); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + for (int i = 0; i < thresholds.length; i++) { + builder.startObject(String.valueOf(thresholds[i])); + builder.field("tp", tp[i]); + builder.field("fp", fp[i]); + builder.field("tn", tn[i]); + builder.field("fn", fn[i]); + builder.endObject(); + } + builder.endObject(); + return builder; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java new file mode 100644 index 0000000000000..d38a52bb203e8 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class Precision extends AbstractConfusionMatrixMetric { + + public static final ParseField NAME = new ParseField("precision"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), + a -> new Precision((List) a[0])); + + static { + PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), AT); + } + + public static Precision fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public Precision(List at) { + super(at.stream().mapToDouble(Double::doubleValue).toArray()); + } + + public Precision(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Precision that = (Precision) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + @Override + protected List aggsAt(String labelField, List classInfos, double threshold) { + List aggs = new ArrayList<>(); + for (ClassInfo classInfo : classInfos) { + aggs.add(buildAgg(classInfo, threshold, Condition.TP)); + aggs.add(buildAgg(classInfo, threshold, Condition.FP)); + } + return aggs; + } + + @Override + public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + double[] precisions = new double[thresholds.length]; + for (int i = 0; i < precisions.length; i++) { + double threshold = thresholds[i]; + Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP)); + Filter fpAgg = aggs.get(aggName(classInfo, threshold, Condition.FP)); + long tp = tpAgg.getDocCount(); + long fp = fpAgg.getDocCount(); + precisions[i] = tp + fp == 0 ? 0.0 : (double) tp / (tp + fp); + } + return new ScoreByThresholdResult(NAME.getPreferredName(), thresholds, precisions); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java new file mode 100644 index 0000000000000..5c4ab57241d95 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class Recall extends AbstractConfusionMatrixMetric { + + public static final ParseField NAME = new ParseField("recall"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), + a -> new Recall((List) a[0])); + + static { + PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), AT); + } + + public static Recall fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public Recall(List at) { + super(at.stream().mapToDouble(Double::doubleValue).toArray()); + } + + public Recall(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Recall that = (Recall) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + @Override + protected List aggsAt(String actualField, List classInfos, double threshold) { + List aggs = new ArrayList<>(); + for (ClassInfo classInfo: classInfos) { + aggs.add(buildAgg(classInfo, threshold, Condition.TP)); + aggs.add(buildAgg(classInfo, threshold, Condition.FN)); + } + return aggs; + } + + @Override + public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + double[] recalls = new double[thresholds.length]; + for (int i = 0; i < recalls.length; i++) { + double threshold = thresholds[i]; + Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP)); + Filter fnAgg =aggs.get(aggName(classInfo, threshold, Condition.FN)); + long tp = tpAgg.getDocCount(); + long fn = fnAgg.getDocCount(); + recalls[i] = tp + fn == 0 ? 0.0 : (double) tp / (tp + fn); + } + return new ScoreByThresholdResult(NAME.getPreferredName(), thresholds, recalls); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java new file mode 100644 index 0000000000000..bd6b6e7db25a1 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.Objects; + +public class ScoreByThresholdResult implements EvaluationMetricResult { + + public static final String NAME = "score_by_threshold_result"; + + private final String name; + private final double[] thresholds; + private final double[] scores; + + public ScoreByThresholdResult(String name, double[] thresholds, double[] scores) { + assert thresholds.length == scores.length; + this.name = Objects.requireNonNull(name); + this.thresholds = thresholds; + this.scores = scores; + } + + public ScoreByThresholdResult(StreamInput in) throws IOException { + this.name = in.readString(); + this.thresholds = in.readDoubleArray(); + this.scores = in.readDoubleArray(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return name; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + out.writeDoubleArray(thresholds); + out.writeDoubleArray(scores); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + for (int i = 0; i < thresholds.length; i++) { + builder.field(String.valueOf(thresholds[i]), scores[i]); + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java new file mode 100644 index 0000000000000..dfb256e9b52f2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.util.List; + +public interface SoftClassificationMetric extends ToXContentObject, NamedWriteable { + + /** + * The information of a specific class + */ + interface ClassInfo { + + /** + * Returns the class name + */ + String getName(); + + /** + * Returns a query that matches documents of the class + */ + QueryBuilder matchingQuery(); + + /** + * Returns the field that has the probability to be of the class + */ + String getProbabilityField(); + } + + /** + * Returns the name of the metric (which may differ to the writeable name) + */ + String getMetricName(); + + /** + * Builds the aggregation that collect required data to compute the metric + * @param actualField the field that stores the actual class + * @param classInfos the information of each class to compute the metric for + * @return the aggregations required to compute the metric + */ + List aggs(String actualField, List classInfos); + + /** + * Calculates the metric result for a given class + * @param classInfo the class to calculate the metric for + * @param aggs the aggregations + * @return the metric result + */ + EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 22eb0dc357bed..417184f8a752b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -50,6 +50,10 @@ public final class Messages { "Datafeed frequency [{0}] must be a multiple of the aggregation interval [{1}]"; public static final String DATAFEED_ID_ALREADY_TAKEN = "A datafeed with id [{0}] already exists"; + public static final String DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT = "Data Frame Analytics config query is not parsable"; + public static final String DATA_FRAME_ANALYTICS_BAD_FIELD_FILTER = + "No compatible fields could be detected in index [{0}] with name [{1}]"; + public static final String FILTER_CANNOT_DELETE = "Cannot delete filter [{0}] currently used by jobs {1}"; public static final String FILTER_CONTAINS_TOO_MANY_ITEMS = "Filter [{0}] contains too many items; up to [{1}] items are allowed"; public static final String FILTER_NOT_FOUND = "No filter with id [{0}] exists"; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index 19cb985c588bd..bc69f4b5d5e20 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -26,6 +26,10 @@ import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DelayedDataCheckConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -144,6 +148,7 @@ public static XContentBuilder configMapping() throws IOException { addJobConfigFields(builder); addDatafeedConfigFields(builder); + addDataFrameAnalyticsFields(builder); builder.endObject() .endObject() @@ -386,6 +391,52 @@ public static void addDatafeedConfigFields(XContentBuilder builder) throws IOExc .endObject(); } + public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws IOException { + builder.startObject(DataFrameAnalyticsConfig.ID.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(DataFrameAnalyticsConfig.SOURCE.getPreferredName()) + .startObject(PROPERTIES) + .startObject(DataFrameAnalyticsSource.INDEX.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(DataFrameAnalyticsSource.QUERY.getPreferredName()) + .field(ENABLED, false) + .endObject() + .endObject() + .endObject() + .startObject(DataFrameAnalyticsConfig.DEST.getPreferredName()) + .startObject(PROPERTIES) + .startObject(DataFrameAnalyticsDest.INDEX.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .endObject() + .endObject() + .startObject(DataFrameAnalyticsConfig.ANALYZED_FIELDS.getPreferredName()) + .field(ENABLED, false) + .endObject() + .startObject(DataFrameAnalyticsConfig.ANALYSIS.getPreferredName()) + .startObject(PROPERTIES) + .startObject(OutlierDetection.NAME.getPreferredName()) + .startObject(PROPERTIES) + .startObject(OutlierDetection.N_NEIGHBORS.getPreferredName()) + .field(TYPE, INTEGER) + .endObject() + .startObject(OutlierDetection.METHOD.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(OutlierDetection.MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + } + /** * Creates a default mapping which has a dynamic template that * treats all dynamically added fields as keywords. This is needed diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index bb16436df35c3..39036abb693b0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -9,6 +9,10 @@ import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DelayedDataCheckConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -268,6 +272,20 @@ public final class ReservedFieldNames { ChunkingConfig.MODE_FIELD.getPreferredName(), ChunkingConfig.TIME_SPAN_FIELD.getPreferredName(), + DataFrameAnalyticsConfig.ID.getPreferredName(), + DataFrameAnalyticsConfig.SOURCE.getPreferredName(), + DataFrameAnalyticsConfig.DEST.getPreferredName(), + DataFrameAnalyticsConfig.ANALYSIS.getPreferredName(), + DataFrameAnalyticsConfig.ANALYZED_FIELDS.getPreferredName(), + DataFrameAnalyticsDest.INDEX.getPreferredName(), + DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName(), + DataFrameAnalyticsSource.INDEX.getPreferredName(), + DataFrameAnalyticsSource.QUERY.getPreferredName(), + OutlierDetection.NAME.getPreferredName(), + OutlierDetection.N_NEIGHBORS.getPreferredName(), + OutlierDetection.METHOD.getPreferredName(), + OutlierDetection.MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName(), + ElasticsearchMappings.CONFIG_TYPE, GetResult._ID, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/process/writer/RecordWriter.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/process/writer/RecordWriter.java index b66fd948a5a83..2d4c636172eca 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/process/writer/RecordWriter.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/process/writer/RecordWriter.java @@ -10,7 +10,7 @@ /** * Interface for classes that write arrays of strings to the - * Ml analytics processes. + * Ml data frame analytics processes. */ public interface RecordWriter { /** diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java index 47c0d4f64f96f..320eace983590 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java @@ -10,6 +10,7 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.search.ShardSearchFailure; +import org.elasticsearch.common.ParseField; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.xpack.core.ml.job.messages.Messages; @@ -34,6 +35,14 @@ public static ResourceAlreadyExistsException datafeedAlreadyExists(String datafe return new ResourceAlreadyExistsException(Messages.getMessage(Messages.DATAFEED_ID_ALREADY_TAKEN, datafeedId)); } + public static ResourceNotFoundException missingDataFrameAnalytics(String id) { + return new ResourceNotFoundException("No known data frame analytics with id [{}]", id); + } + + public static ResourceAlreadyExistsException dataFrameAnalyticsAlreadyExists(String id) { + return new ResourceAlreadyExistsException("A data frame analytics with id [{}] already exists", id); + } + public static ElasticsearchException serverError(String msg) { return new ElasticsearchException(msg); } @@ -86,4 +95,8 @@ public static T requireNonNull(T obj, String paramName) { } return obj; } + + public static T requireNonNull(T obj, ParseField paramName) { + return requireNonNull(obj, paramName.getPreferredName()); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/QueryProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java similarity index 86% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/QueryProvider.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java index ff6d2f595af81..d20b64a4ce8b5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/QueryProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.core.ml.datafeed; +package org.elasticsearch.xpack.core.ml.utils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -19,9 +19,6 @@ import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.xpack.core.ml.job.messages.Messages; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; import java.io.IOException; import java.util.Collections; @@ -29,22 +26,22 @@ import java.util.Map; import java.util.Objects; -class QueryProvider implements Writeable, ToXContentObject { +public class QueryProvider implements Writeable, ToXContentObject { - private static final Logger logger = LogManager.getLogger(AggProvider.class); + private static final Logger logger = LogManager.getLogger(QueryProvider.class); private Exception parsingException; private QueryBuilder parsedQuery; private Map query; - static QueryProvider defaultQuery() { + public static QueryProvider defaultQuery() { return new QueryProvider( Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()), QueryBuilders.matchAllQuery(), null); } - static QueryProvider fromXContent(XContentParser parser, boolean lenient) throws IOException { + public static QueryProvider fromXContent(XContentParser parser, boolean lenient, String failureMessage) throws IOException { Map query = parser.mapOrdered(); QueryBuilder parsedQuery = null; Exception exception = null; @@ -56,15 +53,15 @@ static QueryProvider fromXContent(XContentParser parser, boolean lenient) throws } exception = ex; if (lenient) { - logger.warn(Messages.DATAFEED_CONFIG_QUERY_BAD_FORMAT, ex); + logger.warn(failureMessage, ex); } else { - throw ExceptionsHelper.badRequestException(Messages.DATAFEED_CONFIG_QUERY_BAD_FORMAT, ex); + throw ExceptionsHelper.badRequestException(failureMessage, ex); } } return new QueryProvider(query, parsedQuery, exception); } - static QueryProvider fromParsedQuery(QueryBuilder parsedQuery) throws IOException { + public static QueryProvider fromParsedQuery(QueryBuilder parsedQuery) throws IOException { return parsedQuery == null ? null : new QueryProvider( @@ -73,7 +70,7 @@ static QueryProvider fromParsedQuery(QueryBuilder parsedQuery) throws IOExceptio null); } - static QueryProvider fromStream(StreamInput in) throws IOException { + public static QueryProvider fromStream(StreamInput in) throws IOException { if (in.getVersion().onOrAfter(Version.V_6_7_0)) { // Has our bug fix for query/agg providers return new QueryProvider(in.readMap(), in.readOptionalNamedWriteable(QueryBuilder.class), in.readException()); } else if (in.getVersion().onOrAfter(Version.V_6_6_0)) { // Has the bug, but supports lazy objects @@ -89,7 +86,7 @@ static QueryProvider fromStream(StreamInput in) throws IOException { this.parsingException = parsingException; } - QueryProvider(QueryProvider other) { + public QueryProvider(QueryProvider other) { this(other.query, other.parsedQuery, other.parsingException); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractSerializingDataFrameTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractSerializingDataFrameTestCase.java index 8b633cdfc26d5..14cbdef148ca4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractSerializingDataFrameTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractSerializingDataFrameTestCase.java @@ -13,6 +13,10 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.dataframe.DataFrameField; +import org.elasticsearch.xpack.core.dataframe.DataFrameNamedXContentProvider; +import org.elasticsearch.xpack.core.dataframe.transforms.SyncConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; import org.junit.Before; import java.util.List; @@ -30,7 +34,11 @@ public void registerNamedObjects() { SearchModule searchModule = new SearchModule(Settings.EMPTY, false, emptyList()); List namedWriteables = searchModule.getNamedWriteables(); + namedWriteables.add(new NamedWriteableRegistry.Entry(SyncConfig.class, DataFrameField.TIME_BASED_SYNC.getPreferredName(), + TimeSyncConfig::new)); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); namedWriteableRegistry = new NamedWriteableRegistry(namedWriteables); namedXContentRegistry = new NamedXContentRegistry(namedXContents); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractWireSerializingDataFrameTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractWireSerializingDataFrameTestCase.java index 91a7ec54dd256..47d7860b71da0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractWireSerializingDataFrameTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractWireSerializingDataFrameTestCase.java @@ -12,6 +12,10 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.dataframe.DataFrameField; +import org.elasticsearch.xpack.core.dataframe.DataFrameNamedXContentProvider; +import org.elasticsearch.xpack.core.dataframe.transforms.SyncConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; import org.junit.Before; import java.util.List; @@ -30,7 +34,11 @@ public void registerNamedObjects() { SearchModule searchModule = new SearchModule(Settings.EMPTY, false, emptyList()); List namedWriteables = searchModule.getNamedWriteables(); + namedWriteables.add(new NamedWriteableRegistry.Entry(SyncConfig.class, DataFrameField.TIME_BASED_SYNC.getPreferredName(), + TimeSyncConfig::new)); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); namedWriteableRegistry = new NamedWriteableRegistry(namedWriteables); namedXContentRegistry = new NamedXContentRegistry(namedXContents); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/PreviewDataFrameTransformActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/PreviewDataFrameTransformActionRequestTests.java index c3a921a90d26b..ea6f2a47f4692 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/PreviewDataFrameTransformActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/PreviewDataFrameTransformActionRequestTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.dataframe.action.PreviewDataFrameTransformAction.Request; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfigTests; import org.elasticsearch.xpack.core.dataframe.transforms.DestConfig; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfigTests; @@ -39,9 +40,14 @@ protected boolean supportsUnknownFields() { @Override protected Request createTestInstance() { - DataFrameTransformConfig config = new DataFrameTransformConfig("transform-preview", randomSourceConfig(), + DataFrameTransformConfig config = new DataFrameTransformConfig( + "transform-preview", + randomSourceConfig(), new DestConfig("unused-transform-preview-index", null), - null, PivotConfigTests.randomPivotConfig(), null); + randomBoolean() ? DataFrameTransformConfigTests.randomSyncConfig() : null, + null, + PivotConfigTests.randomPivotConfig(), + null); return new Request(config); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/AbstractSerializingDataFrameTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/AbstractSerializingDataFrameTestCase.java index 2b64fadac051a..79edb8084551d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/AbstractSerializingDataFrameTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/AbstractSerializingDataFrameTestCase.java @@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.BaseAggregationBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.dataframe.DataFrameField; +import org.elasticsearch.xpack.core.dataframe.DataFrameNamedXContentProvider; import org.junit.Before; import java.util.Collections; @@ -48,12 +49,15 @@ public void registerAggregationNamedObjects() throws Exception { MockDeprecatedQueryBuilder::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(AggregationBuilder.class, MockDeprecatedAggregationBuilder.NAME, MockDeprecatedAggregationBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SyncConfig.class, DataFrameField.TIME_BASED_SYNC.getPreferredName(), + TimeSyncConfig::new)); List namedXContents = searchModule.getNamedXContents(); namedXContents.add(new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(MockDeprecatedQueryBuilder.NAME), (p, c) -> MockDeprecatedQueryBuilder.fromXContent(p))); namedXContents.add(new NamedXContentRegistry.Entry(BaseAggregationBuilder.class, new ParseField(MockDeprecatedAggregationBuilder.NAME), (p, c) -> MockDeprecatedAggregationBuilder.fromXContent(p))); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); namedWriteableRegistry = new NamedWriteableRegistry(namedWriteables); namedXContentRegistry = new NamedXContentRegistry(namedXContents); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfigTests.java index 907c8eb98e69f..dd5b5c9ff8841 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfigTests.java @@ -46,6 +46,7 @@ public static DataFrameTransformConfig randomDataFrameTransformConfigWithoutHead return new DataFrameTransformConfig(id, randomSourceConfig(), randomDestConfig(), + randomBoolean() ? null : randomSyncConfig(), null, PivotConfigTests.randomPivotConfig(), randomBoolean() ? null : randomAlphaOfLengthBetween(1, 1000), @@ -57,6 +58,7 @@ public static DataFrameTransformConfig randomDataFrameTransformConfig(String id) return new DataFrameTransformConfig(id, randomSourceConfig(), randomDestConfig(), + randomBoolean() ? null : randomSyncConfig(), randomHeaders(), PivotConfigTests.randomPivotConfig(), randomBoolean() ? null : randomAlphaOfLengthBetween(1, 1000), @@ -66,13 +68,17 @@ public static DataFrameTransformConfig randomDataFrameTransformConfig(String id) public static DataFrameTransformConfig randomInvalidDataFrameTransformConfig() { if (randomBoolean()) { - return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomInvalidSourceConfig(), - randomDestConfig(), randomHeaders(), PivotConfigTests.randomPivotConfig(), - randomBoolean() ? null : randomAlphaOfLengthBetween(1, 100)); + return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomInvalidSourceConfig(), randomDestConfig(), + randomBoolean() ? randomSyncConfig() : null, randomHeaders(), PivotConfigTests.randomPivotConfig(), + randomBoolean() ? null : randomAlphaOfLengthBetween(1, 1000)); } // else - return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomSourceConfig(), - randomDestConfig(), randomHeaders(), PivotConfigTests.randomInvalidPivotConfig(), - randomBoolean() ? null : randomAlphaOfLengthBetween(1, 100)); + return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomSourceConfig(), randomDestConfig(), + randomBoolean() ? randomSyncConfig() : null, randomHeaders(), PivotConfigTests.randomInvalidPivotConfig(), + randomBoolean() ? null : randomAlphaOfLengthBetween(1, 1000)); + } + + public static SyncConfig randomSyncConfig() { + return TimeSyncConfigTests.randomTimeSyncConfig(); } @Before @@ -223,11 +229,11 @@ public void testXContentForInternalStorage() throws IOException { public void testMaxLengthDescription() { IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new DataFrameTransformConfig("id", - randomSourceConfig(), randomDestConfig(), null, PivotConfigTests.randomPivotConfig(), randomAlphaOfLength(1001))); + randomSourceConfig(), randomDestConfig(), null, null, PivotConfigTests.randomPivotConfig(), randomAlphaOfLength(1001))); assertThat(exception.getMessage(), equalTo("[description] must be less than 1000 characters in length.")); String description = randomAlphaOfLength(1000); DataFrameTransformConfig config = new DataFrameTransformConfig("id", - randomSourceConfig(), randomDestConfig(), null, PivotConfigTests.randomPivotConfig(), description); + randomSourceConfig(), randomDestConfig(), null, null, PivotConfigTests.randomPivotConfig(), description); assertThat(description, equalTo(config.getDescription())); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfigTests.java new file mode 100644 index 0000000000000..763e13e77aee0 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfigTests.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.dataframe.transforms; + +import org.elasticsearch.common.io.stream.Writeable.Reader; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; + +import java.io.IOException; + +public class TimeSyncConfigTests extends AbstractSerializingTestCase { + + public static TimeSyncConfig randomTimeSyncConfig() { + return new TimeSyncConfig(randomAlphaOfLengthBetween(1, 10), new TimeValue(randomNonNegativeLong())); + } + + @Override + protected TimeSyncConfig doParseInstance(XContentParser parser) throws IOException { + return TimeSyncConfig.fromXContent(parser, false); + } + + @Override + protected TimeSyncConfig createTestInstance() { + return randomTimeSyncConfig(); + } + + @Override + protected Reader instanceReader() { + return TimeSyncConfig::new; + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java index 3afe76b8b171f..f2015b1a2bbb5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java @@ -22,6 +22,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; public class MlTasksTests extends ESTestCase { public void testGetJobState() { @@ -161,4 +162,10 @@ public void testUnallocatedDatafeedIds() { assertThat(MlTasks.unallocatedDatafeedIds(tasksBuilder.build(), nodes), containsInAnyOrder("datafeed_without_assignment", "datafeed_without_node")); } + + public void testDataFrameAnalyticsTaskIds() { + String taskId = MlTasks.dataFrameAnalyticsTaskId("foo"); + assertThat(taskId, equalTo("data_frame_analytics-foo")); + assertThat(MlTasks.dataFrameAnalyticsIdFromTaskId(taskId), equalTo("foo")); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java new file mode 100644 index 0000000000000..e899b7e6642da --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractStreamableXContentTestCase; +import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Request; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassificationTests; + +import java.util.ArrayList; +import java.util.List; + +public class EvaluateDataFrameActionRequestTests extends AbstractStreamableXContentTestCase { + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected Request createTestInstance() { + Request request = new Request(); + int indicesCount = randomIntBetween(1, 5); + List indices = new ArrayList<>(indicesCount); + for (int i = 0; i < indicesCount; i++) { + indices.add(randomAlphaOfLength(10)); + } + request.setIndices(indices); + request.setEvaluation(BinarySoftClassificationTests.createRandom()); + return request; + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected Request createBlankInstance() { + return new Request(); + } + + @Override + protected Request doParseInstance(XContentParser parser) { + return Request.parseRequest(parser); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java new file mode 100644 index 0000000000000..38a3396316602 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractStreamableTestCase; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction.Response; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class GetDataFrameAnalyticsActionResponseTests extends AbstractStreamableTestCase { + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List namedWriteables = new ArrayList<>(); + namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + return new NamedWriteableRegistry(namedWriteables); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected Response createTestInstance() { + int listSize = randomInt(10); + List analytics = new ArrayList<>(listSize); + for (int j = 0; j < listSize; j++) { + analytics.add(DataFrameAnalyticsConfigTests.createRandom(DataFrameAnalyticsConfigTests.randomValidId())); + } + return new Response(new QueryPage<>(analytics, analytics.size(), Response.RESULTS_FIELD)); + } + + @Override + protected Response createBlankInstance() { + return new Response(); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..438474076c910 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsRequestTests.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction.Request; + +public class GetDataFrameAnalyticsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + Request request = new Request(); + request.setResourceId(randomAlphaOfLength(20)); + request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100))); + return request; + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java new file mode 100644 index 0000000000000..e01618520f5a8 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Response; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; + +import java.util.ArrayList; +import java.util.List; + +public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireSerializingTestCase { + + @Override + protected Response createTestInstance() { + int listSize = randomInt(10); + List analytics = new ArrayList<>(listSize); + for (int j = 0; j < listSize; j++) { + Integer progressPercentage = randomBoolean() ? null : randomIntBetween(0, 100); + Response.Stats stats = new Response.Stats(DataFrameAnalyticsConfigTests.randomValidId(), + randomFrom(DataFrameAnalyticsState.values()), progressPercentage, null, randomAlphaOfLength(20)); + analytics.add(stats); + } + return new Response(new QueryPage<>(analytics, analytics.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD)); + } + + @Override + protected Writeable.Reader instanceReader() { + return Response::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsRequestTests.java new file mode 100644 index 0000000000000..918d04873ef2c --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsRequestTests.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Request; + +public class GetDataFrameAnalyticsStatsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + Request request = new Request(randomAlphaOfLength(20)); + request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100))); + return request; + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java new file mode 100644 index 0000000000000..d00fa4384be8a --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractStreamableXContentTestCase; +import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Request; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class PutDataFrameAnalyticsActionRequestTests extends AbstractStreamableXContentTestCase { + + private String id; + + @Before + public void setUpId() { + id = DataFrameAnalyticsConfigTests.randomValidId(); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List namedWriteables = new ArrayList<>(); + namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + return new NamedWriteableRegistry(namedWriteables); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected Request createTestInstance() { + return new Request(DataFrameAnalyticsConfigTests.createRandom(id)); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected Request createBlankInstance() { + return new Request(); + } + + @Override + protected Request doParseInstance(XContentParser parser) { + return Request.parseRequest(id, parser); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java new file mode 100644 index 0000000000000..c9f678b13df2a --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractStreamableTestCase; +import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Response; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class PutDataFrameAnalyticsActionResponseTests extends AbstractStreamableTestCase { + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List namedWriteables = new ArrayList<>(); + namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + return new NamedWriteableRegistry(namedWriteables); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected Response createTestInstance() { + return new Response(DataFrameAnalyticsConfigTests.createRandom(DataFrameAnalyticsConfigTests.randomValidId())); + } + + @Override + protected Response createBlankInstance() { + return new Response(); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..a3db5833b820d --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction.Request; + +public class StartDataFrameAnalyticsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + Request request = new Request(randomAlphaOfLength(20)); + if (randomBoolean()) { + request.setTimeout(TimeValue.timeValueMillis(randomNonNegativeLong())); + } + return request; + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsActionResponseTests.java new file mode 100644 index 0000000000000..d06d8cb1a1860 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsActionResponseTests.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction.Response; + +public class StopDataFrameAnalyticsActionResponseTests extends AbstractWireSerializingTestCase { + + @Override + protected Response createTestInstance() { + return new Response(randomBoolean()); + } + + @Override + protected Writeable.Reader instanceReader() { + return Response::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..9c61164c5f02a --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsRequestTests.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction.Request; + +import java.util.HashSet; +import java.util.Set; + +public class StopDataFrameAnalyticsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + Request request = new Request(randomAlphaOfLength(20)); + if (randomBoolean()) { + request.setTimeout(TimeValue.timeValueMillis(randomNonNegativeLong())); + } + if (randomBoolean()) { + request.setAllowNoMatch(randomBoolean()); + } + int expandedIdsCount = randomIntBetween(0, 10); + Set expandedIds = new HashSet<>(); + for (int i = 0; i < expandedIdsCount; i++) { + expandedIds.add(randomAlphaOfLength(20)); + } + request.setExpandedIds(expandedIds); + return request; + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java index 6b664777a2d86..7afcc9799f770 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java @@ -46,6 +46,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig.Mode; import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; @@ -57,7 +58,7 @@ import java.util.List; import java.util.Map; -import static org.elasticsearch.xpack.core.ml.datafeed.QueryProviderTests.createRandomValidQueryProvider; +import static org.elasticsearch.xpack.core.ml.utils.QueryProviderTests.createRandomValidQueryProvider; import static org.elasticsearch.xpack.core.ml.job.messages.Messages.DATAFEED_AGGREGATIONS_INTERVAL_MUST_BE_GREATER_THAN_ZERO; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdateTests.java index 571c9e81a9068..969b4aef5ae9a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdateTests.java @@ -38,6 +38,7 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig.Mode; import org.elasticsearch.xpack.core.ml.job.config.JobTests; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; import java.io.IOException; @@ -47,7 +48,7 @@ import java.util.List; import static org.elasticsearch.xpack.core.ml.datafeed.AggProviderTests.createRandomValidAggProvider; -import static org.elasticsearch.xpack.core.ml.datafeed.QueryProviderTests.createRandomValidQueryProvider; +import static org.elasticsearch.xpack.core.ml.utils.QueryProviderTests.createRandomValidQueryProvider; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java new file mode 100644 index 0000000000000..a5df1f83c3d37 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -0,0 +1,251 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import com.carrotsearch.randomizedtesting.generators.CodepointSetGenerator; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.DeprecationHandler; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentParseException; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.startsWith; + +public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase { + + @Override + protected DataFrameAnalyticsConfig doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsConfig.STRICT_PARSER.apply(parser, null).build(); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List namedWriteables = new ArrayList<>(); + namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + return new NamedWriteableRegistry(namedWriteables); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected DataFrameAnalyticsConfig createTestInstance() { + return createRandom(randomValidId()); + } + + @Override + protected Writeable.Reader instanceReader() { + return DataFrameAnalyticsConfig::new; + } + + public static DataFrameAnalyticsConfig createRandom(String id) { + return createRandomBuilder(id).build(); + } + + public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id) { + DataFrameAnalyticsSource source = DataFrameAnalyticsSourceTests.createRandom(); + DataFrameAnalyticsDest dest = DataFrameAnalyticsDestTests.createRandom(); + DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder() + .setId(id) + .setAnalysis(OutlierDetectionTests.createRandom()) + .setSource(source) + .setDest(dest); + if (randomBoolean()) { + builder.setAnalyzedFields(new FetchSourceContext(true, + generateRandomStringArray(10, 10, false, false), + generateRandomStringArray(10, 10, false, false))); + } + if (randomBoolean()) { + builder.setModelMemoryLimit(new ByteSizeValue(randomIntBetween(1, 16), randomFrom(ByteSizeUnit.MB, ByteSizeUnit.GB))); + } + return builder; + } + + public static String randomValidId() { + CodepointSetGenerator generator = new CodepointSetGenerator("abcdefghijklmnopqrstuvwxyz".toCharArray()); + return generator.ofCodePointsLength(random(), 10, 10); + } + + private static final String ANACHRONISTIC_QUERY_DATA_FRAME_ANALYTICS = "{\n" + + " \"id\": \"old-data-frame\",\n" + + //query:match:type stopped being supported in 6.x + " \"source\": {\"index\":\"my-index\", \"query\": {\"match\" : {\"query\":\"fieldName\", \"type\": \"phrase\"}}},\n" + + " \"dest\": {\"index\":\"dest-index\"},\n" + + " \"analysis\": {\"outlier_detection\": {\"n_neighbors\": 10}}\n" + + "}"; + + private static final String MODERN_QUERY_DATA_FRAME_ANALYTICS = "{\n" + + " \"id\": \"data-frame\",\n" + + // match_all if parsed, adds default values in the options + " \"source\": {\"index\":\"my-index\", \"query\": {\"match_all\" : {}}},\n" + + " \"dest\": {\"index\":\"dest-index\"},\n" + + " \"analysis\": {\"outlier_detection\": {\"n_neighbors\": 10}}\n" + + "}"; + + public void testQueryConfigStoresUserInputOnly() throws IOException { + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + MODERN_QUERY_DATA_FRAME_ANALYTICS)) { + + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); + assertThat(config.getSource().getQuery(), equalTo(Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()))); + } + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + MODERN_QUERY_DATA_FRAME_ANALYTICS)) { + + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.STRICT_PARSER.apply(parser, null).build(); + assertThat(config.getSource().getQuery(), equalTo(Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()))); + } + } + + public void testPastQueryConfigParse() throws IOException { + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + ANACHRONISTIC_QUERY_DATA_FRAME_ANALYTICS)) { + + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); + ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> config.getSource().getParsedQuery()); + assertEquals("[match] query doesn't support multiple fields, found [query] and [type]", e.getMessage()); + } + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + ANACHRONISTIC_QUERY_DATA_FRAME_ANALYTICS)) { + + XContentParseException e = expectThrows(XContentParseException.class, + () -> DataFrameAnalyticsConfig.STRICT_PARSER.apply(parser, null).build()); + assertThat(e.getMessage(), containsString("[data_frame_analytics_config] failed to parse field [source]")); + } + } + + public void testToXContentForInternalStorage() throws IOException { + DataFrameAnalyticsConfig.Builder builder = createRandomBuilder("foo"); + + // headers are only persisted to cluster state + Map headers = new HashMap<>(); + headers.put("header-name", "header-value"); + builder.setHeaders(headers); + DataFrameAnalyticsConfig config = builder.build(); + + ToXContent.MapParams params = new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")); + + BytesReference forClusterstateXContent = XContentHelper.toXContent(config, XContentType.JSON, params, false); + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), LoggingDeprecationHandler.INSTANCE, forClusterstateXContent.streamInput()); + + DataFrameAnalyticsConfig parsedConfig = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); + assertThat(parsedConfig.getHeaders(), hasEntry("header-name", "header-value")); + + // headers are not written without the FOR_INTERNAL_STORAGE param + BytesReference nonClusterstateXContent = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); + parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), LoggingDeprecationHandler.INSTANCE, nonClusterstateXContent.streamInput()); + + parsedConfig = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); + assertThat(parsedConfig.getHeaders().entrySet(), hasSize(0)); + } + + public void testInvalidModelMemoryLimits() { + + DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder(); + + // All these are different ways of specifying a limit that is lower than the minimum + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(1048575, ByteSizeUnit.BYTES)))); + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(0, ByteSizeUnit.BYTES)))); + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(-1, ByteSizeUnit.BYTES)))); + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(1023, ByteSizeUnit.KB)))); + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(0, ByteSizeUnit.KB)))); + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(0, ByteSizeUnit.MB)))); + } + + public void testNoMemoryCapping() { + + DataFrameAnalyticsConfig uncapped = createRandom("foo"); + + ByteSizeValue unlimited = randomBoolean() ? null : ByteSizeValue.ZERO; + assertThat(uncapped.getModelMemoryLimit(), + equalTo(new DataFrameAnalyticsConfig.Builder(uncapped, unlimited).build().getModelMemoryLimit())); + } + + public void testMemoryCapping() { + + DataFrameAnalyticsConfig defaultLimitConfig = createRandomBuilder("foo").setModelMemoryLimit(null).build(); + + ByteSizeValue maxLimit = new ByteSizeValue(randomIntBetween(500, 1000), ByteSizeUnit.MB); + if (maxLimit.compareTo(defaultLimitConfig.getModelMemoryLimit()) < 0) { + assertThat(maxLimit, + equalTo(new DataFrameAnalyticsConfig.Builder(defaultLimitConfig, maxLimit).build().getModelMemoryLimit())); + } else { + assertThat(defaultLimitConfig.getModelMemoryLimit(), + equalTo(new DataFrameAnalyticsConfig.Builder(defaultLimitConfig, maxLimit).build().getModelMemoryLimit())); + } + } + + public void testExplicitModelMemoryLimitTooHigh() { + + ByteSizeValue configuredLimit = new ByteSizeValue(randomIntBetween(5, 10), ByteSizeUnit.GB); + DataFrameAnalyticsConfig explicitLimitConfig = createRandomBuilder("foo").setModelMemoryLimit(configuredLimit).build(); + + ByteSizeValue maxLimit = new ByteSizeValue(randomIntBetween(500, 1000), ByteSizeUnit.MB); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new DataFrameAnalyticsConfig.Builder(explicitLimitConfig, maxLimit).build()); + assertThat(e.getMessage(), startsWith("model_memory_limit")); + assertThat(e.getMessage(), containsString("must be less than the value of the xpack.ml.max_model_memory_limit setting")); + } + + public void assertTooSmall(IllegalArgumentException e) { + assertThat(e.getMessage(), is("[model_memory_limit] must be at least [1mb]")); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java new file mode 100644 index 0000000000000..bf8ce4c8a99b0 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.indices.InvalidIndexNameException; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; + +import static org.hamcrest.Matchers.equalTo; + +public class DataFrameAnalyticsDestTests extends AbstractSerializingTestCase { + + @Override + protected DataFrameAnalyticsDest doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsDest.createParser(false).apply(parser, null); + } + + @Override + protected DataFrameAnalyticsDest createTestInstance() { + return createRandom(); + } + + public static DataFrameAnalyticsDest createRandom() { + String index = randomAlphaOfLength(10); + String resultsField = randomBoolean() ? null : randomAlphaOfLength(10); + return new DataFrameAnalyticsDest(index, resultsField); + } + + @Override + protected Writeable.Reader instanceReader() { + return DataFrameAnalyticsDest::new; + } + + public void testValidate_GivenIndexWithFunkyChars() { + expectThrows(InvalidIndexNameException.class, () -> new DataFrameAnalyticsDest("