From 6451bb73eb664543e73893b1748eff9f2bbdaab8 Mon Sep 17 00:00:00 2001 From: "lingfeng.wu@mihoyo.com" Date: Fri, 12 May 2023 12:04:39 +0800 Subject: [PATCH] feat(context): export wrapped parent struct and add stmt context --- conn.go | 18 +++++++++--------- stmt.go | 11 ++++++----- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/conn.go b/conn.go index be89a4a..66da441 100755 --- a/conn.go +++ b/conn.go @@ -43,7 +43,7 @@ func (c wrappedConn) Begin() (driver.Tx, error) { } func (c wrappedConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) { - wrappedParent := wrappedParentConn{c.parent} + wrappedParent := WrappedParentConn{c.parent} ctx, tx, err = c.intr.ConnBeginTx(ctx, wrappedParent, opts) if err != nil { return nil, err @@ -52,7 +52,7 @@ func (c wrappedConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx dri } func (c wrappedConn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) { - wrappedParent := wrappedParentConn{c.parent} + wrappedParent := WrappedParentConn{c.parent} ctx, stmt, err = c.intr.ConnPrepareContext(ctx, wrappedParent, query) if err != nil { return nil, err @@ -73,7 +73,7 @@ func (c wrappedConn) Exec(query string, args []driver.Value) (driver.Result, err } func (c wrappedConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) { - wrappedParent := wrappedParentConn{c.parent} + wrappedParent := WrappedParentConn{c.parent} r, err = c.intr.ConnExecContext(ctx, wrappedParent, query, args) if err != nil { return nil, err @@ -107,7 +107,7 @@ func (c wrappedConn) QueryContext(ctx context.Context, query string, args []driv return nil, driver.ErrSkip } - wrappedParent := wrappedParentConn{c.parent} + wrappedParent := WrappedParentConn{c.parent} ctx, rows, err = c.intr.ConnQueryContext(ctx, wrappedParent, query, args) if err != nil { return nil, err @@ -116,11 +116,11 @@ func (c wrappedConn) QueryContext(ctx context.Context, query string, args []driv return wrapRows(ctx, c.intr, rows), nil } -type wrappedParentConn struct { +type WrappedParentConn struct { driver.Conn } -func (c wrappedParentConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) { +func (c WrappedParentConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) { if connBeginTx, ok := c.Conn.(driver.ConnBeginTx); ok { return connBeginTx.BeginTx(ctx, opts) } @@ -133,7 +133,7 @@ func (c wrappedParentConn) BeginTx(ctx context.Context, opts driver.TxOptions) ( } } -func (c wrappedParentConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { +func (c WrappedParentConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { if connPrepareCtx, ok := c.Conn.(driver.ConnPrepareContext); ok { return connPrepareCtx.PrepareContext(ctx, query) } @@ -146,7 +146,7 @@ func (c wrappedParentConn) PrepareContext(ctx context.Context, query string) (dr } } -func (c wrappedParentConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) { +func (c WrappedParentConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) { if execContext, ok := c.Conn.(driver.ExecerContext); ok { return execContext.ExecContext(ctx, query, args) } @@ -163,7 +163,7 @@ func (c wrappedParentConn) ExecContext(ctx context.Context, query string, args [ } } -func (c wrappedParentConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { +func (c WrappedParentConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { if queryerContext, ok := c.Conn.(driver.QueryerContext); ok { return queryerContext.QueryContext(ctx, query, args) } diff --git a/stmt.go b/stmt.go index f8c89ba..e6bb2ec 100755 --- a/stmt.go +++ b/stmt.go @@ -46,7 +46,7 @@ func (s wrappedStmt) Query(args []driver.Value) (rows driver.Rows, err error) { } func (s wrappedStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) { - wrappedParent := wrappedParentStmt{Stmt: s.parent} + wrappedParent := WrappedParentStmt{Stmt: s.parent, Context: s.ctx} res, err = s.intr.StmtExecContext(ctx, wrappedParent, s.query, args) if err != nil { return nil, err @@ -55,7 +55,7 @@ func (s wrappedStmt) ExecContext(ctx context.Context, args []driver.NamedValue) } func (s wrappedStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { - wrappedParent := wrappedParentStmt{Stmt: s.parent} + wrappedParent := WrappedParentStmt{Stmt: s.parent, Context: s.ctx} ctx, rows, err = s.intr.StmtQueryContext(ctx, wrappedParent, s.query, args) if err != nil { return nil, err @@ -71,11 +71,12 @@ func (s wrappedStmt) ColumnConverter(idx int) driver.ValueConverter { return driver.DefaultParameterConverter } -type wrappedParentStmt struct { +type WrappedParentStmt struct { driver.Stmt + Context context.Context } -func (s wrappedParentStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { +func (s WrappedParentStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { if stmtQueryContext, ok := s.Stmt.(driver.StmtQueryContext); ok { return stmtQueryContext.QueryContext(ctx, args) } @@ -92,7 +93,7 @@ func (s wrappedParentStmt) QueryContext(ctx context.Context, args []driver.Named return s.Stmt.Query(dargs) } -func (s wrappedParentStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) { +func (s WrappedParentStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) { if stmtExecContext, ok := s.Stmt.(driver.StmtExecContext); ok { return stmtExecContext.ExecContext(ctx, args) }