diff --git a/.gitignore b/.gitignore index e4001c0..0e5426a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /examples/blog/blog /examples/orders/orders /examples/basic/basic +.idea/ diff --git a/expectations.go b/expectations.go index ae2a47f..5c82c7b 100644 --- a/expectations.go +++ b/expectations.go @@ -339,21 +339,6 @@ type queryBasedExpectation struct { args []driver.Value } -func (e *queryBasedExpectation) attemptArgMatch(args []namedValue) (err error) { - // catch panic - defer func() { - if e := recover(); e != nil { - _, ok := e.(error) - if !ok { - err = fmt.Errorf(e.(string)) - } - } - }() - - err = e.argsMatches(args) - return -} - // ExpectedPing is used to manage *sql.DB.Ping expectations. // Returned by *Sqlmock.ExpectPing. type ExpectedPing struct { diff --git a/expectations_before_go18.go b/expectations_before_go18.go index e368e04..f6e7b4e 100644 --- a/expectations_before_go18.go +++ b/expectations_before_go18.go @@ -50,3 +50,18 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error { } return nil } + +func (e *queryBasedExpectation) attemptArgMatch(args []namedValue) (err error) { + // catch panic + defer func() { + if e := recover(); e != nil { + _, ok := e.(error) + if !ok { + err = fmt.Errorf(e.(string)) + } + } + }() + + err = e.argsMatches(args) + return +} diff --git a/expectations_before_go18_test.go b/expectations_before_go18_test.go new file mode 100644 index 0000000..897ebff --- /dev/null +++ b/expectations_before_go18_test.go @@ -0,0 +1,118 @@ +// +build !go1.8 + +package sqlmock + +import ( + "database/sql/driver" + "testing" + "time" +) + +func TestQueryExpectationArgComparison(t *testing.T) { + e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} + against := []namedValue{{Value: int64(5), Ordinal: 1}} + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) + } + + e.args = []driver.Value{5, "str"} + + against = []namedValue{{Value: int64(5), Ordinal: 1}} + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since the size is not the same") + } + + against = []namedValue{ + {Value: int64(3), Ordinal: 1}, + {Value: "str", Ordinal: 2}, + } + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since the first argument (int value) is different") + } + + against = []namedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: "st", Ordinal: 2}, + } + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since the second argument (string value) is different") + } + + against = []namedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: "str", Ordinal: 2}, + } + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should match, but it did not: %s", err) + } + + const longForm = "Jan 2, 2006 at 3:04pm (MST)" + tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)") + e.args = []driver.Value{5, tm} + + against = []namedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: tm, Ordinal: 2}, + } + if err := e.argsMatches(against); err != nil { + t.Error("arguments should match, but it did not") + } + + e.args = []driver.Value{5, AnyArg()} + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should match, but it did not: %s", err) + } +} + +func TestQueryExpectationArgComparisonBool(t *testing.T) { + var e *queryBasedExpectation + + e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter} + against := []namedValue{ + {Value: true, Ordinal: 1}, + } + if err := e.argsMatches(against); err != nil { + t.Error("arguments should match, since arguments are the same") + } + + e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter} + against = []namedValue{ + {Value: false, Ordinal: 1}, + } + if err := e.argsMatches(against); err != nil { + t.Error("arguments should match, since argument are the same") + } + + e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter} + against = []namedValue{ + {Value: false, Ordinal: 1}, + } + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since argument is different") + } + + e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter} + against = []namedValue{ + {Value: true, Ordinal: 1}, + } + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since argument is different") + } +} + +type panicConverter struct { +} + +func (s panicConverter) ConvertValue(v interface{}) (driver.Value, error) { + panic(v) +} + +func Test_queryBasedExpectation_attemptArgMatch(t *testing.T) { + e := &queryBasedExpectation{converter: new(panicConverter), args: []driver.Value{"test"}} + values := []namedValue{ + {Ordinal: 1, Name: "test", Value: "test"}, + } + if err := e.attemptArgMatch(values); err == nil { + t.Errorf("error expected") + } +} diff --git a/expectations_go18.go b/expectations_go18.go index 6ee8adf..172bb6c 100644 --- a/expectations_go18.go +++ b/expectations_go18.go @@ -4,6 +4,7 @@ package sqlmock import ( "database/sql" + "database/sql/driver" "fmt" "reflect" ) @@ -19,7 +20,7 @@ func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery { return e } -func (e *queryBasedExpectation) argsMatches(args []namedValue) error { +func (e *queryBasedExpectation) argsMatches(args []driver.NamedValue) error { if nil == e.args { return nil } @@ -59,3 +60,18 @@ func (e *queryBasedExpectation) argsMatches(args []namedValue) error { } return nil } + +func (e *queryBasedExpectation) attemptArgMatch(args []driver.NamedValue) (err error) { + // catch panic + defer func() { + if e := recover(); e != nil { + _, ok := e.(error) + if !ok { + err = fmt.Errorf(e.(string)) + } + } + }() + + err = e.argsMatches(args) + return +} diff --git a/expectations_go18_test.go b/expectations_go18_test.go index 2b85db3..1974721 100644 --- a/expectations_go18_test.go +++ b/expectations_go18_test.go @@ -6,11 +6,104 @@ import ( "database/sql" "database/sql/driver" "testing" + "time" ) +func TestQueryExpectationArgComparison(t *testing.T) { + e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} + against := []driver.NamedValue{{Value: int64(5), Ordinal: 1}} + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) + } + + e.args = []driver.Value{5, "str"} + + against = []driver.NamedValue{{Value: int64(5), Ordinal: 1}} + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since the size is not the same") + } + + against = []driver.NamedValue{ + {Value: int64(3), Ordinal: 1}, + {Value: "str", Ordinal: 2}, + } + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since the first argument (int value) is different") + } + + against = []driver.NamedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: "st", Ordinal: 2}, + } + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since the second argument (string value) is different") + } + + against = []driver.NamedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: "str", Ordinal: 2}, + } + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should match, but it did not: %s", err) + } + + const longForm = "Jan 2, 2006 at 3:04pm (MST)" + tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)") + e.args = []driver.Value{5, tm} + + against = []driver.NamedValue{ + {Value: int64(5), Ordinal: 1}, + {Value: tm, Ordinal: 2}, + } + if err := e.argsMatches(against); err != nil { + t.Error("arguments should match, but it did not") + } + + e.args = []driver.Value{5, AnyArg()} + if err := e.argsMatches(against); err != nil { + t.Errorf("arguments should match, but it did not: %s", err) + } +} + +func TestQueryExpectationArgComparisonBool(t *testing.T) { + var e *queryBasedExpectation + + e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter} + against := []driver.NamedValue{ + {Value: true, Ordinal: 1}, + } + if err := e.argsMatches(against); err != nil { + t.Error("arguments should match, since arguments are the same") + } + + e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter} + against = []driver.NamedValue{ + {Value: false, Ordinal: 1}, + } + if err := e.argsMatches(against); err != nil { + t.Error("arguments should match, since argument are the same") + } + + e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter} + against = []driver.NamedValue{ + {Value: false, Ordinal: 1}, + } + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since argument is different") + } + + e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter} + against = []driver.NamedValue{ + {Value: true, Ordinal: 1}, + } + if err := e.argsMatches(against); err == nil { + t.Error("arguments should not match, since argument is different") + } +} + func TestQueryExpectationNamedArgComparison(t *testing.T) { e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} - against := []namedValue{{Value: int64(5), Name: "id"}} + against := []driver.NamedValue{{Value: int64(5), Name: "id"}} if err := e.argsMatches(against); err != nil { t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) } @@ -24,7 +117,7 @@ func TestQueryExpectationNamedArgComparison(t *testing.T) { t.Error("arguments should not match, since the size is not the same") } - against = []namedValue{ + against = []driver.NamedValue{ {Value: int64(5), Name: "id"}, {Value: "str", Name: "s"}, } @@ -33,7 +126,7 @@ func TestQueryExpectationNamedArgComparison(t *testing.T) { t.Errorf("arguments should have matched, but it did not: %v", err) } - against = []namedValue{ + against = []driver.NamedValue{ {Value: int64(5), Name: "id"}, {Value: "str", Name: "username"}, } @@ -44,7 +137,7 @@ func TestQueryExpectationNamedArgComparison(t *testing.T) { e.args = []driver.Value{int64(5), "str"} - against = []namedValue{ + against = []driver.NamedValue{ {Value: int64(5), Ordinal: 0}, {Value: "str", Ordinal: 1}, } @@ -53,7 +146,7 @@ func TestQueryExpectationNamedArgComparison(t *testing.T) { t.Error("arguments matched, but it should have not due to wrong Ordinal position") } - against = []namedValue{ + against = []driver.NamedValue{ {Value: int64(5), Ordinal: 1}, {Value: "str", Ordinal: 2}, } @@ -62,3 +155,20 @@ func TestQueryExpectationNamedArgComparison(t *testing.T) { t.Errorf("arguments should have matched, but it did not: %v", err) } } + +type panicConverter struct { +} + +func (s panicConverter) ConvertValue(v interface{}) (driver.Value, error) { + panic(v) +} + +func Test_queryBasedExpectation_attemptArgMatch(t *testing.T) { + e := &queryBasedExpectation{converter: new(panicConverter), args: []driver.Value{"test"}} + values := []driver.NamedValue{ + {Ordinal: 1, Name: "test", Value: "test"}, + } + if err := e.attemptArgMatch(values); err == nil { + t.Errorf("error expected") + } +} diff --git a/expectations_test.go b/expectations_test.go index c6889c3..afda582 100644 --- a/expectations_test.go +++ b/expectations_test.go @@ -6,98 +6,20 @@ import ( "fmt" "reflect" "testing" - "time" ) -func TestQueryExpectationArgComparison(t *testing.T) { - e := &queryBasedExpectation{converter: driver.DefaultParameterConverter} - against := []namedValue{{Value: int64(5), Ordinal: 1}} - if err := e.argsMatches(against); err != nil { - t.Errorf("arguments should match, since the no expectation was set, but got err: %s", err) - } - - e.args = []driver.Value{5, "str"} - - against = []namedValue{{Value: int64(5), Ordinal: 1}} - if err := e.argsMatches(against); err == nil { - t.Error("arguments should not match, since the size is not the same") - } - - against = []namedValue{ - {Value: int64(3), Ordinal: 1}, - {Value: "str", Ordinal: 2}, - } - if err := e.argsMatches(against); err == nil { - t.Error("arguments should not match, since the first argument (int value) is different") - } - - against = []namedValue{ - {Value: int64(5), Ordinal: 1}, - {Value: "st", Ordinal: 2}, - } - if err := e.argsMatches(against); err == nil { - t.Error("arguments should not match, since the second argument (string value) is different") - } - - against = []namedValue{ - {Value: int64(5), Ordinal: 1}, - {Value: "str", Ordinal: 2}, - } - if err := e.argsMatches(against); err != nil { - t.Errorf("arguments should match, but it did not: %s", err) - } - - const longForm = "Jan 2, 2006 at 3:04pm (MST)" - tm, _ := time.Parse(longForm, "Feb 3, 2013 at 7:54pm (PST)") - e.args = []driver.Value{5, tm} - - against = []namedValue{ - {Value: int64(5), Ordinal: 1}, - {Value: tm, Ordinal: 2}, - } - if err := e.argsMatches(against); err != nil { - t.Error("arguments should match, but it did not") - } - - e.args = []driver.Value{5, AnyArg()} - if err := e.argsMatches(against); err != nil { - t.Errorf("arguments should match, but it did not: %s", err) - } -} - -func TestQueryExpectationArgComparisonBool(t *testing.T) { - var e *queryBasedExpectation - - e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter} - against := []namedValue{ - {Value: true, Ordinal: 1}, - } - if err := e.argsMatches(against); err != nil { - t.Error("arguments should match, since arguments are the same") - } - - e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter} - against = []namedValue{ - {Value: false, Ordinal: 1}, - } - if err := e.argsMatches(against); err != nil { - t.Error("arguments should match, since argument are the same") - } - - e = &queryBasedExpectation{args: []driver.Value{true}, converter: driver.DefaultParameterConverter} - against = []namedValue{ - {Value: false, Ordinal: 1}, - } - if err := e.argsMatches(against); err == nil { - t.Error("arguments should not match, since argument is different") - } +type CustomConverter struct{} - e = &queryBasedExpectation{args: []driver.Value{false}, converter: driver.DefaultParameterConverter} - against = []namedValue{ - {Value: true, Ordinal: 1}, - } - if err := e.argsMatches(against); err == nil { - t.Error("arguments should not match, since argument is different") +func (s CustomConverter) ConvertValue(v interface{}) (driver.Value, error) { + switch v.(type) { + case string: + return v.(string), nil + case []string: + return v.([]string), nil + case int: + return v.(int), nil + default: + return nil, errors.New(fmt.Sprintf("cannot convert %T with value %v", v, v)) } } @@ -140,20 +62,6 @@ func TestBuildQuery(t *testing.T) { } } -type CustomConverter struct{} - -func (s CustomConverter) ConvertValue(v interface{}) (driver.Value, error) { - switch v.(type) { - case string: - return v.(string), nil - case []string: - return v.([]string), nil - case int: - return v.(int), nil - default: - return nil, errors.New(fmt.Sprintf("cannot convert %T with value %v", v, v)) - } -} func TestCustomValueConverterQueryScan(t *testing.T) { db, mock, _ := New(ValueConverterOption(CustomConverter{})) query := ` diff --git a/sqlmock.go b/sqlmock.go index 9431d0e..90f789b 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -265,88 +265,6 @@ func (c *sqlmock) ExpectBegin() *ExpectedBegin { return e } -// Exec meets http://golang.org/pkg/database/sql/driver/#Execer -func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) { - namedArgs := make([]namedValue, len(args)) - for i, v := range args { - namedArgs[i] = namedValue{ - Ordinal: i + 1, - Value: v, - } - } - - ex, err := c.exec(query, namedArgs) - if ex != nil { - time.Sleep(ex.delay) - } - if err != nil { - return nil, err - } - - return ex.result, nil -} - -func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) { - var expected *ExpectedExec - var fulfilled int - var ok bool - for _, next := range c.expected { - next.Lock() - if next.fulfilled() { - next.Unlock() - fulfilled++ - continue - } - - if c.ordered { - if expected, ok = next.(*ExpectedExec); ok { - break - } - next.Unlock() - return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) - } - if exec, ok := next.(*ExpectedExec); ok { - if err := c.queryMatcher.Match(exec.expectSQL, query); err != nil { - next.Unlock() - continue - } - - if err := exec.attemptArgMatch(args); err == nil { - expected = exec - break - } - } - next.Unlock() - } - if expected == nil { - msg := "call to ExecQuery '%s' with args %+v was not expected" - if fulfilled == len(c.expected) { - msg = "all expectations were already fulfilled, " + msg - } - return nil, fmt.Errorf(msg, query, args) - } - defer expected.Unlock() - - if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { - return nil, fmt.Errorf("ExecQuery: %v", err) - } - - if err := expected.argsMatches(args); err != nil { - return nil, fmt.Errorf("ExecQuery '%s', arguments do not match: %s", query, err) - } - - expected.triggered = true - if expected.err != nil { - return expected, expected.err // mocked to return error - } - - if expected.result == nil { - return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected) - } - - return expected, nil -} - func (c *sqlmock) ExpectExec(expectedSQL string) *ExpectedExec { e := &ExpectedExec{} e.expectSQL = expectedSQL @@ -421,94 +339,6 @@ func (c *sqlmock) ExpectPrepare(expectedSQL string) *ExpectedPrepare { return e } -type namedValue struct { - Name string - Ordinal int - Value driver.Value -} - -// Query meets http://golang.org/pkg/database/sql/driver/#Queryer -func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) { - namedArgs := make([]namedValue, len(args)) - for i, v := range args { - namedArgs[i] = namedValue{ - Ordinal: i + 1, - Value: v, - } - } - - ex, err := c.query(query, namedArgs) - if ex != nil { - time.Sleep(ex.delay) - } - if err != nil { - return nil, err - } - - return ex.rows, nil -} - -func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) { - var expected *ExpectedQuery - var fulfilled int - var ok bool - for _, next := range c.expected { - next.Lock() - if next.fulfilled() { - next.Unlock() - fulfilled++ - continue - } - - if c.ordered { - if expected, ok = next.(*ExpectedQuery); ok { - break - } - next.Unlock() - return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) - } - if qr, ok := next.(*ExpectedQuery); ok { - if err := c.queryMatcher.Match(qr.expectSQL, query); err != nil { - next.Unlock() - continue - } - if err := qr.attemptArgMatch(args); err == nil { - expected = qr - break - } - } - next.Unlock() - } - - if expected == nil { - msg := "call to Query '%s' with args %+v was not expected" - if fulfilled == len(c.expected) { - msg = "all expectations were already fulfilled, " + msg - } - return nil, fmt.Errorf(msg, query, args) - } - - defer expected.Unlock() - - if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { - return nil, fmt.Errorf("Query: %v", err) - } - - if err := expected.argsMatches(args); err != nil { - return nil, fmt.Errorf("Query '%s', arguments do not match: %s", query, err) - } - - expected.triggered = true - if expected.err != nil { - return expected, expected.err // mocked to return error - } - - if expected.rows == nil { - return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected) - } - return expected, nil -} - func (c *sqlmock) ExpectQuery(expectedSQL string) *ExpectedQuery { e := &ExpectedQuery{} e.expectSQL = expectedSQL diff --git a/sqlmock_before_go18.go b/sqlmock_before_go18.go index 88b7aa0..1a5b63a 100644 --- a/sqlmock_before_go18.go +++ b/sqlmock_before_go18.go @@ -2,9 +2,184 @@ package sqlmock -import "log" +import ( + "database/sql/driver" + "fmt" + "log" + "time" +) + +type namedValue struct { + Name string + Ordinal int + Value driver.Value +} func (c *sqlmock) ExpectPing() *ExpectedPing { log.Println("ExpectPing has no effect on Go 1.7 or below") return &ExpectedPing{} } + +// Query meets http://golang.org/pkg/database/sql/driver/#Queryer +func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) { + namedArgs := make([]namedValue, len(args)) + for i, v := range args { + namedArgs[i] = namedValue{ + Ordinal: i + 1, + Value: v, + } + } + + ex, err := c.query(query, namedArgs) + if ex != nil { + time.Sleep(ex.delay) + } + if err != nil { + return nil, err + } + + return ex.rows, nil +} + +func (c *sqlmock) query(query string, args []namedValue) (*ExpectedQuery, error) { + var expected *ExpectedQuery + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if c.ordered { + if expected, ok = next.(*ExpectedQuery); ok { + break + } + next.Unlock() + return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) + } + if qr, ok := next.(*ExpectedQuery); ok { + if err := c.queryMatcher.Match(qr.expectSQL, query); err != nil { + next.Unlock() + continue + } + if err := qr.attemptArgMatch(args); err == nil { + expected = qr + break + } + } + next.Unlock() + } + + if expected == nil { + msg := "call to Query '%s' with args %+v was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg, query, args) + } + + defer expected.Unlock() + + if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { + return nil, fmt.Errorf("Query: %v", err) + } + + if err := expected.argsMatches(args); err != nil { + return nil, fmt.Errorf("Query '%s', arguments do not match: %s", query, err) + } + + expected.triggered = true + if expected.err != nil { + return expected, expected.err // mocked to return error + } + + if expected.rows == nil { + return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected) + } + return expected, nil +} + +// Exec meets http://golang.org/pkg/database/sql/driver/#Execer +func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) { + namedArgs := make([]namedValue, len(args)) + for i, v := range args { + namedArgs[i] = namedValue{ + Ordinal: i + 1, + Value: v, + } + } + + ex, err := c.exec(query, namedArgs) + if ex != nil { + time.Sleep(ex.delay) + } + if err != nil { + return nil, err + } + + return ex.result, nil +} + +func (c *sqlmock) exec(query string, args []namedValue) (*ExpectedExec, error) { + var expected *ExpectedExec + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if c.ordered { + if expected, ok = next.(*ExpectedExec); ok { + break + } + next.Unlock() + return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) + } + if exec, ok := next.(*ExpectedExec); ok { + if err := c.queryMatcher.Match(exec.expectSQL, query); err != nil { + next.Unlock() + continue + } + + if err := exec.attemptArgMatch(args); err == nil { + expected = exec + break + } + } + next.Unlock() + } + if expected == nil { + msg := "call to ExecQuery '%s' with args %+v was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg, query, args) + } + defer expected.Unlock() + + if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { + return nil, fmt.Errorf("ExecQuery: %v", err) + } + + if err := expected.argsMatches(args); err != nil { + return nil, fmt.Errorf("ExecQuery '%s', arguments do not match: %s", query, err) + } + + expected.triggered = true + if expected.err != nil { + return expected, expected.err // mocked to return error + } + + if expected.result == nil { + return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected) + } + + return expected, nil +} diff --git a/sqlmock_go18.go b/sqlmock_go18.go index 43fbb5d..dc37b18 100644 --- a/sqlmock_go18.go +++ b/sqlmock_go18.go @@ -17,12 +17,7 @@ var ErrCancelled = errors.New("canceling query due to user request") // Implement the "QueryerContext" interface func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - namedArgs := make([]namedValue, len(args)) - for i, nv := range args { - namedArgs[i] = namedValue(nv) - } - - ex, err := c.query(query, namedArgs) + ex, err := c.query(query, args) if ex != nil { select { case <-time.After(ex.delay): @@ -40,12 +35,7 @@ func (c *sqlmock) QueryContext(ctx context.Context, query string, args []driver. // Implement the "ExecerContext" interface func (c *sqlmock) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - namedArgs := make([]namedValue, len(args)) - for i, nv := range args { - namedArgs[i] = namedValue(nv) - } - - ex, err := c.exec(query, namedArgs) + ex, err := c.exec(query, args) if ex != nil { select { case <-time.After(ex.delay): @@ -170,4 +160,170 @@ func (c *sqlmock) ExpectPing() *ExpectedPing { return e } +// Query meets http://golang.org/pkg/database/sql/driver/#Queryer +// Deprecated: Drivers should implement QueryerContext instead. +func (c *sqlmock) Query(query string, args []driver.Value) (driver.Rows, error) { + namedArgs := make([]driver.NamedValue, len(args)) + for i, v := range args { + namedArgs[i] = driver.NamedValue{ + Ordinal: i + 1, + Value: v, + } + } + + ex, err := c.query(query, namedArgs) + if ex != nil { + time.Sleep(ex.delay) + } + if err != nil { + return nil, err + } + + return ex.rows, nil +} + +func (c *sqlmock) query(query string, args []driver.NamedValue) (*ExpectedQuery, error) { + var expected *ExpectedQuery + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if c.ordered { + if expected, ok = next.(*ExpectedQuery); ok { + break + } + next.Unlock() + return nil, fmt.Errorf("call to Query '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) + } + if qr, ok := next.(*ExpectedQuery); ok { + if err := c.queryMatcher.Match(qr.expectSQL, query); err != nil { + next.Unlock() + continue + } + if err := qr.attemptArgMatch(args); err == nil { + expected = qr + break + } + } + next.Unlock() + } + + if expected == nil { + msg := "call to Query '%s' with args %+v was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg, query, args) + } + + defer expected.Unlock() + + if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { + return nil, fmt.Errorf("Query: %v", err) + } + + if err := expected.argsMatches(args); err != nil { + return nil, fmt.Errorf("Query '%s', arguments do not match: %s", query, err) + } + + expected.triggered = true + if expected.err != nil { + return expected, expected.err // mocked to return error + } + + if expected.rows == nil { + return nil, fmt.Errorf("Query '%s' with args %+v, must return a database/sql/driver.Rows, but it was not set for expectation %T as %+v", query, args, expected, expected) + } + return expected, nil +} + +// Exec meets http://golang.org/pkg/database/sql/driver/#Execer +// Deprecated: Drivers should implement ExecerContext instead. +func (c *sqlmock) Exec(query string, args []driver.Value) (driver.Result, error) { + namedArgs := make([]driver.NamedValue, len(args)) + for i, v := range args { + namedArgs[i] = driver.NamedValue{ + Ordinal: i + 1, + Value: v, + } + } + + ex, err := c.exec(query, namedArgs) + if ex != nil { + time.Sleep(ex.delay) + } + if err != nil { + return nil, err + } + + return ex.result, nil +} + +func (c *sqlmock) exec(query string, args []driver.NamedValue) (*ExpectedExec, error) { + var expected *ExpectedExec + var fulfilled int + var ok bool + for _, next := range c.expected { + next.Lock() + if next.fulfilled() { + next.Unlock() + fulfilled++ + continue + } + + if c.ordered { + if expected, ok = next.(*ExpectedExec); ok { + break + } + next.Unlock() + return nil, fmt.Errorf("call to ExecQuery '%s' with args %+v, was not expected, next expectation is: %s", query, args, next) + } + if exec, ok := next.(*ExpectedExec); ok { + if err := c.queryMatcher.Match(exec.expectSQL, query); err != nil { + next.Unlock() + continue + } + + if err := exec.attemptArgMatch(args); err == nil { + expected = exec + break + } + } + next.Unlock() + } + if expected == nil { + msg := "call to ExecQuery '%s' with args %+v was not expected" + if fulfilled == len(c.expected) { + msg = "all expectations were already fulfilled, " + msg + } + return nil, fmt.Errorf(msg, query, args) + } + defer expected.Unlock() + + if err := c.queryMatcher.Match(expected.expectSQL, query); err != nil { + return nil, fmt.Errorf("ExecQuery: %v", err) + } + + if err := expected.argsMatches(args); err != nil { + return nil, fmt.Errorf("ExecQuery '%s', arguments do not match: %s", query, err) + } + + expected.triggered = true + if expected.err != nil { + return expected, expected.err // mocked to return error + } + + if expected.result == nil { + return nil, fmt.Errorf("ExecQuery '%s' with args %+v, must return a database/sql/driver.Result, but it was not set for expectation %T as %+v", query, args, expected, expected) + } + + return expected, nil +} + // @TODO maybe add ExpectedBegin.WithOptions(driver.TxOptions) diff --git a/sqlmock_go19_test.go b/sqlmock_go19_test.go index 6c69559..910d704 100644 --- a/sqlmock_go19_test.go +++ b/sqlmock_go19_test.go @@ -3,6 +3,8 @@ package sqlmock import ( + "database/sql" + "database/sql/driver" "errors" "testing" ) @@ -37,3 +39,32 @@ func TestStatementTX(t *testing.T) { t.Fatalf("unexpected result: %v", err) } } + +func Test_sqlmock_CheckNamedValue(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + tests := []struct { + name string + arg *driver.NamedValue + wantErr bool + }{ + { + arg: &driver.NamedValue{Name: "test", Value: "test"}, + wantErr: false, + }, + { + arg: &driver.NamedValue{Name: "test", Value: sql.Out{}}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := mock.(*sqlmock).CheckNamedValue(tt.arg); (err != nil) != tt.wantErr { + t.Errorf("CheckNamedValue() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/sqlmock_test.go b/sqlmock_test.go index 522ea42..ee6b516 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -2,8 +2,10 @@ package sqlmock import ( "database/sql" + "database/sql/driver" "errors" "fmt" + "reflect" "strconv" "sync" "testing" @@ -1217,3 +1219,124 @@ func queryWithTimeout(t time.Duration, db *sql.DB, query string, args ...interfa return nil, fmt.Errorf("query timed out after %v", t) } } + +func Test_sqlmock_Prepare_and_Exec(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + query := "SELECT name, email FROM users WHERE name = ?" + + mock.ExpectPrepare("SELECT (.+) FROM users WHERE (.+)") + expected := NewResult(1, 1) + mock.ExpectExec("SELECT (.+) FROM users WHERE (.+)"). + WillReturnResult(expected) + expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com") + mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows) + + got, err := mock.(*sqlmock).Prepare(query) + if err != nil { + t.Error(err) + return + } + if got == nil { + t.Error("Prepare () stmt must not be nil") + return + } + result, err := got.Exec([]driver.Value{"test"}) + if err != nil { + t.Error(err) + return + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("Results are not equal. Expected: %v, Actual: %v", expected, result) + return + } + rows, err := got.Query([]driver.Value{"test"}) + if err != nil { + t.Error(err) + return + } + defer rows.Close() +} + +type failArgument struct{} + +func (f failArgument) Match(_ driver.Value) bool { + return false +} + +func Test_sqlmock_Exec(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + mock.ExpectBegin() + _, err = mock.(*sqlmock).Exec("", []driver.Value{}) + if err == nil { + t.Errorf("error expected") + return + } + + expected := NewResult(1, 1) + mock.ExpectExec("SELECT (.+) FROM users WHERE (.+)"). + WillReturnResult(expected). + WithArgs("test") + + matchErr := errors.New("matcher sqlmock.failArgument could not match 0 argument driver.NamedValue - {Name: Ordinal:1 Value:{}}") + mock.ExpectExec("SELECT (.+) FROM animals WHERE (.+)"). + WillReturnError(matchErr). + WithArgs(failArgument{}) + + mock.ExpectExec("").WithArgs(failArgument{}) + + mock.(*sqlmock).expected = mock.(*sqlmock).expected[1:] + query := "SELECT name, email FROM users WHERE name = ?" + result, err := mock.(*sqlmock).Exec(query, []driver.Value{"test"}) + if err != nil { + t.Error(err) + return + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("Results are not equal. Expected: %v, Actual: %v", expected, result) + return + } + + failQuery := "SELECT name, sex FROM animals WHERE sex = ?" + _, err = mock.(*sqlmock).Exec(failQuery, []driver.Value{failArgument{}}) + if err == nil { + t.Errorf("error expected") + return + } + mock.(*sqlmock).ordered = false + _, err = mock.(*sqlmock).Exec("", []driver.Value{failArgument{}}) + if err == nil { + t.Errorf("error expected") + return + } +} + +func Test_sqlmock_Query(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + expectedRows := mock.NewRows([]string{"id", "name", "email"}).AddRow(1, "test", "test@example.com") + mock.ExpectQuery("SELECT (.+) FROM users WHERE (.+)").WillReturnRows(expectedRows) + query := "SELECT name, email FROM users WHERE name = ?" + rows, err := mock.(*sqlmock).Query(query, []driver.Value{"test"}) + if err != nil { + t.Error(err) + return + } + defer rows.Close() + _, err = mock.(*sqlmock).Query(query, []driver.Value{failArgument{}}) + if err == nil { + t.Errorf("error expected") + return + } +} diff --git a/statement.go b/statement.go index 570efd9..852b8f3 100644 --- a/statement.go +++ b/statement.go @@ -1,9 +1,5 @@ package sqlmock -import ( - "database/sql/driver" -) - type statement struct { conn *sqlmock ex *ExpectedPrepare @@ -18,11 +14,3 @@ func (stmt *statement) Close() error { func (stmt *statement) NumInput() int { return -1 } - -func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) { - return stmt.conn.Exec(stmt.query, args) -} - -func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) { - return stmt.conn.Query(stmt.query, args) -} diff --git a/statement_before_go18.go b/statement_before_go18.go new file mode 100644 index 0000000..e2cac2b --- /dev/null +++ b/statement_before_go18.go @@ -0,0 +1,17 @@ +// +build !go1.8 + +package sqlmock + +import ( + "database/sql/driver" +) + +// Deprecated: Drivers should implement ExecerContext instead. +func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) { + return stmt.conn.Exec(stmt.query, args) +} + +// Deprecated: Drivers should implement StmtQueryContext instead (or additionally). +func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) { + return stmt.conn.Query(stmt.query, args) +} diff --git a/statement_go18.go b/statement_go18.go new file mode 100644 index 0000000..e083051 --- /dev/null +++ b/statement_go18.go @@ -0,0 +1,26 @@ +// +build go1.8 + +package sqlmock + +import ( + "context" + "database/sql/driver" +) + +// Deprecated: Drivers should implement ExecerContext instead. +func (stmt *statement) Exec(args []driver.Value) (driver.Result, error) { + return stmt.conn.ExecContext(context.Background(), stmt.query, convertValueToNamedValue(args)) +} + +// Deprecated: Drivers should implement StmtQueryContext instead (or additionally). +func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) { + return stmt.conn.QueryContext(context.Background(), stmt.query, convertValueToNamedValue(args)) +} + +func convertValueToNamedValue(args []driver.Value) []driver.NamedValue { + namedArgs := make([]driver.NamedValue, len(args)) + for i, v := range args { + namedArgs[i] = driver.NamedValue{Ordinal: i + 1, Value: v} + } + return namedArgs +}