Skip to content

Commit a3d7557

Browse files
authored
Add token refresh for S3 client and API calls (#31)
* fix token refresh on UsernameAndPasswordAuth * add refreshable credentials to s3 client so it can refresh on large multipart uploads * update cognito app IDs * bump version
1 parent 3be765b commit a3d7557

File tree

4 files changed

+54
-33
lines changed

4 files changed

+54
-33
lines changed

pubweb/auth/username.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1+
import logging
2+
from datetime import datetime, timedelta
3+
from typing import Callable
4+
15
import boto3
26
from pycognito import AWSSRP
37
from requests.auth import AuthBase
48

59
from pubweb.auth.base import AuthInfo
610
from pubweb.config import config
711

12+
logger = logging.getLogger()
13+
814

915
class UsernameAndPasswordAuth(AuthInfo):
1016
"""
@@ -15,27 +21,35 @@ class UsernameAndPasswordAuth(AuthInfo):
1521
def __init__(self, username, password):
1622
self.username = username
1723
self.password = password
24+
self.auth_result = None
25+
self.token_expiry = None
1826

1927
def get_request_auth(self) -> AuthBase:
20-
return self.RequestAuth(self._get_token()['AccessToken'])
28+
return self.RequestAuth(lambda: self._get_token()['AccessToken'])
2129

2230
def get_current_user(self) -> str:
2331
return self.username
2432

2533
def _get_token(self):
34+
if self.token_expiry and self.token_expiry > datetime.now():
35+
return self.auth_result
36+
37+
logger.debug('Fetching new token from cognito')
2638
cognito = boto3.client('cognito-idp', region_name=config.region)
2739
aws = AWSSRP(username=self.username,
2840
password=self.password,
2941
pool_id=config.user_pool_id,
3042
client_id=config.app_id,
3143
client=cognito)
3244
resp = aws.authenticate_user()
33-
return resp['AuthenticationResult']
45+
self.auth_result = resp['AuthenticationResult']
46+
self.token_expiry = datetime.now() + timedelta(seconds=self.auth_result['ExpiresIn'])
47+
return self.auth_result
3448

3549
class RequestAuth(AuthBase):
36-
def __init__(self, token):
37-
self.token = token
50+
def __init__(self, token_getter: Callable[..., str]):
51+
self.token_getter = token_getter
3852

3953
def __call__(self, request):
40-
request.headers['Authorization'] = self.token
54+
request.headers['Authorization'] = self.token_getter()
4155
return request

pubweb/clients/s3.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import math
22
import threading
3-
from datetime import datetime, timezone
43
from pathlib import Path
54
from typing import Callable
65

7-
import boto3
6+
from boto3 import Session
7+
from botocore.credentials import RefreshableCredentials
88
from tqdm import tqdm
99

1010
from pubweb.models.auth import Creds
@@ -21,11 +21,14 @@ def convert_size(size):
2121
return '%.2f %s' % (s, size_name[i])
2222

2323

24-
def build_client(creds: Creds):
25-
return boto3.client('s3',
26-
aws_access_key_id=creds['AccessKeyId'],
27-
aws_secret_access_key=creds['SecretAccessKey'],
28-
aws_session_token=creds['SessionToken'])
24+
def format_creds_for_session(creds: Creds):
25+
expiration = parse_json_date(creds['Expiration'])
26+
return {
27+
'access_key': creds['AccessKeyId'],
28+
'secret_key': creds['SecretAccessKey'],
29+
'token': creds['SessionToken'],
30+
'expiry_time': expiration.isoformat()
31+
}
2932

3033

3134
class ProgressPercentage:
@@ -40,13 +43,10 @@ def __call__(self, bytes_amount):
4043

4144
class S3Client:
4245
def __init__(self, creds_getter: Callable[[], Creds]):
43-
creds = creds_getter()
4446
self._creds_getter = creds_getter
45-
self._creds_expiration = creds['Expiration']
46-
self._client = build_client(creds)
47+
self._client = self._build_session_client()
4748

4849
def upload_file(self, local_path: Path, bucket: str, key: str):
49-
self._check_credentials()
5050
file_size = local_path.stat().st_size
5151
file_name = local_path.name
5252

@@ -59,7 +59,6 @@ def upload_file(self, local_path: Path, bucket: str, key: str):
5959
self._client.upload_file(absolute_path, bucket, key, Callback=ProgressPercentage(progress))
6060

6161
def download_file(self, local_path: Path, bucket: str, key: str):
62-
self._check_credentials()
6362
file_size = self.get_file_stats(bucket, key)['ContentLength']
6463
file_name = local_path.name
6564

@@ -72,7 +71,6 @@ def download_file(self, local_path: Path, bucket: str, key: str):
7271
self._client.download_file(bucket, key, absolute_path, Callback=ProgressPercentage(progress))
7372

7473
def create_object(self, bucket: str, key: str, contents: str, content_type: str):
75-
self._check_credentials()
7674
self._client.put_object(
7775
Bucket=bucket,
7876
Key=key,
@@ -82,22 +80,31 @@ def create_object(self, bucket: str, key: str, contents: str, content_type: str)
8280
)
8381

8482
def get_file(self, bucket: str, key: str) -> str:
85-
self._check_credentials()
8683
resp = self._client.get_object(Bucket=bucket, Key=key)
8784
file_body = resp['Body']
8885
return file_body.read().decode('utf-8')
8986

9087
def get_file_stats(self, bucket: str, key: str):
91-
self._check_credentials()
9288
return self._client.head_object(Bucket=bucket, Key=key)
9389

94-
def _check_credentials(self):
95-
if not self._creds_expiration:
96-
return
97-
98-
expiration = parse_json_date(self._creds_expiration)
99-
100-
if expiration < datetime.now(timezone.utc):
101-
new_creds = self._creds_getter()
102-
self._client = build_client(new_creds)
103-
self._creds_expiration = new_creds['Expiration']
90+
def _build_session_client(self):
91+
creds = self._creds_getter()
92+
93+
if creds['Expiration']:
94+
session = Session()
95+
session._credentials = RefreshableCredentials.create_from_metadata(
96+
metadata=format_creds_for_session(creds),
97+
refresh_using=self._refresh_credentials(),
98+
method='sts'
99+
)
100+
else:
101+
session = Session(
102+
aws_access_key_id=creds['AccessKeyId'],
103+
aws_secret_access_key=creds['SecretAccessKey'],
104+
aws_session_token=creds['SessionToken']
105+
)
106+
return session.client('s3')
107+
108+
def _refresh_credentials(self):
109+
new_creds = self._creds_getter()
110+
return format_creds_for_session(new_creds)

pubweb/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class AuthConfig(NamedTuple):
1313

1414
class DevelopmentConfig:
1515
user_pool_id = 'us-west-2_ViB3UFcvp'
16-
app_id = '2g2eg0g7tbjhbaa45diohmvqhs'
16+
app_id = '39jl0uud4d1i337q7gc5l03r98'
1717
data_endpoint = 'https://drdt2z4kljdbte5s4zx623kyk4.appsync-api.us-west-2.amazonaws.com/graphql'
1818
region = 'us-west-2'
1919
resources_bucket = 'pubweb-resources-dev'
@@ -22,7 +22,7 @@ class DevelopmentConfig:
2222

2323
class ProductionConfig:
2424
user_pool_id = 'us-west-2_LQnstneoZ'
25-
app_id = '7ic2n55r9h4fj0qej5q9ikr2o1'
25+
app_id = '2seju0a0p55hmdajb61ftm4edc'
2626
data_endpoint = 'https://22boctowkfbuzaidvbvwjxfnai.appsync-api.us-west-2.amazonaws.com/graphql'
2727
region = 'us-west-2'
2828
resources_bucket = 'pubweb-resources-prd'

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
setup(
3535
name='pubweb',
36-
version='0.3.2',
36+
version='0.3.3',
3737
author='Fred Hutch',
3838
license='MIT',
3939
author_email='[email protected]',

0 commit comments

Comments
 (0)