diff --git a/src/sagemaker/tensorflow/predictor.py b/src/sagemaker/tensorflow/predictor.py index ae7277e9ac..caf2959f4c 100644 --- a/src/sagemaker/tensorflow/predictor.py +++ b/src/sagemaker/tensorflow/predictor.py @@ -18,20 +18,31 @@ import google.protobuf.json_format as json_format from google.protobuf.message import DecodeError from protobuf_to_dict import protobuf_to_dict -from tensorflow.core.framework import tensor_pb2 # pylint: disable=no-name-in-module -from tensorflow.python.framework import tensor_util # pylint: disable=no-name-in-module from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_OCTET_STREAM, CONTENT_TYPE_CSV from sagemaker.predictor import json_serializer, csv_serializer -from tensorflow_serving.apis import predict_pb2, classification_pb2, inference_pb2, regression_pb2 -_POSSIBLE_RESPONSES = [ - predict_pb2.PredictResponse, - classification_pb2.ClassificationResponse, - inference_pb2.MultiInferenceResponse, - regression_pb2.RegressionResponse, - tensor_pb2.TensorProto, -] + +def _possible_responses(): + """ + Returns: Possible available request types. + """ + from tensorflow.core.framework import tensor_pb2 # pylint: disable=no-name-in-module + from tensorflow_serving.apis import ( + predict_pb2, + classification_pb2, + inference_pb2, + regression_pb2, + ) + + return [ + predict_pb2.PredictResponse, + classification_pb2.ClassificationResponse, + inference_pb2.MultiInferenceResponse, + regression_pb2.RegressionResponse, + tensor_pb2.TensorProto, + ] + REGRESSION_REQUEST = "RegressionRequest" MULTI_INFERENCE_REQUEST = "MultiInferenceRequest" @@ -88,7 +99,7 @@ def __call__(self, stream, content_type): finally: stream.close() - for possible_response in _POSSIBLE_RESPONSES: + for possible_response in _possible_responses(): try: response = possible_response() response.ParseFromString(data) @@ -114,6 +125,9 @@ def __call__(self, data): Args: data: """ + + from tensorflow.core.framework import tensor_pb2 # pylint: disable=no-name-in-module + if isinstance(data, tensor_pb2.TensorProto): return json_format.MessageToJson(data) return json_serializer(data) @@ -139,7 +153,7 @@ def __call__(self, stream, content_type): finally: stream.close() - for possible_response in _POSSIBLE_RESPONSES: + for possible_response in _possible_responses(): try: return protobuf_to_dict(json_format.Parse(data, possible_response())) except (UnicodeDecodeError, DecodeError, json_format.ParseError): @@ -164,6 +178,10 @@ def __call__(self, data): data: """ to_serialize = data + + from tensorflow.core.framework import tensor_pb2 # pylint: disable=no-name-in-module + from tensorflow.python.framework import tensor_util # pylint: disable=no-name-in-module + if isinstance(data, tensor_pb2.TensorProto): to_serialize = tensor_util.MakeNdarray(data) return csv_serializer(to_serialize)