Skip to content

Commit 93d2380

Browse files
committed
Add file system descriptor to tensor & model keys
1 parent eace71e commit 93d2380

20 files changed

+438
-263
lines changed

doc/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Jump to:
1313

1414
Description
1515

16+
- Enable dynamic feature store selection
1617
- Add TorchWorker first implementation and mock inference app example
1718
- Add EnvironmentConfigLoader for ML Worker Manager
1819
- Add Model schema with model metadata included

smartsim/_core/mli/infrastructure/control/workermanager.py

Lines changed: 78 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,9 @@
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
import sys
28-
29-
# isort: off
30-
import dragon
31-
from dragon import fli
32-
33-
# isort: on
34-
3528
import time
3629
import typing as t
3730

38-
import numpy as np
39-
4031
from .....error import SmartSimError
4132
from .....log import get_logger
4233
from ....entrypoints.service import Service
@@ -54,108 +45,28 @@
5445
from ...mli_schemas.response.response_capnp import Response
5546

5647
if t.TYPE_CHECKING:
57-
from dragon.fli import FLInterface
58-
59-
from smartsim._core.mli.mli_schemas.model.model_capnp import Model
6048
from smartsim._core.mli.mli_schemas.response.response_capnp import StatusEnum
6149

6250
logger = get_logger(__name__)
6351

6452

65-
def deserialize_message(
66-
data_blob: bytes,
67-
channel_type: t.Type[CommChannelBase],
68-
device: t.Literal["cpu", "gpu"],
69-
) -> InferenceRequest:
70-
"""Deserialize a message from a byte stream into an InferenceRequest
71-
:param data_blob: The byte stream to deserialize"""
72-
# todo: consider moving to XxxCore and only making
73-
# workers implement the inputs and model conversion?
74-
75-
# alternatively, consider passing the capnproto models
76-
# to this method instead of the data_blob...
77-
78-
# something is definitely wrong here... client shouldn't have to touch
79-
# callback (or batch size)
80-
81-
request = MessageHandler.deserialize_request(data_blob)
82-
# return request
83-
model_key: t.Optional[str] = None
84-
model_bytes: t.Optional[Model] = None
85-
86-
if request.model.which() == "key":
87-
model_key = request.model.key.key
88-
elif request.model.which() == "data":
89-
model_bytes = request.model.data
90-
91-
callback_key = request.replyChannel.reply
92-
93-
# todo: shouldn't this be `CommChannel.find` instead of `DragonCommChannel`
94-
comm_channel = channel_type(callback_key)
95-
# comm_channel = DragonCommChannel(request.replyChannel)
96-
97-
input_keys: t.Optional[t.List[str]] = None
98-
input_bytes: t.Optional[t.List[bytes]] = (
99-
None # these will really be tensors already
100-
)
101-
102-
input_meta: t.List[t.Any] = []
103-
104-
if request.input.which() == "keys":
105-
input_keys = [input_key.key for input_key in request.input.keys]
106-
elif request.input.which() == "data":
107-
input_bytes = [data.blob for data in request.input.data]
108-
input_meta = [data.tensorDescriptor for data in request.input.data]
109-
110-
inference_request = InferenceRequest(
111-
model_key=model_key,
112-
callback=comm_channel,
113-
raw_inputs=input_bytes,
114-
input_meta=input_meta,
115-
input_keys=input_keys,
116-
raw_model=model_bytes,
117-
batch_size=0,
118-
)
119-
return inference_request
120-
121-
12253
def build_failure_reply(status: "StatusEnum", message: str) -> Response:
54+
"""Build a response indicating a failure occurred
55+
:param status: The status of the response
56+
:param message: The error message to include in the response"""
12357
return MessageHandler.build_response(
12458
status=status, # todo: need to indicate correct status
12559
message=message, # todo: decide what these will be
126-
result=[],
60+
result=None,
12761
custom_attributes=None,
12862
)
12963

13064

