2525# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
2727import sys
28-
29- # isort: off
30- import dragon
31- from dragon import fli
32-
33- # isort: on
34-
3528import time
3629import typing as t
3730
38- import numpy as np
39-
4031from .....error import SmartSimError
4132from .....log import get_logger
4233from ....entrypoints .service import Service
5445from ...mli_schemas .response .response_capnp import Response
5546
5647if t .TYPE_CHECKING :
57- from dragon .fli import FLInterface
58-
59- from smartsim ._core .mli .mli_schemas .model .model_capnp import Model
6048 from smartsim ._core .mli .mli_schemas .response .response_capnp import StatusEnum
6149
6250logger = get_logger (__name__ )
6351
6452
65- def deserialize_message (
66- data_blob : bytes ,
67- channel_type : t .Type [CommChannelBase ],
68- device : t .Literal ["cpu" , "gpu" ],
69- ) -> InferenceRequest :
70- """Deserialize a message from a byte stream into an InferenceRequest
71- :param data_blob: The byte stream to deserialize"""
72- # todo: consider moving to XxxCore and only making
73- # workers implement the inputs and model conversion?
74-
75- # alternatively, consider passing the capnproto models
76- # to this method instead of the data_blob...
77-
78- # something is definitely wrong here... client shouldn't have to touch
79- # callback (or batch size)
80-
81- request = MessageHandler .deserialize_request (data_blob )
82- # return request
83- model_key : t .Optional [str ] = None
84- model_bytes : t .Optional [Model ] = None
85-
86- if request .model .which () == "key" :
87- model_key = request .model .key .key
88- elif request .model .which () == "data" :
89- model_bytes = request .model .data
90-
91- callback_key = request .replyChannel .reply
92-
93- # todo: shouldn't this be `CommChannel.find` instead of `DragonCommChannel`
94- comm_channel = channel_type (callback_key )
95- # comm_channel = DragonCommChannel(request.replyChannel)
96-
97- input_keys : t .Optional [t .List [str ]] = None
98- input_bytes : t .Optional [t .List [bytes ]] = (
99- None # these will really be tensors already
100- )
101-
102- input_meta : t .List [t .Any ] = []
103-
104- if request .input .which () == "keys" :
105- input_keys = [input_key .key for input_key in request .input .keys ]
106- elif request .input .which () == "data" :
107- input_bytes = [data .blob for data in request .input .data ]
108- input_meta = [data .tensorDescriptor for data in request .input .data ]
109-
110- inference_request = InferenceRequest (
111- model_key = model_key ,
112- callback = comm_channel ,
113- raw_inputs = input_bytes ,
114- input_meta = input_meta ,
115- input_keys = input_keys ,
116- raw_model = model_bytes ,
117- batch_size = 0 ,
118- )
119- return inference_request
120-
121-
12253def build_failure_reply (status : "StatusEnum" , message : str ) -> Response :
54+ """Build a response indicating a failure occurred
55+ :param status: The status of the response
56+ :param message: The error message to include in the response"""
12357 return MessageHandler .build_response (
12458 status = status , # todo: need to indicate correct status
12559 message = message , # todo: decide what these will be
126- result = [] ,
60+ result = None ,
12761 custom_attributes = None ,
12862 )
12963
13064
131- def prepare_outputs (reply : InferenceReply ) -> t .List [t .Any ]:
132- prepared_outputs : t .List [t .Any ] = []
133- if reply .output_keys :
134- for key in reply .output_keys :
135- if not key :
136- continue
137- msg_key = MessageHandler .build_tensor_key (key )
138- prepared_outputs .append (msg_key )
139- elif reply .outputs :
140- arrays : t .List [np .ndarray [t .Any , np .dtype [t .Any ]]] = [
141- output .numpy () for output in reply .outputs
142- ]
143- for tensor in arrays :
144- # todo: need to have the output attributes specified in the req?
145- # maybe, add `MessageHandler.dtype_of(tensor)`?
146- # can `build_tensor` do dtype and shape?
147- msg_tensor = MessageHandler .build_tensor (
148- tensor ,
149- "c" ,
150- "float32" ,
151- [1 ],
152- )
153- prepared_outputs .append (msg_tensor )
154- return prepared_outputs
155-
156-
157- def build_reply (reply : InferenceReply ) -> Response :
158- results = prepare_outputs (reply )
65+ def build_reply (worker : MachineLearningWorkerBase , reply : InferenceReply ) -> Response :
66+ """Builds a response for a successful inference request
67+ :param worker: A worker to process the reply with
68+ :param reply: The internal representation of the reply"""
69+ results = worker .prepare_outputs (reply )
15970
16071 return MessageHandler .build_response (
16172 status = "complete" ,
@@ -191,10 +102,6 @@ def __init__(
191102
192103 self ._task_queue : t .Optional [CommChannelBase ] = config_loader .get_queue ()
193104 """the queue the manager monitors for new tasks"""
194- self ._feature_store : t .Optional [FeatureStore ] = (
195- config_loader .get_feature_store ()
196- )
197- """a feature store to retrieve models from"""
198105 self ._worker = worker
199106 """The ML Worker implementation"""
200107 self ._comm_channel_type = comm_channel_type
@@ -203,37 +110,68 @@ def __init__(
203110 """Device on which workers need to run"""
204111 self ._cached_models : dict [str , t .Any ] = {}
205112 """Dictionary of previously loaded models"""
113+ self ._feature_stores = config_loader .get_feature_stores ()
114+ """A collection of attached feature stores"""
115+
116+ def _check_feature_stores (self , request : InferenceRequest ) -> bool :
117+ """Ensures that all feature stores required by the request are available
118+ :param request: The request to validate"""
119+ # collect all feature stores required by the request
120+ fs_model = {request .model_key .descriptor }
121+ fs_inputs = {key .descriptor for key in request .input_keys }
122+ fs_outputs = {key .descriptor for key in request .output_keys }
123+
124+ # identify which feature stores are requested and unknown
125+ fs_desired = fs_model + fs_inputs + fs_outputs
126+ fs_actual = {key for key in self ._feature_stores }
127+ fs_missing = fs_desired - fs_actual
128+
129+ # exit if all desired feature stores are not available
130+ if fs_missing :
131+ logger .error (f"Missing feature store(s): { fs_missing } " )
132+ return False
206133
207- def _validate_request (self , request : InferenceRequest ) -> bool :
208- """Ensure the request can be processed.
209- :param request: The request to validate
210- :return: True if the request is valid, False otherwise"""
211- if not self ._feature_store :
212- if request .model_key :
213- logger .error ("Unable to load model by key without feature store" )
214- return False
134+ return True
215135
216- if request .input_keys :
217- logger .error ("Unable to load inputs by key without feature store" )
218- return False
136+ def _check_model (self , request : InferenceRequest ) -> bool :
137+ """Ensure that a model is available for the request
138+ :param request: The request to validate"""
139+ if request .model_key or request .raw_model :
140+ return True
219141
220- if request .output_keys :
221- logger .error ("Unable to persist outputs by key without feature store" )
222- return False
142+ logger .error ("Unable to continue without model bytes or feature store key" )
143+ return False
223144
224- if not request .model_key and not request .raw_model :
225- logger .error ("Unable to continue without model bytes or feature store key" )
226- return False
145+ def _check_inputs (self , request : InferenceRequest ) -> bool :
146+ """Ensure that inputs are available for the request
147+ :param request: The request to validate"""
148+ if request .input_keys or request .raw_inputs :
149+ return True
227150
228- if not request .input_keys and not request .raw_inputs :
229- logger .error ("Unable to continue without input bytes or feature store keys" )
230- return False
151+ logger .error ("Unable to continue without input bytes or feature store keys" )
152+ return False
231153
232- if request .callback is None :
233- logger .error ("No callback channel provided in request" )
234- return False
154+ def _check_callback (self , request : InferenceRequest ) -> bool :
155+ """Ensure that a callback channel is available for the request
156+ :param request: The request to validate"""
157+ if request .callback is not None :
158+ return True
235159
236- return True
160+ logger .error ("No callback channel provided in request" )
161+ return False
162+
163+ def _validate_request (self , request : InferenceRequest ) -> bool :
164+ """Ensure the request can be processed.
165+ :param request: The request to validate
166+ :return: True if the request is valid, False otherwise"""
167+ checks = [
168+ self ._check_feature_stores (request ),
169+ self ._check_model (request ),
170+ self ._check_inputs (request ),
171+ self ._check_callback (request ),
172+ ]
173+
174+ return all (checks )
237175
238176 def _on_iteration (self ) -> None :
239177 """Executes calls to the machine learning worker implementation to complete
@@ -249,8 +187,8 @@ def _on_iteration(self) -> None:
249187 request_bytes : bytes = self ._task_queue .recv ()
250188
251189 interm = time .perf_counter () # timing
252- request = deserialize_message (
253- request_bytes , self ._comm_channel_type , self . _device
190+ request = self . _worker . deserialize_message (
191+ request_bytes , self ._comm_channel_type
254192 )
255193 if not self ._validate_request (request ):
256194 return
@@ -262,18 +200,21 @@ def _on_iteration(self) -> None:
262200 if request .model_key is None :
263201 # A valid request should never get here.
264202 raise ValueError ("Could not read model key" )
265- if request .model_key in self ._cached_models :
203+
204+ if request .model_key .key in self ._cached_models :
266205 timings .append (time .perf_counter () - interm ) # timing
267206 interm = time .perf_counter () # timing
268- model_result = LoadModelResult (self ._cached_models [request .model_key ])
207+ model_result = LoadModelResult (
208+ self ._cached_models [request .model_key .key ]
209+ )
269210
270211 else :
271212 fetch_model_result = None
272213 while True :
273214 try :
274215 interm = time .perf_counter () # timing
275216 fetch_model_result = self ._worker .fetch_model (
276- request , self ._feature_store
217+ request , self ._feature_stores
277218 )
278219 except KeyError :
279220 time .sleep (0.1 )
@@ -287,16 +228,17 @@ def _on_iteration(self) -> None:
287228 model_result = self ._worker .load_model (
288229 request , fetch_model_result , self ._device
289230 )
290- self ._cached_models [request .model_key ] = model_result .model
231+ self ._cached_models [request .model_key . key ] = model_result .model
291232 else :
292- fetch_model_result = self ._worker .fetch_model (request , None )
233+ fetch_model_result = self ._worker .fetch_model (request , {} )
293234 model_result = self ._worker .load_model (
294235 request , fetch_result = fetch_model_result , device = self ._device
295236 )
296237
297238 timings .append (time .perf_counter () - interm ) # timing
298239 interm = time .perf_counter () # timing
299- fetch_input_result = self ._worker .fetch_inputs (request , self ._feature_store )
240+
241+ fetch_input_result = self ._worker .fetch_inputs (request , self ._feature_stores )
300242
301243 timings .append (time .perf_counter () - interm ) # timing
302244 interm = time .perf_counter () # timing
@@ -324,7 +266,7 @@ def _on_iteration(self) -> None:
324266 interm = time .perf_counter () # timing
325267 if request .output_keys :
326268 reply .output_keys = self ._worker .place_output (
327- request , transformed_output , self ._feature_store
269+ request , transformed_output , self ._feature_stores
328270 )
329271 else :
330272 reply .outputs = transformed_output .outputs
@@ -341,7 +283,7 @@ def _on_iteration(self) -> None:
341283 if reply .outputs is None or not reply .outputs :
342284 response = build_failure_reply ("fail" , "no-results" )
343285
344- response = build_reply (reply )
286+ response = build_reply (self . _worker , reply )
345287
346288 timings .append (time .perf_counter () - interm ) # timing
347289 interm = time .perf_counter () # timing
0 commit comments