@@ -119,23 +119,29 @@ def get_queryset(self, *args, **kwargs):
119119 included_model = None
120120 levels = included .split ('.' )
121121 level_model = qs .model
122+ # Suppose we can do select_related by default
123+ can_select_related = True
122124 for level in levels :
123125 if not hasattr (level_model , level ):
124126 break
125127 field = getattr (level_model , level )
126128 field_class = field .__class__
127129
128130 is_forward_relation = (
129- issubclass (field_class , ForwardManyToOneDescriptor ) or
130- issubclass (field_class , ManyToManyDescriptor )
131+ issubclass (field_class , (ForwardManyToOneDescriptor , ManyToManyDescriptor ))
131132 )
132133 is_reverse_relation = (
133- issubclass (field_class , ReverseManyToOneDescriptor ) or
134- issubclass (field_class , ReverseOneToOneDescriptor )
134+ issubclass (field_class , (ReverseManyToOneDescriptor , ReverseOneToOneDescriptor ))
135135 )
136136 if not (is_forward_relation or is_reverse_relation ):
137137 break
138138
139+ # Figuring out if relation should be select related rather than prefetch_related
140+ # If at least one relation in the chain is not "selectable" then use "prefetch"
141+ can_select_related &= (
142+ issubclass (field_class , (ForwardManyToOneDescriptor , ReverseOneToOneDescriptor ))
143+ )
144+
139145 if level == levels [- 1 ]:
140146 included_model = field
141147 else :
@@ -151,7 +157,10 @@ def get_queryset(self, *args, **kwargs):
151157 level_model = model_field .model
152158
153159 if included_model is not None :
154- qs = qs .prefetch_related (included .replace ('.' , '__' ))
160+ if can_select_related :
161+ qs = qs .select_related (included .replace ('.' , '__' ))
162+ else :
163+ qs = qs .prefetch_related (included .replace ('.' , '__' ))
155164
156165 return qs
157166
0 commit comments