1919import pytest
2020
2121from sagemaker .tensorflow import TensorFlow
22- from sagemaker .utils import unique_name_from_base
22+ from sagemaker .utils import unique_name_from_base , sagemaker_timestamp
2323
2424import tests .integ
2525from tests .integ import timeout
3939TAGS = [{"Key" : "some-key" , "Value" : "some-value" }]
4040
4141
42- def test_mnist (sagemaker_session , instance_type ):
42+ def test_mnist_with_checkpoint_config (sagemaker_session , instance_type ):
43+ checkpoint_s3_uri = "s3://{}/tf-{}" .format (
44+ sagemaker_session .default_bucket (), sagemaker_timestamp ()
45+ )
46+ checkpoint_local_path = "/test/checkpoint/path"
4347 estimator = TensorFlow (
4448 entry_point = SCRIPT ,
4549 role = "SageMakerRole" ,
@@ -50,13 +54,16 @@ def test_mnist(sagemaker_session, instance_type):
5054 framework_version = TensorFlow .LATEST_VERSION ,
5155 py_version = tests .integ .PYTHON_VERSION ,
5256 metric_definitions = [{"Name" : "train:global_steps" , "Regex" : r"global_step\/sec:\s(.*)" }],
57+ checkpoint_s3_uri = checkpoint_s3_uri ,
58+ checkpoint_local_path = checkpoint_local_path
5359 )
5460 inputs = estimator .sagemaker_session .upload_data (
5561 path = os .path .join (MNIST_RESOURCE_PATH , "data" ), key_prefix = "scriptmode/mnist"
5662 )
5763
64+ training_job_name = unique_name_from_base ("test-tf-sm-mnist" )
5865 with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
59- estimator .fit (inputs = inputs , job_name = unique_name_from_base ( "test-tf-sm-mnist" ) )
66+ estimator .fit (inputs = inputs , job_name = training_job_name )
6067 assert_s3_files_exist (
6168 sagemaker_session ,
6269 estimator .model_dir ,
@@ -65,29 +72,6 @@ def test_mnist(sagemaker_session, instance_type):
6572 df = estimator .training_job_analytics .dataframe ()
6673 assert df .size > 0
6774
68-
69- def test_checkpoint_config (sagemaker_session , instance_type ):
70- checkpoint_s3_uri = "s3://{}" .format (sagemaker_session .default_bucket ())
71- checkpoint_local_path = "/test/checkpoint/path"
72- estimator = TensorFlow (
73- entry_point = SCRIPT ,
74- role = "SageMakerRole" ,
75- train_instance_count = 1 ,
76- train_instance_type = instance_type ,
77- sagemaker_session = sagemaker_session ,
78- script_mode = True ,
79- framework_version = TensorFlow .LATEST_VERSION ,
80- py_version = tests .integ .PYTHON_VERSION ,
81- checkpoint_s3_uri = checkpoint_s3_uri ,
82- checkpoint_local_path = checkpoint_local_path ,
83- )
84- inputs = estimator .sagemaker_session .upload_data (
85- path = os .path .join (MNIST_RESOURCE_PATH , "data" ), key_prefix = "script/mnist"
86- )
87- training_job_name = unique_name_from_base ("test-tf-sm-checkpoint" )
88- with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
89- estimator .fit (inputs = inputs , job_name = training_job_name )
90-
9175 expected_training_checkpoint_config = {
9276 "S3Uri" : checkpoint_s3_uri ,
9377 "LocalPath" : checkpoint_local_path ,
0 commit comments