@@ -199,253 +199,6 @@ def test_conditional_pytorch_training_model_registration(
199199 pass
200200
201201
202- def test_conditional_pytorch_training_model_registration_without_instance_types (
203- sagemaker_session ,
204- role ,
205- cpu_instance_type ,
206- pipeline_name ,
207- region_name ,
208- ):
209- base_dir = os .path .join (DATA_DIR , "pytorch_mnist" )
210- entry_point = os .path .join (base_dir , "mnist.py" )
211- input_path = sagemaker_session .upload_data (
212- path = os .path .join (base_dir , "training" ),
213- key_prefix = "integ-test-data/pytorch_mnist/training" ,
214- )
215- inputs = TrainingInput (s3_data = input_path )
216-
217- instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
218- instance_type = "ml.m5.xlarge"
219- good_enough_input = ParameterInteger (name = "GoodEnoughInput" , default_value = 1 )
220- in_condition_input = ParameterString (name = "Foo" , default_value = "Foo" )
221-
222- task = "IMAGE_CLASSIFICATION"
223- sample_payload_url = "s3://test-bucket/model"
224- framework = "TENSORFLOW"
225- framework_version = "2.9"
226- nearest_model_name = "resnet50"
227- data_input_configuration = '{"input_1":[1,224,224,3]}'
228-
229- # If image_uri is not provided, the instance_type should not be a pipeline variable
230- # since instance_type is used to retrieve image_uri in compile time (PySDK)
231- pytorch_estimator = PyTorch (
232- entry_point = entry_point ,
233- role = role ,
234- framework_version = "1.5.0" ,
235- py_version = "py3" ,
236- instance_count = instance_count ,
237- instance_type = instance_type ,
238- sagemaker_session = sagemaker_session ,
239- )
240- step_train = TrainingStep (
241- name = "pytorch-train" ,
242- estimator = pytorch_estimator ,
243- inputs = inputs ,
244- )
245-
246- step_register = RegisterModel (
247- name = "pytorch-register-model" ,
248- estimator = pytorch_estimator ,
249- model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
250- content_types = ["*" ],
251- response_types = ["*" ],
252- description = "test-description" ,
253- sample_payload_url = sample_payload_url ,
254- task = task ,
255- framework = framework ,
256- framework_version = framework_version ,
257- nearest_model_name = nearest_model_name ,
258- data_input_configuration = data_input_configuration ,
259- )
260-
261- model = Model (
262- image_uri = pytorch_estimator .training_image_uri (),
263- model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
264- sagemaker_session = sagemaker_session ,
265- role = role ,
266- )
267- model_inputs = CreateModelInput (
268- instance_type = "ml.m5.large" ,
269- accelerator_type = "ml.eia1.medium" ,
270- )
271- step_model = CreateModelStep (
272- name = "pytorch-model" ,
273- model = model ,
274- inputs = model_inputs ,
275- )
276-
277- step_cond = ConditionStep (
278- name = "cond-good-enough" ,
279- conditions = [
280- ConditionGreaterThanOrEqualTo (left = good_enough_input , right = 1 ),
281- ConditionIn (value = in_condition_input , in_values = ["foo" , "bar" ]),
282- ],
283- if_steps = [step_register ],
284- else_steps = [step_model ],
285- depends_on = [step_train ],
286- )
287-
288- pipeline = Pipeline (
289- name = pipeline_name ,
290- parameters = [
291- in_condition_input ,
292- good_enough_input ,
293- instance_count ,
294- ],
295- steps = [step_train , step_cond ],
296- sagemaker_session = sagemaker_session ,
297- )
298-
299- try :
300- response = pipeline .create (role )
301- create_arn = response ["PipelineArn" ]
302- assert re .match (
303- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
304- create_arn ,
305- )
306-
307- execution = pipeline .start (parameters = {})
308- assert re .match (
309- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
310- execution .arn ,
311- )
312-
313- execution = pipeline .start (parameters = {"GoodEnoughInput" : 0 })
314- assert re .match (
315- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
316- execution .arn ,
317- )
318- finally :
319- try :
320- pipeline .delete ()
321- except Exception :
322- pass
323-
324-
325- def test_conditional_pytorch_training_model_registration_with_one_instance_types (
326- sagemaker_session ,
327- role ,
328- cpu_instance_type ,
329- pipeline_name ,
330- region_name ,
331- ):
332- base_dir = os .path .join (DATA_DIR , "pytorch_mnist" )
333- entry_point = os .path .join (base_dir , "mnist.py" )
334- input_path = sagemaker_session .upload_data (
335- path = os .path .join (base_dir , "training" ),
336- key_prefix = "integ-test-data/pytorch_mnist/training" ,
337- )
338- inputs = TrainingInput (s3_data = input_path )
339-
340- instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
341- instance_type = "ml.m5.xlarge"
342- good_enough_input = ParameterInteger (name = "GoodEnoughInput" , default_value = 1 )
343- in_condition_input = ParameterString (name = "Foo" , default_value = "Foo" )
344-
345- task = "IMAGE_CLASSIFICATION"
346- sample_payload_url = "s3://test-bucket/model"
347- framework = "TENSORFLOW"
348- framework_version = "2.9"
349- nearest_model_name = "resnet50"
350- data_input_configuration = '{"input_1":[1,224,224,3]}'
351-
352- # If image_uri is not provided, the instance_type should not be a pipeline variable
353- # since instance_type is used to retrieve image_uri in compile time (PySDK)
354- pytorch_estimator = PyTorch (
355- entry_point = entry_point ,
356- role = role ,
357- framework_version = "1.5.0" ,
358- py_version = "py3" ,
359- instance_count = instance_count ,
360- instance_type = instance_type ,
361- sagemaker_session = sagemaker_session ,
362- )
363- step_train = TrainingStep (
364- name = "pytorch-train" ,
365- estimator = pytorch_estimator ,
366- inputs = inputs ,
367- )
368-
369- step_register = RegisterModel (
370- name = "pytorch-register-model" ,
371- estimator = pytorch_estimator ,
372- model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
373- content_types = ["*" ],
374- response_types = ["*" ],
375- inference_instances = ["*" ],
376- description = "test-description" ,
377- sample_payload_url = sample_payload_url ,
378- task = task ,
379- framework = framework ,
380- framework_version = framework_version ,
381- nearest_model_name = nearest_model_name ,
382- data_input_configuration = data_input_configuration ,
383- )
384-
385- model = Model (
386- image_uri = pytorch_estimator .training_image_uri (),
387- model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
388- sagemaker_session = sagemaker_session ,
389- role = role ,
390- )
391- model_inputs = CreateModelInput (
392- instance_type = "ml.m5.large" ,
393- accelerator_type = "ml.eia1.medium" ,
394- )
395- step_model = CreateModelStep (
396- name = "pytorch-model" ,
397- model = model ,
398- inputs = model_inputs ,
399- )
400-
401- step_cond = ConditionStep (
402- name = "cond-good-enough" ,
403- conditions = [
404- ConditionGreaterThanOrEqualTo (left = good_enough_input , right = 1 ),
405- ConditionIn (value = in_condition_input , in_values = ["foo" , "bar" ]),
406- ],
407- if_steps = [step_register ],
408- else_steps = [step_model ],
409- depends_on = [step_train ],
410- )
411-
412- pipeline = Pipeline (
413- name = pipeline_name ,
414- parameters = [
415- in_condition_input ,
416- good_enough_input ,
417- instance_count ,
418- ],
419- steps = [step_train , step_cond ],
420- sagemaker_session = sagemaker_session ,
421- )
422-
423- try :
424- response = pipeline .create (role )
425- create_arn = response ["PipelineArn" ]
426- assert re .match (
427- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
428- create_arn ,
429- )
430-
431- execution = pipeline .start (parameters = {})
432- assert re .match (
433- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
434- execution .arn ,
435- )
436-
437- execution = pipeline .start (parameters = {"GoodEnoughInput" : 0 })
438- assert re .match (
439- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
440- execution .arn ,
441- )
442- finally :
443- try :
444- pipeline .delete ()
445- except Exception :
446- pass
447-
448-
449202def test_mxnet_model_registration (
450203 sagemaker_session ,
451204 role ,
0 commit comments