@@ -32,11 +32,16 @@ def mxnet_training_job(
3232 sagemaker_session , mxnet_full_version , mxnet_full_py_version , cpu_instance_type
3333):
3434 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
35- script_path = os . path . join ( DATA_DIR , " mxnet_mnist", "mnist.py" )
35+ s3_prefix = "integ-test-data/ mxnet_mnist"
3636 data_path = os .path .join (DATA_DIR , "mxnet_mnist" )
3737
38+ s3_source = sagemaker_session .upload_data (
39+ path = os .path .join (data_path , "sourcedir.tar.gz" ), key_prefix = "{}/src" .format (s3_prefix )
40+ )
41+
3842 mx = MXNet (
39- entry_point = script_path ,
43+ entry_point = "mxnet_mnist/mnist.py" ,
44+ source_dir = s3_source ,
4045 role = "SageMakerRole" ,
4146 framework_version = mxnet_full_version ,
4247 py_version = mxnet_full_py_version ,
@@ -46,10 +51,10 @@ def mxnet_training_job(
4651 )
4752
4853 train_input = mx .sagemaker_session .upload_data (
49- path = os .path .join (data_path , "train" ), key_prefix = "integ-test-data/mxnet_mnist/ train"
54+ path = os .path .join (data_path , "train" ), key_prefix = "{}/ train" . format ( s3_prefix )
5055 )
5156 test_input = mx .sagemaker_session .upload_data (
52- path = os .path .join (data_path , "test" ), key_prefix = "integ-test-data/mxnet_mnist/ test"
57+ path = os .path .join (data_path , "test" ), key_prefix = "{}/ test" . format ( s3_prefix )
5358 )
5459
5560 mx .fit ({"train" : train_input , "test" : test_input })
@@ -62,7 +67,13 @@ def test_attach_deploy(mxnet_training_job, sagemaker_session, cpu_instance_type)
6267
6368 with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
6469 estimator = MXNet .attach (mxnet_training_job , sagemaker_session = sagemaker_session )
65- predictor = estimator .deploy (1 , cpu_instance_type , endpoint_name = endpoint_name )
70+ predictor = estimator .deploy (
71+ 1 ,
72+ cpu_instance_type ,
73+ entry_point = "mnist.py" ,
74+ source_dir = os .path .join (DATA_DIR , "mxnet_mnist" ),
75+ endpoint_name = endpoint_name ,
76+ )
6677 data = numpy .zeros (shape = (1 , 1 , 28 , 28 ))
6778 result = predictor .predict (data )
6879 assert result is not None
0 commit comments