Skip to content

Commit ab900b8

Browse files
authored
Remove device attribute from schemas (#619)
This PR removes `device` from the schemas, MessageHandler, and tests.
1 parent 38081da commit ab900b8

File tree

9 files changed

+24
-175
lines changed

9 files changed

+24
-175
lines changed

doc/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Description
1616
- Add ML worker manager, sample worker, and feature store
1717
- Added schemas and MessageHandler class for de/serialization of
1818
inference requests and response messages
19+
- Removed device from schemas, MessageHandler and tests
1920

2021

2122
### Development branch

smartsim/_core/mli/infrastructure/control/workermanager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def deserialize_message(
6464

6565
request = MessageHandler.deserialize_request(data_blob)
6666
# return request
67-
device = request.device
6867
model_key: t.Optional[str] = None
6968
model_bytes: t.Optional[bytes] = None
7069

@@ -106,7 +105,6 @@ def deserialize_message(
106105
input_keys=input_keys,
107106
raw_model=model_bytes,
108107
batch_size=0,
109-
device=device,
110108
)
111109
return inference_request
112110

smartsim/_core/mli/infrastructure/worker/worker.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def __init__(
5050
output_keys: t.Optional[t.List[str]] = None,
5151
raw_model: t.Optional[bytes] = None,
5252
batch_size: int = 0,
53-
device: t.Optional[str] = None,
5453
):
5554
"""Initialize the object"""
5655
self.model_key = model_key
@@ -61,7 +60,6 @@ def __init__(
6160
self.input_meta = input_meta or []
6261
self.output_keys = output_keys or []
6362
self.batch_size = batch_size
64-
self.device = device
6563

6664

6765
class InferenceReply:

smartsim/_core/mli/message_handler.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -220,22 +220,6 @@ def _assign_reply_channel(
220220
except Exception as e:
221221
raise ValueError("Error building reply channel portion of request.") from e
222222

223-
@staticmethod
224-
def _assign_device(
225-
request: request_capnp.Request, device: "request_capnp.Device"
226-
) -> None:
227-
"""
228-
Assigns a device to the supplied request.
229-
230-
:param request: Request being built
231-
:param device: Device to be assigned
232-
:raises ValueError: if building fails
233-
"""
234-
try:
235-
request.device = device
236-
except Exception as e:
237-
raise ValueError("Error building device portion of request.") from e
238-
239223
@staticmethod
240224
def _assign_inputs(
241225
request: request_capnp.Request,
@@ -342,7 +326,6 @@ def _assign_custom_request_attributes(
342326
def build_request(
343327
reply_channel: t.ByteString,
344328
model: t.Union[data_references_capnp.ModelKey, t.ByteString],
345-
device: "request_capnp.Device",
346329
inputs: t.Union[
347330
t.List[data_references_capnp.TensorKey], t.List[tensor_capnp.Tensor]
348331
],
@@ -359,7 +342,6 @@ def build_request(
359342
360343
:param reply_channel: Reply channel to be assigned to request
361344
:param model: Model to be assigned to request
362-
:param device: Device to be assigned to request
363345
:param inputs: Inputs to be assigned to request
364346
:param outputs: Outputs to be assigned to request
365347
:param output_descriptors: Output descriptors to be assigned to request
@@ -368,7 +350,6 @@ def build_request(
368350
request = request_capnp.Request.new_message()
369351
MessageHandler._assign_reply_channel(request, reply_channel)
370352
MessageHandler._assign_model(request, model)
371-
MessageHandler._assign_device(request, device)
372353
MessageHandler._assign_inputs(request, inputs)
373354
MessageHandler._assign_outputs(request, outputs)
374355
MessageHandler._assign_output_descriptors(request, output_descriptors)

smartsim/_core/mli/mli_schemas/request/request.capnp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@ using Tensors = import "../tensor/tensor.capnp";
3030
using RequestAttributes = import "request_attributes/request_attributes.capnp";
3131
using DataRef = import "../data/data_references.capnp";
3232

33-
enum Device {
34-
cpu @0;
35-
gpu @1;
36-
auto @2;
37-
}
38-
3933
struct ChannelDescriptor {
4034
reply @0 :Data;
4135
}
@@ -46,16 +40,15 @@ struct Request {
4640
modelKey @1 :DataRef.ModelKey;
4741
modelData @2 :Data;
4842
}
49-
device @3 :Device;
5043
input :union {
51-
inputKeys @4 :List(DataRef.TensorKey);
52-
inputData @5 :List(Tensors.Tensor);
44+
inputKeys @3 :List(DataRef.TensorKey);
45+
inputData @4 :List(Tensors.Tensor);
5346
}
54-
output @6 :List(DataRef.TensorKey);
55-
outputDescriptors @7 :List(Tensors.OutputDescriptor);
47+
output @5 :List(DataRef.TensorKey);
48+
outputDescriptors @6 :List(Tensors.OutputDescriptor);
5649
customAttributes :union {
57-
torch @8 :RequestAttributes.TorchRequestAttributes;
58-
tf @9 :RequestAttributes.TensorFlowRequestAttributes;
59-
none @10 :Void;
50+
torch @7 :RequestAttributes.TorchRequestAttributes;
51+
tf @8 :RequestAttributes.TensorFlowRequestAttributes;
52+
none @9 :Void;
6053
}
6154
}

smartsim/_core/mli/mli_schemas/request/request_capnp.pyi

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ from .request_attributes.request_attributes_capnp import (
3333
TorchRequestAttributesReader,
3434
)
3535

36-
Device = Literal["cpu", "gpu", "auto"]
37-
3836
class ChannelDescriptor:
3937
reply: bytes
4038
@staticmethod
@@ -215,7 +213,6 @@ class Request:
215213
def write_packed(file: BufferedWriter) -> None: ...
216214
replyChannel: ChannelDescriptor | ChannelDescriptorBuilder | ChannelDescriptorReader
217215
model: Request.Model | Request.ModelBuilder | Request.ModelReader
218-
device: Device
219216
input: Request.Input | Request.InputBuilder | Request.InputReader
220217
output: Sequence[TensorKey | TensorKeyBuilder | TensorKeyReader]
221218
outputDescriptors: Sequence[

tests/mli/test_integrated_torch_worker.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def persist_torch_model(test_dir: str) -> pathlib.Path:
6666
# model_bytes = persist_torch_model.read_bytes()
6767
# input_tensor = torch.randn(2)
6868

69-
# expected_device = "cpu"
7069
# expected_callback_channel = b"faux_channel_descriptor_bytes"
7170
# callback_channel = mli.DragonCommChannel.find(expected_callback_channel)
7271

@@ -77,7 +76,6 @@ def persist_torch_model(test_dir: str) -> pathlib.Path:
7776
# request = MessageHandler.build_request(
7877
# reply_channel=callback_channel.descriptor,
7978
# model=model_bytes,
80-
# device=expected_device,
8179
# inputs=[message_tensor_input],
8280
# outputs=[],
8381
# custom_attributes=None,
@@ -86,7 +84,6 @@ def persist_torch_model(test_dir: str) -> pathlib.Path:
8684
# msg_bytes = MessageHandler.serialize_request(request)
8785

8886
# inference_request = worker.deserialize(msg_bytes)
89-
# assert inference_request.device == expected_device
9087
# assert inference_request.callback._descriptor == expected_callback_channel
9188

9289

@@ -104,7 +101,6 @@ def persist_torch_model(test_dir: str) -> pathlib.Path:
104101
# # input_tensor = torch.randn(2)
105102
# # feature_store[input_key] = input_tensor
106103

107-
# expected_device = "cpu"
108104
# expected_callback_channel = b"faux_channel_descriptor_bytes"
109105
# callback_channel = mli.DragonCommChannel.find(expected_callback_channel)
110106

@@ -117,7 +113,6 @@ def persist_torch_model(test_dir: str) -> pathlib.Path:
117113
# request = MessageHandler.build_request(
118114
# reply_channel=callback_channel.descriptor,
119115
# model=message_model_key,
120-
# device=expected_device,
121116
# inputs=[message_tensor_input_key],
122117
# outputs=[message_tensor_output_key],
123118
# custom_attributes=None,
@@ -126,7 +121,6 @@ def persist_torch_model(test_dir: str) -> pathlib.Path:
126121
# msg_bytes = MessageHandler.serialize_request(request)
127122

128123
# inference_request = worker.deserialize(msg_bytes)
129-
# assert inference_request.device == expected_device
130124
# assert inference_request.callback._descriptor == expected_callback_channel
131125

132126

@@ -147,7 +141,6 @@ def persist_torch_model(test_dir: str) -> pathlib.Path:
147141
# # input_tensor = torch.randn(2)
148142
# # feature_store[input_key] = input_tensor
149143

150-
# expected_device = "cpu"
151144
# expected_callback_channel = b"faux_channel_descriptor_bytes"
152145
# callback_channel = mli.DragonCommChannel.find(expected_callback_channel)
153146

@@ -160,7 +153,6 @@ def persist_torch_model(test_dir: str) -> pathlib.Path:
160153
# request = MessageHandler.build_request(
161154
# reply_channel=callback_channel.descriptor,
162155
# model=model_bytes,
163-
# device=expected_device,
164156
# inputs=[message_tensor_input_key],
165157
# # outputs=[message_tensor_output_key],
166158
# outputs=[],
@@ -170,7 +162,6 @@ def persist_torch_model(test_dir: str) -> pathlib.Path:
170162
# msg_bytes = MessageHandler.serialize_request(request)
171163

172164
# inference_request = worker.deserialize(msg_bytes)
173-
# assert inference_request.device == expected_device
174165
# assert inference_request.callback._descriptor == expected_callback_channel
175166

176167

@@ -191,7 +182,6 @@ def persist_torch_model(test_dir: str) -> pathlib.Path:
191182
# input_tensor = torch.randn(2)
192183
# # feature_store[input_key] = input_tensor
193184

194-
# expected_device = "cpu"
195185
# expected_callback_channel = b"faux_channel_descriptor_bytes"
196186
# callback_channel = mli.DragonCommChannel.find(expected_callback_channel)
197187

@@ -207,7 +197,6 @@ def persist_torch_model(test_dir: str) -> pathlib.Path:
207197
# request = MessageHandler.build_request(
208198
# reply_channel=callback_channel.descriptor,
209199
# model=model_bytes,
210-
# device=expected_device,
211200
# inputs=[message_tensor_input],
212201
# # outputs=[message_tensor_output_key],
213202
# outputs=[message_tensor_output_key],
@@ -217,7 +206,6 @@ def persist_torch_model(test_dir: str) -> pathlib.Path:
217206
# msg_bytes = MessageHandler.serialize_request(request)
218207

219208
# inference_request = worker.deserialize(msg_bytes)
220-
# assert inference_request.device == expected_device
221209
# assert inference_request.callback._descriptor == expected_callback_channel
222210

223211

@@ -238,7 +226,6 @@ def persist_torch_model(test_dir: str) -> pathlib.Path:
238226
# input_tensor = torch.randn(2)
239227
# # feature_store[input_key] = input_tensor
240228

241-
# expected_device = "cpu"
242229
# expected_callback_channel = b"faux_channel_descriptor_bytes"
243230
# callback_channel = mli.DragonCommChannel.find(expected_callback_channel)
244231

@@ -254,7 +241,6 @@ def persist_torch_model(test_dir: str) -> pathlib.Path:
254241
# request = MessageHandler.build_request(
255242
# reply_channel=callback_channel.descriptor,
256243
# model=message_model_key,
257-
# device=expected_device,
258244
# inputs=[message_tensor_input],
259245
# # outputs=[message_tensor_output_key],
260246
# outputs=[],
@@ -264,7 +250,6 @@ def persist_torch_model(test_dir: str) -> pathlib.Path:
264250
# msg_bytes = MessageHandler.serialize_request(request)
265251

266252
# inference_request = worker.deserialize(msg_bytes)
267-
# assert inference_request.device == expected_device
268253
# assert inference_request.callback._descriptor == expected_callback_channel
269254

270255

tests/mli/test_worker_manager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def mock_messages(
122122
# working set size > 1 has side-effects
123123
# only incurs cost when working set size has been exceeded
124124

125-
expected_device: t.Literal["cpu", "gpu"] = "cpu"
126125
channel_key = comm_channel_root_dir / f"{iteration_number}/channel.txt"
127126
callback_channel = FileSystemCommChannel(pathlib.Path(channel_key))
128127

@@ -144,7 +143,6 @@ def mock_messages(
144143
request = MessageHandler.build_request(
145144
reply_channel=callback_channel.descriptor,
146145
model=message_model_key,
147-
device=expected_device,
148146
inputs=[message_tensor_input_key],
149147
outputs=[message_tensor_output_key],
150148
custom_attributes=None,

0 commit comments

Comments
 (0)