diff --git a/netbox/netbox/tests/test_import.py b/netbox/netbox/tests/test_import.py index 6594409f20c..56650ea7fa9 100644 --- a/netbox/netbox/tests/test_import.py +++ b/netbox/netbox/tests/test_import.py @@ -1,3 +1,8 @@ +import csv +import io +import json +import yaml + from django.contrib.contenttypes.models import ContentType from django.test import override_settings @@ -17,8 +22,16 @@ def setUpTestData(cls): def _get_csv_data(self, csv_data): return '\n'.join(csv_data) + def _get_yaml_data(self, csv_data): + data = [*csv.DictReader(io.StringIO(self._get_csv_data(csv_data)))] + return yaml.dump(data) + + def _get_json_data(self, csv_data): + data = [*csv.DictReader(io.StringIO(self._get_csv_data(csv_data)))] + return json.dumps(data) + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) - def test_valid_tags(self): + def test_valid_tags_csv(self): csv_data = ( 'name,slug,tags', 'Region 1,region-1,"alpha,bravo"', @@ -62,7 +75,93 @@ def test_valid_tags(self): self.assertEqual(regions[3].tags.count(), 0) @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) - def test_invalid_tags(self): + def test_valid_tags_yaml(self): + csv_data = ( + 'name,slug,tags', + 'Region 1,region-1,"alpha,bravo"', + 'Region 2,region-2,"charlie,delta"', + 'Region 3,region-3,echo', + 'Region 4,region-4,', + ) + + data = { + 'format': ImportFormatChoices.YAML, + 'data': self._get_yaml_data(csv_data), + } + + # Assign model-level permission + obj_perm = ObjectPermission(name='Test permission', actions=['add']) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with model-level permission + self.assertHttpStatus(self.client.get(self._get_url('import')), 200) + + # Test POST with permission + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302) + regions = Region.objects.all() + self.assertEqual(regions.count(), 4) + region = Region.objects.get(slug="region-4") + self.assertEqual( + list(regions[0].tags.values_list('name', flat=True)), + ['Alpha', 'Bravo'] + ) + self.assertEqual( + list(regions[1].tags.values_list('name', flat=True)), + ['Charlie', 'Delta'] + ) + self.assertEqual( + list(regions[2].tags.values_list('name', flat=True)), + ['Echo'] + ) + self.assertEqual(regions[3].tags.count(), 0) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_valid_tags_json(self): + csv_data = ( + 'name,slug,tags', + 'Region 1,region-1,"alpha,bravo"', + 'Region 2,region-2,"charlie,delta"', + 'Region 3,region-3,echo', + 'Region 4,region-4,', + ) + + data = { + 'format': ImportFormatChoices.JSON, + 'data': self._get_json_data(csv_data), + } + + # Assign model-level permission + obj_perm = ObjectPermission(name='Test permission', actions=['add']) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with model-level permission + self.assertHttpStatus(self.client.get(self._get_url('import')), 200) + + # Test POST with permission + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302) + regions = Region.objects.all() + self.assertEqual(regions.count(), 4) + region = Region.objects.get(slug="region-4") + self.assertEqual( + list(regions[0].tags.values_list('name', flat=True)), + ['Alpha', 'Bravo'] + ) + self.assertEqual( + list(regions[1].tags.values_list('name', flat=True)), + ['Charlie', 'Delta'] + ) + self.assertEqual( + list(regions[2].tags.values_list('name', flat=True)), + ['Echo'] + ) + self.assertEqual(regions[3].tags.count(), 0) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_invalid_tags_csv(self): csv_data = ( 'name,slug,tags', 'Region 1,region-1,"Alpha,Bravo"', # Valid @@ -86,3 +185,124 @@ def test_invalid_tags(self): # Test POST with permission self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200) self.assertEqual(Region.objects.count(), 0) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_invalid_tags_yaml(self): + csv_data = ( + 'name,slug,tags', + 'Region 1,region-1,"Alpha,Bravo"', # Valid + 'Region 2,region-2,"Alpha,Tango"', # Invalid + ) + + data = { + 'format': ImportFormatChoices.YAML, + 'data': self._get_yaml_data(csv_data), + } + + # Assign model-level permission + obj_perm = ObjectPermission(name='Test permission', actions=['add']) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with model-level permission + self.assertHttpStatus(self.client.get(self._get_url('import')), 200) + + # Test POST with permission + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200) + self.assertEqual(Region.objects.count(), 0) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_invalid_tags_json(self): + csv_data = ( + 'name,slug,tags', + 'Region 1,region-1,"Alpha,Bravo"', # Valid + 'Region 2,region-2,"Alpha,Tango"', # Invalid + ) + + data = { + 'format': ImportFormatChoices.JSON, + 'data': self._get_json_data(csv_data), + } + + # Assign model-level permission + obj_perm = ObjectPermission(name='Test permission', actions=['add']) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with model-level permission + self.assertHttpStatus(self.client.get(self._get_url('import')), 200) + + # Test POST with permission + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200) + self.assertEqual(Region.objects.count(), 0) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_invalid_header_csv(self): + csv_data = ( + 'name,slug,xxx', + 'Region 1,region-1,yyy', + ) + + data = { + 'format': ImportFormatChoices.CSV, + 'data': self._get_csv_data(csv_data), + } + + # Assign model-level permission + obj_perm = ObjectPermission(name='Test permission', actions=['add']) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Test POST with permission + ret = self.client.post(self._get_url('import'), data) + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200) + self.assertEqual(Region.objects.count(), 0) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_invalid_header_yaml(self): + csv_data = ( + 'name,slug,xxx', + 'Region 1,region-1,yyy', + ) + + data = { + 'format': ImportFormatChoices.YAML, + 'data': self._get_yaml_data(csv_data), + } + + # Assign model-level permission + obj_perm = ObjectPermission(name='Test permission', actions=['add']) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Test POST with permission + ret = self.client.post(self._get_url('import'), data) + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200) + self.assertEqual(Region.objects.count(), 0) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_invalid_header_json(self): + csv_data = ( + 'name,slug,xxx', + 'Region 1,region-1,yyy', + ) + + data = { + 'format': ImportFormatChoices.JSON, + 'data': self._get_json_data(csv_data), + } + + # Assign model-level permission + obj_perm = ObjectPermission(name='Test permission', actions=['add']) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Test POST with permission + ret = self.client.post(self._get_url('import'), data) + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200) + self.assertEqual(Region.objects.count(), 0) diff --git a/netbox/netbox/views/generic/bulk_views.py b/netbox/netbox/views/generic/bulk_views.py index bef524bcedc..e1168e2c859 100644 --- a/netbox/netbox/views/generic/bulk_views.py +++ b/netbox/netbox/views/generic/bulk_views.py @@ -15,12 +15,14 @@ from django.utils.safestring import mark_safe from django_tables2.export import TableExport +from dcim.forms.bulk_import import DeviceTypeImportForm, ModuleTypeImportForm from extras.models import ExportTemplate from extras.signals import clear_webhooks from utilities.error_handlers import handle_protectederror from utilities.exceptions import AbortRequest, AbortTransaction, PermissionsViolation from utilities.forms import BulkRenameForm, ConfirmationForm, restrict_form_fields from utilities.forms.bulk_import import BulkImportForm +from utilities.forms.utils import headers_to_dict, validate_import_headers from utilities.htmx import is_embedded, is_htmx from utilities.permissions import get_permission_for_model from utilities.utils import get_viewname @@ -398,10 +400,26 @@ def create_and_update_objects(self, form, request): 'data': record, 'instance': instance, } + headers = None if hasattr(form, '_csv_headers'): - model_form_kwargs['headers'] = form._csv_headers # Add CSV headers + headers = form._csv_headers + model_form_kwargs['headers'] = headers # Add CSV headers model_form = self.model_form(**model_form_kwargs) + # validate the fields (required fields are present and no unknown fields) + form_fields = model_form.fields + required_fields = [ + name for name, field in form_fields.items() if field.required + ] + + if not headers: + keys = list(record.keys()) + if object_id: + keys.append("id") + headers = headers_to_dict(keys) + + validate_import_headers(headers, form_fields, required_fields, allow_extra_columns=isinstance(model_form, (DeviceTypeImportForm, ModuleTypeImportForm))) + # When updating, omit all form fields other than those specified in the record. (No # fields are required when modifying an existing object.) if object_id: diff --git a/netbox/utilities/forms/forms.py b/netbox/utilities/forms/forms.py index 9f84e100f67..6d593cbf55a 100644 --- a/netbox/utilities/forms/forms.py +++ b/netbox/utilities/forms/forms.py @@ -2,6 +2,7 @@ from django import forms from django.utils.translation import gettext as _ + from .mixins import BootstrapMixin __all__ = ( diff --git a/netbox/utilities/forms/utils.py b/netbox/utilities/forms/utils.py index 4d737f16321..eb880c4fedb 100644 --- a/netbox/utilities/forms/utils.py +++ b/netbox/utilities/forms/utils.py @@ -18,7 +18,7 @@ 'parse_numeric_range', 'restrict_form_fields', 'parse_csv', - 'validate_csv', + 'validate_import_headers', ) @@ -205,29 +205,38 @@ def restrict_form_fields(form, user, action='view'): field.queryset = field.queryset.restrict(user, action) -def parse_csv(reader): +def headers_to_dict(headers): """ - Parse a csv_reader object into a headers dictionary and a list of records dictionaries. Raise an error - if the records are formatted incorrectly. Return headers and records as a tuple. + Create a dictionary mapping each header to an optional "to" field specifying how + the related object is being referenced. For example, importing a Device might use a + `site.slug` header, to indicate the related site is being referenced by its slug. """ - records = [] - headers = {} - - # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional - # "to" field specifying how the related object is being referenced. For example, importing a Device might use a - # `site.slug` header, to indicate the related site is being referenced by its slug. - - for header in next(reader): + header_dict = {} + for header in headers: header = header.strip() if '.' in header: field, to_field = header.split('.', 1) - if field in headers: + if field in header_dict: raise forms.ValidationError(f'Duplicate or conflicting column header for "{field}"') - headers[field] = to_field + header_dict[field] = to_field else: - if header in headers: + if header in header_dict: raise forms.ValidationError(f'Duplicate or conflicting column header for "{header}"') - headers[header] = None + header_dict[header] = None + + return header_dict + + +def parse_csv(reader): + """ + Parse a csv_reader object into a headers dictionary and a list of records dictionaries. Raise an error + if the records are formatted incorrectly. Return headers and records as a tuple. + """ + records = [] + headers = {} + + # Consume the first line of CSV data as column headers. + headers = headers_to_dict(list(next(reader))) # Parse CSV rows into a list of dictionaries mapped from the column headers. for i, row in enumerate(reader, start=1): @@ -242,7 +251,7 @@ def parse_csv(reader): return headers, records -def validate_csv(headers, fields, required_fields): +def validate_import_headers(headers, fields, required_fields, allow_extra_columns=False): """ Validate that parsed csv data conforms to the object's available fields. Raise validation errors if parsed csv data contains invalid headers or does not contain required headers. @@ -253,7 +262,7 @@ def validate_csv(headers, fields, required_fields): if field == "id": is_update = True continue - if field not in fields: + if (not allow_extra_columns) and (field not in fields): raise forms.ValidationError(f'Unexpected column header "{field}" found.') if to_field and not hasattr(fields[field], 'to_field_name'): raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots') diff --git a/netbox/utilities/testing/views.py b/netbox/utilities/testing/views.py index 0a84c5d1b18..59e1f35a920 100644 --- a/netbox/utilities/testing/views.py +++ b/netbox/utilities/testing/views.py @@ -1,4 +1,7 @@ import csv +import io +import json +import yaml from django.conf import settings from django.contrib.contenttypes.models import ContentType @@ -574,10 +577,29 @@ class BulkImportObjectsViewTestCase(ModelViewTestCase): def _get_csv_data(self): return '\n'.join(self.csv_data) + def _get_yaml_data(self): + data = [*csv.DictReader(io.StringIO(self._get_csv_data()))] + return yaml.dump(data) + + def _get_json_data(self): + data = [*csv.DictReader(io.StringIO(self._get_csv_data()))] + return json.dumps(data) + def _get_update_csv_data(self): - return self.csv_update_data, '\n'.join(self.csv_update_data) + return '\n'.join(self.csv_update_data) + + def _get_update_yaml_data(self): + data = [*csv.DictReader(io.StringIO(self._get_update_csv_data()))] + return yaml.dump(data) - def test_bulk_import_objects_without_permission(self): + def _get_update_json_data(self): + data = [*csv.DictReader(io.StringIO(self._get_update_csv_data()))] + return json.dumps(data) + + def _get_update_array(self): + return self.csv_update_data + + def test_bulk_import_objects_without_permission_csv(self): data = { 'data': self._get_csv_data(), 'format': ImportFormatChoices.CSV, @@ -593,8 +615,38 @@ def test_bulk_import_objects_without_permission(self): with disable_warnings('django.request'): self.assertHttpStatus(response, 403) + def test_bulk_import_objects_without_permission_yaml(self): + data = { + 'data': self._get_yaml_data(), + 'format': 'yaml', + } + + # Test GET without permission + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.get(self._get_url('import')), 403) + + # Try POST without permission + response = self.client.post(self._get_url('import'), data) + with disable_warnings('django.request'): + self.assertHttpStatus(response, 403) + + def test_bulk_import_objects_without_permission_json(self): + data = { + 'data': self._get_json_data(), + 'format': 'json', + } + + # Test GET without permission + with disable_warnings('django.request'): + self.assertHttpStatus(self.client.get(self._get_url('import')), 403) + + # Try POST without permission + response = self.client.post(self._get_url('import'), data) + with disable_warnings('django.request'): + self.assertHttpStatus(response, 403) + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) - def test_bulk_import_objects_with_permission(self): + def test_bulk_import_objects_with_permission_csv(self): initial_count = self._get_queryset().count() data = { 'data': self._get_csv_data(), @@ -619,15 +671,64 @@ def test_bulk_import_objects_with_permission(self): self.assertEqual(self._get_queryset().count(), initial_count + len(self.csv_data) - 1) @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) - def test_bulk_update_objects_with_permission(self): + def test_bulk_import_objects_with_permission_yaml(self): + initial_count = self._get_queryset().count() + data = { + 'data': self._get_yaml_data(), + 'format': 'yaml', + } + + # Assign model-level permission + obj_perm = ObjectPermission( + name='Test permission', + actions=['add'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with model-level permission + self.assertHttpStatus(self.client.get(self._get_url('import')), 200) + + # Test POST with permission + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302) + self.assertEqual(self._get_queryset().count(), initial_count + len(self.csv_data) - 1) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_bulk_import_objects_with_permission_json(self): + initial_count = self._get_queryset().count() + data = { + 'data': self._get_json_data(), + 'format': 'json', + } + + # Assign model-level permission + obj_perm = ObjectPermission( + name='Test permission', + actions=['add'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Try GET with model-level permission + self.assertHttpStatus(self.client.get(self._get_url('import')), 200) + + # Test POST with permission + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302) + self.assertEqual(self._get_queryset().count(), initial_count + len(self.csv_data) - 1) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_bulk_update_objects_with_permission_csv(self): if not hasattr(self, 'csv_update_data'): raise NotImplementedError("The test must define csv_update_data.") initial_count = self._get_queryset().count() - array, csv_data = self._get_update_csv_data() + array = self._get_update_array() + update_data = self._get_update_csv_data() data = { 'format': ImportFormatChoices.CSV, - 'data': csv_data, + 'data': update_data, 'csv_delimiter': CSVDelimiterChoices.AUTO, } @@ -657,7 +758,83 @@ def test_bulk_update_objects_with_permission(self): self.assertEqual(value, value) @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) - def test_bulk_import_objects_with_constrained_permission(self): + def test_bulk_update_objects_with_permission_yaml(self): + if not hasattr(self, 'csv_update_data'): + raise NotImplementedError("The test must define csv_update_data.") + + initial_count = self._get_queryset().count() + array = self._get_update_array() + update_data = self._get_update_yaml_data() + data = { + 'format': ImportFormatChoices.YAML, + 'data': update_data, + } + + # Assign model-level permission + obj_perm = ObjectPermission( + name='Test permission', + actions=['add'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Test POST with permission + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302) + self.assertEqual(initial_count, self._get_queryset().count()) + + reader = csv.DictReader(array, delimiter=',') + check_data = list(reader) + for line in check_data: + obj = self.model.objects.get(id=line["id"]) + for attr, value in line.items(): + if attr != "id": + field = self.model._meta.get_field(attr) + value = getattr(obj, attr) + # cannot verify FK fields as don't know what name the CSV maps to + if value is not None and not isinstance(field, ForeignKey): + self.assertEqual(value, value) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_bulk_update_objects_with_permission_json(self): + if not hasattr(self, 'csv_update_data'): + raise NotImplementedError("The test must define csv_update_data.") + + initial_count = self._get_queryset().count() + array = self._get_update_array() + update_data = self._get_update_json_data() + data = { + 'format': ImportFormatChoices.JSON, + 'data': update_data, + } + + # Assign model-level permission + obj_perm = ObjectPermission( + name='Test permission', + actions=['add'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Test POST with permission + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302) + self.assertEqual(initial_count, self._get_queryset().count()) + + reader = csv.DictReader(array, delimiter=',') + check_data = list(reader) + for line in check_data: + obj = self.model.objects.get(id=line["id"]) + for attr, value in line.items(): + if attr != "id": + field = self.model._meta.get_field(attr) + value = getattr(obj, attr) + # cannot verify FK fields as don't know what name the CSV maps to + if value is not None and not isinstance(field, ForeignKey): + self.assertEqual(value, value) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_bulk_import_objects_with_constrained_permission_csv(self): initial_count = self._get_queryset().count() data = { 'data': self._get_csv_data(), @@ -687,6 +864,66 @@ def test_bulk_import_objects_with_constrained_permission(self): self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302) self.assertEqual(self._get_queryset().count(), initial_count + len(self.csv_data) - 1) + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_bulk_import_objects_with_constrained_permission_yaml(self): + initial_count = self._get_queryset().count() + data = { + 'data': self._get_yaml_data(), + 'format': 'yaml', + } + + # Assign constrained permission + obj_perm = ObjectPermission( + name='Test permission', + constraints={'pk': 0}, # Dummy permission to deny all + actions=['add'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Attempt to import non-permitted objects + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200) + self.assertEqual(self._get_queryset().count(), initial_count) + + # Update permission constraints + obj_perm.constraints = {'pk__gt': 0} # Dummy permission to allow all + obj_perm.save() + + # Import permitted objects + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302) + self.assertEqual(self._get_queryset().count(), initial_count + len(self.csv_data) - 1) + + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*']) + def test_bulk_import_objects_with_constrained_permission_json(self): + initial_count = self._get_queryset().count() + data = { + 'data': self._get_json_data(), + 'format': 'json', + } + + # Assign constrained permission + obj_perm = ObjectPermission( + name='Test permission', + constraints={'pk': 0}, # Dummy permission to deny all + actions=['add'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + + # Attempt to import non-permitted objects + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200) + self.assertEqual(self._get_queryset().count(), initial_count) + + # Update permission constraints + obj_perm.constraints = {'pk__gt': 0} # Dummy permission to allow all + obj_perm.save() + + # Import permitted objects + self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302) + self.assertEqual(self._get_queryset().count(), initial_count + len(self.csv_data) - 1) + class BulkEditObjectsViewTestCase(ModelViewTestCase): """ Edit multiple instances.