diff --git a/definition.go b/definition.go index 2a1779f5..b387e72e 100644 --- a/definition.go +++ b/definition.go @@ -424,11 +424,13 @@ func (gt *Object) Fields() FieldDefinitionMap { } var configureFields Fields - switch gt.typeConfig.Fields.(type) { + switch fields := gt.typeConfig.Fields.(type) { case Fields: - configureFields = gt.typeConfig.Fields.(Fields) + configureFields = fields case FieldsThunk: - configureFields = gt.typeConfig.Fields.(FieldsThunk)() + configureFields = fields() + case func() Fields: + configureFields = fields() } fields, err := defineFieldMap(gt, configureFields) @@ -444,11 +446,13 @@ func (gt *Object) Interfaces() []*Interface { } var configInterfaces []*Interface - switch gt.typeConfig.Interfaces.(type) { + switch ifaces := gt.typeConfig.Interfaces.(type) { case InterfacesThunk: - configInterfaces = gt.typeConfig.Interfaces.(InterfacesThunk)() + configInterfaces = ifaces() + case func() []*Interface: + configInterfaces = ifaces() case []*Interface: - configInterfaces = gt.typeConfig.Interfaces.([]*Interface) + configInterfaces = ifaces case nil: default: gt.err = fmt.Errorf("Unknown Object.Interfaces type: %T", gt.typeConfig.Interfaces) @@ -754,11 +758,13 @@ func (it *Interface) Fields() (fields FieldDefinitionMap) { } var configureFields Fields - switch it.typeConfig.Fields.(type) { + switch fields := it.typeConfig.Fields.(type) { case Fields: - configureFields = it.typeConfig.Fields.(Fields) + configureFields = fields case FieldsThunk: - configureFields = it.typeConfig.Fields.(FieldsThunk)() + configureFields = fields() + case func() Fields: + configureFields = fields() } fields, err := defineFieldMap(it, configureFields) @@ -1140,11 +1146,13 @@ func NewInputObject(config InputObjectConfig) *InputObject { func (gt *InputObject) defineFieldMap() InputObjectFieldMap { var fieldMap InputObjectConfigFieldMap - switch gt.typeConfig.Fields.(type) { + switch fields := gt.typeConfig.Fields.(type) { case InputObjectConfigFieldMap: - fieldMap = gt.typeConfig.Fields.(InputObjectConfigFieldMap) + fieldMap = fields + case func() InputObjectConfigFieldMap: + fieldMap = fields() case InputObjectConfigFieldMapThunk: - fieldMap = gt.typeConfig.Fields.(InputObjectConfigFieldMapThunk)() + fieldMap = fields() } resultFieldMap := InputObjectFieldMap{} diff --git a/executor.go b/executor.go index a7a7535d..5717f6d0 100644 --- a/executor.go +++ b/executor.go @@ -9,6 +9,7 @@ import ( "github.com/graphql-go/graphql/gqlerrors" "github.com/graphql-go/graphql/language/ast" + "sync" ) type ExecuteParams struct { @@ -111,6 +112,14 @@ type executionContext struct { VariableValues map[string]interface{} Errors []gqlerrors.FormattedError Context context.Context + + errorsMutex sync.Mutex +} + +func (eCtx *executionContext) addError(err gqlerrors.FormattedError) { + eCtx.errorsMutex.Lock() + defer eCtx.errorsMutex.Unlock() + eCtx.Errors = append(eCtx.Errors, err) } func buildExecutionContext(p buildExecutionCtxParams) (*executionContext, error) { @@ -279,13 +288,40 @@ func executeFields(p executeFieldsParams) *Result { p.Fields = map[string][]*ast.Field{} } + var numberOfDeferredFunctions int + recoverChan := make(chan interface{}, len(p.Fields)) + + var resultsMutex sync.Mutex finalResults := map[string]interface{}{} for responseName, fieldASTs := range p.Fields { resolved, state := resolveField(p.ExecutionContext, p.ParentType, p.Source, fieldASTs) if state.hasNoFieldDefs { continue } - finalResults[responseName] = resolved + if state.isDeferred { + numberOfDeferredFunctions += 1 + go func(responseName string) { + defer func() { + recoverChan <- recover() + }() + + res := resolved.(deferredResolveFunction)() + + resultsMutex.Lock() + defer resultsMutex.Unlock() + finalResults[responseName] = res + }(responseName) + } else { + resultsMutex.Lock() + finalResults[responseName] = resolved + resultsMutex.Unlock() + } + } + + for i := 0; i < numberOfDeferredFunctions; i++ { + if r := <-recoverChan; r != nil { + panic(r) + } } return &Result{ @@ -502,8 +538,11 @@ func getFieldEntryKey(node *ast.Field) string { // Internal resolveField state type resolveFieldResultState struct { hasNoFieldDefs bool + isDeferred bool } +type deferredResolveFunction func() interface{} + // Resolves the field on the given source object. In particular, this // figures out the value that the field returns by calling its resolve function, // then calls completeValue to complete promises, serialize scalars, or execute @@ -511,25 +550,27 @@ type resolveFieldResultState struct { func resolveField(eCtx *executionContext, parentType *Object, source interface{}, fieldASTs []*ast.Field) (result interface{}, resultState resolveFieldResultState) { // catch panic from resolveFn var returnType Output + handleRecover := func(r interface{}) { + var err error + if r, ok := r.(string); ok { + err = NewLocatedError( + fmt.Sprintf("%v", r), + FieldASTsToNodeASTs(fieldASTs), + ) + } + if r, ok := r.(error); ok { + err = gqlerrors.FormatError(r) + } + // send panic upstream + if _, ok := returnType.(*NonNull); ok { + panic(gqlerrors.FormatError(err)) + } + eCtx.addError(gqlerrors.FormatError(err)) + } + defer func() (interface{}, resolveFieldResultState) { if r := recover(); r != nil { - - var err error - if r, ok := r.(string); ok { - err = NewLocatedError( - fmt.Sprintf("%v", r), - FieldASTsToNodeASTs(fieldASTs), - ) - } - if r, ok := r.(error); ok { - err = gqlerrors.FormatError(r) - } - // send panic upstream - if _, ok := returnType.(*NonNull); ok { - panic(gqlerrors.FormatError(err)) - } - eCtx.Errors = append(eCtx.Errors, gqlerrors.FormatError(err)) - return result, resultState + handleRecover(r) } return result, resultState }() @@ -581,6 +622,26 @@ func resolveField(eCtx *executionContext, parentType *Object, source interface{} panic(gqlerrors.FormatError(resolveFnError)) } + if deferredResolveFn, ok := result.(func() (interface{}, error)); ok { + resultState.isDeferred = true + return deferredResolveFunction(func() (result interface{}) { + defer func() interface{} { + if r := recover(); r != nil { + handleRecover(r) + } + + return result + }() + + result, resolveFnError = deferredResolveFn() + if resolveFnError != nil { + panic(gqlerrors.FormatError(resolveFnError)) + } + + return completeValueCatchingError(eCtx, returnType, fieldASTs, info, result) + }), resultState + } + completed := completeValueCatchingError(eCtx, returnType, fieldASTs, info, result) return completed, resultState } @@ -594,7 +655,7 @@ func completeValueCatchingError(eCtx *executionContext, returnType Type, fieldAS panic(r) } if err, ok := r.(gqlerrors.FormattedError); ok { - eCtx.Errors = append(eCtx.Errors, err) + eCtx.addError(err) } return completed } diff --git a/executor_resolve_test.go b/executor_resolve_test.go index 7430cd86..59b063fe 100644 --- a/executor_resolve_test.go +++ b/executor_resolve_test.go @@ -114,6 +114,55 @@ func TestExecutesResolveFunction_UsesProvidedResolveFunction(t *testing.T) { } } +func TestExecutesResolveFunction_UsesProvidedResolveFunction_ResolveFunctionIsDeferred(t *testing.T) { + schema := testSchema(t, &graphql.Field{ + Type: graphql.String, + Args: graphql.FieldConfigArgument{ + "aStr": &graphql.ArgumentConfig{Type: graphql.String}, + "aInt": &graphql.ArgumentConfig{Type: graphql.Int}, + }, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return func() (interface{}, error) { + b, err := json.Marshal(p.Args) + return string(b), err + }, nil + }, + }) + + expected := map[string]interface{}{ + "test": "{}", + } + result := graphql.Do(graphql.Params{ + Schema: schema, + RequestString: `{ test }`, + }) + if !reflect.DeepEqual(expected, result.Data) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result.Data)) + } + + expected = map[string]interface{}{ + "test": `{"aStr":"String!"}`, + } + result = graphql.Do(graphql.Params{ + Schema: schema, + RequestString: `{ test(aStr: "String!") }`, + }) + if !reflect.DeepEqual(expected, result.Data) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result.Data)) + } + + expected = map[string]interface{}{ + "test": `{"aInt":-123,"aStr":"String!"}`, + } + result = graphql.Do(graphql.Params{ + Schema: schema, + RequestString: `{ test(aInt: -123, aStr: "String!") }`, + }) + if !reflect.DeepEqual(expected, result.Data) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result.Data)) + } +} + func TestExecutesResolveFunction_UsesProvidedResolveFunction_SourceIsStruct_WithoutJSONTags(t *testing.T) { // For structs without JSON tags, it will map to upper-cased exported field names diff --git a/executor_test.go b/executor_test.go index 954d6d30..3ba357bd 100644 --- a/executor_test.go +++ b/executor_test.go @@ -1483,6 +1483,48 @@ func TestQuery_ExecutionDoesNotAddErrorsFromFieldResolveFn(t *testing.T) { } } +func TestQuery_DeferredResolveFn_ExecutionAddsErrorsFromFieldResolveFn(t *testing.T) { + qError := errors.New("queryError") + q := graphql.NewObject(graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{ + "a": &graphql.Field{ + Type: graphql.String, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return func() (interface{}, error) { + return nil, qError + }, nil + }, + }, + "b": &graphql.Field{ + Type: graphql.String, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return func() (interface{}, error) { + return "ok", nil + }, nil + }, + }, + }, + }) + blogSchema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: q, + }) + if err != nil { + t.Fatalf("unexpected error, got: %v", err) + } + query := "{ a }" + result := graphql.Do(graphql.Params{ + Schema: blogSchema, + RequestString: query, + }) + if len(result.Errors) == 0 { + t.Fatal("wrong result, expected errors, got no errors") + } + if result.Errors[0].Error() != qError.Error() { + t.Fatalf("wrong result, unexpected error, got: %v, expected: %v", result.Errors[0], qError) + } +} + func TestQuery_InputObjectUsesFieldDefaultValueFn(t *testing.T) { inputType := graphql.NewInputObject(graphql.InputObjectConfig{ Name: "Input",