Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion netbox/extras/lookups.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.fields.ranges import RangeField
from django.db.models import CharField, JSONField, Lookup
from django.db.models.fields.json import KeyTextTransform

from .fields import CachedValueField


class RangeContains(Lookup):
"""
Filter ArrayField(RangeField) columns where ANY element-range contains the scalar RHS.

Usage (ORM):
Model.objects.filter(<range_array_field>__range_contains=<scalar>)

Works with int4range[], int8range[], daterange[], tstzrange[], etc.
"""

lookup_name = 'range_contains'

def as_sql(self, compiler, connection):
# Compile LHS (the array-of-ranges column/expression) and RHS (scalar)
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)

# Guard: only allow ArrayField whose base_field is a PostgreSQL RangeField
field = getattr(self.lhs, 'output_field', None)
if not (isinstance(field, ArrayField) and isinstance(field.base_field, RangeField)):
raise TypeError('range_contains is only valid for ArrayField(RangeField) columns')

# Range-contains-element using EXISTS + UNNEST keeps the range on the LHS: r @> value
sql = f"EXISTS (SELECT 1 FROM unnest({lhs}) AS r WHERE r @> {rhs})"
params = lhs_params + rhs_params
return sql, params


class Empty(Lookup):
"""
Filter on whether a string is empty.
Expand All @@ -25,7 +55,7 @@ class JSONEmpty(Lookup):

A key is considered empty if it is "", null, or does not exist.
"""
lookup_name = "empty"
lookup_name = 'empty'

def as_sql(self, compiler, connection):
# self.lhs.lhs is the parent expression (could be a JSONField or another KeyTransform)
Expand Down Expand Up @@ -69,6 +99,7 @@ def as_sql(self, qn, connection):
return 'CAST(%s AS INET) >>= %s' % (lhs, rhs), params


ArrayField.register_lookup(RangeContains)
CharField.register_lookup(Empty)
JSONField.register_lookup(JSONEmpty)
CachedValueField.register_lookup(NetHost)
Expand Down
18 changes: 2 additions & 16 deletions netbox/ipam/filtersets.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,8 @@ class VLANGroupFilterSet(OrganizationalModelFilterSet, TenancyFilterSet):
method='filter_scope'
)
contains_vid = django_filters.NumberFilter(
method='filter_contains_vid'
field_name='vid_ranges',
lookup_expr='range_contains',
)

class Meta:
Expand All @@ -931,21 +932,6 @@ def filter_scope(self, queryset, name, value):
scope_id=value
)

def filter_contains_vid(self, queryset, name, value):
"""
Return all VLANGroups which contain the given VLAN ID.
"""
table_name = VLANGroup._meta.db_table
# TODO: See if this can be optimized without compromising queryset integrity
# Expand VLAN ID ranges to query by integer
groups = VLANGroup.objects.raw(
f'SELECT id FROM {table_name}, unnest(vid_ranges) vid_range WHERE %s <@ vid_range',
params=(value,)
)
return queryset.filter(
pk__in=[g.id for g in groups]
)


class VLANFilterSet(NetBoxModelFilterSet, TenancyFilterSet):
region_id = TreeNodeMultipleChoiceFilter(
Expand Down
4 changes: 2 additions & 2 deletions netbox/ipam/graphql/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from virtualization.models import VMInterface

if TYPE_CHECKING:
from netbox.graphql.filter_lookups import IntegerArrayLookup, IntegerLookup
from netbox.graphql.filter_lookups import IntegerLookup, IntegerRangeArrayLookup
from circuits.graphql.filters import ProviderFilter
from core.graphql.filters import ContentTypeFilter
from dcim.graphql.filters import SiteFilter
Expand Down Expand Up @@ -340,7 +340,7 @@ class VLANFilter(TenancyFilterMixin, PrimaryModelFilterMixin):

@strawberry_django.filter_type(models.VLANGroup, lookups=True)
class VLANGroupFilter(ScopedFilterMixin, OrganizationalModelFilterMixin):
vid_ranges: Annotated['IntegerArrayLookup', strawberry.lazy('netbox.graphql.filter_lookups')] | None = (
vid_ranges: Annotated['IntegerRangeArrayLookup', strawberry.lazy('netbox.graphql.filter_lookups')] | None = (
strawberry_django.filter_field()
)

Expand Down
4 changes: 4 additions & 0 deletions netbox/ipam/tests/test_filtersets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,10 @@ def test_contains_vid(self):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
params = {'contains_vid': 1}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 8)
params = {'contains_vid': 12} # 11 is NOT in [1,11)
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
params = {'contains_vid': 4095}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 0)

