Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions pubweb/auth/username.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import logging
from datetime import datetime, timedelta
from typing import Callable

import boto3
from pycognito import AWSSRP
from requests.auth import AuthBase

from pubweb.auth.base import AuthInfo
from pubweb.config import config

logger = logging.getLogger()


class UsernameAndPasswordAuth(AuthInfo):
"""
Expand All @@ -15,27 +21,35 @@ class UsernameAndPasswordAuth(AuthInfo):
def __init__(self, username, password):
self.username = username
self.password = password
self.auth_result = None
self.token_expiry = None

def get_request_auth(self) -> AuthBase:
return self.RequestAuth(self._get_token()['AccessToken'])
return self.RequestAuth(lambda: self._get_token()['AccessToken'])

def get_current_user(self) -> str:
return self.username

def _get_token(self):
if self.token_expiry and self.token_expiry > datetime.now():
return self.auth_result

logger.debug('Fetching new token from cognito')
cognito = boto3.client('cognito-idp', region_name=config.region)
aws = AWSSRP(username=self.username,
password=self.password,
pool_id=config.user_pool_id,
client_id=config.app_id,
client=cognito)
resp = aws.authenticate_user()
return resp['AuthenticationResult']
self.auth_result = resp['AuthenticationResult']
self.token_expiry = datetime.now() + timedelta(seconds=self.auth_result['ExpiresIn'])
return self.auth_result

class RequestAuth(AuthBase):
def __init__(self, token):
self.token = token
def __init__(self, token_getter: Callable[..., str]):
self.token_getter = token_getter

def __call__(self, request):
request.headers['Authorization'] = self.token
request.headers['Authorization'] = self.token_getter()
return request
57 changes: 32 additions & 25 deletions pubweb/clients/s3.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import math
import threading
from datetime import datetime, timezone
from pathlib import Path
from typing import Callable

import boto3
from boto3 import Session
from botocore.credentials import RefreshableCredentials
from tqdm import tqdm

from pubweb.models.auth import Creds
Expand All @@ -21,11 +21,14 @@ def convert_size(size):
return '%.2f %s' % (s, size_name[i])


def build_client(creds: Creds):
return boto3.client('s3',
aws_access_key_id=creds['AccessKeyId'],
aws_secret_access_key=creds['SecretAccessKey'],
aws_session_token=creds['SessionToken'])
def format_creds_for_session(creds: Creds):
expiration = parse_json_date(creds['Expiration'])
return {
'access_key': creds['AccessKeyId'],
'secret_key': creds['SecretAccessKey'],
'token': creds['SessionToken'],
'expiry_time': expiration.isoformat()
}


class ProgressPercentage:
Expand All @@ -40,13 +43,10 @@ def __call__(self, bytes_amount):

class S3Client:
def __init__(self, creds_getter: Callable[[], Creds]):
creds = creds_getter()
self._creds_getter = creds_getter
self._creds_expiration = creds['Expiration']
self._client = build_client(creds)
self._client = self._build_session_client()

def upload_file(self, local_path: Path, bucket: str, key: str):
self._check_credentials()
file_size = local_path.stat().st_size
file_name = local_path.name

Expand All @@ -59,7 +59,6 @@ def upload_file(self, local_path: Path, bucket: str, key: str):
self._client.upload_file(absolute_path, bucket, key, Callback=ProgressPercentage(progress))

def download_file(self, local_path: Path, bucket: str, key: str):
self._check_credentials()
file_size = self.get_file_stats(bucket, key)['ContentLength']
file_name = local_path.name

Expand All @@ -72,7 +71,6 @@ def download_file(self, local_path: Path, bucket: str, key: str):
self._client.download_file(bucket, key, absolute_path, Callback=ProgressPercentage(progress))

def create_object(self, bucket: str, key: str, contents: str, content_type: str):
self._check_credentials()
self._client.put_object(
Bucket=bucket,
Key=key,
Expand All @@ -82,22 +80,31 @@ def create_object(self, bucket: str, key: str, contents: str, content_type: str)
)

def get_file(self, bucket: str, key: str) -> str:
self._check_credentials()
resp = self._client.get_object(Bucket=bucket, Key=key)
file_body = resp['Body']
return file_body.read().decode('utf-8')

def get_file_stats(self, bucket: str, key: str):
self._check_credentials()
return self._client.head_object(Bucket=bucket, Key=key)

def _check_credentials(self):
if not self._creds_expiration:
return

expiration = parse_json_date(self._creds_expiration)

if expiration < datetime.now(timezone.utc):
new_creds = self._creds_getter()
self._client = build_client(new_creds)
self._creds_expiration = new_creds['Expiration']
def _build_session_client(self):
creds = self._creds_getter()

if creds['Expiration']:
session = Session()
session._credentials = RefreshableCredentials.create_from_metadata(
metadata=format_creds_for_session(creds),
refresh_using=self._refresh_credentials(),
method='sts'
)
else:
session = Session(
aws_access_key_id=creds['AccessKeyId'],
aws_secret_access_key=creds['SecretAccessKey'],
aws_session_token=creds['SessionToken']
)
return session.client('s3')

def _refresh_credentials(self):
new_creds = self._creds_getter()
return format_creds_for_session(new_creds)
4 changes: 2 additions & 2 deletions pubweb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class AuthConfig(NamedTuple):

class DevelopmentConfig:
user_pool_id = 'us-west-2_ViB3UFcvp'
app_id = '2g2eg0g7tbjhbaa45diohmvqhs'
app_id = '39jl0uud4d1i337q7gc5l03r98'
data_endpoint = 'https://drdt2z4kljdbte5s4zx623kyk4.appsync-api.us-west-2.amazonaws.com/graphql'
region = 'us-west-2'
resources_bucket = 'pubweb-resources-dev'
Expand All @@ -22,7 +22,7 @@ class DevelopmentConfig:

class ProductionConfig:
user_pool_id = 'us-west-2_LQnstneoZ'
app_id = '7ic2n55r9h4fj0qej5q9ikr2o1'
app_id = '2seju0a0p55hmdajb61ftm4edc'
data_endpoint = 'https://22boctowkfbuzaidvbvwjxfnai.appsync-api.us-west-2.amazonaws.com/graphql'
region = 'us-west-2'
resources_bucket = 'pubweb-resources-prd'
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

setup(
name='pubweb',
version='0.3.2',
version='0.3.3',
author='Fred Hutch',
license='MIT',
author_email='[email protected]',
Expand Down