33from functools import partial
44from typing import List , Dict
55
6+ from botocore .client import BaseClient
67from cirro_api_client import CirroApiClient
78from cirro_api_client .v1 .api .file import generate_project_file_access_token
89from cirro_api_client .v1 .models import AWSCredentials , ProjectAccessType
@@ -69,6 +70,17 @@ def _get_project_read_credentials(self, access_context: FileAccessContext):
6970
7071 return self ._read_token_cache [project_id ]
7172
73+ def get_aws_s3_client (self , access_context : FileAccessContext ) -> BaseClient :
74+ """
75+ Gets the underlying AWS S3 client to perform operations on files
76+
77+ This is seeded with refreshable credentials from the access_context parameter
78+
79+ This may be used to perform advanced operations, such as CopyObject, S3 Select, etc.
80+ """
81+ s3_client = self ._generate_s3_client (access_context )
82+ return s3_client .get_aws_client ()
83+
7284 def get_file (self , file : File ) -> bytes :
7385 """
7486 Gets the contents of a file
@@ -92,11 +104,7 @@ def get_file_from_path(self, access_context: FileAccessContext, file_path: str)
92104 Returns:
93105 The raw bytes of the file
94106 """
95-
96- s3_client = S3Client (
97- partial (self .get_access_credentials , access_context ),
98- self .enable_additional_checksum
99- )
107+ s3_client = self ._generate_s3_client (access_context )
100108
101109 full_path = f'{ access_context .prefix } /{ file_path } ' .lstrip ('/' )
102110
@@ -113,11 +121,7 @@ def create_file(self, access_context: FileAccessContext, key: str,
113121 contents (str): Content of object
114122 content_type (str):
115123 """
116-
117- s3_client = S3Client (
118- partial (self .get_access_credentials , access_context ),
119- self .enable_additional_checksum
120- )
124+ s3_client = self ._generate_s3_client (access_context )
121125
122126 s3_client .create_object (
123127 key = key ,
@@ -129,7 +133,8 @@ def create_file(self, access_context: FileAccessContext, key: str,
129133 def upload_files (self ,
130134 access_context : FileAccessContext ,
131135 directory : PathLike ,
132- files : List [PathLike ]) -> None :
136+ files : List [PathLike ],
137+ file_path_map : Dict [PathLike , str ]) -> None :
133138 """
134139 Uploads a list of files from the specified directory
135140
@@ -138,19 +143,18 @@ def upload_files(self,
138143 directory (str|Path): Path to directory
139144 files (typing.List[str|Path]): List of paths to files within the directory
140145 must be the same type as directory.
146+ file_path_map (typing.Dict[str|Path, str]): Optional mapping of file paths to upload
147+ from source path to destination path, used to "re-write" paths within the dataset.
141148 """
142-
143- s3_client = S3Client (
144- partial (self .get_access_credentials , access_context ),
145- self .enable_additional_checksum
146- )
149+ s3_client = self ._generate_s3_client (access_context )
147150
148151 upload_directory (
149- directory ,
150- files ,
151- s3_client ,
152- access_context .bucket ,
153- access_context .prefix ,
152+ directory = directory ,
153+ files = files ,
154+ file_path_map = file_path_map ,
155+ s3_client = s3_client ,
156+ bucket = access_context .bucket ,
157+ prefix = access_context .prefix ,
154158 max_retries = self .transfer_retries
155159 )
156160
@@ -163,10 +167,7 @@ def download_files(self, access_context: FileAccessContext, directory: str, file
163167 directory (str): download location
164168 files (List[str]): relative path of files to download
165169 """
166- s3_client = S3Client (
167- partial (self .get_access_credentials , access_context ),
168- self .enable_additional_checksum
169- )
170+ s3_client = self ._generate_s3_client (access_context )
170171
171172 download_directory (
172173 directory ,
@@ -176,6 +177,15 @@ def download_files(self, access_context: FileAccessContext, directory: str, file
176177 access_context .prefix
177178 )
178179
180+ def _generate_s3_client (self , access_context : FileAccessContext ):
181+ """
182+ Generates the Cirro-S3 client to perform operations on files
183+ """
184+ return S3Client (
185+ partial (self .get_access_credentials , access_context ),
186+ self .enable_additional_checksum
187+ )
188+
179189
180190class FileEnabledService (BaseService ):
181191 """
0 commit comments