From 11e5dada8324eaad44cf9673478a102abc2997f2 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Fri, 2 Apr 2021 15:48:57 -0400 Subject: [PATCH 01/20] [ML] Log categorization multi-bucket agg --- docs/build.gradle | 34 ++ docs/reference/aggregations/bucket.asciidoc | 2 + .../categorize-text-aggregation.asciidoc | 396 +++++++++++++++++ .../elasticsearch/search/SearchService.java | 1 + .../ParsedMultiBucketAggregation.java | 2 +- .../support/AggregationContext.java | 13 + .../aggregations/AggregatorTestCase.java | 20 + .../test/AbstractBuilderTestCase.java | 10 +- .../ml/qa/ml-with-security/build.gradle | 3 + .../CategorizationAggregationIT.java | 160 +++++++ .../xpack/ml/MachineLearning.java | 15 + .../CategorizationTokenTree.java | 161 +++++++ .../CategorizeTextAggregationBuilder.java | 275 ++++++++++++ .../CategorizeTextAggregator.java | 208 +++++++++ .../CategorizeTextAggregatorFactory.java | 90 ++++ .../InternalCategorizationAggregation.java | 415 ++++++++++++++++++ .../ml/aggs/categorization/LogGroup.java | 105 +++++ .../ml/aggs/categorization/TreeNode.java | 398 +++++++++++++++++ .../aggs/categorization/TreeNodeFactory.java | 16 + .../CategorizationAnalyzer.java | 5 + .../xpack/ml/LocalStateMachineLearning.java | 18 + ...CategorizeTextAggregationBuilderTests.java | 55 +++ .../CategorizeTextAggregatorTests.java | 297 +++++++++++++ .../categorization/InnerTreeNodeTests.java | 109 +++++ ...nternalCategorizationAggregationTests.java | 117 +++++ .../categorization/LeafTreeNodeTests.java | 68 +++ .../ml/aggs/categorization/LogGroupTests.java | 51 +++ .../categorization/ParsedCategorization.java | 113 +++++ .../test/ml/categorization_agg.yml | 136 ++++++ 29 files changed, 3290 insertions(+), 3 deletions(-) create mode 100644 docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc create mode 100644 x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/LogGroup.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LogGroupTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/ParsedCategorization.java create mode 100644 x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml diff --git a/docs/build.gradle b/docs/build.gradle index ec7168061815e..c9e55a651dfa3 100644 --- a/docs/build.gradle +++ b/docs/build.gradle @@ -1071,6 +1071,40 @@ buildRestTests.setups['farequote_datafeed'] = buildRestTests.setups['farequote_j "indexes":"farequote" } ''' +buildRestTests.setups['categorize_text'] = ''' + - do: + indices.create: + index: log-messages + body: + settings: + number_of_shards: 1 + number_of_replicas: 0 + mappings: + metric: + properties: + time: + type: date + message: + type: text + + - do: + bulk: + index: log-messages + refresh: true + body: | + {"index": {}} + {"time":"2016-02-07T00:00:00+0000", "message": "2016-02-07T00:00:00+0000 Node 3 shutting down"} + {"index": {}} + {"time":"2016-02-07T00:00:00+0000", "message": "2016-02-07T00:00:00+0000 Node 5 starting up"} + {"index": {}} + {"time":"2016-02-07T00:00:00+0000", "message": "2016-02-07T00:00:00+0000 Node 4 shutting down"} + {"index": {}} + {"time":"2016-02-08T00:00:00+0000", "message": "2016-02-08T00:00:00+0000 Node 5 shutting down"} + {"index": {}} + {"time":"2016-02-08T00:00:00+0000", "message": "2016-02-08T00:00:00+0000 User foo_325 logging on"} + {"index": {}} + {"time":"2016-02-08T00:00:00+0000", "message": "2016-02-08T00:00:00+0000 User foo_864 logged off"} +''' buildRestTests.setups['server_metrics_index'] = ''' - do: indices.create: diff --git a/docs/reference/aggregations/bucket.asciidoc b/docs/reference/aggregations/bucket.asciidoc index 302e196caf3ce..dfdaca18e6cfb 100644 --- a/docs/reference/aggregations/bucket.asciidoc +++ b/docs/reference/aggregations/bucket.asciidoc @@ -20,6 +20,8 @@ include::bucket/adjacency-matrix-aggregation.asciidoc[] include::bucket/autodatehistogram-aggregation.asciidoc[] +include::bucket/categorize-text-aggregation.asciidoc[] + include::bucket/children-aggregation.asciidoc[] include::bucket/composite-aggregation.asciidoc[] diff --git a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc new file mode 100644 index 0000000000000..5aa486b9dcf40 --- /dev/null +++ b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc @@ -0,0 +1,396 @@ +[[search-aggregations-bucket-categorize-text-aggregation]] +=== Categorize text aggregation +++++ +Categorize text +++++ + +experimental::[] + +A multi-bucket aggregation that groups semi-structured text into buckets. Each `text` field is re-analyzed +using a custom analyzer. The resulting tokens are then categorized creating buckets of similarly formatted +text values. This aggregation works best with machine generated text like system logs. + +WARNING: Re-analyzing _large_ result sets will require a lot of time and memory. This aggregation should be + used in conjunction with <>. + +[[bucket-categorize-text-agg-syntax]] +==== Parameters + +`field`:: +(Required, string) +The semi-structured text field to categorize. + +`max_children`:: +(Optional, integer, default: `100`) +The maximum number of unique tokens at any given layer of the tokenization tree. +Must be larger than 1. The smaller the value, the more broad the text categories. +Larger values may cause the aggregation to more memory and run more slowly + +`max_depth`:: +(Optional, integer, default: `5`) +The maximum number of tokens matched on before attempting to merge categories. +Larger values may cause the aggregation to more memory and run more slowly. + +`similarity_threshold`:: +(Optional, double, default: `0.5`) +The minimum percentage of tokens that must match for text to be added to the +category bucket. +Must be between 0.1 and 1.0. The larger the value the more restrictive the log categories. +Larger values may increase memory usage. + +`categorization_filters`:: +(Optional, array of strings, default: `[]`) +This property expects an array of regular expressions. The expressions +are used to filter out matching sequences from the categorization field values. +You can use this functionality to fine tune the categorization by excluding +sequences from consideration when categories are defined. For example, you can +exclude SQL statements that appear in your log files. + +`shard_size`:: +(Optional, integer) +The number of categorization buckets to return from each shard before merging +all the results. + +`size`:: +(Optional, integer, default: `10`) +The number of buckets to return. + +`min_doc_count`:: +(Optional, integer) +The minimum number of documents for a bucket to be returned to the results. + +`shard_min_doc_count`:: +(Optional, integer) +The minimum number of documents for a bucket to be returned from the shard before +merging. + +==== Basic use + +Example: + +[source,console,id=categorize-text-aggregation-example] +-------------------------------------------------- +POST log-messages/_search?filter_path=aggregations +{ + "aggs": { + "categories": { + "categorize_text": { + "field": "message" + } + } + } +} +-------------------------------------------------- +// TEST[setup:categorize_text] + +Response: + +[source,console-result] +-------------------------------------------------- +{ + "aggregations" : { + "categories" : { + "buckets" : [ + { + "doc_count" : 3, + "key" : "Node shutting down" + }, + { + "doc_count" : 1, + "key" : "User foo_864 logged off" + }, + { + "doc_count" : 1, + "key" : "User foo_325 logging on" + }, + { + "doc_count" : 1, + "key" : "Node starting up" + } + ] + } + } +} +-------------------------------------------------- +// TESTRESPONSE + + +Here is an example using `categorization_filters` + +[source,console,id=categorize-text-aggregation-with-filters-example] +-------------------------------------------------- +POST log-messages/_search?filter_path=aggregations +{ + "aggs": { + "categories": { + "categorize_text": { + "field": "message", + "categorization_filters": ["\\w+\\_\\d{3}"] <1> + } + } + } +} +-------------------------------------------------- +// TEST[setup:categorize_text] +<1> The filters to apply to the analyzed tokens. It filters + out tokens like `bar_123`. +Note how the `foo_` tokens are not part of the +category results + +[source,console-result] +-------------------------------------------------- +{ + "aggregations" : { + "categories" : { + "buckets" : [ + { + "doc_count" : 3, + "key" : "Node shutting down" + }, + { + "doc_count" : 1, + "key" : "User logged off" + }, + { + "doc_count" : 1, + "key" : "User logging on" + }, + { + "doc_count" : 1, + "key" : "Node starting up" + } + ] + } + } +} +-------------------------------------------------- +// TESTRESPONSE + +Here is an example using `categorization_filters` + +[source,console,id=categorize-text-aggregation-with-broad-categories-example] +-------------------------------------------------- +POST log-messages/_search?filter_path=aggregations +{ + "aggs": { + "categories": { + "categorize_text": { + "field": "message", + "categorization_filters": ["\\w+\\_\\d{3}"], <1> + "max_depth": 2, <2> + "similarity_threshold": 0.3 <3> + } + } + } +} +-------------------------------------------------- +// TEST[setup:categorize_text] +<1> The filters to apply to the analyzed tokens. It filters +out tokens like `bar_123`. +<2> Only the token tree to have 2 tokens before the log categories + attempt to merge together +<3> Require 30% of the tokens to match before expanding a log categories + to add a new log entry + +The resulting categories are now broad, matching the first token +and merging the log groups. + +[source,console-result] +-------------------------------------------------- +{ + "aggregations" : { + "categories" : { + "buckets" : [ + { + "doc_count" : 4, + "key" : "Node *" + }, + { + "doc_count" : 2, + "key" : "User *" + } + ] + } + } +} +-------------------------------------------------- +// TESTRESPONSE + +This aggregation can have both sub-aggregations and itself be a sub-aggregation. + +[source,console,id=categorize-text-aggregation-with-broad-categories-sub-aggs-example] +-------------------------------------------------- +POST log-messages/_search?filter_path=aggregations +{ + "aggs": { + "daily": { + "date_histogram": { + "field": "time", + "fixed_interval": "1d" + }, + "aggs": { + "categories": { + "categorize_text": { + "field": "message", + "categorization_filters": ["\\w+\\_\\d{3}"] + }, + "aggs": { + "hit": { + "top_hits": { + "size": 1, + "_source": "message" + } + } + } + } + } + } + } +} +-------------------------------------------------- +[source,console-result] +-------------------------------------------------- +{ + "aggregations" : { + "daily" : { + "buckets" : [ + { + "key_as_string" : "2016-02-07T00:00:00.000Z", + "key" : 1454803200000, + "doc_count" : 3, + "categories" : { + "buckets" : [ + { + "doc_count" : 2, + "key" : "Node shutting down", + "hit" : { + "hits" : { + "total" : { + "value" : 2, + "relation" : "eq" + }, + "max_score" : 1.0, + "hits" : [ + { + "_index" : "log-messages", + "_id" : "DU9q4HsBtGA51sVjTrac", + "_score" : 1.0, + "_source" : { + "message" : "2016-02-07T00:00:00+0000 Node 3 shutting down" + } + } + ] + } + } + }, + { + "doc_count" : 1, + "key" : "Node starting up", + "hit" : { + "hits" : { + "total" : { + "value" : 1, + "relation" : "eq" + }, + "max_score" : 1.0, + "hits" : [ + { + "_index" : "log-messages", + "_id" : "Dk9q4HsBtGA51sVjTrac", + "_score" : 1.0, + "_source" : { + "message" : "2016-02-07T00:00:00+0000 Node 5 starting up" + } + } + ] + } + } + } + ] + } + }, + { + "key_as_string" : "2016-02-08T00:00:00.000Z", + "key" : 1454889600000, + "doc_count" : 3, + "categories" : { + "buckets" : [ + { + "doc_count" : 1, + "key" : "User logged off", + "hit" : { + "hits" : { + "total" : { + "value" : 1, + "relation" : "eq" + }, + "max_score" : 1.0, + "hits" : [ + { + "_index" : "log-messages", + "_id" : "Ek9q4HsBtGA51sVjTrac", + "_score" : 1.0, + "_source" : { + "message" : "2016-02-08T00:00:00+0000 User foo_864 logged off" + } + } + ] + } + } + }, + { + "doc_count" : 1, + "key" : "User logging on", + "hit" : { + "hits" : { + "total" : { + "value" : 1, + "relation" : "eq" + }, + "max_score" : 1.0, + "hits" : [ + { + "_index" : "log-messages", + "_id" : "EU9q4HsBtGA51sVjTrac", + "_score" : 1.0, + "_source" : { + "message" : "2016-02-08T00:00:00+0000 User foo_325 logging on" + } + } + ] + } + } + }, + { + "doc_count" : 1, + "key" : "Node shutting down", + "hit" : { + "hits" : { + "total" : { + "value" : 1, + "relation" : "eq" + }, + "max_score" : 1.0, + "hits" : [ + { + "_index" : "log-messages", + "_id" : "EE9q4HsBtGA51sVjTrac", + "_score" : 1.0, + "_source" : { + "message" : "2016-02-08T00:00:00+0000 Node 5 shutting down" + } + } + ] + } + } + } + ] + } + } + ] + } + } +} +-------------------------------------------------- +// TESTRESPONSE + diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 6b3c456ef18bc..46a23468505a6 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -985,6 +985,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc context.terminateAfter(source.terminateAfter()); if (source.aggregations() != null && includeAggregations) { AggregationContext aggContext = new ProductionAggregationContext( + indicesService.getAnalysis(), context.getSearchExecutionContext(), bigArrays, source.aggregations().bytesToPreallocate(), diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/ParsedMultiBucketAggregation.java b/server/src/main/java/org/elasticsearch/search/aggregations/ParsedMultiBucketAggregation.java index 76ca0a917fb5d..48bd678ce5f80 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/ParsedMultiBucketAggregation.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/ParsedMultiBucketAggregation.java @@ -48,7 +48,7 @@ protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) return builder; } - protected static , T extends ParsedBucket> void declareMultiBucketAggregationFields( + public static , T extends ParsedBucket> void declareMultiBucketAggregationFields( final ObjectParser objectParser, final CheckedFunction bucketParser, final CheckedFunction keyedBucketParser diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java index 5008fbe08eac4..ad6366d75ce64 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java @@ -18,6 +18,7 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.index.analysis.NamedAnalyzer; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.fielddata.IndexFieldData; @@ -95,6 +96,10 @@ public final FieldContext buildFieldContext(String field) { return new FieldContext(field, buildFieldData(ft), ft); } + public AnalysisRegistry getAnalysisRegistry() { + return null; + } + /** * Lookup the context for an already resolved field type. */ @@ -277,10 +282,12 @@ public static class ProductionAggregationContext extends AggregationContext { private final Supplier isCancelled; private final Function filterQuery; private final boolean enableRewriteToFilterByFilter; + private final AnalysisRegistry analysisRegistry; private final List releaseMe = new ArrayList<>(); public ProductionAggregationContext( + AnalysisRegistry analysisRegistry, SearchExecutionContext context, BigArrays bigArrays, long bytesToPreallocate, @@ -295,6 +302,7 @@ public ProductionAggregationContext( Function filterQuery, boolean enableRewriteToFilterByFilter ) { + this.analysisRegistry = analysisRegistry; this.context = context; if (bytesToPreallocate == 0) { /* @@ -327,6 +335,11 @@ public ProductionAggregationContext( this.enableRewriteToFilterByFilter = enableRewriteToFilterByFilter; } + @Override + public AnalysisRegistry getAnalysisRegistry() { + return this.analysisRegistry; + } + @Override public Query query() { return topLevelQuery.get(); diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java index 0a0fe066677a4..906fc8b50366d 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java @@ -56,6 +56,8 @@ import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +import org.elasticsearch.env.Environment; +import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.AnalysisRegistry; @@ -94,8 +96,10 @@ import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndicesModule; +import org.elasticsearch.indices.analysis.AnalysisModule; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.plugins.AnalysisPlugin; import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.script.ScriptCompiler; import org.elasticsearch.script.ScriptService; @@ -159,6 +163,7 @@ public abstract class AggregatorTestCase extends ESTestCase { private List releasables = new ArrayList<>(); protected ValuesSourceRegistry valuesSourceRegistry; + protected AnalysisModule analysisModule; // A list of field types that should not be tested, or are not currently supported private static final List TYPE_TEST_BLACKLIST = List.of( @@ -178,6 +183,16 @@ public void initValuesSourceRegistry() { valuesSourceRegistry = searchModule.getValuesSourceRegistry(); } + @Before + public void initAnalysisRegistry() throws IOException { + analysisModule = new AnalysisModule( + TestEnvironment.newEnvironment( + Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build() + ), + getAnalysisPlugins() + ); + } + /** * Test cases should override this if they have plugins that need to be loaded, e.g. the plugins their aggregators are in. */ @@ -185,6 +200,10 @@ protected List getSearchPlugins() { return List.of(); } + protected List getAnalysisPlugins() { + return List.of(); + } + protected A createAggregator(AggregationBuilder aggregationBuilder, IndexSearcher searcher, MappedFieldType... fieldTypes) throws IOException { @@ -283,6 +302,7 @@ public void onCache(ShardId shardId, Accountable accountable) {} MultiBucketConsumer consumer = new MultiBucketConsumer(maxBucket, breakerService.getBreaker(CircuitBreaker.REQUEST)); AggregationContext context = new ProductionAggregationContext( + analysisModule.getAnalysisRegistry(), searchExecutionContext, new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), breakerService), bytesToPreallocate, diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java index 30bcfd62a8e99..f7475f8536153 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java @@ -142,6 +142,10 @@ protected Collection> getPlugins() { return Collections.singletonList(TestGeoShapeFieldMapperPlugin.class); } + protected Collection> getExtraPlugins() { + return Collections.emptyList(); + } + protected void initializeAdditionalMappings(MapperService mapperService) throws IOException { } @@ -208,9 +212,11 @@ public void beforeTest() throws Exception { // this setup long masterSeed = SeedUtils.parseSeed(RandomizedTest.getContext().getRunnerSeedAsString()); RandomizedTest.getContext().runWithPrivateRandomness(masterSeed, (Callable) () -> { - serviceHolder = new ServiceHolder(nodeSettings, createTestIndexSettings(), getPlugins(), nowInMillis, + Collection> plugins = new ArrayList<>(getPlugins()); + plugins.addAll(getExtraPlugins()); + serviceHolder = new ServiceHolder(nodeSettings, createTestIndexSettings(), plugins, nowInMillis, AbstractBuilderTestCase.this, true); - serviceHolderWithNoType = new ServiceHolder(nodeSettings, createTestIndexSettings(), getPlugins(), nowInMillis, + serviceHolderWithNoType = new ServiceHolder(nodeSettings, createTestIndexSettings(), plugins, nowInMillis, AbstractBuilderTestCase.this, false); return null; }); diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 0207b69ee9972..02cbd1ba601a9 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -35,6 +35,9 @@ tasks.named("yamlRestTest").configure { 'ml/calendar_crud/Test put calendar given id contains invalid chars', 'ml/calendar_crud/Test delete event from non existing calendar', 'ml/calendar_crud/Test delete job from non existing calendar', + // These are searching tests with aggregations, and do not call any ML endpoints + 'ml/categorization_agg/Test categorization agg simple', + 'ml/categorization_agg/Test categorization aggregation with poor settings', 'ml/custom_all_field/Test querying custom all field', 'ml/datafeeds_crud/Test delete datafeed with missing id', 'ml/datafeeds_crud/Test put datafeed referring to missing job_id', diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java new file mode 100644 index 0000000000000..3eba0d46d0516 --- /dev/null +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java @@ -0,0 +1,160 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.metrics.Max; +import org.elasticsearch.search.aggregations.metrics.Min; +import org.elasticsearch.xpack.ml.aggs.categorization.CategorizeTextAggregationBuilder; +import org.elasticsearch.xpack.ml.aggs.categorization.InternalCategorizationAggregation; +import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; +import org.junit.Before; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notANumber; + +public class CategorizationAggregationIT extends BaseMlIntegTestCase { + + private static final String DATA_INDEX = "categorization-agg-data"; + + @Before + public void setupCluster() { + internalCluster().ensureAtLeastNumDataNodes(3); + ensureStableCluster(); + createSourceData(); + } + + public void testAggregation() { + SearchResponse response = client().prepareSearch(DATA_INDEX) + .setSize(0) + .setTrackTotalHits(false) + .addAggregation( + new CategorizeTextAggregationBuilder("categorize", "msg") + .subAggregation(AggregationBuilders.max("max").field("time")) + .subAggregation(AggregationBuilders.min("min").field("time")) + ).get(); + + InternalCategorizationAggregation agg = response.getAggregations().get("categorize"); + assertThat(agg.getBuckets(), hasSize(3)); + + assertCategorizationBucket(agg.getBuckets().get(0), "Node started", 3); + assertCategorizationBucket( + agg.getBuckets().get(1), + "Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception", + 2 + ); + assertCategorizationBucket(agg.getBuckets().get(2), "Node stopped", 1); + } + + public void testAggregationWithOnlyOneBucket() { + SearchResponse response = client().prepareSearch(DATA_INDEX) + .setSize(0) + .setTrackTotalHits(false) + .addAggregation( + new CategorizeTextAggregationBuilder("categorize", "msg") + .size(1) + .subAggregation(AggregationBuilders.max("max").field("time")) + .subAggregation(AggregationBuilders.min("min").field("time")) + ).get(); + InternalCategorizationAggregation agg = response.getAggregations().get("categorize"); + assertThat(agg.getBuckets(), hasSize(1)); + + assertCategorizationBucket(agg.getBuckets().get(0), "Node started", 3); + } + + public void testAggregationWithBroadCategories() { + SearchResponse response = client().prepareSearch(DATA_INDEX) + .setSize(0) + .setTrackTotalHits(false) + .addAggregation( + new CategorizeTextAggregationBuilder("categorize", "msg") + .setSimilarityThreshold(0.11) + .setMaxChildren(2) + .setMaxDepth(1) + .subAggregation(AggregationBuilders.max("max").field("time")) + .subAggregation(AggregationBuilders.min("min").field("time")) + ).get(); + InternalCategorizationAggregation agg = response.getAggregations().get("categorize"); + assertThat(agg.getBuckets(), hasSize(2)); + + assertCategorizationBucket(agg.getBuckets().get(0), "Node *", 4); + assertCategorizationBucket( + agg.getBuckets().get(1), + "Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception", + 2 + ); + } + + private void assertCategorizationBucket(InternalCategorizationAggregation.Bucket bucket, String key, long docCount) { + assertThat(bucket.getKeyAsString(), equalTo(key)); + assertThat(bucket.getDocCount(), equalTo(docCount)); + assertThat(((Max)bucket.getAggregations().get("max")).getValue(), not(notANumber())); + assertThat(((Min)bucket.getAggregations().get("min")).getValue(), not(notANumber())); + } + + private void ensureStableCluster() { + ensureStableCluster(internalCluster().getNodeNames().length, TimeValue.timeValueSeconds(60)); + } + + private void createSourceData() { + client().admin().indices().prepareCreate(DATA_INDEX) + .setMapping("time", "type=date,format=epoch_millis", + "msg", "type=text") + .get(); + + long nowMillis = System.currentTimeMillis(); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + IndexRequest indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis - TimeValue.timeValueHours(2).millis(), + "msg", "Node 1 started", + "part", "nodes"); + bulkRequestBuilder.add(indexRequest); + indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis - TimeValue.timeValueHours(2).millis() + 1, + "msg", "Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused by foo exception]", + "part", "shutdowns"); + bulkRequestBuilder.add(indexRequest); + indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis - TimeValue.timeValueHours(2).millis() + 1, + "msg", "Failed to shutdown [error org.aaaa.bbbb.Cccc line 55 caused by foo exception]", + "part", "shutdowns"); + bulkRequestBuilder.add(indexRequest); + indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis - TimeValue.timeValueHours(1).millis(), + "msg", "Node 2 started", + "part", "nodes"); + bulkRequestBuilder.add(indexRequest); + indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis, + "msg", "Node 3 started", + "part", "nodes"); + bulkRequestBuilder.add(indexRequest); + + indexRequest = new IndexRequest(DATA_INDEX); + indexRequest.source("time", nowMillis, + "msg", "Node 3 stopped", + "part", "nodes"); + bulkRequestBuilder.add(indexRequest); + + BulkResponse bulkResponse = bulkRequestBuilder + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get(); + assertThat(bulkResponse.hasFailures(), is(false)); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index c1a25a6544bfa..330eb43029814 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -264,6 +264,8 @@ import org.elasticsearch.xpack.ml.action.TransportUpgradeJobModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportValidateDetectorAction; import org.elasticsearch.xpack.ml.action.TransportValidateJobConfigAction; +import org.elasticsearch.xpack.ml.aggs.categorization.CategorizeTextAggregationBuilder; +import org.elasticsearch.xpack.ml.aggs.categorization.InternalCategorizationAggregation; import org.elasticsearch.xpack.ml.aggs.correlation.BucketCorrelationAggregationBuilder; import org.elasticsearch.xpack.ml.aggs.correlation.CorrelationNamedContentProvider; import org.elasticsearch.xpack.ml.aggs.heuristic.PValueScore; @@ -1218,6 +1220,7 @@ public List> getExecutorBuilders(Settings settings) { return Arrays.asList(jobComms, utility, datafeed); } + @Override public Map> getCharFilters() { return MapBuilder.>newMapBuilder() .put(FirstNonBlankLineCharFilter.NAME, FirstNonBlankLineCharFilterFactory::new) @@ -1247,6 +1250,18 @@ public List> getSignificanceHeuristics() { ); } + @Override + public List getAggregations() { + return Arrays.asList( + new AggregationSpec( + CategorizeTextAggregationBuilder.NAME, + CategorizeTextAggregationBuilder::new, + CategorizeTextAggregationBuilder.PARSER + ).addResultReader(InternalCategorizationAggregation::new) + .setAggregatorRegistrar(s -> s.registerUsage(CategorizeTextAggregationBuilder.NAME)) + ); + } + @Override public UnaryOperator> getIndexTemplateMetadataUpgrader() { return UnaryOperator.identity(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java new file mode 100644 index 0000000000000..0de2299be7f45 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java @@ -0,0 +1,161 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.common.Strings; +import org.elasticsearch.search.aggregations.InternalAggregations; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.Collectors; + +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; + +/** + * Categorized semi-structured text utilizing the drain algorithm: https://arxiv.org/pdf/1806.04356.pdf + * With the following key differntiators + * - This structure keeps track of the "smallest" sub-tree. So, instead of naively adding a new "*" node, the smallest sub-tree + * is transformed if the incoming token has a higher doc_count. + * - Additionally, similarities are weighted, which allows for nicer merging of existing log categories + * - An optional tree reduction step is available to collapse together tiny sub-trees + * + * + * The main implementation is a fixed-sized prefix tree. + * Consequently, this assumes that splits that give us more information come earlier in the text. + * + * Examples: + * + * Given log values: + * + * Node is online + * Node is offline + * + * With a fixed tree depth of 2 we would get the following splits + * 3 // initial root is the number of tokens + * | + * "Node" // first prefix node of value "Node" + * | + * "is" + * / \ + * [Node is online] [Node is offline] //the individual categories for this simple case + * + * If the similarityThreshold was less than 0.6, the result would be a single category [Node is *] + * + */ +public class CategorizationTokenTree implements Accountable, TreeNodeFactory { + + static final BytesRef WILD_CARD = new BytesRef("*"); + private static final Logger LOGGER = LogManager.getLogger(CategorizationTokenTree.class); + + private final int maxDepth; + private final int maxChildren; + private final double similarityThreshold; + private final AtomicLong idGen = new AtomicLong(); + // TODO statically allocate an array like DuplicateByteSequenceSpotter ??? + private final Map root = new HashMap<>(); + private long sizeInBytes; + + public CategorizationTokenTree(int maxChildren, int maxDepth, double similarityThreshold) { + assert maxChildren > 0 && maxDepth >= 0; + this.maxChildren = maxChildren; + this.maxDepth = maxDepth; + this.similarityThreshold = similarityThreshold; + this.sizeInBytes = Integer.BYTES // maxDepth + + Integer.BYTES // maxChildren + + Double.BYTES // similarityThreshold + + NUM_BYTES_OBJECT_REF + Long.BYTES // idGen + + NUM_BYTES_OBJECT_REF // tree map + + Long.BYTES; // sizeInBytes + } + + public List toIntermediateBuckets() { + return root.values().stream().flatMap(c -> c.getAllChildrenLogGroups().stream()).map(lg -> { + InternalCategorizationAggregation.Bucket bucket = new InternalCategorizationAggregation.Bucket( + new InternalCategorizationAggregation.BucketKey(lg.getLogEvent()), + lg.getCount(), + InternalAggregations.EMPTY + ); + bucket.bucketOrd = lg.bucketOrd; + return bucket; + }).collect(Collectors.toList()); + } + + public List toBuckets(Map internalAggregations) { + return root.values() + .stream() + .flatMap(c -> c.getAllChildrenLogGroups().stream()) + .map( + lg -> new InternalCategorizationAggregation.Bucket( + new InternalCategorizationAggregation.BucketKey(lg.getLogEvent()), + lg.getCount(), + internalAggregations.get(lg.getId()) + ) + ) + .sorted() + .collect(Collectors.toList()); + } + + void mergeSmallestChildren() { + root.values().forEach(TreeNode::collapseTinyChildren); + } + + public LogGroup parseLogLine(final BytesRef[] logTokens) { + return parseLogLine(logTokens, 1); + } + + public LogGroup parseLogLineConst(final BytesRef[] logTokens) { + TreeNode currentNode = this.root.get(logTokens.length); + if (currentNode == null) { // we are missing an entire sub tree. New log length found + return null; + } + return currentNode.getLogGroup(logTokens); + } + + public LogGroup parseLogLine(final BytesRef[] logTokens, long docCount) { + LOGGER.trace("parsing tokens [{}]", Strings.arrayToDelimitedString(logTokens, " ")); + TreeNode currentNode = this.root.get(logTokens.length); + if (currentNode == null) { // we are missing an entire sub tree. New log length found + currentNode = newNode(docCount, 0, logTokens); + this.root.put(logTokens.length, currentNode); + } else { + currentNode.incCount(docCount); + } + return currentNode.addLog(logTokens, docCount, this); + } + + @Override + public TreeNode newNode(long docCount, int tokenPos, BytesRef[] tokens) { + TreeNode node = tokenPos < maxDepth - 1 && tokenPos < tokens.length + ? new TreeNode.InnerTreeNode(docCount, tokenPos, maxChildren) + : new TreeNode.LeafTreeNode(docCount, similarityThreshold); + // The size of the node + entry (since it is a map entry) + extra reference for priority queue + sizeInBytes += node.ramBytesUsed() + RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + RamUsageEstimator.NUM_BYTES_OBJECT_REF; + return node; + } + + @Override + public LogGroup newGroup(long docCount, BytesRef[] logTokens) { + LogGroup group = new LogGroup(logTokens, docCount, idGen.incrementAndGet()); + // Get the regular size bytes from the LogGroup and how much it costs to reference it + sizeInBytes += group.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF; + return group; + } + + @Override + public long ramBytesUsed() { + return sizeInBytes; + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java new file mode 100644 index 0000000000000..1f6507039f328 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java @@ -0,0 +1,275 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ParseField; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; +import org.elasticsearch.search.aggregations.support.AggregationContext; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class CategorizeTextAggregationBuilder extends AbstractAggregationBuilder { + + static final TermsAggregator.BucketCountThresholds DEFAULT_BUCKET_COUNT_THRESHOLDS = new TermsAggregator.BucketCountThresholds( + 1, + 0, + 10, + -1 + ); + public static final String NAME = "categorize_text"; + + static final ParseField FIELD_NAME = new ParseField("field"); + static final ParseField MAX_CHILDREN = new ParseField("max_children"); + static final ParseField SIMILARITY_THRESHOLD = new ParseField("similarity_threshold"); + static final ParseField MAX_DEPTH = new ParseField("max_depth"); + static final ParseField CATEGORIZATION_FILTERS = new ParseField("categorization_filters"); + + public static final ObjectParser PARSER = ObjectParser.fromBuilder( + CategorizeTextAggregationBuilder.NAME, + CategorizeTextAggregationBuilder::new + ); + static { + PARSER.declareString(CategorizeTextAggregationBuilder::setFieldName, FIELD_NAME); + PARSER.declareInt(CategorizeTextAggregationBuilder::setMaxChildren, MAX_CHILDREN); + PARSER.declareInt(CategorizeTextAggregationBuilder::setMaxDepth, MAX_DEPTH); + PARSER.declareDouble(CategorizeTextAggregationBuilder::setSimilarityThreshold, SIMILARITY_THRESHOLD); + PARSER.declareStringArray(CategorizeTextAggregationBuilder::setCategorizationFilters, CATEGORIZATION_FILTERS); + PARSER.declareInt(CategorizeTextAggregationBuilder::shardSize, TermsAggregationBuilder.SHARD_SIZE_FIELD_NAME); + PARSER.declareLong(CategorizeTextAggregationBuilder::minDocCount, TermsAggregationBuilder.MIN_DOC_COUNT_FIELD_NAME); + PARSER.declareLong(CategorizeTextAggregationBuilder::shardMinDocCount, TermsAggregationBuilder.SHARD_MIN_DOC_COUNT_FIELD_NAME); + PARSER.declareInt(CategorizeTextAggregationBuilder::size, TermsAggregationBuilder.REQUIRED_SIZE_FIELD_NAME); + } + + public static CategorizeTextAggregationBuilder parse(String aggregationName, XContentParser parser) throws IOException { + return PARSER.parse(parser, new CategorizeTextAggregationBuilder(aggregationName), null); + } + + private TermsAggregator.BucketCountThresholds bucketCountThresholds = new TermsAggregator.BucketCountThresholds( + DEFAULT_BUCKET_COUNT_THRESHOLDS + ); + private List categorizationFilters = new ArrayList<>(); + private String fieldName; + private int maxChildren = 100; + private double similarityThreshold = 0.5; + private int maxDepth = 5; + + private CategorizeTextAggregationBuilder(String name) { + super(name); + } + + public CategorizeTextAggregationBuilder(String name, String fieldName) { + super(name); + this.fieldName = ExceptionsHelper.requireNonNull(fieldName, FIELD_NAME); + } + + public String getFieldName() { + return fieldName; + } + + public CategorizeTextAggregationBuilder setFieldName(String fieldName) { + this.fieldName = ExceptionsHelper.requireNonNull(fieldName, FIELD_NAME); + return this; + } + + public CategorizeTextAggregationBuilder(StreamInput in) throws IOException { + super(in); + this.bucketCountThresholds = new TermsAggregator.BucketCountThresholds(in); + this.fieldName = in.readString(); + this.maxChildren = in.readVInt(); + this.maxDepth = in.readVInt(); + this.similarityThreshold = in.readDouble(); + this.categorizationFilters = in.readStringList(); + } + + public int getMaxChildren() { + return maxChildren; + } + + public CategorizeTextAggregationBuilder setMaxChildren(int maxChildren) { + this.maxChildren = maxChildren; + if (maxChildren <= 0) { + throw new IllegalArgumentException("[" + MAX_CHILDREN.getPreferredName() + "] must be greater than 0"); + } + return this; + } + + public double getSimilarityThreshold() { + return similarityThreshold; + } + + public CategorizeTextAggregationBuilder setSimilarityThreshold(double similarityThreshold) { + this.similarityThreshold = similarityThreshold; + if (similarityThreshold < 0.1 || similarityThreshold > 1.0) { + throw new IllegalArgumentException("[" + SIMILARITY_THRESHOLD.getPreferredName() + "] must be in the range [0.1, 1.0]"); + } + return this; + } + + public List getCategorizationFilters() { + return categorizationFilters; + } + + public CategorizeTextAggregationBuilder setCategorizationFilters(List categorizationFilters) { + this.categorizationFilters = ExceptionsHelper.requireNonNull(categorizationFilters, CATEGORIZATION_FILTERS); + return this; + } + + public int getMaxDepth() { + return maxDepth; + } + + public CategorizeTextAggregationBuilder setMaxDepth(int maxDepth) { + this.maxDepth = maxDepth; + if (maxDepth <= 0) { + throw new IllegalArgumentException("[" + MAX_DEPTH.getPreferredName() + "] must be greater than 0"); + } + return this; + } + + /** + * @param size indicating how many buckets should be returned + */ + public CategorizeTextAggregationBuilder size(int size) { + if (size <= 0) { + throw new IllegalArgumentException("[size] must be greater than 0. Found [" + size + "] in [" + name + "]"); + } + bucketCountThresholds.setRequiredSize(size); + return this; + } + + /** + * @param shardSize - indicating the number of buckets each shard + * will return to the coordinating node (the node that coordinates the + * search execution). The higher the shard size is, the more accurate the + * results are. + */ + public CategorizeTextAggregationBuilder shardSize(int shardSize) { + if (shardSize <= 0) { + throw new IllegalArgumentException("[shardSize] must be greater than 0. Found [" + shardSize + "] in [" + name + "]"); + } + bucketCountThresholds.setShardSize(shardSize); + return this; + } + + /** + * @param minDocCount the minimum document count a text category should have in order to appear in + * the response. + */ + public CategorizeTextAggregationBuilder minDocCount(long minDocCount) { + if (minDocCount < 0) { + throw new IllegalArgumentException( + "[minDocCount] must be greater than or equal to 0. Found [" + minDocCount + "] in [" + name + "]" + ); + } + bucketCountThresholds.setMinDocCount(minDocCount); + return this; + } + + /** + * @param shardMinDocCount the minimum document count a text category should have on the shard in order to + * appear in the response. + */ + public CategorizeTextAggregationBuilder shardMinDocCount(long shardMinDocCount) { + if (shardMinDocCount < 0) { + throw new IllegalArgumentException( + "[shardMinDocCount] must be greater than or equal to 0. Found [" + shardMinDocCount + "] in [" + name + "]" + ); + } + bucketCountThresholds.setShardMinDocCount(shardMinDocCount); + return this; + } + + protected CategorizeTextAggregationBuilder( + CategorizeTextAggregationBuilder clone, + AggregatorFactories.Builder factoriesBuilder, + Map metadata + ) { + super(clone, factoriesBuilder, metadata); + this.bucketCountThresholds = new TermsAggregator.BucketCountThresholds(clone.bucketCountThresholds); + this.fieldName = clone.fieldName; + this.maxChildren = clone.maxChildren; + this.maxDepth = clone.maxDepth; + this.similarityThreshold = clone.similarityThreshold; + this.categorizationFilters = clone.categorizationFilters; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + bucketCountThresholds.writeTo(out); + out.writeString(fieldName); + out.writeVInt(maxChildren); + out.writeVInt(maxDepth); + out.writeDouble(similarityThreshold); + out.writeStringCollection(categorizationFilters); + } + + @Override + protected AggregatorFactory doBuild( + AggregationContext context, + AggregatorFactory parent, + AggregatorFactories.Builder subfactoriesBuilder + ) throws IOException { + return new CategorizeTextAggregatorFactory( + name, + fieldName, + maxChildren, + maxDepth, + similarityThreshold, + bucketCountThresholds, + categorizationFilters, + context, + parent, + subfactoriesBuilder, + metadata + ); + } + + @Override + protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + bucketCountThresholds.toXContent(builder, params); + builder.field(FIELD_NAME.getPreferredName(), fieldName); + builder.field(MAX_CHILDREN.getPreferredName(), maxChildren); + builder.field(MAX_DEPTH.getPreferredName(), maxDepth); + builder.field(SIMILARITY_THRESHOLD.getPreferredName(), similarityThreshold); + if (categorizationFilters.isEmpty() == false) { + builder.field(CATEGORIZATION_FILTERS.getPreferredName(), categorizationFilters); + } + builder.endObject(); + return null; + } + + @Override + protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map metadata) { + return new CategorizeTextAggregationBuilder(this, factoriesBuilder, metadata); + } + + @Override + public BucketCardinality bucketCardinality() { + return BucketCardinality.MANY; + } + + @Override + public String getType() { + return NAME; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java new file mode 100644 index 0000000000000..151bfd26798bd --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java @@ -0,0 +1,208 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.PriorityQueue; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.CardinalityUpperBound; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.LeafBucketCollector; +import org.elasticsearch.search.aggregations.LeafBucketCollectorBase; +import org.elasticsearch.search.aggregations.bucket.DeferableBucketAggregator; +import org.elasticsearch.search.aggregations.bucket.terms.LongKeyedBucketOrds; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; +import org.elasticsearch.search.aggregations.support.AggregationContext; +import org.elasticsearch.search.lookup.SourceLookup; +import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; +import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +public class CategorizeTextAggregator extends DeferableBucketAggregator { + + private final TermsAggregator.BucketCountThresholds bucketCountThresholds; + private final SourceLookup sourceLookup; + private final BigArrays bigArrays; + private final MappedFieldType fieldType; + private final CategorizationAnalyzer analyzer; + private final String sourceFieldName; + private ObjectArray categorizers; + private final int maxChildren; + private final int maxDepth; + private final double similarityThreshold; + private final LongKeyedBucketOrds bucketOrds; + + protected CategorizeTextAggregator( + String name, + AggregatorFactories factories, + AggregationContext context, + Aggregator parent, + String sourceFieldName, + MappedFieldType fieldType, + TermsAggregator.BucketCountThresholds bucketCountThresholds, + int maxChildren, + int maxDepth, + double similarityThreshold, + List categorizationFilters, + Map metadata + ) throws IOException { + super(name, factories, context, parent, metadata); + this.sourceLookup = context.lookup().source(); + this.sourceFieldName = sourceFieldName; + this.fieldType = fieldType; + CategorizationAnalyzerConfig categorizationAnalyzerConfig = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer( + categorizationFilters + ); + this.analyzer = new CategorizationAnalyzer(context.getAnalysisRegistry(), categorizationAnalyzerConfig); + this.bigArrays = context.bigArrays(); + this.categorizers = bigArrays().newObjectArray(1); + this.maxChildren = maxChildren; + this.maxDepth = maxDepth; + this.similarityThreshold = similarityThreshold; + this.bucketOrds = LongKeyedBucketOrds.build(bigArrays(), CardinalityUpperBound.MANY); + this.bucketCountThresholds = bucketCountThresholds; + } + + @Override + protected void doClose() { + super.doClose(); + this.analyzer.close(); + } + + @Override + public InternalAggregation[] buildAggregations(long[] ordsToCollect) throws IOException { + InternalCategorizationAggregation.Bucket[][] topBucketsPerOrd = + new InternalCategorizationAggregation.Bucket[ordsToCollect.length][]; + for (int ordIdx = 0; ordIdx < ordsToCollect.length; ordIdx++) { + int size = (int) Math.min(bucketOrds.size(), bucketCountThresholds.getShardSize()); + PriorityQueue ordered = + new InternalCategorizationAggregation.BucketCountPriorityQueue(size); + CategorizationTokenTree categorizationTokenTree = categorizers.get(ordsToCollect[ordIdx]); + for (InternalCategorizationAggregation.Bucket bucket : categorizationTokenTree.toIntermediateBuckets()) { + if (bucket.docCount < bucketCountThresholds.getShardMinDocCount()) { + continue; + } + ordered.insertWithOverflow(bucket); + } + topBucketsPerOrd[ordIdx] = new InternalCategorizationAggregation.Bucket[ordered.size()]; + for (int i = ordered.size() - 1; i >= 0; --i) { + topBucketsPerOrd[ordIdx][i] = ordered.pop(); + } + } + buildSubAggsForAllBuckets(topBucketsPerOrd, b -> b.bucketOrd, (b, a) -> b.aggregations = a); + InternalAggregation[] results = new InternalAggregation[ordsToCollect.length]; + for (int ordIdx = 0; ordIdx < ordsToCollect.length; ordIdx++) { + InternalCategorizationAggregation.Bucket[] bucketArray = topBucketsPerOrd[ordIdx]; + Arrays.sort(bucketArray, Comparator.naturalOrder()); + results[ordIdx] = new InternalCategorizationAggregation( + name, + bucketCountThresholds.getRequiredSize(), + bucketCountThresholds.getMinDocCount(), + maxChildren, + maxDepth, + similarityThreshold, + metadata(), + Arrays.asList(bucketArray) + ); + } + return results; + } + + @Override + public InternalAggregation buildEmptyAggregation() { + return new InternalCategorizationAggregation( + name, + bucketCountThresholds.getRequiredSize(), + bucketCountThresholds.getMinDocCount(), + maxChildren, + maxDepth, + similarityThreshold, + metadata() + ); + } + + @Override + protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + return new LeafBucketCollectorBase(sub, null) { + + @Override + public void collect(int doc, long owningBucketOrd) throws IOException { + collectFromSource(doc, owningBucketOrd); + } + + private void collectFromSource(int doc, long owningBucketOrd) throws IOException { + sourceLookup.setSegmentAndDocument(ctx, doc); + Iterator itr = sourceLookup.extractRawValues(sourceFieldName).stream().map(obj -> { + if (obj == null) { + return null; + } + if (obj instanceof BytesRef) { + return fieldType.valueForDisplay(obj).toString(); + } + return obj.toString(); + }).iterator(); + while (itr.hasNext()) { + TokenStream ts = analyzer.tokenStream(fieldType.name(), itr.next()); + processTokenStream(owningBucketOrd, ts, doc); + } + } + + private void processTokenStream(long owningBucketOrd, TokenStream ts, int doc) throws IOException { + try { + CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); + ts.reset(); + ArrayList tokens = new ArrayList<>(); + while (ts.incrementToken()) { + tokens.add(new BytesRef(termAtt)); + } + if (tokens.isEmpty()) { + return; + } + categorizers = bigArrays.grow(categorizers, owningBucketOrd + 1); + CategorizationTokenTree categorizer = categorizers.get(owningBucketOrd); + if (categorizer == null) { + categorizer = new CategorizationTokenTree(maxChildren, maxDepth, similarityThreshold); + addRequestCircuitBreakerBytes(categorizer.ramBytesUsed()); + categorizers.set(owningBucketOrd, categorizer); + } + long previousSize = categorizer.ramBytesUsed(); + LogGroup lg = categorizer.parseLogLine(tokens.toArray(BytesRef[]::new), docCountProvider.getDocCount(doc)); + long newSize = categorizer.ramBytesUsed(); + if (newSize - previousSize > 0) { + addRequestCircuitBreakerBytes(newSize - previousSize); + } + + long bucketOrd = bucketOrds.add(owningBucketOrd, lg.getId()); + if (bucketOrd < 0) { // already seen + bucketOrd = -1 - bucketOrd; + collectExistingBucket(sub, doc, bucketOrd); + } else { + lg.bucketOrd = bucketOrd; + collectBucket(sub, doc, bucketOrd); + } + } finally { + ts.close(); + } + } + }; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java new file mode 100644 index 0000000000000..2974dadde4ecf --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.CardinalityUpperBound; +import org.elasticsearch.search.aggregations.bucket.BucketUtils; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; +import org.elasticsearch.search.aggregations.support.AggregationContext; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class CategorizeTextAggregatorFactory extends AggregatorFactory { + + private final MappedFieldType fieldType; + private final String indexedFieldName; + private final int maxChildren; + private final int maxDepth; + private final double similarityThreshold; + private final List categorizationFilters; + private final TermsAggregator.BucketCountThresholds bucketCountThresholds; + + public CategorizeTextAggregatorFactory( + String name, + String fieldName, + int maxChildren, + int maxDepth, + double similarityThreshold, + TermsAggregator.BucketCountThresholds bucketCountThresholds, + List categorizationFilters, + AggregationContext context, + AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder, + Map metadata + ) throws IOException { + super(name, context, parent, subFactoriesBuilder, metadata); + this.fieldType = context.getFieldType(fieldName); + if (fieldType != null) { + this.indexedFieldName = fieldType.name(); + } else { + throw new IllegalArgumentException("Only works on indexed fields, cannot find field [" + fieldName + "]"); + } + this.maxChildren = maxChildren; + this.maxDepth = maxDepth; + this.similarityThreshold = similarityThreshold; + this.categorizationFilters = categorizationFilters == null ? Collections.emptyList() : categorizationFilters; + this.bucketCountThresholds = bucketCountThresholds; + } + + @Override + protected Aggregator createInternal(Aggregator parent, CardinalityUpperBound cardinality, Map metadata) + throws IOException { + TermsAggregator.BucketCountThresholds bucketCountThresholds = new TermsAggregator.BucketCountThresholds(this.bucketCountThresholds); + if (bucketCountThresholds.getShardSize() == CategorizeTextAggregationBuilder.DEFAULT_BUCKET_COUNT_THRESHOLDS.getShardSize()) { + // The user has not made a shardSize selection. Use default + // heuristic to avoid any wrong-ranking caused by distributed + // counting + // TODO significant text does a 2x here, should we as well? + bucketCountThresholds.setShardSize(BucketUtils.suggestShardSideQueueSize(bucketCountThresholds.getRequiredSize())); + } + bucketCountThresholds.ensureValidity(); + + return new CategorizeTextAggregator( + name, + factories, + context, + parent, + indexedFieldName, + fieldType, + bucketCountThresholds, + maxChildren, + maxDepth, + similarityThreshold, + categorizationFilters, + metadata + ); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java new file mode 100644 index 0000000000000..89ff6bd302eaf --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java @@ -0,0 +1,415 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.PriorityQueue; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentFragment; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.search.aggregations.AggregationExecutionException; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.InternalAggregations; +import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation; +import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationTokenTree.WILD_CARD; + +public class InternalCategorizationAggregation extends InternalMultiBucketAggregation< + InternalCategorizationAggregation, + InternalCategorizationAggregation.Bucket> { + + // Carries state allowing for delayed reduction of the bucket + // This allows us to keep from accidentally calling "reduce" on the sub-aggs more than once + private static class DelayedCategorizationBucket { + private final BucketKey key; + private long docCount; + private final List toReduce; + + DelayedCategorizationBucket(BucketKey key, List toReduce, long docCount) { + this.key = key; + this.toReduce = new ArrayList<>(toReduce); + this.docCount = docCount; + } + + public long getDocCount() { + return docCount; + } + + public Bucket reduce(BucketKey key, ReduceContext reduceContext) { + List innerAggs = new ArrayList<>(toReduce.size()); + long docCount = 0; + for (Bucket bucket : toReduce) { + innerAggs.add(bucket.aggregations); + docCount += bucket.docCount; + } + return new Bucket(key, docCount, InternalAggregations.reduce(innerAggs, reduceContext)); + } + + public DelayedCategorizationBucket add(Bucket bucket) { + this.docCount += bucket.docCount; + this.toReduce.add(bucket); + return this; + } + + public DelayedCategorizationBucket add(DelayedCategorizationBucket bucket) { + this.docCount += bucket.docCount; + this.toReduce.addAll(bucket.toReduce); + return this; + } + } + + static class BucketCountPriorityQueue extends PriorityQueue { + BucketCountPriorityQueue(int size) { + super(size); + } + + @Override + protected boolean lessThan(Bucket a, Bucket b) { + return a.docCount < b.docCount; + } + } + + static class BucketKey implements ToXContentFragment, Writeable, Comparable { + + private final BytesRef[] key; + + BucketKey(BytesRef[] key) { + this.key = key; + } + + BucketKey(StreamInput in) throws IOException { + key = in.readArray(StreamInput::readBytesRef, BytesRef[]::new); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.value(asString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeArray(StreamOutput::writeBytesRef, key); + } + + public String asString() { + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < key.length - 1; i++) { + builder.append(key[i].utf8ToString()).append(" "); + } + builder.append(key[key.length - 1].utf8ToString()); + return builder.toString(); + } + + @Override + public String toString() { + return asString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BucketKey bucketKey = (BucketKey) o; + return Arrays.equals(key, bucketKey.key); + } + + @Override + public int hashCode() { + return Arrays.hashCode(key); + } + + public BytesRef[] keyAsTokens() { + return key; + } + + @Override + public int compareTo(BucketKey o) { + return Arrays.compare(key, o.key); + } + + private BucketKey collapseWildCards() { + if (key.length <= 1) { + return this; + } + List collapsedWildCards = new ArrayList<>(); + boolean previousTokenWildCard = false; + for (BytesRef token : key) { + if (token.equals(WILD_CARD)) { + if (previousTokenWildCard == false) { + previousTokenWildCard = true; + collapsedWildCards.add(WILD_CARD); + } + } else { + previousTokenWildCard = false; + collapsedWildCards.add(token); + } + } + if (collapsedWildCards.size() == key.length) { + return this; + } + return new BucketKey(collapsedWildCards.toArray(BytesRef[]::new)); + } + } + + public static class Bucket extends InternalBucket implements MultiBucketsAggregation.Bucket, Comparable { + // Used on the shard level to keep track of sub aggregations + long bucketOrd; + + final BucketKey key; + final long docCount; + InternalAggregations aggregations; + + public Bucket(BucketKey key, long docCount, InternalAggregations aggregations) { + this.key = key; + this.docCount = docCount; + this.aggregations = aggregations; + } + + public Bucket(StreamInput in) throws IOException { + key = new BucketKey(in); + docCount = in.readVLong(); + aggregations = InternalAggregations.readFrom(in); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CommonFields.DOC_COUNT.getPreferredName(), docCount); + builder.field(CommonFields.KEY.getPreferredName()); + key.toXContent(builder, params); + aggregations.toXContentInternal(builder, params); + builder.endObject(); + return builder; + } + + @Override + public Object getKey() { + return key; + } + + @Override + public String getKeyAsString() { + return key.asString(); + } + + @Override + public long getDocCount() { + return docCount; + } + + @Override + public Aggregations getAggregations() { + return aggregations; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + key.writeTo(out); + out.writeVLong(getDocCount()); + aggregations.writeTo(out); + } + + @Override + public String toString() { + return "Bucket{" + "key=" + getKeyAsString() + ", docCount=" + docCount + ", aggregations=" + aggregations.asMap() + "}\n"; + } + + @Override + public int compareTo(Bucket o) { + return key.compareTo(o.key); + } + + } + + private final List buckets; + private final int maxChildren; + private final double similarityThreshold; + private final int maxDepth; + protected final int requiredSize; + protected final long minDocCount; + + protected InternalCategorizationAggregation( + String name, + int requiredSize, + long minDocCount, + int maxChildren, + int maxDepth, + double similarityThreshold, + Map metadata + ) { + this(name, requiredSize, minDocCount, maxChildren, maxDepth, similarityThreshold, metadata, new ArrayList<>()); + } + + protected InternalCategorizationAggregation( + String name, + int requiredSize, + long minDocCount, + int maxChildren, + int maxDepth, + double similarityThreshold, + Map metadata, + List buckets + ) { + super(name, metadata); + this.buckets = buckets; + this.maxChildren = maxChildren; + this.maxDepth = maxDepth; + this.similarityThreshold = similarityThreshold; + this.minDocCount = minDocCount; + this.requiredSize = requiredSize; + } + + public InternalCategorizationAggregation(StreamInput in) throws IOException { + super(in); + this.maxChildren = in.readVInt(); + this.maxDepth = in.readVInt(); + this.similarityThreshold = in.readDouble(); + this.buckets = in.readList(Bucket::new); + this.requiredSize = readSize(in); + this.minDocCount = in.readVLong(); + } + + @Override + public InternalCategorizationAggregation create(List buckets) { + return new InternalCategorizationAggregation( + name, + requiredSize, + minDocCount, + maxChildren, + maxDepth, + similarityThreshold, + super.metadata, + buckets + ); + } + + @Override + public Bucket createBucket(InternalAggregations aggregations, Bucket prototype) { + return new Bucket(prototype.key, prototype.docCount, aggregations); + } + + @Override + protected Bucket reduceBucket(List buckets, ReduceContext context) { + throw new IllegalArgumentException("For optimization purposes, typical bucket path is not supported"); + } + + @Override + public List getBuckets() { + return buckets; + } + + @Override + public String getWriteableName() { + return CategorizeTextAggregationBuilder.NAME; + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeVInt(maxChildren); + out.writeVInt(maxDepth); + out.writeDouble(similarityThreshold); + out.writeList(buckets); + writeSize(requiredSize, out); + out.writeVLong(minDocCount); + } + + @Override + public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { + CategorizationTokenTree categorizationTokenTree = new CategorizationTokenTree(maxChildren, maxDepth, similarityThreshold); + // TODO: Could we do a merge sort similar to terms? + // It would require us returning partial reductions sorted by key, not by doc_count + // First, make sure we have all the counts for equal log groups + Map reduced = new HashMap<>(aggregations.size(), 1.0f); + for (InternalAggregation aggregation : aggregations) { + InternalCategorizationAggregation categorizationAggregation = (InternalCategorizationAggregation) aggregation; + for (Bucket bucket : categorizationAggregation.buckets) { + reduced.computeIfAbsent(bucket.key, key -> new DelayedCategorizationBucket(key, new ArrayList<>(1), 0L)).add(bucket); + } + } + + for (DelayedCategorizationBucket bucket : reduced.values()) { + // Parse log line takes document count into account and merging on smallest groups + categorizationTokenTree.parseLogLine(bucket.key.keyAsTokens(), bucket.docCount); + } + // Collapse tiny groups together, this may result in new bucket keys for already known buckets + categorizationTokenTree.mergeSmallestChildren(); + Map mergedBuckets = new HashMap<>(aggregations.size(), 1.0f); + for (DelayedCategorizationBucket delayedBucket : reduced.values()) { + LogGroup group = categorizationTokenTree.parseLogLineConst(delayedBucket.key.keyAsTokens()); + if (group == null) { + throw new AggregationExecutionException( + "Unexpected null categorization group for bucket [" + delayedBucket.key.asString() + "]" + ); + } + BucketKey key = new BucketKey(group.getLogEvent()); + mergedBuckets.computeIfAbsent( + reduceContext.isFinalReduce() ? key.collapseWildCards() : key, + k -> new DelayedCategorizationBucket(k, new ArrayList<>(delayedBucket.toReduce.size()), 0L) + ).add(delayedBucket); + } + + final int size = reduceContext.isFinalReduce() == false ? buckets.size() : Math.min(requiredSize, buckets.size()); + final PriorityQueue pq = new BucketCountPriorityQueue(size); + for (Map.Entry keyAndBuckets : mergedBuckets.entrySet()) { + final BucketKey key = keyAndBuckets.getKey(); + DelayedCategorizationBucket bucket = keyAndBuckets.getValue(); + Bucket newBucket = bucket.reduce(key, reduceContext); + if ((newBucket.docCount >= minDocCount) || reduceContext.isFinalReduce() == false) { + Bucket removed = pq.insertWithOverflow(newBucket); + if (removed == null) { + reduceContext.consumeBucketsAndMaybeBreak(1); + } else { + reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(removed)); + } + } else { + reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(newBucket)); + } + } + Bucket[] bucketList = new Bucket[pq.size()]; + for (int i = pq.size() - 1; i >= 0; i--) { + bucketList[i] = pq.pop(); + } + return new InternalCategorizationAggregation( + name, + requiredSize, + minDocCount, + maxChildren, + maxDepth, + similarityThreshold, + metadata, + Arrays.asList(bucketList) + ); + } + + @Override + public Object getProperty(List path) { + // TODO anything special? + return super.getProperty(path); + } + + @Override + public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + builder.startArray(CommonFields.BUCKETS.getPreferredName()); + for (Bucket bucket : buckets) { + bucket.toXContent(builder, params); + } + builder.endArray(); + return builder; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/LogGroup.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/LogGroup.java new file mode 100644 index 0000000000000..b6c81235c9652 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/LogGroup.java @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.core.Tuple; + +import java.util.Arrays; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationTokenTree.WILD_CARD; + +/** + * A log group that provides methods for: + * - calculating similarity between it and a new log + * - expanding the existing log group by adding a new log + */ +class LogGroup implements Accountable { + + private final long id; + private final BytesRef[] logEvent; + private final long[] tokenCounts; + private long count; + + // Used at the shard level for tracking the bucket ordinal for collecting sub aggregations + long bucketOrd; + + @Override + public String toString() { + return "LogGroup{" + + "id=" + + id + + ", logEvent=" + + Arrays.stream(logEvent).map(BytesRef::utf8ToString).collect(Collectors.joining(", ", "[", "]")) + + ", count=" + + count + + '}'; + } + + LogGroup(BytesRef[] logTokens, long count, long id) { + this.id = id; + this.logEvent = logTokens; + this.count = count; + this.tokenCounts = new long[logTokens.length]; + Arrays.fill(this.tokenCounts, count); + } + + public long getId() { + return id; + } + + BytesRef[] getLogEvent() { + return logEvent; + } + + public long getCount() { + return count; + } + + Tuple calculateSimilarity(BytesRef[] logEvent) { + assert logEvent.length == this.logEvent.length; + int eqParams = 0; + long tokenCount = 0; + long tokensKept = 0; + for (int i = 0; i < logEvent.length; i++) { + if (logEvent[i].equals(this.logEvent[i])) { + tokensKept += tokenCounts[i]; + tokenCount += tokenCounts[i]; + } else if (this.logEvent[i].equals(WILD_CARD)) { + eqParams++; + } else { + tokenCount += tokenCounts[i]; + } + } + return new Tuple<>((double) tokensKept / tokenCount, eqParams); + } + + void addLog(BytesRef[] logEvent, long docCount) { + assert logEvent.length == this.logEvent.length; + for (int i = 0; i < logEvent.length; i++) { + if (logEvent[i].equals(this.logEvent[i]) == false) { + this.logEvent[i] = WILD_CARD; + } else { + tokenCounts[i] += docCount; + } + } + this.count += docCount; + } + + @Override + public long ramBytesUsed() { + return Long.BYTES // id + + (long) logEvent.length * RamUsageEstimator.NUM_BYTES_ARRAY_HEADER // logEvent + + RamUsageEstimator.NUM_BYTES_OBJECT_REF + ((long) logEvent.length * Long.BYTES) + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + + RamUsageEstimator.NUM_BYTES_OBJECT_REF + Long.BYTES; // count + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java new file mode 100644 index 0000000000000..6628b2273719c --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java @@ -0,0 +1,398 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.search.aggregations.AggregationExecutionException; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.PriorityQueue; +import java.util.stream.Collectors; + +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; +import static org.apache.lucene.util.RamUsageEstimator.sizeOfCollection; +import static org.apache.lucene.util.RamUsageEstimator.sizeOfMap; +import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationTokenTree.WILD_CARD; + +/** + * Tree node classes for the categorization token tree. + * + * Two major node types exist: + * - Inner: which are nodes that have children token nodes + * - Leaf: Which collection multiple {@link LogGroup} based on similarity restrictions + */ +abstract class TreeNode implements Accountable { + + private static final Logger LOGGER = LogManager.getLogger(TreeNode.class); + + private long count; + + TreeNode(long count) { + this.count = count; + } + + abstract void mergeWith(TreeNode otherNode); + + abstract boolean isLeaf(); + + final void incCount(long count) { + this.count += count; + } + + final long getCount() { + return count; + } + + // TODO add option for calculating the cost of adding the new group + abstract LogGroup addLog(BytesRef[] logTokens, long docCount, TreeNodeFactory treeNodeFactory); + + abstract LogGroup getLogGroup(BytesRef[] logTokens); + + abstract List getAllChildrenLogGroups(); + + abstract void collapseTinyChildren(); + + @Override + public long ramBytesUsed() { + return Long.BYTES; // count + } + + static class LeafTreeNode extends TreeNode { + private final List logGroups; + private final double similarityThreshold; + + LeafTreeNode(long count, double similarityThreshold) { + super(count); + this.logGroups = new ArrayList<>(); + this.similarityThreshold = similarityThreshold; + } + + public boolean isLeaf() { + return true; + } + + @Override + void mergeWith(TreeNode treeNode) { + if (treeNode == null) { + return; + } + if (treeNode.isLeaf() == false) { + throw new UnsupportedOperationException( + "cannot merge leaf node with non-leaf node in categorization tree \n[" + this + "]\n[" + treeNode + "]" + ); + } + incCount(treeNode.getCount()); + LeafTreeNode otherLeaf = (LeafTreeNode) treeNode; + for (LogGroup group : otherLeaf.logGroups) { + if (getAndUpdateLogGroup(group.getLogEvent(), group.getCount()).isPresent() == false) { + putNewLogGroup(group); + } + } + } + + @Override + public long ramBytesUsed() { + return super.ramBytesUsed() + NUM_BYTES_OBJECT_REF // list reference + + Double.BYTES // similarityThreshold + + sizeOfCollection(logGroups); + } + + @Override + public LogGroup addLog(BytesRef[] logTokens, long docCount, TreeNodeFactory treeNodeFactory) { + return getAndUpdateLogGroup(logTokens, docCount).orElseGet(() -> { + // Need to update the tree if possible + LogGroup group = treeNodeFactory.newGroup(docCount, logTokens); + LOGGER.trace(() -> new ParameterizedMessage("created group! [{}]", group)); + return putNewLogGroup(group); + }); + } + + @Override + List getAllChildrenLogGroups() { + return logGroups; + } + + @Override + void collapseTinyChildren() {} + + private Optional getAndUpdateLogGroup(BytesRef[] logTokens, long docCount) { + return getBestLogGroup(logTokens).map(bestGroupAndSimilarity -> { + if (bestGroupAndSimilarity.v2() >= similarityThreshold) { + bestGroupAndSimilarity.v1().addLog(logTokens, docCount); + return bestGroupAndSimilarity.v1(); + } + return null; + }); + } + + LogGroup putNewLogGroup(LogGroup group) { + logGroups.add(group); + return group; + } + + private Optional> getBestLogGroup(BytesRef[] logTokens) { + if (logGroups.isEmpty()) { + return Optional.empty(); + } + if (logGroups.size() == 1) { + return Optional.of(new Tuple<>(logGroups.get(0), logGroups.get(0).calculateSimilarity(logTokens).v1())); + } + double maxSimilarity = 0.0; + int maxParamMatch = 0; + LogGroup bestGroup = null; + for (LogGroup logGroup : this.logGroups) { + Tuple groupSimilarity = logGroup.calculateSimilarity(logTokens); + if (groupSimilarity.v1() > maxSimilarity) { + maxSimilarity = groupSimilarity.v1(); + maxParamMatch = groupSimilarity.v2(); + bestGroup = logGroup; + } else if (groupSimilarity.v1() == maxSimilarity && groupSimilarity.v2() > maxParamMatch) { + maxParamMatch = groupSimilarity.v2(); + bestGroup = logGroup; + } + } + return Optional.of(new Tuple<>(bestGroup, maxSimilarity)); + } + + @Override + public LogGroup getLogGroup(final BytesRef[] logTokens) { + return getBestLogGroup(logTokens).map(Tuple::v1).orElse(null); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LeafTreeNode that = (LeafTreeNode) o; + return Double.compare(that.similarityThreshold, similarityThreshold) == 0 && Objects.equals(logGroups, that.logGroups); + } + + @Override + public int hashCode() { + return Objects.hash(logGroups, similarityThreshold); + } + } + + static class InnerTreeNode extends TreeNode { + + private final Map children; + private final int childrenTokenPos; + private final int maxChildren; + private final PriorityQueue> smallestChild; + + InnerTreeNode(long count, int childrenTokenPos, int maxChildren) { + super(count); + children = new HashMap<>(); + this.childrenTokenPos = childrenTokenPos; + this.maxChildren = maxChildren; + this.smallestChild = new PriorityQueue<>(maxChildren, Comparator.comparing(Tuple::v2)); + } + + boolean isLeaf() { + return false; + } + + @Override + public LogGroup getLogGroup(final BytesRef[] logTokens) { + return getChild(logTokens[childrenTokenPos]).or(() -> getChild(WILD_CARD)) + .map(node -> node.getLogGroup(logTokens)) + .orElse(null); + } + + @Override + public long ramBytesUsed() { + return super.ramBytesUsed() + NUM_BYTES_OBJECT_REF // children reference + + Integer.BYTES // childrenTokenPos + + Integer.BYTES // maxChildren + + NUM_BYTES_OBJECT_REF // smallestChildReference + + sizeOfMap(children, 0) + // Number of items in the queue, reference to tuple, and then the tuple references + + (long) smallestChild.size() * (NUM_BYTES_OBJECT_REF + NUM_BYTES_OBJECT_REF + NUM_BYTES_OBJECT_REF + Long.BYTES); + } + + @Override + public LogGroup addLog(final BytesRef[] logTokens, final long docCount, final TreeNodeFactory treeNodeFactory) { + BytesRef currentToken = logTokens[childrenTokenPos]; + TreeNode child = getChild(currentToken).map(node -> { + node.incCount(docCount); + if (smallestChild.isEmpty() == false && smallestChild.peek().v1().equals(currentToken)) { + smallestChild.add(smallestChild.poll()); + } + return node; + }).orElseGet(() -> { + if (docCount > 1) { + LOGGER.trace( + () -> new ParameterizedMessage( + "got a token [{}] with doc_count [{}] percentage [{}]", + logTokens[childrenTokenPos].utf8ToString(), + docCount, + (double) docCount / docCount + ) + ); + } + TreeNode newNode = treeNodeFactory.newNode(docCount, childrenTokenPos + 1, logTokens); + return addChild(currentToken, newNode); + }); + return child.addLog(logTokens, docCount, treeNodeFactory); + } + + @Override + void collapseTinyChildren() { + if (this.isLeaf()) { + return; + } + if (children.size() <= 1) { + return; + } + Optional maybeWildChild = getChild(WILD_CARD).or(() -> { + if ((double) smallestChild.peek().v2() / this.getCount() <= 1.0 / maxChildren) { + TreeNode tinyChild = children.remove(smallestChild.poll().v1()); + return Optional.of(addChild(WILD_CARD, tinyChild)); + } + return Optional.empty(); + }); + if (maybeWildChild.isPresent()) { + TreeNode wildChild = maybeWildChild.get(); + Tuple tinyNode; + while ((tinyNode = smallestChild.poll()) != null) { + // If we have no more tiny nodes, stop iterating over them + if ((double) tinyNode.v2() / this.getCount() > 1.0 / maxChildren) { + smallestChild.add(tinyNode); + break; + } else { + wildChild.mergeWith(children.remove(tinyNode.v1())); + } + } + } + children.values().forEach(TreeNode::collapseTinyChildren); + } + + @Override + void mergeWith(TreeNode treeNode) { + if (treeNode == null) { + return; + } + incCount(treeNode.count); + if (treeNode.isLeaf()) { + throw new UnsupportedOperationException( + "cannot merge non-leaf node with leaf node in categorization tree \n[" + this + "]\n[" + treeNode + "]" + ); + } + InnerTreeNode innerTreeNode = (InnerTreeNode) treeNode; + TreeNode siblingWildChild = innerTreeNode.children.remove(WILD_CARD); + addChild(WILD_CARD, siblingWildChild); + Tuple siblingChild; + while ((siblingChild = innerTreeNode.smallestChild.poll()) != null) { + TreeNode nephewNode = innerTreeNode.children.remove(siblingChild.v1()); + addChild(siblingChild.v1(), nephewNode); + } + } + + private TreeNode addChild(BytesRef token, TreeNode node) { + if (node == null || token == null) { + return null; + } + Optional existingChild = getChild(token).map(existingNode -> { + existingNode.mergeWith(node); + if (smallestChild.isEmpty() == false && smallestChild.peek().v1().equals(token)) { + smallestChild.poll(); + smallestChild.add(Tuple.tuple(token, existingNode.getCount())); + } + return existingNode; + }); + if (existingChild.isPresent()) { + return existingChild.get(); + } + if (children.size() == maxChildren) { + return getChild(WILD_CARD).map(wildChild -> { + final TreeNode toMerge; + final TreeNode toReturn; + if (smallestChild.isEmpty() == false && node.getCount() > smallestChild.peek().v2()) { + toMerge = children.remove(smallestChild.poll().v1()); + addChildAndUpdateSmallest(token, node); + toReturn = node; + } else { + toMerge = node; + toReturn = wildChild; + } + wildChild.mergeWith(toMerge); + return toReturn; + }).orElseThrow(() -> new AggregationExecutionException("Missing wild_card child even though maximum children reached")); + } + // we are about to hit the limit, add a wild card if we need to and then add the new child as appropriate + if (children.size() == maxChildren - 1) { + // If we already have a wild token, simply adding the new token is acceptable as we won't breach our limit + if (children.containsKey(WILD_CARD)) { + addChildAndUpdateSmallest(token, node); + } else { // if we don't have a wild card child, we need to add one now + if (token.equals(WILD_CARD)) { + addChildAndUpdateSmallest(token, node); + } else { + if (smallestChild.isEmpty() == false && node.count > smallestChild.peek().v2()) { + addChildAndUpdateSmallest(WILD_CARD, children.remove(smallestChild.poll().v1())); + addChildAndUpdateSmallest(token, node); + } else { + addChildAndUpdateSmallest(WILD_CARD, node); + } + } + } + } else { + addChildAndUpdateSmallest(token, node); + } + return node; + } + + private void addChildAndUpdateSmallest(BytesRef token, TreeNode node) { + children.put(token, node); + if (token.equals(WILD_CARD) == false) { + smallestChild.add(Tuple.tuple(token, node.count)); + } + } + + private Optional getChild(BytesRef token) { + TreeNode node = children.get(token); + return node == null ? Optional.empty() : Optional.of(node); + } + + public List getAllChildrenLogGroups() { + return children.values().stream().flatMap(c -> c.getAllChildrenLogGroups().stream()).collect(Collectors.toList()); + } + + boolean hasChild(BytesRef value) { + return children.containsKey(value); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InnerTreeNode treeNode = (InnerTreeNode) o; + return childrenTokenPos == treeNode.childrenTokenPos + && getCount() == treeNode.getCount() + && Objects.equals(children, treeNode.children) + && Objects.equals(smallestChild, treeNode.smallestChild); + } + + @Override + public int hashCode() { + return Objects.hash(children, childrenTokenPos, smallestChild, getCount()); + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java new file mode 100644 index 0000000000000..b710d2ad9c946 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; + +interface TreeNodeFactory { + TreeNode newNode(long docCount, int tokenPos, BytesRef[] logTokens); + + LogGroup newGroup(long docCount, BytesRef[] logTokens); +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java index 4fc99e1502851..229b505c21783 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java @@ -38,6 +38,11 @@ public CategorizationAnalyzer(AnalysisRegistry analysisRegistry, closeAnalyzer = tuple.v2(); } + public final TokenStream tokenStream(final String fieldName, + final String text) { + return analyzer.tokenStream(fieldName, text); + } + /** * Release resources held by the analyzer (unless it's global). */ diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java index 750f36863a141..d6f147f831df7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java @@ -17,6 +17,9 @@ import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.analysis.CharFilterFactory; +import org.elasticsearch.index.analysis.TokenizerFactory; +import org.elasticsearch.indices.analysis.AnalysisModule; import org.elasticsearch.license.LicenseService; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.plugins.ActionPlugin; @@ -33,6 +36,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Map; import static org.elasticsearch.xpack.ml.MachineLearning.TRAINED_MODEL_CIRCUIT_BREAKER_NAME; @@ -84,6 +88,20 @@ public void cleanUpFeature( mlPlugin.cleanUpFeature(clusterService, client, finalListener); } + @Override + public List getAggregations() { + return mlPlugin.getAggregations(); + } + + @Override + public Map> getCharFilters() { + return mlPlugin.getCharFilters(); + } + + @Override + public Map> getTokenizers() { + return mlPlugin.getTokenizers(); + } /** * This is only required as we now have to have the GetRollupIndexCapsAction as a valid action in our node. diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java new file mode 100644 index 0000000000000..3484e7c90540e --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.aggregations.BaseAggregationTestCase; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.util.Collection; +import java.util.Collections; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class CategorizeTextAggregationBuilderTests extends BaseAggregationTestCase { + + @Override + protected Collection> getExtraPlugins() { + return Collections.singletonList(MachineLearning.class); + } + + @Override + protected CategorizeTextAggregationBuilder createTestAggregatorBuilder() { + CategorizeTextAggregationBuilder builder = new CategorizeTextAggregationBuilder(randomAlphaOfLength(10), randomAlphaOfLength(10)); + if (randomBoolean()) { + builder.setCategorizationFilters(Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList())); + } + if (randomBoolean()) { + builder.setMaxChildren(randomIntBetween(1, 500)); + } + if (randomBoolean()) { + builder.setMaxDepth(randomIntBetween(1, 10)); + } + if (randomBoolean()) { + builder.setSimilarityThreshold(randomDoubleBetween(0.1, 1.0, true)); + } + if (randomBoolean()) { + builder.minDocCount(randomLongBetween(1, 100)); + } + if (randomBoolean()) { + builder.shardMinDocCount(randomLongBetween(1, 100)); + } + if (randomBoolean()) { + builder.size(randomIntBetween(1, 100)); + } + if (randomBoolean()) { + builder.shardSize(randomIntBetween(1, 100)); + } + return builder; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java new file mode 100644 index 0000000000000..1caaf4d685081 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java @@ -0,0 +1,297 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.document.StoredField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.TextFieldMapper; +import org.elasticsearch.plugins.AnalysisPlugin; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.search.aggregations.AggregatorTestCase; +import org.elasticsearch.search.aggregations.bucket.histogram.Histogram; +import org.elasticsearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.Avg; +import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.Max; +import org.elasticsearch.search.aggregations.metrics.MaxAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.Min; +import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; + +public class CategorizeTextAggregatorTests extends AggregatorTestCase { + + @Override + protected List getSearchPlugins() { + return List.of(new MachineLearning(Settings.EMPTY, null)); + } + + @Override + protected List getAnalysisPlugins() { + return List.of(new MachineLearning(Settings.EMPTY, null)); + } + + private static final String TEXT_FIELD_NAME = "text"; + private static final String NUMERIC_FIELD_NAME = "value"; + + public void testCategorizationWithoutSubAggs() throws Exception { + try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir)) { + writeTestDocs(w); + CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME); + try (IndexReader reader = w.getReader()) { + IndexSearcher searcher = new IndexSearcher(reader); + InternalCategorizationAggregation result = searchAndReduce( + searcher, + new MatchAllDocsQuery(), + aggBuilder, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME) + ); + assertThat(result.getBuckets(), hasSize(2)); + assertThat(result.getBuckets().get(0).docCount, equalTo(6L)); + assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(result.getBuckets().get(1).docCount, equalTo(2L)); + assertThat( + result.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + } + } + } + + public void testCategorizationWithSubAggs() throws Exception { + try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir)) { + writeTestDocs(w); + CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( + new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME) + ) + .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) + .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)); + try (IndexReader reader = w.getReader()) { + IndexSearcher searcher = new IndexSearcher(reader); + InternalCategorizationAggregation result = searchAndReduce( + searcher, + new MatchAllDocsQuery(), + aggBuilder, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); + assertThat(result.getBuckets(), hasSize(2)); + assertThat(result.getBuckets().get(0).docCount, equalTo(6L)); + assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(((Max) result.getBuckets().get(0).aggregations.get("max")).getValue(), equalTo(5.0)); + assertThat(((Min) result.getBuckets().get(0).aggregations.get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) result.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(2.5)); + + assertThat(result.getBuckets().get(1).docCount, equalTo(2L)); + assertThat( + result.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + assertThat(((Max) result.getBuckets().get(1).aggregations.get("max")).getValue(), equalTo(4.0)); + assertThat(((Min) result.getBuckets().get(1).aggregations.get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) result.getBuckets().get(1).aggregations.get("avg")).getValue(), equalTo(2.0)); + } + } + } + + public void testCategorizationWithMultiBucketSubAggs() throws Exception { + try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir)) { + writeTestDocs(w); + CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( + new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) + .interval(2) + .subAggregation(new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME)) + .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) + .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) + ); + try (IndexReader reader = w.getReader()) { + IndexSearcher searcher = new IndexSearcher(reader); + InternalCategorizationAggregation result = searchAndReduce( + searcher, + new MatchAllDocsQuery(), + aggBuilder, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); + assertThat(result.getBuckets(), hasSize(2)); + assertThat(result.getBuckets().get(0).docCount, equalTo(6L)); + assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + Histogram histo = result.getBuckets().get(0).aggregations.get("histo"); + assertThat(histo.getBuckets(), hasSize(3)); + for (Histogram.Bucket bucket : histo.getBuckets()) { + assertThat(bucket.getDocCount(), equalTo(2L)); + } + assertThat(((Max) histo.getBuckets().get(0).getAggregations().get("max")).getValue(), equalTo(1.0)); + assertThat(((Min) histo.getBuckets().get(0).getAggregations().get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.5)); + assertThat(((Max) histo.getBuckets().get(1).getAggregations().get("max")).getValue(), equalTo(3.0)); + assertThat(((Min) histo.getBuckets().get(1).getAggregations().get("min")).getValue(), equalTo(2.0)); + assertThat(((Avg) histo.getBuckets().get(1).getAggregations().get("avg")).getValue(), equalTo(2.5)); + assertThat(((Max) histo.getBuckets().get(2).getAggregations().get("max")).getValue(), equalTo(5.0)); + assertThat(((Min) histo.getBuckets().get(2).getAggregations().get("min")).getValue(), equalTo(4.0)); + assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.5)); + + assertThat(result.getBuckets().get(1).docCount, equalTo(2L)); + assertThat( + result.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + histo = result.getBuckets().get(1).aggregations.get("histo"); + assertThat(histo.getBuckets(), hasSize(3)); + assertThat(histo.getBuckets().get(0).getDocCount(), equalTo(1L)); + assertThat(histo.getBuckets().get(1).getDocCount(), equalTo(0L)); + assertThat(histo.getBuckets().get(2).getDocCount(), equalTo(1L)); + assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.0)); + assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.0)); + } + } + } + + public void testCategorizationAsSubAgg() throws Exception { + try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir)) { + writeTestDocs(w); + HistogramAggregationBuilder aggBuilder = new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) + .interval(2) + .subAggregation( + new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME) + .subAggregation(new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME)) + .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) + .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) + ); + try (IndexReader reader = w.getReader()) { + IndexSearcher searcher = new IndexSearcher(reader); + Histogram result = searchAndReduce( + searcher, + new MatchAllDocsQuery(), + aggBuilder, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); + assertThat(result.getBuckets(), hasSize(3)); + + // First histo bucket + assertThat(result.getBuckets().get(0).getDocCount(), equalTo(3L)); + InternalCategorizationAggregation categorizationAggregation = result.getBuckets().get(0).getAggregations().get("my_agg"); + assertThat(categorizationAggregation.getBuckets(), hasSize(2)); + assertThat(categorizationAggregation.getBuckets().get(0).docCount, equalTo(2L)); + assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(((Max) categorizationAggregation.getBuckets().get(0).aggregations.get("max")).getValue(), equalTo(1.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(0).aggregations.get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(0.5)); + + assertThat(categorizationAggregation.getBuckets().get(1).docCount, equalTo(1L)); + assertThat( + categorizationAggregation.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + assertThat(((Max) categorizationAggregation.getBuckets().get(1).aggregations.get("max")).getValue(), equalTo(0.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(1).aggregations.get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(1).aggregations.get("avg")).getValue(), equalTo(0.0)); + + // Second histo bucket + assertThat(result.getBuckets().get(1).getDocCount(), equalTo(2L)); + categorizationAggregation = result.getBuckets().get(1).getAggregations().get("my_agg"); + assertThat(categorizationAggregation.getBuckets(), hasSize(1)); + assertThat(categorizationAggregation.getBuckets().get(0).docCount, equalTo(2L)); + assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(((Max) categorizationAggregation.getBuckets().get(0).aggregations.get("max")).getValue(), equalTo(3.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(0).aggregations.get("min")).getValue(), equalTo(2.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(2.5)); + + // Third histo bucket + assertThat(result.getBuckets().get(2).getDocCount(), equalTo(3L)); + categorizationAggregation = result.getBuckets().get(2).getAggregations().get("my_agg"); + assertThat(categorizationAggregation.getBuckets(), hasSize(2)); + assertThat(categorizationAggregation.getBuckets().get(0).docCount, equalTo(2L)); + assertThat(categorizationAggregation.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + assertThat(((Max) categorizationAggregation.getBuckets().get(0).aggregations.get("max")).getValue(), equalTo(5.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(0).aggregations.get("min")).getValue(), equalTo(4.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(0).aggregations.get("avg")).getValue(), equalTo(4.5)); + + assertThat(categorizationAggregation.getBuckets().get(1).docCount, equalTo(1L)); + assertThat( + categorizationAggregation.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + assertThat(((Max) categorizationAggregation.getBuckets().get(1).aggregations.get("max")).getValue(), equalTo(4.0)); + assertThat(((Min) categorizationAggregation.getBuckets().get(1).aggregations.get("min")).getValue(), equalTo(4.0)); + assertThat(((Avg) categorizationAggregation.getBuckets().get(1).aggregations.get("avg")).getValue(), equalTo(4.0)); + } + } + } + + private static void writeTestDocs(RandomIndexWriter w) throws IOException { + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 1 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 0) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 1 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 1) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField( + "_source", + new BytesRef("{\"text\":\"Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused by foo exception]\"}") + ), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 0) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField( + "_source", + new BytesRef("{\"text\":\"Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused by foo exception]\"}") + ), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 4) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 2 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 2) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 2 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 3) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 3 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 4) + ) + ); + w.addDocument( + Arrays.asList( + new StoredField("_source", new BytesRef("{\"text\":\"Node 3 started\"}")), + new SortedNumericDocValuesField(NUMERIC_FIELD_NAME, 5) + ) + ); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java new file mode 100644 index 0000000000000..731d83c9b31f3 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java @@ -0,0 +1,109 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.test.ESTestCase; + +import static org.elasticsearch.xpack.ml.aggs.categorization.LogGroupTests.getTokens; +import static org.hamcrest.Matchers.arrayContaining; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class InnerTreeNodeTests extends ESTestCase { + + private final TreeNodeFactory factory = new CategorizationTokenTree(3, 4, 0.6); + + public void testAddLog() { + TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); + LogGroup group = innerTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 1, factory); + assertThat(group.getLogEvent(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); + + assertThat( + innerTreeNode.addLog(getTokens("foo2", "bar", "baz", "biz"), 1, factory).getLogEvent(), + arrayContaining(getTokens("foo2", "bar", "baz", "biz")) + ); + assertThat( + innerTreeNode.addLog(getTokens("foo3", "bar", "baz", "biz"), 1, factory).getLogEvent(), + arrayContaining(getTokens("foo3", "bar", "baz", "biz")) + ); + assertThat( + innerTreeNode.addLog(getTokens("foo4", "bar", "baz", "biz"), 1, factory).getLogEvent(), + arrayContaining(getTokens("*", "bar", "baz", "biz")) + ); + assertThat( + innerTreeNode.addLog(getTokens("foo", "bar", "baz", "bizzy"), 1, factory).getLogEvent(), + arrayContaining(getTokens("foo", "bar", "baz", "*")) + ); + } + + public void testAddLogWithLargerIncoming() { + TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); + LogGroup group = innerTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 100, factory); + assertThat(group.getLogEvent(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); + + assertThat( + innerTreeNode.addLog(getTokens("foo2", "bar", "baz", "biz"), 100, factory).getLogEvent(), + arrayContaining(getTokens("foo2", "bar", "baz", "biz")) + ); + assertThat( + innerTreeNode.addLog(getTokens("foosmall", "bar", "baz", "biz"), 1, factory).getLogEvent(), + arrayContaining(getTokens("foosmall", "bar", "baz", "biz")) + ); + assertThat( + innerTreeNode.addLog(getTokens("foobigun", "bar", "baz", "biz"), 1000, factory).getLogEvent(), + arrayContaining(getTokens("foobigun", "bar", "baz", "biz")) + ); + assertThat( + innerTreeNode.getLogGroup(getTokens("foosmall", "bar", "baz", "biz")).getLogEvent(), + equalTo(getTokens("*", "bar", "baz", "biz")) + ); + } + + public void testCollapseTinyChildren() { + TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1000, 0, 4); + LogGroup group = innerTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 1000, factory); + assertThat(group.getLogEvent(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); + + assertThat( + innerTreeNode.addLog(getTokens("foo2", "bar", "baz", "biz"), 1000, factory).getLogEvent(), + arrayContaining(getTokens("foo2", "bar", "baz", "biz")) + ); + innerTreeNode.incCount(1000); + assertThat( + innerTreeNode.addLog(getTokens("foosmall", "bar", "baz", "biz"), 1, factory).getLogEvent(), + arrayContaining(getTokens("foosmall", "bar", "baz", "biz")) + ); + innerTreeNode.incCount(1); + innerTreeNode.collapseTinyChildren(); + assertThat(innerTreeNode.hasChild(new BytesRef("foosmall")), is(false)); + assertThat(innerTreeNode.hasChild(new BytesRef("*")), is(true)); + } + + public void testMergeWith() { + TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1000, 0, 3); + innerTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 1000, factory); + innerTreeNode.incCount(1000); + innerTreeNode.addLog(getTokens("foo2", "bar", "baz", "biz"), 1000, factory); + + expectThrows(UnsupportedOperationException.class, () -> innerTreeNode.mergeWith(new TreeNode.LeafTreeNode(1, 0.6))); + + + TreeNode.InnerTreeNode mergeWith = new TreeNode.InnerTreeNode(1, 0, 3); + innerTreeNode.addLog(getTokens("foosmall", "bar", "baz", "biz"), 1, factory); + innerTreeNode.incCount(1); + innerTreeNode.addLog(getTokens("footiny", "bar", "baz", "biz"), 1, factory); + + innerTreeNode.mergeWith(mergeWith); + assertThat(innerTreeNode.hasChild(new BytesRef("*")), is(true)); + assertThat( + innerTreeNode.getLogGroup(getTokens("footiny", "bar", "baz", "biz")).getLogEvent(), + arrayContaining(getTokens("*", "bar", "baz", "biz")) + ); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java new file mode 100644 index 0000000000000..c3be4b0b8b6bb --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java @@ -0,0 +1,117 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ParseField; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.search.aggregations.Aggregation; +import org.elasticsearch.search.aggregations.InternalAggregations; +import org.elasticsearch.search.aggregations.ParsedMultiBucketAggregation; +import org.elasticsearch.test.InternalMultiBucketAggregationTestCase; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class InternalCategorizationAggregationTests extends InternalMultiBucketAggregationTestCase { + + @Override + protected SearchPlugin registerPlugin() { + return new MachineLearning(Settings.EMPTY, null); + } + + @Override + protected List getNamedXContents() { + return CollectionUtils.appendToCopy( + super.getNamedXContents(), + new NamedXContentRegistry.Entry( + Aggregation.class, + new ParseField(CategorizeTextAggregationBuilder.NAME), + (p, c) -> ParsedCategorization.fromXContent(p, (String) c) + ) + ); + } + + @Override + protected void assertReduced(InternalCategorizationAggregation reduced, List inputs) { + Map reducedCounts = toCounts(reduced.getBuckets().stream()); + Map totalCounts = toCounts(inputs.stream().map(InternalCategorizationAggregation::getBuckets).flatMap(List::stream)); + + Map expectedReducedCounts = new HashMap<>(totalCounts); + expectedReducedCounts.keySet().retainAll(reducedCounts.keySet()); + assertEquals(expectedReducedCounts, reducedCounts); + } + + @Override + protected Predicate excludePathsFromXContentInsertion() { + return p -> p.contains("key"); + } + + static InternalCategorizationAggregation.BucketKey randomKey() { + int numVals = randomIntBetween(1, 50); + return new InternalCategorizationAggregation.BucketKey( + Stream.generate(() -> randomAlphaOfLength(10)).limit(numVals).map(BytesRef::new).toArray(BytesRef[]::new) + ); + } + + @Override + protected InternalCategorizationAggregation createTestInstance( + String name, + Map metadata, + InternalAggregations aggregations + ) { + List buckets = new ArrayList<>(); + final int numBuckets = randomNumberOfBuckets(); + HashSet keys = new HashSet<>(); + for (int i = 0; i < numBuckets; ++i) { + InternalCategorizationAggregation.BucketKey key = randomValueOtherThanMany( + l -> keys.add(l) == false, + InternalCategorizationAggregationTests::randomKey + ); + int docCount = randomIntBetween(1, 100); + buckets.add(new InternalCategorizationAggregation.Bucket(key, docCount, aggregations)); + } + Collections.sort(buckets); + return new InternalCategorizationAggregation( + name, + randomIntBetween(10, 100), + randomLongBetween(1, 10), + randomIntBetween(1, 500), + randomIntBetween(1, 10), + randomDoubleBetween(0.1, 1.0, true), + metadata, + buckets + ); + } + + @Override + protected Class> implementationClass() { + return ParsedCategorization.class; + } + + private static Map toCounts(Stream buckets) { + return buckets.collect( + Collectors.toMap( + InternalCategorizationAggregation.Bucket::getKey, + InternalCategorizationAggregation.Bucket::getDocCount, + Long::sum + ) + ); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java new file mode 100644 index 0000000000000..74b72b784632b --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.elasticsearch.test.ESTestCase; + +import static org.elasticsearch.xpack.ml.aggs.categorization.LogGroupTests.getTokens; +import static org.hamcrest.Matchers.arrayContaining; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasSize; + +public class LeafTreeNodeTests extends ESTestCase { + + private final TreeNodeFactory factory = new CategorizationTokenTree(10, 10, 0.6); + + public void testAddGroup() { + TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 0.6); + LogGroup group = leafTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 1, factory); + + assertThat(group.getLogEvent(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); + assertThat(group.getCount(), equalTo(1L)); + assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(1)); + long previousBytesUsed = leafTreeNode.ramBytesUsed(); + + group = leafTreeNode.addLog(getTokens("foo", "bar", "bozo", "bizzy"), 1, factory); + assertThat(group.getLogEvent(), arrayContaining(getTokens("foo", "bar", "bozo", "bizzy"))); + assertThat(group.getCount(), equalTo(1L)); + assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(2)); + assertThat(leafTreeNode.ramBytesUsed(), greaterThan(previousBytesUsed)); + previousBytesUsed = leafTreeNode.ramBytesUsed(); + + + group = leafTreeNode.addLog(getTokens("foo", "bar", "baz", "different"), 3, factory); + assertThat(group.getLogEvent(), arrayContaining(getTokens("foo", "bar", "baz", "*"))); + assertThat(group.getCount(), equalTo(4L)); + assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(2)); + assertThat(previousBytesUsed, equalTo(leafTreeNode.ramBytesUsed())); + } + + public void testMergeWith() { + TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 0.6); + leafTreeNode.mergeWith(null); + assertThat(leafTreeNode, equalTo(new TreeNode.LeafTreeNode(0, 0.6))); + + expectThrows(UnsupportedOperationException.class, () -> leafTreeNode.mergeWith(new TreeNode.InnerTreeNode(1, 2, 3))); + + leafTreeNode.incCount(5); + leafTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 5, factory); + + TreeNode.LeafTreeNode toMerge = new TreeNode.LeafTreeNode(0, 0.6); + leafTreeNode.incCount(1); + toMerge.addLog(getTokens("foo", "bar", "baz", "bizzy"), 1, factory); + leafTreeNode.incCount(1); + toMerge.addLog(getTokens("foo", "bart", "bat", "built"), 1, factory); + leafTreeNode.mergeWith(toMerge); + + assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(2)); + assertThat(leafTreeNode.getCount(), equalTo(7L)); + assertThat(leafTreeNode.getAllChildrenLogGroups().get(0).getLogEvent(), arrayContaining(getTokens("foo", "bar", "baz", "*"))); + assertThat(leafTreeNode.getAllChildrenLogGroups().get(1).getLogEvent(), arrayContaining(getTokens("foo", "bart", "bat", "built"))); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LogGroupTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LogGroupTests.java new file mode 100644 index 0000000000000..66bae64ac305a --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LogGroupTests.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.arrayContaining; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; + +public class LogGroupTests extends ESTestCase { + + public void testSimilarity() { + LogGroup lg = new LogGroup(getTokens("foo", "bar", "baz", "biz"), 1, 1); + Tuple sims = lg.calculateSimilarity(getTokens("not", "matching", "anything", "nope")); + assertThat(sims.v1(), equalTo(0.0)); + assertThat(sims.v2(), equalTo(0)); + + sims = lg.calculateSimilarity(getTokens("foo", "bar", "baz", "biz")); + assertThat(sims.v1(), equalTo(1.0)); + assertThat(sims.v2(), equalTo(0)); + + sims = lg.calculateSimilarity(getTokens("foo", "fooagain", "notbar", "biz")); + assertThat(sims.v1(), closeTo(0.5, 0.0001)); + assertThat(sims.v2(), equalTo(0)); + } + + public void testAddLog() { + LogGroup lg = new LogGroup(getTokens("foo", "bar", "baz", "biz"), 1, 1); + lg.addLog(getTokens("foo", "bar", "baz", "bozo"), 2); + assertThat(lg.getCount(), equalTo(3L)); + assertThat(lg.getLogEvent(), arrayContaining(getTokens("foo", "bar", "baz", "*"))); + } + + static BytesRef[] getTokens(String... tokens) { + BytesRef[] refs = new BytesRef[tokens.length]; + int i = 0; + for (String token: tokens) { + refs[i++] = new BytesRef(token); + } + return refs; + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/ParsedCategorization.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/ParsedCategorization.java new file mode 100644 index 0000000000000..b554f9cfc43e1 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/ParsedCategorization.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.search.aggregations.ParsedMultiBucketAggregation; +import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +class ParsedCategorization extends ParsedMultiBucketAggregation { + + @Override + public String getType() { + return CategorizeTextAggregationBuilder.NAME; + } + + private static final ObjectParser PARSER = new ObjectParser<>( + ParsedCategorization.class.getSimpleName(), + true, + ParsedCategorization::new + ); + static { + declareMultiBucketAggregationFields(PARSER, ParsedBucket::fromXContent, ParsedBucket::fromXContent); + } + + public static ParsedCategorization fromXContent(XContentParser parser, String name) throws IOException { + ParsedCategorization aggregation = PARSER.parse(parser, null); + aggregation.setName(name); + return aggregation; + } + + @Override + public List getBuckets() { + return buckets; + } + + public static class ParsedBucket extends ParsedMultiBucketAggregation.ParsedBucket implements MultiBucketsAggregation.Bucket { + + private InternalCategorizationAggregation.BucketKey key; + + protected void setKeyAsString(String keyAsString) { + if (keyAsString == null) { + key = null; + return; + } + if (keyAsString.isEmpty()) { + key = new InternalCategorizationAggregation.BucketKey(new BytesRef[0]); + return; + } + String[] split = Strings.tokenizeToStringArray(keyAsString, " "); + key = new InternalCategorizationAggregation.BucketKey( + split == null + ? new BytesRef[] { new BytesRef(keyAsString) } + : Arrays.stream(split).map(BytesRef::new).toArray(BytesRef[]::new) + ); + } + + @Override + public Object getKey() { + return key; + } + + @Override + public String getKeyAsString() { + return key.asString(); + } + + @Override + protected XContentBuilder keyToXContent(XContentBuilder builder) throws IOException { + return builder.field(CommonFields.KEY.getPreferredName(), getKey()); + } + + static InternalCategorizationAggregation.BucketKey parsedKey(final XContentParser parser) throws IOException { + if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { + String toSplit = parser.text(); + String[] split = Strings.tokenizeToStringArray(toSplit, " "); + return new InternalCategorizationAggregation.BucketKey( + split == null + ? new BytesRef[] { new BytesRef(toSplit) } + : Arrays.stream(split).map(BytesRef::new).toArray(BytesRef[]::new) + ); + } else { + return new InternalCategorizationAggregation.BucketKey( + XContentParserUtils.parseList(parser, p -> new BytesRef(p.binaryValue())).toArray(BytesRef[]::new) + ); + } + } + + static ParsedBucket fromXContent(final XContentParser parser) throws IOException { + return ParsedMultiBucketAggregation.ParsedBucket.parseXContent( + parser, + false, + ParsedBucket::new, + (p, bucket) -> bucket.key = parsedKey(p) + ); + } + + } + +} diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml new file mode 100644 index 0000000000000..3b69194263f69 --- /dev/null +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml @@ -0,0 +1,136 @@ +setup: + - skip: + features: headers + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + indices.create: + index: to_categorize + body: + mappings: + properties: + kind: + type: keyword + text: + type: text + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + Content-Type: application/json + bulk: + index: to_categorize + refresh: true + body: | + { "index": {} } + { "product": "server","text": "Node 2 stopping" } + { "index": {} } + { "product": "server", "text": "Node 2 starting"} + { "index": {} } + { "product": "server", "text": "Node 4 stopping"} + { "index": {} } + { "product": "server", "text": "Node 5 stopping"} + { "index": {} } + { "product": "user", "text": "User Foo logging on" } + { "index": {} } + { "product": "user", "text": "User Foo logging on" } + { "index": {} } + { "product": "user", "text": "User Foo logging off" } + +--- +"Test categorization agg simple": + + - do: + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "size": 10 + } + } + } + } + - length: { aggregations.categories.buckets: 4} + - match: {aggregations.categories.buckets.0.doc_count: 3} + - match: {aggregations.categories.buckets.0.key: "Node stopping" } + + - do: + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "size": 10, + "max_children": 2, + "max_depth": 1, + "similarity_threshold": 0.11 + } + } + } + } + + - length: { aggregations.categories.buckets: 2 } + - match: { aggregations.categories.buckets.0.doc_count: 4 } + - match: { aggregations.categories.buckets.0.key: "Node *" } + - match: { aggregations.categories.buckets.1.doc_count: 3 } + - match: { aggregations.categories.buckets.1.key: "User Foo logging *" } +--- +"Test categorization aggregation with poor settings": + + - do: + catch: /\[max_children\] must be greater than 0/ + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "max_children": -2 + } + } + } + } + - do: + catch: /\[max_depth\] must be greater than 0/ + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "max_depth": -2 + } + } + } + } + - do: + catch: /\[similarity_threshold\] must be in the range \[0.1, 1.0\]/ + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "similarity_threshold": 0.0 + } + } + } + } From f384c164417ccceaace50cb35f45a0b5d37ee625 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 14 Sep 2021 08:11:50 -0400 Subject: [PATCH 02/20] addressing PR comments --- .../support/AggregationContext.java | 7 +- .../index/mapper/MapperServiceTestCase.java | 6 ++ .../aggregations/AggregatorTestCase.java | 29 +++---- .../test/AbstractBuilderTestCase.java | 4 + .../CategorizationTokenTree.java | 33 +++---- .../CategorizeTextAggregator.java | 51 +++++------ .../InternalCategorizationAggregation.java | 66 +++++++------- ...{LogGroup.java => TextCategorization.java} | 66 ++++++++++---- .../ml/aggs/categorization/TreeNode.java | 87 +++++++++---------- .../aggs/categorization/TreeNodeFactory.java | 2 +- .../CategorizeTextAggregatorTests.java | 15 +++- .../categorization/InnerTreeNodeTests.java | 36 ++++---- .../categorization/LeafTreeNodeTests.java | 17 ++-- ...ests.java => TextCategorizationTests.java} | 23 +++-- .../test/ml/categorization_agg.yml | 31 ++++--- 15 files changed, 246 insertions(+), 227 deletions(-) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/{LogGroup.java => TextCategorization.java} (52%) rename x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/{LogGroupTests.java => TextCategorizationTests.java} (58%) diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java index ad6366d75ce64..9ebb30db75cd1 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java @@ -96,9 +96,10 @@ public final FieldContext buildFieldContext(String field) { return new FieldContext(field, buildFieldData(ft), ft); } - public AnalysisRegistry getAnalysisRegistry() { - return null; - } + /** + * @return The analysis registry for the node. Allows specialized aggregations to build custom analyzers for tokenizing text + */ + public abstract AnalysisRegistry getAnalysisRegistry(); /** * Lookup the context for an already resolved field type. diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java index f341ae905ce0b..ae42718c241a0 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java @@ -35,6 +35,7 @@ import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.index.analysis.AnalyzerScope; import org.elasticsearch.index.analysis.IndexAnalyzers; import org.elasticsearch.index.analysis.NamedAnalyzer; @@ -352,6 +353,11 @@ public long nowInMillis() { return 0; } + @Override + public AnalysisRegistry getAnalysisRegistry() { + return null; + } + @Override public boolean isFieldMapped(String field) { throw new UnsupportedOperationException(); diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java index 906fc8b50366d..473afb1a7e903 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java @@ -56,8 +56,6 @@ import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; -import org.elasticsearch.env.Environment; -import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.AnalysisRegistry; @@ -99,7 +97,6 @@ import org.elasticsearch.indices.analysis.AnalysisModule; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; -import org.elasticsearch.plugins.AnalysisPlugin; import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.script.ScriptCompiler; import org.elasticsearch.script.ScriptService; @@ -140,6 +137,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.Consumer; import java.util.function.Supplier; @@ -163,7 +161,7 @@ public abstract class AggregatorTestCase extends ESTestCase { private List releasables = new ArrayList<>(); protected ValuesSourceRegistry valuesSourceRegistry; - protected AnalysisModule analysisModule; + private AnalysisModule analysisModule; // A list of field types that should not be tested, or are not currently supported private static final List TYPE_TEST_BLACKLIST = List.of( @@ -184,23 +182,22 @@ public void initValuesSourceRegistry() { } @Before - public void initAnalysisRegistry() throws IOException { - analysisModule = new AnalysisModule( - TestEnvironment.newEnvironment( - Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build() - ), - getAnalysisPlugins() - ); + public void initAnalysisRegistry() throws Exception { + analysisModule = createAnalysisModule(); } /** - * Test cases should override this if they have plugins that need to be loaded, e.g. the plugins their aggregators are in. + * @return a new analysis module. Tests that require a fully constructed analysis module (used to create an analysis registry) + * should override this method */ - protected List getSearchPlugins() { - return List.of(); + protected AnalysisModule createAnalysisModule() throws Exception { + return null; } - protected List getAnalysisPlugins() { + /** + * Test cases should override this if they have plugins that need to be loaded, e.g. the plugins their aggregators are in. + */ + protected List getSearchPlugins() { return List.of(); } @@ -302,7 +299,7 @@ public void onCache(ShardId shardId, Accountable accountable) {} MultiBucketConsumer consumer = new MultiBucketConsumer(maxBucket, breakerService.getBreaker(CircuitBreaker.REQUEST)); AggregationContext context = new ProductionAggregationContext( - analysisModule.getAnalysisRegistry(), + Optional.ofNullable(analysisModule).map(AnalysisModule::getAnalysisRegistry).orElse(null), searchExecutionContext, new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), breakerService), bytesToPreallocate, diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java index f7475f8536153..73eef9442246e 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractBuilderTestCase.java @@ -142,6 +142,10 @@ protected Collection> getPlugins() { return Collections.singletonList(TestGeoShapeFieldMapperPlugin.class); } + /** + * Allows additional plugins other than the required `TestGeoShapeFieldMapperPlugin` + * Could probably be removed when dependencies against geo_shape is decoupled + */ protected Collection> getExtraPlugins() { return Collections.emptyList(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java index 0de2299be7f45..dc5b6003200fd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java @@ -12,9 +12,9 @@ import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; -import org.elasticsearch.common.Strings; import org.elasticsearch.search.aggregations.InternalAggregations; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -83,7 +83,7 @@ public CategorizationTokenTree(int maxChildren, int maxDepth, double similarityT public List toIntermediateBuckets() { return root.values().stream().flatMap(c -> c.getAllChildrenLogGroups().stream()).map(lg -> { InternalCategorizationAggregation.Bucket bucket = new InternalCategorizationAggregation.Bucket( - new InternalCategorizationAggregation.BucketKey(lg.getLogEvent()), + new InternalCategorizationAggregation.BucketKey(lg.getCategorization()), lg.getCount(), InternalAggregations.EMPTY ); @@ -92,30 +92,15 @@ public List toIntermediateBuckets() { }).collect(Collectors.toList()); } - public List toBuckets(Map internalAggregations) { - return root.values() - .stream() - .flatMap(c -> c.getAllChildrenLogGroups().stream()) - .map( - lg -> new InternalCategorizationAggregation.Bucket( - new InternalCategorizationAggregation.BucketKey(lg.getLogEvent()), - lg.getCount(), - internalAggregations.get(lg.getId()) - ) - ) - .sorted() - .collect(Collectors.toList()); - } - void mergeSmallestChildren() { root.values().forEach(TreeNode::collapseTinyChildren); } - public LogGroup parseLogLine(final BytesRef[] logTokens) { + public TextCategorization parseLogLine(final BytesRef[] logTokens) { return parseLogLine(logTokens, 1); } - public LogGroup parseLogLineConst(final BytesRef[] logTokens) { + public TextCategorization parseLogLineConst(final BytesRef[] logTokens) { TreeNode currentNode = this.root.get(logTokens.length); if (currentNode == null) { // we are missing an entire sub tree. New log length found return null; @@ -123,8 +108,10 @@ public LogGroup parseLogLineConst(final BytesRef[] logTokens) { return currentNode.getLogGroup(logTokens); } - public LogGroup parseLogLine(final BytesRef[] logTokens, long docCount) { - LOGGER.trace("parsing tokens [{}]", Strings.arrayToDelimitedString(logTokens, " ")); + public TextCategorization parseLogLine(final BytesRef[] logTokens, long docCount) { + if (LOGGER.isTraceEnabled()) { + LOGGER.trace("parsing tokens [{}]", Arrays.stream(logTokens).map(BytesRef::utf8ToString).collect(Collectors.joining(" "))); + } TreeNode currentNode = this.root.get(logTokens.length); if (currentNode == null) { // we are missing an entire sub tree. New log length found currentNode = newNode(docCount, 0, logTokens); @@ -146,8 +133,8 @@ public TreeNode newNode(long docCount, int tokenPos, BytesRef[] tokens) { } @Override - public LogGroup newGroup(long docCount, BytesRef[] logTokens) { - LogGroup group = new LogGroup(logTokens, docCount, idGen.incrementAndGet()); + public TextCategorization newGroup(long docCount, BytesRef[] logTokens) { + TextCategorization group = new TextCategorization(logTokens, docCount, idGen.incrementAndGet()); // Get the regular size bytes from the LogGroup and how much it costs to reference it sizeInBytes += group.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF; return group; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java index 151bfd26798bd..fb257469d63c9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java @@ -12,7 +12,6 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.PriorityQueue; -import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.search.aggregations.Aggregator; @@ -41,7 +40,6 @@ public class CategorizeTextAggregator extends DeferableBucketAggregator { private final TermsAggregator.BucketCountThresholds bucketCountThresholds; private final SourceLookup sourceLookup; - private final BigArrays bigArrays; private final MappedFieldType fieldType; private final CategorizationAnalyzer analyzer; private final String sourceFieldName; @@ -73,7 +71,6 @@ protected CategorizeTextAggregator( categorizationFilters ); this.analyzer = new CategorizationAnalyzer(context.getAnalysisRegistry(), categorizationAnalyzerConfig); - this.bigArrays = context.bigArrays(); this.categorizers = bigArrays().newObjectArray(1); this.maxChildren = maxChildren; this.maxDepth = maxDepth; @@ -93,7 +90,7 @@ public InternalAggregation[] buildAggregations(long[] ordsToCollect) throws IOEx InternalCategorizationAggregation.Bucket[][] topBucketsPerOrd = new InternalCategorizationAggregation.Bucket[ordsToCollect.length][]; for (int ordIdx = 0; ordIdx < ordsToCollect.length; ordIdx++) { - int size = (int) Math.min(bucketOrds.size(), bucketCountThresholds.getShardSize()); + int size = (int) Math.min(bucketOrds.bucketsInOrd(ordIdx), bucketCountThresholds.getShardSize()); PriorityQueue ordered = new InternalCategorizationAggregation.BucketCountPriorityQueue(size); CategorizationTokenTree categorizationTokenTree = categorizers.get(ordsToCollect[ordIdx]); @@ -167,41 +164,41 @@ private void collectFromSource(int doc, long owningBucketOrd) throws IOException } private void processTokenStream(long owningBucketOrd, TokenStream ts, int doc) throws IOException { + ArrayList tokens = new ArrayList<>(); try { CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); ts.reset(); - ArrayList tokens = new ArrayList<>(); while (ts.incrementToken()) { tokens.add(new BytesRef(termAtt)); } if (tokens.isEmpty()) { return; } - categorizers = bigArrays.grow(categorizers, owningBucketOrd + 1); - CategorizationTokenTree categorizer = categorizers.get(owningBucketOrd); - if (categorizer == null) { - categorizer = new CategorizationTokenTree(maxChildren, maxDepth, similarityThreshold); - addRequestCircuitBreakerBytes(categorizer.ramBytesUsed()); - categorizers.set(owningBucketOrd, categorizer); - } - long previousSize = categorizer.ramBytesUsed(); - LogGroup lg = categorizer.parseLogLine(tokens.toArray(BytesRef[]::new), docCountProvider.getDocCount(doc)); - long newSize = categorizer.ramBytesUsed(); - if (newSize - previousSize > 0) { - addRequestCircuitBreakerBytes(newSize - previousSize); - } - - long bucketOrd = bucketOrds.add(owningBucketOrd, lg.getId()); - if (bucketOrd < 0) { // already seen - bucketOrd = -1 - bucketOrd; - collectExistingBucket(sub, doc, bucketOrd); - } else { - lg.bucketOrd = bucketOrd; - collectBucket(sub, doc, bucketOrd); - } } finally { ts.close(); } + categorizers = bigArrays().grow(categorizers, owningBucketOrd + 1); + CategorizationTokenTree categorizer = categorizers.get(owningBucketOrd); + if (categorizer == null) { + categorizer = new CategorizationTokenTree(maxChildren, maxDepth, similarityThreshold); + addRequestCircuitBreakerBytes(categorizer.ramBytesUsed()); + categorizers.set(owningBucketOrd, categorizer); + } + long previousSize = categorizer.ramBytesUsed(); + TextCategorization lg = categorizer.parseLogLine(tokens.toArray(BytesRef[]::new), docCountProvider.getDocCount(doc)); + long newSize = categorizer.ramBytesUsed(); + if (newSize - previousSize > 0) { + addRequestCircuitBreakerBytes(newSize - previousSize); + } + + long bucketOrd = bucketOrds.add(owningBucketOrd, lg.getId()); + if (bucketOrd < 0) { // already seen + bucketOrd = -1 - bucketOrd; + collectExistingBucket(sub, doc, bucketOrd); + } else { + lg.bucketOrd = bucketOrd; + collectBucket(sub, doc, bucketOrd); + } } }; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java index 89ff6bd302eaf..be65ae70711f7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java @@ -89,6 +89,29 @@ static class BucketKey implements ToXContentFragment, Writeable, Comparable collapsedWildCards = new ArrayList<>(); + boolean previousTokenWildCard = false; + for (BytesRef token : key) { + if (token.equals(WILD_CARD)) { + if (previousTokenWildCard == false) { + previousTokenWildCard = true; + collapsedWildCards.add(WILD_CARD); + } + } else { + previousTokenWildCard = false; + collapsedWildCards.add(token); + } + } + if (collapsedWildCards.size() == key.length) { + return new BucketKey(key); + } + return new BucketKey(collapsedWildCards.toArray(BytesRef[]::new)); + } + BucketKey(BytesRef[] key) { this.key = key; } @@ -143,28 +166,6 @@ public int compareTo(BucketKey o) { return Arrays.compare(key, o.key); } - private BucketKey collapseWildCards() { - if (key.length <= 1) { - return this; - } - List collapsedWildCards = new ArrayList<>(); - boolean previousTokenWildCard = false; - for (BytesRef token : key) { - if (token.equals(WILD_CARD)) { - if (previousTokenWildCard == false) { - previousTokenWildCard = true; - collapsedWildCards.add(WILD_CARD); - } - } else { - previousTokenWildCard = false; - collapsedWildCards.add(token); - } - } - if (collapsedWildCards.size() == key.length) { - return this; - } - return new BucketKey(collapsedWildCards.toArray(BytesRef[]::new)); - } } public static class Bucket extends InternalBucket implements MultiBucketsAggregation.Bucket, Comparable { @@ -351,20 +352,21 @@ public InternalAggregation reduce(List aggregations, Reduce categorizationTokenTree.mergeSmallestChildren(); Map mergedBuckets = new HashMap<>(aggregations.size(), 1.0f); for (DelayedCategorizationBucket delayedBucket : reduced.values()) { - LogGroup group = categorizationTokenTree.parseLogLineConst(delayedBucket.key.keyAsTokens()); + TextCategorization group = categorizationTokenTree.parseLogLineConst(delayedBucket.key.keyAsTokens()); if (group == null) { throw new AggregationExecutionException( "Unexpected null categorization group for bucket [" + delayedBucket.key.asString() + "]" ); } - BucketKey key = new BucketKey(group.getLogEvent()); - mergedBuckets.computeIfAbsent( - reduceContext.isFinalReduce() ? key.collapseWildCards() : key, - k -> new DelayedCategorizationBucket(k, new ArrayList<>(delayedBucket.toReduce.size()), 0L) - ).add(delayedBucket); + + BucketKey key = reduceContext.isFinalReduce() ? + BucketKey.withCollapsedWildcards(group.getCategorization()) : + new BucketKey(group.getCategorization()); + mergedBuckets.computeIfAbsent(key, k -> new DelayedCategorizationBucket(k, new ArrayList<>(delayedBucket.toReduce.size()), 0L)) + .add(delayedBucket); } - final int size = reduceContext.isFinalReduce() == false ? buckets.size() : Math.min(requiredSize, buckets.size()); + final int size = reduceContext.isFinalReduce() == false ? mergedBuckets.size() : Math.min(requiredSize, mergedBuckets.size()); final PriorityQueue pq = new BucketCountPriorityQueue(size); for (Map.Entry keyAndBuckets : mergedBuckets.entrySet()) { final BucketKey key = keyAndBuckets.getKey(); @@ -397,12 +399,6 @@ public InternalAggregation reduce(List aggregations, Reduce ); } - @Override - public Object getProperty(List path) { - // TODO anything special? - return super.getProperty(path); - } - @Override public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.startArray(CommonFields.BUCKETS.getPreferredName()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/LogGroup.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java similarity index 52% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/LogGroup.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java index b6c81235c9652..10d1cc4dbd04d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/LogGroup.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java @@ -10,7 +10,6 @@ import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; -import org.elasticsearch.core.Tuple; import java.util.Arrays; import java.util.stream.Collectors; @@ -18,14 +17,14 @@ import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationTokenTree.WILD_CARD; /** - * A log group that provides methods for: + * A text categorization group that provides methods for: * - calculating similarity between it and a new log * - expanding the existing log group by adding a new log */ -class LogGroup implements Accountable { +class TextCategorization implements Accountable { private final long id; - private final BytesRef[] logEvent; + private final BytesRef[] categorization; private final long[] tokenCounts; private long count; @@ -38,15 +37,15 @@ public String toString() { + "id=" + id + ", logEvent=" - + Arrays.stream(logEvent).map(BytesRef::utf8ToString).collect(Collectors.joining(", ", "[", "]")) + + Arrays.stream(categorization).map(BytesRef::utf8ToString).collect(Collectors.joining(", ", "[", "]")) + ", count=" + count + '}'; } - LogGroup(BytesRef[] logTokens, long count, long id) { + TextCategorization(BytesRef[] logTokens, long count, long id) { this.id = id; - this.logEvent = logTokens; + this.categorization = logTokens; this.count = count; this.tokenCounts = new long[logTokens.length]; Arrays.fill(this.tokenCounts, count); @@ -56,37 +55,37 @@ public long getId() { return id; } - BytesRef[] getLogEvent() { - return logEvent; + BytesRef[] getCategorization() { + return categorization; } public long getCount() { return count; } - Tuple calculateSimilarity(BytesRef[] logEvent) { - assert logEvent.length == this.logEvent.length; + Similarity calculateSimilarity(BytesRef[] logEvent) { + assert logEvent.length == this.categorization.length; int eqParams = 0; long tokenCount = 0; long tokensKept = 0; for (int i = 0; i < logEvent.length; i++) { - if (logEvent[i].equals(this.logEvent[i])) { + if (logEvent[i].equals(this.categorization[i])) { tokensKept += tokenCounts[i]; tokenCount += tokenCounts[i]; - } else if (this.logEvent[i].equals(WILD_CARD)) { + } else if (this.categorization[i].equals(WILD_CARD)) { eqParams++; } else { tokenCount += tokenCounts[i]; } } - return new Tuple<>((double) tokensKept / tokenCount, eqParams); + return new Similarity((double) tokensKept / tokenCount, eqParams); } void addLog(BytesRef[] logEvent, long docCount) { - assert logEvent.length == this.logEvent.length; + assert logEvent.length == this.categorization.length; for (int i = 0; i < logEvent.length; i++) { - if (logEvent[i].equals(this.logEvent[i]) == false) { - this.logEvent[i] = WILD_CARD; + if (logEvent[i].equals(this.categorization[i]) == false) { + this.categorization[i] = WILD_CARD; } else { tokenCounts[i] += docCount; } @@ -97,9 +96,38 @@ void addLog(BytesRef[] logEvent, long docCount) { @Override public long ramBytesUsed() { return Long.BYTES // id - + (long) logEvent.length * RamUsageEstimator.NUM_BYTES_ARRAY_HEADER // logEvent - + RamUsageEstimator.NUM_BYTES_OBJECT_REF + ((long) logEvent.length * Long.BYTES) + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + + (long) categorization.length * RamUsageEstimator.NUM_BYTES_ARRAY_HEADER // logEvent + + RamUsageEstimator.NUM_BYTES_OBJECT_REF + + ((long) categorization.length * Long.BYTES) + + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF + Long.BYTES; // count } + static class Similarity implements Comparable { + private final double similarity; + private final int wildCardCount; + + private Similarity(double similarity, int wildCardCount) { + this.similarity = similarity; + this.wildCardCount = wildCardCount; + } + + @Override + public int compareTo(Similarity o) { + int d = Double.compare(similarity, o.similarity); + if (d != 0) { + return d; + } + return Integer.compare(wildCardCount, o.wildCardCount); + } + + public double getSimilarity() { + return similarity; + } + + public int getWildCardCount() { + return wildCardCount; + } + } + } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java index 6628b2273719c..6f7034d2a4623 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java @@ -35,7 +35,7 @@ * * Two major node types exist: * - Inner: which are nodes that have children token nodes - * - Leaf: Which collection multiple {@link LogGroup} based on similarity restrictions + * - Leaf: Which collection multiple {@link TextCategorization} based on similarity restrictions */ abstract class TreeNode implements Accountable { @@ -60,26 +60,21 @@ final long getCount() { } // TODO add option for calculating the cost of adding the new group - abstract LogGroup addLog(BytesRef[] logTokens, long docCount, TreeNodeFactory treeNodeFactory); + abstract TextCategorization addLog(BytesRef[] logTokens, long docCount, TreeNodeFactory treeNodeFactory); - abstract LogGroup getLogGroup(BytesRef[] logTokens); + abstract TextCategorization getLogGroup(BytesRef[] logTokens); - abstract List getAllChildrenLogGroups(); + abstract List getAllChildrenLogGroups(); abstract void collapseTinyChildren(); - @Override - public long ramBytesUsed() { - return Long.BYTES; // count - } - static class LeafTreeNode extends TreeNode { - private final List logGroups; + private final List textCategorizations; private final double similarityThreshold; LeafTreeNode(long count, double similarityThreshold) { super(count); - this.logGroups = new ArrayList<>(); + this.textCategorizations = new ArrayList<>(); this.similarityThreshold = similarityThreshold; } @@ -99,8 +94,8 @@ void mergeWith(TreeNode treeNode) { } incCount(treeNode.getCount()); LeafTreeNode otherLeaf = (LeafTreeNode) treeNode; - for (LogGroup group : otherLeaf.logGroups) { - if (getAndUpdateLogGroup(group.getLogEvent(), group.getCount()).isPresent() == false) { + for (TextCategorization group : otherLeaf.textCategorizations) { + if (getAndUpdateLogGroup(group.getCategorization(), group.getCount()).isPresent() == false) { putNewLogGroup(group); } } @@ -108,30 +103,31 @@ void mergeWith(TreeNode treeNode) { @Override public long ramBytesUsed() { - return super.ramBytesUsed() + NUM_BYTES_OBJECT_REF // list reference + return Long.BYTES // count + + NUM_BYTES_OBJECT_REF // list reference + Double.BYTES // similarityThreshold - + sizeOfCollection(logGroups); + + sizeOfCollection(textCategorizations); } @Override - public LogGroup addLog(BytesRef[] logTokens, long docCount, TreeNodeFactory treeNodeFactory) { + public TextCategorization addLog(BytesRef[] logTokens, long docCount, TreeNodeFactory treeNodeFactory) { return getAndUpdateLogGroup(logTokens, docCount).orElseGet(() -> { // Need to update the tree if possible - LogGroup group = treeNodeFactory.newGroup(docCount, logTokens); + TextCategorization group = treeNodeFactory.newGroup(docCount, logTokens); LOGGER.trace(() -> new ParameterizedMessage("created group! [{}]", group)); return putNewLogGroup(group); }); } @Override - List getAllChildrenLogGroups() { - return logGroups; + List getAllChildrenLogGroups() { + return textCategorizations; } @Override void collapseTinyChildren() {} - private Optional getAndUpdateLogGroup(BytesRef[] logTokens, long docCount) { + private Optional getAndUpdateLogGroup(BytesRef[] logTokens, long docCount) { return getBestLogGroup(logTokens).map(bestGroupAndSimilarity -> { if (bestGroupAndSimilarity.v2() >= similarityThreshold) { bestGroupAndSimilarity.v1().addLog(logTokens, docCount); @@ -141,37 +137,34 @@ private Optional getAndUpdateLogGroup(BytesRef[] logTokens, long docCo }); } - LogGroup putNewLogGroup(LogGroup group) { - logGroups.add(group); + TextCategorization putNewLogGroup(TextCategorization group) { + textCategorizations.add(group); return group; } - private Optional> getBestLogGroup(BytesRef[] logTokens) { - if (logGroups.isEmpty()) { + private Optional> getBestLogGroup(BytesRef[] logTokens) { + if (textCategorizations.isEmpty()) { return Optional.empty(); } - if (logGroups.size() == 1) { - return Optional.of(new Tuple<>(logGroups.get(0), logGroups.get(0).calculateSimilarity(logTokens).v1())); + if (textCategorizations.size() == 1) { + return Optional.of( + new Tuple<>(textCategorizations.get(0), textCategorizations.get(0).calculateSimilarity(logTokens).getSimilarity()) + ); } - double maxSimilarity = 0.0; - int maxParamMatch = 0; - LogGroup bestGroup = null; - for (LogGroup logGroup : this.logGroups) { - Tuple groupSimilarity = logGroup.calculateSimilarity(logTokens); - if (groupSimilarity.v1() > maxSimilarity) { - maxSimilarity = groupSimilarity.v1(); - maxParamMatch = groupSimilarity.v2(); - bestGroup = logGroup; - } else if (groupSimilarity.v1() == maxSimilarity && groupSimilarity.v2() > maxParamMatch) { - maxParamMatch = groupSimilarity.v2(); - bestGroup = logGroup; + TextCategorization.Similarity maxSimilarity = null; + TextCategorization bestGroup = null; + for (TextCategorization textCategorization : this.textCategorizations) { + TextCategorization.Similarity groupSimilarity = textCategorization.calculateSimilarity(logTokens); + if (maxSimilarity == null || groupSimilarity.compareTo(maxSimilarity) > 0) { + maxSimilarity = groupSimilarity; + bestGroup = textCategorization; } } - return Optional.of(new Tuple<>(bestGroup, maxSimilarity)); + return Optional.of(new Tuple<>(bestGroup, maxSimilarity.getSimilarity())); } @Override - public LogGroup getLogGroup(final BytesRef[] logTokens) { + public TextCategorization getLogGroup(final BytesRef[] logTokens) { return getBestLogGroup(logTokens).map(Tuple::v1).orElse(null); } @@ -180,12 +173,13 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; LeafTreeNode that = (LeafTreeNode) o; - return Double.compare(that.similarityThreshold, similarityThreshold) == 0 && Objects.equals(logGroups, that.logGroups); + return Double.compare(that.similarityThreshold, similarityThreshold) == 0 + && Objects.equals(textCategorizations, that.textCategorizations); } @Override public int hashCode() { - return Objects.hash(logGroups, similarityThreshold); + return Objects.hash(textCategorizations, similarityThreshold); } } @@ -209,7 +203,7 @@ boolean isLeaf() { } @Override - public LogGroup getLogGroup(final BytesRef[] logTokens) { + public TextCategorization getLogGroup(final BytesRef[] logTokens) { return getChild(logTokens[childrenTokenPos]).or(() -> getChild(WILD_CARD)) .map(node -> node.getLogGroup(logTokens)) .orElse(null); @@ -217,7 +211,8 @@ public LogGroup getLogGroup(final BytesRef[] logTokens) { @Override public long ramBytesUsed() { - return super.ramBytesUsed() + NUM_BYTES_OBJECT_REF // children reference + return Long.BYTES // count + + NUM_BYTES_OBJECT_REF // children reference + Integer.BYTES // childrenTokenPos + Integer.BYTES // maxChildren + NUM_BYTES_OBJECT_REF // smallestChildReference @@ -227,7 +222,7 @@ public long ramBytesUsed() { } @Override - public LogGroup addLog(final BytesRef[] logTokens, final long docCount, final TreeNodeFactory treeNodeFactory) { + public TextCategorization addLog(final BytesRef[] logTokens, final long docCount, final TreeNodeFactory treeNodeFactory) { BytesRef currentToken = logTokens[childrenTokenPos]; TreeNode child = getChild(currentToken).map(node -> { node.incCount(docCount); @@ -370,7 +365,7 @@ private Optional getChild(BytesRef token) { return node == null ? Optional.empty() : Optional.of(node); } - public List getAllChildrenLogGroups() { + public List getAllChildrenLogGroups() { return children.values().stream().flatMap(c -> c.getAllChildrenLogGroups().stream()).collect(Collectors.toList()); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java index b710d2ad9c946..eade0ecad3240 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java @@ -12,5 +12,5 @@ interface TreeNodeFactory { TreeNode newNode(long docCount, int tokenPos, BytesRef[] logTokens); - LogGroup newGroup(long docCount, BytesRef[] logTokens); + TextCategorization newGroup(long docCount, BytesRef[] logTokens); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java index 1caaf4d685081..58dbad7b0bb20 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java @@ -16,8 +16,10 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.env.Environment; +import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.index.mapper.TextFieldMapper; -import org.elasticsearch.plugins.AnalysisPlugin; +import org.elasticsearch.indices.analysis.AnalysisModule; import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.search.aggregations.AggregatorTestCase; import org.elasticsearch.search.aggregations.bucket.histogram.Histogram; @@ -40,12 +42,17 @@ public class CategorizeTextAggregatorTests extends AggregatorTestCase { @Override - protected List getSearchPlugins() { - return List.of(new MachineLearning(Settings.EMPTY, null)); + protected AnalysisModule createAnalysisModule() throws Exception { + return new AnalysisModule( + TestEnvironment.newEnvironment( + Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build() + ), + List.of(new MachineLearning(Settings.EMPTY, null)) + ); } @Override - protected List getAnalysisPlugins() { + protected List getSearchPlugins() { return List.of(new MachineLearning(Settings.EMPTY, null)); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java index 731d83c9b31f3..2eb1f26aa5770 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java @@ -10,7 +10,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.test.ESTestCase; -import static org.elasticsearch.xpack.ml.aggs.categorization.LogGroupTests.getTokens; +import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.getTokens; import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; @@ -21,62 +21,62 @@ public class InnerTreeNodeTests extends ESTestCase { public void testAddLog() { TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); - LogGroup group = innerTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 1, factory); - assertThat(group.getLogEvent(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); + TextCategorization group = innerTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 1, factory); + assertThat(group.getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); assertThat( - innerTreeNode.addLog(getTokens("foo2", "bar", "baz", "biz"), 1, factory).getLogEvent(), + innerTreeNode.addLog(getTokens("foo2", "bar", "baz", "biz"), 1, factory).getCategorization(), arrayContaining(getTokens("foo2", "bar", "baz", "biz")) ); assertThat( - innerTreeNode.addLog(getTokens("foo3", "bar", "baz", "biz"), 1, factory).getLogEvent(), + innerTreeNode.addLog(getTokens("foo3", "bar", "baz", "biz"), 1, factory).getCategorization(), arrayContaining(getTokens("foo3", "bar", "baz", "biz")) ); assertThat( - innerTreeNode.addLog(getTokens("foo4", "bar", "baz", "biz"), 1, factory).getLogEvent(), + innerTreeNode.addLog(getTokens("foo4", "bar", "baz", "biz"), 1, factory).getCategorization(), arrayContaining(getTokens("*", "bar", "baz", "biz")) ); assertThat( - innerTreeNode.addLog(getTokens("foo", "bar", "baz", "bizzy"), 1, factory).getLogEvent(), + innerTreeNode.addLog(getTokens("foo", "bar", "baz", "bizzy"), 1, factory).getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "*")) ); } public void testAddLogWithLargerIncoming() { TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); - LogGroup group = innerTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 100, factory); - assertThat(group.getLogEvent(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); + TextCategorization group = innerTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 100, factory); + assertThat(group.getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); assertThat( - innerTreeNode.addLog(getTokens("foo2", "bar", "baz", "biz"), 100, factory).getLogEvent(), + innerTreeNode.addLog(getTokens("foo2", "bar", "baz", "biz"), 100, factory).getCategorization(), arrayContaining(getTokens("foo2", "bar", "baz", "biz")) ); assertThat( - innerTreeNode.addLog(getTokens("foosmall", "bar", "baz", "biz"), 1, factory).getLogEvent(), + innerTreeNode.addLog(getTokens("foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), arrayContaining(getTokens("foosmall", "bar", "baz", "biz")) ); assertThat( - innerTreeNode.addLog(getTokens("foobigun", "bar", "baz", "biz"), 1000, factory).getLogEvent(), + innerTreeNode.addLog(getTokens("foobigun", "bar", "baz", "biz"), 1000, factory).getCategorization(), arrayContaining(getTokens("foobigun", "bar", "baz", "biz")) ); assertThat( - innerTreeNode.getLogGroup(getTokens("foosmall", "bar", "baz", "biz")).getLogEvent(), + innerTreeNode.getLogGroup(getTokens("foosmall", "bar", "baz", "biz")).getCategorization(), equalTo(getTokens("*", "bar", "baz", "biz")) ); } public void testCollapseTinyChildren() { TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1000, 0, 4); - LogGroup group = innerTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 1000, factory); - assertThat(group.getLogEvent(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); + TextCategorization group = innerTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 1000, factory); + assertThat(group.getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); assertThat( - innerTreeNode.addLog(getTokens("foo2", "bar", "baz", "biz"), 1000, factory).getLogEvent(), + innerTreeNode.addLog(getTokens("foo2", "bar", "baz", "biz"), 1000, factory).getCategorization(), arrayContaining(getTokens("foo2", "bar", "baz", "biz")) ); innerTreeNode.incCount(1000); assertThat( - innerTreeNode.addLog(getTokens("foosmall", "bar", "baz", "biz"), 1, factory).getLogEvent(), + innerTreeNode.addLog(getTokens("foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), arrayContaining(getTokens("foosmall", "bar", "baz", "biz")) ); innerTreeNode.incCount(1); @@ -102,7 +102,7 @@ public void testMergeWith() { innerTreeNode.mergeWith(mergeWith); assertThat(innerTreeNode.hasChild(new BytesRef("*")), is(true)); assertThat( - innerTreeNode.getLogGroup(getTokens("footiny", "bar", "baz", "biz")).getLogEvent(), + innerTreeNode.getLogGroup(getTokens("footiny", "bar", "baz", "biz")).getCategorization(), arrayContaining(getTokens("*", "bar", "baz", "biz")) ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java index 74b72b784632b..0cce200fe1526 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java @@ -9,7 +9,7 @@ import org.elasticsearch.test.ESTestCase; -import static org.elasticsearch.xpack.ml.aggs.categorization.LogGroupTests.getTokens; +import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.getTokens; import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -21,15 +21,15 @@ public class LeafTreeNodeTests extends ESTestCase { public void testAddGroup() { TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 0.6); - LogGroup group = leafTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 1, factory); + TextCategorization group = leafTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 1, factory); - assertThat(group.getLogEvent(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); + assertThat(group.getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); assertThat(group.getCount(), equalTo(1L)); assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(1)); long previousBytesUsed = leafTreeNode.ramBytesUsed(); group = leafTreeNode.addLog(getTokens("foo", "bar", "bozo", "bizzy"), 1, factory); - assertThat(group.getLogEvent(), arrayContaining(getTokens("foo", "bar", "bozo", "bizzy"))); + assertThat(group.getCategorization(), arrayContaining(getTokens("foo", "bar", "bozo", "bizzy"))); assertThat(group.getCount(), equalTo(1L)); assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(2)); assertThat(leafTreeNode.ramBytesUsed(), greaterThan(previousBytesUsed)); @@ -37,7 +37,7 @@ public void testAddGroup() { group = leafTreeNode.addLog(getTokens("foo", "bar", "baz", "different"), 3, factory); - assertThat(group.getLogEvent(), arrayContaining(getTokens("foo", "bar", "baz", "*"))); + assertThat(group.getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "*"))); assertThat(group.getCount(), equalTo(4L)); assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(2)); assertThat(previousBytesUsed, equalTo(leafTreeNode.ramBytesUsed())); @@ -62,7 +62,10 @@ public void testMergeWith() { assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(2)); assertThat(leafTreeNode.getCount(), equalTo(7L)); - assertThat(leafTreeNode.getAllChildrenLogGroups().get(0).getLogEvent(), arrayContaining(getTokens("foo", "bar", "baz", "*"))); - assertThat(leafTreeNode.getAllChildrenLogGroups().get(1).getLogEvent(), arrayContaining(getTokens("foo", "bart", "bat", "built"))); + assertThat(leafTreeNode.getAllChildrenLogGroups().get(0).getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "*"))); + assertThat( + leafTreeNode.getAllChildrenLogGroups().get(1).getCategorization(), + arrayContaining(getTokens("foo", "bart", "bat", "built")) + ); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LogGroupTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java similarity index 58% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LogGroupTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java index 66bae64ac305a..d18b5f2a68c26 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LogGroupTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java @@ -8,35 +8,34 @@ package org.elasticsearch.xpack.ml.aggs.categorization; import org.apache.lucene.util.BytesRef; -import org.elasticsearch.core.Tuple; import org.elasticsearch.test.ESTestCase; import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; -public class LogGroupTests extends ESTestCase { +public class TextCategorizationTests extends ESTestCase { public void testSimilarity() { - LogGroup lg = new LogGroup(getTokens("foo", "bar", "baz", "biz"), 1, 1); - Tuple sims = lg.calculateSimilarity(getTokens("not", "matching", "anything", "nope")); - assertThat(sims.v1(), equalTo(0.0)); - assertThat(sims.v2(), equalTo(0)); + TextCategorization lg = new TextCategorization(getTokens("foo", "bar", "baz", "biz"), 1, 1); + TextCategorization.Similarity sims = lg.calculateSimilarity(getTokens("not", "matching", "anything", "nope")); + assertThat(sims.getSimilarity(), equalTo(0.0)); + assertThat(sims.getWildCardCount(), equalTo(0)); sims = lg.calculateSimilarity(getTokens("foo", "bar", "baz", "biz")); - assertThat(sims.v1(), equalTo(1.0)); - assertThat(sims.v2(), equalTo(0)); + assertThat(sims.getSimilarity(), equalTo(1.0)); + assertThat(sims.getWildCardCount(), equalTo(0)); sims = lg.calculateSimilarity(getTokens("foo", "fooagain", "notbar", "biz")); - assertThat(sims.v1(), closeTo(0.5, 0.0001)); - assertThat(sims.v2(), equalTo(0)); + assertThat(sims.getSimilarity(), closeTo(0.5, 0.0001)); + assertThat(sims.getWildCardCount(), equalTo(0)); } public void testAddLog() { - LogGroup lg = new LogGroup(getTokens("foo", "bar", "baz", "biz"), 1, 1); + TextCategorization lg = new TextCategorization(getTokens("foo", "bar", "baz", "biz"), 1, 1); lg.addLog(getTokens("foo", "bar", "baz", "bozo"), 2); assertThat(lg.getCount(), equalTo(3L)); - assertThat(lg.getLogEvent(), arrayContaining(getTokens("foo", "bar", "baz", "*"))); + assertThat(lg.getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "*"))); } static BytesRef[] getTokens(String... tokens) { diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml index 3b69194263f69..c003ffca0e4c7 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml @@ -22,20 +22,20 @@ setup: index: to_categorize refresh: true body: | - { "index": {} } - { "product": "server","text": "Node 2 stopping" } - { "index": {} } - { "product": "server", "text": "Node 2 starting"} - { "index": {} } - { "product": "server", "text": "Node 4 stopping"} - { "index": {} } - { "product": "server", "text": "Node 5 stopping"} - { "index": {} } - { "product": "user", "text": "User Foo logging on" } - { "index": {} } - { "product": "user", "text": "User Foo logging on" } - { "index": {} } - { "product": "user", "text": "User Foo logging off" } + {"index": {}} + {"product": "server","text": "Node 2 stopping"} + {"index": {}} + {"product": "server", "text": "Node 2 starting"} + {"index": {}} + {"product": "server", "text": "Node 4 stopping"} + {"index": {}} + {"product": "server", "text": "Node 5 stopping"} + {"index": {}} + {"product": "user", "text": "User Foo logging on"} + {"index": {}} + {"product": "user", "text": "User Foo logging on"} + {"index": {}} + {"product": "user", "text": "User Foo logging off"} --- "Test categorization agg simple": @@ -49,8 +49,7 @@ setup: "aggs": { "categories": { "categorize_text": { - "field": "text", - "size": 10 + "field": "text" } } } From 397c18ee47044c09f489d681e3bd0729a7da120e Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 14 Sep 2021 08:13:57 -0400 Subject: [PATCH 03/20] fixing docs --- .../aggregations/bucket/categorize-text-aggregation.asciidoc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc index 5aa486b9dcf40..c3639edb6cbe5 100644 --- a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc +++ b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc @@ -112,7 +112,6 @@ Response: } } -------------------------------------------------- -// TESTRESPONSE Here is an example using `categorization_filters` @@ -164,7 +163,6 @@ category results } } -------------------------------------------------- -// TESTRESPONSE Here is an example using `categorization_filters` @@ -214,7 +212,6 @@ and merging the log groups. } } -------------------------------------------------- -// TESTRESPONSE This aggregation can have both sub-aggregations and itself be a sub-aggregation. @@ -392,5 +389,3 @@ POST log-messages/_search?filter_path=aggregations } } -------------------------------------------------- -// TESTRESPONSE - From 0999c36bfe6766621f74494702eae4bcbfed9c43 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 14 Sep 2021 08:22:39 -0400 Subject: [PATCH 04/20] fixing docs --- .../bucket/categorize-text-aggregation.asciidoc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc index c3639edb6cbe5..88e6c15ae8b9e 100644 --- a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc +++ b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc @@ -68,7 +68,7 @@ merging. Example: -[source,console,id=categorize-text-aggregation-example] +[source,console] -------------------------------------------------- POST log-messages/_search?filter_path=aggregations { @@ -116,7 +116,7 @@ Response: Here is an example using `categorization_filters` -[source,console,id=categorize-text-aggregation-with-filters-example] +[source,console] -------------------------------------------------- POST log-messages/_search?filter_path=aggregations { @@ -131,8 +131,10 @@ POST log-messages/_search?filter_path=aggregations } -------------------------------------------------- // TEST[setup:categorize_text] + <1> The filters to apply to the analyzed tokens. It filters out tokens like `bar_123`. + Note how the `foo_` tokens are not part of the category results @@ -166,7 +168,7 @@ category results Here is an example using `categorization_filters` -[source,console,id=categorize-text-aggregation-with-broad-categories-example] +[source,console] -------------------------------------------------- POST log-messages/_search?filter_path=aggregations { @@ -215,7 +217,7 @@ and merging the log groups. This aggregation can have both sub-aggregations and itself be a sub-aggregation. -[source,console,id=categorize-text-aggregation-with-broad-categories-sub-aggs-example] +[source,console] -------------------------------------------------- POST log-messages/_search?filter_path=aggregations { @@ -245,6 +247,7 @@ POST log-messages/_search?filter_path=aggregations } } -------------------------------------------------- + [source,console-result] -------------------------------------------------- { From 580eb3b97fcb3a9d3ca17dcfd0345c48d096973c Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 14 Sep 2021 09:21:44 -0400 Subject: [PATCH 05/20] setting maximum values for aggregation --- .../AggConstructionContentionBenchmark.java | 6 ++ .../CategorizeTextAggregationBuilder.java | 61 ++++++++++++++++--- .../test/ml/categorization_agg.yml | 4 +- 3 files changed, 59 insertions(+), 12 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java index f496608e1c273..996ab9dc66850 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java @@ -22,6 +22,7 @@ import org.elasticsearch.core.Releasables; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.index.analysis.NamedAnalyzer; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.fielddata.IndexFieldData; @@ -197,6 +198,11 @@ public long nowInMillis() { return 0; } + @Override + public AnalysisRegistry getAnalysisRegistry() { + return null; + } + @Override protected IndexFieldData buildFieldData(MappedFieldType ft) { IndexFieldDataCache indexFieldDataCache = indicesFieldDataCache.buildIndexFieldDataCache(new IndexFieldDataCache.Listener() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java index 1f6507039f328..b47dbcd166146 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java @@ -27,6 +27,11 @@ import java.util.List; import java.util.Map; +import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.MIN_DOC_COUNT_FIELD_NAME; +import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.REQUIRED_SIZE_FIELD_NAME; +import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.SHARD_MIN_DOC_COUNT_FIELD_NAME; +import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.SHARD_SIZE_FIELD_NAME; + public class CategorizeTextAggregationBuilder extends AbstractAggregationBuilder { static final TermsAggregator.BucketCountThresholds DEFAULT_BUCKET_COUNT_THRESHOLDS = new TermsAggregator.BucketCountThresholds( @@ -35,6 +40,9 @@ public class CategorizeTextAggregationBuilder extends AbstractAggregationBuilder 10, -1 ); + + static final int MAX_MAX_CHILDREN = 100; + static final int MAX_MAX_DEPTH = 100; public static final String NAME = "categorize_text"; static final ParseField FIELD_NAME = new ParseField("field"); @@ -56,7 +64,7 @@ public class CategorizeTextAggregationBuilder extends AbstractAggregationBuilder PARSER.declareInt(CategorizeTextAggregationBuilder::shardSize, TermsAggregationBuilder.SHARD_SIZE_FIELD_NAME); PARSER.declareLong(CategorizeTextAggregationBuilder::minDocCount, TermsAggregationBuilder.MIN_DOC_COUNT_FIELD_NAME); PARSER.declareLong(CategorizeTextAggregationBuilder::shardMinDocCount, TermsAggregationBuilder.SHARD_MIN_DOC_COUNT_FIELD_NAME); - PARSER.declareInt(CategorizeTextAggregationBuilder::size, TermsAggregationBuilder.REQUIRED_SIZE_FIELD_NAME); + PARSER.declareInt(CategorizeTextAggregationBuilder::size, REQUIRED_SIZE_FIELD_NAME); } public static CategorizeTextAggregationBuilder parse(String aggregationName, XContentParser parser) throws IOException { @@ -107,7 +115,13 @@ public int getMaxChildren() { public CategorizeTextAggregationBuilder setMaxChildren(int maxChildren) { this.maxChildren = maxChildren; if (maxChildren <= 0) { - throw new IllegalArgumentException("[" + MAX_CHILDREN.getPreferredName() + "] must be greater than 0"); + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than 0 and less than [{}]. Found [{}] in [{}]", + MAX_CHILDREN.getPreferredName(), + MAX_MAX_CHILDREN, + maxChildren, + name + ); } return this; } @@ -119,7 +133,12 @@ public double getSimilarityThreshold() { public CategorizeTextAggregationBuilder setSimilarityThreshold(double similarityThreshold) { this.similarityThreshold = similarityThreshold; if (similarityThreshold < 0.1 || similarityThreshold > 1.0) { - throw new IllegalArgumentException("[" + SIMILARITY_THRESHOLD.getPreferredName() + "] must be in the range [0.1, 1.0]"); + throw ExceptionsHelper.badRequestException( + "[{}] must be in the range [0.1, 1.0]. Found [{}] in [{}]", + SIMILARITY_THRESHOLD.getPreferredName(), + similarityThreshold, + name + ); } return this; } @@ -140,7 +159,13 @@ public int getMaxDepth() { public CategorizeTextAggregationBuilder setMaxDepth(int maxDepth) { this.maxDepth = maxDepth; if (maxDepth <= 0) { - throw new IllegalArgumentException("[" + MAX_DEPTH.getPreferredName() + "] must be greater than 0"); + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than 0 and less than [{}]. Found [{}] in [{}]", + MAX_DEPTH.getPreferredName(), + MAX_MAX_DEPTH, + maxDepth, + name + ); } return this; } @@ -150,7 +175,12 @@ public CategorizeTextAggregationBuilder setMaxDepth(int maxDepth) { */ public CategorizeTextAggregationBuilder size(int size) { if (size <= 0) { - throw new IllegalArgumentException("[size] must be greater than 0. Found [" + size + "] in [" + name + "]"); + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than 0. Found [{}] in [{}]", + REQUIRED_SIZE_FIELD_NAME.getPreferredName(), + size, + name + ); } bucketCountThresholds.setRequiredSize(size); return this; @@ -164,7 +194,12 @@ public CategorizeTextAggregationBuilder size(int size) { */ public CategorizeTextAggregationBuilder shardSize(int shardSize) { if (shardSize <= 0) { - throw new IllegalArgumentException("[shardSize] must be greater than 0. Found [" + shardSize + "] in [" + name + "]"); + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than 0. Found [{}] in [{}]", + SHARD_SIZE_FIELD_NAME.getPreferredName(), + shardSize, + name + ); } bucketCountThresholds.setShardSize(shardSize); return this; @@ -176,8 +211,11 @@ public CategorizeTextAggregationBuilder shardSize(int shardSize) { */ public CategorizeTextAggregationBuilder minDocCount(long minDocCount) { if (minDocCount < 0) { - throw new IllegalArgumentException( - "[minDocCount] must be greater than or equal to 0. Found [" + minDocCount + "] in [" + name + "]" + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than or equal to 0. Found [{}] in [{}]", + MIN_DOC_COUNT_FIELD_NAME.getPreferredName(), + minDocCount, + name ); } bucketCountThresholds.setMinDocCount(minDocCount); @@ -190,8 +228,11 @@ public CategorizeTextAggregationBuilder minDocCount(long minDocCount) { */ public CategorizeTextAggregationBuilder shardMinDocCount(long shardMinDocCount) { if (shardMinDocCount < 0) { - throw new IllegalArgumentException( - "[shardMinDocCount] must be greater than or equal to 0. Found [" + shardMinDocCount + "] in [" + name + "]" + throw ExceptionsHelper.badRequestException( + "[{}] must be greater than or equal to 0. Found [{}] in [{}]", + SHARD_MIN_DOC_COUNT_FIELD_NAME.getPreferredName(), + shardMinDocCount, + name ); } bucketCountThresholds.setShardMinDocCount(shardMinDocCount); diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml index c003ffca0e4c7..7ecbb2f353daf 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml @@ -86,7 +86,7 @@ setup: "Test categorization aggregation with poor settings": - do: - catch: /\[max_children\] must be greater than 0/ + catch: /\[max_children\] must be greater than 0 and less than \[100\]/ search: index: to_categorize body: > @@ -102,7 +102,7 @@ setup: } } - do: - catch: /\[max_depth\] must be greater than 0/ + catch: /\[max_depth\] must be greater than 0 and less than \[100\]/ search: index: to_categorize body: > From 1b70178804e2574ba7a71764617b1bfc1af2c6d9 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 14 Sep 2021 11:11:03 -0400 Subject: [PATCH 06/20] adding analyzer etc. --- .../categorize-text-aggregation.asciidoc | 54 +++++++++++--- .../core/ml/job/config/AnalysisConfig.java | 2 +- .../config/CategorizationAnalyzerConfig.java | 6 +- .../CategorizationAggregationIT.java | 2 +- .../CategorizationTokenTree.java | 4 +- .../CategorizeTextAggregationBuilder.java | 71 ++++++++++++++----- .../CategorizeTextAggregator.java | 16 +++-- .../CategorizeTextAggregatorFactory.java | 15 ++-- .../InternalCategorizationAggregation.java | 10 +-- .../ml/aggs/categorization/TreeNode.java | 13 ++-- ...CategorizeTextAggregationBuilderTests.java | 9 ++- .../categorization/InnerTreeNodeTests.java | 4 +- ...nternalCategorizationAggregationTests.java | 2 +- .../categorization/LeafTreeNodeTests.java | 10 +-- .../test/ml/categorization_agg.yml | 24 ++++++- 15 files changed, 170 insertions(+), 72 deletions(-) diff --git a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc index 88e6c15ae8b9e..da8bda3bebdf7 100644 --- a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc +++ b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc @@ -21,30 +21,65 @@ WARNING: Re-analyzing _large_ result sets will require a lot of time and memory. The semi-structured text field to categorize. `max_children`:: -(Optional, integer, default: `100`) +(Optional, integer, default: `50`) The maximum number of unique tokens at any given layer of the tokenization tree. -Must be larger than 1. The smaller the value, the more broad the text categories. -Larger values may cause the aggregation to more memory and run more slowly +Must be larger than 1. Smaller values use less memory and create wider text categories. +Larger values will use more memory and create narrower categories. `max_depth`:: (Optional, integer, default: `5`) The maximum number of tokens matched on before attempting to merge categories. -Larger values may cause the aggregation to more memory and run more slowly. +Larger values will use more memory and create narrower categories. `similarity_threshold`:: -(Optional, double, default: `0.5`) +(Optional, integer, default: `50`) The minimum percentage of tokens that must match for text to be added to the category bucket. -Must be between 0.1 and 1.0. The larger the value the more restrictive the log categories. -Larger values may increase memory usage. +Must be between 1 and 100. The larger the value the narrower the categories. +Larger values will increase memory usage and create narrower categories. `categorization_filters`:: -(Optional, array of strings, default: `[]`) +(Optional, array of strings) This property expects an array of regular expressions. The expressions are used to filter out matching sequences from the categorization field values. You can use this functionality to fine tune the categorization by excluding sequences from consideration when categories are defined. For example, you can -exclude SQL statements that appear in your log files. +exclude SQL statements that appear in your log files. This +property cannot be used at the same time as `categorization_analyzer`. If you +only want to define simple regular expression filters that are applied prior to +tokenization, setting this property is the easiest method. If you also want to +customize the tokenizer or post-tokenization filtering, use the +`categorization_analyzer` property instead and include the filters as +`pattern_replace` character filters. + +`categorization_analyzer`:: +(Optional, (object or string) +The categorization analyzer specifies how the text is analyzed and tokenized before +being categorized. The syntax is very similar to that used to define the `analyzer` in the +<>. This +property cannot be used at the same time as `categorization_filters`. ++ +The `categorization_analyzer` field can be specified either as a string or as an +object. If it is a string it must refer to a +<> or one added by another plugin. If it +is an object it has the following properties: ++ +.Properties of `categorization_analyzer` +[%collapsible%open] +===== +`char_filter`:::: +(array of strings or objects) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=char-filter] + +`tokenizer`:::: +(string or object) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=tokenizer] + +`filter`:::: +(array of strings or objects) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=filter] +===== +end::categorization-analyzer[] `shard_size`:: (Optional, integer) @@ -247,6 +282,7 @@ POST log-messages/_search?filter_path=aggregations } } -------------------------------------------------- +// TEST[setup:categorize_text] [source,console-result] -------------------------------------------------- diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfig.java index a27aa10215ef9..6e8cb3845de63 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfig.java @@ -736,7 +736,7 @@ private void verifyCategorizationFiltersAreValidRegex() { } } - private static boolean isValidRegex(String exp) { + public static boolean isValidRegex(String exp) { try { Pattern.compile(exp); return true; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/CategorizationAnalyzerConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/CategorizationAnalyzerConfig.java index d9430762df5d7..cedc046b26674 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/CategorizationAnalyzerConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/CategorizationAnalyzerConfig.java @@ -86,8 +86,10 @@ public static CategorizationAnalyzerConfig buildFromXContentObject(XContentParse * * The parser is strict when parsing config and lenient when parsing cluster state. */ - static CategorizationAnalyzerConfig buildFromXContentFragment(XContentParser parser, boolean ignoreUnknownFields) throws IOException { - + public static CategorizationAnalyzerConfig buildFromXContentFragment( + XContentParser parser, + boolean ignoreUnknownFields + ) throws IOException { CategorizationAnalyzerConfig.Builder builder = new CategorizationAnalyzerConfig.Builder(); XContentParser.Token token = parser.currentToken(); diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java index 3eba0d46d0516..0f289a0419db0 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java @@ -82,7 +82,7 @@ public void testAggregationWithBroadCategories() { .setTrackTotalHits(false) .addAggregation( new CategorizeTextAggregationBuilder("categorize", "msg") - .setSimilarityThreshold(0.11) + .setSimilarityThreshold(11) .setMaxChildren(2) .setMaxDepth(1) .subAggregation(AggregationBuilders.max("max").field("time")) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java index dc5b6003200fd..6573433efc31d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java @@ -61,13 +61,13 @@ public class CategorizationTokenTree implements Accountable, TreeNodeFactory { private final int maxDepth; private final int maxChildren; - private final double similarityThreshold; + private final int similarityThreshold; private final AtomicLong idGen = new AtomicLong(); // TODO statically allocate an array like DuplicateByteSequenceSpotter ??? private final Map root = new HashMap<>(); private long sizeInBytes; - public CategorizationTokenTree(int maxChildren, int maxDepth, double similarityThreshold) { + public CategorizationTokenTree(int maxChildren, int maxDepth, int similarityThreshold) { assert maxChildren > 0 && maxDepth >= 0; this.maxChildren = maxChildren; this.maxDepth = maxDepth; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java index b47dbcd166146..e4916ca46feaf 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java @@ -20,10 +20,11 @@ import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; import org.elasticsearch.search.aggregations.support.AggregationContext; +import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -31,6 +32,7 @@ import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.REQUIRED_SIZE_FIELD_NAME; import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.SHARD_MIN_DOC_COUNT_FIELD_NAME; import static org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder.SHARD_SIZE_FIELD_NAME; +import static org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig.Builder.isValidRegex; public class CategorizeTextAggregationBuilder extends AbstractAggregationBuilder { @@ -50,6 +52,7 @@ public class CategorizeTextAggregationBuilder extends AbstractAggregationBuilder static final ParseField SIMILARITY_THRESHOLD = new ParseField("similarity_threshold"); static final ParseField MAX_DEPTH = new ParseField("max_depth"); static final ParseField CATEGORIZATION_FILTERS = new ParseField("categorization_filters"); + static final ParseField CATEGORIZATION_ANALYZER = new ParseField("categorization_analyzer"); public static final ObjectParser PARSER = ObjectParser.fromBuilder( CategorizeTextAggregationBuilder.NAME, @@ -59,7 +62,13 @@ public class CategorizeTextAggregationBuilder extends AbstractAggregationBuilder PARSER.declareString(CategorizeTextAggregationBuilder::setFieldName, FIELD_NAME); PARSER.declareInt(CategorizeTextAggregationBuilder::setMaxChildren, MAX_CHILDREN); PARSER.declareInt(CategorizeTextAggregationBuilder::setMaxDepth, MAX_DEPTH); - PARSER.declareDouble(CategorizeTextAggregationBuilder::setSimilarityThreshold, SIMILARITY_THRESHOLD); + PARSER.declareInt(CategorizeTextAggregationBuilder::setSimilarityThreshold, SIMILARITY_THRESHOLD); + PARSER.declareField( + CategorizeTextAggregationBuilder::setCategorizationAnalyzerConfig, + (p, c) -> CategorizationAnalyzerConfig.buildFromXContentFragment(p, false), + CATEGORIZATION_ANALYZER, + ObjectParser.ValueType.OBJECT_OR_STRING + ); PARSER.declareStringArray(CategorizeTextAggregationBuilder::setCategorizationFilters, CATEGORIZATION_FILTERS); PARSER.declareInt(CategorizeTextAggregationBuilder::shardSize, TermsAggregationBuilder.SHARD_SIZE_FIELD_NAME); PARSER.declareLong(CategorizeTextAggregationBuilder::minDocCount, TermsAggregationBuilder.MIN_DOC_COUNT_FIELD_NAME); @@ -74,10 +83,10 @@ public static CategorizeTextAggregationBuilder parse(String aggregationName, XCo private TermsAggregator.BucketCountThresholds bucketCountThresholds = new TermsAggregator.BucketCountThresholds( DEFAULT_BUCKET_COUNT_THRESHOLDS ); - private List categorizationFilters = new ArrayList<>(); + private CategorizationAnalyzerConfig categorizationAnalyzerConfig; private String fieldName; - private int maxChildren = 100; - private double similarityThreshold = 0.5; + private int maxChildren = 50; + private int similarityThreshold = 50; private int maxDepth = 5; private CategorizeTextAggregationBuilder(String name) { @@ -104,8 +113,8 @@ public CategorizeTextAggregationBuilder(StreamInput in) throws IOException { this.fieldName = in.readString(); this.maxChildren = in.readVInt(); this.maxDepth = in.readVInt(); - this.similarityThreshold = in.readDouble(); - this.categorizationFilters = in.readStringList(); + this.similarityThreshold = in.readVInt(); + this.categorizationAnalyzerConfig = in.readOptionalWriteable(CategorizationAnalyzerConfig::new); } public int getMaxChildren() { @@ -130,11 +139,11 @@ public double getSimilarityThreshold() { return similarityThreshold; } - public CategorizeTextAggregationBuilder setSimilarityThreshold(double similarityThreshold) { + public CategorizeTextAggregationBuilder setSimilarityThreshold(int similarityThreshold) { this.similarityThreshold = similarityThreshold; - if (similarityThreshold < 0.1 || similarityThreshold > 1.0) { + if (similarityThreshold < 1 || similarityThreshold > 100) { throw ExceptionsHelper.badRequestException( - "[{}] must be in the range [0.1, 1.0]. Found [{}] in [{}]", + "[{}] must be in the range [1, 100]. Found [{}] in [{}]", SIMILARITY_THRESHOLD.getPreferredName(), similarityThreshold, name @@ -143,12 +152,36 @@ public CategorizeTextAggregationBuilder setSimilarityThreshold(double similarity return this; } - public List getCategorizationFilters() { - return categorizationFilters; + public CategorizeTextAggregationBuilder setCategorizationAnalyzerConfig(CategorizationAnalyzerConfig categorizationAnalyzerConfig) { + this.categorizationAnalyzerConfig = categorizationAnalyzerConfig; + return this; } public CategorizeTextAggregationBuilder setCategorizationFilters(List categorizationFilters) { - this.categorizationFilters = ExceptionsHelper.requireNonNull(categorizationFilters, CATEGORIZATION_FILTERS); + if (categorizationFilters == null || categorizationFilters.isEmpty()) { + return this; + } + if (categorizationAnalyzerConfig != null) { + throw ExceptionsHelper.badRequestException( + "[{}] cannot be used with [{}] - instead specify them as pattern_replace char_filters in the analyzer", + CATEGORIZATION_FILTERS.getPreferredName(), + CATEGORIZATION_ANALYZER.getPreferredName() + ); + } + if (categorizationFilters.stream().distinct().count() != categorizationFilters.size()) { + throw ExceptionsHelper.badRequestException(Messages.JOB_CONFIG_CATEGORIZATION_FILTERS_CONTAINS_DUPLICATES); + } + if (categorizationFilters.stream().anyMatch(String::isEmpty)) { + throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.JOB_CONFIG_CATEGORIZATION_FILTERS_CONTAINS_EMPTY)); + } + for (String filter : categorizationFilters) { + if (isValidRegex(filter) == false) { + throw ExceptionsHelper.badRequestException( + Messages.getMessage(Messages.JOB_CONFIG_CATEGORIZATION_FILTERS_CONTAINS_INVALID_REGEX, filter) + ); + } + } + this.categorizationAnalyzerConfig = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(categorizationFilters); return this; } @@ -250,7 +283,7 @@ protected CategorizeTextAggregationBuilder( this.maxChildren = clone.maxChildren; this.maxDepth = clone.maxDepth; this.similarityThreshold = clone.similarityThreshold; - this.categorizationFilters = clone.categorizationFilters; + this.categorizationAnalyzerConfig = clone.categorizationAnalyzerConfig; } @Override @@ -259,8 +292,8 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeVInt(maxChildren); out.writeVInt(maxDepth); - out.writeDouble(similarityThreshold); - out.writeStringCollection(categorizationFilters); + out.writeVInt(similarityThreshold); + out.writeOptionalWriteable(categorizationAnalyzerConfig); } @Override @@ -276,7 +309,7 @@ protected AggregatorFactory doBuild( maxDepth, similarityThreshold, bucketCountThresholds, - categorizationFilters, + categorizationAnalyzerConfig, context, parent, subfactoriesBuilder, @@ -292,8 +325,8 @@ protected XContentBuilder internalXContent(XContentBuilder builder, Params param builder.field(MAX_CHILDREN.getPreferredName(), maxChildren); builder.field(MAX_DEPTH.getPreferredName(), maxDepth); builder.field(SIMILARITY_THRESHOLD.getPreferredName(), similarityThreshold); - if (categorizationFilters.isEmpty() == false) { - builder.field(CATEGORIZATION_FILTERS.getPreferredName(), categorizationFilters); + if (categorizationAnalyzerConfig != null) { + categorizationAnalyzerConfig.toXContent(builder, params); } builder.endObject(); return null; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java index fb257469d63c9..c9b8033f9ae79 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java @@ -31,10 +31,11 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; import java.util.Iterator; -import java.util.List; import java.util.Map; +import java.util.Optional; public class CategorizeTextAggregator extends DeferableBucketAggregator { @@ -46,7 +47,7 @@ public class CategorizeTextAggregator extends DeferableBucketAggregator { private ObjectArray categorizers; private final int maxChildren; private final int maxDepth; - private final double similarityThreshold; + private final int similarityThreshold; private final LongKeyedBucketOrds bucketOrds; protected CategorizeTextAggregator( @@ -59,18 +60,19 @@ protected CategorizeTextAggregator( TermsAggregator.BucketCountThresholds bucketCountThresholds, int maxChildren, int maxDepth, - double similarityThreshold, - List categorizationFilters, + int similarityThreshold, + CategorizationAnalyzerConfig categorizationAnalyzerConfig, Map metadata ) throws IOException { super(name, factories, context, parent, metadata); this.sourceLookup = context.lookup().source(); this.sourceFieldName = sourceFieldName; this.fieldType = fieldType; - CategorizationAnalyzerConfig categorizationAnalyzerConfig = CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer( - categorizationFilters + this.analyzer = new CategorizationAnalyzer( + context.getAnalysisRegistry(), + Optional.ofNullable(categorizationAnalyzerConfig) + .orElse(CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(Collections.emptyList())) ); - this.analyzer = new CategorizationAnalyzer(context.getAnalysisRegistry(), categorizationAnalyzerConfig); this.categorizers = bigArrays().newObjectArray(1); this.maxChildren = maxChildren; this.maxDepth = maxDepth; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java index 2974dadde4ecf..1103efd0a97a5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java @@ -15,10 +15,9 @@ import org.elasticsearch.search.aggregations.bucket.BucketUtils; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; import org.elasticsearch.search.aggregations.support.AggregationContext; +import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; import java.io.IOException; -import java.util.Collections; -import java.util.List; import java.util.Map; public class CategorizeTextAggregatorFactory extends AggregatorFactory { @@ -27,8 +26,8 @@ public class CategorizeTextAggregatorFactory extends AggregatorFactory { private final String indexedFieldName; private final int maxChildren; private final int maxDepth; - private final double similarityThreshold; - private final List categorizationFilters; + private final int similarityThreshold; + private final CategorizationAnalyzerConfig categorizationAnalyzerConfig; private final TermsAggregator.BucketCountThresholds bucketCountThresholds; public CategorizeTextAggregatorFactory( @@ -36,9 +35,9 @@ public CategorizeTextAggregatorFactory( String fieldName, int maxChildren, int maxDepth, - double similarityThreshold, + int similarityThreshold, TermsAggregator.BucketCountThresholds bucketCountThresholds, - List categorizationFilters, + CategorizationAnalyzerConfig categorizationAnalyzerConfig, AggregationContext context, AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder, @@ -54,7 +53,7 @@ public CategorizeTextAggregatorFactory( this.maxChildren = maxChildren; this.maxDepth = maxDepth; this.similarityThreshold = similarityThreshold; - this.categorizationFilters = categorizationFilters == null ? Collections.emptyList() : categorizationFilters; + this.categorizationAnalyzerConfig = categorizationAnalyzerConfig; this.bucketCountThresholds = bucketCountThresholds; } @@ -82,7 +81,7 @@ protected Aggregator createInternal(Aggregator parent, CardinalityUpperBound car maxChildren, maxDepth, similarityThreshold, - categorizationFilters, + categorizationAnalyzerConfig, metadata ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java index be65ae70711f7..7dbb8235dcebb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java @@ -240,7 +240,7 @@ public int compareTo(Bucket o) { private final List buckets; private final int maxChildren; - private final double similarityThreshold; + private final int similarityThreshold; private final int maxDepth; protected final int requiredSize; protected final long minDocCount; @@ -251,7 +251,7 @@ protected InternalCategorizationAggregation( long minDocCount, int maxChildren, int maxDepth, - double similarityThreshold, + int similarityThreshold, Map metadata ) { this(name, requiredSize, minDocCount, maxChildren, maxDepth, similarityThreshold, metadata, new ArrayList<>()); @@ -263,7 +263,7 @@ protected InternalCategorizationAggregation( long minDocCount, int maxChildren, int maxDepth, - double similarityThreshold, + int similarityThreshold, Map metadata, List buckets ) { @@ -280,7 +280,7 @@ public InternalCategorizationAggregation(StreamInput in) throws IOException { super(in); this.maxChildren = in.readVInt(); this.maxDepth = in.readVInt(); - this.similarityThreshold = in.readDouble(); + this.similarityThreshold = in.readVInt(); this.buckets = in.readList(Bucket::new); this.requiredSize = readSize(in); this.minDocCount = in.readVLong(); @@ -324,7 +324,7 @@ public String getWriteableName() { protected void doWriteTo(StreamOutput out) throws IOException { out.writeVInt(maxChildren); out.writeVInt(maxDepth); - out.writeDouble(similarityThreshold); + out.writeVInt(similarityThreshold); out.writeList(buckets); writeSize(requiredSize, out); out.writeVLong(minDocCount); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java index 6f7034d2a4623..a67f654fc8c4e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java @@ -70,12 +70,15 @@ final long getCount() { static class LeafTreeNode extends TreeNode { private final List textCategorizations; - private final double similarityThreshold; + private final int similarityThreshold; - LeafTreeNode(long count, double similarityThreshold) { + LeafTreeNode(long count, int similarityThreshold) { super(count); this.textCategorizations = new ArrayList<>(); this.similarityThreshold = similarityThreshold; + if (similarityThreshold < 1 || similarityThreshold > 100) { + throw new IllegalArgumentException("similarityThreshold must be between 1 and 100"); + } } public boolean isLeaf() { @@ -105,7 +108,7 @@ void mergeWith(TreeNode treeNode) { public long ramBytesUsed() { return Long.BYTES // count + NUM_BYTES_OBJECT_REF // list reference - + Double.BYTES // similarityThreshold + + Integer.BYTES // similarityThreshold + sizeOfCollection(textCategorizations); } @@ -129,7 +132,7 @@ void collapseTinyChildren() {} private Optional getAndUpdateLogGroup(BytesRef[] logTokens, long docCount) { return getBestLogGroup(logTokens).map(bestGroupAndSimilarity -> { - if (bestGroupAndSimilarity.v2() >= similarityThreshold) { + if ((bestGroupAndSimilarity.v2() * 100) >= similarityThreshold) { bestGroupAndSimilarity.v1().addLog(logTokens, docCount); return bestGroupAndSimilarity.v1(); } @@ -173,7 +176,7 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; LeafTreeNode that = (LeafTreeNode) o; - return Double.compare(that.similarityThreshold, similarityThreshold) == 0 + return that.similarityThreshold == similarityThreshold && Objects.equals(textCategorizations, that.textCategorizations); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java index 3484e7c90540e..809ae245d13f1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.aggregations.BaseAggregationTestCase; import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.job.config.CategorizationAnalyzerConfigTests; import java.util.Collection; import java.util.Collections; @@ -26,9 +27,13 @@ protected Collection> getExtraPlugins() { @Override protected CategorizeTextAggregationBuilder createTestAggregatorBuilder() { CategorizeTextAggregationBuilder builder = new CategorizeTextAggregationBuilder(randomAlphaOfLength(10), randomAlphaOfLength(10)); - if (randomBoolean()) { + final boolean setFilters = randomBoolean(); + if (setFilters) { builder.setCategorizationFilters(Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList())); } + if (setFilters == false) { + builder.setCategorizationAnalyzerConfig(CategorizationAnalyzerConfigTests.createRandomized().build()); + } if (randomBoolean()) { builder.setMaxChildren(randomIntBetween(1, 500)); } @@ -36,7 +41,7 @@ protected CategorizeTextAggregationBuilder createTestAggregatorBuilder() { builder.setMaxDepth(randomIntBetween(1, 10)); } if (randomBoolean()) { - builder.setSimilarityThreshold(randomDoubleBetween(0.1, 1.0, true)); + builder.setSimilarityThreshold(randomIntBetween(1, 100)); } if (randomBoolean()) { builder.minDocCount(randomLongBetween(1, 100)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java index 2eb1f26aa5770..c9e66e3153c7b 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java @@ -17,7 +17,7 @@ public class InnerTreeNodeTests extends ESTestCase { - private final TreeNodeFactory factory = new CategorizationTokenTree(3, 4, 0.6); + private final TreeNodeFactory factory = new CategorizationTokenTree(3, 4, 60); public void testAddLog() { TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); @@ -91,7 +91,7 @@ public void testMergeWith() { innerTreeNode.incCount(1000); innerTreeNode.addLog(getTokens("foo2", "bar", "baz", "biz"), 1000, factory); - expectThrows(UnsupportedOperationException.class, () -> innerTreeNode.mergeWith(new TreeNode.LeafTreeNode(1, 0.6))); + expectThrows(UnsupportedOperationException.class, () -> innerTreeNode.mergeWith(new TreeNode.LeafTreeNode(1, 60))); TreeNode.InnerTreeNode mergeWith = new TreeNode.InnerTreeNode(1, 0, 3); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java index c3be4b0b8b6bb..749863b336fa2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregationTests.java @@ -94,7 +94,7 @@ protected InternalCategorizationAggregation createTestInstance( randomLongBetween(1, 10), randomIntBetween(1, 500), randomIntBetween(1, 10), - randomDoubleBetween(0.1, 1.0, true), + randomIntBetween(1, 100), metadata, buckets ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java index 0cce200fe1526..3c465774ca733 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java @@ -17,10 +17,10 @@ public class LeafTreeNodeTests extends ESTestCase { - private final TreeNodeFactory factory = new CategorizationTokenTree(10, 10, 0.6); + private final TreeNodeFactory factory = new CategorizationTokenTree(10, 10, 60); public void testAddGroup() { - TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 0.6); + TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 60); TextCategorization group = leafTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 1, factory); assertThat(group.getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); @@ -44,16 +44,16 @@ public void testAddGroup() { } public void testMergeWith() { - TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 0.6); + TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 60); leafTreeNode.mergeWith(null); - assertThat(leafTreeNode, equalTo(new TreeNode.LeafTreeNode(0, 0.6))); + assertThat(leafTreeNode, equalTo(new TreeNode.LeafTreeNode(0, 60))); expectThrows(UnsupportedOperationException.class, () -> leafTreeNode.mergeWith(new TreeNode.InnerTreeNode(1, 2, 3))); leafTreeNode.incCount(5); leafTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 5, factory); - TreeNode.LeafTreeNode toMerge = new TreeNode.LeafTreeNode(0, 0.6); + TreeNode.LeafTreeNode toMerge = new TreeNode.LeafTreeNode(0, 60); leafTreeNode.incCount(1); toMerge.addLog(getTokens("foo", "bar", "baz", "bizzy"), 1, factory); leafTreeNode.incCount(1); diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml index 7ecbb2f353daf..e6190185fe920 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml @@ -71,7 +71,7 @@ setup: "size": 10, "max_children": 2, "max_depth": 1, - "similarity_threshold": 0.11 + "similarity_threshold": 11 } } } @@ -118,7 +118,7 @@ setup: } } - do: - catch: /\[similarity_threshold\] must be in the range \[0.1, 1.0\]/ + catch: /\[similarity_threshold\] must be in the range \[1, 100\]/ search: index: to_categorize body: > @@ -128,7 +128,25 @@ setup: "categories": { "categorize_text": { "field": "text", - "similarity_threshold": 0.0 + "similarity_threshold": 0 + } + } + } + } + + - do: + catch: /\[categorization_filters\] cannot be used with \[categorization_analyzer\]/ + search: + index: to_categorize + body: > + { + "size": 0, + "aggs": { + "categories": { + "categorize_text": { + "field": "text", + "categorization_filters": ["foo"], + "categorization_analyzer": "english" } } } From 0fe077802697c6c696118924c035907a249bf53d Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 14 Sep 2021 11:54:59 -0400 Subject: [PATCH 07/20] fixing docs --- docs/build.gradle | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/build.gradle b/docs/build.gradle index c9e55a651dfa3..236dc20e9440d 100644 --- a/docs/build.gradle +++ b/docs/build.gradle @@ -1080,12 +1080,11 @@ buildRestTests.setups['categorize_text'] = ''' number_of_shards: 1 number_of_replicas: 0 mappings: - metric: - properties: - time: - type: date - message: - type: text + properties: + time: + type: date + message: + type: text - do: bulk: From b4eb65b8a21c76c58ba78e9b25403bab7cadd03b Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 14 Sep 2021 13:10:40 -0400 Subject: [PATCH 08/20] fixing result consistency and docs --- .../categorize-text-aggregation.asciidoc | 34 +++++++++---------- .../InternalCategorizationAggregation.java | 13 +++++-- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc index da8bda3bebdf7..919fd3c178854 100644 --- a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc +++ b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc @@ -132,7 +132,7 @@ Response: }, { "doc_count" : 1, - "key" : "User foo_864 logged off" + "key" : "Node starting up" }, { "doc_count" : 1, @@ -140,7 +140,7 @@ Response: }, { "doc_count" : 1, - "key" : "Node starting up" + "key" : "User foo_864 logged off" } ] } @@ -185,15 +185,15 @@ category results }, { "doc_count" : 1, - "key" : "User logged off" + "key" : "Node starting up" }, { "doc_count" : 1, - "key" : "User logging on" + "key" : "User logged off" }, { "doc_count" : 1, - "key" : "Node starting up" + "key" : "User logging on" } ] } @@ -213,7 +213,7 @@ POST log-messages/_search?filter_path=aggregations "field": "message", "categorization_filters": ["\\w+\\_\\d{3}"], <1> "max_depth": 2, <2> - "similarity_threshold": 0.3 <3> + "similarity_threshold": 30 <3> } } } @@ -309,7 +309,7 @@ POST log-messages/_search?filter_path=aggregations "hits" : [ { "_index" : "log-messages", - "_id" : "DU9q4HsBtGA51sVjTrac", + "_id" : "-u5F5XsBST1JKaSDH8OW", "_score" : 1.0, "_source" : { "message" : "2016-02-07T00:00:00+0000 Node 3 shutting down" @@ -332,7 +332,7 @@ POST log-messages/_search?filter_path=aggregations "hits" : [ { "_index" : "log-messages", - "_id" : "Dk9q4HsBtGA51sVjTrac", + "_id" : "--5F5XsBST1JKaSDH8OW", "_score" : 1.0, "_source" : { "message" : "2016-02-07T00:00:00+0000 Node 5 starting up" @@ -353,7 +353,7 @@ POST log-messages/_search?filter_path=aggregations "buckets" : [ { "doc_count" : 1, - "key" : "User logged off", + "key" : "Node shutting down", "hit" : { "hits" : { "total" : { @@ -364,10 +364,10 @@ POST log-messages/_search?filter_path=aggregations "hits" : [ { "_index" : "log-messages", - "_id" : "Ek9q4HsBtGA51sVjTrac", + "_id" : "_e5F5XsBST1JKaSDH8OW", "_score" : 1.0, "_source" : { - "message" : "2016-02-08T00:00:00+0000 User foo_864 logged off" + "message" : "2016-02-08T00:00:00+0000 Node 5 shutting down" } } ] @@ -376,7 +376,7 @@ POST log-messages/_search?filter_path=aggregations }, { "doc_count" : 1, - "key" : "User logging on", + "key" : "User logged off", "hit" : { "hits" : { "total" : { @@ -387,10 +387,10 @@ POST log-messages/_search?filter_path=aggregations "hits" : [ { "_index" : "log-messages", - "_id" : "EU9q4HsBtGA51sVjTrac", + "_id" : "_-5F5XsBST1JKaSDH8OW", "_score" : 1.0, "_source" : { - "message" : "2016-02-08T00:00:00+0000 User foo_325 logging on" + "message" : "2016-02-08T00:00:00+0000 User foo_864 logged off" } } ] @@ -399,7 +399,7 @@ POST log-messages/_search?filter_path=aggregations }, { "doc_count" : 1, - "key" : "Node shutting down", + "key" : "User logging on", "hit" : { "hits" : { "total" : { @@ -410,10 +410,10 @@ POST log-messages/_search?filter_path=aggregations "hits" : [ { "_index" : "log-messages", - "_id" : "EE9q4HsBtGA51sVjTrac", + "_id" : "_u5F5XsBST1JKaSDH8OW", "_score" : 1.0, "_source" : { - "message" : "2016-02-08T00:00:00+0000 Node 5 shutting down" + "message" : "2016-02-08T00:00:00+0000 User foo_325 logging on" } } ] diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java index 7dbb8235dcebb..71eb913e1be43 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java @@ -24,6 +24,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -199,6 +200,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } + BucketKey getRawKey() { + return key; + } + @Override public Object getKey() { return key; @@ -336,7 +341,7 @@ public InternalAggregation reduce(List aggregations, Reduce // TODO: Could we do a merge sort similar to terms? // It would require us returning partial reductions sorted by key, not by doc_count // First, make sure we have all the counts for equal log groups - Map reduced = new HashMap<>(aggregations.size(), 1.0f); + Map reduced = new HashMap<>(); for (InternalAggregation aggregation : aggregations) { InternalCategorizationAggregation categorizationAggregation = (InternalCategorizationAggregation) aggregation; for (Bucket bucket : categorizationAggregation.buckets) { @@ -350,7 +355,7 @@ public InternalAggregation reduce(List aggregations, Reduce } // Collapse tiny groups together, this may result in new bucket keys for already known buckets categorizationTokenTree.mergeSmallestChildren(); - Map mergedBuckets = new HashMap<>(aggregations.size(), 1.0f); + Map mergedBuckets = new HashMap<>(); for (DelayedCategorizationBucket delayedBucket : reduced.values()) { TextCategorization group = categorizationTokenTree.parseLogLineConst(delayedBucket.key.keyAsTokens()); if (group == null) { @@ -387,6 +392,10 @@ public InternalAggregation reduce(List aggregations, Reduce for (int i = pq.size() - 1; i >= 0; i--) { bucketList[i] = pq.pop(); } + // Keep the top categories top, but then sort by the key for those with duplicate counts + if (reduceContext.isFinalReduce()) { + Arrays.sort(bucketList, Comparator.comparing(Bucket::getDocCount).reversed().thenComparing(Bucket::getRawKey)); + } return new InternalCategorizationAggregation( name, requiredSize, From d81facae27be4e4c1c7891d67af58b147dfea972 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 14 Sep 2021 13:30:39 -0400 Subject: [PATCH 09/20] more doc fixes --- docs/build.gradle | 24 ++++---- .../categorize-text-aggregation.asciidoc | 56 ++++++++++++------- 2 files changed, 48 insertions(+), 32 deletions(-) diff --git a/docs/build.gradle b/docs/build.gradle index 236dc20e9440d..3ec926257a0f7 100644 --- a/docs/build.gradle +++ b/docs/build.gradle @@ -1091,18 +1091,18 @@ buildRestTests.setups['categorize_text'] = ''' index: log-messages refresh: true body: | - {"index": {}} - {"time":"2016-02-07T00:00:00+0000", "message": "2016-02-07T00:00:00+0000 Node 3 shutting down"} - {"index": {}} - {"time":"2016-02-07T00:00:00+0000", "message": "2016-02-07T00:00:00+0000 Node 5 starting up"} - {"index": {}} - {"time":"2016-02-07T00:00:00+0000", "message": "2016-02-07T00:00:00+0000 Node 4 shutting down"} - {"index": {}} - {"time":"2016-02-08T00:00:00+0000", "message": "2016-02-08T00:00:00+0000 Node 5 shutting down"} - {"index": {}} - {"time":"2016-02-08T00:00:00+0000", "message": "2016-02-08T00:00:00+0000 User foo_325 logging on"} - {"index": {}} - {"time":"2016-02-08T00:00:00+0000", "message": "2016-02-08T00:00:00+0000 User foo_864 logged off"} + {"index": {"_id":"1"}} + {"time":"2016-02-07T00:01:00+0000", "message": "2016-02-07T00:00:00+0000 Node 3 shutting down"} + {"index": {"_id":"2"}} + {"time":"2016-02-07T00:02:00+0000", "message": "2016-02-07T00:00:00+0000 Node 5 starting up"} + {"index": {"_id":"3"}} + {"time":"2016-02-07T00:03:00+0000", "message": "2016-02-07T00:00:00+0000 Node 4 shutting down"} + {"index": {"_id":"4"}} + {"time":"2016-02-08T00:01:00+0000", "message": "2016-02-08T00:00:00+0000 Node 5 shutting down"} + {"index": {"_id":"5"}} + {"time":"2016-02-08T00:02:00+0000", "message": "2016-02-08T00:00:00+0000 User foo_325 logging on"} + {"index": {"_id":"6"}} + {"time":"2016-02-08T00:04:00+0000", "message": "2016-02-08T00:00:00+0000 User foo_864 logged off"} ''' buildRestTests.setups['server_metrics_index'] = ''' - do: diff --git a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc index 919fd3c178854..3aaeeda107f89 100644 --- a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc +++ b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc @@ -272,6 +272,7 @@ POST log-messages/_search?filter_path=aggregations "hit": { "top_hits": { "size": 1, + "sort": ["time"], "_source": "message" } } @@ -305,15 +306,18 @@ POST log-messages/_search?filter_path=aggregations "value" : 2, "relation" : "eq" }, - "max_score" : 1.0, + "max_score" : null, "hits" : [ { "_index" : "log-messages", - "_id" : "-u5F5XsBST1JKaSDH8OW", - "_score" : 1.0, + "_id" : "1", + "_score" : null, "_source" : { "message" : "2016-02-07T00:00:00+0000 Node 3 shutting down" - } + }, + "sort" : [ + 1454803260000 + ] } ] } @@ -328,15 +332,18 @@ POST log-messages/_search?filter_path=aggregations "value" : 1, "relation" : "eq" }, - "max_score" : 1.0, + "max_score" : null, "hits" : [ { "_index" : "log-messages", - "_id" : "--5F5XsBST1JKaSDH8OW", - "_score" : 1.0, + "_id" : "2", + "_score" : null, "_source" : { "message" : "2016-02-07T00:00:00+0000 Node 5 starting up" - } + }, + "sort" : [ + 1454803320000 + ] } ] } @@ -360,15 +367,18 @@ POST log-messages/_search?filter_path=aggregations "value" : 1, "relation" : "eq" }, - "max_score" : 1.0, + "max_score" : null, "hits" : [ { "_index" : "log-messages", - "_id" : "_e5F5XsBST1JKaSDH8OW", - "_score" : 1.0, + "_id" : "4", + "_score" : null, "_source" : { "message" : "2016-02-08T00:00:00+0000 Node 5 shutting down" - } + }, + "sort" : [ + 1454889660000 + ] } ] } @@ -383,15 +393,18 @@ POST log-messages/_search?filter_path=aggregations "value" : 1, "relation" : "eq" }, - "max_score" : 1.0, + "max_score" : null, "hits" : [ { "_index" : "log-messages", - "_id" : "_-5F5XsBST1JKaSDH8OW", - "_score" : 1.0, + "_id" : "6", + "_score" : null, "_source" : { "message" : "2016-02-08T00:00:00+0000 User foo_864 logged off" - } + }, + "sort" : [ + 1454889840000 + ] } ] } @@ -406,15 +419,18 @@ POST log-messages/_search?filter_path=aggregations "value" : 1, "relation" : "eq" }, - "max_score" : 1.0, + "max_score" : null, "hits" : [ { "_index" : "log-messages", - "_id" : "_u5F5XsBST1JKaSDH8OW", - "_score" : 1.0, + "_id" : "5", + "_score" : null, "_source" : { "message" : "2016-02-08T00:00:00+0000 User foo_325 logging on" - } + }, + "sort" : [ + 1454889720000 + ] } ] } From c67812d9f8e9b1b75d817e511cb8567cbae06896 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 15 Sep 2021 07:19:37 -0400 Subject: [PATCH 10/20] Apply suggestions from code review Co-authored-by: David Roberts --- .../bucket/categorize-text-aggregation.asciidoc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc index 3aaeeda107f89..4f054049a03a3 100644 --- a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc +++ b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc @@ -23,8 +23,8 @@ The semi-structured text field to categorize. `max_children`:: (Optional, integer, default: `50`) The maximum number of unique tokens at any given layer of the tokenization tree. -Must be larger than 1. Smaller values use less memory and create wider text categories. -Larger values will use more memory and create narrower categories. +Must be larger than 1. Smaller values use less memory and create fewer categories. +Larger values will use more memory and create more categories. `max_depth`:: (Optional, integer, default: `5`) @@ -53,7 +53,7 @@ customize the tokenizer or post-tokenization filtering, use the `pattern_replace` character filters. `categorization_analyzer`:: -(Optional, (object or string) +(Optional, object or string) The categorization analyzer specifies how the text is analyzed and tokenized before being categorized. The syntax is very similar to that used to define the `analyzer` in the <>. This From a01595fdb31a389aa9953b2e02e1e199aa3461e1 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 22 Sep 2021 12:44:17 -0400 Subject: [PATCH 11/20] moving to using token IDs and hashes --- .../CategorizationBytesRefHash.java | 73 +++++++++++ .../CategorizationTokenTree.java | 46 +++---- .../CategorizeTextAggregator.java | 11 +- .../CategorizeTextAggregatorFactory.java | 25 +++- .../InternalCategorizationAggregation.java | 24 ++-- .../categorization/TextCategorization.java | 45 +++---- .../ml/aggs/categorization/TreeNode.java | 122 ++++++++---------- .../aggs/categorization/TreeNodeFactory.java | 5 +- .../UnmappedCategorizationAggregation.java | 71 ++++++++++ .../categorization/InnerTreeNodeTests.java | 88 ++++++++----- .../categorization/LeafTreeNodeTests.java | 43 ++++-- .../TextCategorizationTests.java | 44 +++++-- 12 files changed, 402 insertions(+), 195 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java new file mode 100644 index 0000000000000..a45ccc4f4c4ef --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.util.BytesRefHash; + +import java.io.Closeable; +import java.io.IOException; + +class CategorizationBytesRefHash implements Closeable { + + static final BytesRef WILD_CARD_REF = new BytesRef("*"); + static final long WILD_CARD_ID = -1; + private final BytesRefHash bytesRefHash; + + CategorizationBytesRefHash(BytesRefHash bytesRefHash) { + this.bytesRefHash = bytesRefHash; + } + + BytesRef getShallow(long id) { + if (id == WILD_CARD_ID) { + return WILD_CARD_REF; + } + return bytesRefHash.get(id, new BytesRef()); + } + + Long[] getIds(BytesRef[] tokens) { + Long[] ids = new Long[tokens.length]; + for (int i = 0; i < tokens.length; i++) { + ids[i] = put(tokens[i]); + } + return ids; + } + + BytesRef[] getShallows(Long[] ids) { + BytesRef[] tokens = new BytesRef[ids.length]; + for (int i = 0; i < tokens.length; i++) { + tokens[i] = getShallow(ids[i]); + } + return tokens; + } + + BytesRef getDeep(long id) { + if (id == WILD_CARD_ID) { + return WILD_CARD_REF; + } + BytesRef shallow = bytesRefHash.get(id, new BytesRef()); + return BytesRef.deepCopyOf(shallow); + } + + long put(BytesRef bytesRef) { + if (WILD_CARD_REF.equals(bytesRef)) { + return WILD_CARD_ID; + } + long hash = bytesRefHash.add(bytesRef); + if (hash < 0) { + return -1 - hash; + } else { + return hash; + } + } + + @Override + public void close() throws IOException { + bytesRefHash.close(); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java index 6573433efc31d..2e8e09e43794c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java @@ -7,14 +7,11 @@ package org.elasticsearch.xpack.ml.aggs.categorization; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.search.aggregations.InternalAggregations; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -56,9 +53,6 @@ */ public class CategorizationTokenTree implements Accountable, TreeNodeFactory { - static final BytesRef WILD_CARD = new BytesRef("*"); - private static final Logger LOGGER = LogManager.getLogger(CategorizationTokenTree.class); - private final int maxDepth; private final int maxChildren; private final int similarityThreshold; @@ -80,10 +74,15 @@ public CategorizationTokenTree(int maxChildren, int maxDepth, int similarityThre + Long.BYTES; // sizeInBytes } - public List toIntermediateBuckets() { + public List toIntermediateBuckets(CategorizationBytesRefHash hash) { return root.values().stream().flatMap(c -> c.getAllChildrenLogGroups().stream()).map(lg -> { + Long[] categoryTokenIds = lg.getCategorization(); + BytesRef[] bytesRefs = new BytesRef[categoryTokenIds.length]; + for (int i = 0; i < categoryTokenIds.length; i++) { + bytesRefs[i] = hash.getShallow(categoryTokenIds[i]); + } InternalCategorizationAggregation.Bucket bucket = new InternalCategorizationAggregation.Bucket( - new InternalCategorizationAggregation.BucketKey(lg.getCategorization()), + new InternalCategorizationAggregation.BucketKey(bytesRefs), lg.getCount(), InternalAggregations.EMPTY ); @@ -96,35 +95,32 @@ void mergeSmallestChildren() { root.values().forEach(TreeNode::collapseTinyChildren); } - public TextCategorization parseLogLine(final BytesRef[] logTokens) { - return parseLogLine(logTokens, 1); + public TextCategorization parseLogLine(final Long[] logTokenIds) { + return parseLogLine(logTokenIds, 1); } - public TextCategorization parseLogLineConst(final BytesRef[] logTokens) { - TreeNode currentNode = this.root.get(logTokens.length); + public TextCategorization parseLogLineConst(final Long[] logTokenIds) { + TreeNode currentNode = this.root.get(logTokenIds.length); if (currentNode == null) { // we are missing an entire sub tree. New log length found return null; } - return currentNode.getLogGroup(logTokens); + return currentNode.getLogGroup(logTokenIds); } - public TextCategorization parseLogLine(final BytesRef[] logTokens, long docCount) { - if (LOGGER.isTraceEnabled()) { - LOGGER.trace("parsing tokens [{}]", Arrays.stream(logTokens).map(BytesRef::utf8ToString).collect(Collectors.joining(" "))); - } - TreeNode currentNode = this.root.get(logTokens.length); + public TextCategorization parseLogLine(final Long[] logTokenIds, long docCount) { + TreeNode currentNode = this.root.get(logTokenIds.length); if (currentNode == null) { // we are missing an entire sub tree. New log length found - currentNode = newNode(docCount, 0, logTokens); - this.root.put(logTokens.length, currentNode); + currentNode = newNode(docCount, 0, logTokenIds); + this.root.put(logTokenIds.length, currentNode); } else { currentNode.incCount(docCount); } - return currentNode.addLog(logTokens, docCount, this); + return currentNode.addLog(logTokenIds, docCount, this); } @Override - public TreeNode newNode(long docCount, int tokenPos, BytesRef[] tokens) { - TreeNode node = tokenPos < maxDepth - 1 && tokenPos < tokens.length + public TreeNode newNode(long docCount, int tokenPos, Long[] logTokenIds) { + TreeNode node = tokenPos < maxDepth - 1 && tokenPos < logTokenIds.length ? new TreeNode.InnerTreeNode(docCount, tokenPos, maxChildren) : new TreeNode.LeafTreeNode(docCount, similarityThreshold); // The size of the node + entry (since it is a map entry) + extra reference for priority queue @@ -133,8 +129,8 @@ public TreeNode newNode(long docCount, int tokenPos, BytesRef[] tokens) { } @Override - public TextCategorization newGroup(long docCount, BytesRef[] logTokens) { - TextCategorization group = new TextCategorization(logTokens, docCount, idGen.incrementAndGet()); + public TextCategorization newGroup(long docCount, Long[] logTokenIds) { + TextCategorization group = new TextCategorization(logTokenIds, docCount, idGen.incrementAndGet()); // Get the regular size bytes from the LogGroup and how much it costs to reference it sizeInBytes += group.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF; return group; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java index c9b8033f9ae79..168a75c10eacb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java @@ -12,6 +12,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.PriorityQueue; +import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.search.aggregations.Aggregator; @@ -49,6 +50,7 @@ public class CategorizeTextAggregator extends DeferableBucketAggregator { private final int maxDepth; private final int similarityThreshold; private final LongKeyedBucketOrds bucketOrds; + private final CategorizationBytesRefHash bytesRefHash; protected CategorizeTextAggregator( String name, @@ -79,6 +81,7 @@ protected CategorizeTextAggregator( this.similarityThreshold = similarityThreshold; this.bucketOrds = LongKeyedBucketOrds.build(bigArrays(), CardinalityUpperBound.MANY); this.bucketCountThresholds = bucketCountThresholds; + this.bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(1000, bigArrays())); } @Override @@ -96,7 +99,7 @@ public InternalAggregation[] buildAggregations(long[] ordsToCollect) throws IOEx PriorityQueue ordered = new InternalCategorizationAggregation.BucketCountPriorityQueue(size); CategorizationTokenTree categorizationTokenTree = categorizers.get(ordsToCollect[ordIdx]); - for (InternalCategorizationAggregation.Bucket bucket : categorizationTokenTree.toIntermediateBuckets()) { + for (InternalCategorizationAggregation.Bucket bucket : categorizationTokenTree.toIntermediateBuckets(bytesRefHash)) { if (bucket.docCount < bucketCountThresholds.getShardMinDocCount()) { continue; } @@ -166,12 +169,12 @@ private void collectFromSource(int doc, long owningBucketOrd) throws IOException } private void processTokenStream(long owningBucketOrd, TokenStream ts, int doc) throws IOException { - ArrayList tokens = new ArrayList<>(); + ArrayList tokens = new ArrayList<>(); try { CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); ts.reset(); while (ts.incrementToken()) { - tokens.add(new BytesRef(termAtt)); + tokens.add(bytesRefHash.put(new BytesRef(termAtt))); } if (tokens.isEmpty()) { return; @@ -187,7 +190,7 @@ private void processTokenStream(long owningBucketOrd, TokenStream ts, int doc) t categorizers.set(owningBucketOrd, categorizer); } long previousSize = categorizer.ramBytesUsed(); - TextCategorization lg = categorizer.parseLogLine(tokens.toArray(BytesRef[]::new), docCountProvider.getDocCount(doc)); + TextCategorization lg = categorizer.parseLogLine(tokens.toArray(Long[]::new), docCountProvider.getDocCount(doc)); long newSize = categorizer.ramBytesUsed(); if (newSize - previousSize > 0) { addRequestCircuitBreakerBytes(newSize - previousSize); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java index 1103efd0a97a5..ceb3b6c97e4ac 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java @@ -12,6 +12,8 @@ import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactory; import org.elasticsearch.search.aggregations.CardinalityUpperBound; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.NonCollectingAggregator; import org.elasticsearch.search.aggregations.bucket.BucketUtils; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; import org.elasticsearch.search.aggregations.support.AggregationContext; @@ -48,7 +50,7 @@ public CategorizeTextAggregatorFactory( if (fieldType != null) { this.indexedFieldName = fieldType.name(); } else { - throw new IllegalArgumentException("Only works on indexed fields, cannot find field [" + fieldName + "]"); + this.indexedFieldName = null; } this.maxChildren = maxChildren; this.maxDepth = maxDepth; @@ -57,9 +59,30 @@ public CategorizeTextAggregatorFactory( this.bucketCountThresholds = bucketCountThresholds; } + protected Aggregator createUnmapped(Aggregator parent, Map metadata) throws IOException { + final InternalAggregation aggregation = new UnmappedCategorizationAggregation( + name, + bucketCountThresholds.getRequiredSize(), + bucketCountThresholds.getMinDocCount(), + maxChildren, + maxDepth, + similarityThreshold, + metadata + ); + return new NonCollectingAggregator(name, context, parent, factories, metadata) { + @Override + public InternalAggregation buildEmptyAggregation() { + return aggregation; + } + }; + } + @Override protected Aggregator createInternal(Aggregator parent, CardinalityUpperBound cardinality, Map metadata) throws IOException { + if (fieldType == null) { + return createUnmapped(parent, metadata); + } TermsAggregator.BucketCountThresholds bucketCountThresholds = new TermsAggregator.BucketCountThresholds(this.bucketCountThresholds); if (bucketCountThresholds.getShardSize() == CategorizeTextAggregationBuilder.DEFAULT_BUCKET_COUNT_THRESHOLDS.getShardSize()) { // The user has not made a shardSize selection. Use default diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java index 71eb913e1be43..d5ee7c9d65b25 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.xcontent.ToXContentFragment; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.search.aggregations.AggregationExecutionException; @@ -29,7 +30,8 @@ import java.util.List; import java.util.Map; -import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationTokenTree.WILD_CARD; +import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_REF; + public class InternalCategorizationAggregation extends InternalMultiBucketAggregation< InternalCategorizationAggregation, @@ -97,10 +99,10 @@ static BucketKey withCollapsedWildcards(BytesRef[] key) { List collapsedWildCards = new ArrayList<>(); boolean previousTokenWildCard = false; for (BytesRef token : key) { - if (token.equals(WILD_CARD)) { + if (token.equals(WILD_CARD_REF)) { if (previousTokenWildCard == false) { previousTokenWildCard = true; - collapsedWildCards.add(WILD_CARD); + collapsedWildCards.add(WILD_CARD_REF); } } else { previousTokenWildCard = false; @@ -244,9 +246,9 @@ public int compareTo(Bucket o) { } private final List buckets; - private final int maxChildren; - private final int similarityThreshold; - private final int maxDepth; + protected final int maxChildren; + protected final int similarityThreshold; + protected final int maxDepth; protected final int requiredSize; protected final long minDocCount; @@ -337,6 +339,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { @Override public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { + CategorizationBytesRefHash hash = new CategorizationBytesRefHash(new BytesRefHash(1L, reduceContext.bigArrays())); CategorizationTokenTree categorizationTokenTree = new CategorizationTokenTree(maxChildren, maxDepth, similarityThreshold); // TODO: Could we do a merge sort similar to terms? // It would require us returning partial reductions sorted by key, not by doc_count @@ -351,22 +354,23 @@ public InternalAggregation reduce(List aggregations, Reduce for (DelayedCategorizationBucket bucket : reduced.values()) { // Parse log line takes document count into account and merging on smallest groups - categorizationTokenTree.parseLogLine(bucket.key.keyAsTokens(), bucket.docCount); + categorizationTokenTree.parseLogLine(hash.getIds(bucket.key.keyAsTokens()), bucket.docCount); } // Collapse tiny groups together, this may result in new bucket keys for already known buckets categorizationTokenTree.mergeSmallestChildren(); Map mergedBuckets = new HashMap<>(); for (DelayedCategorizationBucket delayedBucket : reduced.values()) { - TextCategorization group = categorizationTokenTree.parseLogLineConst(delayedBucket.key.keyAsTokens()); + TextCategorization group = categorizationTokenTree.parseLogLineConst(hash.getIds(delayedBucket.key.keyAsTokens())); if (group == null) { throw new AggregationExecutionException( "Unexpected null categorization group for bucket [" + delayedBucket.key.asString() + "]" ); } + BytesRef[] categoryTokens = hash.getShallows(group.getCategorization()); BucketKey key = reduceContext.isFinalReduce() ? - BucketKey.withCollapsedWildcards(group.getCategorization()) : - new BucketKey(group.getCategorization()); + BucketKey.withCollapsedWildcards(categoryTokens) : + new BucketKey(categoryTokens); mergedBuckets.computeIfAbsent(key, k -> new DelayedCategorizationBucket(k, new ArrayList<>(delayedBucket.toReduce.size()), 0L)) .add(delayedBucket); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java index 10d1cc4dbd04d..ddc85bf6e34e4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java @@ -8,13 +8,11 @@ package org.elasticsearch.xpack.ml.aggs.categorization; import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; import java.util.Arrays; -import java.util.stream.Collectors; -import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationTokenTree.WILD_CARD; +import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_ID; /** * A text categorization group that provides methods for: @@ -24,30 +22,19 @@ class TextCategorization implements Accountable { private final long id; - private final BytesRef[] categorization; + // TODO Do we want to just make this native arrays? + private final Long[] categorization; private final long[] tokenCounts; private long count; // Used at the shard level for tracking the bucket ordinal for collecting sub aggregations long bucketOrd; - @Override - public String toString() { - return "LogGroup{" - + "id=" - + id - + ", logEvent=" - + Arrays.stream(categorization).map(BytesRef::utf8ToString).collect(Collectors.joining(", ", "[", "]")) - + ", count=" - + count - + '}'; - } - - TextCategorization(BytesRef[] logTokens, long count, long id) { + TextCategorization(Long[] logTokenIds, long count, long id) { this.id = id; - this.categorization = logTokens; + this.categorization = logTokenIds; this.count = count; - this.tokenCounts = new long[logTokens.length]; + this.tokenCounts = new long[logTokenIds.length]; Arrays.fill(this.tokenCounts, count); } @@ -55,7 +42,7 @@ public long getId() { return id; } - BytesRef[] getCategorization() { + Long[] getCategorization() { return categorization; } @@ -63,7 +50,7 @@ public long getCount() { return count; } - Similarity calculateSimilarity(BytesRef[] logEvent) { + Similarity calculateSimilarity(Long[] logEvent) { assert logEvent.length == this.categorization.length; int eqParams = 0; long tokenCount = 0; @@ -72,7 +59,7 @@ Similarity calculateSimilarity(BytesRef[] logEvent) { if (logEvent[i].equals(this.categorization[i])) { tokensKept += tokenCounts[i]; tokenCount += tokenCounts[i]; - } else if (this.categorization[i].equals(WILD_CARD)) { + } else if (this.categorization[i].equals(WILD_CARD_ID)) { eqParams++; } else { tokenCount += tokenCounts[i]; @@ -81,11 +68,11 @@ Similarity calculateSimilarity(BytesRef[] logEvent) { return new Similarity((double) tokensKept / tokenCount, eqParams); } - void addLog(BytesRef[] logEvent, long docCount) { + void addLog(Long[] logEvent, long docCount) { assert logEvent.length == this.categorization.length; for (int i = 0; i < logEvent.length; i++) { if (logEvent[i].equals(this.categorization[i]) == false) { - this.categorization[i] = WILD_CARD; + this.categorization[i] = WILD_CARD_ID; } else { tokenCounts[i] += docCount; } @@ -96,11 +83,11 @@ void addLog(BytesRef[] logEvent, long docCount) { @Override public long ramBytesUsed() { return Long.BYTES // id - + (long) categorization.length * RamUsageEstimator.NUM_BYTES_ARRAY_HEADER // logEvent - + RamUsageEstimator.NUM_BYTES_OBJECT_REF - + ((long) categorization.length * Long.BYTES) - + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER - + RamUsageEstimator.NUM_BYTES_OBJECT_REF + Long.BYTES; // count + + RamUsageEstimator.NUM_BYTES_OBJECT_REF // categorization reference + + RamUsageEstimator.shallowSizeOf(categorization) // categorization we don't deep copy the token ids + + RamUsageEstimator.NUM_BYTES_OBJECT_REF // tokenCounts reference + + RamUsageEstimator.sizeOf(tokenCounts) // tokenCounts + + Long.BYTES; // count } static class Similarity implements Comparable { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java index a67f654fc8c4e..f9eacf469843c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java @@ -9,9 +9,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.BytesRef; import org.elasticsearch.core.Tuple; import org.elasticsearch.search.aggregations.AggregationExecutionException; @@ -28,7 +26,7 @@ import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; import static org.apache.lucene.util.RamUsageEstimator.sizeOfCollection; import static org.apache.lucene.util.RamUsageEstimator.sizeOfMap; -import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationTokenTree.WILD_CARD; +import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_ID; /** * Tree node classes for the categorization token tree. @@ -60,9 +58,9 @@ final long getCount() { } // TODO add option for calculating the cost of adding the new group - abstract TextCategorization addLog(BytesRef[] logTokens, long docCount, TreeNodeFactory treeNodeFactory); + abstract TextCategorization addLog(Long[] logTokenIds, long docCount, TreeNodeFactory treeNodeFactory); - abstract TextCategorization getLogGroup(BytesRef[] logTokens); + abstract TextCategorization getLogGroup(Long[] logTokens); abstract List getAllChildrenLogGroups(); @@ -113,12 +111,10 @@ public long ramBytesUsed() { } @Override - public TextCategorization addLog(BytesRef[] logTokens, long docCount, TreeNodeFactory treeNodeFactory) { - return getAndUpdateLogGroup(logTokens, docCount).orElseGet(() -> { + public TextCategorization addLog(Long[] logTokenIds, long docCount, TreeNodeFactory treeNodeFactory) { + return getAndUpdateLogGroup(logTokenIds, docCount).orElseGet(() -> { // Need to update the tree if possible - TextCategorization group = treeNodeFactory.newGroup(docCount, logTokens); - LOGGER.trace(() -> new ParameterizedMessage("created group! [{}]", group)); - return putNewLogGroup(group); + return putNewLogGroup(treeNodeFactory.newGroup(docCount, logTokenIds)); }); } @@ -130,10 +126,10 @@ List getAllChildrenLogGroups() { @Override void collapseTinyChildren() {} - private Optional getAndUpdateLogGroup(BytesRef[] logTokens, long docCount) { - return getBestLogGroup(logTokens).map(bestGroupAndSimilarity -> { + private Optional getAndUpdateLogGroup(Long[] logTokenIds, long docCount) { + return getBestLogGroup(logTokenIds).map(bestGroupAndSimilarity -> { if ((bestGroupAndSimilarity.v2() * 100) >= similarityThreshold) { - bestGroupAndSimilarity.v1().addLog(logTokens, docCount); + bestGroupAndSimilarity.v1().addLog(logTokenIds, docCount); return bestGroupAndSimilarity.v1(); } return null; @@ -145,19 +141,19 @@ TextCategorization putNewLogGroup(TextCategorization group) { return group; } - private Optional> getBestLogGroup(BytesRef[] logTokens) { + private Optional> getBestLogGroup(Long[] logTokenIds) { if (textCategorizations.isEmpty()) { return Optional.empty(); } if (textCategorizations.size() == 1) { return Optional.of( - new Tuple<>(textCategorizations.get(0), textCategorizations.get(0).calculateSimilarity(logTokens).getSimilarity()) + new Tuple<>(textCategorizations.get(0), textCategorizations.get(0).calculateSimilarity(logTokenIds).getSimilarity()) ); } TextCategorization.Similarity maxSimilarity = null; TextCategorization bestGroup = null; for (TextCategorization textCategorization : this.textCategorizations) { - TextCategorization.Similarity groupSimilarity = textCategorization.calculateSimilarity(logTokens); + TextCategorization.Similarity groupSimilarity = textCategorization.calculateSimilarity(logTokenIds); if (maxSimilarity == null || groupSimilarity.compareTo(maxSimilarity) > 0) { maxSimilarity = groupSimilarity; bestGroup = textCategorization; @@ -167,8 +163,8 @@ private Optional> getBestLogGroup(BytesRef[] l } @Override - public TextCategorization getLogGroup(final BytesRef[] logTokens) { - return getBestLogGroup(logTokens).map(Tuple::v1).orElse(null); + public TextCategorization getLogGroup(final Long[] logTokenIds) { + return getBestLogGroup(logTokenIds).map(Tuple::v1).orElse(null); } @Override @@ -188,10 +184,11 @@ public int hashCode() { static class InnerTreeNode extends TreeNode { - private final Map children; + // TODO: Change to LongObjectMap? + private final Map children; private final int childrenTokenPos; private final int maxChildren; - private final PriorityQueue> smallestChild; + private final PriorityQueue> smallestChild; InnerTreeNode(long count, int childrenTokenPos, int maxChildren) { super(count); @@ -206,9 +203,9 @@ boolean isLeaf() { } @Override - public TextCategorization getLogGroup(final BytesRef[] logTokens) { - return getChild(logTokens[childrenTokenPos]).or(() -> getChild(WILD_CARD)) - .map(node -> node.getLogGroup(logTokens)) + public TextCategorization getLogGroup(final Long[] logTokenIds) { + return getChild(logTokenIds[childrenTokenPos]).or(() -> getChild(WILD_CARD_ID)) + .map(node -> node.getLogGroup(logTokenIds)) .orElse(null); } @@ -219,14 +216,14 @@ public long ramBytesUsed() { + Integer.BYTES // childrenTokenPos + Integer.BYTES // maxChildren + NUM_BYTES_OBJECT_REF // smallestChildReference - + sizeOfMap(children, 0) + + sizeOfMap(children, NUM_BYTES_OBJECT_REF) // children, // Number of items in the queue, reference to tuple, and then the tuple references + (long) smallestChild.size() * (NUM_BYTES_OBJECT_REF + NUM_BYTES_OBJECT_REF + NUM_BYTES_OBJECT_REF + Long.BYTES); } @Override - public TextCategorization addLog(final BytesRef[] logTokens, final long docCount, final TreeNodeFactory treeNodeFactory) { - BytesRef currentToken = logTokens[childrenTokenPos]; + public TextCategorization addLog(final Long[] logTokenIds, final long docCount, final TreeNodeFactory treeNodeFactory) { + Long currentToken = logTokenIds[childrenTokenPos]; TreeNode child = getChild(currentToken).map(node -> { node.incCount(docCount); if (smallestChild.isEmpty() == false && smallestChild.peek().v1().equals(currentToken)) { @@ -234,20 +231,10 @@ public TextCategorization addLog(final BytesRef[] logTokens, final long docCount } return node; }).orElseGet(() -> { - if (docCount > 1) { - LOGGER.trace( - () -> new ParameterizedMessage( - "got a token [{}] with doc_count [{}] percentage [{}]", - logTokens[childrenTokenPos].utf8ToString(), - docCount, - (double) docCount / docCount - ) - ); - } - TreeNode newNode = treeNodeFactory.newNode(docCount, childrenTokenPos + 1, logTokens); + TreeNode newNode = treeNodeFactory.newNode(docCount, childrenTokenPos + 1, logTokenIds); return addChild(currentToken, newNode); }); - return child.addLog(logTokens, docCount, treeNodeFactory); + return child.addLog(logTokenIds, docCount, treeNodeFactory); } @Override @@ -258,16 +245,16 @@ void collapseTinyChildren() { if (children.size() <= 1) { return; } - Optional maybeWildChild = getChild(WILD_CARD).or(() -> { + Optional maybeWildChild = getChild(WILD_CARD_ID).or(() -> { if ((double) smallestChild.peek().v2() / this.getCount() <= 1.0 / maxChildren) { TreeNode tinyChild = children.remove(smallestChild.poll().v1()); - return Optional.of(addChild(WILD_CARD, tinyChild)); + return Optional.of(addChild(WILD_CARD_ID, tinyChild)); } return Optional.empty(); }); if (maybeWildChild.isPresent()) { TreeNode wildChild = maybeWildChild.get(); - Tuple tinyNode; + Tuple tinyNode; while ((tinyNode = smallestChild.poll()) != null) { // If we have no more tiny nodes, stop iterating over them if ((double) tinyNode.v2() / this.getCount() > 1.0 / maxChildren) { @@ -293,24 +280,24 @@ void mergeWith(TreeNode treeNode) { ); } InnerTreeNode innerTreeNode = (InnerTreeNode) treeNode; - TreeNode siblingWildChild = innerTreeNode.children.remove(WILD_CARD); - addChild(WILD_CARD, siblingWildChild); - Tuple siblingChild; + TreeNode siblingWildChild = innerTreeNode.children.remove(WILD_CARD_ID); + addChild(WILD_CARD_ID, siblingWildChild); + Tuple siblingChild; while ((siblingChild = innerTreeNode.smallestChild.poll()) != null) { TreeNode nephewNode = innerTreeNode.children.remove(siblingChild.v1()); addChild(siblingChild.v1(), nephewNode); } } - private TreeNode addChild(BytesRef token, TreeNode node) { - if (node == null || token == null) { + private TreeNode addChild(Long tokenId, TreeNode node) { + if (node == null || tokenId == null) { return null; } - Optional existingChild = getChild(token).map(existingNode -> { + Optional existingChild = getChild(tokenId).map(existingNode -> { existingNode.mergeWith(node); - if (smallestChild.isEmpty() == false && smallestChild.peek().v1().equals(token)) { + if (smallestChild.isEmpty() == false && smallestChild.peek().v1().equals(tokenId)) { smallestChild.poll(); - smallestChild.add(Tuple.tuple(token, existingNode.getCount())); + smallestChild.add(Tuple.tuple(tokenId, existingNode.getCount())); } return existingNode; }); @@ -318,12 +305,12 @@ private TreeNode addChild(BytesRef token, TreeNode node) { return existingChild.get(); } if (children.size() == maxChildren) { - return getChild(WILD_CARD).map(wildChild -> { + return getChild(WILD_CARD_ID).map(wildChild -> { final TreeNode toMerge; final TreeNode toReturn; if (smallestChild.isEmpty() == false && node.getCount() > smallestChild.peek().v2()) { toMerge = children.remove(smallestChild.poll().v1()); - addChildAndUpdateSmallest(token, node); + addChildAndUpdateSmallest(tokenId, node); toReturn = node; } else { toMerge = node; @@ -336,44 +323,43 @@ private TreeNode addChild(BytesRef token, TreeNode node) { // we are about to hit the limit, add a wild card if we need to and then add the new child as appropriate if (children.size() == maxChildren - 1) { // If we already have a wild token, simply adding the new token is acceptable as we won't breach our limit - if (children.containsKey(WILD_CARD)) { - addChildAndUpdateSmallest(token, node); + if (children.containsKey(WILD_CARD_ID)) { + addChildAndUpdateSmallest(tokenId, node); } else { // if we don't have a wild card child, we need to add one now - if (token.equals(WILD_CARD)) { - addChildAndUpdateSmallest(token, node); + if (tokenId.equals(WILD_CARD_ID)) { + addChildAndUpdateSmallest(tokenId, node); } else { if (smallestChild.isEmpty() == false && node.count > smallestChild.peek().v2()) { - addChildAndUpdateSmallest(WILD_CARD, children.remove(smallestChild.poll().v1())); - addChildAndUpdateSmallest(token, node); + addChildAndUpdateSmallest(WILD_CARD_ID, children.remove(smallestChild.poll().v1())); + addChildAndUpdateSmallest(tokenId, node); } else { - addChildAndUpdateSmallest(WILD_CARD, node); + addChildAndUpdateSmallest(WILD_CARD_ID, node); } } } } else { - addChildAndUpdateSmallest(token, node); + addChildAndUpdateSmallest(tokenId, node); } return node; } - private void addChildAndUpdateSmallest(BytesRef token, TreeNode node) { - children.put(token, node); - if (token.equals(WILD_CARD) == false) { - smallestChild.add(Tuple.tuple(token, node.count)); + private void addChildAndUpdateSmallest(Long tokenId, TreeNode node) { + children.put(tokenId, node); + if (tokenId.equals(WILD_CARD_ID) == false) { + smallestChild.add(Tuple.tuple(tokenId, node.count)); } } - private Optional getChild(BytesRef token) { - TreeNode node = children.get(token); - return node == null ? Optional.empty() : Optional.of(node); + private Optional getChild(Long tokenId) { + return Optional.ofNullable(children.get(tokenId)); } public List getAllChildrenLogGroups() { return children.values().stream().flatMap(c -> c.getAllChildrenLogGroups().stream()).collect(Collectors.toList()); } - boolean hasChild(BytesRef value) { - return children.containsKey(value); + boolean hasChild(Long tokenId) { + return children.containsKey(tokenId); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java index eade0ecad3240..548a2fdb8e803 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java @@ -7,10 +7,9 @@ package org.elasticsearch.xpack.ml.aggs.categorization; -import org.apache.lucene.util.BytesRef; interface TreeNodeFactory { - TreeNode newNode(long docCount, int tokenPos, BytesRef[] logTokens); + TreeNode newNode(long docCount, int tokenPos, Long[] logTokenIds); - TextCategorization newGroup(long docCount, BytesRef[] logTokens); + TextCategorization newGroup(long docCount, Long[] logTokenIds); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java new file mode 100644 index 0000000000000..113399a833959 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.categorization; + +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.InternalAggregations; + +import java.util.List; +import java.util.Map; + + +class UnmappedCategorizationAggregation extends InternalCategorizationAggregation { + protected UnmappedCategorizationAggregation( + String name, + int requiredSize, + long minDocCount, + int maxChildren, + int maxDepth, + int similarityThreshold, + Map metadata + ) { + super(name, requiredSize, minDocCount, maxChildren, maxDepth, similarityThreshold, metadata); + } + + @Override + public InternalCategorizationAggregation create(List buckets) { + return new UnmappedCategorizationAggregation( + name, + requiredSize, + minDocCount, + maxChildren, + maxDepth, + similarityThreshold, + super.metadata + ); + } + + @Override + public Bucket createBucket(InternalAggregations aggregations, Bucket prototype) { + throw new UnsupportedOperationException("not supported for UnmappedCategorizationAggregation"); + } + + @Override + public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { + return new UnmappedCategorizationAggregation( + name, + requiredSize, + minDocCount, + maxChildren, + maxDepth, + similarityThreshold, + super.metadata + ); + } + + @Override + public boolean isMapped() { + return false; + } + + @Override + public List getBuckets() { + return List.of(); + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java index c9e66e3153c7b..bb8d62f1aaeba 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java @@ -8,9 +8,16 @@ package org.elasticsearch.xpack.ml.aggs.categorization; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.test.ESTestCase; +import org.junit.After; +import org.junit.Before; +import java.io.IOException; + +import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_ID; import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.getTokens; +import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.mockBigArrays; import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; @@ -18,92 +25,103 @@ public class InnerTreeNodeTests extends ESTestCase { private final TreeNodeFactory factory = new CategorizationTokenTree(3, 4, 60); + private CategorizationBytesRefHash bytesRefHash; + + @Before + public void createRefHash() { + bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(1L, mockBigArrays())); + } + + @After + public void closeRefHash() throws IOException { + bytesRefHash.close(); + } public void testAddLog() { TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); - TextCategorization group = innerTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 1, factory); - assertThat(group.getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); + TextCategorization group = innerTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, factory); + assertThat(group.getCategorization(), arrayContaining(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"))); assertThat( - innerTreeNode.addLog(getTokens("foo2", "bar", "baz", "biz"), 1, factory).getCategorization(), - arrayContaining(getTokens("foo2", "bar", "baz", "biz")) + innerTreeNode.addLog(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1, factory).getCategorization(), + arrayContaining(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz")) ); assertThat( - innerTreeNode.addLog(getTokens("foo3", "bar", "baz", "biz"), 1, factory).getCategorization(), - arrayContaining(getTokens("foo3", "bar", "baz", "biz")) + innerTreeNode.addLog(getTokens(bytesRefHash, "foo3", "bar", "baz", "biz"), 1, factory).getCategorization(), + arrayContaining(getTokens(bytesRefHash, "foo3", "bar", "baz", "biz")) ); assertThat( - innerTreeNode.addLog(getTokens("foo4", "bar", "baz", "biz"), 1, factory).getCategorization(), - arrayContaining(getTokens("*", "bar", "baz", "biz")) + innerTreeNode.addLog(getTokens(bytesRefHash, "foo4", "bar", "baz", "biz"), 1, factory).getCategorization(), + arrayContaining(getTokens(bytesRefHash, "*", "bar", "baz", "biz")) ); assertThat( - innerTreeNode.addLog(getTokens("foo", "bar", "baz", "bizzy"), 1, factory).getCategorization(), - arrayContaining(getTokens("foo", "bar", "baz", "*")) + innerTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "bizzy"), 1, factory).getCategorization(), + arrayContaining(getTokens(bytesRefHash, "foo", "bar", "baz", "*")) ); } public void testAddLogWithLargerIncoming() { TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); - TextCategorization group = innerTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 100, factory); - assertThat(group.getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); + TextCategorization group = innerTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 100, factory); + assertThat(group.getCategorization(), arrayContaining(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"))); assertThat( - innerTreeNode.addLog(getTokens("foo2", "bar", "baz", "biz"), 100, factory).getCategorization(), - arrayContaining(getTokens("foo2", "bar", "baz", "biz")) + innerTreeNode.addLog(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 100, factory).getCategorization(), + arrayContaining(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz")) ); assertThat( - innerTreeNode.addLog(getTokens("foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), - arrayContaining(getTokens("foosmall", "bar", "baz", "biz")) + innerTreeNode.addLog(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), + arrayContaining(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz")) ); assertThat( - innerTreeNode.addLog(getTokens("foobigun", "bar", "baz", "biz"), 1000, factory).getCategorization(), - arrayContaining(getTokens("foobigun", "bar", "baz", "biz")) + innerTreeNode.addLog(getTokens(bytesRefHash, "foobigun", "bar", "baz", "biz"), 1000, factory).getCategorization(), + arrayContaining(getTokens(bytesRefHash, "foobigun", "bar", "baz", "biz")) ); assertThat( - innerTreeNode.getLogGroup(getTokens("foosmall", "bar", "baz", "biz")).getCategorization(), - equalTo(getTokens("*", "bar", "baz", "biz")) + innerTreeNode.getLogGroup(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz")).getCategorization(), + equalTo(getTokens(bytesRefHash, "*", "bar", "baz", "biz")) ); } public void testCollapseTinyChildren() { TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1000, 0, 4); - TextCategorization group = innerTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 1000, factory); - assertThat(group.getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); + TextCategorization group = innerTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1000, factory); + assertThat(group.getCategorization(), arrayContaining(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"))); assertThat( - innerTreeNode.addLog(getTokens("foo2", "bar", "baz", "biz"), 1000, factory).getCategorization(), - arrayContaining(getTokens("foo2", "bar", "baz", "biz")) + innerTreeNode.addLog(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1000, factory).getCategorization(), + arrayContaining(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz")) ); innerTreeNode.incCount(1000); assertThat( - innerTreeNode.addLog(getTokens("foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), - arrayContaining(getTokens("foosmall", "bar", "baz", "biz")) + innerTreeNode.addLog(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), + arrayContaining(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz")) ); innerTreeNode.incCount(1); innerTreeNode.collapseTinyChildren(); - assertThat(innerTreeNode.hasChild(new BytesRef("foosmall")), is(false)); - assertThat(innerTreeNode.hasChild(new BytesRef("*")), is(true)); + assertThat(innerTreeNode.hasChild(bytesRefHash.put(new BytesRef("foosmall"))), is(false)); + assertThat(innerTreeNode.hasChild(WILD_CARD_ID), is(true)); } public void testMergeWith() { TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1000, 0, 3); - innerTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 1000, factory); + innerTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1000, factory); innerTreeNode.incCount(1000); - innerTreeNode.addLog(getTokens("foo2", "bar", "baz", "biz"), 1000, factory); + innerTreeNode.addLog(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1000, factory); expectThrows(UnsupportedOperationException.class, () -> innerTreeNode.mergeWith(new TreeNode.LeafTreeNode(1, 60))); TreeNode.InnerTreeNode mergeWith = new TreeNode.InnerTreeNode(1, 0, 3); - innerTreeNode.addLog(getTokens("foosmall", "bar", "baz", "biz"), 1, factory); + innerTreeNode.addLog(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory); innerTreeNode.incCount(1); - innerTreeNode.addLog(getTokens("footiny", "bar", "baz", "biz"), 1, factory); + innerTreeNode.addLog(getTokens(bytesRefHash, "footiny", "bar", "baz", "biz"), 1, factory); innerTreeNode.mergeWith(mergeWith); - assertThat(innerTreeNode.hasChild(new BytesRef("*")), is(true)); + assertThat(innerTreeNode.hasChild(WILD_CARD_ID), is(true)); assertThat( - innerTreeNode.getLogGroup(getTokens("footiny", "bar", "baz", "biz")).getCategorization(), - arrayContaining(getTokens("*", "bar", "baz", "biz")) + innerTreeNode.getLogGroup(getTokens(bytesRefHash, "footiny", "bar", "baz", "biz")).getCategorization(), + arrayContaining(getTokens(bytesRefHash, "*", "bar", "baz", "biz")) ); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java index 3c465774ca733..1b748cfc7aa49 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java @@ -7,9 +7,15 @@ package org.elasticsearch.xpack.ml.aggs.categorization; +import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.test.ESTestCase; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.getTokens; +import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.mockBigArrays; import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -19,25 +25,37 @@ public class LeafTreeNodeTests extends ESTestCase { private final TreeNodeFactory factory = new CategorizationTokenTree(10, 10, 60); + private CategorizationBytesRefHash bytesRefHash; + + @Before + public void createRefHash() { + bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(1L, mockBigArrays())); + } + + @After + public void closeRefHash() throws IOException { + bytesRefHash.close(); + } + public void testAddGroup() { TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 60); - TextCategorization group = leafTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 1, factory); + TextCategorization group = leafTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, factory); - assertThat(group.getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "biz"))); + assertThat(group.getCategorization(), arrayContaining(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"))); assertThat(group.getCount(), equalTo(1L)); assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(1)); long previousBytesUsed = leafTreeNode.ramBytesUsed(); - group = leafTreeNode.addLog(getTokens("foo", "bar", "bozo", "bizzy"), 1, factory); - assertThat(group.getCategorization(), arrayContaining(getTokens("foo", "bar", "bozo", "bizzy"))); + group = leafTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "bozo", "bizzy"), 1, factory); + assertThat(group.getCategorization(), arrayContaining(getTokens(bytesRefHash, "foo", "bar", "bozo", "bizzy"))); assertThat(group.getCount(), equalTo(1L)); assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(2)); assertThat(leafTreeNode.ramBytesUsed(), greaterThan(previousBytesUsed)); previousBytesUsed = leafTreeNode.ramBytesUsed(); - group = leafTreeNode.addLog(getTokens("foo", "bar", "baz", "different"), 3, factory); - assertThat(group.getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "*"))); + group = leafTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "different"), 3, factory); + assertThat(group.getCategorization(), arrayContaining(getTokens(bytesRefHash, "foo", "bar", "baz", "*"))); assertThat(group.getCount(), equalTo(4L)); assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(2)); assertThat(previousBytesUsed, equalTo(leafTreeNode.ramBytesUsed())); @@ -51,21 +69,24 @@ public void testMergeWith() { expectThrows(UnsupportedOperationException.class, () -> leafTreeNode.mergeWith(new TreeNode.InnerTreeNode(1, 2, 3))); leafTreeNode.incCount(5); - leafTreeNode.addLog(getTokens("foo", "bar", "baz", "biz"), 5, factory); + leafTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 5, factory); TreeNode.LeafTreeNode toMerge = new TreeNode.LeafTreeNode(0, 60); leafTreeNode.incCount(1); - toMerge.addLog(getTokens("foo", "bar", "baz", "bizzy"), 1, factory); + toMerge.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "bizzy"), 1, factory); leafTreeNode.incCount(1); - toMerge.addLog(getTokens("foo", "bart", "bat", "built"), 1, factory); + toMerge.addLog(getTokens(bytesRefHash, "foo", "bart", "bat", "built"), 1, factory); leafTreeNode.mergeWith(toMerge); assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(2)); assertThat(leafTreeNode.getCount(), equalTo(7L)); - assertThat(leafTreeNode.getAllChildrenLogGroups().get(0).getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "*"))); + assertThat( + leafTreeNode.getAllChildrenLogGroups().get(0).getCategorization(), + arrayContaining(getTokens(bytesRefHash, "foo", "bar", "baz", "*")) + ); assertThat( leafTreeNode.getAllChildrenLogGroups().get(1).getCategorization(), - arrayContaining(getTokens("foo", "bart", "bat", "built")) + arrayContaining(getTokens(bytesRefHash, "foo", "bart", "bat", "built")) ); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java index d18b5f2a68c26..e1250e6aa9b60 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java @@ -8,7 +8,17 @@ package org.elasticsearch.xpack.ml.aggs.categorization; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.MockPageCacheRecycler; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.test.ESTestCase; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.closeTo; @@ -16,35 +26,51 @@ public class TextCategorizationTests extends ESTestCase { + static BigArrays mockBigArrays() { + return new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + } + + private CategorizationBytesRefHash bytesRefHash; + + @Before + public void createRefHash() { + bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(1L, mockBigArrays())); + } + + @After + public void closeRefHash() throws IOException { + bytesRefHash.close(); + } + public void testSimilarity() { - TextCategorization lg = new TextCategorization(getTokens("foo", "bar", "baz", "biz"), 1, 1); - TextCategorization.Similarity sims = lg.calculateSimilarity(getTokens("not", "matching", "anything", "nope")); + TextCategorization lg = new TextCategorization(getTokens(bytesRefHash,"foo", "bar", "baz", "biz"), 1, 1); + TextCategorization.Similarity sims = lg.calculateSimilarity(getTokens(bytesRefHash,"not", "matching", "anything", "nope")); assertThat(sims.getSimilarity(), equalTo(0.0)); assertThat(sims.getWildCardCount(), equalTo(0)); - sims = lg.calculateSimilarity(getTokens("foo", "bar", "baz", "biz")); + sims = lg.calculateSimilarity(getTokens(bytesRefHash,"foo", "bar", "baz", "biz")); assertThat(sims.getSimilarity(), equalTo(1.0)); assertThat(sims.getWildCardCount(), equalTo(0)); - sims = lg.calculateSimilarity(getTokens("foo", "fooagain", "notbar", "biz")); + sims = lg.calculateSimilarity(getTokens(bytesRefHash,"foo", "fooagain", "notbar", "biz")); assertThat(sims.getSimilarity(), closeTo(0.5, 0.0001)); assertThat(sims.getWildCardCount(), equalTo(0)); } public void testAddLog() { - TextCategorization lg = new TextCategorization(getTokens("foo", "bar", "baz", "biz"), 1, 1); - lg.addLog(getTokens("foo", "bar", "baz", "bozo"), 2); + TextCategorization lg = new TextCategorization(getTokens(bytesRefHash,"foo", "bar", "baz", "biz"), 1, 1); + lg.addLog(getTokens(bytesRefHash,"foo", "bar", "baz", "bozo"), 2); assertThat(lg.getCount(), equalTo(3L)); - assertThat(lg.getCategorization(), arrayContaining(getTokens("foo", "bar", "baz", "*"))); + assertThat(lg.getCategorization(), arrayContaining(getTokens(bytesRefHash,"foo", "bar", "baz", "*"))); } - static BytesRef[] getTokens(String... tokens) { + static Long[] getTokens(CategorizationBytesRefHash bytesRefHash, String... tokens) { BytesRef[] refs = new BytesRef[tokens.length]; int i = 0; for (String token: tokens) { refs[i++] = new BytesRef(token); } - return refs; + return bytesRefHash.getIds(refs); } } From 9ba60fb2d99d1a3f012e9e59d157877f6834368a Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 22 Sep 2021 16:33:58 -0400 Subject: [PATCH 12/20] moving to bytes ref hash and fixing two bugs --- .../CategorizationBytesRefHash.java | 12 +- .../CategorizationTokenTree.java | 14 +- .../CategorizeTextAggregator.java | 20 ++- .../InternalCategorizationAggregation.java | 132 ++++++++++-------- .../categorization/TextCategorization.java | 18 +-- .../ml/aggs/categorization/TreeNode.java | 92 +++++++----- .../aggs/categorization/TreeNodeFactory.java | 4 +- .../CategorizeTextAggregatorTests.java | 5 +- .../categorization/InnerTreeNodeTests.java | 48 +++---- .../categorization/LeafTreeNodeTests.java | 16 +-- .../TextCategorizationTests.java | 19 ++- 11 files changed, 213 insertions(+), 167 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java index a45ccc4f4c4ef..35bc15ceb02af 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.ml.aggs.categorization; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.util.BytesRefHash; @@ -15,6 +17,7 @@ class CategorizationBytesRefHash implements Closeable { + private static final Logger logger = LogManager.getLogger(CategorizationBytesRefHash.class); static final BytesRef WILD_CARD_REF = new BytesRef("*"); static final long WILD_CARD_ID = -1; private final BytesRefHash bytesRefHash; @@ -30,15 +33,15 @@ BytesRef getShallow(long id) { return bytesRefHash.get(id, new BytesRef()); } - Long[] getIds(BytesRef[] tokens) { - Long[] ids = new Long[tokens.length]; + long[] getIds(BytesRef[] tokens) { + long[] ids = new long[tokens.length]; for (int i = 0; i < tokens.length; i++) { ids[i] = put(tokens[i]); } return ids; } - BytesRef[] getShallows(Long[] ids) { + BytesRef[] getShallows(long[] ids) { BytesRef[] tokens = new BytesRef[ids.length]; for (int i = 0; i < tokens.length; i++) { tokens[i] = getShallow(ids[i]); @@ -62,6 +65,9 @@ long put(BytesRef bytesRef) { if (hash < 0) { return -1 - hash; } else { + if (hash > Integer.MAX_VALUE) { + logger.error("More than Integer.MAX_VALUE unique terms"); + } return hash; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java index 2e8e09e43794c..fd9ff219ff471 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java @@ -76,10 +76,10 @@ public CategorizationTokenTree(int maxChildren, int maxDepth, int similarityThre public List toIntermediateBuckets(CategorizationBytesRefHash hash) { return root.values().stream().flatMap(c -> c.getAllChildrenLogGroups().stream()).map(lg -> { - Long[] categoryTokenIds = lg.getCategorization(); + long[] categoryTokenIds = lg.getCategorization(); BytesRef[] bytesRefs = new BytesRef[categoryTokenIds.length]; for (int i = 0; i < categoryTokenIds.length; i++) { - bytesRefs[i] = hash.getShallow(categoryTokenIds[i]); + bytesRefs[i] = hash.getDeep(categoryTokenIds[i]); } InternalCategorizationAggregation.Bucket bucket = new InternalCategorizationAggregation.Bucket( new InternalCategorizationAggregation.BucketKey(bytesRefs), @@ -95,11 +95,11 @@ void mergeSmallestChildren() { root.values().forEach(TreeNode::collapseTinyChildren); } - public TextCategorization parseLogLine(final Long[] logTokenIds) { + public TextCategorization parseLogLine(final long[] logTokenIds) { return parseLogLine(logTokenIds, 1); } - public TextCategorization parseLogLineConst(final Long[] logTokenIds) { + public TextCategorization parseLogLineConst(final long[] logTokenIds) { TreeNode currentNode = this.root.get(logTokenIds.length); if (currentNode == null) { // we are missing an entire sub tree. New log length found return null; @@ -107,7 +107,7 @@ public TextCategorization parseLogLineConst(final Long[] logTokenIds) { return currentNode.getLogGroup(logTokenIds); } - public TextCategorization parseLogLine(final Long[] logTokenIds, long docCount) { + public TextCategorization parseLogLine(final long[] logTokenIds, long docCount) { TreeNode currentNode = this.root.get(logTokenIds.length); if (currentNode == null) { // we are missing an entire sub tree. New log length found currentNode = newNode(docCount, 0, logTokenIds); @@ -119,7 +119,7 @@ public TextCategorization parseLogLine(final Long[] logTokenIds, long docCount) } @Override - public TreeNode newNode(long docCount, int tokenPos, Long[] logTokenIds) { + public TreeNode newNode(long docCount, int tokenPos, long[] logTokenIds) { TreeNode node = tokenPos < maxDepth - 1 && tokenPos < logTokenIds.length ? new TreeNode.InnerTreeNode(docCount, tokenPos, maxChildren) : new TreeNode.LeafTreeNode(docCount, similarityThreshold); @@ -129,7 +129,7 @@ public TreeNode newNode(long docCount, int tokenPos, Long[] logTokenIds) { } @Override - public TextCategorization newGroup(long docCount, Long[] logTokenIds) { + public TextCategorization newGroup(long docCount, long[] logTokenIds) { TextCategorization group = new TextCategorization(logTokenIds, docCount, idGen.incrementAndGet()); // Get the regular size bytes from the LogGroup and how much it costs to reference it sizeInBytes += group.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java index 168a75c10eacb..9cd44467669fe 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java @@ -30,6 +30,7 @@ import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer; import java.io.IOException; +import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -81,13 +82,19 @@ protected CategorizeTextAggregator( this.similarityThreshold = similarityThreshold; this.bucketOrds = LongKeyedBucketOrds.build(bigArrays(), CardinalityUpperBound.MANY); this.bucketCountThresholds = bucketCountThresholds; - this.bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(1000, bigArrays())); + this.bytesRefHash = new CategorizationBytesRefHash(new BytesRefHash(2048, bigArrays())); } @Override protected void doClose() { super.doClose(); this.analyzer.close(); + try { + this.bytesRefHash.close(); + } catch (IOException ex) { + //TODO Should we just eat the exception? + throw new UncheckedIOException(ex); + } } @Override @@ -95,10 +102,14 @@ public InternalAggregation[] buildAggregations(long[] ordsToCollect) throws IOEx InternalCategorizationAggregation.Bucket[][] topBucketsPerOrd = new InternalCategorizationAggregation.Bucket[ordsToCollect.length][]; for (int ordIdx = 0; ordIdx < ordsToCollect.length; ordIdx++) { + final CategorizationTokenTree categorizationTokenTree = categorizers.get(ordsToCollect[ordIdx]); + if (categorizationTokenTree == null) { + topBucketsPerOrd[ordIdx] = new InternalCategorizationAggregation.Bucket[0]; + continue; + } int size = (int) Math.min(bucketOrds.bucketsInOrd(ordIdx), bucketCountThresholds.getShardSize()); PriorityQueue ordered = new InternalCategorizationAggregation.BucketCountPriorityQueue(size); - CategorizationTokenTree categorizationTokenTree = categorizers.get(ordsToCollect[ordIdx]); for (InternalCategorizationAggregation.Bucket bucket : categorizationTokenTree.toIntermediateBuckets(bytesRefHash)) { if (bucket.docCount < bucketCountThresholds.getShardMinDocCount()) { continue; @@ -190,7 +201,10 @@ private void processTokenStream(long owningBucketOrd, TokenStream ts, int doc) t categorizers.set(owningBucketOrd, categorizer); } long previousSize = categorizer.ramBytesUsed(); - TextCategorization lg = categorizer.parseLogLine(tokens.toArray(Long[]::new), docCountProvider.getDocCount(doc)); + TextCategorization lg = categorizer.parseLogLine( + tokens.stream().mapToLong(Long::valueOf).toArray(), + docCountProvider.getDocCount(doc) + ); long newSize = categorizer.ramBytesUsed(); if (newSize - previousSize > 0) { addRequestCircuitBreakerBytes(newSize - previousSize); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java index d5ee7c9d65b25..84f19ccff5e6d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java @@ -23,6 +23,7 @@ import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation; import java.io.IOException; +import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; @@ -339,77 +340,84 @@ protected void doWriteTo(StreamOutput out) throws IOException { @Override public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { - CategorizationBytesRefHash hash = new CategorizationBytesRefHash(new BytesRefHash(1L, reduceContext.bigArrays())); - CategorizationTokenTree categorizationTokenTree = new CategorizationTokenTree(maxChildren, maxDepth, similarityThreshold); - // TODO: Could we do a merge sort similar to terms? - // It would require us returning partial reductions sorted by key, not by doc_count - // First, make sure we have all the counts for equal log groups - Map reduced = new HashMap<>(); - for (InternalAggregation aggregation : aggregations) { - InternalCategorizationAggregation categorizationAggregation = (InternalCategorizationAggregation) aggregation; - for (Bucket bucket : categorizationAggregation.buckets) { - reduced.computeIfAbsent(bucket.key, key -> new DelayedCategorizationBucket(key, new ArrayList<>(1), 0L)).add(bucket); + try (CategorizationBytesRefHash hash = new CategorizationBytesRefHash(new BytesRefHash(1L, reduceContext.bigArrays()))) { + CategorizationTokenTree categorizationTokenTree = new CategorizationTokenTree(maxChildren, maxDepth, similarityThreshold); + // TODO: Could we do a merge sort similar to terms? + // It would require us returning partial reductions sorted by key, not by doc_count + // First, make sure we have all the counts for equal log groups + Map reduced = new HashMap<>(); + for (InternalAggregation aggregation : aggregations) { + InternalCategorizationAggregation categorizationAggregation = (InternalCategorizationAggregation) aggregation; + for (Bucket bucket : categorizationAggregation.buckets) { + reduced.computeIfAbsent(bucket.key, key -> new DelayedCategorizationBucket(key, new ArrayList<>(1), 0L)).add(bucket); + } } - } - for (DelayedCategorizationBucket bucket : reduced.values()) { - // Parse log line takes document count into account and merging on smallest groups - categorizationTokenTree.parseLogLine(hash.getIds(bucket.key.keyAsTokens()), bucket.docCount); - } - // Collapse tiny groups together, this may result in new bucket keys for already known buckets - categorizationTokenTree.mergeSmallestChildren(); - Map mergedBuckets = new HashMap<>(); - for (DelayedCategorizationBucket delayedBucket : reduced.values()) { - TextCategorization group = categorizationTokenTree.parseLogLineConst(hash.getIds(delayedBucket.key.keyAsTokens())); - if (group == null) { - throw new AggregationExecutionException( - "Unexpected null categorization group for bucket [" + delayedBucket.key.asString() + "]" + reduced.values() + .stream() + .sorted(Comparator.comparing(DelayedCategorizationBucket::getDocCount).reversed()) + .forEach(bucket -> + // Parse log line takes document count into account and merging on smallest groups + categorizationTokenTree.parseLogLine(hash.getIds(bucket.key.keyAsTokens()), bucket.docCount) ); + categorizationTokenTree.mergeSmallestChildren(); + Map mergedBuckets = new HashMap<>(); + for (DelayedCategorizationBucket delayedBucket : reduced.values()) { + TextCategorization group = categorizationTokenTree.parseLogLineConst(hash.getIds(delayedBucket.key.keyAsTokens())); + if (group == null) { + throw new AggregationExecutionException( + "Unexpected null categorization group for bucket [" + delayedBucket.key.asString() + "]" + ); + } + BytesRef[] categoryTokens = hash.getShallows(group.getCategorization()); + + BucketKey key = reduceContext.isFinalReduce() ? + BucketKey.withCollapsedWildcards(categoryTokens) : + new BucketKey(categoryTokens); + mergedBuckets.computeIfAbsent( + key, + k -> new DelayedCategorizationBucket(k, new ArrayList<>(delayedBucket.toReduce.size()), 0L) + ).add(delayedBucket); } - BytesRef[] categoryTokens = hash.getShallows(group.getCategorization()); - - BucketKey key = reduceContext.isFinalReduce() ? - BucketKey.withCollapsedWildcards(categoryTokens) : - new BucketKey(categoryTokens); - mergedBuckets.computeIfAbsent(key, k -> new DelayedCategorizationBucket(k, new ArrayList<>(delayedBucket.toReduce.size()), 0L)) - .add(delayedBucket); - } - final int size = reduceContext.isFinalReduce() == false ? mergedBuckets.size() : Math.min(requiredSize, mergedBuckets.size()); - final PriorityQueue pq = new BucketCountPriorityQueue(size); - for (Map.Entry keyAndBuckets : mergedBuckets.entrySet()) { - final BucketKey key = keyAndBuckets.getKey(); - DelayedCategorizationBucket bucket = keyAndBuckets.getValue(); - Bucket newBucket = bucket.reduce(key, reduceContext); - if ((newBucket.docCount >= minDocCount) || reduceContext.isFinalReduce() == false) { - Bucket removed = pq.insertWithOverflow(newBucket); - if (removed == null) { - reduceContext.consumeBucketsAndMaybeBreak(1); + final int size = reduceContext.isFinalReduce() == false ? mergedBuckets.size() : Math.min(requiredSize, mergedBuckets.size()); + final PriorityQueue pq = new BucketCountPriorityQueue(size); + for (Map.Entry keyAndBuckets : mergedBuckets.entrySet()) { + final BucketKey key = keyAndBuckets.getKey(); + DelayedCategorizationBucket bucket = keyAndBuckets.getValue(); + Bucket newBucket = bucket.reduce(key, reduceContext); + if ((newBucket.docCount >= minDocCount) || reduceContext.isFinalReduce() == false) { + Bucket removed = pq.insertWithOverflow(newBucket); + if (removed == null) { + reduceContext.consumeBucketsAndMaybeBreak(1); + } else { + reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(removed)); + } } else { - reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(removed)); + reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(newBucket)); } - } else { - reduceContext.consumeBucketsAndMaybeBreak(-countInnerBucket(newBucket)); } + Bucket[] bucketList = new Bucket[pq.size()]; + for (int i = pq.size() - 1; i >= 0; i--) { + bucketList[i] = pq.pop(); + } + // Keep the top categories top, but then sort by the key for those with duplicate counts + if (reduceContext.isFinalReduce()) { + Arrays.sort(bucketList, Comparator.comparing(Bucket::getDocCount).reversed().thenComparing(Bucket::getRawKey)); + } + return new InternalCategorizationAggregation( + name, + requiredSize, + minDocCount, + maxChildren, + maxDepth, + similarityThreshold, + metadata, + Arrays.asList(bucketList) + ); + } catch (IOException ex) { + throw new UncheckedIOException(ex); } - Bucket[] bucketList = new Bucket[pq.size()]; - for (int i = pq.size() - 1; i >= 0; i--) { - bucketList[i] = pq.pop(); - } - // Keep the top categories top, but then sort by the key for those with duplicate counts - if (reduceContext.isFinalReduce()) { - Arrays.sort(bucketList, Comparator.comparing(Bucket::getDocCount).reversed().thenComparing(Bucket::getRawKey)); - } - return new InternalCategorizationAggregation( - name, - requiredSize, - minDocCount, - maxChildren, - maxDepth, - similarityThreshold, - metadata, - Arrays.asList(bucketList) - ); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java index ddc85bf6e34e4..6c5e153026458 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java @@ -23,14 +23,14 @@ class TextCategorization implements Accountable { private final long id; // TODO Do we want to just make this native arrays? - private final Long[] categorization; + private final long[] categorization; private final long[] tokenCounts; private long count; // Used at the shard level for tracking the bucket ordinal for collecting sub aggregations long bucketOrd; - TextCategorization(Long[] logTokenIds, long count, long id) { + TextCategorization(long[] logTokenIds, long count, long id) { this.id = id; this.categorization = logTokenIds; this.count = count; @@ -42,7 +42,7 @@ public long getId() { return id; } - Long[] getCategorization() { + long[] getCategorization() { return categorization; } @@ -50,16 +50,16 @@ public long getCount() { return count; } - Similarity calculateSimilarity(Long[] logEvent) { + Similarity calculateSimilarity(long[] logEvent) { assert logEvent.length == this.categorization.length; int eqParams = 0; long tokenCount = 0; long tokensKept = 0; for (int i = 0; i < logEvent.length; i++) { - if (logEvent[i].equals(this.categorization[i])) { + if (logEvent[i] == this.categorization[i]) { tokensKept += tokenCounts[i]; tokenCount += tokenCounts[i]; - } else if (this.categorization[i].equals(WILD_CARD_ID)) { + } else if (this.categorization[i] == WILD_CARD_ID) { eqParams++; } else { tokenCount += tokenCounts[i]; @@ -68,10 +68,10 @@ Similarity calculateSimilarity(Long[] logEvent) { return new Similarity((double) tokensKept / tokenCount, eqParams); } - void addLog(Long[] logEvent, long docCount) { + void addLog(long[] logEvent, long docCount) { assert logEvent.length == this.categorization.length; for (int i = 0; i < logEvent.length; i++) { - if (logEvent[i].equals(this.categorization[i]) == false) { + if (logEvent[i] != this.categorization[i]) { this.categorization[i] = WILD_CARD_ID; } else { tokenCounts[i] += docCount; @@ -84,7 +84,7 @@ void addLog(Long[] logEvent, long docCount) { public long ramBytesUsed() { return Long.BYTES // id + RamUsageEstimator.NUM_BYTES_OBJECT_REF // categorization reference - + RamUsageEstimator.shallowSizeOf(categorization) // categorization we don't deep copy the token ids + + RamUsageEstimator.sizeOf(categorization) // categorization token Ids + RamUsageEstimator.NUM_BYTES_OBJECT_REF // tokenCounts reference + RamUsageEstimator.sizeOf(tokenCounts) // tokenCounts + Long.BYTES; // count diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java index f9eacf469843c..6a6a5ec51c712 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java @@ -58,9 +58,9 @@ final long getCount() { } // TODO add option for calculating the cost of adding the new group - abstract TextCategorization addLog(Long[] logTokenIds, long docCount, TreeNodeFactory treeNodeFactory); + abstract TextCategorization addLog(long[] logTokenIds, long docCount, TreeNodeFactory treeNodeFactory); - abstract TextCategorization getLogGroup(Long[] logTokens); + abstract TextCategorization getLogGroup(long[] logTokens); abstract List getAllChildrenLogGroups(); @@ -111,7 +111,7 @@ public long ramBytesUsed() { } @Override - public TextCategorization addLog(Long[] logTokenIds, long docCount, TreeNodeFactory treeNodeFactory) { + public TextCategorization addLog(long[] logTokenIds, long docCount, TreeNodeFactory treeNodeFactory) { return getAndUpdateLogGroup(logTokenIds, docCount).orElseGet(() -> { // Need to update the tree if possible return putNewLogGroup(treeNodeFactory.newGroup(docCount, logTokenIds)); @@ -126,7 +126,7 @@ List getAllChildrenLogGroups() { @Override void collapseTinyChildren() {} - private Optional getAndUpdateLogGroup(Long[] logTokenIds, long docCount) { + private Optional getAndUpdateLogGroup(long[] logTokenIds, long docCount) { return getBestLogGroup(logTokenIds).map(bestGroupAndSimilarity -> { if ((bestGroupAndSimilarity.v2() * 100) >= similarityThreshold) { bestGroupAndSimilarity.v1().addLog(logTokenIds, docCount); @@ -141,7 +141,7 @@ TextCategorization putNewLogGroup(TextCategorization group) { return group; } - private Optional> getBestLogGroup(Long[] logTokenIds) { + private Optional> getBestLogGroup(long[] logTokenIds) { if (textCategorizations.isEmpty()) { return Optional.empty(); } @@ -163,7 +163,7 @@ private Optional> getBestLogGroup(Long[] logTo } @Override - public TextCategorization getLogGroup(final Long[] logTokenIds) { + public TextCategorization getLogGroup(final long[] logTokenIds) { return getBestLogGroup(logTokenIds).map(Tuple::v1).orElse(null); } @@ -188,14 +188,14 @@ static class InnerTreeNode extends TreeNode { private final Map children; private final int childrenTokenPos; private final int maxChildren; - private final PriorityQueue> smallestChild; + private final PriorityQueue smallestChild; InnerTreeNode(long count, int childrenTokenPos, int maxChildren) { super(count); children = new HashMap<>(); this.childrenTokenPos = childrenTokenPos; this.maxChildren = maxChildren; - this.smallestChild = new PriorityQueue<>(maxChildren, Comparator.comparing(Tuple::v2)); + this.smallestChild = new PriorityQueue<>(maxChildren, Comparator.comparing(NativeLongPair::count)); } boolean isLeaf() { @@ -203,7 +203,7 @@ boolean isLeaf() { } @Override - public TextCategorization getLogGroup(final Long[] logTokenIds) { + public TextCategorization getLogGroup(final long[] logTokenIds) { return getChild(logTokenIds[childrenTokenPos]).or(() -> getChild(WILD_CARD_ID)) .map(node -> node.getLogGroup(logTokenIds)) .orElse(null); @@ -218,15 +218,15 @@ public long ramBytesUsed() { + NUM_BYTES_OBJECT_REF // smallestChildReference + sizeOfMap(children, NUM_BYTES_OBJECT_REF) // children, // Number of items in the queue, reference to tuple, and then the tuple references - + (long) smallestChild.size() * (NUM_BYTES_OBJECT_REF + NUM_BYTES_OBJECT_REF + NUM_BYTES_OBJECT_REF + Long.BYTES); + + (long) smallestChild.size() * (NUM_BYTES_OBJECT_REF + Long.BYTES + Long.BYTES); } @Override - public TextCategorization addLog(final Long[] logTokenIds, final long docCount, final TreeNodeFactory treeNodeFactory) { - Long currentToken = logTokenIds[childrenTokenPos]; + public TextCategorization addLog(final long[] logTokenIds, final long docCount, final TreeNodeFactory treeNodeFactory) { + final long currentToken = logTokenIds[childrenTokenPos]; TreeNode child = getChild(currentToken).map(node -> { node.incCount(docCount); - if (smallestChild.isEmpty() == false && smallestChild.peek().v1().equals(currentToken)) { + if (smallestChild.isEmpty() == false && smallestChild.peek().tokenId == currentToken) { smallestChild.add(smallestChild.poll()); } return node; @@ -246,22 +246,22 @@ void collapseTinyChildren() { return; } Optional maybeWildChild = getChild(WILD_CARD_ID).or(() -> { - if ((double) smallestChild.peek().v2() / this.getCount() <= 1.0 / maxChildren) { - TreeNode tinyChild = children.remove(smallestChild.poll().v1()); + if ((double) smallestChild.peek().count / this.getCount() <= 1.0 / maxChildren) { + TreeNode tinyChild = children.remove(smallestChild.poll().tokenId); return Optional.of(addChild(WILD_CARD_ID, tinyChild)); } return Optional.empty(); }); if (maybeWildChild.isPresent()) { TreeNode wildChild = maybeWildChild.get(); - Tuple tinyNode; + NativeLongPair tinyNode; while ((tinyNode = smallestChild.poll()) != null) { // If we have no more tiny nodes, stop iterating over them - if ((double) tinyNode.v2() / this.getCount() > 1.0 / maxChildren) { + if ((double) tinyNode.count / this.getCount() > 1.0 / maxChildren) { smallestChild.add(tinyNode); break; } else { - wildChild.mergeWith(children.remove(tinyNode.v1())); + wildChild.mergeWith(children.remove(tinyNode.count)); } } } @@ -282,22 +282,22 @@ void mergeWith(TreeNode treeNode) { InnerTreeNode innerTreeNode = (InnerTreeNode) treeNode; TreeNode siblingWildChild = innerTreeNode.children.remove(WILD_CARD_ID); addChild(WILD_CARD_ID, siblingWildChild); - Tuple siblingChild; + NativeLongPair siblingChild; while ((siblingChild = innerTreeNode.smallestChild.poll()) != null) { - TreeNode nephewNode = innerTreeNode.children.remove(siblingChild.v1()); - addChild(siblingChild.v1(), nephewNode); + TreeNode nephewNode = innerTreeNode.children.remove(siblingChild.tokenId); + addChild(siblingChild.tokenId, nephewNode); } } - private TreeNode addChild(Long tokenId, TreeNode node) { - if (node == null || tokenId == null) { + private TreeNode addChild(long tokenId, TreeNode node) { + if (node == null) { return null; } Optional existingChild = getChild(tokenId).map(existingNode -> { existingNode.mergeWith(node); - if (smallestChild.isEmpty() == false && smallestChild.peek().v1().equals(tokenId)) { + if (smallestChild.isEmpty() == false && smallestChild.peek().tokenId == tokenId) { smallestChild.poll(); - smallestChild.add(Tuple.tuple(tokenId, existingNode.getCount())); + smallestChild.add(NativeLongPair.of(tokenId, existingNode.getCount())); } return existingNode; }); @@ -308,8 +308,8 @@ private TreeNode addChild(Long tokenId, TreeNode node) { return getChild(WILD_CARD_ID).map(wildChild -> { final TreeNode toMerge; final TreeNode toReturn; - if (smallestChild.isEmpty() == false && node.getCount() > smallestChild.peek().v2()) { - toMerge = children.remove(smallestChild.poll().v1()); + if (smallestChild.isEmpty() == false && node.getCount() > smallestChild.peek().count) { + toMerge = children.remove(smallestChild.poll().tokenId); addChildAndUpdateSmallest(tokenId, node); toReturn = node; } else { @@ -326,11 +326,11 @@ private TreeNode addChild(Long tokenId, TreeNode node) { if (children.containsKey(WILD_CARD_ID)) { addChildAndUpdateSmallest(tokenId, node); } else { // if we don't have a wild card child, we need to add one now - if (tokenId.equals(WILD_CARD_ID)) { + if (tokenId == WILD_CARD_ID) { addChildAndUpdateSmallest(tokenId, node); } else { - if (smallestChild.isEmpty() == false && node.count > smallestChild.peek().v2()) { - addChildAndUpdateSmallest(WILD_CARD_ID, children.remove(smallestChild.poll().v1())); + if (smallestChild.isEmpty() == false && node.count > smallestChild.peek().count) { + addChildAndUpdateSmallest(WILD_CARD_ID, children.remove(smallestChild.poll().tokenId)); addChildAndUpdateSmallest(tokenId, node); } else { addChildAndUpdateSmallest(WILD_CARD_ID, node); @@ -343,14 +343,14 @@ private TreeNode addChild(Long tokenId, TreeNode node) { return node; } - private void addChildAndUpdateSmallest(Long tokenId, TreeNode node) { + private void addChildAndUpdateSmallest(long tokenId, TreeNode node) { children.put(tokenId, node); - if (tokenId.equals(WILD_CARD_ID) == false) { - smallestChild.add(Tuple.tuple(tokenId, node.count)); + if (tokenId != WILD_CARD_ID) { + smallestChild.add(NativeLongPair.of(tokenId, node.count)); } } - private Optional getChild(Long tokenId) { + private Optional getChild(long tokenId) { return Optional.ofNullable(children.get(tokenId)); } @@ -358,7 +358,7 @@ public List getAllChildrenLogGroups() { return children.values().stream().flatMap(c -> c.getAllChildrenLogGroups().stream()).collect(Collectors.toList()); } - boolean hasChild(Long tokenId) { + boolean hasChild(long tokenId) { return children.containsKey(tokenId); } @@ -379,4 +379,26 @@ public int hashCode() { } } + private static class NativeLongPair { + private final long tokenId; + private final long count; + + static NativeLongPair of(long tokenId, long count) { + return new NativeLongPair(tokenId, count); + } + + NativeLongPair(long tokenId, long count) { + this.tokenId = tokenId; + this.count = count; + } + + public long tokenId() { + return tokenId; + } + + public long count() { + return count; + } + } + } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java index 548a2fdb8e803..1533987ae5954 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java @@ -9,7 +9,7 @@ interface TreeNodeFactory { - TreeNode newNode(long docCount, int tokenPos, Long[] logTokenIds); + TreeNode newNode(long docCount, int tokenPos, long[] logTokenIds); - TextCategorization newGroup(long docCount, Long[] logTokenIds); + TextCategorization newGroup(long docCount, long[] logTokenIds); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java index 58dbad7b0bb20..8560da8f76190 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java @@ -178,8 +178,9 @@ public void testCategorizationAsSubAgg() throws Exception { HistogramAggregationBuilder aggBuilder = new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) .interval(2) .subAggregation( - new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME) - .subAggregation(new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME)) + new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( + new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME) + ) .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java index bb8d62f1aaeba..ac52eab68c826 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java @@ -18,7 +18,6 @@ import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_ID; import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.getTokens; import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.mockBigArrays; -import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; @@ -40,42 +39,42 @@ public void closeRefHash() throws IOException { public void testAddLog() { TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); TextCategorization group = innerTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, factory); - assertThat(group.getCategorization(), arrayContaining(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"))); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); - assertThat( + assertArrayEquals( innerTreeNode.addLog(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1, factory).getCategorization(), - arrayContaining(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz")) + getTokens(bytesRefHash, "foo2", "bar", "baz", "biz") ); - assertThat( + assertArrayEquals( innerTreeNode.addLog(getTokens(bytesRefHash, "foo3", "bar", "baz", "biz"), 1, factory).getCategorization(), - arrayContaining(getTokens(bytesRefHash, "foo3", "bar", "baz", "biz")) + getTokens(bytesRefHash, "foo3", "bar", "baz", "biz") ); - assertThat( + assertArrayEquals( innerTreeNode.addLog(getTokens(bytesRefHash, "foo4", "bar", "baz", "biz"), 1, factory).getCategorization(), - arrayContaining(getTokens(bytesRefHash, "*", "bar", "baz", "biz")) + getTokens(bytesRefHash, "*", "bar", "baz", "biz") ); - assertThat( + assertArrayEquals( innerTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "bizzy"), 1, factory).getCategorization(), - arrayContaining(getTokens(bytesRefHash, "foo", "bar", "baz", "*")) + getTokens(bytesRefHash, "foo", "bar", "baz", "*") ); } public void testAddLogWithLargerIncoming() { TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); TextCategorization group = innerTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 100, factory); - assertThat(group.getCategorization(), arrayContaining(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"))); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); - assertThat( + assertArrayEquals( innerTreeNode.addLog(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 100, factory).getCategorization(), - arrayContaining(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz")) + getTokens(bytesRefHash, "foo2", "bar", "baz", "biz") ); - assertThat( + assertArrayEquals( innerTreeNode.addLog(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), - arrayContaining(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz")) + getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz") ); - assertThat( + assertArrayEquals( innerTreeNode.addLog(getTokens(bytesRefHash, "foobigun", "bar", "baz", "biz"), 1000, factory).getCategorization(), - arrayContaining(getTokens(bytesRefHash, "foobigun", "bar", "baz", "biz")) + getTokens(bytesRefHash, "foobigun", "bar", "baz", "biz") ); assertThat( innerTreeNode.getLogGroup(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz")).getCategorization(), @@ -86,16 +85,16 @@ public void testAddLogWithLargerIncoming() { public void testCollapseTinyChildren() { TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1000, 0, 4); TextCategorization group = innerTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1000, factory); - assertThat(group.getCategorization(), arrayContaining(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"))); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); - assertThat( + assertArrayEquals( innerTreeNode.addLog(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1000, factory).getCategorization(), - arrayContaining(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz")) + getTokens(bytesRefHash, "foo2", "bar", "baz", "biz") ); innerTreeNode.incCount(1000); - assertThat( + assertArrayEquals( innerTreeNode.addLog(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), - arrayContaining(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz")) + getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz") ); innerTreeNode.incCount(1); innerTreeNode.collapseTinyChildren(); @@ -111,7 +110,6 @@ public void testMergeWith() { expectThrows(UnsupportedOperationException.class, () -> innerTreeNode.mergeWith(new TreeNode.LeafTreeNode(1, 60))); - TreeNode.InnerTreeNode mergeWith = new TreeNode.InnerTreeNode(1, 0, 3); innerTreeNode.addLog(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory); innerTreeNode.incCount(1); @@ -119,9 +117,9 @@ public void testMergeWith() { innerTreeNode.mergeWith(mergeWith); assertThat(innerTreeNode.hasChild(WILD_CARD_ID), is(true)); - assertThat( + assertArrayEquals( innerTreeNode.getLogGroup(getTokens(bytesRefHash, "footiny", "bar", "baz", "biz")).getCategorization(), - arrayContaining(getTokens(bytesRefHash, "*", "bar", "baz", "biz")) + getTokens(bytesRefHash, "*", "bar", "baz", "biz") ); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java index 1b748cfc7aa49..1e8ac11217812 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java @@ -16,7 +16,6 @@ import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.getTokens; import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.mockBigArrays; -import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; @@ -41,21 +40,20 @@ public void testAddGroup() { TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 60); TextCategorization group = leafTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, factory); - assertThat(group.getCategorization(), arrayContaining(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"))); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); assertThat(group.getCount(), equalTo(1L)); assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(1)); long previousBytesUsed = leafTreeNode.ramBytesUsed(); group = leafTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "bozo", "bizzy"), 1, factory); - assertThat(group.getCategorization(), arrayContaining(getTokens(bytesRefHash, "foo", "bar", "bozo", "bizzy"))); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "bozo", "bizzy")); assertThat(group.getCount(), equalTo(1L)); assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(2)); assertThat(leafTreeNode.ramBytesUsed(), greaterThan(previousBytesUsed)); previousBytesUsed = leafTreeNode.ramBytesUsed(); - group = leafTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "different"), 3, factory); - assertThat(group.getCategorization(), arrayContaining(getTokens(bytesRefHash, "foo", "bar", "baz", "*"))); + assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "*")); assertThat(group.getCount(), equalTo(4L)); assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(2)); assertThat(previousBytesUsed, equalTo(leafTreeNode.ramBytesUsed())); @@ -80,13 +78,13 @@ public void testMergeWith() { assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(2)); assertThat(leafTreeNode.getCount(), equalTo(7L)); - assertThat( + assertArrayEquals( leafTreeNode.getAllChildrenLogGroups().get(0).getCategorization(), - arrayContaining(getTokens(bytesRefHash, "foo", "bar", "baz", "*")) + getTokens(bytesRefHash, "foo", "bar", "baz", "*") ); - assertThat( + assertArrayEquals( leafTreeNode.getAllChildrenLogGroups().get(1).getCategorization(), - arrayContaining(getTokens(bytesRefHash, "foo", "bart", "bat", "built")) + getTokens(bytesRefHash, "foo", "bart", "bat", "built") ); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java index e1250e6aa9b60..bde8e3a8a6d42 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java @@ -20,7 +20,6 @@ import java.io.IOException; -import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; @@ -43,31 +42,31 @@ public void closeRefHash() throws IOException { } public void testSimilarity() { - TextCategorization lg = new TextCategorization(getTokens(bytesRefHash,"foo", "bar", "baz", "biz"), 1, 1); - TextCategorization.Similarity sims = lg.calculateSimilarity(getTokens(bytesRefHash,"not", "matching", "anything", "nope")); + TextCategorization lg = new TextCategorization(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, 1); + TextCategorization.Similarity sims = lg.calculateSimilarity(getTokens(bytesRefHash, "not", "matching", "anything", "nope")); assertThat(sims.getSimilarity(), equalTo(0.0)); assertThat(sims.getWildCardCount(), equalTo(0)); - sims = lg.calculateSimilarity(getTokens(bytesRefHash,"foo", "bar", "baz", "biz")); + sims = lg.calculateSimilarity(getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); assertThat(sims.getSimilarity(), equalTo(1.0)); assertThat(sims.getWildCardCount(), equalTo(0)); - sims = lg.calculateSimilarity(getTokens(bytesRefHash,"foo", "fooagain", "notbar", "biz")); + sims = lg.calculateSimilarity(getTokens(bytesRefHash, "foo", "fooagain", "notbar", "biz")); assertThat(sims.getSimilarity(), closeTo(0.5, 0.0001)); assertThat(sims.getWildCardCount(), equalTo(0)); } public void testAddLog() { - TextCategorization lg = new TextCategorization(getTokens(bytesRefHash,"foo", "bar", "baz", "biz"), 1, 1); - lg.addLog(getTokens(bytesRefHash,"foo", "bar", "baz", "bozo"), 2); + TextCategorization lg = new TextCategorization(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, 1); + lg.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "bozo"), 2); assertThat(lg.getCount(), equalTo(3L)); - assertThat(lg.getCategorization(), arrayContaining(getTokens(bytesRefHash,"foo", "bar", "baz", "*"))); + assertArrayEquals(lg.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "*")); } - static Long[] getTokens(CategorizationBytesRefHash bytesRefHash, String... tokens) { + static long[] getTokens(CategorizationBytesRefHash bytesRefHash, String... tokens) { BytesRef[] refs = new BytesRef[tokens.length]; int i = 0; - for (String token: tokens) { + for (String token : tokens) { refs[i++] = new BytesRef(token); } return bytesRefHash.getIds(refs); From fc1656a4dfc77a744d56d266732348452460e9f1 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 23 Sep 2021 08:02:42 -0400 Subject: [PATCH 13/20] fixing tokenization bug and addressing PR comments --- .../support/AggregationContext.java | 45 ++++++++++++++++--- .../index/mapper/MapperServiceTestCase.java | 15 ++++++- .../CategorizationBytesRefHash.java | 11 +---- .../CategorizeTextAggregator.java | 31 ++++++++++--- .../InternalCategorizationAggregation.java | 2 +- .../CategorizationAnalyzer.java | 5 +++ 6 files changed, 85 insertions(+), 24 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java index 9ebb30db75cd1..8ce69009f09e0 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/AggregationContext.java @@ -19,6 +19,7 @@ import org.elasticsearch.core.Releasables; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.AnalysisRegistry; +import org.elasticsearch.index.analysis.NameOrDefinition; import org.elasticsearch.index.analysis.NamedAnalyzer; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.fielddata.IndexFieldData; @@ -97,9 +98,28 @@ public final FieldContext buildFieldContext(String field) { } /** - * @return The analysis registry for the node. Allows specialized aggregations to build custom analyzers for tokenizing text + * Returns an existing registered analyzer that should NOT be closed when finished being used. + * @param analyzer The custom analyzer name + * @return The existing named analyzer. */ - public abstract AnalysisRegistry getAnalysisRegistry(); + public abstract Analyzer getNamedAnalyzer(String analyzer) throws IOException; + + /** + * Creates a new custom analyzer that should be closed when finished being used. + * @param indexSettings The current index settings or null + * @param normalizer Is a normalizer + * @param tokenizer The tokenizer name or definition to use + * @param charFilters The char filter name or definition to use + * @param tokenFilters The token filter name or definition to use + * @return A new custom analyzer + */ + public abstract Analyzer buildCustomAnalyzer( + IndexSettings indexSettings, + boolean normalizer, + NameOrDefinition tokenizer, + List charFilters, + List tokenFilters + ) throws IOException; /** * Lookup the context for an already resolved field type. @@ -336,11 +356,6 @@ public ProductionAggregationContext( this.enableRewriteToFilterByFilter = enableRewriteToFilterByFilter; } - @Override - public AnalysisRegistry getAnalysisRegistry() { - return this.analysisRegistry; - } - @Override public Query query() { return topLevelQuery.get(); @@ -364,6 +379,22 @@ public long nowInMillis() { return context.nowInMillis(); } + @Override + public Analyzer getNamedAnalyzer(String analyzer) throws IOException { + return analysisRegistry.getAnalyzer(analyzer); + } + + @Override + public Analyzer buildCustomAnalyzer( + IndexSettings indexSettings, + boolean normalizer, + NameOrDefinition tokenizer, + List charFilters, + List tokenFilters + ) throws IOException { + return analysisRegistry.buildCustomAnalyzer(indexSettings, normalizer, tokenizer, charFilters, tokenFilters); + } + @Override protected IndexFieldData buildFieldData(MappedFieldType ft) { return context.getForField(ft); diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java index ae42718c241a0..c9d1529e45430 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java @@ -35,9 +35,9 @@ import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.index.IndexSettings; -import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.index.analysis.AnalyzerScope; import org.elasticsearch.index.analysis.IndexAnalyzers; +import org.elasticsearch.index.analysis.NameOrDefinition; import org.elasticsearch.index.analysis.NamedAnalyzer; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.fielddata.IndexFieldData; @@ -354,7 +354,18 @@ public long nowInMillis() { } @Override - public AnalysisRegistry getAnalysisRegistry() { + public Analyzer getNamedAnalyzer(String analyzer) { + return null; + } + + @Override + public Analyzer buildCustomAnalyzer( + IndexSettings indexSettings, + boolean normalizer, + NameOrDefinition tokenizer, + List charFilters, + List tokenFilters + ) { return null; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java index 35bc15ceb02af..c143a4d72c11a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java @@ -26,13 +26,6 @@ class CategorizationBytesRefHash implements Closeable { this.bytesRefHash = bytesRefHash; } - BytesRef getShallow(long id) { - if (id == WILD_CARD_ID) { - return WILD_CARD_REF; - } - return bytesRefHash.get(id, new BytesRef()); - } - long[] getIds(BytesRef[] tokens) { long[] ids = new long[tokens.length]; for (int i = 0; i < tokens.length; i++) { @@ -41,10 +34,10 @@ long[] getIds(BytesRef[] tokens) { return ids; } - BytesRef[] getShallows(long[] ids) { + BytesRef[] getDeeps(long[] ids) { BytesRef[] tokens = new BytesRef[ids.length]; for (int i = 0; i < tokens.length; i++) { - tokens[i] = getShallow(ids[i]); + tokens[i] = getDeep(ids[i]); } return tokens; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java index 9cd44467669fe..c6ec1a93ecdc2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java @@ -7,10 +7,12 @@ package org.elasticsearch.xpack.ml.aggs.categorization; +import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefBuilder; import org.apache.lucene.util.PriorityQueue; import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.ObjectArray; @@ -71,11 +73,29 @@ protected CategorizeTextAggregator( this.sourceLookup = context.lookup().source(); this.sourceFieldName = sourceFieldName; this.fieldType = fieldType; - this.analyzer = new CategorizationAnalyzer( - context.getAnalysisRegistry(), - Optional.ofNullable(categorizationAnalyzerConfig) - .orElse(CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(Collections.emptyList())) - ); + CategorizationAnalyzerConfig analyzerConfig = Optional.ofNullable(categorizationAnalyzerConfig) + .orElse(CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(Collections.emptyList())); + String analyzer = analyzerConfig.getAnalyzer(); + final boolean shouldClose; + final Analyzer innerAnalyzer; + if (analyzer != null) { + Analyzer globalAnalyzer = context.getNamedAnalyzer(analyzer); + if (globalAnalyzer == null) { + throw new IllegalArgumentException("Failed to find global analyzer [" + analyzer + "]"); + } + innerAnalyzer = globalAnalyzer; + shouldClose = false; + } else { + innerAnalyzer = context.buildCustomAnalyzer( + context.getIndexSettings(), + false, + analyzerConfig.getTokenizer(), + analyzerConfig.getCharFilters(), + analyzerConfig.getTokenFilters() + ); + shouldClose = true; + } + this.analyzer = new CategorizationAnalyzer(innerAnalyzer, shouldClose); this.categorizers = bigArrays().newObjectArray(1); this.maxChildren = maxChildren; this.maxDepth = maxDepth; @@ -156,6 +176,7 @@ public InternalAggregation buildEmptyAggregation() { @Override protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { return new LeafBucketCollectorBase(sub, null) { + private final BytesRefBuilder scratch = new BytesRefBuilder(); @Override public void collect(int doc, long owningBucketOrd) throws IOException { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java index 84f19ccff5e6d..58f3d8df81d7f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java @@ -369,7 +369,7 @@ public InternalAggregation reduce(List aggregations, Reduce "Unexpected null categorization group for bucket [" + delayedBucket.key.asString() + "]" ); } - BytesRef[] categoryTokens = hash.getShallows(group.getCategorization()); + BytesRef[] categoryTokens = hash.getDeeps(group.getCategorization()); BucketKey key = reduceContext.isFinalReduce() ? BucketKey.withCollapsedWildcards(categoryTokens) : diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java index 229b505c21783..d7e403aa59034 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java @@ -38,6 +38,11 @@ public CategorizationAnalyzer(AnalysisRegistry analysisRegistry, closeAnalyzer = tuple.v2(); } + public CategorizationAnalyzer(Analyzer analyzer, boolean closeAnalyzer) { + this.analyzer = analyzer; + this.closeAnalyzer = closeAnalyzer; + } + public final TokenStream tokenStream(final String fieldName, final String text) { return analyzer.tokenStream(fieldName, text); From 5deb2034f43a885963c61f7bf96e903c14f74967 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 23 Sep 2021 10:05:21 -0400 Subject: [PATCH 14/20] updating docs, restricting to 2billion unique tokens --- .../AggConstructionContentionBenchmark.java | 14 ++++++++- .../categorize-text-aggregation.asciidoc | 12 ++++++-- .../CategorizationBytesRefHash.java | 29 ++++++++++++------- .../CategorizationTokenTree.java | 12 ++++---- .../CategorizeTextAggregator.java | 4 +-- .../categorization/TextCategorization.java | 11 ++++--- .../ml/aggs/categorization/TreeNode.java | 16 +++++----- .../aggs/categorization/TreeNodeFactory.java | 4 +-- .../TextCategorizationTests.java | 2 +- 9 files changed, 65 insertions(+), 39 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java index 996ab9dc66850..61608a4d30ef8 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java @@ -23,6 +23,7 @@ import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.analysis.AnalysisRegistry; +import org.elasticsearch.index.analysis.NameOrDefinition; import org.elasticsearch.index.analysis.NamedAnalyzer; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; import org.elasticsearch.index.fielddata.IndexFieldData; @@ -199,7 +200,18 @@ public long nowInMillis() { } @Override - public AnalysisRegistry getAnalysisRegistry() { + public Analyzer getNamedAnalyzer(String analyzer) { + return null; + } + + @Override + public Analyzer buildCustomAnalyzer( + IndexSettings indexSettings, + boolean normalizer, + NameOrDefinition tokenizer, + List charFilters, + List tokenFilters + ) { return null; } diff --git a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc index 4f054049a03a3..030cd536ace63 100644 --- a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc +++ b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc @@ -11,7 +11,15 @@ using a custom analyzer. The resulting tokens are then categorized creating buck text values. This aggregation works best with machine generated text like system logs. WARNING: Re-analyzing _large_ result sets will require a lot of time and memory. This aggregation should be - used in conjunction with <>. + used in conjunction with <>. Additionally, you may consider + using the aggregation as a child of either the <> or + <> aggregation. + This will typically improve speed and memory use. + +NOTE: If you have considerable memory allocated to your JVM but are receiving circuit breaker exceptions from this + aggregation, you may be attempting to categorize text that is poorly formatted for categorization. Consider + adding `categorization_filters` or running under <> or + <> to explore the created categories. [[bucket-categorize-text-agg-syntax]] ==== Parameters @@ -24,7 +32,7 @@ The semi-structured text field to categorize. (Optional, integer, default: `50`) The maximum number of unique tokens at any given layer of the tokenization tree. Must be larger than 1. Smaller values use less memory and create fewer categories. -Larger values will use more memory and create more categories. +Larger values will use more memory and create narrower categories. `max_depth`:: (Optional, integer, default: `5`) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java index c143a4d72c11a..1e4143993d8f7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java @@ -7,34 +7,33 @@ package org.elasticsearch.xpack.ml.aggs.categorization; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.logging.LoggerMessageFormat; import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.search.aggregations.AggregationExecutionException; import java.io.Closeable; import java.io.IOException; class CategorizationBytesRefHash implements Closeable { - private static final Logger logger = LogManager.getLogger(CategorizationBytesRefHash.class); static final BytesRef WILD_CARD_REF = new BytesRef("*"); - static final long WILD_CARD_ID = -1; + static final int WILD_CARD_ID = -1; private final BytesRefHash bytesRefHash; CategorizationBytesRefHash(BytesRefHash bytesRefHash) { this.bytesRefHash = bytesRefHash; } - long[] getIds(BytesRef[] tokens) { - long[] ids = new long[tokens.length]; + int[] getIds(BytesRef[] tokens) { + int[] ids = new int[tokens.length]; for (int i = 0; i < tokens.length; i++) { ids[i] = put(tokens[i]); } return ids; } - BytesRef[] getDeeps(long[] ids) { + BytesRef[] getDeeps(int[] ids) { BytesRef[] tokens = new BytesRef[ids.length]; for (int i = 0; i < tokens.length; i++) { tokens[i] = getDeep(ids[i]); @@ -50,18 +49,26 @@ BytesRef getDeep(long id) { return BytesRef.deepCopyOf(shallow); } - long put(BytesRef bytesRef) { + int put(BytesRef bytesRef) { if (WILD_CARD_REF.equals(bytesRef)) { return WILD_CARD_ID; } long hash = bytesRefHash.add(bytesRef); if (hash < 0) { - return -1 - hash; + return (int) (-1L - hash); } else { if (hash > Integer.MAX_VALUE) { - logger.error("More than Integer.MAX_VALUE unique terms"); + throw new AggregationExecutionException( + LoggerMessageFormat.format( + "more than [{}] unique terms encountered. " + + "Consider restricting the documents queried or adding [{}] in the {} configuration", + Integer.MAX_VALUE, + CategorizeTextAggregationBuilder.CATEGORIZATION_FILTERS.getPreferredName(), + CategorizeTextAggregationBuilder.NAME + ) + ); } - return hash; + return (int) hash; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java index fd9ff219ff471..d8f5f315b1a6e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java @@ -76,7 +76,7 @@ public CategorizationTokenTree(int maxChildren, int maxDepth, int similarityThre public List toIntermediateBuckets(CategorizationBytesRefHash hash) { return root.values().stream().flatMap(c -> c.getAllChildrenLogGroups().stream()).map(lg -> { - long[] categoryTokenIds = lg.getCategorization(); + int[] categoryTokenIds = lg.getCategorization(); BytesRef[] bytesRefs = new BytesRef[categoryTokenIds.length]; for (int i = 0; i < categoryTokenIds.length; i++) { bytesRefs[i] = hash.getDeep(categoryTokenIds[i]); @@ -95,11 +95,11 @@ void mergeSmallestChildren() { root.values().forEach(TreeNode::collapseTinyChildren); } - public TextCategorization parseLogLine(final long[] logTokenIds) { + public TextCategorization parseLogLine(final int[] logTokenIds) { return parseLogLine(logTokenIds, 1); } - public TextCategorization parseLogLineConst(final long[] logTokenIds) { + public TextCategorization parseLogLineConst(final int[] logTokenIds) { TreeNode currentNode = this.root.get(logTokenIds.length); if (currentNode == null) { // we are missing an entire sub tree. New log length found return null; @@ -107,7 +107,7 @@ public TextCategorization parseLogLineConst(final long[] logTokenIds) { return currentNode.getLogGroup(logTokenIds); } - public TextCategorization parseLogLine(final long[] logTokenIds, long docCount) { + public TextCategorization parseLogLine(final int[] logTokenIds, long docCount) { TreeNode currentNode = this.root.get(logTokenIds.length); if (currentNode == null) { // we are missing an entire sub tree. New log length found currentNode = newNode(docCount, 0, logTokenIds); @@ -119,7 +119,7 @@ public TextCategorization parseLogLine(final long[] logTokenIds, long docCount) } @Override - public TreeNode newNode(long docCount, int tokenPos, long[] logTokenIds) { + public TreeNode newNode(long docCount, int tokenPos, int[] logTokenIds) { TreeNode node = tokenPos < maxDepth - 1 && tokenPos < logTokenIds.length ? new TreeNode.InnerTreeNode(docCount, tokenPos, maxChildren) : new TreeNode.LeafTreeNode(docCount, similarityThreshold); @@ -129,7 +129,7 @@ public TreeNode newNode(long docCount, int tokenPos, long[] logTokenIds) { } @Override - public TextCategorization newGroup(long docCount, long[] logTokenIds) { + public TextCategorization newGroup(long docCount, int[] logTokenIds) { TextCategorization group = new TextCategorization(logTokenIds, docCount, idGen.incrementAndGet()); // Get the regular size bytes from the LogGroup and how much it costs to reference it sizeInBytes += group.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java index c6ec1a93ecdc2..03ae9ba3d6783 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java @@ -201,7 +201,7 @@ private void collectFromSource(int doc, long owningBucketOrd) throws IOException } private void processTokenStream(long owningBucketOrd, TokenStream ts, int doc) throws IOException { - ArrayList tokens = new ArrayList<>(); + ArrayList tokens = new ArrayList<>(); try { CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); ts.reset(); @@ -223,7 +223,7 @@ private void processTokenStream(long owningBucketOrd, TokenStream ts, int doc) t } long previousSize = categorizer.ramBytesUsed(); TextCategorization lg = categorizer.parseLogLine( - tokens.stream().mapToLong(Long::valueOf).toArray(), + tokens.stream().mapToInt(Integer::valueOf).toArray(), docCountProvider.getDocCount(doc) ); long newSize = categorizer.ramBytesUsed(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java index 6c5e153026458..84c37c2486ef3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java @@ -22,15 +22,14 @@ class TextCategorization implements Accountable { private final long id; - // TODO Do we want to just make this native arrays? - private final long[] categorization; + private final int[] categorization; private final long[] tokenCounts; private long count; // Used at the shard level for tracking the bucket ordinal for collecting sub aggregations long bucketOrd; - TextCategorization(long[] logTokenIds, long count, long id) { + TextCategorization(int[] logTokenIds, long count, long id) { this.id = id; this.categorization = logTokenIds; this.count = count; @@ -42,7 +41,7 @@ public long getId() { return id; } - long[] getCategorization() { + int[] getCategorization() { return categorization; } @@ -50,7 +49,7 @@ public long getCount() { return count; } - Similarity calculateSimilarity(long[] logEvent) { + Similarity calculateSimilarity(int[] logEvent) { assert logEvent.length == this.categorization.length; int eqParams = 0; long tokenCount = 0; @@ -68,7 +67,7 @@ Similarity calculateSimilarity(long[] logEvent) { return new Similarity((double) tokensKept / tokenCount, eqParams); } - void addLog(long[] logEvent, long docCount) { + void addLog(int[] logEvent, long docCount) { assert logEvent.length == this.categorization.length; for (int i = 0; i < logEvent.length; i++) { if (logEvent[i] != this.categorization[i]) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java index 6a6a5ec51c712..badde8f152a32 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java @@ -58,9 +58,9 @@ final long getCount() { } // TODO add option for calculating the cost of adding the new group - abstract TextCategorization addLog(long[] logTokenIds, long docCount, TreeNodeFactory treeNodeFactory); + abstract TextCategorization addLog(int[] logTokenIds, long docCount, TreeNodeFactory treeNodeFactory); - abstract TextCategorization getLogGroup(long[] logTokens); + abstract TextCategorization getLogGroup(int[] logTokens); abstract List getAllChildrenLogGroups(); @@ -111,7 +111,7 @@ public long ramBytesUsed() { } @Override - public TextCategorization addLog(long[] logTokenIds, long docCount, TreeNodeFactory treeNodeFactory) { + public TextCategorization addLog(int[] logTokenIds, long docCount, TreeNodeFactory treeNodeFactory) { return getAndUpdateLogGroup(logTokenIds, docCount).orElseGet(() -> { // Need to update the tree if possible return putNewLogGroup(treeNodeFactory.newGroup(docCount, logTokenIds)); @@ -126,7 +126,7 @@ List getAllChildrenLogGroups() { @Override void collapseTinyChildren() {} - private Optional getAndUpdateLogGroup(long[] logTokenIds, long docCount) { + private Optional getAndUpdateLogGroup(int[] logTokenIds, long docCount) { return getBestLogGroup(logTokenIds).map(bestGroupAndSimilarity -> { if ((bestGroupAndSimilarity.v2() * 100) >= similarityThreshold) { bestGroupAndSimilarity.v1().addLog(logTokenIds, docCount); @@ -141,7 +141,7 @@ TextCategorization putNewLogGroup(TextCategorization group) { return group; } - private Optional> getBestLogGroup(long[] logTokenIds) { + private Optional> getBestLogGroup(int[] logTokenIds) { if (textCategorizations.isEmpty()) { return Optional.empty(); } @@ -163,7 +163,7 @@ private Optional> getBestLogGroup(long[] logTo } @Override - public TextCategorization getLogGroup(final long[] logTokenIds) { + public TextCategorization getLogGroup(final int[] logTokenIds) { return getBestLogGroup(logTokenIds).map(Tuple::v1).orElse(null); } @@ -203,7 +203,7 @@ boolean isLeaf() { } @Override - public TextCategorization getLogGroup(final long[] logTokenIds) { + public TextCategorization getLogGroup(final int[] logTokenIds) { return getChild(logTokenIds[childrenTokenPos]).or(() -> getChild(WILD_CARD_ID)) .map(node -> node.getLogGroup(logTokenIds)) .orElse(null); @@ -222,7 +222,7 @@ public long ramBytesUsed() { } @Override - public TextCategorization addLog(final long[] logTokenIds, final long docCount, final TreeNodeFactory treeNodeFactory) { + public TextCategorization addLog(final int[] logTokenIds, final long docCount, final TreeNodeFactory treeNodeFactory) { final long currentToken = logTokenIds[childrenTokenPos]; TreeNode child = getChild(currentToken).map(node -> { node.incCount(docCount); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java index 1533987ae5954..adbc7997cb5c3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java @@ -9,7 +9,7 @@ interface TreeNodeFactory { - TreeNode newNode(long docCount, int tokenPos, long[] logTokenIds); + TreeNode newNode(long docCount, int tokenPos, int[] logTokenIds); - TextCategorization newGroup(long docCount, long[] logTokenIds); + TextCategorization newGroup(long docCount, int[] logTokenIds); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java index bde8e3a8a6d42..827019ea86843 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java @@ -63,7 +63,7 @@ public void testAddLog() { assertArrayEquals(lg.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "*")); } - static long[] getTokens(CategorizationBytesRefHash bytesRefHash, String... tokens) { + static int[] getTokens(CategorizationBytesRefHash bytesRefHash, String... tokens) { BytesRef[] refs = new BytesRef[tokens.length]; int i = 0; for (String token : tokens) { From 1cf93ebc38dd95c189d99a57fc24fa242c1d82e3 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 23 Sep 2021 16:03:13 -0400 Subject: [PATCH 15/20] renaming parameters and updating docs --- .../categorize-text-aggregation.asciidoc | 22 ++++-- .../query-dsl/has-child-query.asciidoc | 4 +- .../CategorizationAggregationIT.java | 4 +- .../CategorizationTokenTree.java | 16 ++--- .../CategorizeTextAggregationBuilder.java | 68 +++++++++---------- .../CategorizeTextAggregator.java | 22 +++--- .../CategorizeTextAggregatorFactory.java | 20 +++--- .../InternalCategorizationAggregation.java | 40 ++++++----- .../UnmappedCategorizationAggregation.java | 8 +-- ...CategorizeTextAggregationBuilderTests.java | 4 +- .../test/ml/categorization_agg.yml | 12 ++-- 11 files changed, 116 insertions(+), 104 deletions(-) diff --git a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc index 030cd536ace63..636b6c32d1f9d 100644 --- a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc +++ b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc @@ -28,17 +28,26 @@ NOTE: If you have considerable memory allocated to your JVM but are receiving ci (Required, string) The semi-structured text field to categorize. -`max_children`:: +`max_unique_tokens`:: (Optional, integer, default: `50`) -The maximum number of unique tokens at any given layer of the tokenization tree. +The maximum number of unique tokens at any position up to `max_matched_tokens`. Must be larger than 1. Smaller values use less memory and create fewer categories. Larger values will use more memory and create narrower categories. -`max_depth`:: +`max_matched_tokens`:: (Optional, integer, default: `5`) -The maximum number of tokens matched on before attempting to merge categories. +The maximum number of token positions to match on before attempting to merge categories. Larger values will use more memory and create narrower categories. +Example: +`max_matched_tokens` of 2 would disallow merging of the categories +[`foo` `bar` `baz`] +[`foo` `baz` `bozo`] +As the first 2 tokens are required to match for the category. + +NOTE: Once `max_unique_tokens` is reached at a given position, a new `*` token is +added and all new tokens at that position are matched by the `*` token. + `similarity_threshold`:: (Optional, integer, default: `50`) The minimum percentage of tokens that must match for text to be added to the @@ -220,7 +229,7 @@ POST log-messages/_search?filter_path=aggregations "categorize_text": { "field": "message", "categorization_filters": ["\\w+\\_\\d{3}"], <1> - "max_depth": 2, <2> + "max_matched_tokens": 2, <2> "similarity_threshold": 30 <3> } } @@ -230,8 +239,7 @@ POST log-messages/_search?filter_path=aggregations // TEST[setup:categorize_text] <1> The filters to apply to the analyzed tokens. It filters out tokens like `bar_123`. -<2> Only the token tree to have 2 tokens before the log categories - attempt to merge together +<2> Require at least 2 tokens before the log categories attempt to merge together <3> Require 30% of the tokens to match before expanding a log categories to add a new log entry diff --git a/docs/reference/query-dsl/has-child-query.asciidoc b/docs/reference/query-dsl/has-child-query.asciidoc index 7b8158aaab5bf..9021f04c762b2 100644 --- a/docs/reference/query-dsl/has-child-query.asciidoc +++ b/docs/reference/query-dsl/has-child-query.asciidoc @@ -16,7 +16,7 @@ unique parent documents increases. Each `has_child` query in a search can increase query time significantly. If you care about query performance, do not use this query. If you need to use -the `has_child` query, use it as rarely as possible. +the `has_child` query, use it as rarely as possible. ==== [[has-child-query-ex-request]] @@ -59,7 +59,7 @@ GET /_search "query": { "match_all": {} }, - "max_children": 10, + "max_unique_tokens": 10, "min_children": 2, "score_mode": "min" } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java index 0f289a0419db0..bd85984b3619f 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java @@ -83,8 +83,8 @@ public void testAggregationWithBroadCategories() { .addAggregation( new CategorizeTextAggregationBuilder("categorize", "msg") .setSimilarityThreshold(11) - .setMaxChildren(2) - .setMaxDepth(1) + .setMaxUniqueTokens(2) + .setMaxMatchedTokens(1) .subAggregation(AggregationBuilders.max("max").field("time")) .subAggregation(AggregationBuilders.min("min").field("time")) ).get(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java index d8f5f315b1a6e..560bcddee8bf1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java @@ -53,18 +53,18 @@ */ public class CategorizationTokenTree implements Accountable, TreeNodeFactory { - private final int maxDepth; - private final int maxChildren; + private final int maxMatchTokens; + private final int maxUniqueTokens; private final int similarityThreshold; private final AtomicLong idGen = new AtomicLong(); // TODO statically allocate an array like DuplicateByteSequenceSpotter ??? private final Map root = new HashMap<>(); private long sizeInBytes; - public CategorizationTokenTree(int maxChildren, int maxDepth, int similarityThreshold) { - assert maxChildren > 0 && maxDepth >= 0; - this.maxChildren = maxChildren; - this.maxDepth = maxDepth; + public CategorizationTokenTree(int maxUniqueTokens, int maxMatchTokens, int similarityThreshold) { + assert maxUniqueTokens > 0 && maxMatchTokens >= 0; + this.maxUniqueTokens = maxUniqueTokens; + this.maxMatchTokens = maxMatchTokens; this.similarityThreshold = similarityThreshold; this.sizeInBytes = Integer.BYTES // maxDepth + Integer.BYTES // maxChildren @@ -120,8 +120,8 @@ public TextCategorization parseLogLine(final int[] logTokenIds, long docCount) { @Override public TreeNode newNode(long docCount, int tokenPos, int[] logTokenIds) { - TreeNode node = tokenPos < maxDepth - 1 && tokenPos < logTokenIds.length - ? new TreeNode.InnerTreeNode(docCount, tokenPos, maxChildren) + TreeNode node = tokenPos < maxMatchTokens - 1 && tokenPos < logTokenIds.length + ? new TreeNode.InnerTreeNode(docCount, tokenPos, maxUniqueTokens) : new TreeNode.LeafTreeNode(docCount, similarityThreshold); // The size of the node + entry (since it is a map entry) + extra reference for priority queue sizeInBytes += node.ramBytesUsed() + RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + RamUsageEstimator.NUM_BYTES_OBJECT_REF; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java index e4916ca46feaf..e84167fd1d0a9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilder.java @@ -43,14 +43,14 @@ public class CategorizeTextAggregationBuilder extends AbstractAggregationBuilder -1 ); - static final int MAX_MAX_CHILDREN = 100; - static final int MAX_MAX_DEPTH = 100; + static final int MAX_MAX_UNIQUE_TOKENS = 100; + static final int MAX_MAX_MATCHED_TOKENS = 100; public static final String NAME = "categorize_text"; static final ParseField FIELD_NAME = new ParseField("field"); - static final ParseField MAX_CHILDREN = new ParseField("max_children"); + static final ParseField MAX_UNIQUE_TOKENS = new ParseField("max_unique_tokens"); static final ParseField SIMILARITY_THRESHOLD = new ParseField("similarity_threshold"); - static final ParseField MAX_DEPTH = new ParseField("max_depth"); + static final ParseField MAX_MATCHED_TOKENS = new ParseField("max_matched_tokens"); static final ParseField CATEGORIZATION_FILTERS = new ParseField("categorization_filters"); static final ParseField CATEGORIZATION_ANALYZER = new ParseField("categorization_analyzer"); @@ -60,8 +60,8 @@ public class CategorizeTextAggregationBuilder extends AbstractAggregationBuilder ); static { PARSER.declareString(CategorizeTextAggregationBuilder::setFieldName, FIELD_NAME); - PARSER.declareInt(CategorizeTextAggregationBuilder::setMaxChildren, MAX_CHILDREN); - PARSER.declareInt(CategorizeTextAggregationBuilder::setMaxDepth, MAX_DEPTH); + PARSER.declareInt(CategorizeTextAggregationBuilder::setMaxUniqueTokens, MAX_UNIQUE_TOKENS); + PARSER.declareInt(CategorizeTextAggregationBuilder::setMaxMatchedTokens, MAX_MATCHED_TOKENS); PARSER.declareInt(CategorizeTextAggregationBuilder::setSimilarityThreshold, SIMILARITY_THRESHOLD); PARSER.declareField( CategorizeTextAggregationBuilder::setCategorizationAnalyzerConfig, @@ -85,9 +85,9 @@ public static CategorizeTextAggregationBuilder parse(String aggregationName, XCo ); private CategorizationAnalyzerConfig categorizationAnalyzerConfig; private String fieldName; - private int maxChildren = 50; + private int maxUniqueTokens = 50; private int similarityThreshold = 50; - private int maxDepth = 5; + private int maxMatchedTokens = 5; private CategorizeTextAggregationBuilder(String name) { super(name); @@ -111,24 +111,24 @@ public CategorizeTextAggregationBuilder(StreamInput in) throws IOException { super(in); this.bucketCountThresholds = new TermsAggregator.BucketCountThresholds(in); this.fieldName = in.readString(); - this.maxChildren = in.readVInt(); - this.maxDepth = in.readVInt(); + this.maxUniqueTokens = in.readVInt(); + this.maxMatchedTokens = in.readVInt(); this.similarityThreshold = in.readVInt(); this.categorizationAnalyzerConfig = in.readOptionalWriteable(CategorizationAnalyzerConfig::new); } - public int getMaxChildren() { - return maxChildren; + public int getMaxUniqueTokens() { + return maxUniqueTokens; } - public CategorizeTextAggregationBuilder setMaxChildren(int maxChildren) { - this.maxChildren = maxChildren; - if (maxChildren <= 0) { + public CategorizeTextAggregationBuilder setMaxUniqueTokens(int maxUniqueTokens) { + this.maxUniqueTokens = maxUniqueTokens; + if (maxUniqueTokens <= 0) { throw ExceptionsHelper.badRequestException( "[{}] must be greater than 0 and less than [{}]. Found [{}] in [{}]", - MAX_CHILDREN.getPreferredName(), - MAX_MAX_CHILDREN, - maxChildren, + MAX_UNIQUE_TOKENS.getPreferredName(), + MAX_MAX_UNIQUE_TOKENS, + maxUniqueTokens, name ); } @@ -185,18 +185,18 @@ public CategorizeTextAggregationBuilder setCategorizationFilters(List ca return this; } - public int getMaxDepth() { - return maxDepth; + public int getMaxMatchedTokens() { + return maxMatchedTokens; } - public CategorizeTextAggregationBuilder setMaxDepth(int maxDepth) { - this.maxDepth = maxDepth; - if (maxDepth <= 0) { + public CategorizeTextAggregationBuilder setMaxMatchedTokens(int maxMatchedTokens) { + this.maxMatchedTokens = maxMatchedTokens; + if (maxMatchedTokens <= 0) { throw ExceptionsHelper.badRequestException( "[{}] must be greater than 0 and less than [{}]. Found [{}] in [{}]", - MAX_DEPTH.getPreferredName(), - MAX_MAX_DEPTH, - maxDepth, + MAX_MATCHED_TOKENS.getPreferredName(), + MAX_MAX_MATCHED_TOKENS, + maxMatchedTokens, name ); } @@ -280,8 +280,8 @@ protected CategorizeTextAggregationBuilder( super(clone, factoriesBuilder, metadata); this.bucketCountThresholds = new TermsAggregator.BucketCountThresholds(clone.bucketCountThresholds); this.fieldName = clone.fieldName; - this.maxChildren = clone.maxChildren; - this.maxDepth = clone.maxDepth; + this.maxUniqueTokens = clone.maxUniqueTokens; + this.maxMatchedTokens = clone.maxMatchedTokens; this.similarityThreshold = clone.similarityThreshold; this.categorizationAnalyzerConfig = clone.categorizationAnalyzerConfig; } @@ -290,8 +290,8 @@ protected CategorizeTextAggregationBuilder( protected void doWriteTo(StreamOutput out) throws IOException { bucketCountThresholds.writeTo(out); out.writeString(fieldName); - out.writeVInt(maxChildren); - out.writeVInt(maxDepth); + out.writeVInt(maxUniqueTokens); + out.writeVInt(maxMatchedTokens); out.writeVInt(similarityThreshold); out.writeOptionalWriteable(categorizationAnalyzerConfig); } @@ -305,8 +305,8 @@ protected AggregatorFactory doBuild( return new CategorizeTextAggregatorFactory( name, fieldName, - maxChildren, - maxDepth, + maxUniqueTokens, + maxMatchedTokens, similarityThreshold, bucketCountThresholds, categorizationAnalyzerConfig, @@ -322,8 +322,8 @@ protected XContentBuilder internalXContent(XContentBuilder builder, Params param builder.startObject(); bucketCountThresholds.toXContent(builder, params); builder.field(FIELD_NAME.getPreferredName(), fieldName); - builder.field(MAX_CHILDREN.getPreferredName(), maxChildren); - builder.field(MAX_DEPTH.getPreferredName(), maxDepth); + builder.field(MAX_UNIQUE_TOKENS.getPreferredName(), maxUniqueTokens); + builder.field(MAX_MATCHED_TOKENS.getPreferredName(), maxMatchedTokens); builder.field(SIMILARITY_THRESHOLD.getPreferredName(), similarityThreshold); if (categorizationAnalyzerConfig != null) { categorizationAnalyzerConfig.toXContent(builder, params); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java index 03ae9ba3d6783..c2d46a511847a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java @@ -49,8 +49,8 @@ public class CategorizeTextAggregator extends DeferableBucketAggregator { private final CategorizationAnalyzer analyzer; private final String sourceFieldName; private ObjectArray categorizers; - private final int maxChildren; - private final int maxDepth; + private final int maxUniqueTokens; + private final int maxMatchTokens; private final int similarityThreshold; private final LongKeyedBucketOrds bucketOrds; private final CategorizationBytesRefHash bytesRefHash; @@ -63,8 +63,8 @@ protected CategorizeTextAggregator( String sourceFieldName, MappedFieldType fieldType, TermsAggregator.BucketCountThresholds bucketCountThresholds, - int maxChildren, - int maxDepth, + int maxUniqueTokens, + int maxMatchTokens, int similarityThreshold, CategorizationAnalyzerConfig categorizationAnalyzerConfig, Map metadata @@ -97,8 +97,8 @@ protected CategorizeTextAggregator( } this.analyzer = new CategorizationAnalyzer(innerAnalyzer, shouldClose); this.categorizers = bigArrays().newObjectArray(1); - this.maxChildren = maxChildren; - this.maxDepth = maxDepth; + this.maxUniqueTokens = maxUniqueTokens; + this.maxMatchTokens = maxMatchTokens; this.similarityThreshold = similarityThreshold; this.bucketOrds = LongKeyedBucketOrds.build(bigArrays(), CardinalityUpperBound.MANY); this.bucketCountThresholds = bucketCountThresholds; @@ -150,8 +150,8 @@ public InternalAggregation[] buildAggregations(long[] ordsToCollect) throws IOEx name, bucketCountThresholds.getRequiredSize(), bucketCountThresholds.getMinDocCount(), - maxChildren, - maxDepth, + maxUniqueTokens, + maxMatchTokens, similarityThreshold, metadata(), Arrays.asList(bucketArray) @@ -166,8 +166,8 @@ public InternalAggregation buildEmptyAggregation() { name, bucketCountThresholds.getRequiredSize(), bucketCountThresholds.getMinDocCount(), - maxChildren, - maxDepth, + maxUniqueTokens, + maxMatchTokens, similarityThreshold, metadata() ); @@ -217,7 +217,7 @@ private void processTokenStream(long owningBucketOrd, TokenStream ts, int doc) t categorizers = bigArrays().grow(categorizers, owningBucketOrd + 1); CategorizationTokenTree categorizer = categorizers.get(owningBucketOrd); if (categorizer == null) { - categorizer = new CategorizationTokenTree(maxChildren, maxDepth, similarityThreshold); + categorizer = new CategorizationTokenTree(maxUniqueTokens, maxMatchTokens, similarityThreshold); addRequestCircuitBreakerBytes(categorizer.ramBytesUsed()); categorizers.set(owningBucketOrd, categorizer); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java index ceb3b6c97e4ac..f63b4ba1f802b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorFactory.java @@ -26,8 +26,8 @@ public class CategorizeTextAggregatorFactory extends AggregatorFactory { private final MappedFieldType fieldType; private final String indexedFieldName; - private final int maxChildren; - private final int maxDepth; + private final int maxUniqueTokens; + private final int maxMatchTokens; private final int similarityThreshold; private final CategorizationAnalyzerConfig categorizationAnalyzerConfig; private final TermsAggregator.BucketCountThresholds bucketCountThresholds; @@ -35,8 +35,8 @@ public class CategorizeTextAggregatorFactory extends AggregatorFactory { public CategorizeTextAggregatorFactory( String name, String fieldName, - int maxChildren, - int maxDepth, + int maxUniqueTokens, + int maxMatchTokens, int similarityThreshold, TermsAggregator.BucketCountThresholds bucketCountThresholds, CategorizationAnalyzerConfig categorizationAnalyzerConfig, @@ -52,8 +52,8 @@ public CategorizeTextAggregatorFactory( } else { this.indexedFieldName = null; } - this.maxChildren = maxChildren; - this.maxDepth = maxDepth; + this.maxUniqueTokens = maxUniqueTokens; + this.maxMatchTokens = maxMatchTokens; this.similarityThreshold = similarityThreshold; this.categorizationAnalyzerConfig = categorizationAnalyzerConfig; this.bucketCountThresholds = bucketCountThresholds; @@ -64,8 +64,8 @@ protected Aggregator createUnmapped(Aggregator parent, Map metad name, bucketCountThresholds.getRequiredSize(), bucketCountThresholds.getMinDocCount(), - maxChildren, - maxDepth, + maxUniqueTokens, + maxMatchTokens, similarityThreshold, metadata ); @@ -101,8 +101,8 @@ protected Aggregator createInternal(Aggregator parent, CardinalityUpperBound car indexedFieldName, fieldType, bucketCountThresholds, - maxChildren, - maxDepth, + maxUniqueTokens, + maxMatchTokens, similarityThreshold, categorizationAnalyzerConfig, metadata diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java index 58f3d8df81d7f..16cad0731d500 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java @@ -247,9 +247,9 @@ public int compareTo(Bucket o) { } private final List buckets; - protected final int maxChildren; + protected final int maxUniqueTokens; protected final int similarityThreshold; - protected final int maxDepth; + protected final int maxMatchTokens; protected final int requiredSize; protected final long minDocCount; @@ -257,28 +257,28 @@ protected InternalCategorizationAggregation( String name, int requiredSize, long minDocCount, - int maxChildren, - int maxDepth, + int maxUniqueTokens, + int maxMatchTokens, int similarityThreshold, Map metadata ) { - this(name, requiredSize, minDocCount, maxChildren, maxDepth, similarityThreshold, metadata, new ArrayList<>()); + this(name, requiredSize, minDocCount, maxUniqueTokens, maxMatchTokens, similarityThreshold, metadata, new ArrayList<>()); } protected InternalCategorizationAggregation( String name, int requiredSize, long minDocCount, - int maxChildren, - int maxDepth, + int maxUniqueTokens, + int maxMatchTokens, int similarityThreshold, Map metadata, List buckets ) { super(name, metadata); this.buckets = buckets; - this.maxChildren = maxChildren; - this.maxDepth = maxDepth; + this.maxUniqueTokens = maxUniqueTokens; + this.maxMatchTokens = maxMatchTokens; this.similarityThreshold = similarityThreshold; this.minDocCount = minDocCount; this.requiredSize = requiredSize; @@ -286,8 +286,8 @@ protected InternalCategorizationAggregation( public InternalCategorizationAggregation(StreamInput in) throws IOException { super(in); - this.maxChildren = in.readVInt(); - this.maxDepth = in.readVInt(); + this.maxUniqueTokens = in.readVInt(); + this.maxMatchTokens = in.readVInt(); this.similarityThreshold = in.readVInt(); this.buckets = in.readList(Bucket::new); this.requiredSize = readSize(in); @@ -300,8 +300,8 @@ public InternalCategorizationAggregation create(List buckets) { name, requiredSize, minDocCount, - maxChildren, - maxDepth, + maxUniqueTokens, + maxMatchTokens, similarityThreshold, super.metadata, buckets @@ -330,8 +330,8 @@ public String getWriteableName() { @Override protected void doWriteTo(StreamOutput out) throws IOException { - out.writeVInt(maxChildren); - out.writeVInt(maxDepth); + out.writeVInt(maxUniqueTokens); + out.writeVInt(maxMatchTokens); out.writeVInt(similarityThreshold); out.writeList(buckets); writeSize(requiredSize, out); @@ -341,7 +341,11 @@ protected void doWriteTo(StreamOutput out) throws IOException { @Override public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { try (CategorizationBytesRefHash hash = new CategorizationBytesRefHash(new BytesRefHash(1L, reduceContext.bigArrays()))) { - CategorizationTokenTree categorizationTokenTree = new CategorizationTokenTree(maxChildren, maxDepth, similarityThreshold); + CategorizationTokenTree categorizationTokenTree = new CategorizationTokenTree( + maxUniqueTokens, + maxMatchTokens, + similarityThreshold + ); // TODO: Could we do a merge sort similar to terms? // It would require us returning partial reductions sorted by key, not by doc_count // First, make sure we have all the counts for equal log groups @@ -409,8 +413,8 @@ public InternalAggregation reduce(List aggregations, Reduce name, requiredSize, minDocCount, - maxChildren, - maxDepth, + maxUniqueTokens, + maxMatchTokens, similarityThreshold, metadata, Arrays.asList(bucketList) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java index 113399a833959..22975b97f3b9a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java @@ -33,8 +33,8 @@ public InternalCategorizationAggregation create(List buckets) { name, requiredSize, minDocCount, - maxChildren, - maxDepth, + maxUniqueTokens, + maxMatchTokens, similarityThreshold, super.metadata ); @@ -51,8 +51,8 @@ public InternalAggregation reduce(List aggregations, Reduce name, requiredSize, minDocCount, - maxChildren, - maxDepth, + maxUniqueTokens, + maxMatchTokens, similarityThreshold, super.metadata ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java index 809ae245d13f1..7b907ea3ecd29 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java @@ -35,10 +35,10 @@ protected CategorizeTextAggregationBuilder createTestAggregatorBuilder() { builder.setCategorizationAnalyzerConfig(CategorizationAnalyzerConfigTests.createRandomized().build()); } if (randomBoolean()) { - builder.setMaxChildren(randomIntBetween(1, 500)); + builder.setMaxUniqueTokens(randomIntBetween(1, 500)); } if (randomBoolean()) { - builder.setMaxDepth(randomIntBetween(1, 10)); + builder.setMaxMatchedTokens(randomIntBetween(1, 10)); } if (randomBoolean()) { builder.setSimilarityThreshold(randomIntBetween(1, 100)); diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml index e6190185fe920..c2d5e0dbf09f1 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/categorization_agg.yml @@ -69,8 +69,8 @@ setup: "categorize_text": { "field": "text", "size": 10, - "max_children": 2, - "max_depth": 1, + "max_unique_tokens": 2, + "max_matched_tokens": 1, "similarity_threshold": 11 } } @@ -86,7 +86,7 @@ setup: "Test categorization aggregation with poor settings": - do: - catch: /\[max_children\] must be greater than 0 and less than \[100\]/ + catch: /\[max_unique_tokens\] must be greater than 0 and less than \[100\]/ search: index: to_categorize body: > @@ -96,13 +96,13 @@ setup: "categories": { "categorize_text": { "field": "text", - "max_children": -2 + "max_unique_tokens": -2 } } } } - do: - catch: /\[max_depth\] must be greater than 0 and less than \[100\]/ + catch: /\[max_matched_tokens\] must be greater than 0 and less than \[100\]/ search: index: to_categorize body: > @@ -112,7 +112,7 @@ setup: "categories": { "categorize_text": { "field": "text", - "max_depth": -2 + "max_matched_tokens": -2 } } } From 6f940e3171762d034fda3b36397560f7e91b05f7 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 27 Sep 2021 07:23:50 -0400 Subject: [PATCH 16/20] fixing accidental code change --- docs/reference/query-dsl/has-child-query.asciidoc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/reference/query-dsl/has-child-query.asciidoc b/docs/reference/query-dsl/has-child-query.asciidoc index 9021f04c762b2..7b8158aaab5bf 100644 --- a/docs/reference/query-dsl/has-child-query.asciidoc +++ b/docs/reference/query-dsl/has-child-query.asciidoc @@ -16,7 +16,7 @@ unique parent documents increases. Each `has_child` query in a search can increase query time significantly. If you care about query performance, do not use this query. If you need to use -the `has_child` query, use it as rarely as possible. +the `has_child` query, use it as rarely as possible. ==== [[has-child-query-ex-request]] @@ -59,7 +59,7 @@ GET /_search "query": { "match_all": {} }, - "max_unique_tokens": 10, + "max_children": 10, "min_children": 2, "score_mode": "min" } From 191e9ad442c48e51381d0649eca1a6ffc69a6090 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 27 Sep 2021 16:27:21 -0400 Subject: [PATCH 17/20] Addressing PR comments --- .../categorize-text-aggregation.asciidoc | 10 +- .../CategorizationBytesRefHash.java | 14 +- .../CategorizationTokenTree.java | 48 ++--- .../CategorizeTextAggregator.java | 70 ++++--- .../InternalCategorizationAggregation.java | 93 +++++---- .../categorization/TextCategorization.java | 8 +- .../ml/aggs/categorization/TreeNode.java | 16 +- .../aggs/categorization/TreeNodeFactory.java | 15 -- .../UnmappedCategorizationAggregation.java | 20 +- .../CategorizationAnalyzer.java | 3 +- .../CategorizeTextAggregatorTests.java | 181 +++++++++++------- .../categorization/InnerTreeNodeTests.java | 6 +- .../categorization/LeafTreeNodeTests.java | 6 +- 13 files changed, 261 insertions(+), 229 deletions(-) delete mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java diff --git a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc index 636b6c32d1f9d..091d2111b8070 100644 --- a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc +++ b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc @@ -218,7 +218,12 @@ category results } -------------------------------------------------- -Here is an example using `categorization_filters` +Here is an example using `categorization_filters`. +The default analyzer is a whitespace analyzer with a custom token filter +which filters out tokens that start with any number. +But, it may be that a token is a known highly-variable token (formatted usernames, emails, etc.). In that case, it is good to supply +custom `categorization_filters` to filter out those tokens for better categories. These filters will also reduce memory usage as fewer +tokens are held in memory for the categories. [source,console] -------------------------------------------------- @@ -266,7 +271,8 @@ and merging the log groups. } -------------------------------------------------- -This aggregation can have both sub-aggregations and itself be a sub-aggregation. +This aggregation can have both sub-aggregations and itself be a sub-aggregation. This allows gathering the top daily categories and the +top sample doc as below. [source,console] -------------------------------------------------- diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java index 1e4143993d8f7..6246683ddfe6e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationBytesRefHash.java @@ -10,14 +10,18 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.logging.LoggerMessageFormat; import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.core.Releasable; import org.elasticsearch.search.aggregations.AggregationExecutionException; -import java.io.Closeable; -import java.io.IOException; - -class CategorizationBytesRefHash implements Closeable { +class CategorizationBytesRefHash implements Releasable { + /** + * Our special wild card value. + */ static final BytesRef WILD_CARD_REF = new BytesRef("*"); + /** + * For all WILD_CARD references, the token ID is always -1 + */ static final int WILD_CARD_ID = -1; private final BytesRefHash bytesRefHash; @@ -73,7 +77,7 @@ int put(BytesRef bytesRef) { } @Override - public void close() throws IOException { + public void close() { bytesRefHash.close(); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java index 560bcddee8bf1..a4d552d35f4e3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java @@ -15,14 +15,13 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicLong; +import java.util.Optional; import java.util.stream.Collectors; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; /** * Categorized semi-structured text utilizing the drain algorithm: https://arxiv.org/pdf/1806.04356.pdf - * With the following key differntiators + * With the following key differences * - This structure keeps track of the "smallest" sub-tree. So, instead of naively adding a new "*" node, the smallest sub-tree * is transformed if the incoming token has a higher doc_count. * - Additionally, similarities are weighted, which allows for nicer merging of existing log categories @@ -51,13 +50,13 @@ * If the similarityThreshold was less than 0.6, the result would be a single category [Node is *] * */ -public class CategorizationTokenTree implements Accountable, TreeNodeFactory { +public class CategorizationTokenTree implements Accountable { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(CategorizationTokenTree.class); private final int maxMatchTokens; private final int maxUniqueTokens; private final int similarityThreshold; - private final AtomicLong idGen = new AtomicLong(); - // TODO statically allocate an array like DuplicateByteSequenceSpotter ??? + private long idGenerator; private final Map root = new HashMap<>(); private long sizeInBytes; @@ -66,12 +65,7 @@ public CategorizationTokenTree(int maxUniqueTokens, int maxMatchTokens, int simi this.maxUniqueTokens = maxUniqueTokens; this.maxMatchTokens = maxMatchTokens; this.similarityThreshold = similarityThreshold; - this.sizeInBytes = Integer.BYTES // maxDepth - + Integer.BYTES // maxChildren - + Double.BYTES // similarityThreshold - + NUM_BYTES_OBJECT_REF + Long.BYTES // idGen - + NUM_BYTES_OBJECT_REF // tree map - + Long.BYTES; // sizeInBytes + this.sizeInBytes = SHALLOW_SIZE; } public List toIntermediateBuckets(CategorizationBytesRefHash hash) { @@ -95,18 +89,26 @@ void mergeSmallestChildren() { root.values().forEach(TreeNode::collapseTinyChildren); } - public TextCategorization parseLogLine(final int[] logTokenIds) { - return parseLogLine(logTokenIds, 1); - } - - public TextCategorization parseLogLineConst(final int[] logTokenIds) { + /** + * This method does not mutate the underlying structure. Meaning, if a matching categories isn't found, it may return empty. + * + * @param logTokenIds The tokens to categorize + * @return The log category or `Optional.empty()` if one doesn't exist + */ + public Optional parseLogLineConst(final int[] logTokenIds) { TreeNode currentNode = this.root.get(logTokenIds.length); if (currentNode == null) { // we are missing an entire sub tree. New log length found - return null; + return Optional.empty(); } - return currentNode.getLogGroup(logTokenIds); + return Optional.ofNullable(currentNode.getLogGroup(logTokenIds)); } + /** + * This categorizes the passed tokens, potentially mutating the structure by expanding an existing category or adding a new one. + * @param logTokenIds The log tokens to categorize + * @param docCount The count of docs for the given tokens + * @return An existing categorization or a newly created one + */ public TextCategorization parseLogLine(final int[] logTokenIds, long docCount) { TreeNode currentNode = this.root.get(logTokenIds.length); if (currentNode == null) { // we are missing an entire sub tree. New log length found @@ -118,8 +120,7 @@ public TextCategorization parseLogLine(final int[] logTokenIds, long docCount) { return currentNode.addLog(logTokenIds, docCount, this); } - @Override - public TreeNode newNode(long docCount, int tokenPos, int[] logTokenIds) { + TreeNode newNode(long docCount, int tokenPos, int[] logTokenIds) { TreeNode node = tokenPos < maxMatchTokens - 1 && tokenPos < logTokenIds.length ? new TreeNode.InnerTreeNode(docCount, tokenPos, maxUniqueTokens) : new TreeNode.LeafTreeNode(docCount, similarityThreshold); @@ -128,9 +129,8 @@ public TreeNode newNode(long docCount, int tokenPos, int[] logTokenIds) { return node; } - @Override - public TextCategorization newGroup(long docCount, int[] logTokenIds) { - TextCategorization group = new TextCategorization(logTokenIds, docCount, idGen.incrementAndGet()); + TextCategorization newGroup(long docCount, int[] logTokenIds) { + TextCategorization group = new TextCategorization(logTokenIds, docCount, idGenerator++); // Get the regular size bytes from the LogGroup and how much it costs to reference it sizeInBytes += group.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF; return group; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java index c2d46a511847a..acd42eb1c6dda 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java @@ -12,10 +12,10 @@ import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.BytesRefBuilder; import org.apache.lucene.util.PriorityQueue; import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.core.Releasables; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.AggregatorFactories; @@ -32,7 +32,6 @@ import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer; import java.io.IOException; -import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -75,27 +74,25 @@ protected CategorizeTextAggregator( this.fieldType = fieldType; CategorizationAnalyzerConfig analyzerConfig = Optional.ofNullable(categorizationAnalyzerConfig) .orElse(CategorizationAnalyzerConfig.buildStandardCategorizationAnalyzer(Collections.emptyList())); - String analyzer = analyzerConfig.getAnalyzer(); - final boolean shouldClose; - final Analyzer innerAnalyzer; - if (analyzer != null) { - Analyzer globalAnalyzer = context.getNamedAnalyzer(analyzer); + final String analyzerName = analyzerConfig.getAnalyzer(); + if (analyzerName != null) { + Analyzer globalAnalyzer = context.getNamedAnalyzer(analyzerName); if (globalAnalyzer == null) { - throw new IllegalArgumentException("Failed to find global analyzer [" + analyzer + "]"); + throw new IllegalArgumentException("Failed to find global analyzer [" + analyzerName + "]"); } - innerAnalyzer = globalAnalyzer; - shouldClose = false; + this.analyzer = new CategorizationAnalyzer(globalAnalyzer, false); } else { - innerAnalyzer = context.buildCustomAnalyzer( - context.getIndexSettings(), - false, - analyzerConfig.getTokenizer(), - analyzerConfig.getCharFilters(), - analyzerConfig.getTokenFilters() + this.analyzer = new CategorizationAnalyzer( + context.buildCustomAnalyzer( + context.getIndexSettings(), + false, + analyzerConfig.getTokenizer(), + analyzerConfig.getCharFilters(), + analyzerConfig.getTokenFilters() + ), + true ); - shouldClose = true; } - this.analyzer = new CategorizationAnalyzer(innerAnalyzer, shouldClose); this.categorizers = bigArrays().newObjectArray(1); this.maxUniqueTokens = maxUniqueTokens; this.maxMatchTokens = maxMatchTokens; @@ -108,13 +105,7 @@ protected CategorizeTextAggregator( @Override protected void doClose() { super.doClose(); - this.analyzer.close(); - try { - this.bytesRefHash.close(); - } catch (IOException ex) { - //TODO Should we just eat the exception? - throw new UncheckedIOException(ex); - } + Releasables.close(this.analyzer, this.bytesRefHash); } @Override @@ -176,14 +167,19 @@ public InternalAggregation buildEmptyAggregation() { @Override protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { return new LeafBucketCollectorBase(sub, null) { - private final BytesRefBuilder scratch = new BytesRefBuilder(); - @Override public void collect(int doc, long owningBucketOrd) throws IOException { - collectFromSource(doc, owningBucketOrd); + categorizers = bigArrays().grow(categorizers, owningBucketOrd + 1); + CategorizationTokenTree categorizer = categorizers.get(owningBucketOrd); + if (categorizer == null) { + categorizer = new CategorizationTokenTree(maxUniqueTokens, maxMatchTokens, similarityThreshold); + addRequestCircuitBreakerBytes(categorizer.ramBytesUsed()); + categorizers.set(owningBucketOrd, categorizer); + } + collectFromSource(doc, owningBucketOrd, categorizer); } - private void collectFromSource(int doc, long owningBucketOrd) throws IOException { + private void collectFromSource(int doc, long owningBucketOrd, CategorizationTokenTree categorizer) throws IOException { sourceLookup.setSegmentAndDocument(ctx, doc); Iterator itr = sourceLookup.extractRawValues(sourceFieldName).stream().map(obj -> { if (obj == null) { @@ -196,11 +192,16 @@ private void collectFromSource(int doc, long owningBucketOrd) throws IOException }).iterator(); while (itr.hasNext()) { TokenStream ts = analyzer.tokenStream(fieldType.name(), itr.next()); - processTokenStream(owningBucketOrd, ts, doc); + processTokenStream(owningBucketOrd, ts, doc, categorizer); } } - private void processTokenStream(long owningBucketOrd, TokenStream ts, int doc) throws IOException { + private void processTokenStream( + long owningBucketOrd, + TokenStream ts, + int doc, + CategorizationTokenTree categorizer + ) throws IOException { ArrayList tokens = new ArrayList<>(); try { CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); @@ -214,13 +215,6 @@ private void processTokenStream(long owningBucketOrd, TokenStream ts, int doc) t } finally { ts.close(); } - categorizers = bigArrays().grow(categorizers, owningBucketOrd + 1); - CategorizationTokenTree categorizer = categorizers.get(owningBucketOrd); - if (categorizer == null) { - categorizer = new CategorizationTokenTree(maxUniqueTokens, maxMatchTokens, similarityThreshold); - addRequestCircuitBreakerBytes(categorizer.ramBytesUsed()); - categorizers.set(owningBucketOrd, categorizer); - } long previousSize = categorizer.ramBytesUsed(); TextCategorization lg = categorizer.parseLogLine( tokens.stream().mapToInt(Integer::valueOf).toArray(), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java index 16cad0731d500..b8530f4988668 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java @@ -23,7 +23,6 @@ import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation; import java.io.IOException; -import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; @@ -192,6 +191,13 @@ public Bucket(StreamInput in) throws IOException { aggregations = InternalAggregations.readFrom(in); } + @Override + public void writeTo(StreamOutput out) throws IOException { + key.writeTo(out); + out.writeVLong(getDocCount()); + aggregations.writeTo(out); + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -227,13 +233,6 @@ public Aggregations getAggregations() { return aggregations; } - @Override - public void writeTo(StreamOutput out) throws IOException { - key.writeTo(out); - out.writeVLong(getDocCount()); - aggregations.writeTo(out); - } - @Override public String toString() { return "Bucket{" + "key=" + getKeyAsString() + ", docCount=" + docCount + ", aggregations=" + aggregations.asMap() + "}\n"; @@ -247,11 +246,11 @@ public int compareTo(Bucket o) { } private final List buckets; - protected final int maxUniqueTokens; - protected final int similarityThreshold; - protected final int maxMatchTokens; - protected final int requiredSize; - protected final long minDocCount; + private final int maxUniqueTokens; + private final int similarityThreshold; + private final int maxMatchTokens; + private final int requiredSize; + private final long minDocCount; protected InternalCategorizationAggregation( String name, @@ -294,6 +293,26 @@ public InternalCategorizationAggregation(StreamInput in) throws IOException { this.minDocCount = in.readVLong(); } + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeVInt(maxUniqueTokens); + out.writeVInt(maxMatchTokens); + out.writeVInt(similarityThreshold); + out.writeList(buckets); + writeSize(requiredSize, out); + out.writeVLong(minDocCount); + } + + @Override + public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + builder.startArray(CommonFields.BUCKETS.getPreferredName()); + for (Bucket bucket : buckets) { + bucket.toXContent(builder, params); + } + builder.endArray(); + return builder; + } + @Override public InternalCategorizationAggregation create(List buckets) { return new InternalCategorizationAggregation( @@ -328,16 +347,6 @@ public String getWriteableName() { return CategorizeTextAggregationBuilder.NAME; } - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeVInt(maxUniqueTokens); - out.writeVInt(maxMatchTokens); - out.writeVInt(similarityThreshold); - out.writeList(buckets); - writeSize(requiredSize, out); - out.writeVLong(minDocCount); - } - @Override public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { try (CategorizationBytesRefHash hash = new CategorizationBytesRefHash(new BytesRefHash(1L, reduceContext.bigArrays()))) { @@ -367,12 +376,12 @@ public InternalAggregation reduce(List aggregations, Reduce categorizationTokenTree.mergeSmallestChildren(); Map mergedBuckets = new HashMap<>(); for (DelayedCategorizationBucket delayedBucket : reduced.values()) { - TextCategorization group = categorizationTokenTree.parseLogLineConst(hash.getIds(delayedBucket.key.keyAsTokens())); - if (group == null) { - throw new AggregationExecutionException( - "Unexpected null categorization group for bucket [" + delayedBucket.key.asString() + "]" + TextCategorization group = categorizationTokenTree.parseLogLineConst(hash.getIds(delayedBucket.key.keyAsTokens())) + .orElseThrow( + () -> new AggregationExecutionException( + "Unexpected null categorization group for bucket [" + delayedBucket.key.asString() + "]" + ) ); - } BytesRef[] categoryTokens = hash.getDeeps(group.getCategorization()); BucketKey key = reduceContext.isFinalReduce() ? @@ -419,18 +428,26 @@ public InternalAggregation reduce(List aggregations, Reduce metadata, Arrays.asList(bucketList) ); - } catch (IOException ex) { - throw new UncheckedIOException(ex); } } - @Override - public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { - builder.startArray(CommonFields.BUCKETS.getPreferredName()); - for (Bucket bucket : buckets) { - bucket.toXContent(builder, params); - } - builder.endArray(); - return builder; + public int getMaxUniqueTokens() { + return maxUniqueTokens; + } + + public int getSimilarityThreshold() { + return similarityThreshold; + } + + public int getMaxMatchTokens() { + return maxMatchTokens; + } + + public int getRequiredSize() { + return requiredSize; + } + + public long getMinDocCount() { + return minDocCount; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java index 84c37c2486ef3..e4370e2f49776 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java @@ -21,6 +21,7 @@ */ class TextCategorization implements Accountable { + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TextCategorization.class); private final long id; private final int[] categorization; private final long[] tokenCounts; @@ -81,12 +82,9 @@ void addLog(int[] logEvent, long docCount) { @Override public long ramBytesUsed() { - return Long.BYTES // id - + RamUsageEstimator.NUM_BYTES_OBJECT_REF // categorization reference + return SHALLOW_SIZE + RamUsageEstimator.sizeOf(categorization) // categorization token Ids - + RamUsageEstimator.NUM_BYTES_OBJECT_REF // tokenCounts reference - + RamUsageEstimator.sizeOf(tokenCounts) // tokenCounts - + Long.BYTES; // count + + RamUsageEstimator.sizeOf(tokenCounts); // tokenCounts } static class Similarity implements Comparable { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java index badde8f152a32..907946e7d1790 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java @@ -7,8 +7,6 @@ package org.elasticsearch.xpack.ml.aggs.categorization; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.apache.lucene.util.Accountable; import org.elasticsearch.core.Tuple; import org.elasticsearch.search.aggregations.AggregationExecutionException; @@ -37,8 +35,6 @@ */ abstract class TreeNode implements Accountable { - private static final Logger LOGGER = LogManager.getLogger(TreeNode.class); - private long count; TreeNode(long count) { @@ -58,7 +54,7 @@ final long getCount() { } // TODO add option for calculating the cost of adding the new group - abstract TextCategorization addLog(int[] logTokenIds, long docCount, TreeNodeFactory treeNodeFactory); + abstract TextCategorization addLog(int[] logTokenIds, long docCount, CategorizationTokenTree treeNodeFactory); abstract TextCategorization getLogGroup(int[] logTokens); @@ -79,6 +75,7 @@ static class LeafTreeNode extends TreeNode { } } + @Override public boolean isLeaf() { return true; } @@ -111,7 +108,7 @@ public long ramBytesUsed() { } @Override - public TextCategorization addLog(int[] logTokenIds, long docCount, TreeNodeFactory treeNodeFactory) { + public TextCategorization addLog(int[] logTokenIds, long docCount, CategorizationTokenTree treeNodeFactory) { return getAndUpdateLogGroup(logTokenIds, docCount).orElseGet(() -> { // Need to update the tree if possible return putNewLogGroup(treeNodeFactory.newGroup(docCount, logTokenIds)); @@ -198,6 +195,7 @@ static class InnerTreeNode extends TreeNode { this.smallestChild = new PriorityQueue<>(maxChildren, Comparator.comparing(NativeLongPair::count)); } + @Override boolean isLeaf() { return false; } @@ -222,7 +220,7 @@ public long ramBytesUsed() { } @Override - public TextCategorization addLog(final int[] logTokenIds, final long docCount, final TreeNodeFactory treeNodeFactory) { + public TextCategorization addLog(final int[] logTokenIds, final long docCount, final CategorizationTokenTree treeNodeFactory) { final long currentToken = logTokenIds[childrenTokenPos]; TreeNode child = getChild(currentToken).map(node -> { node.incCount(docCount); @@ -392,10 +390,6 @@ static NativeLongPair of(long tokenId, long count) { this.count = count; } - public long tokenId() { - return tokenId; - } - public long count() { return count; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java deleted file mode 100644 index adbc7997cb5c3..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNodeFactory.java +++ /dev/null @@ -1,15 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.aggs.categorization; - - -interface TreeNodeFactory { - TreeNode newNode(long docCount, int tokenPos, int[] logTokenIds); - - TextCategorization newGroup(long docCount, int[] logTokenIds); -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java index 22975b97f3b9a..ae1081f66d09f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/UnmappedCategorizationAggregation.java @@ -31,11 +31,11 @@ protected UnmappedCategorizationAggregation( public InternalCategorizationAggregation create(List buckets) { return new UnmappedCategorizationAggregation( name, - requiredSize, - minDocCount, - maxUniqueTokens, - maxMatchTokens, - similarityThreshold, + getRequiredSize(), + getMinDocCount(), + getMaxUniqueTokens(), + getMaxMatchTokens(), + getSimilarityThreshold(), super.metadata ); } @@ -49,11 +49,11 @@ public Bucket createBucket(InternalAggregations aggregations, Bucket prototype) public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { return new UnmappedCategorizationAggregation( name, - requiredSize, - minDocCount, - maxUniqueTokens, - maxMatchTokens, - similarityThreshold, + getRequiredSize(), + getMinDocCount(), + getMaxUniqueTokens(), + getMaxMatchTokens(), + getSimilarityThreshold(), super.metadata ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java index d7e403aa59034..a3753c4105eab 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java @@ -10,6 +10,7 @@ import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; @@ -25,7 +26,7 @@ * Converts messages to lists of tokens that will be fed to the ML categorization algorithm. * */ -public class CategorizationAnalyzer implements Closeable { +public class CategorizationAnalyzer implements Closeable, Releasable { private final Analyzer analyzer; private final boolean closeAnalyzer; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java index 8560da8f76190..95cfdcb0f8f8f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregatorTests.java @@ -9,11 +9,8 @@ import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.document.StoredField; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.RandomIndexWriter; -import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; -import org.apache.lucene.store.Directory; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.env.Environment; @@ -24,6 +21,7 @@ import org.elasticsearch.search.aggregations.AggregatorTestCase; import org.elasticsearch.search.aggregations.bucket.histogram.Histogram; import org.elasticsearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder; +import org.elasticsearch.search.aggregations.bucket.histogram.InternalHistogram; import org.elasticsearch.search.aggregations.metrics.Avg; import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.Max; @@ -60,17 +58,11 @@ protected List getSearchPlugins() { private static final String NUMERIC_FIELD_NAME = "value"; public void testCategorizationWithoutSubAggs() throws Exception { - try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir)) { - writeTestDocs(w); - CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME); - try (IndexReader reader = w.getReader()) { - IndexSearcher searcher = new IndexSearcher(reader); - InternalCategorizationAggregation result = searchAndReduce( - searcher, - new MatchAllDocsQuery(), - aggBuilder, - new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME) - ); + testCase( + new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME), + new MatchAllDocsQuery(), + CategorizeTextAggregatorTests::writeTestDocs, + (InternalCategorizationAggregation result) -> { assertThat(result.getBuckets(), hasSize(2)); assertThat(result.getBuckets().get(0).docCount, equalTo(6L)); assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); @@ -79,27 +71,23 @@ public void testCategorizationWithoutSubAggs() throws Exception { result.getBuckets().get(1).getKeyAsString(), equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") ); - } - } + }, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); } public void testCategorizationWithSubAggs() throws Exception { - try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir)) { - writeTestDocs(w); - CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( + CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME) ) - .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) - .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)); - try (IndexReader reader = w.getReader()) { - IndexSearcher searcher = new IndexSearcher(reader); - InternalCategorizationAggregation result = searchAndReduce( - searcher, - new MatchAllDocsQuery(), - aggBuilder, - new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), - longField(NUMERIC_FIELD_NAME) - ); + .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) + .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)); + testCase( + aggBuilder, + new MatchAllDocsQuery(), + CategorizeTextAggregatorTests::writeTestDocs, + (InternalCategorizationAggregation result) -> { assertThat(result.getBuckets(), hasSize(2)); assertThat(result.getBuckets().get(0).docCount, equalTo(6L)); assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); @@ -115,29 +103,25 @@ public void testCategorizationWithSubAggs() throws Exception { assertThat(((Max) result.getBuckets().get(1).aggregations.get("max")).getValue(), equalTo(4.0)); assertThat(((Min) result.getBuckets().get(1).aggregations.get("min")).getValue(), equalTo(0.0)); assertThat(((Avg) result.getBuckets().get(1).aggregations.get("avg")).getValue(), equalTo(2.0)); - } - } + }, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); } public void testCategorizationWithMultiBucketSubAggs() throws Exception { - try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir)) { - writeTestDocs(w); - CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( - new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) - .interval(2) - .subAggregation(new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME)) - .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) - .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) - ); - try (IndexReader reader = w.getReader()) { - IndexSearcher searcher = new IndexSearcher(reader); - InternalCategorizationAggregation result = searchAndReduce( - searcher, - new MatchAllDocsQuery(), - aggBuilder, - new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), - longField(NUMERIC_FIELD_NAME) - ); + CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( + new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) + .interval(2) + .subAggregation(new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME)) + .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) + .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) + ); + testCase( + aggBuilder, + new MatchAllDocsQuery(), + CategorizeTextAggregatorTests::writeTestDocs, + (InternalCategorizationAggregation result) -> { assertThat(result.getBuckets(), hasSize(2)); assertThat(result.getBuckets().get(0).docCount, equalTo(6L)); assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); @@ -168,31 +152,27 @@ public void testCategorizationWithMultiBucketSubAggs() throws Exception { assertThat(histo.getBuckets().get(2).getDocCount(), equalTo(1L)); assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.0)); assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.0)); - } - } + }, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); } public void testCategorizationAsSubAgg() throws Exception { - try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir)) { - writeTestDocs(w); - HistogramAggregationBuilder aggBuilder = new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) - .interval(2) - .subAggregation( - new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( + HistogramAggregationBuilder aggBuilder = new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) + .interval(2) + .subAggregation( + new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME) ) - .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) - .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) - ); - try (IndexReader reader = w.getReader()) { - IndexSearcher searcher = new IndexSearcher(reader); - Histogram result = searchAndReduce( - searcher, - new MatchAllDocsQuery(), - aggBuilder, - new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), - longField(NUMERIC_FIELD_NAME) - ); + .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) + .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) + ); + testCase( + aggBuilder, + new MatchAllDocsQuery(), + CategorizeTextAggregatorTests::writeTestDocs, + (InternalHistogram result) -> { assertThat(result.getBuckets(), hasSize(3)); // First histo bucket @@ -242,8 +222,59 @@ public void testCategorizationAsSubAgg() throws Exception { assertThat(((Max) categorizationAggregation.getBuckets().get(1).aggregations.get("max")).getValue(), equalTo(4.0)); assertThat(((Min) categorizationAggregation.getBuckets().get(1).aggregations.get("min")).getValue(), equalTo(4.0)); assertThat(((Avg) categorizationAggregation.getBuckets().get(1).aggregations.get("avg")).getValue(), equalTo(4.0)); - } - } + }, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); + } + + public void testCategorizationWithSubAggsManyDocs() throws Exception { + CategorizeTextAggregationBuilder aggBuilder = new CategorizeTextAggregationBuilder("my_agg", TEXT_FIELD_NAME).subAggregation( + new HistogramAggregationBuilder("histo").field(NUMERIC_FIELD_NAME) + .interval(2) + .subAggregation(new MaxAggregationBuilder("max").field(NUMERIC_FIELD_NAME)) + .subAggregation(new AvgAggregationBuilder("avg").field(NUMERIC_FIELD_NAME)) + .subAggregation(new MinAggregationBuilder("min").field(NUMERIC_FIELD_NAME)) + ); + testCase( + aggBuilder, + new MatchAllDocsQuery(), + CategorizeTextAggregatorTests::writeManyTestDocs, + (InternalCategorizationAggregation result) -> { + assertThat(result.getBuckets(), hasSize(2)); + assertThat(result.getBuckets().get(0).docCount, equalTo(30_000L)); + assertThat(result.getBuckets().get(0).getKeyAsString(), equalTo("Node started")); + Histogram histo = result.getBuckets().get(0).aggregations.get("histo"); + assertThat(histo.getBuckets(), hasSize(3)); + for (Histogram.Bucket bucket : histo.getBuckets()) { + assertThat(bucket.getDocCount(), equalTo(10_000L)); + } + assertThat(((Max) histo.getBuckets().get(0).getAggregations().get("max")).getValue(), equalTo(1.0)); + assertThat(((Min) histo.getBuckets().get(0).getAggregations().get("min")).getValue(), equalTo(0.0)); + assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.5)); + assertThat(((Max) histo.getBuckets().get(1).getAggregations().get("max")).getValue(), equalTo(3.0)); + assertThat(((Min) histo.getBuckets().get(1).getAggregations().get("min")).getValue(), equalTo(2.0)); + assertThat(((Avg) histo.getBuckets().get(1).getAggregations().get("avg")).getValue(), equalTo(2.5)); + assertThat(((Max) histo.getBuckets().get(2).getAggregations().get("max")).getValue(), equalTo(5.0)); + assertThat(((Min) histo.getBuckets().get(2).getAggregations().get("min")).getValue(), equalTo(4.0)); + assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.5)); + + assertThat(result.getBuckets().get(1).docCount, equalTo(10_000L)); + assertThat( + result.getBuckets().get(1).getKeyAsString(), + equalTo("Failed to shutdown error org.aaaa.bbbb.Cccc line caused by foo exception") + ); + histo = result.getBuckets().get(1).aggregations.get("histo"); + assertThat(histo.getBuckets(), hasSize(3)); + assertThat(histo.getBuckets().get(0).getDocCount(), equalTo(5_000L)); + assertThat(histo.getBuckets().get(1).getDocCount(), equalTo(0L)); + assertThat(histo.getBuckets().get(2).getDocCount(), equalTo(5_000L)); + assertThat(((Avg) histo.getBuckets().get(0).getAggregations().get("avg")).getValue(), equalTo(0.0)); + assertThat(((Avg) histo.getBuckets().get(2).getAggregations().get("avg")).getValue(), equalTo(4.0)); + }, + new TextFieldMapper.TextFieldType(TEXT_FIELD_NAME), + longField(NUMERIC_FIELD_NAME) + ); } private static void writeTestDocs(RandomIndexWriter w) throws IOException { @@ -302,4 +333,10 @@ private static void writeTestDocs(RandomIndexWriter w) throws IOException { ) ); } + + private static void writeManyTestDocs(RandomIndexWriter w) throws IOException { + for (int i = 0; i < 5_000; i++) { + writeTestDocs(w); + } + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java index ac52eab68c826..bf7c49b6dd511 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java @@ -13,8 +13,6 @@ import org.junit.After; import org.junit.Before; -import java.io.IOException; - import static org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash.WILD_CARD_ID; import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.getTokens; import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.mockBigArrays; @@ -23,7 +21,7 @@ public class InnerTreeNodeTests extends ESTestCase { - private final TreeNodeFactory factory = new CategorizationTokenTree(3, 4, 60); + private final CategorizationTokenTree factory = new CategorizationTokenTree(3, 4, 60); private CategorizationBytesRefHash bytesRefHash; @Before @@ -32,7 +30,7 @@ public void createRefHash() { } @After - public void closeRefHash() throws IOException { + public void closeRefHash() { bytesRefHash.close(); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java index 1e8ac11217812..3e5730e67dfe2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java @@ -12,8 +12,6 @@ import org.junit.After; import org.junit.Before; -import java.io.IOException; - import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.getTokens; import static org.elasticsearch.xpack.ml.aggs.categorization.TextCategorizationTests.mockBigArrays; import static org.hamcrest.Matchers.equalTo; @@ -22,7 +20,7 @@ public class LeafTreeNodeTests extends ESTestCase { - private final TreeNodeFactory factory = new CategorizationTokenTree(10, 10, 60); + private final CategorizationTokenTree factory = new CategorizationTokenTree(10, 10, 60); private CategorizationBytesRefHash bytesRefHash; @@ -32,7 +30,7 @@ public void createRefHash() { } @After - public void closeRefHash() throws IOException { + public void closeRefHash() { bytesRefHash.close(); } From 3210d545ca42e3589ee9f129aff853145468f20c Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 28 Sep 2021 10:25:55 -0400 Subject: [PATCH 18/20] fixing benchmark spotless --- .../search/aggregations/AggConstructionContentionBenchmark.java | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java index 61608a4d30ef8..43a2d1930d0e8 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/AggConstructionContentionBenchmark.java @@ -22,7 +22,6 @@ import org.elasticsearch.core.Releasables; import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexSettings; -import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.index.analysis.NameOrDefinition; import org.elasticsearch.index.analysis.NamedAnalyzer; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; From 6cf0c299930265c539bba0ff6dbebe548bd0d4c8 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 29 Sep 2021 07:22:17 -0400 Subject: [PATCH 19/20] addressing PR comments --- .../bucket/categorize-text-aggregation.asciidoc | 13 +++++++------ .../categorization/CategorizationTokenTree.java | 15 +++++++-------- .../xpack/ml/aggs/categorization/TreeNode.java | 10 +++++++++- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc index 091d2111b8070..cc0a0e787f844 100644 --- a/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc +++ b/docs/reference/aggregations/bucket/categorize-text-aggregation.asciidoc @@ -10,12 +10,6 @@ A multi-bucket aggregation that groups semi-structured text into buckets. Each ` using a custom analyzer. The resulting tokens are then categorized creating buckets of similarly formatted text values. This aggregation works best with machine generated text like system logs. -WARNING: Re-analyzing _large_ result sets will require a lot of time and memory. This aggregation should be - used in conjunction with <>. Additionally, you may consider - using the aggregation as a child of either the <> or - <> aggregation. - This will typically improve speed and memory use. - NOTE: If you have considerable memory allocated to your JVM but are receiving circuit breaker exceptions from this aggregation, you may be attempting to categorize text that is poorly formatted for categorization. Consider adding `categorization_filters` or running under <> or @@ -118,6 +112,13 @@ merging. ==== Basic use + +WARNING: Re-analyzing _large_ result sets will require a lot of time and memory. This aggregation should be + used in conjunction with <>. Additionally, you may consider + using the aggregation as a child of either the <> or + <> aggregation. + This will typically improve speed and memory use. + Example: [source,console] diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java index a4d552d35f4e3..4a64a28fd0001 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java @@ -113,6 +113,7 @@ public TextCategorization parseLogLine(final int[] logTokenIds, long docCount) { TreeNode currentNode = this.root.get(logTokenIds.length); if (currentNode == null) { // we are missing an entire sub tree. New log length found currentNode = newNode(docCount, 0, logTokenIds); + incSize(currentNode.ramBytesUsed() + RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + RamUsageEstimator.NUM_BYTES_OBJECT_REF); this.root.put(logTokenIds.length, currentNode); } else { currentNode.incCount(docCount); @@ -121,19 +122,17 @@ public TextCategorization parseLogLine(final int[] logTokenIds, long docCount) { } TreeNode newNode(long docCount, int tokenPos, int[] logTokenIds) { - TreeNode node = tokenPos < maxMatchTokens - 1 && tokenPos < logTokenIds.length + return tokenPos < maxMatchTokens - 1 && tokenPos < logTokenIds.length ? new TreeNode.InnerTreeNode(docCount, tokenPos, maxUniqueTokens) : new TreeNode.LeafTreeNode(docCount, similarityThreshold); - // The size of the node + entry (since it is a map entry) + extra reference for priority queue - sizeInBytes += node.ramBytesUsed() + RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + RamUsageEstimator.NUM_BYTES_OBJECT_REF; - return node; } TextCategorization newGroup(long docCount, int[] logTokenIds) { - TextCategorization group = new TextCategorization(logTokenIds, docCount, idGenerator++); - // Get the regular size bytes from the LogGroup and how much it costs to reference it - sizeInBytes += group.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF; - return group; + return new TextCategorization(logTokenIds, docCount, idGenerator++); + } + + void incSize(long size) { + sizeInBytes += size; } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java index 907946e7d1790..dd68e9064468b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.aggs.categorization; import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.core.Tuple; import org.elasticsearch.search.aggregations.AggregationExecutionException; @@ -111,7 +112,10 @@ public long ramBytesUsed() { public TextCategorization addLog(int[] logTokenIds, long docCount, CategorizationTokenTree treeNodeFactory) { return getAndUpdateLogGroup(logTokenIds, docCount).orElseGet(() -> { // Need to update the tree if possible - return putNewLogGroup(treeNodeFactory.newGroup(docCount, logTokenIds)); + TextCategorization group = putNewLogGroup(treeNodeFactory.newGroup(docCount, logTokenIds)); + // Get the regular size bytes from the LogGroup and how much it costs to reference it + treeNodeFactory.incSize(group.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF); + return group; }); } @@ -230,6 +234,10 @@ public TextCategorization addLog(final int[] logTokenIds, final long docCount, f return node; }).orElseGet(() -> { TreeNode newNode = treeNodeFactory.newNode(docCount, childrenTokenPos + 1, logTokenIds); + // The size of the node + entry (since it is a map entry) + extra reference for priority queue + treeNodeFactory.incSize( + newNode.ramBytesUsed() + RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + RamUsageEstimator.NUM_BYTES_OBJECT_REF + ); return addChild(currentToken, newNode); }); return child.addLog(logTokenIds, docCount, treeNodeFactory); From 9aba6ae325e86c33c6f007cd791eeb92a53e580d Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 30 Sep 2021 17:00:12 -0400 Subject: [PATCH 20/20] Addressing PR comments and fixing bug --- .../CategorizationTokenTree.java | 40 ++++---- .../CategorizeTextAggregator.java | 2 +- .../InternalCategorizationAggregation.java | 8 +- .../categorization/TextCategorization.java | 26 ++--- .../ml/aggs/categorization/TreeNode.java | 98 +++++++++---------- .../CategorizationAnalyzer.java | 3 +- .../categorization/InnerTreeNodeTests.java | 40 ++++---- .../categorization/LeafTreeNodeTests.java | 24 ++--- .../TextCategorizationTests.java | 4 +- 9 files changed, 122 insertions(+), 123 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java index 4a64a28fd0001..f5b5e6daea956 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizationTokenTree.java @@ -24,7 +24,7 @@ * With the following key differences * - This structure keeps track of the "smallest" sub-tree. So, instead of naively adding a new "*" node, the smallest sub-tree * is transformed if the incoming token has a higher doc_count. - * - Additionally, similarities are weighted, which allows for nicer merging of existing log categories + * - Additionally, similarities are weighted, which allows for nicer merging of existing categories * - An optional tree reduction step is available to collapse together tiny sub-trees * * @@ -33,7 +33,7 @@ * * Examples: * - * Given log values: + * Given token values: * * Node is online * Node is offline @@ -69,7 +69,7 @@ public CategorizationTokenTree(int maxUniqueTokens, int maxMatchTokens, int simi } public List toIntermediateBuckets(CategorizationBytesRefHash hash) { - return root.values().stream().flatMap(c -> c.getAllChildrenLogGroups().stream()).map(lg -> { + return root.values().stream().flatMap(c -> c.getAllChildrenTextCategorizations().stream()).map(lg -> { int[] categoryTokenIds = lg.getCategorization(); BytesRef[] bytesRefs = new BytesRef[categoryTokenIds.length]; for (int i = 0; i < categoryTokenIds.length; i++) { @@ -92,43 +92,43 @@ void mergeSmallestChildren() { /** * This method does not mutate the underlying structure. Meaning, if a matching categories isn't found, it may return empty. * - * @param logTokenIds The tokens to categorize - * @return The log category or `Optional.empty()` if one doesn't exist + * @param tokenIds The tokens to categorize + * @return The category or `Optional.empty()` if one doesn't exist */ - public Optional parseLogLineConst(final int[] logTokenIds) { - TreeNode currentNode = this.root.get(logTokenIds.length); - if (currentNode == null) { // we are missing an entire sub tree. New log length found + public Optional parseTokensConst(final int[] tokenIds) { + TreeNode currentNode = this.root.get(tokenIds.length); + if (currentNode == null) { // we are missing an entire sub tree. New token length found return Optional.empty(); } - return Optional.ofNullable(currentNode.getLogGroup(logTokenIds)); + return Optional.ofNullable(currentNode.getCategorization(tokenIds)); } /** * This categorizes the passed tokens, potentially mutating the structure by expanding an existing category or adding a new one. - * @param logTokenIds The log tokens to categorize + * @param tokenIds The tokens to categorize * @param docCount The count of docs for the given tokens * @return An existing categorization or a newly created one */ - public TextCategorization parseLogLine(final int[] logTokenIds, long docCount) { - TreeNode currentNode = this.root.get(logTokenIds.length); - if (currentNode == null) { // we are missing an entire sub tree. New log length found - currentNode = newNode(docCount, 0, logTokenIds); + public TextCategorization parseTokens(final int[] tokenIds, long docCount) { + TreeNode currentNode = this.root.get(tokenIds.length); + if (currentNode == null) { // we are missing an entire sub tree. New token length found + currentNode = newNode(docCount, 0, tokenIds); incSize(currentNode.ramBytesUsed() + RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + RamUsageEstimator.NUM_BYTES_OBJECT_REF); - this.root.put(logTokenIds.length, currentNode); + this.root.put(tokenIds.length, currentNode); } else { currentNode.incCount(docCount); } - return currentNode.addLog(logTokenIds, docCount, this); + return currentNode.addText(tokenIds, docCount, this); } - TreeNode newNode(long docCount, int tokenPos, int[] logTokenIds) { - return tokenPos < maxMatchTokens - 1 && tokenPos < logTokenIds.length + TreeNode newNode(long docCount, int tokenPos, int[] tokenIds) { + return tokenPos < maxMatchTokens - 1 && tokenPos < tokenIds.length ? new TreeNode.InnerTreeNode(docCount, tokenPos, maxUniqueTokens) : new TreeNode.LeafTreeNode(docCount, similarityThreshold); } - TextCategorization newGroup(long docCount, int[] logTokenIds) { - return new TextCategorization(logTokenIds, docCount, idGenerator++); + TextCategorization newCategorization(long docCount, int[] tokenIds) { + return new TextCategorization(tokenIds, docCount, idGenerator++); } void incSize(long size) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java index acd42eb1c6dda..16058fbdae4f2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregator.java @@ -216,7 +216,7 @@ private void processTokenStream( ts.close(); } long previousSize = categorizer.ramBytesUsed(); - TextCategorization lg = categorizer.parseLogLine( + TextCategorization lg = categorizer.parseTokens( tokens.stream().mapToInt(Integer::valueOf).toArray(), docCountProvider.getDocCount(doc) ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java index b8530f4988668..92c51b8d75b4c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/InternalCategorizationAggregation.java @@ -357,7 +357,7 @@ public InternalAggregation reduce(List aggregations, Reduce ); // TODO: Could we do a merge sort similar to terms? // It would require us returning partial reductions sorted by key, not by doc_count - // First, make sure we have all the counts for equal log groups + // First, make sure we have all the counts for equal categorizations Map reduced = new HashMap<>(); for (InternalAggregation aggregation : aggregations) { InternalCategorizationAggregation categorizationAggregation = (InternalCategorizationAggregation) aggregation; @@ -370,13 +370,13 @@ public InternalAggregation reduce(List aggregations, Reduce .stream() .sorted(Comparator.comparing(DelayedCategorizationBucket::getDocCount).reversed()) .forEach(bucket -> - // Parse log line takes document count into account and merging on smallest groups - categorizationTokenTree.parseLogLine(hash.getIds(bucket.key.keyAsTokens()), bucket.docCount) + // Parse tokens takes document count into account and merging on smallest groups + categorizationTokenTree.parseTokens(hash.getIds(bucket.key.keyAsTokens()), bucket.docCount) ); categorizationTokenTree.mergeSmallestChildren(); Map mergedBuckets = new HashMap<>(); for (DelayedCategorizationBucket delayedBucket : reduced.values()) { - TextCategorization group = categorizationTokenTree.parseLogLineConst(hash.getIds(delayedBucket.key.keyAsTokens())) + TextCategorization group = categorizationTokenTree.parseTokensConst(hash.getIds(delayedBucket.key.keyAsTokens())) .orElseThrow( () -> new AggregationExecutionException( "Unexpected null categorization group for bucket [" + delayedBucket.key.asString() + "]" diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java index e4370e2f49776..7ea72f489ae2d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorization.java @@ -16,8 +16,8 @@ /** * A text categorization group that provides methods for: - * - calculating similarity between it and a new log - * - expanding the existing log group by adding a new log + * - calculating similarity between it and a new text + * - expanding the existing categorization by adding a new array of tokens */ class TextCategorization implements Accountable { @@ -30,11 +30,11 @@ class TextCategorization implements Accountable { // Used at the shard level for tracking the bucket ordinal for collecting sub aggregations long bucketOrd; - TextCategorization(int[] logTokenIds, long count, long id) { + TextCategorization(int[] tokenIds, long count, long id) { this.id = id; - this.categorization = logTokenIds; + this.categorization = tokenIds; this.count = count; - this.tokenCounts = new long[logTokenIds.length]; + this.tokenCounts = new long[tokenIds.length]; Arrays.fill(this.tokenCounts, count); } @@ -50,13 +50,13 @@ public long getCount() { return count; } - Similarity calculateSimilarity(int[] logEvent) { - assert logEvent.length == this.categorization.length; + Similarity calculateSimilarity(int[] tokenIds) { + assert tokenIds.length == this.categorization.length; int eqParams = 0; long tokenCount = 0; long tokensKept = 0; - for (int i = 0; i < logEvent.length; i++) { - if (logEvent[i] == this.categorization[i]) { + for (int i = 0; i < tokenIds.length; i++) { + if (tokenIds[i] == this.categorization[i]) { tokensKept += tokenCounts[i]; tokenCount += tokenCounts[i]; } else if (this.categorization[i] == WILD_CARD_ID) { @@ -68,10 +68,10 @@ Similarity calculateSimilarity(int[] logEvent) { return new Similarity((double) tokensKept / tokenCount, eqParams); } - void addLog(int[] logEvent, long docCount) { - assert logEvent.length == this.categorization.length; - for (int i = 0; i < logEvent.length; i++) { - if (logEvent[i] != this.categorization[i]) { + void addTokens(int[] tokenIds, long docCount) { + assert tokenIds.length == this.categorization.length; + for (int i = 0; i < tokenIds.length; i++) { + if (tokenIds[i] != this.categorization[i]) { this.categorization[i] = WILD_CARD_ID; } else { tokenCounts[i] += docCount; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java index dd68e9064468b..7b13e93d8f1ea 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TreeNode.java @@ -55,11 +55,11 @@ final long getCount() { } // TODO add option for calculating the cost of adding the new group - abstract TextCategorization addLog(int[] logTokenIds, long docCount, CategorizationTokenTree treeNodeFactory); + abstract TextCategorization addText(int[] tokenIds, long docCount, CategorizationTokenTree treeNodeFactory); - abstract TextCategorization getLogGroup(int[] logTokens); + abstract TextCategorization getCategorization(int[] tokenIds); - abstract List getAllChildrenLogGroups(); + abstract List getAllChildrenTextCategorizations(); abstract void collapseTinyChildren(); @@ -94,8 +94,8 @@ void mergeWith(TreeNode treeNode) { incCount(treeNode.getCount()); LeafTreeNode otherLeaf = (LeafTreeNode) treeNode; for (TextCategorization group : otherLeaf.textCategorizations) { - if (getAndUpdateLogGroup(group.getCategorization(), group.getCount()).isPresent() == false) { - putNewLogGroup(group); + if (getAndUpdateTextCategorization(group.getCategorization(), group.getCount()).isPresent() == false) { + putNewTextCategorization(group); } } } @@ -109,52 +109,52 @@ public long ramBytesUsed() { } @Override - public TextCategorization addLog(int[] logTokenIds, long docCount, CategorizationTokenTree treeNodeFactory) { - return getAndUpdateLogGroup(logTokenIds, docCount).orElseGet(() -> { + public TextCategorization addText(int[] tokenIds, long docCount, CategorizationTokenTree treeNodeFactory) { + return getAndUpdateTextCategorization(tokenIds, docCount).orElseGet(() -> { // Need to update the tree if possible - TextCategorization group = putNewLogGroup(treeNodeFactory.newGroup(docCount, logTokenIds)); - // Get the regular size bytes from the LogGroup and how much it costs to reference it - treeNodeFactory.incSize(group.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF); - return group; + TextCategorization categorization = putNewTextCategorization(treeNodeFactory.newCategorization(docCount, tokenIds)); + // Get the regular size bytes from the TextCategorization and how much it costs to reference it + treeNodeFactory.incSize(categorization.ramBytesUsed() + RamUsageEstimator.NUM_BYTES_OBJECT_REF); + return categorization; }); } @Override - List getAllChildrenLogGroups() { + List getAllChildrenTextCategorizations() { return textCategorizations; } @Override void collapseTinyChildren() {} - private Optional getAndUpdateLogGroup(int[] logTokenIds, long docCount) { - return getBestLogGroup(logTokenIds).map(bestGroupAndSimilarity -> { + private Optional getAndUpdateTextCategorization(int[] tokenIds, long docCount) { + return getBestCategorization(tokenIds).map(bestGroupAndSimilarity -> { if ((bestGroupAndSimilarity.v2() * 100) >= similarityThreshold) { - bestGroupAndSimilarity.v1().addLog(logTokenIds, docCount); + bestGroupAndSimilarity.v1().addTokens(tokenIds, docCount); return bestGroupAndSimilarity.v1(); } return null; }); } - TextCategorization putNewLogGroup(TextCategorization group) { - textCategorizations.add(group); - return group; + TextCategorization putNewTextCategorization(TextCategorization categorization) { + textCategorizations.add(categorization); + return categorization; } - private Optional> getBestLogGroup(int[] logTokenIds) { + private Optional> getBestCategorization(int[] tokenIds) { if (textCategorizations.isEmpty()) { return Optional.empty(); } if (textCategorizations.size() == 1) { return Optional.of( - new Tuple<>(textCategorizations.get(0), textCategorizations.get(0).calculateSimilarity(logTokenIds).getSimilarity()) + new Tuple<>(textCategorizations.get(0), textCategorizations.get(0).calculateSimilarity( tokenIds).getSimilarity()) ); } TextCategorization.Similarity maxSimilarity = null; TextCategorization bestGroup = null; for (TextCategorization textCategorization : this.textCategorizations) { - TextCategorization.Similarity groupSimilarity = textCategorization.calculateSimilarity(logTokenIds); + TextCategorization.Similarity groupSimilarity = textCategorization.calculateSimilarity( tokenIds); if (maxSimilarity == null || groupSimilarity.compareTo(maxSimilarity) > 0) { maxSimilarity = groupSimilarity; bestGroup = textCategorization; @@ -164,8 +164,8 @@ private Optional> getBestLogGroup(int[] logTok } @Override - public TextCategorization getLogGroup(final int[] logTokenIds) { - return getBestLogGroup(logTokenIds).map(Tuple::v1).orElse(null); + public TextCategorization getCategorization(final int[] tokenIds) { + return getBestCategorization(tokenIds).map(Tuple::v1).orElse(null); } @Override @@ -186,17 +186,17 @@ public int hashCode() { static class InnerTreeNode extends TreeNode { // TODO: Change to LongObjectMap? - private final Map children; + private final Map children; private final int childrenTokenPos; private final int maxChildren; - private final PriorityQueue smallestChild; + private final PriorityQueue smallestChild; InnerTreeNode(long count, int childrenTokenPos, int maxChildren) { super(count); children = new HashMap<>(); this.childrenTokenPos = childrenTokenPos; this.maxChildren = maxChildren; - this.smallestChild = new PriorityQueue<>(maxChildren, Comparator.comparing(NativeLongPair::count)); + this.smallestChild = new PriorityQueue<>(maxChildren, Comparator.comparing(NativeIntLongPair::count)); } @Override @@ -205,9 +205,9 @@ boolean isLeaf() { } @Override - public TextCategorization getLogGroup(final int[] logTokenIds) { - return getChild(logTokenIds[childrenTokenPos]).or(() -> getChild(WILD_CARD_ID)) - .map(node -> node.getLogGroup(logTokenIds)) + public TextCategorization getCategorization(final int[] tokenIds) { + return getChild(tokenIds[childrenTokenPos]).or(() -> getChild(WILD_CARD_ID)) + .map(node -> node.getCategorization(tokenIds)) .orElse(null); } @@ -220,12 +220,12 @@ public long ramBytesUsed() { + NUM_BYTES_OBJECT_REF // smallestChildReference + sizeOfMap(children, NUM_BYTES_OBJECT_REF) // children, // Number of items in the queue, reference to tuple, and then the tuple references - + (long) smallestChild.size() * (NUM_BYTES_OBJECT_REF + Long.BYTES + Long.BYTES); + + (long) smallestChild.size() * (NUM_BYTES_OBJECT_REF + Integer.BYTES + Long.BYTES); } @Override - public TextCategorization addLog(final int[] logTokenIds, final long docCount, final CategorizationTokenTree treeNodeFactory) { - final long currentToken = logTokenIds[childrenTokenPos]; + public TextCategorization addText(final int[] tokenIds, final long docCount, final CategorizationTokenTree treeNodeFactory) { + final int currentToken = tokenIds[childrenTokenPos]; TreeNode child = getChild(currentToken).map(node -> { node.incCount(docCount); if (smallestChild.isEmpty() == false && smallestChild.peek().tokenId == currentToken) { @@ -233,14 +233,14 @@ public TextCategorization addLog(final int[] logTokenIds, final long docCount, f } return node; }).orElseGet(() -> { - TreeNode newNode = treeNodeFactory.newNode(docCount, childrenTokenPos + 1, logTokenIds); + TreeNode newNode = treeNodeFactory.newNode(docCount, childrenTokenPos + 1, tokenIds); // The size of the node + entry (since it is a map entry) + extra reference for priority queue treeNodeFactory.incSize( newNode.ramBytesUsed() + RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + RamUsageEstimator.NUM_BYTES_OBJECT_REF ); return addChild(currentToken, newNode); }); - return child.addLog(logTokenIds, docCount, treeNodeFactory); + return child.addText(tokenIds, docCount, treeNodeFactory); } @Override @@ -260,7 +260,7 @@ void collapseTinyChildren() { }); if (maybeWildChild.isPresent()) { TreeNode wildChild = maybeWildChild.get(); - NativeLongPair tinyNode; + NativeIntLongPair tinyNode; while ((tinyNode = smallestChild.poll()) != null) { // If we have no more tiny nodes, stop iterating over them if ((double) tinyNode.count / this.getCount() > 1.0 / maxChildren) { @@ -288,14 +288,14 @@ void mergeWith(TreeNode treeNode) { InnerTreeNode innerTreeNode = (InnerTreeNode) treeNode; TreeNode siblingWildChild = innerTreeNode.children.remove(WILD_CARD_ID); addChild(WILD_CARD_ID, siblingWildChild); - NativeLongPair siblingChild; + NativeIntLongPair siblingChild; while ((siblingChild = innerTreeNode.smallestChild.poll()) != null) { TreeNode nephewNode = innerTreeNode.children.remove(siblingChild.tokenId); addChild(siblingChild.tokenId, nephewNode); } } - private TreeNode addChild(long tokenId, TreeNode node) { + private TreeNode addChild(int tokenId, TreeNode node) { if (node == null) { return null; } @@ -303,7 +303,7 @@ private TreeNode addChild(long tokenId, TreeNode node) { existingNode.mergeWith(node); if (smallestChild.isEmpty() == false && smallestChild.peek().tokenId == tokenId) { smallestChild.poll(); - smallestChild.add(NativeLongPair.of(tokenId, existingNode.getCount())); + smallestChild.add(NativeIntLongPair.of(tokenId, existingNode.getCount())); } return existingNode; }); @@ -349,22 +349,22 @@ private TreeNode addChild(long tokenId, TreeNode node) { return node; } - private void addChildAndUpdateSmallest(long tokenId, TreeNode node) { + private void addChildAndUpdateSmallest(int tokenId, TreeNode node) { children.put(tokenId, node); if (tokenId != WILD_CARD_ID) { - smallestChild.add(NativeLongPair.of(tokenId, node.count)); + smallestChild.add(NativeIntLongPair.of(tokenId, node.count)); } } - private Optional getChild(long tokenId) { + private Optional getChild(int tokenId) { return Optional.ofNullable(children.get(tokenId)); } - public List getAllChildrenLogGroups() { - return children.values().stream().flatMap(c -> c.getAllChildrenLogGroups().stream()).collect(Collectors.toList()); + public List getAllChildrenTextCategorizations() { + return children.values().stream().flatMap(c -> c.getAllChildrenTextCategorizations().stream()).collect(Collectors.toList()); } - boolean hasChild(long tokenId) { + boolean hasChild(int tokenId) { return children.containsKey(tokenId); } @@ -385,15 +385,15 @@ public int hashCode() { } } - private static class NativeLongPair { - private final long tokenId; + private static class NativeIntLongPair { + private final int tokenId; private final long count; - static NativeLongPair of(long tokenId, long count) { - return new NativeLongPair(tokenId, count); + static NativeIntLongPair of(int tokenId, long count) { + return new NativeIntLongPair(tokenId, count); } - NativeLongPair(long tokenId, long count) { + NativeIntLongPair(int tokenId, long count) { this.tokenId = tokenId; this.count = count; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java index a3753c4105eab..6147bc0256ca5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/categorization/CategorizationAnalyzer.java @@ -15,7 +15,6 @@ import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig; -import java.io.Closeable; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -26,7 +25,7 @@ * Converts messages to lists of tokens that will be fed to the ML categorization algorithm. * */ -public class CategorizationAnalyzer implements Closeable, Releasable { +public class CategorizationAnalyzer implements Releasable { private final Analyzer analyzer; private final boolean closeAnalyzer; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java index bf7c49b6dd511..e7f78a01d0130 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/InnerTreeNodeTests.java @@ -34,64 +34,64 @@ public void closeRefHash() { bytesRefHash.close(); } - public void testAddLog() { + public void testAddText() { TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); - TextCategorization group = innerTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, factory); + TextCategorization group = innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, factory); assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); assertArrayEquals( - innerTreeNode.addLog(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1, factory).getCategorization(), + innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1, factory).getCategorization(), getTokens(bytesRefHash, "foo2", "bar", "baz", "biz") ); assertArrayEquals( - innerTreeNode.addLog(getTokens(bytesRefHash, "foo3", "bar", "baz", "biz"), 1, factory).getCategorization(), + innerTreeNode.addText(getTokens(bytesRefHash, "foo3", "bar", "baz", "biz"), 1, factory).getCategorization(), getTokens(bytesRefHash, "foo3", "bar", "baz", "biz") ); assertArrayEquals( - innerTreeNode.addLog(getTokens(bytesRefHash, "foo4", "bar", "baz", "biz"), 1, factory).getCategorization(), + innerTreeNode.addText(getTokens(bytesRefHash, "foo4", "bar", "baz", "biz"), 1, factory).getCategorization(), getTokens(bytesRefHash, "*", "bar", "baz", "biz") ); assertArrayEquals( - innerTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "bizzy"), 1, factory).getCategorization(), + innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "bizzy"), 1, factory).getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "*") ); } - public void testAddLogWithLargerIncoming() { + public void testAddTokensWithLargerIncoming() { TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1, 0, 3); - TextCategorization group = innerTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 100, factory); + TextCategorization group = innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 100, factory); assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); assertArrayEquals( - innerTreeNode.addLog(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 100, factory).getCategorization(), + innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 100, factory).getCategorization(), getTokens(bytesRefHash, "foo2", "bar", "baz", "biz") ); assertArrayEquals( - innerTreeNode.addLog(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), + innerTreeNode.addText(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz") ); assertArrayEquals( - innerTreeNode.addLog(getTokens(bytesRefHash, "foobigun", "bar", "baz", "biz"), 1000, factory).getCategorization(), + innerTreeNode.addText(getTokens(bytesRefHash, "foobigun", "bar", "baz", "biz"), 1000, factory).getCategorization(), getTokens(bytesRefHash, "foobigun", "bar", "baz", "biz") ); assertThat( - innerTreeNode.getLogGroup(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz")).getCategorization(), + innerTreeNode.getCategorization(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz")).getCategorization(), equalTo(getTokens(bytesRefHash, "*", "bar", "baz", "biz")) ); } public void testCollapseTinyChildren() { TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1000, 0, 4); - TextCategorization group = innerTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1000, factory); + TextCategorization group = innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1000, factory); assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); assertArrayEquals( - innerTreeNode.addLog(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1000, factory).getCategorization(), + innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1000, factory).getCategorization(), getTokens(bytesRefHash, "foo2", "bar", "baz", "biz") ); innerTreeNode.incCount(1000); assertArrayEquals( - innerTreeNode.addLog(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), + innerTreeNode.addText(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory).getCategorization(), getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz") ); innerTreeNode.incCount(1); @@ -102,21 +102,21 @@ public void testCollapseTinyChildren() { public void testMergeWith() { TreeNode.InnerTreeNode innerTreeNode = new TreeNode.InnerTreeNode(1000, 0, 3); - innerTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1000, factory); + innerTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1000, factory); innerTreeNode.incCount(1000); - innerTreeNode.addLog(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1000, factory); + innerTreeNode.addText(getTokens(bytesRefHash, "foo2", "bar", "baz", "biz"), 1000, factory); expectThrows(UnsupportedOperationException.class, () -> innerTreeNode.mergeWith(new TreeNode.LeafTreeNode(1, 60))); TreeNode.InnerTreeNode mergeWith = new TreeNode.InnerTreeNode(1, 0, 3); - innerTreeNode.addLog(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory); + innerTreeNode.addText(getTokens(bytesRefHash, "foosmall", "bar", "baz", "biz"), 1, factory); innerTreeNode.incCount(1); - innerTreeNode.addLog(getTokens(bytesRefHash, "footiny", "bar", "baz", "biz"), 1, factory); + innerTreeNode.addText(getTokens(bytesRefHash, "footiny", "bar", "baz", "biz"), 1, factory); innerTreeNode.mergeWith(mergeWith); assertThat(innerTreeNode.hasChild(WILD_CARD_ID), is(true)); assertArrayEquals( - innerTreeNode.getLogGroup(getTokens(bytesRefHash, "footiny", "bar", "baz", "biz")).getCategorization(), + innerTreeNode.getCategorization(getTokens(bytesRefHash, "footiny", "bar", "baz", "biz")).getCategorization(), getTokens(bytesRefHash, "*", "bar", "baz", "biz") ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java index 3e5730e67dfe2..2bef18993e019 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/LeafTreeNodeTests.java @@ -36,24 +36,24 @@ public void closeRefHash() { public void testAddGroup() { TreeNode.LeafTreeNode leafTreeNode = new TreeNode.LeafTreeNode(0, 60); - TextCategorization group = leafTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, factory); + TextCategorization group = leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, factory); assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "biz")); assertThat(group.getCount(), equalTo(1L)); - assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(1)); + assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(1)); long previousBytesUsed = leafTreeNode.ramBytesUsed(); - group = leafTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "bozo", "bizzy"), 1, factory); + group = leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "bozo", "bizzy"), 1, factory); assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "bozo", "bizzy")); assertThat(group.getCount(), equalTo(1L)); - assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(2)); + assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(2)); assertThat(leafTreeNode.ramBytesUsed(), greaterThan(previousBytesUsed)); previousBytesUsed = leafTreeNode.ramBytesUsed(); - group = leafTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "different"), 3, factory); + group = leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "different"), 3, factory); assertArrayEquals(group.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "*")); assertThat(group.getCount(), equalTo(4L)); - assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(2)); + assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(2)); assertThat(previousBytesUsed, equalTo(leafTreeNode.ramBytesUsed())); } @@ -65,23 +65,23 @@ public void testMergeWith() { expectThrows(UnsupportedOperationException.class, () -> leafTreeNode.mergeWith(new TreeNode.InnerTreeNode(1, 2, 3))); leafTreeNode.incCount(5); - leafTreeNode.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 5, factory); + leafTreeNode.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 5, factory); TreeNode.LeafTreeNode toMerge = new TreeNode.LeafTreeNode(0, 60); leafTreeNode.incCount(1); - toMerge.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "bizzy"), 1, factory); + toMerge.addText(getTokens(bytesRefHash, "foo", "bar", "baz", "bizzy"), 1, factory); leafTreeNode.incCount(1); - toMerge.addLog(getTokens(bytesRefHash, "foo", "bart", "bat", "built"), 1, factory); + toMerge.addText(getTokens(bytesRefHash, "foo", "bart", "bat", "built"), 1, factory); leafTreeNode.mergeWith(toMerge); - assertThat(leafTreeNode.getAllChildrenLogGroups(), hasSize(2)); + assertThat(leafTreeNode.getAllChildrenTextCategorizations(), hasSize(2)); assertThat(leafTreeNode.getCount(), equalTo(7L)); assertArrayEquals( - leafTreeNode.getAllChildrenLogGroups().get(0).getCategorization(), + leafTreeNode.getAllChildrenTextCategorizations().get(0).getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "*") ); assertArrayEquals( - leafTreeNode.getAllChildrenLogGroups().get(1).getCategorization(), + leafTreeNode.getAllChildrenTextCategorizations().get(1).getCategorization(), getTokens(bytesRefHash, "foo", "bart", "bat", "built") ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java index 827019ea86843..59129f8801937 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/TextCategorizationTests.java @@ -56,9 +56,9 @@ public void testSimilarity() { assertThat(sims.getWildCardCount(), equalTo(0)); } - public void testAddLog() { + public void testAddTokens() { TextCategorization lg = new TextCategorization(getTokens(bytesRefHash, "foo", "bar", "baz", "biz"), 1, 1); - lg.addLog(getTokens(bytesRefHash, "foo", "bar", "baz", "bozo"), 2); + lg.addTokens(getTokens(bytesRefHash, "foo", "bar", "baz", "bozo"), 2); assertThat(lg.getCount(), equalTo(3L)); assertArrayEquals(lg.getCategorization(), getTokens(bytesRefHash, "foo", "bar", "baz", "*")); }