diff --git a/netbox/extras/tests/test_views.py b/netbox/extras/tests/test_views.py index 9da6f047a54..91444e2ce94 100644 --- a/netbox/extras/tests/test_views.py +++ b/netbox/extras/tests/test_views.py @@ -1,11 +1,14 @@ from django.contrib.contenttypes.models import ContentType from django.urls import reverse +from django.test import tag +from core.choices import ManagedFileRootPathChoices from core.events import * from core.models import ObjectType from dcim.models import DeviceType, Manufacturer, Site from extras.choices import * from extras.models import * +from extras.scripts import Script as PythonClass, IntegerVar, BooleanVar from users.models import Group, User from utilities.testing import ViewTestCases, TestCase @@ -897,3 +900,70 @@ def test_script_list_embedded_parameter(self): response = self.client.get(url, {'embedded': 'true'}) self.assertEqual(response.status_code, 200) self.assertTemplateUsed(response, 'extras/inc/script_list_content.html') + + +class ScriptValidationErrorTest(TestCase): + user_permissions = ['extras.view_script', 'extras.run_script'] + + class TestScriptMixin: + bar = IntegerVar(min_value=0, max_value=30, default=30) + + class TestScriptClass(TestScriptMixin, PythonClass): + class Meta: + name = 'Test script' + commit_default = False + fieldsets = (("Logging", ("debug_mode",)),) + + debug_mode = BooleanVar(default=False) + + def run(self, data, commit): + return "Complete" + + @classmethod + def setUpTestData(cls): + module = ScriptModule.objects.create(file_root=ManagedFileRootPathChoices.SCRIPTS, file_path='test_script.py') + cls.script = Script.objects.create(module=module, name='Test script', is_executable=True) + + def setUp(self): + super().setUp() + Script.python_class = property(lambda self: ScriptValidationErrorTest.TestScriptClass) + + @tag('regression') + def test_script_validation_error_displays_message(self): + from unittest.mock import patch + + url = reverse('extras:script', kwargs={'pk': self.script.pk}) + + with patch('extras.views.get_workers_for_queue', return_value=['worker']): + response = self.client.post(url, {'debug_mode': 'true', '_commit': 'true'}) + + self.assertEqual(response.status_code, 200) + messages = list(response.context['messages']) + self.assertEqual(len(messages), 1) + self.assertEqual(str(messages[0]), "bar: This field is required.") + + @tag('regression') + def test_script_validation_error_no_toast_for_fieldset_fields(self): + from unittest.mock import patch, PropertyMock + + class FieldsetScript(PythonClass): + class Meta: + name = 'Fieldset test' + commit_default = False + fieldsets = (("Fields", ("required_field",)),) + + required_field = IntegerVar(min_value=10) + + def run(self, data, commit): + return "Complete" + + url = reverse('extras:script', kwargs={'pk': self.script.pk}) + + with patch.object(Script, 'python_class', new_callable=PropertyMock) as mock_python_class: + mock_python_class.return_value = FieldsetScript + with patch('extras.views.get_workers_for_queue', return_value=['worker']): + response = self.client.post(url, {'required_field': '5', '_commit': 'true'}) + + self.assertEqual(response.status_code, 200) + messages = list(response.context['messages']) + self.assertEqual(len(messages), 0) diff --git a/netbox/extras/views.py b/netbox/extras/views.py index 32d19674b86..32f87fb9708 100644 --- a/netbox/extras/views.py +++ b/netbox/extras/views.py @@ -1485,6 +1485,15 @@ def post(self, request, **kwargs): ) return redirect('extras:script_result', job_pk=job.pk) + else: + fieldset_fields = {field for _, fields in script_class.get_fieldsets() for field in fields} + hidden_errors = { + field: errors for field, errors in form.errors.items() + if field not in fieldset_fields + } + if hidden_errors: + error_msg = '; '.join(f"{field}: {', '.join(errors)}" for field, errors in hidden_errors.items()) + messages.error(request, error_msg) return render(request, 'extras/script.html', { 'object': script,