Skip to content
21 changes: 21 additions & 0 deletions test/integration/sagemaker/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,27 @@ def test_tuning(sagemaker_session, ecr_image, instance_type, framework_version):
tuner.wait()


Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs this test fixture:

@pytest.mark.skip_py2_containers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this required in tf2 branch as well?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. There are py2 containers for tf2, so this will be required for tf2 as well.

@pytest.mark.skip_py2_containers
def test_tf1x_smdebug(sagemaker_session, ecr_image, instance_type, framework_version):
resource_path = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
script = os.path.join(resource_path, 'mnist', 'tf1x_mnist_smdebug.py')
hyperparameters = {'smdebug_path': '/opt/ml/output/tensors'}
estimator = TensorFlow(entry_point=script,
role='SageMakerRole',
train_instance_type=instance_type,
train_instance_count=1,
sagemaker_session=sagemaker_session,
image_name=ecr_image,
framework_version=framework_version,
script_mode=True,
hyperparameters=hyperparameters)
inputs = estimator.sagemaker_session.upload_data(
path=os.path.join(resource_path, 'mnist', 'data'),
key_prefix='scriptmode/mnist_smdebug')
estimator.fit(inputs, job_name=unique_name_from_base('test-sagemaker-mnist-smdebug'))
_assert_s3_file_exists(sagemaker_session.boto_region_name, estimator.model_data)


def _assert_checkpoint_exists(region, model_dir, checkpoint_number):
_assert_s3_file_exists(region, os.path.join(model_dir, 'graph.pbtxt'))
_assert_s3_file_exists(region,
Expand Down
103 changes: 103 additions & 0 deletions test/resources/mnist/tf1x_mnist_smdebug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import argparse
import json
import os
import sys

import numpy as np
import tensorflow.compat.v1 as tf

import smdebug.tensorflow as smd
from smdebug.core.collection import CollectionKeys
from smdebug.core.reduction_config import ALLOWED_NORMS, ALLOWED_REDUCTIONS
from smdebug.tensorflow import ReductionConfig, SaveConfig
from smdebug.trials import create_trial


def _parse_args():

parser = argparse.ArgumentParser()

# hyperparameters sent by the client are passed as command-line arguments to the script.
parser.add_argument('--epochs', type=int, default=1)
# Data, model, and output directories
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
parser.add_argument(
"--smdebug_path",
type=str,
default=None,
help="S3 URI of the bucket where tensor data will be stored.",
)
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAINING'])
parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS']))
parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST'])

return parser.parse_known_args()


def _load_training_data(base_dir):
x_train = np.load(os.path.join(base_dir, 'train', 'x_train.npy'))
y_train = np.load(os.path.join(base_dir, 'train', 'y_train.npy'))
return x_train, y_train


def _load_testing_data(base_dir):
x_test = np.load(os.path.join(base_dir, 'test', 'x_test.npy'))
y_test = np.load(os.path.join(base_dir, 'test', 'y_test.npy'))
return x_test, y_test


def create_smdebug_hook(out_dir):
include_collections = [
CollectionKeys.WEIGHTS,
CollectionKeys.BIASES,
CollectionKeys.GRADIENTS,
CollectionKeys.LOSSES,
CollectionKeys.OUTPUTS,
CollectionKeys.METRICS,
CollectionKeys.LOSSES,
CollectionKeys.OPTIMIZER_VARIABLES,
]
save_config = SaveConfig(save_interval=3)
hook = smd.KerasHook(
out_dir,
save_config=save_config,
include_collections=include_collections,
reduction_config=ReductionConfig(norms=ALLOWED_NORMS, reductions=ALLOWED_REDUCTIONS),
)
return hook


args, unknown = _parse_args()

hook = create_smdebug_hook(args.smdebug_path)
hooks = [hook]

model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
x_train, y_train = _load_training_data(args.train)
x_test, y_test = _load_testing_data(args.train)
model.fit(x_train, y_train, epochs=args.epochs, callbacks=hooks)
model.evaluate(x_test, y_test, callbacks=hooks)

if args.current_host == args.hosts[0]:
model.save(os.path.join('/opt/ml/model', 'my_model.h5'))

print("Created the trial with out_dir {0}".format(args.smdebug_path))
trial = create_trial(args.smdebug_path)
assert trial

print(f"trial.tensor_names() = {trial.tensor_names()}")

weights_tensors = hook.collection_manager.get("weights").tensor_names
assert len(weights_tensors) > 0

losses_tensors = hook.collection_manager.get("losses").tensor_names
assert len(losses_tensors) > 0