Skip to content

Commit ca01cb1

Browse files
authored
Add integration of dragon-based event broadcasting (#710)
This PR integrates event publishers and consumers in `ProtoClient` and `DragonBackend` [ committed by @ankona] [ reviewed by @al-rigazzi @mellis13 @amandarichardsonn ]
1 parent 5ec287c commit ca01cb1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+5328
-1573
lines changed

conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
test_hostlist = None
9494
has_aprun = shutil.which("aprun") is not None
9595

96+
9697
def get_account() -> str:
9798
return test_account
9899

@@ -227,7 +228,6 @@ def kill_all_test_spawned_processes() -> None:
227228
print("Not all processes were killed after test")
228229

229230

230-
231231
def get_hostlist() -> t.Optional[t.List[str]]:
232232
global test_hostlist
233233
if not test_hostlist:

doc/changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ Jump to:
1313

1414
Description
1515

16+
- Implement asynchronous notifications for shared data
1617
- Quick bug fix in _validate
1718
- Add helper methods to MLI classes
1819
- Update error handling for consistency
1920
- Parameterize installation of dragon package with `smart build`
2021
- Update docstrings
21-
- Implement asynchronous notifications for shared data
2222
- Filenames conform to snake case
2323
- Update SmartSim environment variables using new naming convention
2424
- Refactor `exception_handler`

ex/high_throughput_inference/mock_app.py

Lines changed: 39 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -37,102 +37,35 @@
3737

3838
import argparse
3939
import io
40-
import numpy
41-
import os
42-
import time
40+
4341
import torch
4442

45-
from mpi4py import MPI
46-
from smartsim._core.mli.infrastructure.storage.dragon_feature_store import (
47-
DragonFeatureStore,
48-
)
49-
from smartsim._core.mli.message_handler import MessageHandler
5043
from smartsim.log import get_logger
51-
from smartsim._core.utils.timings import PerfTimer
5244

5345
torch.set_num_interop_threads(16)
5446
torch.set_num_threads(1)
5547

5648
logger = get_logger("App")
5749
logger.info("Started app")
5850

59-
CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False
51+
from collections import OrderedDict
6052

61-
class ProtoClient:
62-
def __init__(self, timing_on: bool):
63-
comm = MPI.COMM_WORLD
64-
rank = comm.Get_rank()
65-
connect_to_infrastructure()
66-
ddict_str = os.environ["_SMARTSIM_INFRA_BACKBONE"]
67-
self._ddict = DDict.attach(ddict_str)
68-
self._backbone_descriptor = DragonFeatureStore(self._ddict).descriptor
69-
to_worker_fli_str = None
70-
while to_worker_fli_str is None:
71-
try:
72-
to_worker_fli_str = self._ddict["to_worker_fli"]
73-
self._to_worker_fli = fli.FLInterface.attach(to_worker_fli_str)
74-
except KeyError:
75-
time.sleep(1)
76-
self._from_worker_ch = Channel.make_process_local()
77-
self._from_worker_ch_serialized = self._from_worker_ch.serialize()
78-
self._to_worker_ch = Channel.make_process_local()
79-
80-
self.perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"a{rank}_")
81-
82-
def run_model(self, model: bytes | str, batch: torch.Tensor):
83-
tensors = [batch.numpy()]
84-
self.perf_timer.start_timings("batch_size", batch.shape[0])
85-
built_tensor_desc = MessageHandler.build_tensor_descriptor(
86-
"c", "float32", list(batch.shape)
87-
)
88-
self.perf_timer.measure_time("build_tensor_descriptor")
89-
if isinstance(model, str):
90-
model_arg = MessageHandler.build_model_key(model, self._backbone_descriptor)
91-
else:
92-
model_arg = MessageHandler.build_model(model, "resnet-50", "1.0")
93-
request = MessageHandler.build_request(
94-
reply_channel=self._from_worker_ch_serialized,
95-
model=model_arg,
96-
inputs=[built_tensor_desc],
97-
outputs=[],
98-
output_descriptors=[],
99-
custom_attributes=None,
100-
)
101-
self.perf_timer.measure_time("build_request")
102-
request_bytes = MessageHandler.serialize_request(request)
103-
self.perf_timer.measure_time("serialize_request")
104-
with self._to_worker_fli.sendh(timeout=None, stream_channel=self._to_worker_ch) as to_sendh:
105-
to_sendh.send_bytes(request_bytes)
106-
self.perf_timer.measure_time("send_request")
107-
for tensor in tensors:
108-
to_sendh.send_bytes(tensor.tobytes()) #TODO NOT FAST ENOUGH!!!
109-
self.perf_timer.measure_time("send_tensors")
110-
with self._from_worker_ch.recvh(timeout=None) as from_recvh:
111-
resp = from_recvh.recv_bytes(timeout=None)
112-
self.perf_timer.measure_time("receive_response")
113-
response = MessageHandler.deserialize_response(resp)
114-
self.perf_timer.measure_time("deserialize_response")
115-
# list of data blobs? recv depending on the len(response.result.descriptors)?
116-
data_blob: bytes = from_recvh.recv_bytes(timeout=None)
117-
self.perf_timer.measure_time("receive_tensor")
118-
result = torch.from_numpy(
119-
numpy.frombuffer(
120-
data_blob,
121-
dtype=str(response.result.descriptors[0].dataType),
122-
)
123-
)
124-
self.perf_timer.measure_time("deserialize_tensor")
53+
from smartsim.log import get_logger, log_to_file
54+
from smartsim._core.mli.client.protoclient import ProtoClient
12555

