Skip to content

Commit 9b437cb

Browse files
committed
Squash event integration
1 parent b4798da commit 9b437cb

37 files changed

+2874
-559
lines changed

doc/changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ Jump to:
1313

1414
Description
1515

16+
- Implement asynchronous notifications for shared data
1617
- Parameterize installation of dragon package with `smart build`
1718
- Update docstrings
18-
- Implement asynchronous notifications for shared data
1919
- Filenames conform to snake case
2020
- Update SmartSim environment variables using new naming convention
2121
- Refactor `exception_handler`

ex/high_throughput_inference/mock_app.py

Lines changed: 23 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -37,98 +37,26 @@
3737

3838
import argparse
3939
import io
40-
import numpy
41-
import os
42-
import time
40+
4341
import torch
4442

45-
from mpi4py import MPI
46-
from smartsim._core.mli.infrastructure.storage.dragon_feature_store import (
47-
DragonFeatureStore,
48-
)
49-
from smartsim._core.mli.message_handler import MessageHandler
5043
from smartsim.log import get_logger
51-
from smartsim._core.utils.timings import PerfTimer
5244

5345
torch.set_num_interop_threads(16)
5446
torch.set_num_threads(1)
5547

5648
logger = get_logger("App")
5749
logger.info("Started app")
5850

59-
CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False
51+
from collections import OrderedDict
6052

61-
class ProtoClient:
62-
def __init__(self, timing_on: bool):
63-
comm = MPI.COMM_WORLD
64-
rank = comm.Get_rank()
65-
connect_to_infrastructure()
66-
ddict_str = os.environ["_SMARTSIM_INFRA_BACKBONE"]
67-
self._ddict = DDict.attach(ddict_str)
68-
self._backbone_descriptor = DragonFeatureStore(self._ddict).descriptor
69-
to_worker_fli_str = None
70-
while to_worker_fli_str is None:
71-
try:
72-
to_worker_fli_str = self._ddict["to_worker_fli"]
73-
self._to_worker_fli = fli.FLInterface.attach(to_worker_fli_str)
74-
except KeyError:
75-
time.sleep(1)
76-
self._from_worker_ch = Channel.make_process_local()
77-
self._from_worker_ch_serialized = self._from_worker_ch.serialize()
78-
self._to_worker_ch = Channel.make_process_local()
79-
80-
self.perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"a{rank}_")
81-
82-
def run_model(self, model: bytes | str, batch: torch.Tensor):
83-
tensors = [batch.numpy()]
84-
self.perf_timer.start_timings("batch_size", batch.shape[0])
85-
built_tensor_desc = MessageHandler.build_tensor_descriptor(
86-
"c", "float32", list(batch.shape)
87-
)
88-
self.perf_timer.measure_time("build_tensor_descriptor")
89-
if isinstance(model, str):
90-
model_arg = MessageHandler.build_model_key(model, self._backbone_descriptor)
91-
else:
92-
model_arg = MessageHandler.build_model(model, "resnet-50", "1.0")
93-
request = MessageHandler.build_request(
94-
reply_channel=self._from_worker_ch_serialized,
95-
model=model_arg,
96-
inputs=[built_tensor_desc],
97-
outputs=[],
98-
output_descriptors=[],
99-
custom_attributes=None,
100-
)
101-
self.perf_timer.measure_time("build_request")
102-
request_bytes = MessageHandler.serialize_request(request)
103-
self.perf_timer.measure_time("serialize_request")
104-
with self._to_worker_fli.sendh(timeout=None, stream_channel=self._to_worker_ch) as to_sendh:
105-
to_sendh.send_bytes(request_bytes)
106-
self.perf_timer.measure_time("send_request")
107-
for tensor in tensors:
108-
to_sendh.send_bytes(tensor.tobytes()) #TODO NOT FAST ENOUGH!!!
109-
self.perf_timer.measure_time("send_tensors")
110-
with self._from_worker_ch.recvh(timeout=None) as from_recvh:
111-
resp = from_recvh.recv_bytes(timeout=None)
112-
self.perf_timer.measure_time("receive_response")
113-
response = MessageHandler.deserialize_response(resp)
114-
self.perf_timer.measure_time("deserialize_response")
115-
# list of data blobs? recv depending on the len(response.result.descriptors)?
116-
data_blob: bytes = from_recvh.recv_bytes(timeout=None)
117-
self.perf_timer.measure_time("receive_tensor")
118-
result = torch.from_numpy(
119-
numpy.frombuffer(
120-
data_blob,
121-
dtype=str(response.result.descriptors[0].dataType),
122-
)
123-
)
124-
self.perf_timer.measure_time("deserialize_tensor")
53+
from smartsim.log import get_logger, log_to_file
54+
from smartsim.protoclient import ProtoClient
12555

126-
self.perf_timer.end_timings()
127-
return result
56+
logger = get_logger("App", "DEBUG")
12857

129-
def set_model(self, key: str, model: bytes):
130-
self._ddict[key] = model
13158

59+
CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False
13260

13361

13462
class ResNetWrapper:
@@ -151,6 +79,7 @@ def model(self):
15179
def name(self):
15280
return self._name
15381

82+
15483
if __name__ == "__main__":
15584

15685
parser = argparse.ArgumentParser("Mock application")
@@ -160,30 +89,38 @@ def name(self):
16089

16190
resnet = ResNetWrapper("resnet50", f"resnet50.{args.device}.pt")
16291

163-
client = ProtoClient(timing_on=True)
164-
client.set_model(resnet.name, resnet.model)
92+
client = ProtoClient(timing_on=True, wait_timeout=0)
93+
# client.set_model(resnet.name, resnet.model)
16594