131-
def prepare_outputs(reply: InferenceReply) -> t.List[t.Any]:
132-
prepared_outputs: t.List[t.Any] = []
133-
if reply.output_keys:
134-
for key in reply.output_keys:
135-
if not key:
136-
continue
137-
msg_key = MessageHandler.build_tensor_key(key)
138-
prepared_outputs.append(msg_key)
139-
elif reply.outputs:
140-
arrays: t.List[np.ndarray[t.Any, np.dtype[t.Any]]] = [
141-
output.numpy() for output in reply.outputs
142-
]
143-
for tensor in arrays:
144-
# todo: need to have the output attributes specified in the req?
145-
# maybe, add `MessageHandler.dtype_of(tensor)`?
146-
# can `build_tensor` do dtype and shape?
147-
msg_tensor = MessageHandler.build_tensor(
148-
tensor,
149-
"c",
150-
"float32",
151-
[1],
152-
)
153-
prepared_outputs.append(msg_tensor)
154-
return prepared_outputs
155-
156-
157-
def build_reply(reply: InferenceReply) -> Response:
158-
results = prepare_outputs(reply)
65+
def build_reply(worker: MachineLearningWorkerBase, reply: InferenceReply) -> Response:
66+
"""Builds a response for a successful inference request
67+
:param worker: A worker to process the reply with
68+
:param reply: The internal representation of the reply"""
69+
results = worker.prepare_outputs(reply)
15970

