-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
Hey there,
Another small issue with the LocalMode stuff :)
If I train a TensorFlow estimator with a 'local' instance type, the model_data is then stored in a local path. This works fine as long as you're using a LocalSession object as your sagemaker_session, but a normal SageMaker Session object will complain that it's not an S3/HTTP URI. This comes up if you do the following:
model_obj = sagemaker.tensorflow.model.TensorFlowModel(
model_data,
role=role,
entry_point=entry_point,
source_dir=source_dir,
name=model_name
)
model_obj.deploy(
initial_instance_count=1,
instance_type='local',
endpoint_name=endpoint_name
)
The reason I'm doing this is because you can't serialize the TensorFlow object (as far as I can tell), so I'm currently extracting the necessary data from it and serializing that as a separate object. Unfortunately, this means that the SageMaker Session (which is not serializable) is lost, however I can hack around it by doing this in between those two calls:
if instance_type in ('local', 'local_gpu'):
model_obj.sagemaker_session = sagemaker.local.local_session.LocalSession()
This is a bit hacky however and I'd prefer if there was a better way that the LocalSession aspect was abstracted into the actual SageMaker API.
Are there any suggestions for a better way to do this? Or is this sort of what I'll have to work with?
Thanks!