Skip to content

Commit 3fc1ed6

Browse files
committed
Use cassandra-java-driver's QueryBuidler vector support
Salvaged changes from #1817 Fixes #1817 Signed-off-by: Eric Bottard <[email protected]>
1 parent 107ab68 commit 3fc1ed6

File tree

1 file changed

+39
-76
lines changed

1 file changed

+39
-76
lines changed

vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/cassandra/CassandraVectorStore.java

Lines changed: 39 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import com.datastax.oss.driver.api.core.cql.BoundStatement;
4040
import com.datastax.oss.driver.api.core.cql.BoundStatementBuilder;
4141
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
42+
import com.datastax.oss.driver.api.core.cql.ResultSet;
4243
import com.datastax.oss.driver.api.core.cql.Row;
4344
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
4445
import com.datastax.oss.driver.api.core.data.CqlVector;
@@ -49,7 +50,6 @@
4950
import com.datastax.oss.driver.api.core.type.DataTypes;
5051
import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry;
5152
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
52-
import com.datastax.oss.driver.api.querybuilder.BuildableQuery;
5353
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
5454
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
5555
import com.datastax.oss.driver.api.querybuilder.delete.Delete;
@@ -60,6 +60,8 @@
6060
import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumnEnd;
6161
import com.datastax.oss.driver.api.querybuilder.schema.CreateTable;
6262
import com.datastax.oss.driver.api.querybuilder.schema.CreateTableStart;
63+
import com.datastax.oss.driver.api.querybuilder.select.Select;
64+
import com.datastax.oss.driver.api.querybuilder.select.Selector;
6365
import com.datastax.oss.driver.shaded.guava.common.annotations.VisibleForTesting;
6466
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
6567
import org.slf4j.Logger;
@@ -191,8 +193,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme
191193

192194
public static final String DRIVER_PROFILE_SEARCH = "spring-ai-search";
193195

194-
private static final String QUERY_FORMAT = "select %s,%s,%s%s from %s.%s ? order by %s ann of ? limit ?";
195-
196196
private static final Logger logger = LoggerFactory.getLogger(CassandraVectorStore.class);
197197

