3939import com .datastax .oss .driver .api .core .cql .BoundStatement ;
4040import com .datastax .oss .driver .api .core .cql .BoundStatementBuilder ;
4141import com .datastax .oss .driver .api .core .cql .PreparedStatement ;
42+ import com .datastax .oss .driver .api .core .cql .ResultSet ;
4243import com .datastax .oss .driver .api .core .cql .Row ;
4344import com .datastax .oss .driver .api .core .cql .SimpleStatement ;
4445import com .datastax .oss .driver .api .core .data .CqlVector ;
4950import com .datastax .oss .driver .api .core .type .DataTypes ;
5051import com .datastax .oss .driver .api .core .type .codec .registry .CodecRegistry ;
5152import com .datastax .oss .driver .api .core .type .reflect .GenericType ;
52- import com .datastax .oss .driver .api .querybuilder .BuildableQuery ;
5353import com .datastax .oss .driver .api .querybuilder .QueryBuilder ;
5454import com .datastax .oss .driver .api .querybuilder .SchemaBuilder ;
5555import com .datastax .oss .driver .api .querybuilder .delete .Delete ;
6060import com .datastax .oss .driver .api .querybuilder .schema .AlterTableAddColumnEnd ;
6161import com .datastax .oss .driver .api .querybuilder .schema .CreateTable ;
6262import com .datastax .oss .driver .api .querybuilder .schema .CreateTableStart ;
63+ import com .datastax .oss .driver .api .querybuilder .select .Select ;
64+ import com .datastax .oss .driver .api .querybuilder .select .Selector ;
6365import com .datastax .oss .driver .shaded .guava .common .annotations .VisibleForTesting ;
6466import com .datastax .oss .driver .shaded .guava .common .base .Preconditions ;
6567import org .slf4j .Logger ;
@@ -191,8 +193,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme
191193
192194 public static final String DRIVER_PROFILE_SEARCH = "spring-ai-search" ;
193195
194- private static final String QUERY_FORMAT = "select %s,%s,%s%s from %s.%s ? order by %s ann of ? limit ?" ;
195-
196196 private static final Logger logger = LoggerFactory .getLogger (CassandraVectorStore .class );
197197
198198 private static final Map <Similarity , VectorStoreSimilarityMetric > SIMILARITY_TYPE_MAPPING = Map .of (
@@ -219,8 +219,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme
219219
220220 private final PreparedStatement deleteStmt ;
221221
222- private final String similarityStmt ;
223-
224222 private final Similarity similarity ;
225223
226224 protected CassandraVectorStore (Builder builder ) {
@@ -247,7 +245,6 @@ protected CassandraVectorStore(Builder builder) {
247245 .get ();
248246
249247 this .similarity = getIndexSimilarity (cassandraMetadata );
250- this .similarityStmt = similaritySearchStatement ();
251248
252249 this .filterExpressionConverter = builder .filterExpressionConverter != null ? builder .filterExpressionConverter
253250 : new CassandraFilterExpressionConverter (cassandraMetadata .getColumns ().values ());
@@ -353,21 +350,13 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
353350 Preconditions .checkArgument (request .getTopK () <= 1000 );
354351 var embedding = toFloatArray (this .embeddingModel .embed (request .getQuery ()));
355352 CqlVector <Float > cqlVector = CqlVector .newInstance (embedding );
353+ String cql = createSimilaritySearchCql (request , cqlVector , request .getTopK ());
356354
357- String whereClause = "" ;
358- if (request .hasFilterExpression ()) {
359- String expression = this .filterExpressionConverter .convertExpression (request .getFilterExpression ());
360- if (!expression .isBlank ()) {
361- whereClause = String .format ("where %s" , expression );
362- }
363- }
364-
365- String query = String .format (this .similarityStmt , cqlVector , whereClause , cqlVector , request .getTopK ());
366355 List <Document > documents = new ArrayList <>();
367- logger . trace ( "Executing {}" , query );
368- SimpleStatement s = SimpleStatement .newInstance (query ).setExecutionProfileName (DRIVER_PROFILE_SEARCH );
356+ ResultSet result = this . session
357+ . execute ( SimpleStatement .newInstance (cql ).setExecutionProfileName (DRIVER_PROFILE_SEARCH ) );
369358
370- for (Row row : this . session . execute ( s ) ) {
359+ for (Row row : result ) {
371360 float score = row .getFloat (0 );
372361 if (score < request .getSimilarityThreshold ()) {
373362 break ;
@@ -455,34 +444,34 @@ private PreparedStatement prepareAddStatement(Set<String> metadataFields) {
455444 });
456445 }
457446
458- private String similaritySearchStatement () {
459- StringBuilder ids = new StringBuilder ();
460- for (var m : this .schema .partitionKeys ()) {
461- ids .append (m .name ()).append (',' );
462- }
463- for (var m : this .schema .clusteringKeys ()) {
464- ids .append (m .name ()).append (',' );
465- }
466- ids .deleteCharAt (ids .length () - 1 );
447+ private String createSimilaritySearchCql (SearchRequest request , CqlVector <Float > cqlVector , int topK ) {
467448
468- String similarityFunction = new StringBuilder ("similarity_" ).append (this .similarity .toString ().toLowerCase ())
469- .append ('(' )
470- .append (this .schema .embedding ())
471- .append (",?)" )
472- .toString ();
449+ Select stmt = QueryBuilder .selectFrom (this .schema .keyspace (), this .schema .table ())
450+ .function ("similarity_" + this .similarity .toString ().toLowerCase (),
451+ Selector .column (this .schema .embedding ()), QueryBuilder .literal (cqlVector ));
473452
474- StringBuilder extraSelectFields = new StringBuilder ();
453+ for (var c : this .schema .partitionKeys ()) {
454+ stmt = stmt .column (c .name ());
455+ }
456+ for (var c : this .schema .clusteringKeys ()) {
457+ stmt = stmt .column (c .name ());
458+ }
459+ stmt = stmt .column (this .schema .content ());
475460 for (var m : this .schema .metadataColumns ()) {
476- extraSelectFields . append ( ',' ). append (m .name ());
461+ stmt = stmt . column (m .name ());
477462 }
463+ stmt = stmt .column (this .schema .embedding ());
478464
479- // java-driver-query-builder doesn't support orderByAnnOf yet
480- String query = String .format (QUERY_FORMAT , similarityFunction , ids .toString (), this .schema .content (),
481- extraSelectFields .toString (), this .schema .keyspace (), this .schema .table (), this .schema .embedding ());
482-
483- query = query .replace ("?" , "%s" );
484- logger .debug ("preparing {}" , query );
485- return query ;
465+ // the filterExpression is a string so we go back to building a CQL string
466+ String whereClause = "" ;
467+ if (request .hasFilterExpression ()) {
468+ String expression = this .filterExpressionConverter .convertExpression (request .getFilterExpression ());
469+ if (!expression .isBlank ()) {
470+ whereClause = String .format (" WHERE %s" , expression );
471+ }
472+ }
473+ String cql = stmt .orderByAnnOf (this .schema .embedding (), cqlVector ).limit (topK ).asCql ();
474+ return cql .replace (" ORDER " , whereClause + " ORDER " );
486475 }
487476
488477 private String getDocumentId (Row row ) {
@@ -631,25 +620,15 @@ private void ensureTableExists(int vectorDimension) {
631620 createTable = createTable .withClusteringColumn (clusteringKey .name , clusteringKey .type );
632621 }
633622
634- createTable = createTable .withColumn (this .schema .content , DataTypes .TEXT );
623+ createTable = createTable .withColumn (this .schema .content , DataTypes .TEXT )
624+ .withColumn (this .schema .embedding , DataTypes .vectorOf (DataTypes .FLOAT , vectorDimension ));
635625
636626 for (SchemaColumn metadata : this .schema .metadataColumns ) {
637627 createTable = createTable .withColumn (metadata .name (), metadata .type ());
638628 }
639629
640- // https://datastax-oss.atlassian.net/browse/JAVA-3118
641- // .withColumn(config.embedding, new DefaultVectorType(DataTypes.FLOAT,
642- // vectorDimension));
643-
644- StringBuilder tableStmt = new StringBuilder (createTable .asCql ());
645- tableStmt .setLength (tableStmt .length () - 1 );
646- tableStmt .append (',' )
647- .append (this .schema .embedding )
648- .append (" vector<float," )
649- .append (vectorDimension )
650- .append (">)" );
651- logger .debug ("Executing {}" , tableStmt .toString ());
652- this .session .execute (tableStmt .toString ());
630+ logger .debug ("Executing {}" , createTable .asCql ());
631+ this .session .execute (createTable .build ());
653632 }
654633 }
655634
@@ -687,28 +666,12 @@ private void ensureTableColumnsExist(int vectorDimension) {
687666 alterTable = alterTable .addColumn (this .schema .content , DataTypes .TEXT );
688667 }
689668 if (addEmbedding ) {
690- // special case for embedding column, bc JAVA-3118, as above
691- StringBuilder alterTableStmt = new StringBuilder (((BuildableQuery ) alterTable ).asCql ());
692- if (newColumns .isEmpty () && !addContent ) {
693- alterTableStmt .append (" ADD (" );
694- }
695- else {
696- alterTableStmt .setLength (alterTableStmt .length () - 1 );
697- alterTableStmt .append (',' );
698- }
699- alterTableStmt .append (this .schema .embedding )
700- .append (" vector<float," )
701- .append (vectorDimension )
702- .append (">)" );
703-
704- logger .debug ("Executing {}" , alterTableStmt .toString ());
705- this .session .execute (alterTableStmt .toString ());
706- }
707- else {
708- SimpleStatement stmt = ((AlterTableAddColumnEnd ) alterTable ).build ();
709- logger .debug ("Executing {}" , stmt .getQuery ());
710- this .session .execute (stmt );
669+ alterTable = alterTable .addColumn (this .schema .embedding ,
670+ DataTypes .vectorOf (DataTypes .FLOAT , vectorDimension ));
711671 }
672+ SimpleStatement stmt = ((AlterTableAddColumnEnd ) alterTable ).build ();
673+ logger .debug ("Executing {}" , stmt .getQuery ());
674+ this .session .execute (stmt );
712675 }
713676 }
714677
0 commit comments