16695
if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
16796
# TODO: adapt to non-Nvidia devices
16897
torch_device = args.device.replace("gpu", "cuda")
169-
pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to(torch_device)
98+
pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to(
99+
torch_device
100+
)
170101

171102
TOTAL_ITERATIONS = 100
172103

173-
for log2_bsize in range(args.log_max_batchsize+1):
104+
for log2_bsize in range(args.log_max_batchsize + 1):
174105
b_size: int = 2**log2_bsize
175106
logger.info(f"Batch size: {b_size}")
176-
for iteration_number in range(TOTAL_ITERATIONS + int(b_size==1)):
107+
for iteration_number in range(TOTAL_ITERATIONS + int(b_size == 1)):
177108
logger.info(f"Iteration: {iteration_number}")
178109
sample_batch = resnet.get_batch(b_size)
179110
remote_result = client.run_model(resnet.name, sample_batch)
180111
logger.info(client.perf_timer.get_last("total_time"))
181112
if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
182113
local_res = pt_model(sample_batch.to(torch_device))
183-
err_norm = torch.linalg.vector_norm(torch.flatten(remote_result).to(torch_device)-torch.flatten(local_res), ord=1).cpu()
114+
err_norm = torch.linalg.vector_norm(
115+
torch.flatten(remote_result).to(torch_device)
116+
- torch.flatten(local_res),
117+
ord=1,
118+
).cpu()
184119
res_norm = torch.linalg.vector_norm(remote_result, ord=1).item()
185120
local_res_norm = torch.linalg.vector_norm(local_res, ord=1).item()
186-
logger.info(f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}")
121+
logger.info(
122+
f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}"
123+
)
187124
torch.cuda.synchronize()
188125

189-
client.perf_timer.print_timings(to_file=True)
126+
client.perf_timer.print_timings(to_file=True)

ex/high_throughput_inference/standalone_worker_manager.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from dragon.globalservices.api_setup import connect_to_infrastructure
3838
from dragon.managed_memory import MemoryPool
3939
from dragon.utils import b64decode, b64encode
40+
4041
# pylint enable=import-error
4142

4243
# isort: off
@@ -45,6 +46,7 @@
4546
import argparse
4647
import base64
4748
import multiprocessing as mp
49+
import optparse
4850
import os
4951
import pickle
5052
import socket
@@ -53,26 +55,24 @@
5355
import typing as t
5456

5557
import cloudpickle
56-
import optparse
57-
import os
5858

5959
from smartsim._core.entrypoints.service import Service
6060
from smartsim._core.mli.comm.channel.channel import CommChannelBase
6161
from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel
6262
from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel
63-
from smartsim._core.mli.infrastructure.storage.dragon_feature_store import (
64-
DragonFeatureStore,
65-
)
6663
from smartsim._core.mli.infrastructure.control.request_dispatcher import (
6764
RequestDispatcher,
6865
)
6966
from smartsim._core.mli.infrastructure.control.worker_manager import WorkerManager
7067
from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader
68+
from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
69+
BackboneFeatureStore,
70+
)
7171
from smartsim._core.mli.infrastructure.storage.dragon_feature_store import (
7272
DragonFeatureStore,
7373
)
74+
from smartsim._core.mli.infrastructure.storage.feature_store import ReservedKeys
7475
from smartsim._core.mli.infrastructure.worker.worker import MachineLearningWorkerBase
75-
7676
from smartsim.log import get_logger
7777

7878
logger = get_logger("Worker Manager Entry Point")
@@ -85,7 +85,6 @@
8585
logger.info(f"CPUS: {os.cpu_count()}")
8686

8787

88-
8988
def service_as_dragon_proc(
9089
service: Service, cpu_affinity: list[int], gpu_affinity: list[int]
9190
) -> dragon_process.Process:
@@ -108,8 +107,6 @@ def service_as_dragon_proc(
108107
)
109108

110109

111-
112-
113110
if __name__ == "__main__":
114111
parser = argparse.ArgumentParser("Worker Manager")
115112
parser.add_argument(
@@ -144,26 +141,24 @@ def service_as_dragon_proc(
144141

145142
connect_to_infrastructure()
146143
ddict_str = os.environ["_SMARTSIM_INFRA_BACKBONE"]
147-
ddict = DDict.attach(ddict_str)
144+
145+
backbone = BackboneFeatureStore.from_descriptor(ddict_str)
148146

149147
to_worker_channel = Channel.make_process_local()
150148
to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None)
151-
to_worker_fli_serialized = to_worker_fli.serialize()
152-
ddict["to_worker_fli"] = to_worker_fli_serialized
149+
to_worker_fli_comm_channel = DragonFLIChannel(to_worker_fli, True)
150+
151+
backbone.worker_queue = to_worker_fli_comm_channel.descriptor
153152

154153
arg_worker_type = cloudpickle.loads(
155154
base64.b64decode(args.worker_class.encode("ascii"))
156155
)
157156

158-
dfs = DragonFeatureStore(ddict)
159-
comm_channel = DragonFLIChannel(to_worker_fli_serialized)
160-
161-
descriptor = base64.b64encode(to_worker_fli_serialized).decode("utf-8")
162-
os.environ["_SMARTSIM_REQUEST_QUEUE"] = descriptor
157+
os.environ["_SMARTSIM_REQUEST_QUEUE"] = to_worker_fli_comm_channel.descriptor
163158

164159
config_loader = EnvironmentConfigLoader(
165160
featurestore_factory=DragonFeatureStore.from_descriptor,
166-
callback_factory=DragonCommChannel,
161+
callback_factory=DragonCommChannel.from_descriptor,
167162
queue_factory=DragonFLIChannel.from_descriptor,
168163
)
169164

0 commit comments

Comments
 (0)