diff --git a/go/core/core.go b/go/core/core.go index beb7207f..f5a71f34 100644 --- a/go/core/core.go +++ b/go/core/core.go @@ -61,10 +61,16 @@ type StaticTags struct { // CommenterOptions contains all options regarding SQLCommenter library. // This includes the configurations as well as any static tags. type CommenterOptions struct { + Driver DriverOptions Config CommenterConfig Tags StaticTags } +type DriverOptions struct { + // Setting this to true means your underlying driver supports the ConnBeginTx interface + WithBeginTX bool +} + func encodeURL(k string) string { return url.QueryEscape(k) } diff --git a/go/database/sql/connection.go b/go/database/sql/connection.go index 2dbb5f1d..faac4201 100644 --- a/go/database/sql/connection.go +++ b/go/database/sql/connection.go @@ -31,6 +31,19 @@ type sqlCommenterConn struct { options core.CommenterOptions } +type sqlCommenterConnWithTx struct { + sqlCommenterConn +} + +func newConn(conn driver.Conn, options core.CommenterOptions) driver.Conn { + commenterConn := newSQLCommenterConn(conn, options) + if options.Driver.WithBeginTX { + return &sqlCommenterConnWithTx{*commenterConn} + } + + return commenterConn +} + func newSQLCommenterConn(conn driver.Conn, options core.CommenterOptions) *sqlCommenterConn { return &sqlCommenterConn{ Conn: conn, @@ -89,6 +102,15 @@ func (s *sqlCommenterConn) Raw() driver.Conn { return s.Conn } +func (s *sqlCommenterConnWithTx) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + begintxer, ok := s.Conn.(driver.ConnBeginTx) + if !ok { + return nil, driver.ErrSkip + } + + return begintxer.BeginTx(ctx, opts) +} + // ***** Commenter Functions ***** func (conn *sqlCommenterConn) withComment(ctx context.Context, query string) string { diff --git a/go/database/sql/gosql.go b/go/database/sql/gosql.go index 525ddd37..24f68396 100644 --- a/go/database/sql/gosql.go +++ b/go/database/sql/gosql.go @@ -43,7 +43,7 @@ func (d *sqlCommenterDriver) Open(name string) (driver.Conn, error) { if err != nil { return nil, err } - return newSQLCommenterConn(rawConn, d.options), nil + return newConn(rawConn, d.options), nil } func (d *sqlCommenterDriver) OpenConnector(name string) (driver.Connector, error) { @@ -73,7 +73,7 @@ func (c *sqlCommenterConnector) Connect(ctx context.Context) (connection driver. if err != nil { return nil, err } - return newSQLCommenterConn(connection, c.options), nil + return newConn(connection, c.options), nil } func (c *sqlCommenterConnector) Driver() driver.Driver { diff --git a/go/go.work b/go/go.work new file mode 100644 index 00000000..086e1927 --- /dev/null +++ b/go/go.work @@ -0,0 +1,9 @@ +go 1.19 + +use ( + ./core + ./database/sql + ./gorrila/mux + ./net/http + ./samples/http +)