diff --git a/README.md b/README.md index 8cf9e7a8..f6ceed14 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ This project aims for [high test coverage](https://github.com/ncruces/go-sqlite3 It also benefits greatly from [SQLite's](https://sqlite.org/testing.html) and [wazero's](https://tetrate.io/blog/introducing-wazero-from-tetrate/#:~:text=Rock%2Dsolid%20test%20approach) thorough testing. -Every commit is [tested](.github/workflows/test.yml) on +Every commit is [tested](https://github.com/ncruces/go-sqlite3/wiki/Test-matrix) on Linux (amd64/arm64/386/riscv64/s390x), macOS (amd64/arm64), Windows (amd64), FreeBSD (amd64), OpenBSD (amd64), NetBSD (amd64), illumos (amd64), and Solaris (amd64). diff --git a/config.go b/config.go index 84f0881c..3f60b8fe 100644 --- a/config.go +++ b/config.go @@ -162,7 +162,7 @@ func (c *Conn) Limit(id LimitCategory, value int) int { // SetAuthorizer registers an authorizer callback with the database connection. // // https://sqlite.org/c3ref/set_authorizer.html -func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4th, schema, nameInner string) AuthorizerReturnCode) error { +func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4th, schema, inner string) AuthorizerReturnCode) error { var enable uint64 if cb != nil { enable = 1 @@ -176,9 +176,9 @@ func (c *Conn) SetAuthorizer(cb func(action AuthorizerActionCode, name3rd, name4 } -func authorizerCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zName3rd, zName4th, zSchema, zNameInner uint32) (rc AuthorizerReturnCode) { +func authorizerCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zName3rd, zName4th, zSchema, zInner uint32) (rc AuthorizerReturnCode) { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.authorizer != nil { - var name3rd, name4th, schema, nameInner string + var name3rd, name4th, schema, inner string if zName3rd != 0 { name3rd = util.ReadString(mod, zName3rd, _MAX_NAME) } @@ -188,10 +188,48 @@ func authorizerCallback(ctx context.Context, mod api.Module, pDB uint32, action if zSchema != 0 { schema = util.ReadString(mod, zSchema, _MAX_NAME) } - if zNameInner != 0 { - nameInner = util.ReadString(mod, zNameInner, _MAX_NAME) + if zInner != 0 { + inner = util.ReadString(mod, zInner, _MAX_NAME) + } + rc = c.authorizer(action, name3rd, name4th, schema, inner) + } + return rc +} + +// Trace registers a trace callback function against the database connection. +// +// https://sqlite.org/c3ref/trace_v2.html +func (c *Conn) Trace(mask TraceEvent, cb func(evt TraceEvent, arg1 any, arg2 any) error) error { + r := c.call("sqlite3_trace_go", uint64(c.handle), uint64(mask)) + if err := c.error(r); err != nil { + return err + } + c.trace = cb + return nil +} + +func traceCallback(ctx context.Context, mod api.Module, evt TraceEvent, pDB, pArg1, pArg2 uint32) (rc uint32) { + if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.trace != nil { + var arg1, arg2 any + if evt == TRACE_CLOSE { + arg1 = c + } else { + for _, s := range c.stmts { + if pArg1 == s.handle { + arg1 = s + switch evt { + case TRACE_STMT: + arg2 = s.SQL() + case TRACE_PROFILE: + arg2 = int64(util.ReadUint64(mod, pArg2)) + } + break + } + } + } + if arg1 != nil { + _, rc = errorCode(c.trace(evt, arg1, arg2), ERROR) } - rc = c.authorizer(action, name3rd, name4th, schema, nameInner) } return rc } diff --git a/conn.go b/conn.go index a492256a..b4335f4c 100644 --- a/conn.go +++ b/conn.go @@ -22,14 +22,16 @@ type Conn struct { interrupt context.Context pending *Stmt + stmts []*Stmt busy func(int) bool log func(xErrorCode, string) collation func(*Conn, string) + wal func(*Conn, string, int) error + trace func(TraceEvent, any, any) error authorizer func(AuthorizerActionCode, string, string, string, string) AuthorizerReturnCode update func(AuthorizerActionCode, string, string, int64) commit func() bool rollback func() - wal func(*Conn, string, int) error arena arena handle uint32 @@ -202,6 +204,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str if stmt.handle == 0 { return nil, "", nil } + c.stmts = append(c.stmts, stmt) return stmt, tail, nil } @@ -326,7 +329,12 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { // A busy SQL statement prevents SQLite from ignoring an interrupt // that comes before any other statements are started. if c.pending == nil { - c.pending, _, _ = c.Prepare(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`) + defer c.arena.mark()() + stmtPtr := c.arena.new(ptrlen) + loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`) + c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(loopPtr), math.MaxUint64, 0, uint64(stmtPtr), 0) + c.pending = &Stmt{c: c} + c.pending.handle = util.ReadUint32(c.mod, stmtPtr) } old = c.interrupt @@ -414,10 +422,74 @@ func busyCallback(ctx context.Context, mod api.Module, pDB uint32, count int32) return retry } +// Status retrieves runtime status information about a database connection. +// +// https://sqlite.org/c3ref/db_status.html +func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int, err error) { + defer c.arena.mark()() + hiPtr := c.arena.new(4) + curPtr := c.arena.new(4) + + var i uint64 + if reset { + i = 1 + } + + r := c.call("sqlite3_db_status", uint64(c.handle), + uint64(op), uint64(curPtr), uint64(hiPtr), i) + if err = c.error(r); err == nil { + current = int(util.ReadUint32(c.mod, curPtr)) + highwater = int(util.ReadUint32(c.mod, hiPtr)) + } + return +} + +// TableColumnMetadata extracts metadata about a column of a table. +// +// https://sqlite.org/c3ref/table_column_metadata.html +func (c *Conn) TableColumnMetadata(schema, table, column string) (declType, collSeq string, notNull, primaryKey, autoInc bool, err error) { + defer c.arena.mark()() + + var schemaPtr, columnPtr uint32 + declTypePtr := c.arena.new(ptrlen) + collSeqPtr := c.arena.new(ptrlen) + notNullPtr := c.arena.new(ptrlen) + primaryKeyPtr := c.arena.new(ptrlen) + autoIncPtr := c.arena.new(ptrlen) + if schema != "" { + schemaPtr = c.arena.string(schema) + } + tablePtr := c.arena.string(table) + if column != "" { + columnPtr = c.arena.string(column) + } + + r := c.call("sqlite3_table_column_metadata", uint64(c.handle), + uint64(schemaPtr), uint64(tablePtr), uint64(columnPtr), + uint64(declTypePtr), uint64(collSeqPtr), + uint64(notNullPtr), uint64(primaryKeyPtr), uint64(autoIncPtr)) + if err = c.error(r); err == nil && column != "" { + declType = util.ReadString(c.mod, util.ReadUint32(c.mod, declTypePtr), _MAX_NAME) + collSeq = util.ReadString(c.mod, util.ReadUint32(c.mod, collSeqPtr), _MAX_NAME) + notNull = util.ReadUint32(c.mod, notNullPtr) != 0 + autoInc = util.ReadUint32(c.mod, autoIncPtr) != 0 + primaryKey = util.ReadUint32(c.mod, primaryKeyPtr) != 0 + } + return +} + func (c *Conn) error(rc uint64, sql ...string) error { return c.sqlite.error(rc, c.handle, sql...) } +func (c *Conn) stmtsIter(yield func(*Stmt) bool) { + for _, s := range c.stmts { + if !yield(s) { + break + } + } +} + // DriverConn is implemented by the SQLite [database/sql] driver connection. // // It can be used to access SQLite features like [online backup]. diff --git a/conn_iter.go b/conn_iter.go new file mode 100644 index 00000000..e5a50e2e --- /dev/null +++ b/conn_iter.go @@ -0,0 +1,11 @@ +//go:build go1.23 || goexperiment.rangefunc + +package sqlite3 + +import "iter" + +// Stmts returns an iterator for the prepared statements +// associated with the database connection. +// +// https://sqlite.org/c3ref/next_stmt.html +func (c *Conn) Stmts() iter.Seq[*Stmt] { return c.stmtsIter } diff --git a/conn_old.go b/conn_old.go new file mode 100644 index 00000000..ce0c9c52 --- /dev/null +++ b/conn_old.go @@ -0,0 +1,9 @@ +//go:build !(go1.23 || goexperiment.rangefunc) + +package sqlite3 + +// Stmts returns an iterator for the prepared statements +// associated with the database connection. +// +// https://sqlite.org/c3ref/next_stmt.html +func (c *Conn) Stmts() func(func(*Stmt) bool) { return c.stmtsIter } diff --git a/const.go b/const.go index a3bd395c..11eb33c8 100644 --- a/const.go +++ b/const.go @@ -109,7 +109,7 @@ const ( CANTOPEN_ISDIR ExtendedErrorCode = xErrorCode(CANTOPEN) | (2 << 8) CANTOPEN_FULLPATH ExtendedErrorCode = xErrorCode(CANTOPEN) | (3 << 8) CANTOPEN_CONVPATH ExtendedErrorCode = xErrorCode(CANTOPEN) | (4 << 8) - CANTOPEN_DIRTYWAL ExtendedErrorCode = xErrorCode(CANTOPEN) | (5 << 8) /* Not Used */ + // CANTOPEN_DIRTYWAL ExtendedErrorCode = xErrorCode(CANTOPEN) | (5 << 8) /* Not Used */ CANTOPEN_SYMLINK ExtendedErrorCode = xErrorCode(CANTOPEN) | (6 << 8) CORRUPT_VTAB ExtendedErrorCode = xErrorCode(CORRUPT) | (1 << 8) CORRUPT_SEQUENCE ExtendedErrorCode = xErrorCode(CORRUPT) | (2 << 8) @@ -177,11 +177,11 @@ const ( type FunctionFlag uint32 const ( - DETERMINISTIC FunctionFlag = 0x000000800 - DIRECTONLY FunctionFlag = 0x000080000 - SUBTYPE FunctionFlag = 0x000100000 - INNOCUOUS FunctionFlag = 0x000200000 - RESULT_SUBTYPE FunctionFlag = 0x001000000 + DETERMINISTIC FunctionFlag = 0x000000800 + DIRECTONLY FunctionFlag = 0x000080000 + INNOCUOUS FunctionFlag = 0x000200000 + // SUBTYPE FunctionFlag = 0x000100000 + // RESULT_SUBTYPE FunctionFlag = 0x001000000 ) // StmtStatus name counter values associated with the [Stmt.Status] method. @@ -201,6 +201,27 @@ const ( STMTSTATUS_MEMUSED StmtStatus = 99 ) +// DBStatus are the available "verbs" that can be passed to the [Conn.Status] method. +// +// https://sqlite.org/c3ref/c_dbstatus_options.html +type DBStatus uint32 + +const ( + DBSTATUS_LOOKASIDE_USED DBStatus = 0 + DBSTATUS_CACHE_USED DBStatus = 1 + DBSTATUS_SCHEMA_USED DBStatus = 2 + DBSTATUS_STMT_USED DBStatus = 3 + DBSTATUS_LOOKASIDE_HIT DBStatus = 4 + DBSTATUS_LOOKASIDE_MISS_SIZE DBStatus = 5 + DBSTATUS_LOOKASIDE_MISS_FULL DBStatus = 6 + DBSTATUS_CACHE_HIT DBStatus = 7 + DBSTATUS_CACHE_MISS DBStatus = 8 + DBSTATUS_CACHE_WRITE DBStatus = 9 + DBSTATUS_DEFERRED_FKS DBStatus = 10 + DBSTATUS_CACHE_USED_SHARED DBStatus = 11 + DBSTATUS_CACHE_SPILL DBStatus = 12 +) + // DBConfig are the available database connection configuration options. // // https://sqlite.org/c3ref/c_dbconfig_defensive.html @@ -307,8 +328,8 @@ const ( AUTH_DROP_VTABLE AuthorizerActionCode = 30 /* Table Name Module Name */ AUTH_FUNCTION AuthorizerActionCode = 31 /* NULL Function Name */ AUTH_SAVEPOINT AuthorizerActionCode = 32 /* Operation Savepoint Name */ - AUTH_COPY AuthorizerActionCode = 0 /* No longer used */ AUTH_RECURSIVE AuthorizerActionCode = 33 /* NULL NULL */ + // AUTH_COPY AuthorizerActionCode = 0 /* No longer used */ ) // AuthorizerReturnCode are the integer codes @@ -346,6 +367,18 @@ const ( TXN_WRITE TxnState = 2 ) +// TraceEvent identify classes of events that can be monitored with [Conn.Trace]. +// +// https://sqlite.org/c3ref/c_trace.html +type TraceEvent uint32 + +const ( + TRACE_STMT TraceEvent = 0x01 + TRACE_PROFILE TraceEvent = 0x02 + TRACE_ROW TraceEvent = 0x04 + TRACE_CLOSE TraceEvent = 0x08 +) + // Datatype is a fundamental datatype of SQLite. // // https://sqlite.org/c3ref/c_blob.html diff --git a/driver/driver.go b/driver/driver.go index e49a3313..c4080e41 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -74,11 +74,22 @@ func init() { // Open opens the SQLite database specified by dataSourceName as a [database/sql.DB]. // -// The init function is called by the driver on new connections. +// Open accepts zero, one, or two callbacks (nil callbacks are ignored). +// The first callback is called when the driver opens a new connection. +// The second callback is called before the driver closes a connection. // The [sqlite3.Conn] can be used to execute queries, register functions, etc. -// Any error returned closes the connection and is returned to [database/sql]. -func Open(dataSourceName string, init func(*sqlite3.Conn) error) (*sql.DB, error) { - c, err := (&SQLite{init}).OpenConnector(dataSourceName) +func Open(dataSourceName string, fn ...func(*sqlite3.Conn) error) (*sql.DB, error) { + var drv SQLite + if len(fn) > 2 { + return nil, util.ArgErr + } + if len(fn) > 1 { + drv.term = fn[1] + } + if len(fn) > 0 { + drv.init = fn[0] + } + c, err := drv.OpenConnector(dataSourceName) if err != nil { return nil, err } @@ -88,6 +99,7 @@ func Open(dataSourceName string, init func(*sqlite3.Conn) error) (*sql.DB, error // SQLite implements [database/sql/driver.Driver]. type SQLite struct { init func(*sqlite3.Conn) error + term func(*sqlite3.Conn) error } // Open implements [database/sql/driver.Driver]. @@ -204,6 +216,14 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) { return nil, err } } + if n.driver.term != nil { + err = c.Conn.Trace(sqlite3.TRACE_CLOSE, func(sqlite3.TraceEvent, any, any) error { + return n.driver.term(c.Conn) + }) + if err != nil { + return nil, err + } + } return c, nil } diff --git a/driver/json_test.go b/driver/json_test.go index e1cce5fd..e7604030 100644 --- a/driver/json_test.go +++ b/driver/json_test.go @@ -11,7 +11,7 @@ import ( ) func Example_json() { - db, err := driver.Open("file:/test.db?vfs=memdb", nil) + db, err := driver.Open("file:/test.db?vfs=memdb") if err != nil { log.Fatal(err) } diff --git a/driver/savepoint_test.go b/driver/savepoint_test.go index a95cb35a..9939b692 100644 --- a/driver/savepoint_test.go +++ b/driver/savepoint_test.go @@ -10,7 +10,7 @@ import ( ) func ExampleSavepoint() { - db, err := driver.Open("file:/test.db?vfs=memdb", nil) + db, err := driver.Open("file:/test.db?vfs=memdb") if err != nil { log.Fatal(err) } diff --git a/embed/exports.txt b/embed/exports.txt index f9c47611..e7882cb5 100644 --- a/embed/exports.txt +++ b/embed/exports.txt @@ -55,17 +55,20 @@ sqlite3_create_function_go sqlite3_create_module_go sqlite3_create_window_function_go sqlite3_database_file_object +sqlite3_db_cacheflush sqlite3_db_config sqlite3_db_filename sqlite3_db_name sqlite3_db_readonly sqlite3_db_release_memory +sqlite3_db_status sqlite3_declare_vtab sqlite3_errcode sqlite3_errmsg sqlite3_error_offset sqlite3_errstr sqlite3_exec +sqlite3_expanded_sql sqlite3_file_control sqlite3_filename_database sqlite3_filename_journal @@ -101,16 +104,18 @@ sqlite3_step sqlite3_stmt_busy sqlite3_stmt_readonly sqlite3_stmt_status +sqlite3_table_column_metadata sqlite3_total_changes64 +sqlite3_trace_go sqlite3_txn_state sqlite3_update_hook_go sqlite3_uri_key -sqlite3_uri_parameter sqlite3_value_blob sqlite3_value_bytes sqlite3_value_double sqlite3_value_dup sqlite3_value_free +sqlite3_value_frombind sqlite3_value_int64 sqlite3_value_nochange sqlite3_value_numeric_type diff --git a/embed/sqlite3.wasm b/embed/sqlite3.wasm index 4700bbf9..99671f14 100755 Binary files a/embed/sqlite3.wasm and b/embed/sqlite3.wasm differ diff --git a/gormlite/sqlite.go b/gormlite/sqlite.go index cbbff742..57062814 100644 --- a/gormlite/sqlite.go +++ b/gormlite/sqlite.go @@ -37,7 +37,7 @@ func (dialector _Dialector) Initialize(db *gorm.DB) (err error) { if dialector.Conn != nil { db.ConnPool = dialector.Conn } else { - conn, err := driver.Open(dialector.DSN, nil) + conn, err := driver.Open(dialector.DSN) if err != nil { return err } diff --git a/internal/util/error.go b/internal/util/error.go index 2aecac96..8ac52c54 100644 --- a/internal/util/error.go +++ b/internal/util/error.go @@ -12,6 +12,7 @@ func (e ErrorString) Error() string { return string(e) } const ( NilErr = ErrorString("sqlite3: invalid memory address or null pointer dereference") OOMErr = ErrorString("sqlite3: out of memory") + ArgErr = ErrorString("sqlite3: invalid argument") RangeErr = ErrorString("sqlite3: index out of range") NoNulErr = ErrorString("sqlite3: missing NUL terminator") NoBinaryErr = ErrorString("sqlite3: no SQLite binary embed/set/loaded") diff --git a/internal/util/mmap.go b/internal/util/mmap.go index 434cc12a..b091e38b 100644 --- a/internal/util/mmap.go +++ b/internal/util/mmap.go @@ -32,7 +32,7 @@ func (s *mmapState) new(ctx context.Context, mod api.Module, size int32) *Mapped // Allocate page aligned memmory. alloc := mod.ExportedFunction("aligned_alloc") - stack := [2]uint64{ + stack := [...]uint64{ uint64(unix.Getpagesize()), uint64(size), } diff --git a/sqlite.go b/sqlite.go index e698fc81..712ad516 100644 --- a/sqlite.go +++ b/sqlite.go @@ -85,7 +85,7 @@ type sqlite struct { id [32]*byte mask uint32 } - stack [8]uint64 + stack [9]uint64 freer uint32 } @@ -306,6 +306,7 @@ func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback) util.ExportFuncVIIIIJ(env, "go_update_hook", updateCallback) util.ExportFuncIIIII(env, "go_wal_hook", walCallback) + util.ExportFuncIIIII(env, "go_trace", traceCallback) util.ExportFuncIIIIII(env, "go_autovacuum_pages", autoVacuumCallback) util.ExportFuncIIIIIII(env, "go_authorizer", authorizerCallback) util.ExportFuncVIII(env, "go_log", logCallback) diff --git a/sqlite3/hooks.c b/sqlite3/hooks.c index c8721317..4acc2fd8 100644 --- a/sqlite3/hooks.c +++ b/sqlite3/hooks.c @@ -10,7 +10,7 @@ int go_commit_hook(void *); void go_rollback_hook(void *); void go_update_hook(void *, int, char const *, char const *, sqlite3_int64); int go_wal_hook(void *, sqlite3 *, const char *, int); - +int go_trace(unsigned, void *, void *, void *); int go_authorizer(void *, int, const char *, const char *, const char *, const char *); @@ -47,6 +47,10 @@ int sqlite3_set_authorizer_go(sqlite3 *db, bool enable) { return sqlite3_set_authorizer(db, enable ? go_authorizer : NULL, /*arg=*/db); } +int sqlite3_trace_go(sqlite3 *db, unsigned mask) { + return sqlite3_trace_v2(db, mask, go_trace, /*arg=*/db); +} + int sqlite3_config_log_go(bool enable) { return sqlite3_config(SQLITE_CONFIG_LOG, enable ? go_log : NULL, /*arg=*/NULL); diff --git a/stmt.go b/stmt.go index 381a7d06..6fb83688 100644 --- a/stmt.go +++ b/stmt.go @@ -15,6 +15,7 @@ import ( type Stmt struct { c *Conn err error + sql string handle uint32 } @@ -29,6 +30,15 @@ func (s *Stmt) Close() error { } r := s.c.call("sqlite3_finalize", uint64(s.handle)) + for i := range s.c.stmts { + if s == s.c.stmts[i] { + l := len(s.c.stmts) - 1 + s.c.stmts[i] = s.c.stmts[l] + s.c.stmts[l] = nil + s.c.stmts = s.c.stmts[:l] + break + } + } s.handle = 0 return s.c.error(r) @@ -41,6 +51,24 @@ func (s *Stmt) Conn() *Conn { return s.c } +// SQL returns the SQL text used to create the prepared statement. +// +// https://sqlite.org/c3ref/expanded_sql.html +func (s *Stmt) SQL() string { + return s.sql +} + +// ExpandedSQL returns the SQL text of the prepared statement +// with bound parameters expanded. +// +// https://sqlite.org/c3ref/expanded_sql.html +func (s *Stmt) ExpandedSQL() string { + r := s.c.call("sqlite3_expanded_sql", uint64(s.handle)) + sql := util.ReadString(s.c.mod, uint32(r), _MAX_SQL_LENGTH) + s.c.free(uint32(r)) + return sql +} + // ReadOnly returns true if and only if the statement // makes no direct changes to the content of the database file. // diff --git a/tests/conn_test.go b/tests/conn_test.go index 0b85f794..e59ad476 100644 --- a/tests/conn_test.go +++ b/tests/conn_test.go @@ -543,6 +543,78 @@ func TestConn_SetAuthorizer(t *testing.T) { } } +func TestConn_Trace(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + rows := 0 + closed := false + err = db.Trace(math.MaxUint32, func(evt sqlite3.TraceEvent, a1 any, a2 any) error { + switch evt { + case sqlite3.TRACE_CLOSE: + closed = true + _ = a1.(*sqlite3.Conn) + return db.Exec(`PRAGMA optimize`) + case sqlite3.TRACE_STMT: + stmt := a1.(*sqlite3.Stmt) + if sql := a2.(string); sql != stmt.SQL() { + t.Errorf("got %q, want %q", sql, stmt.SQL()) + } + if sql := stmt.ExpandedSQL(); sql != `SELECT 1` { + t.Errorf("got %q", sql) + } + case sqlite3.TRACE_PROFILE: + _ = a1.(*sqlite3.Stmt) + if ns := a2.(int64); ns < 0 { + t.Errorf("got %d", ns) + } + case sqlite3.TRACE_ROW: + _ = a1.(*sqlite3.Stmt) + if a2 != nil { + t.Errorf("got %v", a2) + } + rows++ + } + return nil + }) + if err != nil { + t.Fatal(err) + } + + stmt, _, err := db.Prepare(`SELECT ?`) + if err != nil { + t.Fatal(err) + } + err = stmt.BindInt(1, 1) + if err != nil { + t.Fatal(err) + } + err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + if rows != 1 { + t.Error("want 1") + } + + err = db.Close() + if err != nil { + t.Fatal(err) + } + if !closed { + t.Error("want closed") + } +} + func TestConn_ReleaseMemory(t *testing.T) { t.Parallel() @@ -684,3 +756,96 @@ func TestConn_AutoVacuumPages(t *testing.T) { t.Fatal(err) } } + +func TestConn_Status(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE test (col)`) + if err != nil { + t.Fatal(err) + } + + cr, hi, err := db.Status(sqlite3.DBSTATUS_SCHEMA_USED, true) + if err != nil { + t.Error("want nil") + } + if cr == 0 { + t.Error("want something") + } + if hi != 0 { + t.Error("want zero") + } + + cr, hi, err = db.Status(sqlite3.DBSTATUS_LOOKASIDE_HIT, true) + if err != nil { + t.Error("want nil") + } + if cr != 0 { + t.Error("want zero") + } + if hi == 0 { + t.Error("want something") + } + + cr, hi, err = db.Status(sqlite3.DBSTATUS_LOOKASIDE_HIT, true) + if err != nil { + t.Error("want nil") + } + if cr != 0 { + t.Error("want zero") + } + if hi != 0 { + t.Error("want zero") + } +} + +func TestConn_TableColumnMetadata(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE test (col)`) + if err != nil { + t.Fatal(err) + } + + _, _, _, _, _, err = db.TableColumnMetadata("", "table", "") + if err == nil { + t.Error("want error") + } + + _, _, _, _, _, err = db.TableColumnMetadata("", "test", "") + if err != nil { + t.Error("want nil") + } + + typ, ord, nn, pk, ai, err := db.TableColumnMetadata("", "test", "rowid") + if err != nil { + t.Error("want nil") + } + if typ != "INTEGER" { + t.Error("want INTEGER") + } + if ord != "BINARY" { + t.Error("want BINARY") + } + if nn != false { + t.Error("want false") + } + if pk != true { + t.Error("want true") + } + if ai != false { + t.Error("want false") + } +} diff --git a/tests/driver_test.go b/tests/driver_test.go index e176ea1f..f6a4f097 100644 --- a/tests/driver_test.go +++ b/tests/driver_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/ncruces/go-sqlite3" "github.com/ncruces/go-sqlite3/driver" _ "github.com/ncruces/go-sqlite3/embed" _ "github.com/ncruces/go-sqlite3/internal/testcfg" @@ -15,7 +16,9 @@ func TestDriver(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - db, err := driver.Open(":memory:", nil) + db, err := driver.Open(":memory:", nil, func(c *sqlite3.Conn) error { + return c.Exec(`PRAGMA optimize`) + }) if err != nil { t.Fatal(err) } diff --git a/tests/func_test.go b/tests/func_test.go index 921b3efe..acedc868 100644 --- a/tests/func_test.go +++ b/tests/func_test.go @@ -48,7 +48,9 @@ func TestCreateFunction(t *testing.T) { case 10: ctx.ResultNull() case 11: - arg.NoChange() + if arg.NoChange() || arg.FromBind() { + t.Error() + } ctx.ResultError(sqlite3.FULL) } }) diff --git a/tests/json_test.go b/tests/json_test.go index 9d442ad3..d8328d77 100644 --- a/tests/json_test.go +++ b/tests/json_test.go @@ -20,7 +20,7 @@ func TestJSON(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - db, err := driver.Open(":memory:", nil) + db, err := driver.Open(":memory:") if err != nil { t.Fatal(err) } diff --git a/tests/stmt_test.go b/tests/stmt_test.go index 7542673a..c0cb2675 100644 --- a/tests/stmt_test.go +++ b/tests/stmt_test.go @@ -503,6 +503,13 @@ func TestStmt(t *testing.T) { } } + db.Stmts()(func(s *sqlite3.Stmt) bool { + if s != stmt { + t.Error() + } + return false + }) + if err := stmt.Close(); err != nil { t.Fatal(err) } diff --git a/tests/time_test.go b/tests/time_test.go index 3c67beec..529355a0 100644 --- a/tests/time_test.go +++ b/tests/time_test.go @@ -136,7 +136,7 @@ func TestTimeFormat_Scanner(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - db, err := driver.Open(":memory:", nil) + db, err := driver.Open(":memory:") if err != nil { t.Fatal(err) } diff --git a/tests/txn_test.go b/tests/txn_test.go index 3f11af3e..a56abe16 100644 --- a/tests/txn_test.go +++ b/tests/txn_test.go @@ -184,6 +184,10 @@ func TestConn_Transaction_interrupt(t *testing.T) { if err != nil { t.Fatal(err) } + err = db.CacheFlush() + if err != nil { + t.Fatal(err) + } cancel() _, err = db.BeginImmediate() diff --git a/tests/wal_test.go b/tests/wal_test.go index b3d21010..432841eb 100644 --- a/tests/wal_test.go +++ b/tests/wal_test.go @@ -54,13 +54,13 @@ func TestWAL_readonly(t *testing.T) { tmp := filepath.ToSlash(filepath.Join(t.TempDir(), "test.db")) - db1, err := driver.Open("file:"+tmp+"?_pragma=journal_mode(wal)&_txlock=immediate", nil) + db1, err := driver.Open("file:" + tmp + "?_pragma=journal_mode(wal)&_txlock=immediate") if err != nil { t.Fatal(err) } defer db1.Close() - db2, err := driver.Open("file:"+tmp+"?_pragma=journal_mode(wal)&mode=ro", nil) + db2, err := driver.Open("file:" + tmp + "?_pragma=journal_mode(wal)&mode=ro") if err != nil { t.Fatal(err) } diff --git a/txn.go b/txn.go index 1badf676..6fe288e1 100644 --- a/txn.go +++ b/txn.go @@ -305,3 +305,11 @@ func updateCallback(ctx context.Context, mod api.Module, pDB uint32, action Auth c.update(action, schema, table, int64(rowid)) } } + +// CacheFlush flushes caches to disk mid-transaction. +// +// https://sqlite.org/c3ref/db_cacheflush.html +func (c *Conn) CacheFlush() error { + r := c.call("sqlite3_db_cacheflush", uint64(c.handle)) + return c.error(r) +} diff --git a/value.go b/value.go index 1894ff4f..86f6689d 100644 --- a/value.go +++ b/value.go @@ -201,6 +201,14 @@ func (v Value) NoChange() bool { return r != 0 } +// FromBind returns true if value originated from a bound parameter. +// +// https://sqlite.org/c3ref/value_blob.html +func (v Value) FromBind() bool { + r := v.c.call("sqlite3_value_frombind", v.protected()) + return r != 0 +} + // InFirst returns the first element // on the right-hand side of an IN constraint. // diff --git a/vfs/adiantum/adiantum_test.go b/vfs/adiantum/adiantum_test.go index 0f1d289c..1096a358 100644 --- a/vfs/adiantum/adiantum_test.go +++ b/vfs/adiantum/adiantum_test.go @@ -23,7 +23,7 @@ func Test_fileformat(t *testing.T) { readervfs.Create("test.db", ioutil.NewSizeReaderAt(strings.NewReader(testDB))) adiantum.Register("radiantum", vfs.Find("reader"), nil) - db, err := driver.Open("file:test.db?vfs=radiantum", nil) + db, err := driver.Open("file:test.db?vfs=radiantum") if err != nil { t.Fatal(err) }