def test_region(self):
params = {'region': Region.objects.first().pk}
Expand Down
66 changes: 66 additions & 0 deletions netbox/ipam/tests/test_lookups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from django.test import TestCase
from django.db.backends.postgresql.psycopg_any import NumericRange
from ipam.models import VLANGroup


class VLANGroupRangeContainsLookupTests(TestCase):
@classmethod
def setUpTestData(cls):
# Two ranges: [1,11) and [20,31)
cls.g1 = VLANGroup.objects.create(
name='VlanGroup-A',
slug='VlanGroup-A',
vid_ranges=[NumericRange(1, 11), NumericRange(20, 31)],
)
# One range: [100,201)
cls.g2 = VLANGroup.objects.create(
name='VlanGroup-B',
slug='VlanGroup-B',
vid_ranges=[NumericRange(100, 201)],
)
cls.g_empty = VLANGroup.objects.create(
name='VlanGroup-empty',
slug='VlanGroup-empty',
vid_ranges=[],
)

def test_contains_value_in_first_range(self):
"""
Tests whether a specific value is contained within the first range in a queried
set of VLANGroup objects.
"""
names = list(
VLANGroup.objects.filter(vid_ranges__range_contains=10).values_list('name', flat=True).order_by('name')
)
self.assertEqual(names, ['VlanGroup-A'])

def test_contains_value_in_second_range(self):
"""
Tests if a value exists in the second range of VLANGroup objects and
validates the result against the expected list of names.
"""
names = list(
VLANGroup.objects.filter(vid_ranges__range_contains=25).values_list('name', flat=True).order_by('name')
)
self.assertEqual(names, ['VlanGroup-A'])

def test_upper_bound_is_exclusive(self):
"""
Tests if the upper bound of the range is exclusive in the filter method.
"""
# 11 is NOT in [1,11)
self.assertFalse(VLANGroup.objects.filter(vid_ranges__range_contains=11).exists())

def test_no_match_far_outside(self):
"""
Tests that no VLANGroup contains a VID within a specified range far outside
common VID bounds and returns `False`.
"""
self.assertFalse(VLANGroup.objects.filter(vid_ranges__range_contains=4095).exists())

def test_empty_array_never_matches(self):
"""
Tests the behavior of VLANGroup objects when an empty array is used to match a
specific condition.
"""
self.assertFalse(VLANGroup.objects.filter(pk=self.g_empty.pk, vid_ranges__range_contains=1).exists())
28 changes: 28 additions & 0 deletions netbox/netbox/graphql/filter_lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
'FloatLookup',
'IntegerArrayLookup',
'IntegerLookup',
'IntegerRangeArrayLookup',
'JSONFilter',
'StringArrayLookup',
'TreeNodeFilter',
Expand Down Expand Up @@ -217,3 +218,30 @@ class FloatArrayLookup(ArrayLookup[float]):
@strawberry.input(one_of=True, description='Lookup for Array fields. Only one of the lookup fields can be set.')
class StringArrayLookup(ArrayLookup[str]):
pass


@strawberry.input(one_of=True, description='Lookups for an ArrayField(RangeField). Only one may be set.')
class RangeArrayValueLookup(Generic[T]):
"""
class for Array field of Range fields lookups
"""

contains: T | None = strawberry.field(
default=strawberry.UNSET, description='Return rows where any stored range contains this value.'
)

@strawberry_django.filter_field
def filter(self, info: Info, queryset: QuerySet, prefix: str = '') -> Tuple[QuerySet, Q]:
"""
Map GraphQL: { <field>: { contains: <T> } } To Django ORM: <field>__range_contains=<T>
"""
if self.contains is strawberry.UNSET or self.contains is None:
return queryset, Q()

# Build '<prefix>range_contains' so it works for nested paths too
return queryset, Q(**{f'{prefix}range_contains': self.contains})


@strawberry.input(one_of=True, description='Lookups for an ArrayField(IntegerRangeField). Only one may be set.')
class IntegerRangeArrayLookup(RangeArrayValueLookup[int]):
pass