Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

import java.io.IOException;
import java.util.Locale;
import java.util.Objects;

import org.opensearch.Version;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -31,9 +33,11 @@
* <p>
* Use this parameter only if the model is asymmetric and has been registered with the corresponding
* `query_prefix` and `passage_prefix` configuration parameters.
* <p>
* Also supports embedding format control for sparse encoding algorithms.
*/
@Data
@MLAlgoParameter(algorithms = { FunctionName.TEXT_EMBEDDING })
@MLAlgoParameter(algorithms = { FunctionName.TEXT_EMBEDDING, FunctionName.SPARSE_ENCODING, FunctionName.SPARSE_TOKENIZE })
public class AsymmetricTextEmbeddingParameters implements MLAlgoParams {

public enum EmbeddingContentType {
Expand All @@ -47,18 +51,44 @@ public enum EmbeddingContentType {
new ParseField(PARSE_FIELD_NAME),
it -> parse(it)
);
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY_SPARSE_ENCODING = new NamedXContentRegistry.Entry(
MLAlgoParams.class,
new ParseField(FunctionName.SPARSE_ENCODING.name()),
it -> parse(it)
);
public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY_SPARSE_TOKENIZE = new NamedXContentRegistry.Entry(
MLAlgoParams.class,
new ParseField(FunctionName.SPARSE_TOKENIZE.name()),
it -> parse(it)
);

@Builder(toBuilder = true)
public AsymmetricTextEmbeddingParameters(EmbeddingContentType embeddingContentType, SparseEmbeddingFormat sparseEmbeddingFormat) {
this.embeddingContentType = embeddingContentType;
this.sparseEmbeddingFormat = sparseEmbeddingFormat != null ? sparseEmbeddingFormat : SparseEmbeddingFormat.WORD;
}

// Constructor for backward compatibility
public AsymmetricTextEmbeddingParameters(EmbeddingContentType embeddingContentType) {
this.embeddingContentType = embeddingContentType;
this.sparseEmbeddingFormat = SparseEmbeddingFormat.WORD;
}

public AsymmetricTextEmbeddingParameters(StreamInput in) throws IOException {
this.embeddingContentType = EmbeddingContentType.valueOf(in.readOptionalString());
Version streamInputVersion = in.getVersion();
String contentType = in.readOptionalString();
this.embeddingContentType = contentType != null ? EmbeddingContentType.valueOf(contentType) : null;
if (streamInputVersion.onOrAfter(Version.V_3_2_0)) {
String formatName = in.readOptionalString();
this.sparseEmbeddingFormat = formatName != null ? SparseEmbeddingFormat.valueOf(formatName) : SparseEmbeddingFormat.WORD;
} else {
this.sparseEmbeddingFormat = SparseEmbeddingFormat.WORD;
}
}

public static MLAlgoParams parse(XContentParser parser) throws IOException {
EmbeddingContentType embeddingContentType = null;
SparseEmbeddingFormat sparseEmbeddingFormat = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -70,19 +100,27 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException {
String contentType = parser.text();
embeddingContentType = EmbeddingContentType.valueOf(contentType.toUpperCase(Locale.ROOT));
break;
case SPARSE_EMBEDDING_FORMAT_FIELD:
String formatType = parser.text();
sparseEmbeddingFormat = SparseEmbeddingFormat.valueOf(formatType.toUpperCase(Locale.ROOT));
break;
default:
parser.skipChildren();
break;
}
}
return new AsymmetricTextEmbeddingParameters(embeddingContentType);
return new AsymmetricTextEmbeddingParameters(embeddingContentType, sparseEmbeddingFormat);
}

public static final String EMBEDDING_CONTENT_TYPE_FIELD = "content_type";
public static final String SPARSE_EMBEDDING_FORMAT_FIELD = "sparse_embedding_format";

// The type of the content to be embedded
private EmbeddingContentType embeddingContentType;

// The format of the embedding output
private SparseEmbeddingFormat sparseEmbeddingFormat;

@Override
public int getVersion() {
return 1;
Expand All @@ -95,7 +133,11 @@ public String getWriteableName() {

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(embeddingContentType.name());
Version streamOutputVersion = out.getVersion();
out.writeOptionalString(embeddingContentType != null ? embeddingContentType.name() : null);
if (streamOutputVersion.onOrAfter(Version.V_3_2_0)) {
out.writeOptionalString(sparseEmbeddingFormat != null ? sparseEmbeddingFormat.name() : null);
}
}

@Override
Expand All @@ -104,11 +146,29 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
if (embeddingContentType != null) {
xContentBuilder.field(EMBEDDING_CONTENT_TYPE_FIELD, embeddingContentType.name());
}
xContentBuilder.field(SPARSE_EMBEDDING_FORMAT_FIELD, sparseEmbeddingFormat.name());
xContentBuilder.endObject();
return xContentBuilder;
}

public EmbeddingContentType getEmbeddingContentType() {
return embeddingContentType;
}

public SparseEmbeddingFormat getSparseEmbeddingFormat() {
return sparseEmbeddingFormat;
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
AsymmetricTextEmbeddingParameters other = (AsymmetricTextEmbeddingParameters) obj;
return Objects.equals(embeddingContentType, other.embeddingContentType)
&& Objects.equals(sparseEmbeddingFormat, other.sparseEmbeddingFormat);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.input.parameter.textembedding;

/**
* Enum defining the format of sparse embeddings.
*/
public enum SparseEmbeddingFormat {
WORD,
TOKEN_ID
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.opensearch.ml.common.dataset;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.opensearch.ml.common.TestHelper.contentObjectToString;
import static org.opensearch.ml.common.TestHelper.testParseFromString;

Expand All @@ -11,12 +12,14 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.Version;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.TestHelper;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters.EmbeddingContentType;
import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat;

public class AsymmetricTextEmbeddingParametersTest {

Expand Down Expand Up @@ -74,6 +77,127 @@ public void readInputStream_Success_EmptyParams() throws IOException {
readInputStream(AsymmetricTextEmbeddingParameters.builder().embeddingContentType(EmbeddingContentType.PASSAGE).build());
}

@Test
public void parse_AsymmetricTextEmbeddingParameters_WithSparseEmbeddingFormat_LEXICAL() throws IOException {
AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters
.builder()
.embeddingContentType(EmbeddingContentType.QUERY)
.sparseEmbeddingFormat(SparseEmbeddingFormat.WORD)
.build();
TestHelper.testParse(params, function);
}

@Test
public void parse_AsymmetricTextEmbeddingParameters_WithSparseEmbeddingFormat_TOKEN_ID() throws IOException {
AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters
.builder()
.embeddingContentType(EmbeddingContentType.PASSAGE)
.sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID)
.build();
TestHelper.testParse(params, function);
}

@Test
public void parse_AsymmetricTextEmbeddingParameters_OnlySparseEmbeddingFormat() throws IOException {
AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters
.builder()
.sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID)
.build();
TestHelper.testParse(params, function);
}

@Test
public void parse_AsymmetricTextEmbeddingParameters_SparseEmbeddingFormat_Invalid() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule
.expectMessage("No enum constant org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat.INVALID");
String jsonWithInvalidFormat = "{\"content_type\": \"QUERY\", \"sparse_embedding_format\": \"INVALID\"}";
testParseFromString(params, jsonWithInvalidFormat, function);
}

@Test
public void constructor_BackwardCompatibility() {
AsymmetricTextEmbeddingParameters params = new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY);
assertEquals(EmbeddingContentType.QUERY, params.getEmbeddingContentType());
assertEquals(SparseEmbeddingFormat.WORD, params.getSparseEmbeddingFormat());
}