126-
self.perf_timer.end_timings()
127-
return result
56+
logger = get_logger("App")
12857

129-
def set_model(self, key: str, model: bytes):
130-
self._ddict[key] = model
13158

59+
CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False
13260

13361

13462
class ResNetWrapper:
63+
"""Wrapper around a pre-rained ResNet model."""
13564
def __init__(self, name: str, model: str):
65+
"""Initialize the instance.
66+
67+
:param name: The name to use for the model
68+
:param model: The path to the pre-trained PyTorch model"""
13669
self._model = torch.jit.load(model)
13770
self._name = name
13871
buffer = io.BytesIO()
@@ -141,16 +74,28 @@ def __init__(self, name: str, model: str):
14174
self._serialized_model = buffer.getvalue()
14275

14376
def get_batch(self, batch_size: int = 32):
77+
"""Create a random batch of data with the correct dimensions to
78+
invoke a ResNet model.
79+
80+
:param batch_size: The desired number of samples to produce
81+
:returns: A PyTorch tensor"""
14482
return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32)
14583

14684
@property
147-
def model(self):
85+
def model(self) -> bytes:
86+
"""The content of a model file.
87+
88+
:returns: The model bytes"""
14889
return self._serialized_model
14990

15091
@property
151-
def name(self):
92+
def name(self) -> str:
93+
"""The name applied to the model.
94+
95+
:returns: The name"""
15296
return self._name
15397

98+
15499
if __name__ == "__main__":
155100

156101
parser = argparse.ArgumentParser("Mock application")
@@ -166,24 +111,32 @@ def name(self):
166111
if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
167112
# TODO: adapt to non-Nvidia devices
168113
torch_device = args.device.replace("gpu", "cuda")
169-
pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to(torch_device)
114+
pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to(
115+
torch_device
116+
)
170117

171118
TOTAL_ITERATIONS = 100
172119

173-
for log2_bsize in range(args.log_max_batchsize+1):
120+
for log2_bsize in range(args.log_max_batchsize + 1):
174121
b_size: int = 2**log2_bsize
175122
logger.info(f"Batch size: {b_size}")
176-
for iteration_number in range(TOTAL_ITERATIONS + int(b_size==1)):
123+
for iteration_number in range(TOTAL_ITERATIONS + int(b_size == 1)):
177124
logger.info(f"Iteration: {iteration_number}")
178125
sample_batch = resnet.get_batch(b_size)
179126
remote_result = client.run_model(resnet.name, sample_batch)
180127
logger.info(client.perf_timer.get_last("total_time"))
181128
if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
182129
local_res = pt_model(sample_batch.to(torch_device))
183-
err_norm = torch.linalg.vector_norm(torch.flatten(remote_result).to(torch_device)-torch.flatten(local_res), ord=1).cpu()
130+
err_norm = torch.linalg.vector_norm(
131+
torch.flatten(remote_result).to(torch_device)
132+
- torch.flatten(local_res),
133+
ord=1,
134+
).cpu()
184135
res_norm = torch.linalg.vector_norm(remote_result, ord=1).item()
185136
local_res_norm = torch.linalg.vector_norm(local_res, ord=1).item()
186-
logger.info(f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}")
137+
logger.info(
138+
f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}"
139+
)
187140
torch.cuda.synchronize()
188141

