Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions netbox/extras/models/customfields.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,12 @@ class Meta:
def __str__(self):
return self.name

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Cache the initial set of choices for comparison under clean()
self._original_extra_choices = self.__dict__.get('extra_choices')

def get_absolute_url(self):
return reverse('extras:customfieldchoiceset', args=[self.pk])

Expand Down Expand Up @@ -818,6 +824,32 @@ def clean(self):
if not self.base_choices and not self.extra_choices:
raise ValidationError(_("Must define base or extra choices."))

# Check whether any choices have been removed. If so, check whether any of the removed
# choices are still set in custom field data for any object.
original_choices = set([
c[0] for c in self._original_extra_choices
]) if self._original_extra_choices else set()
current_choices = set([
c[0] for c in self.extra_choices
]) if self.extra_choices else set()
if removed_choices := original_choices - current_choices:
for custom_field in self.choices_for.all():
for object_type in custom_field.object_types.all():
model = object_type.model_class()
for choice in removed_choices:
# Form the query based on the type of custom field
if custom_field.type == CustomFieldTypeChoices.TYPE_MULTISELECT:
query_args = {f"custom_field_data__{custom_field.name}__contains": choice}
else:
query_args = {f"custom_field_data__{custom_field.name}": choice}
# Raise a ValidationError if there are any objects which still reference the removed choice
if model.objects.filter(models.Q(**query_args)).exists():
raise ValidationError(
_(
"Cannot remove choice {choice} as there are {model} objects which reference it."
).format(choice=choice, model=object_type)
)

def save(self, *args, **kwargs):

# Sort choices if alphabetical ordering is enforced
Expand Down
68 changes: 68 additions & 0 deletions netbox/extras/tests/test_customfields.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,74 @@ def test_multiselect_field(self):
instance.refresh_from_db()
self.assertIsNone(instance.custom_field_data.get(cf.name))

def test_remove_selected_choice(self):
"""
Removing a ChoiceSet choice that is referenced by an object should raise
a ValidationError exception.
"""
CHOICES = (
('a', 'Option A'),
('b', 'Option B'),
('c', 'Option C'),
('d', 'Option D'),
)

# Create a set of custom field choices
choice_set = CustomFieldChoiceSet.objects.create(
name='Custom Field Choice Set 1',
extra_choices=CHOICES
)

# Create a select custom field
cf = CustomField.objects.create(
name='select_field',
type=CustomFieldTypeChoices.TYPE_SELECT,
required=False,
choice_set=choice_set
)
cf.object_types.set([self.object_type])

# Create a multi-select custom field
cf_multiselect = CustomField.objects.create(
name='multiselect_field',
type=CustomFieldTypeChoices.TYPE_MULTISELECT,
required=False,
choice_set=choice_set
)
cf_multiselect.object_types.set([self.object_type])

# Assign a choice for both custom fields on an object
instance = Site.objects.first()
instance.custom_field_data[cf.name] = 'a'
instance.custom_field_data[cf_multiselect.name] = ['b', 'c']
instance.save()

# Attempting to delete a selected choice should fail
with self.assertRaises(ValidationError):
choice_set.extra_choices = (
('b', 'Option B'),
('c', 'Option C'),
('d', 'Option D'),
)
choice_set.full_clean()

# Attempting to delete either of the multi-select choices should fail
with self.assertRaises(ValidationError):
choice_set.extra_choices = (
('a', 'Option A'),
('b', 'Option B'),
('d', 'Option D'),
)
choice_set.full_clean()

# Removing a non-selected choice should succeed
choice_set.extra_choices = (
('a', 'Option A'),
('b', 'Option B'),
('c', 'Option C'),
)
choice_set.full_clean()

def test_object_field(self):
value = VLAN.objects.create(name='VLAN 1', vid=1).pk

Expand Down