11import math
22import threading
3- from datetime import datetime , timezone
43from pathlib import Path
54from typing import Callable
65
7- import boto3
6+ from boto3 import Session
7+ from botocore .credentials import RefreshableCredentials
88from tqdm import tqdm
99
1010from pubweb .models .auth import Creds
@@ -21,11 +21,14 @@ def convert_size(size):
2121 return '%.2f %s' % (s , size_name [i ])
2222
2323
24- def build_client (creds : Creds ):
25- return boto3 .client ('s3' ,
26- aws_access_key_id = creds ['AccessKeyId' ],
27- aws_secret_access_key = creds ['SecretAccessKey' ],
28- aws_session_token = creds ['SessionToken' ])
24+ def format_creds_for_session (creds : Creds ):
25+ expiration = parse_json_date (creds ['Expiration' ])
26+ return {
27+ 'access_key' : creds ['AccessKeyId' ],
28+ 'secret_key' : creds ['SecretAccessKey' ],
29+ 'token' : creds ['SessionToken' ],
30+ 'expiry_time' : expiration .isoformat ()
31+ }
2932
3033
3134class ProgressPercentage :
@@ -40,13 +43,10 @@ def __call__(self, bytes_amount):
4043
4144class S3Client :
4245 def __init__ (self , creds_getter : Callable [[], Creds ]):
43- creds = creds_getter ()
4446 self ._creds_getter = creds_getter
45- self ._creds_expiration = creds ['Expiration' ]
46- self ._client = build_client (creds )
47+ self ._client = self ._build_session_client ()
4748
4849 def upload_file (self , local_path : Path , bucket : str , key : str ):
49- self ._check_credentials ()
5050 file_size = local_path .stat ().st_size
5151 file_name = local_path .name
5252
@@ -59,7 +59,6 @@ def upload_file(self, local_path: Path, bucket: str, key: str):
5959 self ._client .upload_file (absolute_path , bucket , key , Callback = ProgressPercentage (progress ))
6060
6161 def download_file (self , local_path : Path , bucket : str , key : str ):
62- self ._check_credentials ()
6362 file_size = self .get_file_stats (bucket , key )['ContentLength' ]
6463 file_name = local_path .name
6564
@@ -72,7 +71,6 @@ def download_file(self, local_path: Path, bucket: str, key: str):
7271 self ._client .download_file (bucket , key , absolute_path , Callback = ProgressPercentage (progress ))
7372
7473 def create_object (self , bucket : str , key : str , contents : str , content_type : str ):
75- self ._check_credentials ()
7674 self ._client .put_object (
7775 Bucket = bucket ,
7876 Key = key ,
@@ -82,22 +80,31 @@ def create_object(self, bucket: str, key: str, contents: str, content_type: str)
8280 )
8381
8482 def get_file (self , bucket : str , key : str ) -> str :
85- self ._check_credentials ()
8683 resp = self ._client .get_object (Bucket = bucket , Key = key )
8784 file_body = resp ['Body' ]
8885 return file_body .read ().decode ('utf-8' )
8986
9087 def get_file_stats (self , bucket : str , key : str ):
91- self ._check_credentials ()
9288 return self ._client .head_object (Bucket = bucket , Key = key )
9389
94- def _check_credentials (self ):
95- if not self ._creds_expiration :
96- return
97-
98- expiration = parse_json_date (self ._creds_expiration )
99-
100- if expiration < datetime .now (timezone .utc ):
101- new_creds = self ._creds_getter ()
102- self ._client = build_client (new_creds )
103- self ._creds_expiration = new_creds ['Expiration' ]
90+ def _build_session_client (self ):
91+ creds = self ._creds_getter ()
92+
93+ if creds ['Expiration' ]:
94+ session = Session ()
95+ session ._credentials = RefreshableCredentials .create_from_metadata (
96+ metadata = format_creds_for_session (creds ),
97+ refresh_using = self ._refresh_credentials (),
98+ method = 'sts'
99+ )
100+ else :
101+ session = Session (
102+ aws_access_key_id = creds ['AccessKeyId' ],
103+ aws_secret_access_key = creds ['SecretAccessKey' ],
104+ aws_session_token = creds ['SessionToken' ]
105+ )
106+ return session .client ('s3' )
107+
108+ def _refresh_credentials (self ):
109+ new_creds = self ._creds_getter ()
110+ return format_creds_for_session (new_creds )
0 commit comments