189-
client.perf_timer.print_timings(to_file=True)
142+
client.perf_timer.print_timings(to_file=True)

ex/high_throughput_inference/standalone_worker_manager.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from dragon.globalservices.api_setup import connect_to_infrastructure
3838
from dragon.managed_memory import MemoryPool
3939
from dragon.utils import b64decode, b64encode
40+
4041
# pylint enable=import-error
4142

4243
# isort: off
@@ -46,33 +47,27 @@
4647
import base64
4748
import multiprocessing as mp
4849
import os
49-
import pickle
5050
import socket
51-
import sys
5251
import time
5352
import typing as t
5453

5554
import cloudpickle
56-
import optparse
57-
import os
5855

5956
from smartsim._core.entrypoints.service import Service
60-
from smartsim._core.mli.comm.channel.channel import CommChannelBase
6157
from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel
6258
from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel
63-
from smartsim._core.mli.infrastructure.storage.dragon_feature_store import (
64-
DragonFeatureStore,
65-
)
59+
from smartsim._core.mli.comm.channel.dragon_util import create_local
6660
from smartsim._core.mli.infrastructure.control.request_dispatcher import (
6761
RequestDispatcher,
6862
)
6963
from smartsim._core.mli.infrastructure.control.worker_manager import WorkerManager
7064
from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader
65+
from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
66+
BackboneFeatureStore,
67+
)
7168
from smartsim._core.mli.infrastructure.storage.dragon_feature_store import (
7269
DragonFeatureStore,
7370
)
74-
from smartsim._core.mli.infrastructure.worker.worker import MachineLearningWorkerBase
75-
7671
from smartsim.log import get_logger
7772

7873
logger = get_logger("Worker Manager Entry Point")
@@ -85,7 +80,6 @@
8580
logger.info(f"CPUS: {os.cpu_count()}")
8681

8782

88-
8983
def service_as_dragon_proc(
9084
service: Service, cpu_affinity: list[int], gpu_affinity: list[int]
9185
) -> dragon_process.Process:
@@ -108,8 +102,6 @@ def service_as_dragon_proc(
108102
)
109103

110104