198198
private static final Map<Similarity, VectorStoreSimilarityMetric> SIMILARITY_TYPE_MAPPING = Map.of(
@@ -219,8 +219,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme
219219

220220
private final PreparedStatement deleteStmt;
221221

222-
private final String similarityStmt;
223-
224222
private final Similarity similarity;
225223

226224
protected CassandraVectorStore(Builder builder) {
@@ -247,7 +245,6 @@ protected CassandraVectorStore(Builder builder) {
247245
.get();
248246

249247
this.similarity = getIndexSimilarity(cassandraMetadata);
250-
this.similarityStmt = similaritySearchStatement();
251248

252249
this.filterExpressionConverter = builder.filterExpressionConverter != null ? builder.filterExpressionConverter
253250
: new CassandraFilterExpressionConverter(cassandraMetadata.getColumns().values());
@@ -353,21 +350,13 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
353350
Preconditions.checkArgument(request.getTopK() <= 1000);
354351
var embedding = toFloatArray(this.embeddingModel.embed(request.getQuery()));
355352
CqlVector<Float> cqlVector = CqlVector.newInstance(embedding);
353+
String cql = createSimilaritySearchCql(request, cqlVector, request.getTopK());
356354

357-
String whereClause = "";
358-
if (request.hasFilterExpression()) {
359-
String expression = this.filterExpressionConverter.convertExpression(request.getFilterExpression());
360-
if (!expression.isBlank()) {
361-
whereClause = String.format("where %s", expression);
362-
}
363-
}
364-
365-
String query = String.format(this.similarityStmt, cqlVector, whereClause, cqlVector, request.getTopK());
366355
List<Document> documents = new ArrayList<>();
367-
logger.trace("Executing {}", query);
368-
SimpleStatement s = SimpleStatement.newInstance(query).setExecutionProfileName(DRIVER_PROFILE_SEARCH);
356+
ResultSet result = this.session
357+
.execute(SimpleStatement.newInstance(cql).setExecutionProfileName(DRIVER_PROFILE_SEARCH));
369358

370-
for (Row row : this.session.execute(s)) {
359+
for (Row row : result) {
371360
float score = row.getFloat(0);
372361
if (score < request.getSimilarityThreshold()) {
373362
break;
@@ -455,34 +444,34 @@ private PreparedStatement prepareAddStatement(Set<String> metadataFields) {
455444
});
456445
}
457446

458-
private String similaritySearchStatement() {
459-
StringBuilder ids = new StringBuilder();
460-
for (var m : this.schema.partitionKeys()) {
461-
ids.append(m.name()).append(',');
462-
}
463-
for (var m : this.schema.clusteringKeys()) {
464-
ids.append(m.name()).append(',');
465-
}
466-
ids.deleteCharAt(ids.length() - 1);
447+
private String createSimilaritySearchCql(SearchRequest request, CqlVector<Float> cqlVector, int topK) {
467448

468-
String similarityFunction = new StringBuilder("similarity_").append(this.similarity.toString().toLowerCase())
469-
.append('(')
470-
.append(this.schema.embedding())
471-
.append(",?)")
472-
.toString();
449+
Select stmt = QueryBuilder.selectFrom(this.schema.keyspace(), this.schema.table())
450+
.function("similarity_" + this.similarity.toString().toLowerCase(),
451+
Selector.column(this.schema.embedding()), QueryBuilder.literal(cqlVector));
473452

474-
StringBuilder extraSelectFields = new StringBuilder();
453+
for (var c : this.schema.partitionKeys()) {
454+
stmt = stmt.column(c.name());
455+
}
456+
for (var c : this.schema.clusteringKeys()) {
457+
stmt = stmt.column(c.name());
458+
}
459+
stmt = stmt.column(this.schema.content());
475460
for (var m : this.schema.metadataColumns()) {
476-
extraSelectFields.append(',').append(m.name());
461+
stmt = stmt.column(m.name());
477462
}
463+
stmt = stmt.column(this.schema.embedding());
478464

479-
// java-driver-query-builder doesn't support orderByAnnOf yet
480-
String query = String.format(QUERY_FORMAT, similarityFunction, ids.toString(), this.schema.content(),
481-
extraSelectFields.toString(), this.schema.keyspace(), this.schema.table(), this.schema.embedding());
482-
483-
query = query.replace("?", "%s");
484-
logger.debug("preparing {}", query);
485-
return query;
465+
// the filterExpression is a string so we go back to building a CQL string
466+
String whereClause = "";
467+
if (request.hasFilterExpression()) {
468+
String expression = this.filterExpressionConverter.convertExpression(request.getFilterExpression());
469+
if (!expression.isBlank()) {
470+
whereClause = String.format(" WHERE %s", expression);
471+
}
472+
}
473+
String cql = stmt.orderByAnnOf(this.schema.embedding(), cqlVector).limit(topK).asCql();
474+
return cql.replace(" ORDER ", whereClause + " ORDER ");
486475
}
487476

488477
private String getDocumentId(Row row) {
@@ -631,25 +620,15 @@ private void ensureTableExists(int vectorDimension) {
631620
createTable = createTable.withClusteringColumn(clusteringKey.name, clusteringKey.type);
632621
}
633622

634-
createTable = createTable.withColumn(this.schema.content, DataTypes.TEXT);
623+
createTable = createTable.withColumn(this.schema.content, DataTypes.TEXT)
624+
.withColumn(this.schema.embedding, DataTypes.vectorOf(DataTypes.FLOAT, vectorDimension));
635625

636626
for (SchemaColumn metadata : this.schema.metadataColumns) {
637627
createTable = createTable.withColumn(metadata.name(), metadata.type());
638628
}
639629

640-
// https://datastax-oss.atlassian.net/browse/JAVA-3118
641-
// .withColumn(config.embedding, new DefaultVectorType(DataTypes.FLOAT,
642-
// vectorDimension));
643-
644-
StringBuilder tableStmt = new StringBuilder(createTable.asCql());
645-
tableStmt.setLength(tableStmt.length() - 1);
646-
tableStmt.append(',')
647-
.append(this.schema.embedding)
648-
.append(" vector<float,")
649-
.append(vectorDimension)
650-
.append(">)");
651-
logger.debug("Executing {}", tableStmt.toString());
652-
this.session.execute(tableStmt.toString());
630+
logger.debug("Executing {}", createTable.asCql());
631+
this.session.execute(createTable.build());
653632
}
654633
}
655634

@@ -687,28 +666,12 @@ private void ensureTableColumnsExist(int vectorDimension) {
687666
alterTable = alterTable.addColumn(this.schema.content, DataTypes.TEXT);
688667
}
689668
if (addEmbedding) {
690-
// special case for embedding column, bc JAVA-3118, as above
691-
StringBuilder alterTableStmt = new StringBuilder(((BuildableQuery) alterTable).asCql());
692-
if (newColumns.isEmpty() && !addContent) {
693-
alterTableStmt.append(" ADD (");
694-
}
695-
else {
696-
alterTableStmt.setLength(alterTableStmt.length() - 1);
697-
alterTableStmt.append(',');
698-
}
699-
alterTableStmt.append(this.schema.embedding)
700-
.append(" vector<float,")
701-
.append(vectorDimension)
702-
.append(">)");
703-
704-
logger.debug("Executing {}", alterTableStmt.toString());
705-
this.session.execute(alterTableStmt.toString());
706-
}
707-
else {
708-
SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build();
709-
logger.debug("Executing {}", stmt.getQuery());
710-
this.session.execute(stmt);
669+
alterTable = alterTable.addColumn(this.schema.embedding,
670+
DataTypes.vectorOf(DataTypes.FLOAT, vectorDimension));
711671
}
672+
SimpleStatement stmt = ((AlterTableAddColumnEnd) alterTable).build();
673+
logger.debug("Executing {}", stmt.getQuery());
674+
this.session.execute(stmt);
712675
}
713676
}
714677

0 commit comments

Comments
 (0)