From 5616c01dd5a0c81728954bdcd91ce72c9b8bf218 Mon Sep 17 00:00:00 2001 From: yuye-aws Date: Wed, 21 May 2025 17:01:42 +0800 Subject: [PATCH 01/10] POC: Support sparse model return token_id Signed-off-by: yuye-aws --- .../ml/common/input/nlp/TextDocsMLInput.java | 5 +- .../SparseEncodingParameters.java | 95 +++++++++++++++++++ .../engine/algorithms/TextEmbeddingModel.java | 5 + .../SparseEncodingTranslator.java | 30 ++++-- .../ml/plugin/MachineLearningPlugin.java | 4 +- 5 files changed, 127 insertions(+), 12 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java diff --git a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java index 1a2f201dd5..a818137ad7 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java @@ -27,10 +27,7 @@ * ML input class which supports a list fo text docs. * This class can be used for TEXT_EMBEDDING model. */ -@org.opensearch.ml.common.annotation.MLInput(functionNames = { - FunctionName.TEXT_EMBEDDING, - FunctionName.SPARSE_ENCODING, - FunctionName.SPARSE_TOKENIZE }) +@org.opensearch.ml.common.annotation.MLInput(functionNames = { FunctionName.TEXT_EMBEDDING }) public class TextDocsMLInput extends MLInput { public static final String TEXT_DOCS_FIELD = "text_docs"; public static final String RESULT_FILTER_FIELD = "result_filter"; diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java new file mode 100644 index 0000000000..f5bf85d331 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.input.parameter.textembedding; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.MLAlgoParameter; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; + +import lombok.Builder; + +@MLAlgoParameter(algorithms = { FunctionName.SPARSE_ENCODING }) +public class SparseEncodingParameters implements MLAlgoParams { + + public static final String PARSE_FIELD_NAME = FunctionName.SPARSE_ENCODING.name(); + public static final String SPARSE_ENCODING_FORMAT_FIELD = "sparse_encoding_format"; + + @Override + public int getVersion() { + return 1; + } + + @Override + public String getWriteableName() { + return PARSE_FIELD_NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(sparseEncodingType.name()); + } + + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + SparseEncodingParameters::parse + ); + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + xContentBuilder.startObject(); + if (sparseEncodingType != null) { + xContentBuilder.field(SPARSE_ENCODING_FORMAT_FIELD, sparseEncodingType.name()); + } + xContentBuilder.endObject(); + return xContentBuilder; + } + + public enum SparseEncodingFormat { + WORD, + INT + } + + // The type of the content to be embedded + private final SparseEncodingFormat sparseEncodingType; + + @Builder(toBuilder = true) + public SparseEncodingParameters(SparseEncodingFormat sparseEncodingType) { + this.sparseEncodingType = sparseEncodingType; + } + + public SparseEncodingFormat getSparseEncodingType() { + return sparseEncodingType; + } + + public static MLAlgoParams parse(XContentParser parser) throws IOException { + SparseEncodingFormat sparseEncodingType = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + if (fieldName.equals(SPARSE_ENCODING_FORMAT_FIELD)) { + String contentType = parser.text(); + sparseEncodingType = SparseEncodingFormat.valueOf(contentType.toUpperCase(Locale.ROOT)); + } else { + parser.skipChildren(); + } + } + return new SparseEncodingParameters(sparseEncodingType); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java index 63c11ca79d..3d29d3fbef 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java @@ -12,6 +12,7 @@ import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; +import org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -40,6 +41,10 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla for (String doc : textDocsInput.getDocs()) { Input input = new Input(); input.add(doc); + if (mlParams instanceof SparseEncodingParameters) { + input.add("sparse_encoding_format", ((SparseEncodingParameters) mlParams).getSparseEncodingType().name()); + } + output = getPredictor().predict(input); tensorOutputs.add(parseModelTensorOutput(output, resultFilter)); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java index baebbe1972..f67e124695 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java @@ -6,34 +6,48 @@ package org.opensearch.ml.engine.algorithms.sparse_encoding; import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY; +import static org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters.SPARSE_ENCODING_FORMAT_FIELD; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.Iterator; import java.util.List; import java.util.Map; +import org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.engine.algorithms.SentenceTransformerTranslator; +import ai.djl.modality.Input; import ai.djl.modality.Output; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.translate.TranslatorContext; public class SparseEncodingTranslator extends SentenceTransformerTranslator { + + @Override + public NDList processInput(TranslatorContext ctx, Input input) { + String sparse_encoding_format = input.getAsString(SPARSE_ENCODING_FORMAT_FIELD); + if (sparse_encoding_format != null) { + ctx.setAttachment(SPARSE_ENCODING_FORMAT_FIELD, sparse_encoding_format); + } + return super.processInput(ctx, input); + } + @Override public Output processOutput(TranslatorContext ctx, NDList list) { Output output = new Output(200, "OK"); + Object sparseEncodingFormatObject = ctx.getAttachment(SPARSE_ENCODING_FORMAT_FIELD); + String sparseEncodingFormatString = sparseEncodingFormatObject != null + ? sparseEncodingFormatObject.toString() + : SparseEncodingParameters.SparseEncodingFormat.WORD.name(); List outputs = new ArrayList<>(); - Iterator iterator = list.iterator(); - while (iterator.hasNext()) { - NDArray ndArray = iterator.next(); + for (NDArray ndArray : list) { String name = ndArray.getName(); - Map tokenWeightsMap = convertOutput(ndArray); + Map tokenWeightsMap = convertOutput(ndArray, sparseEncodingFormatString); Map wrappedMap = Map.of(ML_MAP_RESPONSE_KEY, Collections.singletonList(tokenWeightsMap)); ModelTensor tensor = ModelTensor.builder().name(name).dataAsMap(wrappedMap).build(); outputs.add(tensor); @@ -44,12 +58,14 @@ public Output processOutput(TranslatorContext ctx, NDList list) { return output; } - private Map convertOutput(NDArray array) { + private Map convertOutput(NDArray array, String sparseEncodingFormat) { Map map = new HashMap<>(); NDArray nonZeroIndices = array.nonzero().squeeze(); for (long index : nonZeroIndices.toLongArray()) { - String s = this.tokenizer.decode(new long[] { index }, true); + String s = sparseEncodingFormat.equals(SparseEncodingParameters.SparseEncodingFormat.INT.name()) + ? Long.toString(index) + : this.tokenizer.decode(new long[] { index }, true); if (!s.isEmpty()) { map.put(s, array.getFloat(index)); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index e1d7e78d2b..fb4aeb51a7 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -150,6 +150,7 @@ import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams; import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.settings.MLCommonsSettings; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; @@ -1094,7 +1095,8 @@ public List getNamedXContent() { RCFSummarizeParams.XCONTENT_REGISTRY, LogisticRegressionParams.XCONTENT_REGISTRY, TextEmbeddingModelConfig.XCONTENT_REGISTRY, - AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY + AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY, + SparseEncodingParameters.XCONTENT_REGISTRY ); } From 1262a65df57d5be68b48c10a1a5c54aea3838e3d Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 2 Jul 2025 13:59:44 +0800 Subject: [PATCH 02/10] params Signed-off-by: zhichao-aws --- .../AbstractSparseEncodingParameters.java | 98 +++++++++++++++++++ .../SparseEncodingParameters.java | 72 ++++---------- .../SparseTokenizeParameters.java | 57 +++++++++++ .../engine/algorithms/TextEmbeddingModel.java | 10 +- .../SparseEncodingTranslator.java | 22 ++--- .../ml/plugin/MachineLearningPlugin.java | 4 +- 6 files changed, 193 insertions(+), 70 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AbstractSparseEncodingParameters.java create mode 100644 common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseTokenizeParameters.java diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AbstractSparseEncodingParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AbstractSparseEncodingParameters.java new file mode 100644 index 0000000000..6f3102a409 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AbstractSparseEncodingParameters.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.input.parameter.textembedding; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; + +import lombok.Getter; + +/** + * Abstract base class for sparse encoding related parameters. + * Contains common logic shared between SPARSE_ENCODING and SPARSE_TOKENIZE algorithms. + */ +@Getter +public abstract class AbstractSparseEncodingParameters implements MLAlgoParams { + + public static final String EMBEDDING_FORMAT_FIELD = "embedding_format"; + + public enum EmbeddingFormat { + LEXICAL, + VECTOR + } + + // The type of the content to be encoded + protected final EmbeddingFormat embeddingFormat; + + protected AbstractSparseEncodingParameters(EmbeddingFormat embeddingFormat) { + // Set default to LEXICAL if null + this.embeddingFormat = embeddingFormat != null ? embeddingFormat : EmbeddingFormat.LEXICAL; + } + + /** + * Constructor for deserialization from StreamInput + */ + protected AbstractSparseEncodingParameters(StreamInput in) throws IOException { + String formatName = in.readOptionalString(); + this.embeddingFormat = formatName != null ? EmbeddingFormat.valueOf(formatName) : EmbeddingFormat.LEXICAL; + } + + @Override + public int getVersion() { + return 1; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(embeddingFormat.name()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + xContentBuilder.startObject(); + xContentBuilder.field(EMBEDDING_FORMAT_FIELD, embeddingFormat.name()); + xContentBuilder.endObject(); + return xContentBuilder; + } + + /** + * Common parsing method that can be used by subclasses. + * + * @param parser XContentParser to parse from + * @return parsed EmbeddingFormat, defaults to LEXICAL if not specified + * @throws IOException if parsing fails + */ + protected static EmbeddingFormat parseCommon(XContentParser parser) throws IOException { + EmbeddingFormat sparseEncodingType = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + if (fieldName.equals(EMBEDDING_FORMAT_FIELD)) { + String contentType = parser.text(); + sparseEncodingType = EmbeddingFormat.valueOf(contentType.toUpperCase(Locale.ROOT)); + } else { + parser.skipChildren(); + } + } + // Return LEXICAL as default if not specified + return sparseEncodingType != null ? sparseEncodingType : EmbeddingFormat.LEXICAL; + } + + public EmbeddingFormat getEmbeddingFormat() { + return embeddingFormat; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java index f5bf85d331..8f4db602b3 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java @@ -5,15 +5,11 @@ package org.opensearch.ml.common.input.parameter.textembedding; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - import java.io.IOException; -import java.util.Locale; import org.opensearch.core.ParseField; -import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.annotation.MLAlgoParameter; @@ -22,25 +18,9 @@ import lombok.Builder; @MLAlgoParameter(algorithms = { FunctionName.SPARSE_ENCODING }) -public class SparseEncodingParameters implements MLAlgoParams { +public class SparseEncodingParameters extends AbstractSparseEncodingParameters { public static final String PARSE_FIELD_NAME = FunctionName.SPARSE_ENCODING.name(); - public static final String SPARSE_ENCODING_FORMAT_FIELD = "sparse_encoding_format"; - - @Override - public int getVersion() { - return 1; - } - - @Override - public String getWriteableName() { - return PARSE_FIELD_NAME; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalString(sparseEncodingType.name()); - } public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( MLAlgoParams.class, @@ -48,48 +28,30 @@ public void writeTo(StreamOutput out) throws IOException { SparseEncodingParameters::parse ); - @Override - public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { - xContentBuilder.startObject(); - if (sparseEncodingType != null) { - xContentBuilder.field(SPARSE_ENCODING_FORMAT_FIELD, sparseEncodingType.name()); - } - xContentBuilder.endObject(); - return xContentBuilder; + // Default constructor with LEXICAL format + public SparseEncodingParameters() { + super(EmbeddingFormat.LEXICAL); } - public enum SparseEncodingFormat { - WORD, - INT + @Builder(toBuilder = true) + public SparseEncodingParameters(EmbeddingFormat sparseEncodingType) { + super(sparseEncodingType); } - // The type of the content to be embedded - private final SparseEncodingFormat sparseEncodingType; - - @Builder(toBuilder = true) - public SparseEncodingParameters(SparseEncodingFormat sparseEncodingType) { - this.sparseEncodingType = sparseEncodingType; + /** + * Constructor for deserialization from StreamInput + */ + public SparseEncodingParameters(StreamInput in) throws IOException { + super(in); } - public SparseEncodingFormat getSparseEncodingType() { - return sparseEncodingType; + @Override + public String getWriteableName() { + return PARSE_FIELD_NAME; } public static MLAlgoParams parse(XContentParser parser) throws IOException { - SparseEncodingFormat sparseEncodingType = null; - - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = parser.currentName(); - parser.nextToken(); - - if (fieldName.equals(SPARSE_ENCODING_FORMAT_FIELD)) { - String contentType = parser.text(); - sparseEncodingType = SparseEncodingFormat.valueOf(contentType.toUpperCase(Locale.ROOT)); - } else { - parser.skipChildren(); - } - } + EmbeddingFormat sparseEncodingType = parseCommon(parser); return new SparseEncodingParameters(sparseEncodingType); } } diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseTokenizeParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseTokenizeParameters.java new file mode 100644 index 0000000000..7a029930b8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseTokenizeParameters.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.input.parameter.textembedding; + +import java.io.IOException; + +import org.opensearch.core.ParseField; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.MLAlgoParameter; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; + +import lombok.Builder; + +@MLAlgoParameter(algorithms = { FunctionName.SPARSE_TOKENIZE }) +public class SparseTokenizeParameters extends AbstractSparseEncodingParameters { + + public static final String PARSE_FIELD_NAME = FunctionName.SPARSE_TOKENIZE.name(); + + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + SparseTokenizeParameters::parse + ); + + // Default constructor with LEXICAL format + public SparseTokenizeParameters() { + super(EmbeddingFormat.LEXICAL); + } + + @Builder(toBuilder = true) + public SparseTokenizeParameters(EmbeddingFormat sparseEncodingType) { + super(sparseEncodingType); + } + + /** + * Constructor for deserialization from StreamInput + */ + public SparseTokenizeParameters(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return PARSE_FIELD_NAME; + } + + public static MLAlgoParams parse(XContentParser parser) throws IOException { + EmbeddingFormat sparseEncodingType = parseCommon(parser); + return new SparseTokenizeParameters(sparseEncodingType); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java index 3d29d3fbef..eb2adb3e03 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java @@ -10,9 +10,9 @@ import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.input.parameter.textembedding.AbstractSparseEncodingParameters; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; -import org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -41,8 +41,12 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla for (String doc : textDocsInput.getDocs()) { Input input = new Input(); input.add(doc); - if (mlParams instanceof SparseEncodingParameters) { - input.add("sparse_encoding_format", ((SparseEncodingParameters) mlParams).getSparseEncodingType().name()); + if (mlParams instanceof AbstractSparseEncodingParameters) { + input + .add( + AbstractSparseEncodingParameters.EMBEDDING_FORMAT_FIELD, + ((AbstractSparseEncodingParameters) mlParams).getEmbeddingFormat().name() + ); } output = getPredictor().predict(input); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java index f67e124695..30e002749c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java @@ -6,7 +6,7 @@ package org.opensearch.ml.engine.algorithms.sparse_encoding; import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY; -import static org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters.SPARSE_ENCODING_FORMAT_FIELD; +import static org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters.EMBEDDING_FORMAT_FIELD; import java.util.ArrayList; import java.util.Collections; @@ -14,7 +14,7 @@ import java.util.List; import java.util.Map; -import org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AbstractSparseEncodingParameters; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.engine.algorithms.SentenceTransformerTranslator; @@ -29,9 +29,9 @@ public class SparseEncodingTranslator extends SentenceTransformerTranslator { @Override public NDList processInput(TranslatorContext ctx, Input input) { - String sparse_encoding_format = input.getAsString(SPARSE_ENCODING_FORMAT_FIELD); + String sparse_encoding_format = input.getAsString(EMBEDDING_FORMAT_FIELD); if (sparse_encoding_format != null) { - ctx.setAttachment(SPARSE_ENCODING_FORMAT_FIELD, sparse_encoding_format); + ctx.setAttachment(EMBEDDING_FORMAT_FIELD, sparse_encoding_format); } return super.processInput(ctx, input); } @@ -39,15 +39,15 @@ public NDList processInput(TranslatorContext ctx, Input input) { @Override public Output processOutput(TranslatorContext ctx, NDList list) { Output output = new Output(200, "OK"); - Object sparseEncodingFormatObject = ctx.getAttachment(SPARSE_ENCODING_FORMAT_FIELD); - String sparseEncodingFormatString = sparseEncodingFormatObject != null - ? sparseEncodingFormatObject.toString() - : SparseEncodingParameters.SparseEncodingFormat.WORD.name(); + Object embeddingFormatObject = ctx.getAttachment(EMBEDDING_FORMAT_FIELD); + String embeddingFormatString = embeddingFormatObject != null + ? embeddingFormatObject.toString() + : AbstractSparseEncodingParameters.EmbeddingFormat.LEXICAL.name(); List outputs = new ArrayList<>(); for (NDArray ndArray : list) { String name = ndArray.getName(); - Map tokenWeightsMap = convertOutput(ndArray, sparseEncodingFormatString); + Map tokenWeightsMap = convertOutput(ndArray, embeddingFormatString); Map wrappedMap = Map.of(ML_MAP_RESPONSE_KEY, Collections.singletonList(tokenWeightsMap)); ModelTensor tensor = ModelTensor.builder().name(name).dataAsMap(wrappedMap).build(); outputs.add(tensor); @@ -58,12 +58,12 @@ public Output processOutput(TranslatorContext ctx, NDList list) { return output; } - private Map convertOutput(NDArray array, String sparseEncodingFormat) { + private Map convertOutput(NDArray array, String embeddingFormat) { Map map = new HashMap<>(); NDArray nonZeroIndices = array.nonzero().squeeze(); for (long index : nonZeroIndices.toLongArray()) { - String s = sparseEncodingFormat.equals(SparseEncodingParameters.SparseEncodingFormat.INT.name()) + String s = embeddingFormat.equals(AbstractSparseEncodingParameters.EmbeddingFormat.VECTOR.name()) ? Long.toString(index) : this.tokenizer.decode(new long[] { index }, true); if (!s.isEmpty()) { diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index fb4aeb51a7..45ab4cec19 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -151,6 +151,7 @@ import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; import org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.SparseTokenizeParameters; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.settings.MLCommonsSettings; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; @@ -1096,7 +1097,8 @@ public List getNamedXContent() { LogisticRegressionParams.XCONTENT_REGISTRY, TextEmbeddingModelConfig.XCONTENT_REGISTRY, AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY, - SparseEncodingParameters.XCONTENT_REGISTRY + SparseEncodingParameters.XCONTENT_REGISTRY, + SparseTokenizeParameters.XCONTENT_REGISTRY ); } From 2ed473424bac7500ba49eb6aadd650402a769c63 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 2 Jul 2025 14:03:49 +0800 Subject: [PATCH 03/10] rename Signed-off-by: zhichao-aws --- .../textembedding/AbstractSparseEncodingParameters.java | 9 +++------ .../textembedding/SparseEncodingParameters.java | 8 ++++---- .../textembedding/SparseTokenizeParameters.java | 8 ++++---- .../sparse_encoding/SparseEncodingTranslator.java | 6 +++--- 4 files changed, 14 insertions(+), 17 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AbstractSparseEncodingParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AbstractSparseEncodingParameters.java index 6f3102a409..ff90f00bdd 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AbstractSparseEncodingParameters.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AbstractSparseEncodingParameters.java @@ -74,7 +74,7 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params * @throws IOException if parsing fails */ protected static EmbeddingFormat parseCommon(XContentParser parser) throws IOException { - EmbeddingFormat sparseEncodingType = null; + EmbeddingFormat embeddingFormat = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -83,16 +83,13 @@ protected static EmbeddingFormat parseCommon(XContentParser parser) throws IOExc if (fieldName.equals(EMBEDDING_FORMAT_FIELD)) { String contentType = parser.text(); - sparseEncodingType = EmbeddingFormat.valueOf(contentType.toUpperCase(Locale.ROOT)); + embeddingFormat = EmbeddingFormat.valueOf(contentType.toUpperCase(Locale.ROOT)); } else { parser.skipChildren(); } } // Return LEXICAL as default if not specified - return sparseEncodingType != null ? sparseEncodingType : EmbeddingFormat.LEXICAL; + return embeddingFormat != null ? embeddingFormat : EmbeddingFormat.LEXICAL; } - public EmbeddingFormat getEmbeddingFormat() { - return embeddingFormat; - } } diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java index 8f4db602b3..7bfb8cae70 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java @@ -34,8 +34,8 @@ public SparseEncodingParameters() { } @Builder(toBuilder = true) - public SparseEncodingParameters(EmbeddingFormat sparseEncodingType) { - super(sparseEncodingType); + public SparseEncodingParameters(EmbeddingFormat embeddingFormat) { + super(embeddingFormat); } /** @@ -51,7 +51,7 @@ public String getWriteableName() { } public static MLAlgoParams parse(XContentParser parser) throws IOException { - EmbeddingFormat sparseEncodingType = parseCommon(parser); - return new SparseEncodingParameters(sparseEncodingType); + EmbeddingFormat embeddingFormat = parseCommon(parser); + return new SparseEncodingParameters(embeddingFormat); } } diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseTokenizeParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseTokenizeParameters.java index 7a029930b8..62e35891c8 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseTokenizeParameters.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseTokenizeParameters.java @@ -34,8 +34,8 @@ public SparseTokenizeParameters() { } @Builder(toBuilder = true) - public SparseTokenizeParameters(EmbeddingFormat sparseEncodingType) { - super(sparseEncodingType); + public SparseTokenizeParameters(EmbeddingFormat embeddingFormat) { + super(embeddingFormat); } /** @@ -51,7 +51,7 @@ public String getWriteableName() { } public static MLAlgoParams parse(XContentParser parser) throws IOException { - EmbeddingFormat sparseEncodingType = parseCommon(parser); - return new SparseTokenizeParameters(sparseEncodingType); + EmbeddingFormat embeddingFormat = parseCommon(parser); + return new SparseTokenizeParameters(embeddingFormat); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java index 30e002749c..ebc4f08727 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java @@ -29,9 +29,9 @@ public class SparseEncodingTranslator extends SentenceTransformerTranslator { @Override public NDList processInput(TranslatorContext ctx, Input input) { - String sparse_encoding_format = input.getAsString(EMBEDDING_FORMAT_FIELD); - if (sparse_encoding_format != null) { - ctx.setAttachment(EMBEDDING_FORMAT_FIELD, sparse_encoding_format); + String embedding_format = input.getAsString(EMBEDDING_FORMAT_FIELD); + if (embedding_format != null) { + ctx.setAttachment(EMBEDDING_FORMAT_FIELD, embedding_format); } return super.processInput(ctx, input); } From 47171bc060429b59db87f73e8fade7f50d4766ef Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 2 Jul 2025 16:36:50 +0800 Subject: [PATCH 04/10] ml algorithms Signed-off-by: zhichao-aws --- .../ml/common/output/model/ModelTensor.java | 60 +++++++++++++++- .../engine/algorithms/TextEmbeddingModel.java | 8 +-- .../SparseEncodingTranslator.java | 46 +++++++++---- .../tokenize/SparseTokenizerModel.java | 69 ++++++++++++++----- 4 files changed, 145 insertions(+), 38 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java index 6d075ab205..7269d30e70 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java @@ -15,6 +15,7 @@ import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -26,6 +27,12 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import com.google.gson.JsonPrimitive; + import lombok.Builder; import lombok.Data; @@ -232,7 +239,7 @@ public ModelTensor(StreamInput in) throws IOException { this.result = in.readOptionalString(); if (in.readBoolean()) { String mapStr = in.readString(); - this.dataAsMap = gson.fromJson(mapStr, Map.class); + this.dataAsMap = parseMapPreservingNumberTypes(mapStr); } } @@ -289,4 +296,55 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } } + + /** + * Parse JSON string to Map while preserving number types (int vs double) + */ + private static Map parseMapPreservingNumberTypes(String jsonStr) { + JsonElement element = JsonParser.parseString(jsonStr); + return convertJsonElementToObject(element); + } + + @SuppressWarnings("unchecked") + private static T convertJsonElementToObject(JsonElement element) { + if (element.isJsonNull()) { + return null; + } else if (element.isJsonPrimitive()) { + JsonPrimitive primitive = element.getAsJsonPrimitive(); + if (primitive.isNumber()) { + // Preserve integer types + Number number = primitive.getAsNumber(); + if (number.toString().contains(".")) { + return (T) Double.valueOf(number.doubleValue()); + } else { + // Check if it fits in an int, otherwise use long + long longValue = number.longValue(); + if (longValue >= Integer.MIN_VALUE && longValue <= Integer.MAX_VALUE) { + return (T) Integer.valueOf((int) longValue); + } else { + return (T) Long.valueOf(longValue); + } + } + } else if (primitive.isBoolean()) { + return (T) Boolean.valueOf(primitive.getAsBoolean()); + } else { + return (T) primitive.getAsString(); + } + } else if (element.isJsonArray()) { + JsonArray array = element.getAsJsonArray(); + List list = new ArrayList<>(); + for (JsonElement arrayElement : array) { + list.add(convertJsonElementToObject(arrayElement)); + } + return (T) list; + } else if (element.isJsonObject()) { + JsonObject object = element.getAsJsonObject(); + Map map = new HashMap<>(); + for (Map.Entry entry : object.entrySet()) { + map.put(entry.getKey(), convertJsonElementToObject(entry.getValue())); + } + return (T) map; + } + return null; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java index eb2adb3e03..b14cc1261f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java @@ -10,9 +10,9 @@ import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import org.opensearch.ml.common.input.parameter.textembedding.AbstractSparseEncodingParameters; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; +import org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -41,11 +41,11 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla for (String doc : textDocsInput.getDocs()) { Input input = new Input(); input.add(doc); - if (mlParams instanceof AbstractSparseEncodingParameters) { + if (mlParams instanceof SparseEncodingParameters) { input .add( - AbstractSparseEncodingParameters.EMBEDDING_FORMAT_FIELD, - ((AbstractSparseEncodingParameters) mlParams).getEmbeddingFormat().name() + SparseEncodingParameters.EMBEDDING_FORMAT_FIELD, + ((SparseEncodingParameters) mlParams).getEmbeddingFormat().name() ); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java index ebc4f08727..aacb377b27 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java @@ -29,9 +29,9 @@ public class SparseEncodingTranslator extends SentenceTransformerTranslator { @Override public NDList processInput(TranslatorContext ctx, Input input) { - String embedding_format = input.getAsString(EMBEDDING_FORMAT_FIELD); - if (embedding_format != null) { - ctx.setAttachment(EMBEDDING_FORMAT_FIELD, embedding_format); + String embeddingFormat = input.getAsString(EMBEDDING_FORMAT_FIELD); + if (embeddingFormat != null) { + ctx.setAttachment(EMBEDDING_FORMAT_FIELD, embeddingFormat); } return super.processInput(ctx, input); } @@ -47,8 +47,8 @@ public Output processOutput(TranslatorContext ctx, NDList list) { List outputs = new ArrayList<>(); for (NDArray ndArray : list) { String name = ndArray.getName(); - Map tokenWeightsMap = convertOutput(ndArray, embeddingFormatString); - Map wrappedMap = Map.of(ML_MAP_RESPONSE_KEY, Collections.singletonList(tokenWeightsMap)); + Object result = convertOutput(ndArray, embeddingFormatString); + Map wrappedMap = Map.of(ML_MAP_RESPONSE_KEY, Collections.singletonList(result)); ModelTensor tensor = ModelTensor.builder().name(name).dataAsMap(wrappedMap).build(); outputs.add(tensor); } @@ -58,18 +58,36 @@ public Output processOutput(TranslatorContext ctx, NDList list) { return output; } - private Map convertOutput(NDArray array, String embeddingFormat) { - Map map = new HashMap<>(); + private Object convertOutput(NDArray array, String embeddingFormat) { NDArray nonZeroIndices = array.nonzero().squeeze(); + long[] indices = nonZeroIndices.toLongArray(); - for (long index : nonZeroIndices.toLongArray()) { - String s = embeddingFormat.equals(AbstractSparseEncodingParameters.EmbeddingFormat.VECTOR.name()) - ? Long.toString(index) - : this.tokenizer.decode(new long[] { index }, true); - if (!s.isEmpty()) { - map.put(s, array.getFloat(index)); + if (embeddingFormat.equals(AbstractSparseEncodingParameters.EmbeddingFormat.VECTOR.name())) { + // Return vector format: {"indices": [...], "values": [...]} + // Sort indices for vector format + java.util.Arrays.sort(indices); + List indicesList = new ArrayList<>(); + List valuesList = new ArrayList<>(); + + for (long index : indices) { + indicesList.add(index); + valuesList.add(array.getFloat(index)); + } + + Map vectorFormat = new HashMap<>(); + vectorFormat.put("indices", indicesList); + vectorFormat.put("values", valuesList); + return vectorFormat; + } else { + // Return lexical format: {"token": weight, ...} + Map tokenWeights = new HashMap<>(); + for (long index : indices) { + String token = this.tokenizer.decode(new long[] { index }, true); + if (!token.isEmpty()) { + tokenWeights.put(token, array.getFloat(index)); + } } + return tokenWeights; } - return map; } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java index 9b56076248..7bc28e5c5e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java @@ -22,8 +22,10 @@ import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.input.parameter.textembedding.AbstractSparseEncodingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.SparseTokenizeParameters; import org.opensearch.ml.common.model.MLModelConfig; -import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -67,28 +69,57 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla MLInputDataset inputDataSet = mlInput.getInputDataset(); List tensorOutputs = new ArrayList<>(); TextDocsInputDataSet textDocsInput = (TextDocsInputDataSet) inputDataSet; - ModelResultFilter resultFilter = textDocsInput.getResultFilter(); + + // Get the embedding format from parameters + MLAlgoParams parameters = mlInput.getParameters(); + AbstractSparseEncodingParameters.EmbeddingFormat embeddingFormat = AbstractSparseEncodingParameters.EmbeddingFormat.LEXICAL; // default + + if (parameters instanceof SparseTokenizeParameters) { + SparseTokenizeParameters sparseParams = (SparseTokenizeParameters) parameters; + embeddingFormat = sparseParams.getEmbeddingFormat(); + } + for (String doc : textDocsInput.getDocs()) { - Output output = new Output(200, "OK"); Encoding encodings = tokenizer.encode(doc); long[] indices = encodings.getIds(); - List outputs = new ArrayList<>(); - String[] tokens = Arrays - .stream(indices) - .distinct() - .mapToObj(value -> new long[] { value }) - .map(value -> this.tokenizer.decode(value, true)) - .filter(s -> !s.isEmpty()) - .toArray(String[]::new); - Map tokenWeights = Arrays - .stream(tokens) - .collect(Collectors.toMap(token -> token, token -> idf.getOrDefault(token, 1.0f))); - Map wrappedMap = Map.of(ML_MAP_RESPONSE_KEY, Collections.singletonList(tokenWeights)); + + Map wrappedMap; + if (embeddingFormat == AbstractSparseEncodingParameters.EmbeddingFormat.VECTOR) { + // Return vector format: {"indices": [...], "values": [...]} + // Get distinct token IDs and sort them for vector format + long[] uniqueIndices = Arrays.stream(indices).distinct().sorted().toArray(); + List indicesList = new ArrayList<>(); + List valuesList = new ArrayList<>(); + + for (long index : uniqueIndices) { + String token = this.tokenizer.decode(new long[] { index }, true); + if (!token.isEmpty()) { + indicesList.add(index); + valuesList.add(idf.getOrDefault(token, 1.0f)); + } + } + + Map vectorFormat = new HashMap<>(); + vectorFormat.put("indices", indicesList); + vectorFormat.put("values", valuesList); + wrappedMap = Map.of(ML_MAP_RESPONSE_KEY, Collections.singletonList(vectorFormat)); + } else { + // Return lexical format: {"token": weight, ...} + String[] tokens = Arrays + .stream(indices) + .distinct() + .mapToObj(value -> new long[] { value }) + .map(value -> this.tokenizer.decode(value, true)) + .filter(s -> !s.isEmpty()) + .toArray(String[]::new); + Map tokenWeights = Arrays + .stream(tokens) + .collect(Collectors.toMap(token -> token, token -> idf.getOrDefault(token, 1.0f))); + wrappedMap = Map.of(ML_MAP_RESPONSE_KEY, Collections.singletonList(tokenWeights)); + } + ModelTensor tensor = ModelTensor.builder().dataAsMap(wrappedMap).build(); - outputs.add(tensor); - ModelTensors modelTensorOutput = new ModelTensors(outputs); - output.add(modelTensorOutput.toBytes()); - tensorOutputs.add(parseModelTensorOutput(output, resultFilter)); + tensorOutputs.add(new ModelTensors(List.of(tensor))); } return new ModelTensorOutput(tensorOutputs); } From 3e73931f59ae93ddb7964a65e2fc7d8b13fa74bb Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 7 Jul 2025 15:56:14 +0800 Subject: [PATCH 05/10] update using AsymmetricTextEmbeddingParameters Signed-off-by: zhichao-aws --- .../ml/common/input/nlp/TextDocsMLInput.java | 5 +- .../AbstractSparseEncodingParameters.java | 95 ------------------- .../AsymmetricTextEmbeddingParameters.java | 54 ++++++++++- .../SparseEncodingParameters.java | 57 ----------- .../SparseTokenizeParameters.java | 57 ----------- .../engine/algorithms/TextEmbeddingModel.java | 13 +-- .../SparseEncodingTranslator.java | 29 +++--- .../TextEmbeddingSparseEncodingModel.java | 5 + .../tokenize/SparseTokenizerModel.java | 59 ++++-------- .../ml/plugin/MachineLearningPlugin.java | 6 +- 10 files changed, 99 insertions(+), 281 deletions(-) delete mode 100644 common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AbstractSparseEncodingParameters.java delete mode 100644 common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java delete mode 100644 common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseTokenizeParameters.java diff --git a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java index a818137ad7..1a2f201dd5 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java @@ -27,7 +27,10 @@ * ML input class which supports a list fo text docs. * This class can be used for TEXT_EMBEDDING model. */ -@org.opensearch.ml.common.annotation.MLInput(functionNames = { FunctionName.TEXT_EMBEDDING }) +@org.opensearch.ml.common.annotation.MLInput(functionNames = { + FunctionName.TEXT_EMBEDDING, + FunctionName.SPARSE_ENCODING, + FunctionName.SPARSE_TOKENIZE }) public class TextDocsMLInput extends MLInput { public static final String TEXT_DOCS_FIELD = "text_docs"; public static final String RESULT_FILTER_FIELD = "result_filter"; diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AbstractSparseEncodingParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AbstractSparseEncodingParameters.java deleted file mode 100644 index ff90f00bdd..0000000000 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AbstractSparseEncodingParameters.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.input.parameter.textembedding; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - -import java.io.IOException; -import java.util.Locale; - -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.input.parameter.MLAlgoParams; - -import lombok.Getter; - -/** - * Abstract base class for sparse encoding related parameters. - * Contains common logic shared between SPARSE_ENCODING and SPARSE_TOKENIZE algorithms. - */ -@Getter -public abstract class AbstractSparseEncodingParameters implements MLAlgoParams { - - public static final String EMBEDDING_FORMAT_FIELD = "embedding_format"; - - public enum EmbeddingFormat { - LEXICAL, - VECTOR - } - - // The type of the content to be encoded - protected final EmbeddingFormat embeddingFormat; - - protected AbstractSparseEncodingParameters(EmbeddingFormat embeddingFormat) { - // Set default to LEXICAL if null - this.embeddingFormat = embeddingFormat != null ? embeddingFormat : EmbeddingFormat.LEXICAL; - } - - /** - * Constructor for deserialization from StreamInput - */ - protected AbstractSparseEncodingParameters(StreamInput in) throws IOException { - String formatName = in.readOptionalString(); - this.embeddingFormat = formatName != null ? EmbeddingFormat.valueOf(formatName) : EmbeddingFormat.LEXICAL; - } - - @Override - public int getVersion() { - return 1; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalString(embeddingFormat.name()); - } - - @Override - public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { - xContentBuilder.startObject(); - xContentBuilder.field(EMBEDDING_FORMAT_FIELD, embeddingFormat.name()); - xContentBuilder.endObject(); - return xContentBuilder; - } - - /** - * Common parsing method that can be used by subclasses. - * - * @param parser XContentParser to parse from - * @return parsed EmbeddingFormat, defaults to LEXICAL if not specified - * @throws IOException if parsing fails - */ - protected static EmbeddingFormat parseCommon(XContentParser parser) throws IOException { - EmbeddingFormat embeddingFormat = null; - - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = parser.currentName(); - parser.nextToken(); - - if (fieldName.equals(EMBEDDING_FORMAT_FIELD)) { - String contentType = parser.text(); - embeddingFormat = EmbeddingFormat.valueOf(contentType.toUpperCase(Locale.ROOT)); - } else { - parser.skipChildren(); - } - } - // Return LEXICAL as default if not specified - return embeddingFormat != null ? embeddingFormat : EmbeddingFormat.LEXICAL; - } - -} diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java index f73b83e106..214273b35f 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java @@ -10,6 +10,7 @@ import java.io.IOException; import java.util.Locale; +import org.opensearch.Version; import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -31,9 +32,11 @@ *