16071
return MessageHandler.build_response(
16172
status="complete",
@@ -191,10 +102,6 @@ def __init__(
191102

192103
self._task_queue: t.Optional[CommChannelBase] = config_loader.get_queue()
193104
"""the queue the manager monitors for new tasks"""
194-
self._feature_store: t.Optional[FeatureStore] = (
195-
config_loader.get_feature_store()
196-
)
197-
"""a feature store to retrieve models from"""
198105
self._worker = worker
199106
"""The ML Worker implementation"""
200107
self._comm_channel_type = comm_channel_type
@@ -203,37 +110,68 @@ def __init__(
203110
"""Device on which workers need to run"""
204111
self._cached_models: dict[str, t.Any] = {}
205112
"""Dictionary of previously loaded models"""
113+
self._feature_stores = config_loader.get_feature_stores()
114+
"""A collection of attached feature stores"""
115+
116+
def _check_feature_stores(self, request: InferenceRequest) -> bool:
117+
"""Ensures that all feature stores required by the request are available
118+
:param request: The request to validate"""
119+
# collect all feature stores required by the request
120+
fs_model = {request.model_key.descriptor}
121+
fs_inputs = {key.descriptor for key in request.input_keys}
122+
fs_outputs = {key.descriptor for key in request.output_keys}
123+
124+
# identify which feature stores are requested and unknown
125+
fs_desired = fs_model + fs_inputs + fs_outputs
126+
fs_actual = {key for key in self._feature_stores}
127+
fs_missing = fs_desired - fs_actual
128+
129+
# exit if all desired feature stores are not available
130+
if fs_missing:
131+
logger.error(f"Missing feature store(s): {fs_missing}")
132+
return False
206133

207-
def _validate_request(self, request: InferenceRequest) -> bool:
208-
"""Ensure the request can be processed.
209-
:param request: The request to validate
210-
:return: True if the request is valid, False otherwise"""
211-
if not self._feature_store:
212-
if request.model_key:
213-
logger.error("Unable to load model by key without feature store")
214-
return False
134+
return True
215135

216-
if request.input_keys:
217-
logger.error("Unable to load inputs by key without feature store")
218-
return False
136+
def _check_model(self, request: InferenceRequest) -> bool:
137+
"""Ensure that a model is available for the request
138+
:param request: The request to validate"""
139+
if request.model_key or request.raw_model:
140+
return True
219141

220-
if request.output_keys:
221-
logger.error("Unable to persist outputs by key without feature store")
222-
return False
142+
logger.error("Unable to continue without model bytes or feature store key")
143+
return False
223144

224-
if not request.model_key and not request.raw_model:
225-
logger.error("Unable to continue without model bytes or feature store key")
226-
return False
145+
def _check_inputs(self, request: InferenceRequest) -> bool:
146+
"""Ensure that inputs are available for the request
147+
:param request: The request to validate"""
148+
if request.input_keys or request.raw_inputs:
149+
return True
227150

228-
if not request.input_keys and not request.raw_inputs:
229-
logger.error("Unable to continue without input bytes or feature store keys")
230-
return False
151+
logger.error("Unable to continue without input bytes or feature store keys")
152+
return False
231153

232-
if request.callback is None:
233-
logger.error("No callback channel provided in request")
234-
return False
154+
def _check_callback(self, request: InferenceRequest) -> bool:
155+
"""Ensure that a callback channel is available for the request
156+
:param request: The request to validate"""
157+
if request.callback is not None:
158+
return True
235159

236-
return True
160+
logger.error("No callback channel provided in request")
161+
return False
162+
163+
def _validate_request(self, request: InferenceRequest) -> bool:
164+
"""Ensure the request can be processed.
165+
:param request: The request to validate
166+
:return: True if the request is valid, False otherwise"""
167+
checks = [
168+
self._check_feature_stores(request),
169+
self._check_model(request),
170+
self._check_inputs(request),
171+
self._check_callback(request),
172+
]
173+
174+
return all(checks)
237175

238176
def _on_iteration(self) -> None:
239177
"""Executes calls to the machine learning worker implementation to complete
@@ -249,8 +187,8 @@ def _on_iteration(self) -> None:
249187
request_bytes: bytes = self._task_queue.recv()
250188

251189
interm = time.perf_counter() # timing
252-
request = deserialize_message(
253-
request_bytes, self._comm_channel_type, self._device
190+
request = self._worker.deserialize_message(
191+
request_bytes, self._comm_channel_type
254192
)
255193
if not self._validate_request(request):
256194
return
@@ -262,18 +200,21 @@ def _on_iteration(self) -> None:
262200
if request.model_key is None:
263201
# A valid request should never get here.
264202
raise ValueError("Could not read model key")
265-
if request.model_key in self._cached_models:
203+
204+
if request.model_key.key in self._cached_models:
266205
timings.append(time.perf_counter() - interm) # timing
267206
interm = time.perf_counter() # timing
268-
model_result = LoadModelResult(self._cached_models[request.model_key])
207+
model_result = LoadModelResult(
208+
self._cached_models[request.model_key.key]
209+
)
269210

270211
else:
271212
fetch_model_result = None
272213
while True:
273214
try:
274215
interm = time.perf_counter() # timing
275216
fetch_model_result = self._worker.fetch_model(
276-
request, self._feature_store
217+
request, self._feature_stores
277218
)
278219
except KeyError:
279220
time.sleep(0.1)
@@ -287,16 +228,17 @@ def _on_iteration(self) -> None:
287228
model_result = self._worker.load_model(
288229
request, fetch_model_result, self._device
289230
)
290-
self._cached_models[request.model_key] = model_result.model
231+
self._cached_models[request.model_key.key] = model_result.model
291232
else:
292-
fetch_model_result = self._worker.fetch_model(request, None)
233+
fetch_model_result = self._worker.fetch_model(request, {})
293234
model_result = self._worker.load_model(
294235
request, fetch_result=fetch_model_result, device=self._device
295236
)
296237

297238
timings.append(time.perf_counter() - interm) # timing
298239
interm = time.perf_counter() # timing
299-
fetch_input_result = self._worker.fetch_inputs(request, self._feature_store)
240+
241+
fetch_input_result = self._worker.fetch_inputs(request, self._feature_stores)
300242

301243
timings.append(time.perf_counter() - interm) # timing
302244
interm = time.perf_counter() # timing
@@ -324,7 +266,7 @@ def _on_iteration(self) -> None:
324266
interm = time.perf_counter() # timing
325267
if request.output_keys:
326268
reply.output_keys = self._worker.place_output(
327-
request, transformed_output, self._feature_store
269+
request, transformed_output, self._feature_stores
328270
)
329271
else:
330272
reply.outputs = transformed_output.outputs
@@ -341,7 +283,7 @@ def _on_iteration(self) -> None:
341283
if reply.outputs is None or not reply.outputs:
342284
response = build_failure_reply("fail", "no-results")
343285

344-
response = build_reply(reply)
286+
response = build_reply(self._worker, reply)
345287

346288
timings.append(time.perf_counter() - interm) # timing
347289
interm = time.perf_counter() # timing

smartsim/_core/mli/infrastructure/environmentloader.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,15 @@
3131

3232
from dragon.fli import FLInterface # pylint: disable=all
3333

34+
from smartsim.error.errors import SmartSimError
35+
from smartsim.log import get_logger
3436
from smartsim._core.mli.comm.channel.dragonfli import DragonFLIChannel
3537
from smartsim._core.mli.infrastructure.storage.featurestore import FeatureStore
3638

3739

40+
logger = get_logger(__name__)
41+
42+
3843
class EnvironmentConfigLoader:
3944
"""
4045
Facilitates the loading of a FeatureStore and Queue
@@ -47,15 +52,35 @@ def __init__(self) -> None:
4752
)
4853
self._queue_descriptor: t.Optional[str] = os.getenv("SSQueue", None)
4954
self.feature_store: t.Optional[FeatureStore] = None
55+
self.feature_stores: t.Optional[t.Dict[FeatureStore]] = None
5056
self.queue: t.Optional[DragonFLIChannel] = None
5157

52-
def get_feature_store(self) -> t.Optional[FeatureStore]:
53-
"""Loads the Feature Store previously set in SSFeatureStore"""
54-
if self._feature_store_descriptor is not None:
55-
self.feature_store = pickle.loads(
56-
base64.b64decode(self._feature_store_descriptor)
58+
def _load_feature_store(self, env_var: str) -> FeatureStore:
59+
"""Load a feature store from a descriptor
60+
:param descriptor: The descriptor of the feature store
61+
:returns: The hydrated feature store"""
62+
logger.debug(f"Loading feature store from env: {env_var}")
63+
64+
value = os.getenv(env_var)
65+
if not value:
66+
raise SmartSimError(f"Empty feature store descriptor in environment: {env_var}")
67+
68+
try:
69+
return pickle.loads(base64.b64decode(value))
70+
except:
71+
raise SmartSimError(
72+
f"Invalid feature store descriptor in environment: {env_var}"
5773
)
58-
return self.feature_store
74+
75+
def get_feature_stores(self) -> t.Dict[str, FeatureStore]:
76+
"""Loads multiple Feature Stores by scanning environment for variables
77+
prefixed with `SSFeatureStore`"""
78+
prefix = "SSFeatureStore"
79+
if self.feature_stores is None:
80+
env_vars = [var for var in os.environ if var.startswith(prefix)]
81+
stores = [self._load_feature_store(var) for var in env_vars]
82+
self.feature_stores = {fs.descriptor: fs for fs in stores}
83+
return self.feature_stores
5984

6085
def get_queue(self, sender_supplied: bool = True) -> t.Optional[DragonFLIChannel]:
6186
"""Returns the Queue previously set in SSQueue"""

smartsim/_core/mli/infrastructure/storage/dragonfeaturestore.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,10 @@ def __contains__(self, key: str) -> bool:
6969
Return `True` if the key is found, `False` otherwise
7070
:param key: Unique key of an item to retrieve from the feature store"""
7171
return key in self._storage
72+
73+
@property
74+
def descriptor(self) -> str:
75+
"""Return a unique identifier enabling a client to connect to
76+
the feature store
77+
:returns: A descriptor encoded as a string"""
78+
return str(self._storage.serialize())

0 commit comments

Comments
 (0)