Skip to content

GODRIVER-2603 (Contd.) Revised error handling using Go 1.13 error APIs #1476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bson/bsoncodec/default_value_decoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -1787,7 +1787,7 @@ func (DefaultValueDecoders) decodeElemsFromDocumentReader(dc DecodeContext, dr b
elems := make([]reflect.Value, 0)
for {
key, vr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
if errors.Is(err, bsonrw.ErrEOD) {
break
}
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions bson/bsoncodec/default_value_encoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw bsonrw.Docum
return err
}

if lookupErr == errInvalidValue {
if errors.Is(lookupErr, errInvalidValue) {
err = vw.WriteNull()
if err != nil {
return err
Expand Down Expand Up @@ -427,7 +427,7 @@ func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw bsonrw.Val
return err
}

if lookupErr == errInvalidValue {
if errors.Is(lookupErr, errInvalidValue) {
err = vw.WriteNull()
if err != nil {
return err
Expand Down Expand Up @@ -496,7 +496,7 @@ func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw bsonrw.Val
return err
}

if lookupErr == errInvalidValue {
if errors.Is(lookupErr, errInvalidValue) {
err = vw.WriteNull()
if err != nil {
return err
Expand Down
5 changes: 3 additions & 2 deletions bson/bsoncodec/map_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package bsoncodec

import (
"encoding"
"errors"
"fmt"
"reflect"
"strconv"
Expand Down Expand Up @@ -137,7 +138,7 @@ func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, v
return err
}

if lookupErr == errInvalidValue {
if errors.Is(lookupErr, errInvalidValue) {
err = vw.WriteNull()
if err != nil {
return err
Expand Down Expand Up @@ -200,7 +201,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref

for {
key, vr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
if errors.Is(err, bsonrw.ErrEOD) {
break
}
if err != nil {
Expand Down
7 changes: 5 additions & 2 deletions bson/bsoncodec/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package bsoncodec

import (
"errors"
"reflect"
"testing"

Expand Down Expand Up @@ -351,7 +352,8 @@ func TestRegistryBuilder(t *testing.T) {
})
t.Run("Decoder", func(t *testing.T) {
wanterr := tc.wanterr
if ene, ok := tc.wanterr.(ErrNoEncoder); ok {
var ene ErrNoEncoder
if errors.As(tc.wanterr, &ene) {
wanterr = ErrNoDecoder(ene)
}

Expand Down Expand Up @@ -775,7 +777,8 @@ func TestRegistry(t *testing.T) {
t.Parallel()

wanterr := tc.wanterr
if ene, ok := tc.wanterr.(ErrNoEncoder); ok {
var ene ErrNoEncoder
if errors.As(tc.wanterr, &ene) {
wanterr = ErrNoDecoder(ene)
}

Expand Down
5 changes: 3 additions & 2 deletions bson/bsonrw/copier.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package bsonrw

import (
"errors"
"fmt"
"io"

Expand Down Expand Up @@ -442,7 +443,7 @@ func (c Copier) copyArray(dst ValueWriter, src ValueReader) error {

for {
vr, err := ar.ReadValue()
if err == ErrEOA {
if errors.Is(err, ErrEOA) {
break
}
if err != nil {
Expand All @@ -466,7 +467,7 @@ func (c Copier) copyArray(dst ValueWriter, src ValueReader) error {
func (c Copier) copyDocumentCore(dw DocumentWriter, dr DocumentReader) error {
for {
key, vr, err := dr.ReadElement()
if err == ErrEOD {
if errors.Is(err, ErrEOD) {
break
}
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion bson/bsonrw/extjson_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func (ejp *extJSONParser) readValue(t bsontype.Type) (*extJSONValue, error) {
// convert hex to bytes
bytes, err := hex.DecodeString(uuidNoHyphens)
if err != nil {
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %v", err)
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %w", err)
}

ejp.advanceState()
Expand Down
3 changes: 2 additions & 1 deletion bson/bsonrw/extjson_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package bsonrw

import (
"errors"
"io"
"strings"
"testing"
Expand Down Expand Up @@ -47,7 +48,7 @@ type readKeyValueTestCase struct {

func expectSpecificError(expected error) expectedErrorFunc {
return func(t *testing.T, err error, desc string) {
if err != expected {
if !errors.Is(err, expected) {
t.Helper()
t.Errorf("%s: Expected %v but got: %v", desc, expected, err)
t.FailNow()
Expand Down
5 changes: 3 additions & 2 deletions bson/bsonrw/extjson_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package bsonrw

import (
"errors"
"fmt"
"io"
"sync"
Expand Down Expand Up @@ -613,7 +614,7 @@ func (ejvr *extJSONValueReader) ReadElement() (string, ValueReader, error) {
name, t, err := ejvr.p.readKey()

if err != nil {
if err == ErrEOD {
if errors.Is(err, ErrEOD) {
if ejvr.stack[ejvr.frame].mode == mCodeWithScope {
_, err := ejvr.p.peekType()
if err != nil {
Expand All @@ -640,7 +641,7 @@ func (ejvr *extJSONValueReader) ReadValue() (ValueReader, error) {

t, err := ejvr.p.peekType()
if err != nil {
if err == ErrEOA {
if errors.Is(err, ErrEOA) {
ejvr.pop()
}

Expand Down
12 changes: 6 additions & 6 deletions bson/bsonrw/json_scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (js *jsonScanner) nextToken() (*jsonToken, error) {
c, err = js.readNextByte()
}

if err == io.EOF {
if errors.Is(err, io.EOF) {
return &jsonToken{t: jttEOF}, nil
} else if err != nil {
return nil, err
Expand Down Expand Up @@ -198,7 +198,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) {
for {
c, err = js.readNextByte()
if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
Expand All @@ -209,7 +209,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) {
case '\\':
c, err = js.readNextByte()
if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
Expand Down Expand Up @@ -248,7 +248,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) {
if utf16.IsSurrogate(rn) {
c, err = js.readNextByte()
if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
Expand All @@ -264,7 +264,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) {

c, err = js.readNextByte()
if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
Expand Down Expand Up @@ -384,7 +384,7 @@ func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) {
for {
c, err = js.readNextByte()

if err != nil && err != io.EOF {
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}

Expand Down
7 changes: 4 additions & 3 deletions bson/bsonrw/value_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package bsonrw

import (
"bytes"
"errors"
"fmt"
"io"
"math"
Expand Down Expand Up @@ -185,7 +186,7 @@ func TestValueReader(t *testing.T) {
// invalid length
vr.d = []byte{0x00, 0x00}
_, err := vr.ReadDocument()
if err != io.EOF {
if !errors.Is(err, io.EOF) {
t.Errorf("Expected io.EOF with document length too small. got %v; want %v", err, io.EOF)
}

Expand Down Expand Up @@ -239,7 +240,7 @@ func TestValueReader(t *testing.T) {

vr.frame--
_, err = vr.ReadDocument()
if err != io.EOF {
if !errors.Is(err, io.EOF) {
t.Errorf("Should return error when attempting to read length with not enough bytes. got %v; want %v", err, io.EOF)
}
})
Expand Down Expand Up @@ -1482,7 +1483,7 @@ func TestValueReader(t *testing.T) {
frame: 0,
}
gotType, got, gotErr := vr.ReadValueBytes(nil)
if gotErr != tc.wantErr {
if !errors.Is(gotErr, tc.wantErr) {
t.Errorf("Did not receive expected error. got %v; want %v", gotErr, tc.wantErr)
}
if tc.wantErr == nil && gotType != tc.wantType {
Expand Down
2 changes: 1 addition & 1 deletion bson/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ func TestDecoderv2(t *testing.T) {

var got *D
err = dec.Decode(got)
if err != ErrDecodeToNil {
if !errors.Is(err, ErrDecodeToNil) {
t.Fatalf("Decode error mismatch; expected %v, got %v", ErrDecodeToNil, err)
}
})
Expand Down
4 changes: 2 additions & 2 deletions bson/primitive_codecs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
func bytesFromDoc(doc interface{}) []byte {
b, err := Marshal(doc)
if err != nil {
panic(fmt.Errorf("Couldn't marshal BSON document: %v", err))
panic(fmt.Errorf("Couldn't marshal BSON document: %w", err))
}
return b
}
Expand Down Expand Up @@ -471,7 +471,7 @@ func TestDefaultValueEncoders(t *testing.T) {
enc, err := NewEncoder(vw)
noerr(t, err)
err = enc.Encode(tc.value)
if err != tc.err {
if !errors.Is(err, tc.err) {
t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err)
}
if diff := cmp.Diff([]byte(b), tc.b); diff != "" {
Expand Down
7 changes: 4 additions & 3 deletions bson/raw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package bson
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"strings"
Expand Down Expand Up @@ -52,7 +53,7 @@ func TestRaw(t *testing.T) {
r := make(Raw, 5)
binary.LittleEndian.PutUint32(r[0:4], 200)
got := r.Validate()
if got != want {
if !errors.Is(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
Expand All @@ -62,7 +63,7 @@ func TestRaw(t *testing.T) {
binary.LittleEndian.PutUint32(r[0:4], 8)
r[4], r[5], r[6], r[7] = '\x02', 'f', 'o', 'o'
got := r.Validate()
if got != want {
if !errors.Is(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
Expand All @@ -72,7 +73,7 @@ func TestRaw(t *testing.T) {
binary.LittleEndian.PutUint32(r[0:4], 9)
r[4], r[5], r[6], r[7], r[8] = '\x0A', 'f', 'o', 'o', '\x00'
got := r.Validate()
if got != want {
if !errors.Is(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
Expand Down
7 changes: 5 additions & 2 deletions examples/documentation_examples/examples.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package documentation_examples

import (
"context"
"errors"
"fmt"
"io/ioutil"
logger "log"
Expand Down Expand Up @@ -1816,7 +1817,8 @@ func RunTransactionWithRetry(sctx mongo.SessionContext, txnFn func(mongo.Session
log.Println("Transaction aborted. Caught exception during transaction.")

// If transient error, retry the whole transaction
if cmdErr, ok := err.(mongo.CommandError); ok && cmdErr.HasErrorLabel("TransientTransactionError") {
var cmdErr mongo.CommandError
if errors.As(err, &cmdErr) && cmdErr.HasErrorLabel("TransientTransactionError") {
log.Println("TransientTransactionError, retrying transaction...")
continue
}
Expand Down Expand Up @@ -1883,7 +1885,8 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error {
log.Println("Transaction aborted. Caught exception during transaction.")

// If transient error, retry the whole transaction
if cmdErr, ok := err.(mongo.CommandError); ok && cmdErr.HasErrorLabel("TransientTransactionError") {
var cmdErr mongo.CommandError
if errors.As(err, &cmdErr) && cmdErr.HasErrorLabel("TransientTransactionError") {
log.Println("TransientTransactionError, retrying transaction...")
continue
}
Expand Down
2 changes: 1 addition & 1 deletion internal/logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func selectLogSink(sink LogSink) (LogSink, *os.File, error) {
if path != "" {
logFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0666)
if err != nil {
return nil, nil, fmt.Errorf("unable to open log file: %v", err)
return nil, nil, fmt.Errorf("unable to open log file: %w", err)
}

return NewIOSink(logFile), logFile, nil
Expand Down
13 changes: 7 additions & 6 deletions mongo/bulk_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package mongo

import (
"context"
"errors"

"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/primitive"
Expand Down Expand Up @@ -108,8 +109,8 @@ func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWr
case *InsertOneModel:
res, err := bw.runInsert(ctx, batch)
if err != nil {
writeErr, ok := err.(driver.WriteCommandError)
if !ok {
var writeErr driver.WriteCommandError
if !errors.As(err, &writeErr) {
return BulkWriteResult{}, batchErr, err
}
writeErrors = writeErr.WriteErrors
Expand All @@ -120,8 +121,8 @@ func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWr
case *DeleteOneModel, *DeleteManyModel:
res, err := bw.runDelete(ctx, batch)
if err != nil {
writeErr, ok := err.(driver.WriteCommandError)
if !ok {
var writeErr driver.WriteCommandError
if !errors.As(err, &writeErr) {
return BulkWriteResult{}, batchErr, err
}
writeErrors = writeErr.WriteErrors
Expand All @@ -132,8 +133,8 @@ func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWr
case *ReplaceOneModel, *UpdateOneModel, *UpdateManyModel:
res, err := bw.runUpdate(ctx, batch)
if err != nil {
writeErr, ok := err.(driver.WriteCommandError)
if !ok {
var writeErr driver.WriteCommandError
if !errors.As(err, &writeErr) {
return BulkWriteResult{}, batchErr, err
}
writeErrors = writeErr.WriteErrors
Expand Down
Loading