From 35c55be94e0110d588a1c927ea122e19e0730c84 Mon Sep 17 00:00:00 2001 From: Nathan Thorpe Date: Mon, 11 Jul 2022 14:59:59 -0700 Subject: [PATCH] implement creating of datasets using IAM Auth --- pubweb/auth/base.py | 4 ++++ pubweb/auth/iam.py | 9 +++++++++ pubweb/auth/username.py | 3 +++ pubweb/clients/api.py | 15 ++++++++++++++- pubweb/clients/s3.py | 10 ++++++++++ pubweb/services/dataset.py | 38 +++++++++++++++++++++++++++++++++++++- pubweb/services/file.py | 16 +++++++++++++--- 7 files changed, 90 insertions(+), 5 deletions(-) diff --git a/pubweb/auth/base.py b/pubweb/auth/base.py index e5b88227..5c6ffc95 100644 --- a/pubweb/auth/base.py +++ b/pubweb/auth/base.py @@ -7,3 +7,7 @@ class AuthInfo(ABC): @abstractmethod def get_request_auth(self) -> AuthBase: raise NotImplementedError() + + @abstractmethod + def get_current_user(self) -> str: + raise NotImplementedError() diff --git a/pubweb/auth/iam.py b/pubweb/auth/iam.py index 167c22e4..8b4e8552 100644 --- a/pubweb/auth/iam.py +++ b/pubweb/auth/iam.py @@ -34,3 +34,12 @@ def get_request_auth(self) -> AuthBase: config.region, 'appsync', session_token=self.creds['SessionToken']) + + def get_current_user(self) -> str: + sts_client = boto3.client('sts', + aws_access_key_id=self.creds['AccessKeyId'], + aws_secret_access_key=self.creds['SecretAccessKey'], + aws_session_token=self.creds['SessionToken']) + identity_arn = sts_client.get_caller_identity()['Arn'] + username = identity_arn.split('/')[-1] + return f'iam-{username}' diff --git a/pubweb/auth/username.py b/pubweb/auth/username.py index f319150a..9057895d 100644 --- a/pubweb/auth/username.py +++ b/pubweb/auth/username.py @@ -19,6 +19,9 @@ def __init__(self, username, password): def get_request_auth(self) -> AuthBase: return self.RequestAuth(self._get_token()['AccessToken']) + def get_current_user(self) -> str: + return self.username + def _get_token(self): cognito = boto3.client('cognito-idp', region_name=config.region) aws = AWSSRP(username=self.username, diff --git a/pubweb/clients/api.py b/pubweb/clients/api.py index 154271ce..b80f77fd 100644 --- a/pubweb/clients/api.py +++ b/pubweb/clients/api.py @@ -5,6 +5,7 @@ from pubweb import config from pubweb.auth.base import AuthInfo +from pubweb.auth.iam import IAMAuth HEADERS = { 'Accept': 'application/json', @@ -19,8 +20,20 @@ def _build_gql_client(auth_info: AuthInfo, endpoint: str): class ApiClient: def __init__(self, auth_info: AuthInfo): - self.auth_info = auth_info + self._auth_info = auth_info self._gql_client = _build_gql_client(auth_info, config.data_endpoint) def query(self, query: str, variables=None) -> Dict: return self._gql_client.execute(gql(query), variable_values=variables) + + @property + def has_iam_creds(self) -> bool: + return isinstance(self._auth_info, IAMAuth) + + @property + def current_user(self) -> str: + return self._auth_info.get_current_user() + + def get_iam_creds(self): + if self.has_iam_creds: + return self._auth_info.creds diff --git a/pubweb/clients/s3.py b/pubweb/clients/s3.py index d6eedd3f..d8494d27 100644 --- a/pubweb/clients/s3.py +++ b/pubweb/clients/s3.py @@ -71,6 +71,16 @@ def download_file(self, local_path: Path, bucket: str, key: str): absolute_path = str(local_path.absolute()) self._client.download_file(bucket, key, absolute_path, Callback=ProgressPercentage(progress)) + def create_object(self, bucket: str, key: str, contents: str, content_type: str): + self._check_credentials() + self._client.put_object( + Bucket=bucket, + Key=key, + ContentType=content_type, + ContentEncoding='utf-8', + Body=bytes(contents, "UTF-8") + ) + def get_file(self, bucket: str, key: str) -> str: self._check_credentials() resp = self._client.get_object(Bucket=bucket, Key=key) diff --git a/pubweb/services/dataset.py b/pubweb/services/dataset.py index 633a37a6..c4e9e40a 100644 --- a/pubweb/services/dataset.py +++ b/pubweb/services/dataset.py @@ -1,4 +1,6 @@ +import json import logging +import uuid from typing import List, Union from pubweb.clients.utils import filter_deleted @@ -65,9 +67,14 @@ def find_by_project(self, project_id: str, name: str = None) -> List[Dataset]: def create(self, create_request: CreateIngestDatasetInput) -> DatasetCreateResponse: """ - Creates an ingest dataset + Creates an ingest dataset. + This only registers into the system, does not upload any files """ logger.info(f"Creating dataset {create_request.name}") + + if self._api_client.has_iam_creds: + return self._write_dataset_manifest(create_request) + query = ''' mutation CreateIngestDataset($input: CreateIngestDatasetInput!) { createIngestDataset(input: $input) { @@ -112,3 +119,32 @@ def download_files(self, project_id: str, dataset_id: str, download_location: st files = [file.relative_path for file in files] self._file_service.download_files(access_context, download_location, files) + + def _write_dataset_manifest(self, request: CreateIngestDatasetInput) -> DatasetCreateResponse: + """ + Internal method for registering a dataset without API access. + To be used for machine or service accounts + """ + manifest = { + 'project': request.project_id, + 'process': request.process_id, + 'name': request.name, + 'desc': request.description, + 'infoJson': { + 'ingestedBy': self._api_client.current_user + }, + 'files': [{'name': file} for file in request.files] + } + dataset_id = str(uuid.uuid4()) + manifest_path = f'datasets/{dataset_id}/artifacts/manifest.json' + manifest_json = json.dumps(manifest, indent=4) + access_context = FileAccessContext.upload_dataset(dataset_id=dataset_id, + project_id=request.project_id) + self._file_service.create_file(access_context, + key=manifest_path, + contents=manifest_json, + content_type='application/json') + return { + 'datasetId': dataset_id, + 'dataPath': f'datasets/{dataset_id}/artifacts/data' + } diff --git a/pubweb/services/file.py b/pubweb/services/file.py index 6171cb8e..10937186 100644 --- a/pubweb/services/file.py +++ b/pubweb/services/file.py @@ -2,7 +2,6 @@ from functools import partial from typing import List -from pubweb.auth.iam import IAMAuth from pubweb.clients import ApiClient, S3Client from pubweb.file_utils import upload_directory, download_directory from pubweb.models.auth import Creds @@ -14,8 +13,8 @@ class FileService(BaseService): def get_access_credentials(self, access_context: FileAccessContext) -> Creds: # Special case: # we do not need to call the API to get IAM creds if we are using IAM creds - if isinstance(self._api_client.auth_info, IAMAuth): - return self._api_client.auth_info.creds + if self._api_client.has_iam_creds: + return self._api_client.get_iam_creds() # Call API to get temporary credentials credentials_response = self._api_client.query(*access_context.get_token_query) return credentials_response['getFileAccessToken'] @@ -34,6 +33,17 @@ def get_file_from_path(self, access_context: FileAccessContext, file_path: str) full_path = f'{access_context.path_prefix}/{file_path}'.lstrip('/') return s3_client.get_file(access_context.bucket, full_path) + def create_file(self, access_context: FileAccessContext, key: str, + contents: str, content_type: str): + """ + Creates a file at the specified path + """ + s3_client = S3Client(partial(self.get_access_credentials, access_context)) + s3_client.create_object(key=key, + contents=contents, + content_type=content_type, + bucket=access_context.bucket) + def upload_files(self, access_context: FileAccessContext, directory: str, files: List[str]): """ Uploads a list of files from the specified directory