Skip to content

Commit b613fef

Browse files
committed
Simple implementation of DSL inline fragments
1 parent ca4021d commit b613fef

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

gql/dsl.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
GraphQLObjectType,
1818
GraphQLSchema,
1919
GraphQLWrappingType,
20+
InlineFragmentNode,
2021
ListTypeNode,
2122
ListValueNode,
2223
NamedTypeNode,
@@ -407,6 +408,10 @@ class DSLField:
407408
method.
408409
"""
409410

411+
_type: Union[GraphQLObjectType, GraphQLInterfaceType]
412+
ast_field: FieldNode
413+
field: GraphQLField
414+
410415
def __init__(
411416
self,
412417
name: str,
@@ -423,11 +428,9 @@ def __init__(
423428
:param graphql_type: the GraphQL type definition from the schema
424429
:param graphql_field: the GraphQL field definition from the schema
425430
"""
426-
self._type: Union[GraphQLObjectType, GraphQLInterfaceType] = graphql_type
427-
self.field: GraphQLField = graphql_field
428-
self.ast_field: FieldNode = FieldNode(
429-
name=NameNode(value=name), arguments=FrozenList()
430-
)
431+
self._type = graphql_type
432+
self.field = graphql_field
433+
self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList())
431434
log.debug(f"Creating {self!r}")
432435

433436
@staticmethod
@@ -585,7 +588,25 @@ def __str__(self) -> str:
585588
return print_ast(self.ast_field)
586589

587590
def __repr__(self) -> str:
588-
return (
589-
f"<{self.__class__.__name__} {self._type.name}"
590-
f"::{self.ast_field.name.value}>"
591+
name = self._type.name
592+
try:
593+
name += f"::{self.ast_field.name.value}"
594+
except AttributeError:
595+
pass
596+
return f"<{self.__class__.__name__} {name}>"
597+
598+
599+
class DSLFragment(DSLField):
600+
def __init__(
601+
self, type_condition: Optional[DSLType] = None,
602+
):
603+
self.ast_field = InlineFragmentNode() # type: ignore
604+
if type_condition:
605+
self.on(type_condition)
606+
607+
def on(self, type_condition: DSLType):
608+
self._type = type_condition._type
609+
self.ast_field.type_condition = NamedTypeNode( # type: ignore
610+
name=NameNode(value=self._type.name)
591611
)
612+
return self

tests/starwars/test_dsl.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from gql import Client
1717
from gql.dsl import (
18+
DSLFragment,
1819
DSLMutation,
1920
DSLQuery,
2021
DSLSchema,
@@ -416,6 +417,24 @@ def test_multiple_operations(ds):
416417
)
417418

418419

420+
def test_inline_fragments(ds):
421+
query = """hero(episode: JEDI) {
422+
name
423+
... on Droid {
424+
primaryFunction
425+
}
426+
... on Human {
427+
homePlanet
428+
}
429+
}"""
430+
query_dsl = ds.Query.hero.args(episode=6).select(
431+
ds.Character.name,
432+
DSLFragment().on(ds.Droid).select(ds.Droid.primaryFunction),
433+
DSLFragment().on(ds.Human).select(ds.Human.homePlanet),
434+
)
435+
assert query == str(query_dsl)
436+
437+
419438
def test_dsl_query_all_fields_should_be_instances_of_DSLField():
420439
with pytest.raises(
421440
TypeError, match="fields must be instances of DSLField. Received type:"

0 commit comments

Comments
 (0)