diff --git a/netbox/users/api/serializers.py b/netbox/users/api/serializers.py index 1f4bf4ea078..75ab877cf88 100644 --- a/netbox/users/api/serializers.py +++ b/netbox/users/api/serializers.py @@ -1,11 +1,12 @@ from django.conf import settings +from django.contrib.auth import authenticate from django.contrib.auth import get_user_model from django.contrib.auth.models import Group from django.contrib.contenttypes.models import ContentType from drf_spectacular.utils import extend_schema_field from drf_spectacular.types import OpenApiTypes from rest_framework import serializers -from rest_framework.exceptions import PermissionDenied +from rest_framework.exceptions import AuthenticationFailed, PermissionDenied from netbox.api.fields import ContentTypeField, IPNetworkSerializer, SerializedPKRelatedField from netbox.api.serializers import ValidatedModelSerializer @@ -107,9 +108,42 @@ def validate(self, data): return super().validate(data) -class TokenProvisionSerializer(serializers.Serializer): - username = serializers.CharField() - password = serializers.CharField() +class TokenProvisionSerializer(TokenSerializer): + user = NestedUserSerializer( + read_only=True + ) + username = serializers.CharField( + write_only=True + ) + password = serializers.CharField( + write_only=True + ) + last_used = serializers.DateTimeField( + read_only=True + ) + key = serializers.CharField( + read_only=True + ) + + class Meta: + model = Token + fields = ( + 'id', 'url', 'display', 'user', 'created', 'expires', 'last_used', 'key', 'write_enabled', 'description', + 'allowed_ips', 'username', 'password', + ) + + def validate(self, data): + # Validate the username and password + username = data.pop('username') + password = data.pop('password') + user = authenticate(request=self.context.get('request'), username=username, password=password) + if user is None: + raise AuthenticationFailed("Invalid username/password") + + # Inject the user into the validated data + data['user'] = user + + return data class ObjectPermissionSerializer(ValidatedModelSerializer): diff --git a/netbox/users/api/views.py b/netbox/users/api/views.py index 9cf5b1ac5cd..62a32c71b8b 100644 --- a/netbox/users/api/views.py +++ b/netbox/users/api/views.py @@ -1,3 +1,4 @@ +import logging from django.contrib.auth import authenticate from django.contrib.auth import get_user_model from django.contrib.auth.models import Group @@ -63,34 +64,21 @@ class TokenProvisionView(APIView): @extend_schema( request=serializers.TokenProvisionSerializer, responses={ - 201: serializers.TokenSerializer, + 201: serializers.TokenProvisionSerializer, 401: OpenApiTypes.OBJECT, } ) def post(self, request): - serializer = serializers.TokenProvisionSerializer(data=request.data) - serializer.is_valid() - - # Authenticate the user account based on the provided credentials - username = serializer.data.get('username') - password = serializer.data.get('password') - if not username or not password: - raise AuthenticationFailed("Username and password must be provided to provision a token.") - user = authenticate(request=request, username=username, password=password) - if user is None: - raise AuthenticationFailed("Invalid username/password") - - # Create a new Token for the User - token = Token(user=user) - token.save() - data = serializers.TokenSerializer(token, context={'request': request}).data - # Manually append the token key, which is normally write-only - data['key'] = token.key - - return Response(data, status=HTTP_201_CREATED) - - def get_serializer_class(self): - return serializers.TokenSerializer + serializer = serializers.TokenProvisionSerializer(data=request.data, context={'request': request}) + serializer.is_valid(raise_exception=True) + self.perform_create(serializer) + return Response(serializer.data, status=HTTP_201_CREATED) + + def perform_create(self, serializer): + model = serializer.Meta.model + logger = logging.getLogger(f'netbox.api.views.TokenProvisionView') + logger.info(f"Creating new {model._meta.verbose_name}") + serializer.save() # diff --git a/netbox/users/tests/test_api.py b/netbox/users/tests/test_api.py index 859dd0b8336..0011424107b 100644 --- a/netbox/users/tests/test_api.py +++ b/netbox/users/tests/test_api.py @@ -141,17 +141,25 @@ def test_provision_token_valid(self): """ Test the provisioning of a new REST API token given a valid username and password. """ - data = { + user_credentials = { 'username': 'user1', 'password': 'abc123', } - user = User.objects.create_user(**data) + user = User.objects.create_user(**user_credentials) + + data = { + **user_credentials, + 'description': 'My API token', + 'expires': '2099-12-31T23:59:59Z', + } url = reverse('users-api:token_provision') response = self.client.post(url, data, format='json', **self.header) self.assertEqual(response.status_code, 201) self.assertIn('key', response.data) self.assertEqual(len(response.data['key']), 40) + self.assertEqual(response.data['description'], data['description']) + self.assertEqual(response.data['expires'], data['expires']) token = Token.objects.get(user=user) self.assertEqual(token.key, response.data['key'])