@@ -180,15 +180,15 @@ def test_create_model(sagemaker_session, mxnet_version):
180180 container_log_level = '"logging.INFO"'
181181 source_dir = "s3://mybucket/source"
182182 mx = MXNet (
183- entry_point = SCRIPT_PATH ,
183+ entry_point = SCRIPT_NAME ,
184+ source_dir = source_dir ,
184185 role = ROLE ,
185186 sagemaker_session = sagemaker_session ,
186187 train_instance_count = INSTANCE_COUNT ,
187188 train_instance_type = INSTANCE_TYPE ,
188189 framework_version = mxnet_version ,
189190 container_log_level = container_log_level ,
190191 base_job_name = "job" ,
191- source_dir = source_dir ,
192192 )
193193
194194 job_name = "new_name"
@@ -198,6 +198,7 @@ def test_create_model(sagemaker_session, mxnet_version):
198198 assert model .sagemaker_session == sagemaker_session
199199 assert model .framework_version == mxnet_version
200200 assert model .py_version == mx .py_version
201+ assert model .entry_point == SCRIPT_NAME
201202 assert model .role == ROLE
202203 assert model .name == job_name
203204 assert model .container_log_level == container_log_level
@@ -206,55 +207,19 @@ def test_create_model(sagemaker_session, mxnet_version):
206207 assert model .vpc_config is None
207208
208209
209- @patch ("sagemaker.utils.create_tar_file" , MagicMock ())
210- def test_create_model_default_entry_with_mms (
211- sagemaker_session , mxnet_version , skip_if_not_mms_version
212- ):
213- mx = MXNet (
214- entry_point = SCRIPT_PATH ,
215- role = ROLE ,
216- sagemaker_session = sagemaker_session ,
217- train_instance_count = INSTANCE_COUNT ,
218- train_instance_type = INSTANCE_TYPE ,
219- framework_version = mxnet_version ,
220- )
221-
222- mx .fit ()
223- model = mx .create_model ()
224-
225- assert model .entry_point == SCRIPT_PATH
226-
227-
228- @patch ("sagemaker.utils.create_tar_file" , MagicMock ())
229- def test_create_model_default_entry_not_mms (sagemaker_session , mxnet_version , skip_if_mms_version ):
230- mx = MXNet (
231- entry_point = SCRIPT_PATH ,
232- role = ROLE ,
233- sagemaker_session = sagemaker_session ,
234- train_instance_count = INSTANCE_COUNT ,
235- train_instance_type = INSTANCE_TYPE ,
236- framework_version = mxnet_version ,
237- )
238-
239- mx .fit ()
240- model = mx .create_model ()
241-
242- assert model .entry_point == SCRIPT_NAME
243-
244-
245210def test_create_model_with_optional_params (sagemaker_session ):
246211 container_log_level = '"logging.INFO"'
247212 source_dir = "s3://mybucket/source"
248213 enable_cloudwatch_metrics = "true"
249214 mx = MXNet (
250- entry_point = SCRIPT_PATH ,
215+ entry_point = SCRIPT_NAME ,
216+ source_dir = source_dir ,
251217 role = ROLE ,
252218 sagemaker_session = sagemaker_session ,
253219 train_instance_count = INSTANCE_COUNT ,
254220 train_instance_type = INSTANCE_TYPE ,
255221 container_log_level = container_log_level ,
256222 base_job_name = "job" ,
257- source_dir = source_dir ,
258223 enable_cloudwatch_metrics = enable_cloudwatch_metrics ,
259224 )
260225
@@ -286,15 +251,15 @@ def test_create_model_with_custom_image(sagemaker_session):
286251 source_dir = "s3://mybucket/source"
287252 custom_image = "mxnet:2.0"
288253 mx = MXNet (
289- entry_point = SCRIPT_PATH ,
254+ entry_point = SCRIPT_NAME ,
255+ source_dir = source_dir ,
290256 role = ROLE ,
291257 sagemaker_session = sagemaker_session ,
292258 train_instance_count = INSTANCE_COUNT ,
293259 train_instance_type = INSTANCE_TYPE ,
294260 image_name = custom_image ,
295261 container_log_level = container_log_level ,
296262 base_job_name = "job" ,
297- source_dir = source_dir ,
298263 )
299264
300265 job_name = "new_name"
@@ -303,7 +268,7 @@ def test_create_model_with_custom_image(sagemaker_session):
303268
304269 assert model .sagemaker_session == sagemaker_session
305270 assert model .image == custom_image
306- assert model .entry_point == SCRIPT_PATH
271+ assert model .entry_point == SCRIPT_NAME
307272 assert model .role == ROLE
308273 assert model .name == job_name
309274 assert model .container_log_level == container_log_level
@@ -823,7 +788,6 @@ def test_create_model_with_custom_hosting_image(sagemaker_session):
823788 image_name = custom_image ,
824789 container_log_level = container_log_level ,
825790 base_job_name = "job" ,
826- source_dir = source_dir ,
827791 )
828792
829793 mx .fit (inputs = "s3://mybucket/train" , job_name = "new_name" )
0 commit comments