Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 222 additions & 2 deletions netbox/netbox/tests/test_import.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import csv
import io
import json
import yaml

from django.contrib.contenttypes.models import ContentType
from django.test import override_settings

Expand All @@ -17,8 +22,16 @@ def setUpTestData(cls):
def _get_csv_data(self, csv_data):
return '\n'.join(csv_data)

def _get_yaml_data(self, csv_data):
data = [*csv.DictReader(io.StringIO(self._get_csv_data(csv_data)))]
return yaml.dump(data)

def _get_json_data(self, csv_data):
data = [*csv.DictReader(io.StringIO(self._get_csv_data(csv_data)))]
return json.dumps(data)

@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_valid_tags(self):
def test_valid_tags_csv(self):
csv_data = (
'name,slug,tags',
'Region 1,region-1,"alpha,bravo"',
Expand Down Expand Up @@ -62,7 +75,93 @@ def test_valid_tags(self):
self.assertEqual(regions[3].tags.count(), 0)

@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_invalid_tags(self):
def test_valid_tags_yaml(self):
csv_data = (
'name,slug,tags',
'Region 1,region-1,"alpha,bravo"',
'Region 2,region-2,"charlie,delta"',
'Region 3,region-3,echo',
'Region 4,region-4,',
)

data = {
'format': ImportFormatChoices.YAML,
'data': self._get_yaml_data(csv_data),
}

# Assign model-level permission
obj_perm = ObjectPermission(name='Test permission', actions=['add'])
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))

# Try GET with model-level permission
self.assertHttpStatus(self.client.get(self._get_url('import')), 200)

# Test POST with permission
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302)
regions = Region.objects.all()
self.assertEqual(regions.count(), 4)
region = Region.objects.get(slug="region-4")
self.assertEqual(
list(regions[0].tags.values_list('name', flat=True)),
['Alpha', 'Bravo']
)
self.assertEqual(
list(regions[1].tags.values_list('name', flat=True)),
['Charlie', 'Delta']
)
self.assertEqual(
list(regions[2].tags.values_list('name', flat=True)),
['Echo']
)
self.assertEqual(regions[3].tags.count(), 0)

@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_valid_tags_json(self):
csv_data = (
'name,slug,tags',
'Region 1,region-1,"alpha,bravo"',
'Region 2,region-2,"charlie,delta"',
'Region 3,region-3,echo',
'Region 4,region-4,',
)

data = {
'format': ImportFormatChoices.JSON,
'data': self._get_json_data(csv_data),
}

# Assign model-level permission
obj_perm = ObjectPermission(name='Test permission', actions=['add'])
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))

# Try GET with model-level permission
self.assertHttpStatus(self.client.get(self._get_url('import')), 200)

# Test POST with permission
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 302)
regions = Region.objects.all()
self.assertEqual(regions.count(), 4)
region = Region.objects.get(slug="region-4")
self.assertEqual(
list(regions[0].tags.values_list('name', flat=True)),
['Alpha', 'Bravo']
)
self.assertEqual(
list(regions[1].tags.values_list('name', flat=True)),
['Charlie', 'Delta']
)
self.assertEqual(
list(regions[2].tags.values_list('name', flat=True)),
['Echo']
)
self.assertEqual(regions[3].tags.count(), 0)

@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_invalid_tags_csv(self):
csv_data = (
'name,slug,tags',
'Region 1,region-1,"Alpha,Bravo"', # Valid
Expand All @@ -86,3 +185,124 @@ def test_invalid_tags(self):
# Test POST with permission
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200)
self.assertEqual(Region.objects.count(), 0)

@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_invalid_tags_yaml(self):
csv_data = (
'name,slug,tags',
'Region 1,region-1,"Alpha,Bravo"', # Valid
'Region 2,region-2,"Alpha,Tango"', # Invalid
)

data = {
'format': ImportFormatChoices.YAML,
'data': self._get_yaml_data(csv_data),
}

# Assign model-level permission
obj_perm = ObjectPermission(name='Test permission', actions=['add'])
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))

# Try GET with model-level permission
self.assertHttpStatus(self.client.get(self._get_url('import')), 200)

# Test POST with permission
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200)
self.assertEqual(Region.objects.count(), 0)

@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_invalid_tags_json(self):
csv_data = (
'name,slug,tags',
'Region 1,region-1,"Alpha,Bravo"', # Valid
'Region 2,region-2,"Alpha,Tango"', # Invalid
)

data = {
'format': ImportFormatChoices.JSON,
'data': self._get_json_data(csv_data),
}

# Assign model-level permission
obj_perm = ObjectPermission(name='Test permission', actions=['add'])
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))

# Try GET with model-level permission
self.assertHttpStatus(self.client.get(self._get_url('import')), 200)

# Test POST with permission
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200)
self.assertEqual(Region.objects.count(), 0)

@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_invalid_header_csv(self):
csv_data = (
'name,slug,xxx',
'Region 1,region-1,yyy',
)

data = {
'format': ImportFormatChoices.CSV,
'data': self._get_csv_data(csv_data),
}

# Assign model-level permission
obj_perm = ObjectPermission(name='Test permission', actions=['add'])
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))

# Test POST with permission
ret = self.client.post(self._get_url('import'), data)
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200)
self.assertEqual(Region.objects.count(), 0)

@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_invalid_header_yaml(self):
csv_data = (
'name,slug,xxx',
'Region 1,region-1,yyy',
)

