3434import torch
3535
3636import smartsim .error as sse
37- from smartsim ._core .mli .infrastructure .storage .feature_store import FeatureStoreKey
37+ from smartsim ._core .mli .infrastructure .storage .feature_store import ModelKey , TensorKey
3838from smartsim ._core .mli .infrastructure .worker .worker import (
3939 InferenceRequest ,
4040 MachineLearningWorkerCore ,
@@ -98,7 +98,7 @@ def test_fetch_model_disk(persist_torch_model: pathlib.Path, test_dir: str) -> N
9898 fsd = feature_store .descriptor
9999 feature_store [str (persist_torch_model )] = persist_torch_model .read_bytes ()
100100
101- model_key = FeatureStoreKey (key = key , descriptor = fsd )
101+ model_key = ModelKey (key = key , descriptor = fsd )
102102 request = InferenceRequest (model_key = model_key )
103103 batch = RequestBatch ([request ], None , model_key )
104104
@@ -116,7 +116,7 @@ def test_fetch_model_disk_missing() -> None:
116116
117117 key = "/path/that/doesnt/exist"
118118
119- model_key = FeatureStoreKey (key = key , descriptor = fsd )
119+ model_key = ModelKey (key = key , descriptor = fsd )
120120 request = InferenceRequest (model_key = model_key )
121121 batch = RequestBatch ([request ], None , model_key )
122122
@@ -141,7 +141,7 @@ def test_fetch_model_feature_store(persist_torch_model: pathlib.Path) -> None:
141141 fsd = feature_store .descriptor
142142 feature_store [key ] = persist_torch_model .read_bytes ()
143143
144- model_key = FeatureStoreKey (key = key , descriptor = feature_store .descriptor )
144+ model_key = ModelKey (key = key , descriptor = feature_store .descriptor )
145145 request = InferenceRequest (model_key = model_key )
146146 batch = RequestBatch ([request ], None , model_key )
147147
@@ -159,7 +159,7 @@ def test_fetch_model_feature_store_missing() -> None:
159159 feature_store = MemoryFeatureStore ()
160160 fsd = feature_store .descriptor
161161
162- model_key = FeatureStoreKey (key = key , descriptor = feature_store .descriptor )
162+ model_key = ModelKey (key = key , descriptor = feature_store .descriptor )
163163 request = InferenceRequest (model_key = model_key )
164164 batch = RequestBatch ([request ], None , model_key )
165165
@@ -182,7 +182,7 @@ def test_fetch_model_memory(persist_torch_model: pathlib.Path) -> None:
182182 fsd = feature_store .descriptor
183183 feature_store [key ] = persist_torch_model .read_bytes ()
184184
185- model_key = FeatureStoreKey (key = key , descriptor = feature_store .descriptor )
185+ model_key = ModelKey (key = key , descriptor = feature_store .descriptor )
186186 request = InferenceRequest (model_key = model_key )
187187 batch = RequestBatch ([request ], None , model_key )
188188
@@ -199,11 +199,9 @@ def test_fetch_input_disk(persist_torch_tensor: pathlib.Path) -> None:
199199
200200 feature_store = MemoryFeatureStore ()
201201 fsd = feature_store .descriptor
202- request = InferenceRequest (
203- input_keys = [FeatureStoreKey (key = tensor_name , descriptor = fsd )]
204- )
202+ request = InferenceRequest (input_keys = [TensorKey (key = tensor_name , descriptor = fsd )])
205203
206- model_key = FeatureStoreKey (key = "test-model" , descriptor = fsd )
204+ model_key = ModelKey (key = "test-model" , descriptor = fsd )
207205 batch = RequestBatch ([request ], None , model_key )
208206
209207 worker = MachineLearningWorkerCore
@@ -223,9 +221,9 @@ def test_fetch_input_disk_missing() -> None:
223221 fsd = feature_store .descriptor
224222 key = "/path/that/doesnt/exist"
225223
226- request = InferenceRequest (input_keys = [FeatureStoreKey (key = key , descriptor = fsd )])
224+ request = InferenceRequest (input_keys = [TensorKey (key = key , descriptor = fsd )])
227225
228- model_key = FeatureStoreKey (key = "test-model" , descriptor = fsd )
226+ model_key = ModelKey (key = "test-model" , descriptor = fsd )
229227 batch = RequestBatch ([request ], None , model_key )
230228
231229 with pytest .raises (sse .SmartSimError ) as ex :
@@ -245,14 +243,12 @@ def test_fetch_input_feature_store(persist_torch_tensor: pathlib.Path) -> None:
245243 feature_store = MemoryFeatureStore ()
246244 fsd = feature_store .descriptor
247245
248- request = InferenceRequest (
249- input_keys = [FeatureStoreKey (key = tensor_name , descriptor = fsd )]
250- )
246+ request = InferenceRequest (input_keys = [TensorKey (key = tensor_name , descriptor = fsd )])
251247
252248 # put model bytes into the feature store
253249 feature_store [tensor_name ] = persist_torch_tensor .read_bytes ()
254250
255- model_key = FeatureStoreKey (key = "test-model" , descriptor = fsd )
251+ model_key = ModelKey (key = "test-model" , descriptor = fsd )
256252 batch = RequestBatch ([request ], None , model_key )
257253
258254 fetch_result = worker .fetch_inputs (batch , {fsd : feature_store })
@@ -284,13 +280,13 @@ def test_fetch_multi_input_feature_store(persist_torch_tensor: pathlib.Path) ->
284280
285281 request = InferenceRequest (
286282 input_keys = [
287- FeatureStoreKey (key = tensor_name + "1" , descriptor = fsd ),
288- FeatureStoreKey (key = tensor_name + "2" , descriptor = fsd ),
289- FeatureStoreKey (key = tensor_name + "3" , descriptor = fsd ),
283+ TensorKey (key = tensor_name + "1" , descriptor = fsd ),
284+ TensorKey (key = tensor_name + "2" , descriptor = fsd ),
285+ TensorKey (key = tensor_name + "3" , descriptor = fsd ),
290286 ]
291287 )
292288
293- model_key = FeatureStoreKey (key = "test-model" , descriptor = fsd )
289+ model_key = ModelKey (key = "test-model" , descriptor = fsd )
294290 batch = RequestBatch ([request ], None , model_key )
295291
296292 fetch_result = worker .fetch_inputs (batch , {fsd : feature_store })
@@ -310,9 +306,9 @@ def test_fetch_input_feature_store_missing() -> None:
310306 key = "bad-key"
311307 feature_store = MemoryFeatureStore ()
312308 fsd = feature_store .descriptor
313- request = InferenceRequest (input_keys = [FeatureStoreKey (key = key , descriptor = fsd )])
309+ request = InferenceRequest (input_keys = [TensorKey (key = key , descriptor = fsd )])
314310
315- model_key = FeatureStoreKey (key = "test-model" , descriptor = fsd )
311+ model_key = ModelKey (key = "test-model" , descriptor = fsd )
316312 batch = RequestBatch ([request ], None , model_key )
317313
318314 with pytest .raises (sse .SmartSimError ) as ex :
@@ -332,9 +328,9 @@ def test_fetch_input_memory(persist_torch_tensor: pathlib.Path) -> None:
332328
333329 key = "test-model"
334330 feature_store [key ] = persist_torch_tensor .read_bytes ()
335- request = InferenceRequest (input_keys = [FeatureStoreKey (key = key , descriptor = fsd )])
331+ request = InferenceRequest (input_keys = [TensorKey (key = key , descriptor = fsd )])
336332
337- model_key = FeatureStoreKey (key = "test-model" , descriptor = fsd )
333+ model_key = ModelKey (key = "test-model" , descriptor = fsd )
338334 batch = RequestBatch ([request ], None , model_key )
339335
340336 fetch_result = worker .fetch_inputs (batch , {fsd : feature_store })
@@ -351,9 +347,9 @@ def test_place_outputs() -> None:
351347
352348 # create a key to retrieve from the feature store
353349 keys = [
354- FeatureStoreKey (key = key_name + "1" , descriptor = fsd ),
355- FeatureStoreKey (key = key_name + "2" , descriptor = fsd ),
356- FeatureStoreKey (key = key_name + "3" , descriptor = fsd ),
350+ TensorKey (key = key_name + "1" , descriptor = fsd ),
351+ TensorKey (key = key_name + "2" , descriptor = fsd ),
352+ TensorKey (key = key_name + "3" , descriptor = fsd ),
357353 ]
358354 data = [b"abcdef" , b"ghijkl" , b"mnopqr" ]
359355
@@ -376,6 +372,6 @@ def test_place_outputs() -> None:
376372 pytest .param ("key" , "" , id = "invalid descriptor" ),
377373 ],
378374)
379- def test_invalid_featurestorekey (key , descriptor ) -> None :
375+ def test_invalid_tensorkey (key , descriptor ) -> None :
380376 with pytest .raises (ValueError ):
381- fsk = FeatureStoreKey (key , descriptor )
377+ fsk = TensorKey (key , descriptor )
0 commit comments