@@ -832,10 +832,14 @@ static char *generate_memory_quant_table (const char *table_name, const char *co
832832 return sqlite3_snprintf (STATIC_SQL_SIZE , sql , "SELECT SUM(LENGTH(data)) FROM vector0_%q_%q;" , table_name , column_name );
833833}
834834
835- static char * generate_insert_quant_table (const char * table_name , const char * column_name , char sql [STATIC_SQL_SIZE ]) {
835+ static char * generate_insert_quant_table (const char * table_name , const char * column_name , char sql [STATIC_SQL_SIZE ]) {
836836 return sqlite3_snprintf (STATIC_SQL_SIZE , sql , "INSERT INTO vector0_%q_%q (rowid1, rowid2, counter, data) VALUES (?, ?, ?, ?);" , table_name , column_name );
837837}
838838
839+ static char * generate_quant_table_name (const char * table_name , const char * column_name , char sql [STATIC_SQL_SIZE ]) {
840+ return sqlite3_snprintf (STATIC_SQL_SIZE , sql , "vector0_%q_%q" , table_name , column_name );
841+ }
842+
839843// MARK: - Vector Context and Options -
840844
841845void * vector_context_create (void ) {
@@ -1513,7 +1517,7 @@ static void vector_convert_i8 (sqlite3_context *context, int argc, sqlite3_value
15131517
15141518// MARK: - Modules -
15151519
1516- static int vCursorFilterCommon (sqlite3_vtab_cursor * cur , int idxNum , const char * idxStr , int argc , sqlite3_value * * argv , const char * fname , vcursor_run_callback run_callback , vcursor_sort_callback sort_callback ) {
1520+ static int vCursorFilterCommon (sqlite3_vtab_cursor * cur , int idxNum , const char * idxStr , int argc , sqlite3_value * * argv , const char * fname , vcursor_run_callback run_callback , vcursor_sort_callback sort_callback , bool check_quant ) {
15171521
15181522 vFullScanCursor * c = (vFullScanCursor * )cur ;
15191523 vFullScan * vtab = (vFullScan * )cur -> pVtab ;
@@ -1562,6 +1566,15 @@ static int vCursorFilterCommon (sqlite3_vtab_cursor *cur, int idxNum, const char
15621566 vsize = sqlite3_value_bytes (argv [2 ]);
15631567 }
15641568
1569+ if (check_quant ) {
1570+ char buffer [STATIC_SQL_SIZE ];
1571+ char * name = generate_quant_table_name (table_name , column_name , buffer );
1572+ if (!name || !sqlite_table_exists (vtab -> db , name )) {
1573+ sqlite_vtab_set_error (& vtab -> base , "Quantization table not found for table '%s' and column '%s'. Ensure that vector_quantize() has been called before using vector_quantize_scan()." );
1574+ return SQLITE_ERROR ;
1575+ }
1576+ }
1577+
15651578 int k = sqlite3_value_int (argv [3 ]);
15661579
15671580 // nothing needs to be returned
@@ -1785,7 +1798,7 @@ static int vFullScanRun (sqlite3 *db, vFullScanCursor *c, const void *v1, int v1
17851798}
17861799
17871800static int vFullScanCursorFilter (sqlite3_vtab_cursor * cur , int idxNum , const char * idxStr , int argc , sqlite3_value * * argv ) {
1788- return vCursorFilterCommon (cur , idxNum , idxStr , argc , argv , "vector_full_scan" , vFullScanRun , vFullScanSortSlots );
1801+ return vCursorFilterCommon (cur , idxNum , idxStr , argc , argv , "vector_full_scan" , vFullScanRun , vFullScanSortSlots , false );
17891802}
17901803
17911804// MARK: -
@@ -1897,7 +1910,7 @@ static int vQuantRun (sqlite3 *db, vFullScanCursor *c, const void *v1, int v1siz
18971910
18981911
18991912static int vQuantCursorFilter (sqlite3_vtab_cursor * cur , int idxNum , const char * idxStr , int argc , sqlite3_value * * argv ) {
1900- return vCursorFilterCommon (cur , idxNum , idxStr , argc , argv , "vector_quantize_scan" , vQuantRun , vFullScanSortSlots );
1913+ return vCursorFilterCommon (cur , idxNum , idxStr , argc , argv , "vector_quantize_scan" , vQuantRun , vFullScanSortSlots , true );
19011914}
19021915
19031916static sqlite3_module vFullScanModule = {
0 commit comments