data = {
'format': ImportFormatChoices.YAML,
'data': self._get_yaml_data(csv_data),
}

# Assign model-level permission
obj_perm = ObjectPermission(name='Test permission', actions=['add'])
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))

# Test POST with permission
ret = self.client.post(self._get_url('import'), data)
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200)
self.assertEqual(Region.objects.count(), 0)

@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_invalid_header_json(self):
csv_data = (
'name,slug,xxx',
'Region 1,region-1,yyy',
)

data = {
'format': ImportFormatChoices.JSON,
'data': self._get_json_data(csv_data),
}

# Assign model-level permission
obj_perm = ObjectPermission(name='Test permission', actions=['add'])
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))

# Test POST with permission
ret = self.client.post(self._get_url('import'), data)
self.assertHttpStatus(self.client.post(self._get_url('import'), data), 200)
self.assertEqual(Region.objects.count(), 0)
20 changes: 19 additions & 1 deletion netbox/netbox/views/generic/bulk_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from django.utils.safestring import mark_safe
from django_tables2.export import TableExport

from dcim.forms.bulk_import import DeviceTypeImportForm, ModuleTypeImportForm
from extras.models import ExportTemplate
from extras.signals import clear_webhooks
from utilities.error_handlers import handle_protectederror
from utilities.exceptions import AbortRequest, AbortTransaction, PermissionsViolation
from utilities.forms import BulkRenameForm, ConfirmationForm, restrict_form_fields
from utilities.forms.bulk_import import BulkImportForm
from utilities.forms.utils import headers_to_dict, validate_import_headers
from utilities.htmx import is_embedded, is_htmx
from utilities.permissions import get_permission_for_model
from utilities.utils import get_viewname
Expand Down Expand Up @@ -398,10 +400,26 @@ def create_and_update_objects(self, form, request):
'data': record,
'instance': instance,
}
headers = None
if hasattr(form, '_csv_headers'):
model_form_kwargs['headers'] = form._csv_headers # Add CSV headers
headers = form._csv_headers
model_form_kwargs['headers'] = headers # Add CSV headers
model_form = self.model_form(**model_form_kwargs)

# validate the fields (required fields are present and no unknown fields)
form_fields = model_form.fields
required_fields = [
name for name, field in form_fields.items() if field.required
]

if not headers:
keys = list(record.keys())
if object_id:
keys.append("id")
headers = headers_to_dict(keys)

validate_import_headers(headers, form_fields, required_fields, allow_extra_columns=isinstance(model_form, (DeviceTypeImportForm, ModuleTypeImportForm)))

# When updating, omit all form fields other than those specified in the record. (No
# fields are required when modifying an existing object.)
if object_id:
Expand Down
1 change: 1 addition & 0 deletions netbox/utilities/forms/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from django import forms
from django.utils.translation import gettext as _

from .mixins import BootstrapMixin

__all__ = (
Expand Down
45 changes: 27 additions & 18 deletions netbox/utilities/forms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
'parse_numeric_range',
'restrict_form_fields',
'parse_csv',
'validate_csv',
'validate_import_headers',
)


Expand Down Expand Up @@ -205,29 +205,38 @@ def restrict_form_fields(form, user, action='view'):
field.queryset = field.queryset.restrict(user, action)


def parse_csv(reader):
def headers_to_dict(headers):
"""
Parse a csv_reader object into a headers dictionary and a list of records dictionaries. Raise an error
if the records are formatted incorrectly. Return headers and records as a tuple.
Create a dictionary mapping each header to an optional "to" field specifying how
the related object is being referenced. For example, importing a Device might use a
`site.slug` header, to indicate the related site is being referenced by its slug.
"""
records = []
headers = {}

# Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
# "to" field specifying how the related object is being referenced. For example, importing a Device might use a
# `site.slug` header, to indicate the related site is being referenced by its slug.

for header in next(reader):
header_dict = {}
for header in headers:
header = header.strip()
if '.' in header:
field, to_field = header.split('.', 1)
if field in headers:
if field in header_dict:
raise forms.ValidationError(f'Duplicate or conflicting column header for "{field}"')
headers[field] = to_field
header_dict[field] = to_field
else:
if header in headers:
if header in header_dict:
raise forms.ValidationError(f'Duplicate or conflicting column header for "{header}"')
headers[header] = None
header_dict[header] = None

return header_dict


def parse_csv(reader):
"""
Parse a csv_reader object into a headers dictionary and a list of records dictionaries. Raise an error
if the records are formatted incorrectly. Return headers and records as a tuple.
"""
records = []
headers = {}

# Consume the first line of CSV data as column headers.
headers = headers_to_dict(list(next(reader)))

# Parse CSV rows into a list of dictionaries mapped from the column headers.
for i, row in enumerate(reader, start=1):
Expand All @@ -242,7 +251,7 @@ def parse_csv(reader):
return headers, records


def validate_csv(headers, fields, required_fields):
def validate_import_headers(headers, fields, required_fields, allow_extra_columns=False):
"""
Validate that parsed csv data conforms to the object's available fields. Raise validation errors
if parsed csv data contains invalid headers or does not contain required headers.
Expand All @@ -253,7 +262,7 @@ def validate_csv(headers, fields, required_fields):
if field == "id":
is_update = True
continue
if field not in fields:
if (not allow_extra_columns) and (field not in fields):
raise forms.ValidationError(f'Unexpected column header "{field}" found.')
if to_field and not hasattr(fields[field], 'to_field_name'):
raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots')
Expand Down
Loading