@@ -199,6 +199,253 @@ def test_conditional_pytorch_training_model_registration(
199199 pass
200200
201201
202+ def test_conditional_pytorch_training_model_registration_without_instance_types (
203+ sagemaker_session ,
204+ role ,
205+ cpu_instance_type ,
206+ pipeline_name ,
207+ region_name ,
208+ ):
209+ base_dir = os .path .join (DATA_DIR , "pytorch_mnist" )
210+ entry_point = os .path .join (base_dir , "mnist.py" )
211+ input_path = sagemaker_session .upload_data (
212+ path = os .path .join (base_dir , "training" ),
213+ key_prefix = "integ-test-data/pytorch_mnist/training" ,
214+ )
215+ inputs = TrainingInput (s3_data = input_path )
216+
217+ instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
218+ instance_type = "ml.m5.xlarge"
219+ good_enough_input = ParameterInteger (name = "GoodEnoughInput" , default_value = 1 )
220+ in_condition_input = ParameterString (name = "Foo" , default_value = "Foo" )
221+
222+ task = "IMAGE_CLASSIFICATION"
223+ sample_payload_url = "s3://test-bucket/model"
224+ framework = "TENSORFLOW"
225+ framework_version = "2.9"
226+ nearest_model_name = "resnet50"
227+ data_input_configuration = '{"input_1":[1,224,224,3]}'
228+
229+ # If image_uri is not provided, the instance_type should not be a pipeline variable
230+ # since instance_type is used to retrieve image_uri in compile time (PySDK)
231+ pytorch_estimator = PyTorch (
232+ entry_point = entry_point ,
233+ role = role ,
234+ framework_version = "1.5.0" ,
235+ py_version = "py3" ,
236+ instance_count = instance_count ,
237+ instance_type = instance_type ,
238+ sagemaker_session = sagemaker_session ,
239+ )
240+ step_train = TrainingStep (
241+ name = "pytorch-train" ,
242+ estimator = pytorch_estimator ,
243+ inputs = inputs ,
244+ )
245+
246+ step_register = RegisterModel (
247+ name = "pytorch-register-model" ,
248+ estimator = pytorch_estimator ,
249+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
250+ content_types = ["*" ],
251+ response_types = ["*" ],
252+ description = "test-description" ,
253+ sample_payload_url = sample_payload_url ,
254+ task = task ,
255+ framework = framework ,
256+ framework_version = framework_version ,
257+ nearest_model_name = nearest_model_name ,
258+ data_input_configuration = data_input_configuration ,
259+ )
260+
261+ model = Model (
262+ image_uri = pytorch_estimator .training_image_uri (),
263+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
264+ sagemaker_session = sagemaker_session ,
265+ role = role ,
266+ )
267+ model_inputs = CreateModelInput (
268+ instance_type = "ml.m5.large" ,
269+ accelerator_type = "ml.eia1.medium" ,
270+ )
271+ step_model = CreateModelStep (
272+ name = "pytorch-model" ,
273+ model = model ,
274+ inputs = model_inputs ,
275+ )
276+
277+ step_cond = ConditionStep (
278+ name = "cond-good-enough" ,
279+ conditions = [
280+ ConditionGreaterThanOrEqualTo (left = good_enough_input , right = 1 ),
281+ ConditionIn (value = in_condition_input , in_values = ["foo" , "bar" ]),
282+ ],
283+ if_steps = [step_register ],
284+ else_steps = [step_model ],
285+ depends_on = [step_train ],
286+ )
287+
288+ pipeline = Pipeline (
289+ name = pipeline_name ,
290+ parameters = [
291+ in_condition_input ,
292+ good_enough_input ,
293+ instance_count ,
294+ ],
295+ steps = [step_train , step_cond ],
296+ sagemaker_session = sagemaker_session ,
297+ )
298+
299+ try :
300+ response = pipeline .create (role )
301+ create_arn = response ["PipelineArn" ]
302+ assert re .match (
303+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
304+ create_arn ,
305+ )
306+
307+ execution = pipeline .start (parameters = {})
308+ assert re .match (
309+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
310+ execution .arn ,
311+ )
312+
313+ execution = pipeline .start (parameters = {"GoodEnoughInput" : 0 })
314+ assert re .match (
315+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
316+ execution .arn ,
317+ )
318+ finally :
319+ try :
320+ pipeline .delete ()
321+ except Exception :
322+ pass
323+
324+
325+ def test_conditional_pytorch_training_model_registration_with_one_instance_types (
326+ sagemaker_session ,
327+ role ,
328+ cpu_instance_type ,
329+ pipeline_name ,
330+ region_name ,
331+ ):
332+ base_dir = os .path .join (DATA_DIR , "pytorch_mnist" )
333+ entry_point = os .path .join (base_dir , "mnist.py" )
334+ input_path = sagemaker_session .upload_data (
335+ path = os .path .join (base_dir , "training" ),
336+ key_prefix = "integ-test-data/pytorch_mnist/training" ,
337+ )
338+ inputs = TrainingInput (s3_data = input_path )
339+
340+ instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
341+ instance_type = "ml.m5.xlarge"
342+ good_enough_input = ParameterInteger (name = "GoodEnoughInput" , default_value = 1 )
343+ in_condition_input = ParameterString (name = "Foo" , default_value = "Foo" )
344+
345+ task = "IMAGE_CLASSIFICATION"
346+ sample_payload_url = "s3://test-bucket/model"
347+ framework = "TENSORFLOW"
348+ framework_version = "2.9"
349+ nearest_model_name = "resnet50"
350+ data_input_configuration = '{"input_1":[1,224,224,3]}'
351+
352+ # If image_uri is not provided, the instance_type should not be a pipeline variable
353+ # since instance_type is used to retrieve image_uri in compile time (PySDK)
354+ pytorch_estimator = PyTorch (
355+ entry_point = entry_point ,
356+ role = role ,
357+ framework_version = "1.5.0" ,
358+ py_version = "py3" ,
359+ instance_count = instance_count ,
360+ instance_type = instance_type ,
361+ sagemaker_session = sagemaker_session ,
362+ )
363+ step_train = TrainingStep (
364+ name = "pytorch-train" ,
365+ estimator = pytorch_estimator ,
366+ inputs = inputs ,
367+ )
368+
369+ step_register = RegisterModel (
370+ name = "pytorch-register-model" ,
371+ estimator = pytorch_estimator ,
372+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
373+ content_types = ["*" ],
374+ response_types = ["*" ],
375+ inference_instances = ["*" ],
376+ description = "test-description" ,
377+ sample_payload_url = sample_payload_url ,
378+ task = task ,
379+ framework = framework ,
380+ framework_version = framework_version ,
381+ nearest_model_name = nearest_model_name ,
382+ data_input_configuration = data_input_configuration ,
383+ )
384+
385+ model = Model (
386+ image_uri = pytorch_estimator .training_image_uri (),
387+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
388+ sagemaker_session = sagemaker_session ,
389+ role = role ,
390+ )
391+ model_inputs = CreateModelInput (
392+ instance_type = "ml.m5.large" ,
393+ accelerator_type = "ml.eia1.medium" ,
394+ )
395+ step_model = CreateModelStep (
396+ name = "pytorch-model" ,
397+ model = model ,
398+ inputs = model_inputs ,
399+ )
400+
401+ step_cond = ConditionStep (
402+ name = "cond-good-enough" ,
403+ conditions = [
404+ ConditionGreaterThanOrEqualTo (left = good_enough_input , right = 1 ),
405+ ConditionIn (value = in_condition_input , in_values = ["foo" , "bar" ]),
406+ ],
407+ if_steps = [step_register ],
408+ else_steps = [step_model ],
409+ depends_on = [step_train ],
410+ )
411+
412+ pipeline = Pipeline (
413+ name = pipeline_name ,
414+ parameters = [
415+ in_condition_input ,
416+ good_enough_input ,
417+ instance_count ,
418+ ],
419+ steps = [step_train , step_cond ],
420+ sagemaker_session = sagemaker_session ,
421+ )
422+
423+ try :
424+ response = pipeline .create (role )
425+ create_arn = response ["PipelineArn" ]
426+ assert re .match (
427+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
428+ create_arn ,
429+ )
430+
431+ execution = pipeline .start (parameters = {})
432+ assert re .match (
433+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
434+ execution .arn ,
435+ )
436+
437+ execution = pipeline .start (parameters = {"GoodEnoughInput" : 0 })
438+ assert re .match (
439+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } /execution/" ,
440+ execution .arn ,
441+ )
442+ finally :
443+ try :
444+ pipeline .delete ()
445+ except Exception :
446+ pass
447+
448+
202449def test_mxnet_model_registration (
203450 sagemaker_session ,
204451 role ,
0 commit comments