@Test
public void constructor_WithSparseEmbeddingFormat() {
AsymmetricTextEmbeddingParameters params = new AsymmetricTextEmbeddingParameters(
EmbeddingContentType.PASSAGE,
SparseEmbeddingFormat.TOKEN_ID
);
assertEquals(EmbeddingContentType.PASSAGE, params.getEmbeddingContentType());
assertEquals(SparseEmbeddingFormat.TOKEN_ID, params.getSparseEmbeddingFormat());
}

@Test
public void constructor_WithNullSparseEmbeddingFormat_DefaultsToLexical() {
AsymmetricTextEmbeddingParameters params = new AsymmetricTextEmbeddingParameters(EmbeddingContentType.QUERY, null);
assertEquals(EmbeddingContentType.QUERY, params.getEmbeddingContentType());
assertEquals(SparseEmbeddingFormat.WORD, params.getSparseEmbeddingFormat());
}

@Test
public void constructor_NullContentType_WithSparseEmbeddingFormat() {
AsymmetricTextEmbeddingParameters params = new AsymmetricTextEmbeddingParameters(null, SparseEmbeddingFormat.TOKEN_ID);
assertNull(params.getEmbeddingContentType());
assertEquals(SparseEmbeddingFormat.TOKEN_ID, params.getSparseEmbeddingFormat());
}

