diff --git a/any_tests/jsoniter_any_object_test.go b/any_tests/jsoniter_any_object_test.go index 5af292da..1a0d55b4 100644 --- a/any_tests/jsoniter_any_object_test.go +++ b/any_tests/jsoniter_any_object_test.go @@ -3,7 +3,7 @@ package any_tests import ( "testing" - "github.com/json-iterator/go" + jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/require" ) diff --git a/api_tests/marshal_json_test.go b/api_tests/marshal_json_test.go index 635a24ee..1a57effa 100644 --- a/api_tests/marshal_json_test.go +++ b/api_tests/marshal_json_test.go @@ -3,12 +3,13 @@ package test import ( "bytes" "encoding/json" - "github.com/json-iterator/go" + "errors" "testing" + + jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/require" ) - type Foo struct { Bar interface{} } @@ -19,11 +20,10 @@ func (f Foo) MarshalJSON() ([]byte, error) { return buf.Bytes(), err } - // Standard Encoder has trailing newline. func TestEncodeMarshalJSON(t *testing.T) { - foo := Foo { + foo := Foo{ Bar: 123, } should := require.New(t) @@ -34,3 +34,21 @@ func TestEncodeMarshalJSON(t *testing.T) { stdenc.Encode(foo) should.Equal(stdbuf.Bytes(), buf.Bytes()) } + +func TestMarshalObjectWithCycle(t *testing.T) { + type A struct { + A *A + } + a := A{} + a.A = &a + + api := jsoniter.ConfigCompatibleWithStandardLibrary + + if _, err := jsoniter.Marshal(a); !errors.Is(err, jsoniter.ErrCycleEncountered) { + t.Fatal(err) + } + + if err := api.NewEncoder(nil).Encode(a); !errors.Is(err, jsoniter.ErrCycleEncountered) { + t.Fatal(err) + } +} diff --git a/config.go b/config.go index 2adcdc3b..208b7c59 100644 --- a/config.go +++ b/config.go @@ -284,26 +284,35 @@ func (cfg *frozenConfig) cleanEncoders() { } func (cfg *frozenConfig) MarshalToString(v interface{}) (string, error) { - stream := cfg.BorrowStream(nil) - defer cfg.ReturnStream(stream) - stream.WriteVal(v) - if stream.Error != nil { - return "", stream.Error + result, err := cfg.marshalToStream(v) + if err != nil { + return "", err } - return string(stream.Buffer()), nil + + return string(result), nil } func (cfg *frozenConfig) Marshal(v interface{}) ([]byte, error) { + result, err := cfg.marshalToStream(v) + if err != nil { + return nil, err + } + + copied := make([]byte, len(result)) + copy(copied, result) + return copied, nil +} + +// marshalToStream writes v to a borrowed stream and returns stream.Buffer() with error. +func (cfg *frozenConfig) marshalToStream(v interface{}) ([]byte, error) { stream := cfg.BorrowStream(nil) defer cfg.ReturnStream(stream) + stream.WriteVal(v) if stream.Error != nil { return nil, stream.Error } - result := stream.Buffer() - copied := make([]byte, len(result)) - copy(copied, result) - return copied, nil + return stream.Buffer(), nil } func (cfg *frozenConfig) MarshalIndent(v interface{}, prefix, indent string) ([]byte, error) { diff --git a/reflect.go b/reflect.go index 39acb320..942c84f9 100644 --- a/reflect.go +++ b/reflect.go @@ -1,6 +1,7 @@ package jsoniter import ( + "errors" "fmt" "reflect" "unsafe" @@ -83,12 +84,51 @@ func (iter *Iterator) ReadVal(obj interface{}) { } } +func hasCycle(v interface{}) bool { + visited := make(map[uintptr]bool) + queue := []reflect.Value{reflect.ValueOf(v)} + + for len(queue) > 0 { + val := queue[0] + queue = queue[1:] + + for val.Kind() == reflect.Ptr || val.Kind() == reflect.Interface { + if val.IsNil() { + break + } + if val.Kind() == reflect.Ptr { + ptr := val.Pointer() + if visited[ptr] { + return true + } + visited[ptr] = true + } + val = val.Elem() + } + + if val.Kind() == reflect.Struct { + for i := 0; i < val.NumField(); i++ { + queue = append(queue, val.Field(i)) + } + } + } + return false +} + +var ErrCycleEncountered = errors.New("jsoniter: unsupported type: encountered a cycle") + // WriteVal copy the go interface into underlying JSON, same as json.Marshal func (stream *Stream) WriteVal(val interface{}) { - if nil == val { + if val == nil { stream.WriteNil() return } + + if hasCycle(val) { + stream.Error = ErrCycleEncountered + return + } + cacheKey := reflect2.RTypeOf(val) encoder := stream.cfg.getEncoderFromCache(cacheKey) if encoder == nil {