2121RESOURCE_PATH = os .path .join (os .path .dirname (__file__ ), ".." , ".." , "resources" )
2222
2323
24- def test_multi_node (sagemaker_session , instance_type , image_uri , tmpdir , framework_version , capsys ):
24+ def test_keras_example (
25+ sagemaker_session , instance_type , image_uri , tmpdir , framework_version , capsys
26+ ):
2527 estimator = TensorFlow (
2628 entry_point = os .path .join (RESOURCE_PATH , "multi_worker_mirrored" , "train_dummy.py" ),
2729 role = "SageMakerRole" ,
@@ -40,3 +42,57 @@ def test_multi_node(sagemaker_session, instance_type, image_uri, tmpdir, framewo
4042 logs = captured .out + captured .err
4143 assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
4244 assert "TF_CONFIG=" in logs
45+
46+
47+ def test_tf_model_garden (
48+ sagemaker_session , instance_type , image_uri , tmpdir , framework_version , capsys
49+ ):
50+ epochs = 10
51+ batch_size = 512
52+ train_steps = int (1024 * epochs / batch_size )
53+ steps_per_loop = train_steps // 10
54+ overrides = (
55+ f"runtime.enable_xla=False,"
56+ f"runtime.num_gpus=1,"
57+ f"runtime.distribution_strategy=multi_worker_mirrored,"
58+ f"runtime.mixed_precision_dtype=float16,"
59+ f"task.train_data.global_batch_size={ batch_size } ,"
60+ f"task.train_data.input_path=/opt/ml/input/data/training/validation*,"
61+ f"task.train_data.cache=True,"
62+ f"trainer.train_steps={ train_steps } ,"
63+ f"trainer.steps_per_loop={ steps_per_loop } ,"
64+ f"trainer.summary_interval={ steps_per_loop } ,"
65+ f"trainer.checkpoint_interval={ train_steps } ,"
66+ f"task.model.backbone.type=resnet,"
67+ f"task.model.backbone.resnet.model_id=50"
68+ )
69+ estimator = TensorFlow (
70+ git_config = {
71+ "repo" : "https://github.com/tensorflow/models.git" ,
72+ "branch" : "v2.9.2" ,
73+ },
74+ source_dir = "." ,
75+ entry_point = "official/vision/train.py" ,
76+ model_dir = False ,
77+ instance_type = instance_type ,
78+ instance_count = 2 ,
79+ image_uri = image_uri ,
80+ hyperparameters = {
81+ "sagemaker_multi_worker_mirrored_strategy_enabled" : True ,
82+ "experiment" : "resnet_imagenet" ,
83+ "config_file" : "official/vision/configs/experiments/image_classification/imagenet_resnet50_gpu.yaml" ,
84+ "mode" : "train" ,
85+ "model_dir" : "/opt/ml/model" ,
86+ "params_override" : overrides ,
87+ },
88+ max_run = 60 * 60 * 1 , # 1 hour
89+ role = "SageMakerRole" ,
90+ )
91+ estimator .fit (
92+ inputs = "s3://collection-of-ml-datasets/Imagenet/TFRecords/validation" ,
93+ job_name = unique_name_from_base ("test-tf-mwms" ),
94+ )
95+ captured = capsys .readouterr ()
96+ logs = captured .out + captured .err
97+ assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
98+ assert "TF_CONFIG=" in logs
0 commit comments