@@ -133,14 +133,18 @@ def data_location(self, data_location: str):
133133
134134 if not data_location .startswith ("s3://" ):
135135 raise ValueError (
136- 'Expecting an S3 URL beginning with "s3://". Got "{}"' .format (data_location )
136+ 'Expecting an S3 URL beginning with "s3://". Got "{}"' .format (
137+ data_location
138+ )
137139 )
138140 if data_location [- 1 ] != "/" :
139141 data_location = data_location + "/"
140142 self ._data_location = data_location
141143
142144 @classmethod
143- def _prepare_init_params_from_job_description (cls , job_details , model_channel_name = None ):
145+ def _prepare_init_params_from_job_description (
146+ cls , job_details , model_channel_name = None
147+ ):
144148 """Convert the job description to init params that can be handled by the class constructor.
145149
146150 Args:
@@ -168,7 +172,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
168172 del init_params ["image_uri" ]
169173 return init_params
170174
171- def prepare_workflow_for_training (self , records = None , mini_batch_size = None , job_name = None ):
175+ def prepare_workflow_for_training (
176+ self , records = None , mini_batch_size = None , job_name = None
177+ ):
172178 """Calls _prepare_for_training. Used when setting up a workflow.
173179
174180 Args:
@@ -194,7 +200,9 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
194200 specified, one is generated, using the base name given to the
195201 constructor if applicable.
196202 """
197- super (AmazonAlgorithmEstimatorBase , self )._prepare_for_training (job_name = job_name )
203+ super (AmazonAlgorithmEstimatorBase , self )._prepare_for_training (
204+ job_name = job_name
205+ )
198206
199207 feature_dim = None
200208
@@ -260,7 +268,9 @@ def fit(
260268 will be unassociated.
261269 * `TrialComponentDisplayName` is used for display in Studio.
262270 """
263- self ._prepare_for_training (records , job_name = job_name , mini_batch_size = mini_batch_size )
271+ self ._prepare_for_training (
272+ records , job_name = job_name , mini_batch_size = mini_batch_size
273+ )
264274
265275 experiment_config = check_and_get_run_experiment_config (experiment_config )
266276 self .latest_training_job = _TrainingJob .start_new (
@@ -269,12 +279,14 @@ def fit(
269279 if wait :
270280 self .latest_training_job .wait (logs = logs )
271281
272- def record_set (self ,
273- train ,
274- labels = None ,
275- channel = "train" ,
276- encrypt = False ,
277- distribution = "ShardedByS3Key" ):
282+ def record_set (
283+ self ,
284+ train ,
285+ labels = None ,
286+ channel = "train" ,
287+ encrypt = False ,
288+ distribution = "ShardedByS3Key" ,
289+ ):
278290 """Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector.
279291
280292 For the 2D ``ndarray`` ``train``, each row is converted to a
@@ -311,7 +323,9 @@ def record_set(self,
311323 )
312324 parsed_s3_url = urlparse (self .data_location )
313325 bucket , key_prefix = parsed_s3_url .netloc , parsed_s3_url .path
314- key_prefix = key_prefix + "{}-{}/" .format (type (self ).__name__ , sagemaker_timestamp ())
326+ key_prefix = key_prefix + "{}-{}/" .format (
327+ type (self ).__name__ , sagemaker_timestamp ()
328+ )
315329 key_prefix = key_prefix .lstrip ("/" )
316330 logger .debug ("Uploading to bucket %s and key_prefix %s" , bucket , key_prefix )
317331 manifest_s3_file = upload_numpy_to_s3_shards (
@@ -338,7 +352,9 @@ def _get_default_mini_batch_size(self, num_records: int):
338352 )
339353 return 1
340354
341- return min (self .DEFAULT_MINI_BATCH_SIZE , max (1 , int (num_records / self .instance_count )))
355+ return min (
356+ self .DEFAULT_MINI_BATCH_SIZE , max (1 , int (num_records / self .instance_count ))
357+ )
342358
343359
344360class RecordSet (object ):
@@ -447,7 +463,10 @@ def _build_shards(num_shards, array):
447463 shard_size = int (array .shape [0 ] / num_shards )
448464 if shard_size == 0 :
449465 raise ValueError ("Array length is less than num shards" )
450- shards = [array [i * shard_size : i * shard_size + shard_size ] for i in range (num_shards - 1 )]
466+ shards = [
467+ array [i * shard_size : i * shard_size + shard_size ]
468+ for i in range (num_shards - 1 )
469+ ]
451470 shards .append (array [(num_shards - 1 ) * shard_size :])
452471 return shards
453472
@@ -494,7 +513,9 @@ def upload_numpy_to_s3_shards(
494513 manifest_str = json .dumps (
495514 [{"prefix" : "s3://{}/{}" .format (bucket , key_prefix )}] + uploaded_files
496515 )
497- s3 .Object (bucket , manifest_key ).put (Body = manifest_str .encode ("utf-8" ), ** extra_put_kwargs )
516+ s3 .Object (bucket , manifest_key ).put (
517+ Body = manifest_str .encode ("utf-8" ), ** extra_put_kwargs
518+ )
498519 return "s3://{}/{}" .format (bucket , manifest_key )
499520 except Exception as ex : # pylint: disable=broad-except
500521 try :
0 commit comments