Skip to content

Commit 795be20

Browse files
authored
implement creating of datasets using IAM Auth (#30)
1 parent 25b64e1 commit 795be20

File tree

7 files changed

+90
-5
lines changed

7 files changed

+90
-5
lines changed

pubweb/auth/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@ class AuthInfo(ABC):
77
@abstractmethod
88
def get_request_auth(self) -> AuthBase:
99
raise NotImplementedError()
10+
11+
@abstractmethod
12+
def get_current_user(self) -> str:
13+
raise NotImplementedError()

pubweb/auth/iam.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,12 @@ def get_request_auth(self) -> AuthBase:
3434
config.region,
3535
'appsync',
3636
session_token=self.creds['SessionToken'])
37+
38+
def get_current_user(self) -> str:
39+
sts_client = boto3.client('sts',
40+
aws_access_key_id=self.creds['AccessKeyId'],
41+
aws_secret_access_key=self.creds['SecretAccessKey'],
42+
aws_session_token=self.creds['SessionToken'])
43+
identity_arn = sts_client.get_caller_identity()['Arn']
44+
username = identity_arn.split('/')[-1]
45+
return f'iam-{username}'

pubweb/auth/username.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def __init__(self, username, password):
1919
def get_request_auth(self) -> AuthBase:
2020
return self.RequestAuth(self._get_token()['AccessToken'])
2121

22+
def get_current_user(self) -> str:
23+
return self.username
24+
2225
def _get_token(self):
2326
cognito = boto3.client('cognito-idp', region_name=config.region)
2427
aws = AWSSRP(username=self.username,

pubweb/clients/api.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pubweb import config
77
from pubweb.auth.base import AuthInfo
8+
from pubweb.auth.iam import IAMAuth
89

910
HEADERS = {
1011
'Accept': 'application/json',
@@ -19,8 +20,20 @@ def _build_gql_client(auth_info: AuthInfo, endpoint: str):
1920

2021
class ApiClient:
2122
def __init__(self, auth_info: AuthInfo):
22-
self.auth_info = auth_info
23+
self._auth_info = auth_info
2324
self._gql_client = _build_gql_client(auth_info, config.data_endpoint)
2425

2526
def query(self, query: str, variables=None) -> Dict:
2627
return self._gql_client.execute(gql(query), variable_values=variables)
28+
29+
@property
30+
def has_iam_creds(self) -> bool:
31+
return isinstance(self._auth_info, IAMAuth)
32+
33+
@property
34+
def current_user(self) -> str:
35+
return self._auth_info.get_current_user()
36+
37+
def get_iam_creds(self):
38+
if self.has_iam_creds:
39+
return self._auth_info.creds

pubweb/clients/s3.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@ def download_file(self, local_path: Path, bucket: str, key: str):
7171
absolute_path = str(local_path.absolute())
7272
self._client.download_file(bucket, key, absolute_path, Callback=ProgressPercentage(progress))
7373

74+
def create_object(self, bucket: str, key: str, contents: str, content_type: str):
75+
self._check_credentials()
76+
self._client.put_object(
77+
Bucket=bucket,
78+
Key=key,
79+
ContentType=content_type,
80+
ContentEncoding='utf-8',
81+
Body=bytes(contents, "UTF-8")
82+
)
83+
7484
def get_file(self, bucket: str, key: str) -> str:
7585
self._check_credentials()
7686
resp = self._client.get_object(Bucket=bucket, Key=key)

pubweb/services/dataset.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import json
12
import logging
3+
import uuid
24
from typing import List, Union
35

46
from pubweb.clients.utils import filter_deleted
@@ -65,9 +67,14 @@ def find_by_project(self, project_id: str, name: str = None) -> List[Dataset]:
6567

6668
def create(self, create_request: CreateIngestDatasetInput) -> DatasetCreateResponse:
6769
"""
68-
Creates an ingest dataset
70+
Creates an ingest dataset.
71+
This only registers into the system, does not upload any files
6972
"""
7073
logger.info(f"Creating dataset {create_request.name}")
74+
75+
if self._api_client.has_iam_creds:
76+
return self._write_dataset_manifest(create_request)
77+
7178
query = '''
7279
mutation CreateIngestDataset($input: CreateIngestDatasetInput!) {
7380
createIngestDataset(input: $input) {
@@ -112,3 +119,32 @@ def download_files(self, project_id: str, dataset_id: str, download_location: st
112119
files = [file.relative_path for file in files]
113120

114121
self._file_service.download_files(access_context, download_location, files)
122+
123+
def _write_dataset_manifest(self, request: CreateIngestDatasetInput) -> DatasetCreateResponse:
124+
"""
125+
Internal method for registering a dataset without API access.
126+
To be used for machine or service accounts
127+
"""
128+
manifest = {
129+
'project': request.project_id,
130+
'process': request.process_id,
131+
'name': request.name,
132+
'desc': request.description,
133+
'infoJson': {
134+
'ingestedBy': self._api_client.current_user
135+
},
136+
'files': [{'name': file} for file in request.files]
137+
}
138+
dataset_id = str(uuid.uuid4())
139+
manifest_path = f'datasets/{dataset_id}/artifacts/manifest.json'
140+
manifest_json = json.dumps(manifest, indent=4)
141+
access_context = FileAccessContext.upload_dataset(dataset_id=dataset_id,
142+
project_id=request.project_id)
143+
self._file_service.create_file(access_context,
144+
key=manifest_path,
145+
contents=manifest_json,
146+
content_type='application/json')
147+
return {
148+
'datasetId': dataset_id,
149+
'dataPath': f'datasets/{dataset_id}/artifacts/data'
150+
}

pubweb/services/file.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from functools import partial
33
from typing import List
44

5-
from pubweb.auth.iam import IAMAuth
65
from pubweb.clients import ApiClient, S3Client
76
from pubweb.file_utils import upload_directory, download_directory
87
from pubweb.models.auth import Creds
@@ -14,8 +13,8 @@ class FileService(BaseService):
1413
def get_access_credentials(self, access_context: FileAccessContext) -> Creds:
1514
# Special case:
1615
# we do not need to call the API to get IAM creds if we are using IAM creds
17-
if isinstance(self._api_client.auth_info, IAMAuth):
18-
return self._api_client.auth_info.creds
16+
if self._api_client.has_iam_creds:
17+
return self._api_client.get_iam_creds()
1918
# Call API to get temporary credentials
2019
credentials_response = self._api_client.query(*access_context.get_token_query)
2120
return credentials_response['getFileAccessToken']
@@ -34,6 +33,17 @@ def get_file_from_path(self, access_context: FileAccessContext, file_path: str)
3433
full_path = f'{access_context.path_prefix}/{file_path}'.lstrip('/')
3534
return s3_client.get_file(access_context.bucket, full_path)
3635

36+
def create_file(self, access_context: FileAccessContext, key: str,
37+
contents: str, content_type: str):
38+
"""
39+
Creates a file at the specified path
40+
"""
41+
s3_client = S3Client(partial(self.get_access_credentials, access_context))
42+
s3_client.create_object(key=key,
43+
contents=contents,
44+
content_type=content_type,
45+
bucket=access_context.bucket)
46+
3747
def upload_files(self, access_context: FileAccessContext, directory: str, files: List[str]):
3848
"""
3949
Uploads a list of files from the specified directory

0 commit comments

Comments
 (0)