@Test
public void readInputStream_WithSparseEmbeddingFormat() throws IOException {
AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters
.builder()
.embeddingContentType(EmbeddingContentType.PASSAGE)
.sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID)
.build();
readInputStream(params);
}

@Test
public void readInputStream_OnlySparseEmbeddingFormat() throws IOException {
AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters
.builder()
.sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID)
.build();
readInputStream(params);
}

@Test
public void readInputStream_VersionCompatibility_Pre_V_3_2_0() throws IOException {
AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters
.builder()
.embeddingContentType(EmbeddingContentType.QUERY)
.sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID)
.build();

BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
bytesStreamOutput.setVersion(Version.V_3_1_0);
params.writeTo(bytesStreamOutput);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
streamInput.setVersion(Version.V_3_1_0);
AsymmetricTextEmbeddingParameters parsedParams = new AsymmetricTextEmbeddingParameters(streamInput);

assertEquals(EmbeddingContentType.QUERY, parsedParams.getEmbeddingContentType());
assertEquals(SparseEmbeddingFormat.WORD, parsedParams.getSparseEmbeddingFormat());
}

@Test
public void toXContent_IncludesSparseEmbeddingFormat() throws IOException {
AsymmetricTextEmbeddingParameters params = AsymmetricTextEmbeddingParameters
.builder()
.embeddingContentType(EmbeddingContentType.QUERY)
.sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID)
.build();

String jsonStr = contentObjectToString(params);
assert (jsonStr.contains("\"content_type\":\"QUERY\""));
assert (jsonStr.contains("\"sparse_embedding_format\":\"TOKEN_ID\""));
}

private void readInputStream(AsymmetricTextEmbeddingParameters params) throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
params.writeTo(bytesStreamOutput);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import ai.djl.translate.TranslateException;

public abstract class TextEmbeddingModel extends DLModel {
protected boolean isSparseModel = false;

@Override
public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
Expand All @@ -40,13 +41,18 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla
for (String doc : textDocsInput.getDocs()) {
Input input = new Input();
input.add(doc);
if (mlParams instanceof AsymmetricTextEmbeddingParameters) {
AsymmetricTextEmbeddingParameters params = (AsymmetricTextEmbeddingParameters) mlParams;
input.add(AsymmetricTextEmbeddingParameters.SPARSE_EMBEDDING_FORMAT_FIELD, params.getSparseEmbeddingFormat().name());
}

output = getPredictor().predict(input);
tensorOutputs.add(parseModelTensorOutput(output, resultFilter));
}
return new ModelTensorOutput(tensorOutputs);
}

private boolean isAsymmetricModel(MLAlgoParams mlParams) {
protected boolean isAsymmetricModel(MLAlgoParams mlParams) {
if (mlParams instanceof AsymmetricTextEmbeddingParameters) {
// Check for the necessary prefixes in modelConfig
if (modelConfig == null
Expand Down
Loading
Loading