* Use this parameter only if the model is asymmetric and has been registered with the corresponding * `query_prefix` and `passage_prefix` configuration parameters. + *

+ * Also supports embedding format control for sparse encoding algorithms. */ @Data -@MLAlgoParameter(algorithms = { FunctionName.TEXT_EMBEDDING }) +@MLAlgoParameter(algorithms = { FunctionName.TEXT_EMBEDDING, FunctionName.SPARSE_ENCODING, FunctionName.SPARSE_TOKENIZE }) public class AsymmetricTextEmbeddingParameters implements MLAlgoParams { public enum EmbeddingContentType { @@ -41,24 +44,54 @@ public enum EmbeddingContentType { PASSAGE } + public enum SparseEmbeddingFormat { + LEXICAL, + TOKEN_ID + } + public static final String PARSE_FIELD_NAME = FunctionName.TEXT_EMBEDDING.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( MLAlgoParams.class, new ParseField(PARSE_FIELD_NAME), it -> parse(it) ); + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY_SPARSE_ENCODING = new NamedXContentRegistry.Entry( + MLAlgoParams.class, + new ParseField(FunctionName.SPARSE_ENCODING.name()), + it -> parse(it) + ); + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY_SPARSE_TOKENIZE = new NamedXContentRegistry.Entry( + MLAlgoParams.class, + new ParseField(FunctionName.SPARSE_TOKENIZE.name()), + it -> parse(it) + ); @Builder(toBuilder = true) + public AsymmetricTextEmbeddingParameters(EmbeddingContentType embeddingContentType, SparseEmbeddingFormat sparseEmbeddingFormat) { + this.embeddingContentType = embeddingContentType; + this.sparseEmbeddingFormat = sparseEmbeddingFormat != null ? sparseEmbeddingFormat : SparseEmbeddingFormat.LEXICAL; + } + + // Constructor for backward compatibility public AsymmetricTextEmbeddingParameters(EmbeddingContentType embeddingContentType) { this.embeddingContentType = embeddingContentType; + this.sparseEmbeddingFormat = SparseEmbeddingFormat.LEXICAL; } public AsymmetricTextEmbeddingParameters(StreamInput in) throws IOException { + Version streamInputVersion = in.getVersion(); this.embeddingContentType = EmbeddingContentType.valueOf(in.readOptionalString()); + if (streamInputVersion.onOrAfter(Version.V_3_2_0)) { + String formatName = in.readOptionalString(); + this.sparseEmbeddingFormat = formatName != null ? SparseEmbeddingFormat.valueOf(formatName) : SparseEmbeddingFormat.LEXICAL; + } else { + this.sparseEmbeddingFormat = SparseEmbeddingFormat.LEXICAL; + } } public static MLAlgoParams parse(XContentParser parser) throws IOException { EmbeddingContentType embeddingContentType = null; + SparseEmbeddingFormat sparseEmbeddingFormat = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -70,19 +103,27 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException { String contentType = parser.text(); embeddingContentType = EmbeddingContentType.valueOf(contentType.toUpperCase(Locale.ROOT)); break; + case SPARSE_EMBEDDING_FORMAT_FIELD: + String formatType = parser.text(); + sparseEmbeddingFormat = SparseEmbeddingFormat.valueOf(formatType.toUpperCase(Locale.ROOT)); + break; default: parser.skipChildren(); break; } } - return new AsymmetricTextEmbeddingParameters(embeddingContentType); + return new AsymmetricTextEmbeddingParameters(embeddingContentType, sparseEmbeddingFormat); } public static final String EMBEDDING_CONTENT_TYPE_FIELD = "content_type"; + public static final String SPARSE_EMBEDDING_FORMAT_FIELD = "sparse_embedding_format"; // The type of the content to be embedded private EmbeddingContentType embeddingContentType; + // The format of the embedding output + private SparseEmbeddingFormat sparseEmbeddingFormat; + @Override public int getVersion() { return 1; @@ -95,7 +136,11 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { + Version streamOutputVersion = out.getVersion(); out.writeOptionalString(embeddingContentType.name()); + if (streamOutputVersion.onOrAfter(Version.V_3_2_0)) { + out.writeOptionalString(sparseEmbeddingFormat.name()); + } } @Override @@ -104,6 +149,7 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params if (embeddingContentType != null) { xContentBuilder.field(EMBEDDING_CONTENT_TYPE_FIELD, embeddingContentType.name()); } + xContentBuilder.field(SPARSE_EMBEDDING_FORMAT_FIELD, sparseEmbeddingFormat.name()); xContentBuilder.endObject(); return xContentBuilder; } @@ -111,4 +157,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params public EmbeddingContentType getEmbeddingContentType() { return embeddingContentType; } + + public SparseEmbeddingFormat getSparseEmbeddingFormat() { + return sparseEmbeddingFormat; + } } diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java deleted file mode 100644 index 7bfb8cae70..0000000000 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEncodingParameters.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.input.parameter.textembedding; - -import java.io.IOException; - -import org.opensearch.core.ParseField; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.annotation.MLAlgoParameter; -import org.opensearch.ml.common.input.parameter.MLAlgoParams; - -import lombok.Builder; - -@MLAlgoParameter(algorithms = { FunctionName.SPARSE_ENCODING }) -public class SparseEncodingParameters extends AbstractSparseEncodingParameters { - - public static final String PARSE_FIELD_NAME = FunctionName.SPARSE_ENCODING.name(); - - public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - SparseEncodingParameters::parse - ); - - // Default constructor with LEXICAL format - public SparseEncodingParameters() { - super(EmbeddingFormat.LEXICAL); - } - - @Builder(toBuilder = true) - public SparseEncodingParameters(EmbeddingFormat embeddingFormat) { - super(embeddingFormat); - } - - /** - * Constructor for deserialization from StreamInput - */ - public SparseEncodingParameters(StreamInput in) throws IOException { - super(in); - } - - @Override - public String getWriteableName() { - return PARSE_FIELD_NAME; - } - - public static MLAlgoParams parse(XContentParser parser) throws IOException { - EmbeddingFormat embeddingFormat = parseCommon(parser); - return new SparseEncodingParameters(embeddingFormat); - } -} diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseTokenizeParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseTokenizeParameters.java deleted file mode 100644 index 62e35891c8..0000000000 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseTokenizeParameters.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.input.parameter.textembedding; - -import java.io.IOException; - -import org.opensearch.core.ParseField; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.annotation.MLAlgoParameter; -import org.opensearch.ml.common.input.parameter.MLAlgoParams; - -import lombok.Builder; - -@MLAlgoParameter(algorithms = { FunctionName.SPARSE_TOKENIZE }) -public class SparseTokenizeParameters extends AbstractSparseEncodingParameters { - - public static final String PARSE_FIELD_NAME = FunctionName.SPARSE_TOKENIZE.name(); - - public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - SparseTokenizeParameters::parse - ); - - // Default constructor with LEXICAL format - public SparseTokenizeParameters() { - super(EmbeddingFormat.LEXICAL); - } - - @Builder(toBuilder = true) - public SparseTokenizeParameters(EmbeddingFormat embeddingFormat) { - super(embeddingFormat); - } - - /** - * Constructor for deserialization from StreamInput - */ - public SparseTokenizeParameters(StreamInput in) throws IOException { - super(in); - } - - @Override - public String getWriteableName() { - return PARSE_FIELD_NAME; - } - - public static MLAlgoParams parse(XContentParser parser) throws IOException { - EmbeddingFormat embeddingFormat = parseCommon(parser); - return new SparseTokenizeParameters(embeddingFormat); - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java index b14cc1261f..f6cedafa45 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/TextEmbeddingModel.java @@ -12,7 +12,6 @@ import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; -import org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -25,6 +24,7 @@ import ai.djl.translate.TranslateException; public abstract class TextEmbeddingModel extends DLModel { + protected boolean isSparseModel = false; @Override public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException { @@ -41,12 +41,9 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla for (String doc : textDocsInput.getDocs()) { Input input = new Input(); input.add(doc); - if (mlParams instanceof SparseEncodingParameters) { - input - .add( - SparseEncodingParameters.EMBEDDING_FORMAT_FIELD, - ((SparseEncodingParameters) mlParams).getEmbeddingFormat().name() - ); + if (mlParams instanceof AsymmetricTextEmbeddingParameters) { + AsymmetricTextEmbeddingParameters params = (AsymmetricTextEmbeddingParameters) mlParams; + input.add(AsymmetricTextEmbeddingParameters.SPARSE_EMBEDDING_FORMAT_FIELD, params.getSparseEmbeddingFormat().name()); } output = getPredictor().predict(input); @@ -55,7 +52,7 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla return new ModelTensorOutput(tensorOutputs); } - private boolean isAsymmetricModel(MLAlgoParams mlParams) { + protected boolean isAsymmetricModel(MLAlgoParams mlParams) { if (mlParams instanceof AsymmetricTextEmbeddingParameters) { // Check for the necessary prefixes in modelConfig if (modelConfig == null diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java index aacb377b27..af1b994d67 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java @@ -6,7 +6,7 @@ package org.opensearch.ml.engine.algorithms.sparse_encoding; import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY; -import static org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters.EMBEDDING_FORMAT_FIELD; +import static org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.SPARSE_EMBEDDING_FORMAT_FIELD; import java.util.ArrayList; import java.util.Collections; @@ -14,7 +14,7 @@ import java.util.List; import java.util.Map; -import org.opensearch.ml.common.input.parameter.textembedding.AbstractSparseEncodingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.engine.algorithms.SentenceTransformerTranslator; @@ -29,9 +29,9 @@ public class SparseEncodingTranslator extends SentenceTransformerTranslator { @Override public NDList processInput(TranslatorContext ctx, Input input) { - String embeddingFormat = input.getAsString(EMBEDDING_FORMAT_FIELD); + String embeddingFormat = input.getAsString(SPARSE_EMBEDDING_FORMAT_FIELD); if (embeddingFormat != null) { - ctx.setAttachment(EMBEDDING_FORMAT_FIELD, embeddingFormat); + ctx.setAttachment(SPARSE_EMBEDDING_FORMAT_FIELD, embeddingFormat); } return super.processInput(ctx, input); } @@ -39,10 +39,10 @@ public NDList processInput(TranslatorContext ctx, Input input) { @Override public Output processOutput(TranslatorContext ctx, NDList list) { Output output = new Output(200, "OK"); - Object embeddingFormatObject = ctx.getAttachment(EMBEDDING_FORMAT_FIELD); + Object embeddingFormatObject = ctx.getAttachment(SPARSE_EMBEDDING_FORMAT_FIELD); String embeddingFormatString = embeddingFormatObject != null ? embeddingFormatObject.toString() - : AbstractSparseEncodingParameters.EmbeddingFormat.LEXICAL.name(); + : AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.LEXICAL.name(); List outputs = new ArrayList<>(); for (NDArray ndArray : list) { @@ -62,22 +62,15 @@ private Object convertOutput(NDArray array, String embeddingFormat) { NDArray nonZeroIndices = array.nonzero().squeeze(); long[] indices = nonZeroIndices.toLongArray(); - if (embeddingFormat.equals(AbstractSparseEncodingParameters.EmbeddingFormat.VECTOR.name())) { - // Return vector format: {"indices": [...], "values": [...]} - // Sort indices for vector format - java.util.Arrays.sort(indices); - List indicesList = new ArrayList<>(); - List valuesList = new ArrayList<>(); + if (embeddingFormat.equals(AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.TOKEN_ID.name())) { + // Return token_id format: {"123": 1.1, "456": 2.2} + Map tokenIdWeights = new HashMap<>(); for (long index : indices) { - indicesList.add(index); - valuesList.add(array.getFloat(index)); + tokenIdWeights.put(String.valueOf(index), array.getFloat(index)); } - Map vectorFormat = new HashMap<>(); - vectorFormat.put("indices", indicesList); - vectorFormat.put("values", valuesList); - return vectorFormat; + return tokenIdWeights; } else { // Return lexical format: {"token": weight, ...} Map tokenWeights = new HashMap<>(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModel.java index 11221c840c..e6aad630dd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModel.java @@ -6,6 +6,7 @@ package org.opensearch.ml.engine.algorithms.sparse_encoding; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.engine.algorithms.TextEmbeddingModel; import org.opensearch.ml.engine.annotation.Function; @@ -30,4 +31,8 @@ public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig model return null; } + @Override + protected boolean isAsymmetricModel(MLAlgoParams mlParams) { + return false; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java index 7bc28e5c5e..71bebc2cf6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java @@ -23,8 +23,7 @@ import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import org.opensearch.ml.common.input.parameter.textembedding.AbstractSparseEncodingParameters; -import org.opensearch.ml.common.input.parameter.textembedding.SparseTokenizeParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; @@ -72,52 +71,34 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla // Get the embedding format from parameters MLAlgoParams parameters = mlInput.getParameters(); - AbstractSparseEncodingParameters.EmbeddingFormat embeddingFormat = AbstractSparseEncodingParameters.EmbeddingFormat.LEXICAL; // default + AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat sparseEmbeddingFormat = + AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.LEXICAL; // default - if (parameters instanceof SparseTokenizeParameters) { - SparseTokenizeParameters sparseParams = (SparseTokenizeParameters) parameters; - embeddingFormat = sparseParams.getEmbeddingFormat(); + if (parameters instanceof AsymmetricTextEmbeddingParameters) { + AsymmetricTextEmbeddingParameters sparseParams = (AsymmetricTextEmbeddingParameters) parameters; + sparseEmbeddingFormat = sparseParams.getSparseEmbeddingFormat(); } for (String doc : textDocsInput.getDocs()) { Encoding encodings = tokenizer.encode(doc); long[] indices = encodings.getIds(); - - Map wrappedMap; - if (embeddingFormat == AbstractSparseEncodingParameters.EmbeddingFormat.VECTOR) { - // Return vector format: {"indices": [...], "values": [...]} - // Get distinct token IDs and sort them for vector format - long[] uniqueIndices = Arrays.stream(indices).distinct().sorted().toArray(); - List indicesList = new ArrayList<>(); - List valuesList = new ArrayList<>(); - - for (long index : uniqueIndices) { - String token = this.tokenizer.decode(new long[] { index }, true); - if (!token.isEmpty()) { - indicesList.add(index); - valuesList.add(idf.getOrDefault(token, 1.0f)); - } + long[] uniqueIndices = Arrays.stream(indices).distinct().toArray(); + String[] tokens = Arrays.stream(uniqueIndices).mapToObj(value -> this.tokenizer.decode(new long[] { value }, true)).toArray(String[]::new); + + Map tokenWeights = new HashMap<>(); + for (int i = 0; i < uniqueIndices.length; i++) { + String token = tokens[i]; + if (token.isEmpty()) { + continue; + } + if (sparseEmbeddingFormat == AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.TOKEN_ID) { + tokenWeights.put(String.valueOf(uniqueIndices[i]), idf.getOrDefault(token, 1.0f)); + } else { + tokenWeights.put(token, idf.getOrDefault(token, 1.0f)); } - - Map vectorFormat = new HashMap<>(); - vectorFormat.put("indices", indicesList); - vectorFormat.put("values", valuesList); - wrappedMap = Map.of(ML_MAP_RESPONSE_KEY, Collections.singletonList(vectorFormat)); - } else { - // Return lexical format: {"token": weight, ...} - String[] tokens = Arrays - .stream(indices) - .distinct() - .mapToObj(value -> new long[] { value }) - .map(value -> this.tokenizer.decode(value, true)) - .filter(s -> !s.isEmpty()) - .toArray(String[]::new); - Map tokenWeights = Arrays - .stream(tokens) - .collect(Collectors.toMap(token -> token, token -> idf.getOrDefault(token, 1.0f))); - wrappedMap = Map.of(ML_MAP_RESPONSE_KEY, Collections.singletonList(tokenWeights)); } + Map wrappedMap = Map.of(ML_MAP_RESPONSE_KEY, Collections.singletonList(tokenWeights)); ModelTensor tensor = ModelTensor.builder().dataAsMap(wrappedMap).build(); tensorOutputs.add(new ModelTensors(List.of(tensor))); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 45ab4cec19..f662f8a28f 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -150,8 +150,6 @@ import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams; import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; -import org.opensearch.ml.common.input.parameter.textembedding.SparseEncodingParameters; -import org.opensearch.ml.common.input.parameter.textembedding.SparseTokenizeParameters; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.settings.MLCommonsSettings; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; @@ -1097,8 +1095,8 @@ public List getNamedXContent() { LogisticRegressionParams.XCONTENT_REGISTRY, TextEmbeddingModelConfig.XCONTENT_REGISTRY, AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY, - SparseEncodingParameters.XCONTENT_REGISTRY, - SparseTokenizeParameters.XCONTENT_REGISTRY + AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY_SPARSE_ENCODING, + AsymmetricTextEmbeddingParameters.XCONTENT_REGISTRY_SPARSE_TOKENIZE ); } From d2f760a15f54dcaa3fb8e1e552bba6d2763b60f2 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 7 Jul 2025 16:01:09 +0800 Subject: [PATCH 06/10] restore changes in modeltensor Signed-off-by: zhichao-aws --- .../ml/common/output/model/ModelTensor.java | 60 +------------------ 1 file changed, 1 insertion(+), 59 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java index 7269d30e70..6d075ab205 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java @@ -15,7 +15,6 @@ import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; @@ -27,12 +26,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import com.google.gson.JsonArray; -import com.google.gson.JsonElement; -import com.google.gson.JsonObject; -import com.google.gson.JsonParser; -import com.google.gson.JsonPrimitive; - import lombok.Builder; import lombok.Data; @@ -239,7 +232,7 @@ public ModelTensor(StreamInput in) throws IOException { this.result = in.readOptionalString(); if (in.readBoolean()) { String mapStr = in.readString(); - this.dataAsMap = parseMapPreservingNumberTypes(mapStr); + this.dataAsMap = gson.fromJson(mapStr, Map.class); } } @@ -296,55 +289,4 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } } - - /** - * Parse JSON string to Map while preserving number types (int vs double) - */ - private static Map parseMapPreservingNumberTypes(String jsonStr) { - JsonElement element = JsonParser.parseString(jsonStr); - return convertJsonElementToObject(element); - } - - @SuppressWarnings("unchecked") - private static T convertJsonElementToObject(JsonElement element) { - if (element.isJsonNull()) { - return null; - } else if (element.isJsonPrimitive()) { - JsonPrimitive primitive = element.getAsJsonPrimitive(); - if (primitive.isNumber()) { - // Preserve integer types - Number number = primitive.getAsNumber(); - if (number.toString().contains(".")) { - return (T) Double.valueOf(number.doubleValue()); - } else { - // Check if it fits in an int, otherwise use long - long longValue = number.longValue(); - if (longValue >= Integer.MIN_VALUE && longValue <= Integer.MAX_VALUE) { - return (T) Integer.valueOf((int) longValue); - } else { - return (T) Long.valueOf(longValue); - } - } - } else if (primitive.isBoolean()) { - return (T) Boolean.valueOf(primitive.getAsBoolean()); - } else { - return (T) primitive.getAsString(); - } - } else if (element.isJsonArray()) { - JsonArray array = element.getAsJsonArray(); - List list = new ArrayList<>(); - for (JsonElement arrayElement : array) { - list.add(convertJsonElementToObject(arrayElement)); - } - return (T) list; - } else if (element.isJsonObject()) { - JsonObject object = element.getAsJsonObject(); - Map map = new HashMap<>(); - for (Map.Entry entry : object.entrySet()) { - map.put(entry.getKey(), convertJsonElementToObject(entry.getValue())); - } - return (T) map; - } - return null; - } } From 7a31a8cb0b39ee264925a7561effe9718db5e88d Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 7 Jul 2025 17:01:43 +0800 Subject: [PATCH 07/10] add test Signed-off-by: zhichao-aws --- .../AsymmetricTextEmbeddingParameters.java | 21 ++- ...AsymmetricTextEmbeddingParametersTest.java | 126 ++++++++++++++ .../tokenize/SparseTokenizerModel.java | 6 +- .../TextEmbeddingSparseEncodingModelTest.java | 150 +++++++++++++++++ .../tokenize/SparseTokenizerModelTest.java | 159 ++++++++++++++++++ 5 files changed, 457 insertions(+), 5 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java index 214273b35f..c7a30a4541 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java @@ -9,6 +9,7 @@ import java.io.IOException; import java.util.Locale; +import java.util.Objects; import org.opensearch.Version; import org.opensearch.core.ParseField; @@ -80,7 +81,8 @@ public AsymmetricTextEmbeddingParameters(EmbeddingContentType embeddingContentTy public AsymmetricTextEmbeddingParameters(StreamInput in) throws IOException { Version streamInputVersion = in.getVersion(); - this.embeddingContentType = EmbeddingContentType.valueOf(in.readOptionalString()); + String contentType = in.readOptionalString(); + this.embeddingContentType = contentType != null ? EmbeddingContentType.valueOf(contentType) : null; if (streamInputVersion.onOrAfter(Version.V_3_2_0)) { String formatName = in.readOptionalString(); this.sparseEmbeddingFormat = formatName != null ? SparseEmbeddingFormat.valueOf(formatName) : SparseEmbeddingFormat.LEXICAL; @@ -137,9 +139,9 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { Version streamOutputVersion = out.getVersion(); - out.writeOptionalString(embeddingContentType.name()); + out.writeOptionalString(embeddingContentType != null ? embeddingContentType.name() : null); if (streamOutputVersion.onOrAfter(Version.V_3_2_0)) { - out.writeOptionalString(sparseEmbeddingFormat.name()); + out.writeOptionalString(sparseEmbeddingFormat != null ? sparseEmbeddingFormat.name() : null); } } @@ -161,4 +163,17 @@ public EmbeddingContentType getEmbeddingContentType() { public SparseEmbeddingFormat getSparseEmbeddingFormat() { return sparseEmbeddingFormat; } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + AsymmetricTextEmbeddingParameters other = (AsymmetricTextEmbeddingParameters) obj; + return Objects.equals(embeddingContentType, other.embeddingContentType) + && Objects.equals(sparseEmbeddingFormat, other.sparseEmbeddingFormat); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java index b949208472..808a84aa10 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java @@ -1,6 +1,7 @@ package org.opensearch.ml.common.dataset; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.opensearch.ml.common.TestHelper.contentObjectToString; import static org.opensearch.ml.common.TestHelper.testParseFromString; @@ -11,12 +12,14 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat; public class AsymmetricTextEmbeddingParametersTest { @@ -74,6 +77,129 @@ public void readInputStream_Success_EmptyParams() throws IOException { readInputStream(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build()); } + @Test + public void parse_AsymmetricTextEmbeddingParameters_WithSparseEmbeddingFormat_LEXICAL() throws IOException { + AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters + .builder() + .embeddingContentType(EmbeddingContentType.QUERY) + .sparseEmbeddingFormat(SparseEmbeddingFormat.LEXICAL) + .build(); + TestHelper.testParse(params, function); + } + + @Test + public void parse_AsymmetricTextEmbeddingParameters_WithSparseEmbeddingFormat_TOKEN_ID() throws IOException { + AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters + .builder() + .embeddingContentType(EmbeddingContentType.PASSAGE) + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + TestHelper.testParse(params, function); + } + + @Test + public void parse_AsymmetricTextEmbeddingParameters_OnlySparseEmbeddingFormat() throws IOException { + AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters + .builder() + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + TestHelper.testParse(params, function); + } + + @Test + public void parse_AsymmetricTextEmbeddingParameters_SparseEmbeddingFormat_Invalid() throws IOException { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule + .expectMessage( + "No enum constant org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.INVALID" + ); + String jsonWithInvalidFormat = "{\"content_type\": \"QUERY\", \"sparse_embedding_format\": \"INVALID\"}"; + testParseFromString(params, jsonWithInvalidFormat, function); + } + + @Test + public void constructor_BackwardCompatibility() { + AsymmetricTextEmbeddingParameters params = new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY); + assertEquals(EmbeddingContentType.QUERY, params.getEmbeddingContentType()); + assertEquals(SparseEmbeddingFormat.LEXICAL, params.getSparseEmbeddingFormat()); + } + + @Test + public void constructor_WithSparseEmbeddingFormat() { + AsymmetricTextEmbeddingParameters params = new AsymmetricTextEmbeddingParameters( + EmbeddingContentType.PASSAGE, + SparseEmbeddingFormat.TOKEN_ID + ); + assertEquals(EmbeddingContentType.PASSAGE, params.getEmbeddingContentType()); + assertEquals(SparseEmbeddingFormat.TOKEN_ID, params.getSparseEmbeddingFormat()); + } + + @Test + public void constructor_WithNullSparseEmbeddingFormat_DefaultsToLexical() { + AsymmetricTextEmbeddingParameters params = new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY, null); + assertEquals(EmbeddingContentType.QUERY, params.getEmbeddingContentType()); + assertEquals(SparseEmbeddingFormat.LEXICAL, params.getSparseEmbeddingFormat()); + } + + @Test + public void constructor_NullContentType_WithSparseEmbeddingFormat() { + AsymmetricTextEmbeddingParameters params = new AsymmetricTextEmbeddingParameters(null, SparseEmbeddingFormat.TOKEN_ID); + assertNull(params.getEmbeddingContentType()); + assertEquals(SparseEmbeddingFormat.TOKEN_ID, params.getSparseEmbeddingFormat()); + } + + @Test + public void readInputStream_WithSparseEmbeddingFormat() throws IOException { + AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters + .builder() + .embeddingContentType(EmbeddingContentType.PASSAGE) + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + readInputStream(params); + } + + @Test + public void readInputStream_OnlySparseEmbeddingFormat() throws IOException { + AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters + .builder() + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + readInputStream(params); + } + + @Test + public void readInputStream_VersionCompatibility_Pre_V_3_2_0() throws IOException { + AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters + .builder() + .embeddingContentType(EmbeddingContentType.QUERY) + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + bytesStreamOutput.setVersion(Version.V_3_1_0); + params.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + streamInput.setVersion(Version.V_3_1_0); + AsymmetricTextEmbeddingParameters parsedParams = new AsymmetricTextEmbeddingParameters(streamInput); + + assertEquals(EmbeddingContentType.QUERY, parsedParams.getEmbeddingContentType()); + assertEquals(SparseEmbeddingFormat.LEXICAL, parsedParams.getSparseEmbeddingFormat()); + } + + @Test + public void toXContent_IncludesSparseEmbeddingFormat() throws IOException { + AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters + .builder() + .embeddingContentType(EmbeddingContentType.QUERY) + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + + String jsonStr = contentObjectToString(params); + assert (jsonStr.contains("\"content_type\":\"QUERY\"")); + assert (jsonStr.contains("\"sparse_embedding_format\":\"TOKEN_ID\"")); + } + private void readInputStream(AsymmetricTextEmbeddingParameters params) throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); params.writeTo(bytesStreamOutput); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java index 71bebc2cf6..6326a8e772 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java @@ -16,7 +16,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; @@ -83,7 +82,10 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla Encoding encodings = tokenizer.encode(doc); long[] indices = encodings.getIds(); long[] uniqueIndices = Arrays.stream(indices).distinct().toArray(); - String[] tokens = Arrays.stream(uniqueIndices).mapToObj(value -> this.tokenizer.decode(new long[] { value }, true)).toArray(String[]::new); + String[] tokens = Arrays + .stream(uniqueIndices) + .mapToObj(value -> this.tokenizer.decode(new long[] { value }, true)) + .toArray(String[]::new); Map tokenWeights = new HashMap<>(); for (int i = 0; i < uniqueIndices.length; i++) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModelTest.java index 2fc6d9f89a..484644ae99 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModelTest.java @@ -1,6 +1,7 @@ package org.opensearch.ml.engine.algorithms.sparse_encoding; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.mockito.Mockito.*; import static org.opensearch.ml.engine.algorithms.DLModel.*; @@ -29,6 +30,8 @@ import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -187,6 +190,153 @@ public void initModel_predict_TorchScript_SparseEncoding_ResultFilter() { textEmbeddingSparseEncodingModel.close(); } + // Test AsymmetricTextEmbeddingParameters with LEXICAL format + @Test + public void initModel_predict_SparseEncoding_WithLexicalFormat() { + textEmbeddingSparseEncodingModel.initModel(model, params, encryptor); + + AsymmetricTextEmbeddingParameters parameters = AsymmetricTextEmbeddingParameters + .builder() + .sparseEmbeddingFormat(SparseEmbeddingFormat.LEXICAL) + .build(); + + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.SPARSE_ENCODING) + .inputDataset(inputDataSet) + .parameters(parameters) + .build(); + + ModelTensorOutput output = (ModelTensorOutput) textEmbeddingSparseEncodingModel.predict(mlInput); + List mlModelOutputs = output.getMlModelOutputs(); + assertEquals(2, mlModelOutputs.size()); + + for (ModelTensors tensors : mlModelOutputs) { + List mlModelTensors = tensors.getMlModelTensors(); + assertEquals(1, mlModelTensors.size()); + ModelTensor tensor = mlModelTensors.get(0); + assertNotNull(tensor.getDataAsMap()); + } + textEmbeddingSparseEncodingModel.close(); + } + + // Test AsymmetricTextEmbeddingParameters with TOKEN_ID format + @Test + public void initModel_predict_SparseEncoding_WithTokenIdFormat() { + textEmbeddingSparseEncodingModel.initModel(model, params, encryptor); + + AsymmetricTextEmbeddingParameters parameters = AsymmetricTextEmbeddingParameters + .builder() + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.SPARSE_ENCODING) + .inputDataset(inputDataSet) + .parameters(parameters) + .build(); + + ModelTensorOutput output = (ModelTensorOutput) textEmbeddingSparseEncodingModel.predict(mlInput); + List mlModelOutputs = output.getMlModelOutputs(); + assertEquals(2, mlModelOutputs.size()); + + for (ModelTensors tensors : mlModelOutputs) { + List mlModelTensors = tensors.getMlModelTensors(); + assertEquals(1, mlModelTensors.size()); + ModelTensor tensor = mlModelTensors.get(0); + assertNotNull(tensor.getDataAsMap()); + } + textEmbeddingSparseEncodingModel.close(); + } + + // Test both content_type and sparse_embedding_format parameters + @Test + public void initModel_predict_SparseEncoding_WithBothParameters() { + textEmbeddingSparseEncodingModel.initModel(model, params, encryptor); + + AsymmetricTextEmbeddingParameters parameters = AsymmetricTextEmbeddingParameters + .builder() + .embeddingContentType(AsymmetricTextEmbeddingParameters.EmbeddingContentType.QUERY) + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.SPARSE_ENCODING) + .inputDataset(inputDataSet) + .parameters(parameters) + .build(); + + ModelTensorOutput output = (ModelTensorOutput) textEmbeddingSparseEncodingModel.predict(mlInput); + List mlModelOutputs = output.getMlModelOutputs(); + assertEquals(2, mlModelOutputs.size()); + + for (ModelTensors tensors : mlModelOutputs) { + List mlModelTensors = tensors.getMlModelTensors(); + assertEquals(1, mlModelTensors.size()); + ModelTensor tensor = mlModelTensors.get(0); + assertNotNull(tensor.getDataAsMap()); + } + textEmbeddingSparseEncodingModel.close(); + } + + // Test default parameters behavior (no AsymmetricTextEmbeddingParameters) + @Test + public void initModel_predict_SparseEncoding_WithoutParameters() { + textEmbeddingSparseEncodingModel.initModel(model, params, encryptor); + + MLInput mlInput = MLInput.builder().algorithm(FunctionName.SPARSE_ENCODING).inputDataset(inputDataSet).build(); + + ModelTensorOutput output = (ModelTensorOutput) textEmbeddingSparseEncodingModel.predict(mlInput); + List mlModelOutputs = output.getMlModelOutputs(); + assertEquals(2, mlModelOutputs.size()); + + for (ModelTensors tensors : mlModelOutputs) { + List mlModelTensors = tensors.getMlModelTensors(); + assertEquals(1, mlModelTensors.size()); + ModelTensor tensor = mlModelTensors.get(0); + assertNotNull(tensor.getDataAsMap()); + } + textEmbeddingSparseEncodingModel.close(); + } + + // Test isAsymmetricModel method override returns false + @Test + public void test_isAsymmetricModel_ReturnsFalse() { + AsymmetricTextEmbeddingParameters parameters = AsymmetricTextEmbeddingParameters + .builder() + .embeddingContentType(AsymmetricTextEmbeddingParameters.EmbeddingContentType.QUERY) + .sparseEmbeddingFormat(SparseEmbeddingFormat.LEXICAL) + .build(); + + // Test that isAsymmetricModel returns false even with AsymmetricTextEmbeddingParameters + boolean isAsymmetric = textEmbeddingSparseEncodingModel.isAsymmetricModel(parameters); + assertFalse("isAsymmetricModel should return false for sparse encoding model", isAsymmetric); + } + + // Test isAsymmetricModel with null parameters + @Test + public void test_isAsymmetricModel_WithNullParameters() { + boolean isAsymmetric = textEmbeddingSparseEncodingModel.isAsymmetricModel(null); + assertFalse("isAsymmetricModel should return false with null parameters", isAsymmetric); + } + + // Test isSparseModel field default value + @Test + public void test_isSparseModel_DefaultValue() { + // Test that the protected field isSparseModel defaults to false + // This indirectly tests the field existence and default value + textEmbeddingSparseEncodingModel.initModel(model, params, encryptor); + + // The field is tested implicitly through model behavior + MLInput mlInput = MLInput.builder().algorithm(FunctionName.SPARSE_ENCODING).inputDataset(inputDataSet).build(); + ModelTensorOutput output = (ModelTensorOutput) textEmbeddingSparseEncodingModel.predict(mlInput); + assertNotNull("Model should predict successfully with default isSparseModel value", output); + + textEmbeddingSparseEncodingModel.close(); + } + @Test public void initModel_NullModelZipFile() { exceptionRule.expect(IllegalArgumentException.class); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModelTest.java index eac4af205c..fb9ae65fbc 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModelTest.java @@ -2,6 +2,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -26,6 +27,8 @@ import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; @@ -163,6 +166,147 @@ public void initModel_predict_Tokenize() throws URISyntaxException, TranslateExc } } + // Test default LEXICAL format (no parameters provided) + @Test + public void initModel_predict_Tokenize_DefaultLexicalFormat() throws URISyntaxException, TranslateException { + sparseTokenizerModel.initModel(model, params, encryptor); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.SPARSE_TOKENIZE).inputDataset(inputDataSet).build(); + ModelTensorOutput output = (ModelTensorOutput) sparseTokenizerModel.predict(mlInput); + + List mlModelOutputs = output.getMlModelOutputs(); + assertEquals(2, mlModelOutputs.size()); + + // Verify output format is LEXICAL (token strings as keys) + for (ModelTensors tensors : mlModelOutputs) { + List mlModelTensors = tensors.getMlModelTensors(); + assertEquals(1, mlModelTensors.size()); + ModelTensor tensor = mlModelTensors.get(0); + Map resultMap = tensor.getDataAsMap(); + List> resultList = (List>) resultMap.get("response"); + Map result = resultList.get(0); + + // Verify keys are token strings rather than numeric IDs + for (String key : result.keySet()) { + assertTrue("Key should be a token string, not numeric ID", !isNumeric(key)); + } + } + } + + // Test LEXICAL format with explicit parameter + @Test + public void initModel_predict_Tokenize_WithLexicalFormat() throws URISyntaxException, TranslateException { + sparseTokenizerModel.initModel(model, params, encryptor); + + AsymmetricTextEmbeddingParameters parameters = AsymmetricTextEmbeddingParameters + .builder() + .sparseEmbeddingFormat(SparseEmbeddingFormat.LEXICAL) + .build(); + + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.SPARSE_TOKENIZE) + .inputDataset(inputDataSet) + .parameters(parameters) + .build(); + + ModelTensorOutput output = (ModelTensorOutput) sparseTokenizerModel.predict(mlInput); + + List mlModelOutputs = output.getMlModelOutputs(); + assertEquals(2, mlModelOutputs.size()); + + // Verify output format is LEXICAL + for (ModelTensors tensors : mlModelOutputs) { + List mlModelTensors = tensors.getMlModelTensors(); + assertEquals(1, mlModelTensors.size()); + ModelTensor tensor = mlModelTensors.get(0); + Map resultMap = tensor.getDataAsMap(); + List> resultList = (List>) resultMap.get("response"); + Map result = resultList.get(0); + + // Verify keys are token strings + for (String key : result.keySet()) { + assertTrue("Key should be a token string for LEXICAL format", !isNumeric(key)); + } + } + } + + // Test TOKEN_ID format + @Test + public void initModel_predict_Tokenize_WithTokenIdFormat() throws URISyntaxException, TranslateException { + sparseTokenizerModel.initModel(model, params, encryptor); + + AsymmetricTextEmbeddingParameters parameters = AsymmetricTextEmbeddingParameters + .builder() + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.SPARSE_TOKENIZE) + .inputDataset(inputDataSet) + .parameters(parameters) + .build(); + + ModelTensorOutput output = (ModelTensorOutput) sparseTokenizerModel.predict(mlInput); + + List mlModelOutputs = output.getMlModelOutputs(); + assertEquals(2, mlModelOutputs.size()); + + // Verify output format is TOKEN_ID + for (ModelTensors tensors : mlModelOutputs) { + List mlModelTensors = tensors.getMlModelTensors(); + assertEquals(1, mlModelTensors.size()); + ModelTensor tensor = mlModelTensors.get(0); + Map resultMap = tensor.getDataAsMap(); + List> resultList = (List>) resultMap.get("response"); + Map result = resultList.get(0); + + // Verify keys are numeric token ID strings + for (String key : result.keySet()) { + assertTrue("Key should be a numeric token ID for TOKEN_ID format", isNumeric(key)); + } + } + } + + // Test both content_type and sparse_embedding_format parameters + @Test + public void initModel_predict_Tokenize_WithBothContentTypeAndFormat() throws URISyntaxException, TranslateException { + sparseTokenizerModel.initModel(model, params, encryptor); + + AsymmetricTextEmbeddingParameters parameters = AsymmetricTextEmbeddingParameters + .builder() + .embeddingContentType(AsymmetricTextEmbeddingParameters.EmbeddingContentType.QUERY) + .sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID) + .build(); + + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.SPARSE_TOKENIZE) + .inputDataset(inputDataSet) + .parameters(parameters) + .build(); + + ModelTensorOutput output = (ModelTensorOutput) sparseTokenizerModel.predict(mlInput); + + List mlModelOutputs = output.getMlModelOutputs(); + assertEquals(2, mlModelOutputs.size()); + + // Verify output format is TOKEN_ID (sparse_embedding_format parameter takes effect) + for (ModelTensors tensors : mlModelOutputs) { + List mlModelTensors = tensors.getMlModelTensors(); + assertEquals(1, mlModelTensors.size()); + ModelTensor tensor = mlModelTensors.get(0); + Map resultMap = tensor.getDataAsMap(); + List> resultList = (List>) resultMap.get("response"); + Map result = resultList.get(0); + + // Verify keys are numeric token ID strings + for (String key : result.keySet()) { + assertTrue("Key should be a numeric token ID for TOKEN_ID format", isNumeric(key)); + } + } + } + @Test public void initModel_NullModelHelper() throws URISyntaxException { exceptionRule.expect(IllegalArgumentException.class); @@ -257,6 +401,21 @@ public void predict_BeforeInitingModel() { sparseTokenizerModel.predict(MLInput.builder().algorithm(FunctionName.SPARSE_TOKENIZE).inputDataset(inputDataSet).build(), model); } + /** + * Helper method to check if a string is numeric + */ + private boolean isNumeric(String str) { + if (str == null || str.isEmpty()) { + return false; + } + try { + Integer.parseInt(str); + return true; + } catch (NumberFormatException e) { + return false; + } + } + @After public void tearDown() { FileUtils.deleteFileQuietly(mlCachePath); From 87ecf95332d9d7ed15e8e8ebb3830c7b53c70b8c Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 9 Jul 2025 09:05:33 +0800 Subject: [PATCH 08/10] address comments Signed-off-by: zhichao-aws --- .../sparse_encoding/SparseEncodingTranslator.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java index af1b994d67..590aeceb4e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java @@ -40,14 +40,14 @@ public NDList processInput(TranslatorContext ctx, Input input) { public Output processOutput(TranslatorContext ctx, NDList list) { Output output = new Output(200, "OK"); Object embeddingFormatObject = ctx.getAttachment(SPARSE_EMBEDDING_FORMAT_FIELD); - String embeddingFormatString = embeddingFormatObject != null - ? embeddingFormatObject.toString() - : AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.LEXICAL.name(); + AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat embeddingFormat = embeddingFormatObject != null + ? AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.valueOf(embeddingFormatObject.toString()) + : AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.LEXICAL; List outputs = new ArrayList<>(); for (NDArray ndArray : list) { String name = ndArray.getName(); - Object result = convertOutput(ndArray, embeddingFormatString); + Object result = convertOutput(ndArray, embeddingFormat); Map wrappedMap = Map.of(ML_MAP_RESPONSE_KEY, Collections.singletonList(result)); ModelTensor tensor = ModelTensor.builder().name(name).dataAsMap(wrappedMap).build(); outputs.add(tensor); @@ -58,11 +58,11 @@ public Output processOutput(TranslatorContext ctx, NDList list) { return output; } - private Object convertOutput(NDArray array, String embeddingFormat) { + private Object convertOutput(NDArray array, AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat embeddingFormat) { NDArray nonZeroIndices = array.nonzero().squeeze(); long[] indices = nonZeroIndices.toLongArray(); - if (embeddingFormat.equals(AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.TOKEN_ID.name())) { + if (embeddingFormat == AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.TOKEN_ID) { // Return token_id format: {"123": 1.1, "456": 2.2} Map tokenIdWeights = new HashMap<>(); From 1440324d6128ae84cf842eda1fa68c61159bf0be Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 14 Jul 2025 16:53:29 +0800 Subject: [PATCH 09/10] move sparse embedding format a standalone enum; add param attribute for analyzer Signed-off-by: zhichao-aws --- .../AsymmetricTextEmbeddingParameters.java | 13 ++++------- .../textembedding/SparseEmbeddingFormat.java | 14 +++++++++++ ...AsymmetricTextEmbeddingParametersTest.java | 14 +++++------ .../SparseEncodingTranslator.java | 14 +++++------ .../tokenize/SparseTokenizerModel.java | 6 ++--- .../ml/engine/analysis/HFModelTokenizer.java | 23 +++++++++++++++++-- .../TextEmbeddingSparseEncodingModelTest.java | 8 +++---- .../tokenize/SparseTokenizerModelTest.java | 14 +++++------ 8 files changed, 66 insertions(+), 40 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEmbeddingFormat.java diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java index c7a30a4541..bbeceed464 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/AsymmetricTextEmbeddingParameters.java @@ -45,11 +45,6 @@ public enum EmbeddingContentType { PASSAGE } - public enum SparseEmbeddingFormat { - LEXICAL, - TOKEN_ID - } - public static final String PARSE_FIELD_NAME = FunctionName.TEXT_EMBEDDING.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( MLAlgoParams.class, @@ -70,13 +65,13 @@ public enum SparseEmbeddingFormat { @Builder(toBuilder = true) public AsymmetricTextEmbeddingParameters(EmbeddingContentType embeddingContentType, SparseEmbeddingFormat sparseEmbeddingFormat) { this.embeddingContentType = embeddingContentType; - this.sparseEmbeddingFormat = sparseEmbeddingFormat != null ? sparseEmbeddingFormat : SparseEmbeddingFormat.LEXICAL; + this.sparseEmbeddingFormat = sparseEmbeddingFormat != null ? sparseEmbeddingFormat : SparseEmbeddingFormat.WORD; } // Constructor for backward compatibility public AsymmetricTextEmbeddingParameters(EmbeddingContentType embeddingContentType) { this.embeddingContentType = embeddingContentType; - this.sparseEmbeddingFormat = SparseEmbeddingFormat.LEXICAL; + this.sparseEmbeddingFormat = SparseEmbeddingFormat.WORD; } public AsymmetricTextEmbeddingParameters(StreamInput in) throws IOException { @@ -85,9 +80,9 @@ public AsymmetricTextEmbeddingParameters(StreamInput in) throws IOException { this.embeddingContentType = contentType != null ? EmbeddingContentType.valueOf(contentType) : null; if (streamInputVersion.onOrAfter(Version.V_3_2_0)) { String formatName = in.readOptionalString(); - this.sparseEmbeddingFormat = formatName != null ? SparseEmbeddingFormat.valueOf(formatName) : SparseEmbeddingFormat.LEXICAL; + this.sparseEmbeddingFormat = formatName != null ? SparseEmbeddingFormat.valueOf(formatName) : SparseEmbeddingFormat.WORD; } else { - this.sparseEmbeddingFormat = SparseEmbeddingFormat.LEXICAL; + this.sparseEmbeddingFormat = SparseEmbeddingFormat.WORD; } } diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEmbeddingFormat.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEmbeddingFormat.java new file mode 100644 index 0000000000..1e66f825a8 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/textembedding/SparseEmbeddingFormat.java @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.input.parameter.textembedding; + +/** + * Enum defining the format of sparse embeddings. + */ +public enum SparseEmbeddingFormat { + WORD, + TOKEN_ID +} diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java index 808a84aa10..dc494030ca 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/AsymmetricTextEmbeddingParametersTest.java @@ -19,7 +19,7 @@ import org.opensearch.ml.common.TestHelper; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType; -import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat; +import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat; public class AsymmetricTextEmbeddingParametersTest { @@ -82,7 +82,7 @@ public void parse_AsymmetricTextEmbeddingParameters_WithSparseEmbeddingFormat_LE AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters .builder() .embeddingContentType(EmbeddingContentType.QUERY) - .sparseEmbeddingFormat(SparseEmbeddingFormat.LEXICAL) + .sparseEmbeddingFormat(SparseEmbeddingFormat.WORD) .build(); TestHelper.testParse(params, function); } @@ -110,9 +110,7 @@ public void parse_AsymmetricTextEmbeddingParameters_OnlySparseEmbeddingFormat() public void parse_AsymmetricTextEmbeddingParameters_SparseEmbeddingFormat_Invalid() throws IOException { exceptionRule.expect(IllegalArgumentException.class); exceptionRule - .expectMessage( - "No enum constant org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.INVALID" - ); + .expectMessage("No enum constant org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat.INVALID"); String jsonWithInvalidFormat = "{\"content_type\": \"QUERY\", \"sparse_embedding_format\": \"INVALID\"}"; testParseFromString(params, jsonWithInvalidFormat, function); } @@ -121,7 +119,7 @@ public void parse_AsymmetricTextEmbeddingParameters_SparseEmbeddingFormat_Invali public void constructor_BackwardCompatibility() { AsymmetricTextEmbeddingParameters params = new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY); assertEquals(EmbeddingContentType.QUERY, params.getEmbeddingContentType()); - assertEquals(SparseEmbeddingFormat.LEXICAL, params.getSparseEmbeddingFormat()); + assertEquals(SparseEmbeddingFormat.WORD, params.getSparseEmbeddingFormat()); } @Test @@ -138,7 +136,7 @@ public void constructor_WithSparseEmbeddingFormat() { public void constructor_WithNullSparseEmbeddingFormat_DefaultsToLexical() { AsymmetricTextEmbeddingParameters params = new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY, null); assertEquals(EmbeddingContentType.QUERY, params.getEmbeddingContentType()); - assertEquals(SparseEmbeddingFormat.LEXICAL, params.getSparseEmbeddingFormat()); + assertEquals(SparseEmbeddingFormat.WORD, params.getSparseEmbeddingFormat()); } @Test @@ -184,7 +182,7 @@ public void readInputStream_VersionCompatibility_Pre_V_3_2_0() throws IOExceptio AsymmetricTextEmbeddingParameters parsedParams = new AsymmetricTextEmbeddingParameters(streamInput); assertEquals(EmbeddingContentType.QUERY, parsedParams.getEmbeddingContentType()); - assertEquals(SparseEmbeddingFormat.LEXICAL, parsedParams.getSparseEmbeddingFormat()); + assertEquals(SparseEmbeddingFormat.WORD, parsedParams.getSparseEmbeddingFormat()); } @Test diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java index 590aeceb4e..7e26fbec55 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sparse_encoding/SparseEncodingTranslator.java @@ -14,7 +14,7 @@ import java.util.List; import java.util.Map; -import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.engine.algorithms.SentenceTransformerTranslator; @@ -40,9 +40,9 @@ public NDList processInput(TranslatorContext ctx, Input input) { public Output processOutput(TranslatorContext ctx, NDList list) { Output output = new Output(200, "OK"); Object embeddingFormatObject = ctx.getAttachment(SPARSE_EMBEDDING_FORMAT_FIELD); - AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat embeddingFormat = embeddingFormatObject != null - ? AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.valueOf(embeddingFormatObject.toString()) - : AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.LEXICAL; + SparseEmbeddingFormat embeddingFormat = embeddingFormatObject != null + ? SparseEmbeddingFormat.valueOf(embeddingFormatObject.toString()) + : SparseEmbeddingFormat.WORD; List outputs = new ArrayList<>(); for (NDArray ndArray : list) { @@ -58,11 +58,11 @@ public Output processOutput(TranslatorContext ctx, NDList list) { return output; } - private Object convertOutput(NDArray array, AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat embeddingFormat) { + private Object convertOutput(NDArray array, SparseEmbeddingFormat embeddingFormat) { NDArray nonZeroIndices = array.nonzero().squeeze(); long[] indices = nonZeroIndices.toLongArray(); - if (embeddingFormat == AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.TOKEN_ID) { + if (embeddingFormat == SparseEmbeddingFormat.TOKEN_ID) { // Return token_id format: {"123": 1.1, "456": 2.2} Map tokenIdWeights = new HashMap<>(); @@ -72,7 +72,7 @@ private Object convertOutput(NDArray array, AsymmetricTextEmbeddingParameters.Sp return tokenIdWeights; } else { - // Return lexical format: {"token": weight, ...} + // Return word format: {"token": weight, ...} Map tokenWeights = new HashMap<>(); for (long index : indices) { String token = this.tokenizer.decode(new long[] { index }, true); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java index 6326a8e772..5219f1cfb5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModel.java @@ -23,6 +23,7 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; @@ -70,8 +71,7 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla // Get the embedding format from parameters MLAlgoParams parameters = mlInput.getParameters(); - AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat sparseEmbeddingFormat = - AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.LEXICAL; // default + SparseEmbeddingFormat sparseEmbeddingFormat = SparseEmbeddingFormat.WORD; // default if (parameters instanceof AsymmetricTextEmbeddingParameters) { AsymmetricTextEmbeddingParameters sparseParams = (AsymmetricTextEmbeddingParameters) parameters; @@ -93,7 +93,7 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla if (token.isEmpty()) { continue; } - if (sparseEmbeddingFormat == AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat.TOKEN_ID) { + if (sparseEmbeddingFormat == SparseEmbeddingFormat.TOKEN_ID) { tokenWeights.put(String.valueOf(uniqueIndices[i]), idf.getOrDefault(token, 1.0f)); } else { tokenWeights.put(token, idf.getOrDefault(token, 1.0f)); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/analysis/HFModelTokenizer.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/analysis/HFModelTokenizer.java index 52715233fd..53cc56bfc5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/analysis/HFModelTokenizer.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/analysis/HFModelTokenizer.java @@ -15,7 +15,9 @@ import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.OffsetAttribute; import org.apache.lucene.analysis.tokenattributes.PayloadAttribute; +import org.apache.lucene.analysis.tokenattributes.TypeAttribute; import org.apache.lucene.util.BytesRef; +import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat; import com.google.common.io.CharStreams; @@ -35,6 +37,7 @@ public class HFModelTokenizer extends Tokenizer { private final CharTermAttribute termAtt; private final PayloadAttribute payloadAtt; private final OffsetAttribute offsetAtt; + private final TypeAttribute typeAtt; private final Supplier tokenizerSupplier; private final Supplier> tokenWeightsSupplier; @@ -49,6 +52,7 @@ public HFModelTokenizer(Supplier huggingFaceTokenizerSuppl public HFModelTokenizer(Supplier huggingFaceTokenizerSupplier, Supplier> weightsSupplier) { termAtt = addAttribute(CharTermAttribute.class); offsetAtt = addAttribute(OffsetAttribute.class); + typeAtt = addAttribute(TypeAttribute.class); if (Objects.nonNull(weightsSupplier)) { payloadAtt = addAttribute(PayloadAttribute.class); } else { @@ -81,9 +85,19 @@ public static float bytesToFloat(byte[] bytes) { return ByteBuffer.wrap(bytes).getFloat(); } + /** + * Clear all attributes except type. Type is used to identify the sparse embedding format. + * It should be immutable and not needed to be cleared by the tokenizer. + */ + private void clearAttributesExceptType() { + String type = typeAtt.type(); + clearAttributes(); + typeAtt.setType(type); + } + @Override final public boolean incrementToken() throws IOException { - clearAttributes(); + clearAttributesExceptType(); if (Objects.isNull(encoding)) return false; Encoding curEncoding = overflowingIdx == -1 ? encoding : encoding.getOverflowing()[overflowingIdx]; @@ -99,7 +113,12 @@ final public boolean incrementToken() throws IOException { } curEncoding = encoding.getOverflowing()[overflowingIdx]; } else { - termAtt.append(curEncoding.getTokens()[tokenIdx]); + SparseEmbeddingFormat sparseEmbeddingFormat = SparseEmbeddingFormat.valueOf(typeAtt.type().toUpperCase()); + if (sparseEmbeddingFormat == SparseEmbeddingFormat.WORD) { + termAtt.append(curEncoding.getTokens()[tokenIdx]); + } else { + termAtt.append(String.valueOf(curEncoding.getIds()[tokenIdx])); + } offsetAtt .setOffset(curEncoding.getCharTokenSpans()[tokenIdx].getStart(), curEncoding.getCharTokenSpans()[tokenIdx].getEnd()); if (Objects.nonNull(tokenWeightsSupplier)) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModelTest.java index 484644ae99..4dabe79dac 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sparse_encoding/TextEmbeddingSparseEncodingModelTest.java @@ -31,7 +31,7 @@ import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; -import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat; +import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -190,14 +190,14 @@ public void initModel_predict_TorchScript_SparseEncoding_ResultFilter() { textEmbeddingSparseEncodingModel.close(); } - // Test AsymmetricTextEmbeddingParameters with LEXICAL format + // Test AsymmetricTextEmbeddingParameters with WORD format @Test public void initModel_predict_SparseEncoding_WithLexicalFormat() { textEmbeddingSparseEncodingModel.initModel(model, params, encryptor); AsymmetricTextEmbeddingParameters parameters = AsymmetricTextEmbeddingParameters .builder() - .sparseEmbeddingFormat(SparseEmbeddingFormat.LEXICAL) + .sparseEmbeddingFormat(SparseEmbeddingFormat.WORD) .build(); MLInput mlInput = MLInput @@ -307,7 +307,7 @@ public void test_isAsymmetricModel_ReturnsFalse() { AsymmetricTextEmbeddingParameters parameters = AsymmetricTextEmbeddingParameters .builder() .embeddingContentType(AsymmetricTextEmbeddingParameters.EmbeddingContentType.QUERY) - .sparseEmbeddingFormat(SparseEmbeddingFormat.LEXICAL) + .sparseEmbeddingFormat(SparseEmbeddingFormat.WORD) .build(); // Test that isAsymmetricModel returns false even with AsymmetricTextEmbeddingParameters diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModelTest.java index fb9ae65fbc..79130b62ab 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/tokenize/SparseTokenizerModelTest.java @@ -28,7 +28,7 @@ import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; -import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.SparseEmbeddingFormat; +import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; @@ -166,7 +166,7 @@ public void initModel_predict_Tokenize() throws URISyntaxException, TranslateExc } } - // Test default LEXICAL format (no parameters provided) + // Test default WORD format (no parameters provided) @Test public void initModel_predict_Tokenize_DefaultLexicalFormat() throws URISyntaxException, TranslateException { sparseTokenizerModel.initModel(model, params, encryptor); @@ -176,7 +176,7 @@ public void initModel_predict_Tokenize_DefaultLexicalFormat() throws URISyntaxEx List mlModelOutputs = output.getMlModelOutputs(); assertEquals(2, mlModelOutputs.size()); - // Verify output format is LEXICAL (token strings as keys) + // Verify output format is WORD (token strings as keys) for (ModelTensors tensors : mlModelOutputs) { List mlModelTensors = tensors.getMlModelTensors(); assertEquals(1, mlModelTensors.size()); @@ -192,14 +192,14 @@ public void initModel_predict_Tokenize_DefaultLexicalFormat() throws URISyntaxEx } } - // Test LEXICAL format with explicit parameter + // Test WORD format with explicit parameter @Test public void initModel_predict_Tokenize_WithLexicalFormat() throws URISyntaxException, TranslateException { sparseTokenizerModel.initModel(model, params, encryptor); AsymmetricTextEmbeddingParameters parameters = AsymmetricTextEmbeddingParameters .builder() - .sparseEmbeddingFormat(SparseEmbeddingFormat.LEXICAL) + .sparseEmbeddingFormat(SparseEmbeddingFormat.WORD) .build(); MLInput mlInput = MLInput @@ -214,7 +214,7 @@ public void initModel_predict_Tokenize_WithLexicalFormat() throws URISyntaxExcep List mlModelOutputs = output.getMlModelOutputs(); assertEquals(2, mlModelOutputs.size()); - // Verify output format is LEXICAL + // Verify output format is WORD for (ModelTensors tensors : mlModelOutputs) { List mlModelTensors = tensors.getMlModelTensors(); assertEquals(1, mlModelTensors.size()); @@ -225,7 +225,7 @@ public void initModel_predict_Tokenize_WithLexicalFormat() throws URISyntaxExcep // Verify keys are token strings for (String key : result.keySet()) { - assertTrue("Key should be a token string for LEXICAL format", !isNumeric(key)); + assertTrue("Key should be a token string for WORD format", !isNumeric(key)); } } } From 3ad0c36c72c7a16bb636da8a2356fd4e6ca1a76b Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 14 Jul 2025 17:06:49 +0800 Subject: [PATCH 10/10] add UT Signed-off-by: zhichao-aws --- .../analysis/HFModelTokenizerTests.java | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/analysis/HFModelTokenizerTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/analysis/HFModelTokenizerTests.java index 1a88d27bf8..aea5332e02 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/analysis/HFModelTokenizerTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/analysis/HFModelTokenizerTests.java @@ -9,6 +9,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import java.io.StringReader; import java.util.HashMap; @@ -16,6 +17,7 @@ import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.PayloadAttribute; +import org.apache.lucene.analysis.tokenattributes.TypeAttribute; import org.apache.lucene.util.BytesRef; import org.junit.Before; import org.junit.Test; @@ -74,24 +76,28 @@ public void testTokenizeWithWeights() { CharTermAttribute termAtt = tokenizer.addAttribute(CharTermAttribute.class); PayloadAttribute payloadAtt = tokenizer.addAttribute(PayloadAttribute.class); + TypeAttribute typeAtt = tokenizer.addAttribute(TypeAttribute.class); assertTrue(tokenizer.incrementToken()); assertEquals("hello", termAtt.toString()); BytesRef payload = payloadAtt.getPayload(); assertNotNull(payload); assertEquals(0.5f, HFModelTokenizer.bytesToFloat(payload.bytes), 0.0001f); + assertEquals("word", typeAtt.type()); assertTrue(tokenizer.incrementToken()); assertEquals("world", termAtt.toString()); payload = payloadAtt.getPayload(); assertNotNull(payload); assertEquals(0.3f, HFModelTokenizer.bytesToFloat(payload.bytes), 0.0001f); + assertEquals("word", typeAtt.type()); assertTrue(tokenizer.incrementToken()); assertEquals("a", termAtt.toString()); payload = payloadAtt.getPayload(); assertNotNull(payload); assertEquals(1f, HFModelTokenizer.bytesToFloat(payload.bytes), 0f); + assertEquals("word", typeAtt.type()); // No more tokens assertFalse(tokenizer.incrementToken()); @@ -125,4 +131,110 @@ public void testFloatBytesConversion() { float convertedValue = HFModelTokenizer.bytesToFloat(bytes); assertEquals(originalValue, convertedValue, 0f); } + + @SneakyThrows + @Test + public void testTokenizeWithTypeWord() { + HFModelTokenizer tokenizer = new HFModelTokenizer(() -> huggingFaceTokenizer); + tokenizer.setReader(new StringReader("hello world")); + tokenizer.reset(); + + // Add and set type attribute before reset + TypeAttribute typeAtt = tokenizer.addAttribute(TypeAttribute.class); + typeAtt.setType("word"); + CharTermAttribute termAtt = tokenizer.addAttribute(CharTermAttribute.class); + + assertTrue(tokenizer.incrementToken()); + assertEquals("hello", termAtt.toString()); + assertEquals("word", typeAtt.type()); + + assertTrue(tokenizer.incrementToken()); + assertEquals("world", termAtt.toString()); + assertEquals("word", typeAtt.type()); + + // No more tokens + assertFalse(tokenizer.incrementToken()); + } + + @SneakyThrows + @Test + public void testTokenizeWithTypeTokenId() { + HFModelTokenizer tokenizer = new HFModelTokenizer(() -> huggingFaceTokenizer); + tokenizer.setReader(new StringReader("hello world")); + tokenizer.reset(); + + // Add and set type attribute before reset + TypeAttribute typeAtt = tokenizer.addAttribute(TypeAttribute.class); + typeAtt.setType("token_id"); + CharTermAttribute termAtt = tokenizer.addAttribute(CharTermAttribute.class); + + assertTrue(tokenizer.incrementToken()); + String firstToken = termAtt.toString(); + assertEquals("7592", termAtt.toString()); + + assertTrue(tokenizer.incrementToken()); + String secondToken = termAtt.toString(); + assertEquals("2088", termAtt.toString()); + + // No more tokens + assertFalse(tokenizer.incrementToken()); + } + + @SneakyThrows + @Test + public void testTokenizeWithTypeTokenIdAndWeights() { + HFModelTokenizer tokenizer = new HFModelTokenizer(() -> huggingFaceTokenizer, () -> tokenWeights); + tokenizer.setReader(new StringReader("hello world")); + tokenizer.reset(); + + // Add and set type attribute before reset + TypeAttribute typeAtt = tokenizer.addAttribute(TypeAttribute.class); + typeAtt.setType("token_id"); + CharTermAttribute termAtt = tokenizer.addAttribute(CharTermAttribute.class); + PayloadAttribute payloadAtt = tokenizer.addAttribute(PayloadAttribute.class); + + assertTrue(tokenizer.incrementToken()); + String firstToken = termAtt.toString(); + assertEquals("7592", termAtt.toString()); + BytesRef payload = payloadAtt.getPayload(); + assertNotNull(payload); + assertEquals(0.5f, HFModelTokenizer.bytesToFloat(payload.bytes), 0.0001f); + assertEquals("token_id", typeAtt.type()); + + assertTrue(tokenizer.incrementToken()); + String secondToken = termAtt.toString(); + assertEquals("2088", termAtt.toString()); + payload = payloadAtt.getPayload(); + assertNotNull(payload); + assertEquals(0.3f, HFModelTokenizer.bytesToFloat(payload.bytes), 0.0001f); + assertEquals("token_id", typeAtt.type()); + + // No more tokens + assertFalse(tokenizer.incrementToken()); + } + + @SneakyThrows + @Test + public void testTokenizeWithInvalidType() { + HFModelTokenizer tokenizer = new HFModelTokenizer(() -> huggingFaceTokenizer); + tokenizer.setReader(new StringReader("hello world")); + + // Add and set invalid type attribute before reset + TypeAttribute typeAtt = tokenizer.addAttribute(TypeAttribute.class); + typeAtt.setType("invalid_type"); + CharTermAttribute termAtt = tokenizer.addAttribute(CharTermAttribute.class); + + tokenizer.reset(); + + // Should throw IllegalArgumentException when incrementToken() is called with invalid type + try { + tokenizer.incrementToken(); + fail("Expected IllegalArgumentException for invalid type"); + } catch (IllegalArgumentException e) { + assertTrue( + "Exception message should contain enum error", + e.getMessage().contains("No enum constant") && e.getMessage().contains("SparseEmbeddingFormat.INVALID_TYPE") + ); + } + } }