111-
112-
113105
if __name__ == "__main__":
114106
parser = argparse.ArgumentParser("Worker Manager")
115107
parser.add_argument(
@@ -143,27 +135,26 @@ def service_as_dragon_proc(
143135
args = parser.parse_args()
144136

145137
connect_to_infrastructure()
146-
ddict_str = os.environ["_SMARTSIM_INFRA_BACKBONE"]
147-
ddict = DDict.attach(ddict_str)
138+
ddict_str = os.environ[BackboneFeatureStore.MLI_BACKBONE]
139+
140+
backbone = BackboneFeatureStore.from_descriptor(ddict_str)
148141

149-
to_worker_channel = Channel.make_process_local()
142+
to_worker_channel = create_local()
150143
to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None)
151-
to_worker_fli_serialized = to_worker_fli.serialize()
152-
ddict["to_worker_fli"] = to_worker_fli_serialized
144+
to_worker_fli_comm_ch = DragonFLIChannel(to_worker_fli)
145+
146+
backbone.worker_queue = to_worker_fli_comm_ch.descriptor
147+
148+
os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = to_worker_fli_comm_ch.descriptor
149+
os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor
153150

154151
arg_worker_type = cloudpickle.loads(
155152
base64.b64decode(args.worker_class.encode("ascii"))
156153
)
157154

158-
dfs = DragonFeatureStore(ddict)
159-
comm_channel = DragonFLIChannel(to_worker_fli_serialized)
160-
161-
descriptor = base64.b64encode(to_worker_fli_serialized).decode("utf-8")
162-
os.environ["_SMARTSIM_REQUEST_QUEUE"] = descriptor
163-
164155
config_loader = EnvironmentConfigLoader(
165156
featurestore_factory=DragonFeatureStore.from_descriptor,
166-
callback_factory=DragonCommChannel,
157+
callback_factory=DragonCommChannel.from_descriptor,
167158
queue_factory=DragonFLIChannel.from_descriptor,
168159
)
169160

@@ -178,7 +169,7 @@ def service_as_dragon_proc(
178169
worker_device = args.device
179170
for wm_idx in range(args.num_workers):
180171

181-
worker_manager = WorkerManager(
172+
worker_manager = WorkerManager(
182173
config_loader=config_loader,
183174
worker_type=arg_worker_type,
184175
as_service=True,
@@ -196,21 +187,25 @@ def service_as_dragon_proc(
196187
# the GPU-to-CPU mapping is taken from the nvidia-smi tool
197188
# TODO can this be computed on the fly?
198189
gpu_to_cpu_aff: dict[int, list[int]] = {}
199-
gpu_to_cpu_aff[0] = list(range(48,64)) + list(range(112,128))
200-
gpu_to_cpu_aff[1] = list(range(32,48)) + list(range(96,112))
201-
gpu_to_cpu_aff[2] = list(range(16,32)) + list(range(80,96))
202-
gpu_to_cpu_aff[3] = list(range(0,16)) + list(range(64,80))
190+
gpu_to_cpu_aff[0] = list(range(48, 64)) + list(range(112, 128))
191+
gpu_to_cpu_aff[1] = list(range(32, 48)) + list(range(96, 112))
192+
gpu_to_cpu_aff[2] = list(range(16, 32)) + list(range(80, 96))
193+
gpu_to_cpu_aff[3] = list(range(0, 16)) + list(range(64, 80))
203194

204195
worker_manager_procs = []
205196
for worker_idx in range(args.num_workers):
206197
wm_cpus = len(gpu_to_cpu_aff[worker_idx]) - 4
207198
wm_affinity = gpu_to_cpu_aff[worker_idx][:wm_cpus]
208199
disp_affinity.extend(gpu_to_cpu_aff[worker_idx][wm_cpus:])
209-
worker_manager_procs.append(service_as_dragon_proc(
200+
worker_manager_procs.append(
201+
service_as_dragon_proc(
210202
worker_manager, cpu_affinity=wm_affinity, gpu_affinity=[worker_idx]
211-
))
203+
)
204+
)
212205

213-
dispatcher_proc = service_as_dragon_proc(dispatcher, cpu_affinity=disp_affinity, gpu_affinity=[])
206+
dispatcher_proc = service_as_dragon_proc(
207+
dispatcher, cpu_affinity=disp_affinity, gpu_affinity=[]
208+
)
214209

215210
# TODO: use ProcessGroup and restart=True?
216211
all_procs = [dispatcher_proc, *worker_manager_procs]

smartsim/_core/_cli/scripts/dragon_install.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
def _check(self) -> None:
5858
"""Perform validation of this instance
5959
60-
:raises: ValueError if any value fails validation"""
60+
:raises ValueError: if any value fails validation"""
6161
if not self.repo_name or len(self.repo_name.split("/")) != 2:
6262
raise ValueError(
6363
f"Invalid dragon repository name. Example: `dragonhpc/dragon`"
@@ -95,13 +95,13 @@ def get_auth_token(request: DragonInstallRequest) -> t.Optional[Token]:
9595
def create_dotenv(dragon_root_dir: pathlib.Path, dragon_version: str) -> None:
9696
"""Create a .env file with required environment variables for the Dragon runtime"""
9797
dragon_root = str(dragon_root_dir)
98-
dragon_inc_dir = str(dragon_root_dir / "include")
99-
dragon_lib_dir = str(dragon_root_dir / "lib")
100-
dragon_bin_dir = str(dragon_root_dir / "bin")
98+
dragon_inc_dir = dragon_root + "/include"
99+
dragon_lib_dir = dragon_root + "/lib"
100+
dragon_bin_dir = dragon_root + "/bin"
101101

102102
dragon_vars = {
103103
"DRAGON_BASE_DIR": dragon_root,
104-
"DRAGON_ROOT_DIR": dragon_root, # note: same as base_dir
104+
"DRAGON_ROOT_DIR": dragon_root,
105105
"DRAGON_INCLUDE_DIR": dragon_inc_dir,
106106
"DRAGON_LIB_DIR": dragon_lib_dir,
107107
"DRAGON_VERSION": dragon_version,
@@ -286,7 +286,7 @@ def retrieve_asset(
286286
:param request: details of a request for the installation of the dragon package
287287
:param asset: GitHub release asset to retrieve
288288
:returns: path to the directory containing the extracted release asset
289-
:raises: SmartSimCLIActionCancelled if the asset cannot be downloaded or extracted
289+
:raises SmartSimCLIActionCancelled: if the asset cannot be downloaded or extracted
290290
"""
291291
download_dir = request.working_dir / str(asset.id)
292292

0 commit comments

Comments
 (0)