6767 ConditionLessThanOrEqualTo ,
6868)
6969from sagemaker .workflow .condition_step import ConditionStep
70- from sagemaker .workflow .callback_step import CallbackStep , CallbackOutput , CallbackOutputTypeEnum
71- from sagemaker .workflow .lambda_step import LambdaStep , LambdaOutput , LambdaOutputTypeEnum
70+ from sagemaker .workflow .callback_step import (
71+ CallbackStep ,
72+ CallbackOutput ,
73+ CallbackOutputTypeEnum ,
74+ )
75+ from sagemaker .workflow .lambda_step import (
76+ LambdaStep ,
77+ LambdaOutput ,
78+ LambdaOutputTypeEnum ,
79+ )
7280from sagemaker .workflow .emr_step import EMRStep , EMRStepConfig
7381from sagemaker .wrangler .processing import DataWranglerProcessor
74- from sagemaker .dataset_definition .inputs import DatasetDefinition , AthenaDatasetDefinition
82+ from sagemaker .dataset_definition .inputs import (
83+ DatasetDefinition ,
84+ AthenaDatasetDefinition ,
85+ )
7586from sagemaker .workflow .execution_variables import ExecutionVariables
7687from sagemaker .workflow .functions import Join , JsonGet
7788from sagemaker .wrangler .ingestion import generate_data_ingestion_flow_from_s3_input
92103from sagemaker .workflow .step_collections import RegisterModel
93104from sagemaker .workflow .pipeline import Pipeline
94105from sagemaker .lambda_helper import Lambda
95- from sagemaker .feature_store .feature_group import FeatureGroup , FeatureDefinition , FeatureTypeEnum
106+ from sagemaker .feature_store .feature_group import (
107+ FeatureGroup ,
108+ FeatureDefinition ,
109+ FeatureTypeEnum ,
110+ )
96111from tests .integ import DATA_DIR
97112from tests .integ .kms_utils import get_or_create_kms_key
98113from tests .integ .retry import retries
@@ -262,7 +277,10 @@ def build_jar():
262277 )
263278 else :
264279 subprocess .run (
265- ["javac" , os .path .join (jar_file_path , java_file_path , "HelloJavaSparkApp.java" )]
280+ [
281+ "javac" ,
282+ os .path .join (jar_file_path , java_file_path , "HelloJavaSparkApp.java" ),
283+ ]
266284 )
267285
268286 subprocess .run (
@@ -383,10 +401,20 @@ def test_three_step_definition(
383401 assert set (tuple (param .items ()) for param in definition ["Parameters" ]) == set (
384402 [
385403 tuple (
386- {"Name" : "InstanceType" , "Type" : "String" , "DefaultValue" : "ml.m5.xlarge" }.items ()
404+ {
405+ "Name" : "InstanceType" ,
406+ "Type" : "String" ,
407+ "DefaultValue" : "ml.m5.xlarge" ,
408+ }.items ()
387409 ),
388410 tuple ({"Name" : "InstanceCount" , "Type" : "Integer" , "DefaultValue" : 1 }.items ()),
389- tuple ({"Name" : "OutputPrefix" , "Type" : "String" , "DefaultValue" : "output" }.items ()),
411+ tuple (
412+ {
413+ "Name" : "OutputPrefix" ,
414+ "Type" : "String" ,
415+ "DefaultValue" : "output" ,
416+ }.items ()
417+ ),
390418 ]
391419 )
392420
@@ -740,7 +768,13 @@ def test_one_step_pyspark_processing_pipeline(
740768
741769
742770def test_one_step_sparkjar_processing_pipeline (
743- sagemaker_session , role , cpu_instance_type , pipeline_name , region_name , configuration , build_jar
771+ sagemaker_session ,
772+ role ,
773+ cpu_instance_type ,
774+ pipeline_name ,
775+ region_name ,
776+ configuration ,
777+ build_jar ,
744778):
745779 instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
746780 cache_config = CacheConfig (enable_caching = True , expire_after = "T30m" )
@@ -758,7 +792,9 @@ def test_one_step_sparkjar_processing_pipeline(
758792 body = data .read ()
759793 input_data_uri = f"s3://{ bucket } /spark/input/data.jsonl"
760794 S3Uploader .upload_string_as_file_body (
761- body = body , desired_s3_uri = input_data_uri , sagemaker_session = sagemaker_session
795+ body = body ,
796+ desired_s3_uri = input_data_uri ,
797+ sagemaker_session = sagemaker_session ,
762798 )
763799 output_data_uri = f"s3://{ bucket } /spark/output/sales/{ datetime .now ().isoformat ()} "
764800
@@ -877,7 +913,12 @@ def test_one_step_callback_pipeline(sagemaker_session, role, pipeline_name, regi
877913
878914
879915def test_steps_with_map_params_pipeline (
880- sagemaker_session , role , script_dir , pipeline_name , region_name , athena_dataset_definition
916+ sagemaker_session ,
917+ role ,
918+ script_dir ,
919+ pipeline_name ,
920+ region_name ,
921+ athena_dataset_definition ,
881922):
882923 instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
883924 framework_version = "0.20.0"
@@ -1184,7 +1225,8 @@ def test_two_steps_emr_pipeline(sagemaker_session, role, pipeline_name, region_n
11841225 response = pipeline .create (role )
11851226 create_arn = response ["PipelineArn" ]
11861227 assert re .match (
1187- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
1228+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
1229+ create_arn ,
11881230 )
11891231 finally :
11901232 try :
@@ -1267,7 +1309,12 @@ def test_conditional_pytorch_training_model_registration(
12671309
12681310 pipeline = Pipeline (
12691311 name = pipeline_name ,
1270- parameters = [in_condition_input , good_enough_input , instance_count , instance_type ],
1312+ parameters = [
1313+ in_condition_input ,
1314+ good_enough_input ,
1315+ instance_count ,
1316+ instance_type ,
1317+ ],
12711318 steps = [step_cond ],
12721319 sagemaker_session = sagemaker_session ,
12731320 )
@@ -1276,7 +1323,8 @@ def test_conditional_pytorch_training_model_registration(
12761323 response = pipeline .create (role )
12771324 create_arn = response ["PipelineArn" ]
12781325 assert re .match (
1279- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
1326+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
1327+ create_arn ,
12801328 )
12811329
12821330 execution = pipeline .start (parameters = {})
@@ -1395,7 +1443,8 @@ def test_tuning_single_algo(
13951443 response = pipeline .create (role )
13961444 create_arn = response ["PipelineArn" ]
13971445 assert re .match (
1398- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
1446+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
1447+ create_arn ,
13991448 )
14001449
14011450 execution = pipeline .start (parameters = {})
@@ -1522,7 +1571,7 @@ def test_tuning_multi_algos(
15221571 response = pipeline .create (role )
15231572 create_arn = response ["PipelineArn" ]
15241573 assert re .match (
1525- rf "arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
1574+ fr "arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
15261575 )
15271576
15281577 execution = pipeline .start (parameters = {})
@@ -1583,7 +1632,8 @@ def test_mxnet_model_registration(
15831632 response = pipeline .create (role )
15841633 create_arn = response ["PipelineArn" ]
15851634 assert re .match (
1586- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
1635+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
1636+ create_arn ,
15871637 )
15881638
15891639 execution = pipeline .start (parameters = {})
@@ -1655,10 +1705,14 @@ def test_sklearn_xgboost_sip_model_registration(
16551705 destination = train_data_path_param ,
16561706 ),
16571707 ProcessingOutput (
1658- output_name = "val_data" , source = "/opt/ml/processing/val" , destination = val_data_path_param
1708+ output_name = "val_data" ,
1709+ source = "/opt/ml/processing/val" ,
1710+ destination = val_data_path_param ,
16591711 ),
16601712 ProcessingOutput (
1661- output_name = "model" , source = "/opt/ml/processing/model" , destination = model_path_param
1713+ output_name = "model" ,
1714+ source = "/opt/ml/processing/model" ,
1715+ destination = model_path_param ,
16621716 ),
16631717 ]
16641718
@@ -1775,7 +1829,8 @@ def test_sklearn_xgboost_sip_model_registration(
17751829 response = pipeline .upsert (role_arn = role )
17761830 create_arn = response ["PipelineArn" ]
17771831 assert re .match (
1778- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
1832+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
1833+ create_arn ,
17791834 )
17801835
17811836 execution = pipeline .start (parameters = {})
@@ -1831,7 +1886,9 @@ def test_model_registration_with_drift_check_baselines(
18311886 utils .unique_name_from_base ("metrics" ),
18321887 )
18331888 metrics_uri = S3Uploader .upload_string_as_file_body (
1834- body = metrics_data , desired_s3_uri = metrics_base_uri , sagemaker_session = sagemaker_session
1889+ body = metrics_data ,
1890+ desired_s3_uri = metrics_base_uri ,
1891+ sagemaker_session = sagemaker_session ,
18351892 )
18361893 metrics_uri_param = ParameterString (name = "metrics_uri" , default_value = metrics_uri )
18371894
@@ -2070,7 +2127,8 @@ def test_model_registration_with_model_repack(
20702127 response = pipeline .create (role )
20712128 create_arn = response ["PipelineArn" ]
20722129 assert re .match (
2073- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
2130+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
2131+ create_arn ,
20742132 )
20752133
20762134 execution = pipeline .start (parameters = {})
@@ -2417,13 +2475,17 @@ def test_one_step_ingestion_pipeline(
24172475 input_name = "features.csv"
24182476 input_file_path = os .path .join (DATA_DIR , "workflow" , "features.csv" )
24192477 input_data_uri = os .path .join (
2420- "s3://" , sagemaker_session .default_bucket (), "py-sdk-ingestion-test-input/features.csv"
2478+ "s3://" ,
2479+ sagemaker_session .default_bucket (),
2480+ "py-sdk-ingestion-test-input/features.csv" ,
24212481 )
24222482
24232483 with open (input_file_path , "r" ) as data :
24242484 body = data .read ()
24252485 S3Uploader .upload_string_as_file_body (
2426- body = body , desired_s3_uri = input_data_uri , sagemaker_session = sagemaker_session
2486+ body = body ,
2487+ desired_s3_uri = input_data_uri ,
2488+ sagemaker_session = sagemaker_session ,
24272489 )
24282490
24292491 inputs = [
@@ -2735,7 +2797,9 @@ def test_end_to_end_pipeline_successful_execution(
27352797 sagemaker_session = sagemaker_session ,
27362798 )
27372799 step_transform = TransformStep (
2738- name = "AbaloneTransform" , transformer = transformer , inputs = TransformInput (data = batch_data )
2800+ name = "AbaloneTransform" ,
2801+ transformer = transformer ,
2802+ inputs = TransformInput (data = batch_data ),
27392803 )
27402804
27412805 # define register model step
0 commit comments