Skip to content

Commit 5d85995

Browse files
authored
Queue-based Worker Manager (#647)
This PR adds the `RequestDispatcher` to the MLI. The `RequestDispatcher` batches inference requests together and dispatches batches to `WorkerManagers`. [ committed by @al-rigazzi ] [ reviewed by @mellis13 @ankona @AlyssaCote ]
1 parent 6d5518b commit 5d85995

26 files changed

+2426
-655
lines changed

doc/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Jump to:
1313

1414
Description
1515

16+
- Add RequestDispatcher and the possibility of batching inference requests
1617
- Enable hostname selection for dragon tasks
1718
- Remove pydantic dependency from MLI code
1819
- Update MLI environment variables using new naming convention
Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1-
import argparse
21
import os
32
import base64
43
import cloudpickle
54
import sys
65
from smartsim import Experiment
76
from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker
87
from smartsim.status import TERMINAL_STATUSES
8+
from smartsim.settings import DragonRunSettings
99
import time
1010
import typing as t
1111

12-
device = "gpu"
12+
DEVICE = "gpu"
13+
NUM_RANKS = 4
14+
NUM_WORKERS = 1
1315
filedir = os.path.dirname(__file__)
1416
worker_manager_script_name = os.path.join(filedir, "standalone_workermanager.py")
1517
app_script_name = os.path.join(filedir, "mock_app.py")
16-
model_name = os.path.join(filedir, f"resnet50.{device.upper()}.pt")
18+
model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt")
1719

1820
transport: t.Literal["hsta", "tcp"] = "hsta"
1921

@@ -25,37 +27,51 @@
2527

2628
torch_worker_str = base64.b64encode(cloudpickle.dumps(TorchWorker)).decode("ascii")
2729

28-
worker_manager_rs = exp.create_run_settings(
30+
worker_manager_rs: DragonRunSettings = exp.create_run_settings(
2931
sys.executable,
3032
[
3133
worker_manager_script_name,
3234
"--device",
33-
device,
35+
DEVICE,
3436
"--worker_class",
3537
torch_worker_str,
38+
"--batch_size",
39+
str(NUM_RANKS//NUM_WORKERS),
40+
"--batch_timeout",
41+
str(0.00),
42+
"--num_workers",
43+
str(NUM_WORKERS)
3644
],
3745
)
46+
47+
aff = []
48+
49+
worker_manager_rs.set_cpu_affinity(aff)
50+
3851
worker_manager = exp.create_model("worker_manager", run_settings=worker_manager_rs)
3952
worker_manager.attach_generator_files(to_copy=[worker_manager_script_name])
4053

41-
app_rs = exp.create_run_settings(
54+
app_rs: DragonRunSettings = exp.create_run_settings(
4255
sys.executable,
43-
exe_args=[app_script_name, "--device", device],
56+
exe_args=[app_script_name, "--device", DEVICE, "--log_max_batchsize", str(6)],
4457
)
58+
app_rs.set_tasks_per_node(NUM_RANKS)
59+
60+
4561
app = exp.create_model("app", run_settings=app_rs)
4662
app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name])
4763

48-
4964
exp.generate(worker_manager, app, overwrite=True)
5065
exp.start(worker_manager, app, block=False)
5166

5267
while True:
5368
if exp.get_status(app)[0] in TERMINAL_STATUSES:
69+
time.sleep(10)
5470
exp.stop(worker_manager)
5571
break
5672
if exp.get_status(worker_manager)[0] in TERMINAL_STATUSES:
73+
time.sleep(10)
5774
exp.stop(app)
5875
break
59-
time.sleep(5)
6076

6177
print("Exiting.")

ex/high_throughput_inference/mock_app.py

Lines changed: 56 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,27 @@
4141
import os
4242
import time
4343
import torch
44-
import numbers
4544

46-
from collections import OrderedDict
45+
from mpi4py import MPI
4746
from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import (
4847
DragonFeatureStore,
4948
)
5049
from smartsim._core.mli.message_handler import MessageHandler
5150
from smartsim.log import get_logger
51+
from smartsim._core.utils.timings import PerfTimer
52+
53+
torch.set_num_interop_threads(16)
54+
torch.set_num_threads(1)
5255

5356
logger = get_logger("App")
57+
logger.info("Started app")
5458

59+
CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False
5560

5661
class ProtoClient:
5762
def __init__(self, timing_on: bool):
63+
comm = MPI.COMM_WORLD
64+
rank = comm.Get_rank()
5865
connect_to_infrastructure()
5966
ddict_str = os.environ["_SMARTSIM_INFRA_BACKBONE"]
6067
self._ddict = DDict.attach(ddict_str)
@@ -70,61 +77,15 @@ def __init__(self, timing_on: bool):
7077
self._from_worker_ch_serialized = self._from_worker_ch.serialize()
7178
self._to_worker_ch = Channel.make_process_local()
7279

73-
self._start = None
74-
self._interm = None
75-
self._timings: OrderedDict[str, list[numbers.Number]] = OrderedDict()
76-
self._timing_on = timing_on
77-
78-
def _add_label_to_timings(self, label: str):
79-
if label not in self._timings:
80-
self._timings[label] = []
81-
82-
@staticmethod
83-
def _format_number(number: numbers.Number):
84-
return f"{number:0.4e}"
85-
86-
def start_timings(self, batch_size: int):
87-
if self._timing_on:
88-
self._add_label_to_timings("batch_size")
89-
self._timings["batch_size"].append(batch_size)
90-
self._start = time.perf_counter()
91-
self._interm = time.perf_counter()
92-
93-
def end_timings(self):
94-
if self._timing_on:
95-
self._add_label_to_timings("total_time")
96-
self._timings["total_time"].append(
97-
self._format_number(time.perf_counter() - self._start)
98-
)
99-
100-
def measure_time(self, label: str):
101-
if self._timing_on:
102-
self._add_label_to_timings(label)
103-
self._timings[label].append(
104-
self._format_number(time.perf_counter() - self._interm)
105-
)
106-
self._interm = time.perf_counter()
107-
108-
def print_timings(self, to_file: bool = False):
109-
print(" ".join(self._timings.keys()))
110-
value_array = numpy.array(
111-
[value for value in self._timings.values()], dtype=float
112-
)
113-
value_array = numpy.transpose(value_array)
114-
for i in range(value_array.shape[0]):
115-
print(" ".join(self._format_number(value) for value in value_array[i]))
116-
if to_file:
117-
numpy.save("timings.npy", value_array)
118-
numpy.savetxt("timings.txt", value_array)
80+
self.perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"a{rank}_")
11981

12082
def run_model(self, model: bytes | str, batch: torch.Tensor):
12183
tensors = [batch.numpy()]
122-
self.start_timings(batch.shape[0])
84+
self.perf_timer.start_timings("batch_size", batch.shape[0])
12385
built_tensor_desc = MessageHandler.build_tensor_descriptor(
12486
"c", "float32", list(batch.shape)
12587
)
126-
self.measure_time("build_tensor_descriptor")
127-
built_model = None
88+
self.perf_timer.measure_time("build_tensor_descriptor")
12889
if isinstance(model, str):
12990
model_arg = MessageHandler.build_model_key(model, self._backbone_descriptor)
13091
else:
@@ -137,39 +98,39 @@ def run_model(self, model: bytes | str, batch: torch.Tensor):
13798
output_descriptors=[],
13899
custom_attributes=None,
139100
)
140-
self.measure_time("build_request")
101+
self.perf_timer.measure_time("build_request")
141102
request_bytes = MessageHandler.serialize_request(request)
142-
self.measure_time("serialize_request")
143-
with self._to_worker_fli.sendh(
144-
timeout=None, stream_channel=self._to_worker_ch
145-
) as to_sendh:
103+
self.perf_timer.measure_time("serialize_request")
104+
with self._to_worker_fli.sendh(timeout=None, stream_channel=self._to_worker_ch) as to_sendh:
146105
to_sendh.send_bytes(request_bytes)
147-
for t in tensors:
148-
to_sendh.send_bytes(t.tobytes()) # TODO NOT FAST ENOUGH!!!
149-
# to_sendh.send_bytes(bytes(t.data))
150-
logger.info(f"Message size: {len(request_bytes)} bytes")
151-
152-
self.measure_time("send")
106+
self.perf_timer.measure_time("send_request")
107+
for tensor in tensors:
108+
to_sendh.send_bytes(tensor.tobytes()) #TODO NOT FAST ENOUGH!!!
109+
self.perf_timer.measure_time("send_tensors")
153110
with self._from_worker_ch.recvh(timeout=None) as from_recvh:
154111
resp = from_recvh.recv_bytes(timeout=None)
155-
self.measure_time("receive")
112+
self.perf_timer.measure_time("receive_response")
156113
response = MessageHandler.deserialize_response(resp)
157-
self.measure_time("deserialize_response")
114+
self.perf_timer.measure_time("deserialize_response")
158115
# list of data blobs? recv depending on the len(response.result.descriptors)?
159-
data_blob = from_recvh.recv_bytes(timeout=None)
160-
result = numpy.frombuffer(
161-
data_blob,
162-
dtype=str(response.result.descriptors[0].dataType),
116+
data_blob: bytes = from_recvh.recv_bytes(timeout=None)
117+
self.perf_timer.measure_time("receive_tensor")
118+
result = torch.from_numpy(
119+
numpy.frombuffer(
120+
data_blob,
121+
dtype=str(response.result.descriptors[0].dataType),
122+
)
163123
)
164-
self.measure_time("deserialize_tensor")
124+
self.perf_timer.measure_time("deserialize_tensor")
165125

166-
self.end_timings()
126+
self.perf_timer.end_timings()
167127
return result
168128

169129
def set_model(self, key: str, model: bytes):
170130
self._ddict[key] = model
171131

172132

133+
173134
class ResNetWrapper:
174135
def __init__(self, name: str, model: str):
175136
self._model = torch.jit.load(model)
@@ -190,24 +151,39 @@ def model(self):
190151
def name(self):
191152
return self._name
192153

193-
194154
if __name__ == "__main__":
195155

196156
parser = argparse.ArgumentParser("Mock application")
197-
parser.add_argument("--device", default="cpu")
157+
parser.add_argument("--device", default="cpu", type=str)
158+
parser.add_argument("--log_max_batchsize", default=8, type=int)
198159
args = parser.parse_args()
199160

200-
resnet = ResNetWrapper("resnet50", f"resnet50.{args.device.upper()}.pt")
161+
resnet = ResNetWrapper("resnet50", f"resnet50.{args.device}.pt")
201162

202163
client = ProtoClient(timing_on=True)
203164
client.set_model(resnet.name, resnet.model)
204165

205-
total_iterations = 100
166+
if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
167+
# TODO: adapt to non-Nvidia devices
168+
torch_device = args.device.replace("gpu", "cuda")
169+
pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to(torch_device)
206170

207-
for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:
208-
logger.info(f"Batch size: {batch_size}")
209-
for iteration_number in range(total_iterations + int(batch_size == 1)):
210-
logger.info(f"Iteration: {iteration_number}")
211-
client.run_model(resnet.name, resnet.get_batch(batch_size))
171+
TOTAL_ITERATIONS = 100
212172

213-
client.print_timings(to_file=True)
173+
for log2_bsize in range(args.log_max_batchsize+1):
174+
b_size: int = 2**log2_bsize
175+
logger.info(f"Batch size: {b_size}")
176+
for iteration_number in range(TOTAL_ITERATIONS + int(b_size==1)):
177+
logger.info(f"Iteration: {iteration_number}")
178+
sample_batch = resnet.get_batch(b_size)
179+
remote_result = client.run_model(resnet.name, sample_batch)
180+
logger.info(client.perf_timer.get_last("total_time"))
181+
if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
182+
local_res = pt_model(sample_batch.to(torch_device))
183+
err_norm = torch.linalg.vector_norm(torch.flatten(remote_result).to(torch_device)-torch.flatten(local_res), ord=1).cpu()
184+
res_norm = torch.linalg.vector_norm(remote_result, ord=1).item()
185+
local_res_norm = torch.linalg.vector_norm(local_res, ord=1).item()
186+
logger.info(f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}")
187+
torch.cuda.synchronize()
188+
189+
client.perf_timer.print_timings(to_file=True)

ex/high_throughput_inference/mock_app_redis.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
import numpy
3030
import time
3131
import torch
32+
from mpi4py import MPI
3233
from smartsim.log import get_logger
34+
from smartsim._core.utils.timings import PerfTimer
3335
from smartredis import Client
3436

3537
logger = get_logger("App")
@@ -56,6 +58,9 @@ def name(self):
5658

5759
if __name__ == "__main__":
5860

61+
comm = MPI.COMM_WORLD
62+
rank = comm.Get_rank()
63+
5964
parser = argparse.ArgumentParser("Mock application")
6065
parser.add_argument("--device", default="cpu")
6166
args = parser.parse_args()
@@ -65,24 +70,21 @@ def name(self):
6570
client = Client(cluster=False, address=None)
6671
client.set_model(resnet.name, resnet.model, backend='TORCH', device=args.device.upper())
6772

73+
perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"redis{rank}_")
74+
6875
total_iterations = 100
6976
timings=[]
7077
for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:
7178
logger.info(f"Batch size: {batch_size}")
7279
for iteration_number in range(total_iterations + int(batch_size==1)):
73-
timing = [batch_size]
80+
perf_timer.start_timings("batch_size", batch_size)
7481
logger.info(f"Iteration: {iteration_number}")
75-
start = time.perf_counter()
76-
client.put_tensor(name="batch", data=resnet.get_batch(batch_size).numpy())
77-
client.run_model(name=resnet.name, inputs=["batch"], outputs=["result"])
78-
result = client.get_tensor(name="result")
79-
end = time.perf_counter()
80-
timing.append(end-start)
81-
timings.append(timing)
82-
82+
input_name = f"batch_{rank}"
83+
output_name = f"result_{rank}"
84+
client.put_tensor(name=input_name, data=resnet.get_batch(batch_size).numpy())
85+
client.run_model(name=resnet.name, inputs=[input_name], outputs=[output_name])
86+
result = client.get_tensor(name=output_name)
87+
perf_timer.end_timings()
8388

8489

85-
timings_np = numpy.asarray(timings)
86-
numpy.save("timings.npy", timings_np)
87-
for timing in timings:
88-
print(" ".join(str(t) for t in timing))
90+
perf_timer.print_timings(True)

ex/high_throughput_inference/redis_driver.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,24 @@
2929
from smartsim import Experiment
3030
from smartsim.status import TERMINAL_STATUSES
3131
import time
32-
import typing as t
3332

34-
device = "gpu"
33+
DEVICE = "gpu"
3534
filedir = os.path.dirname(__file__)
3635
app_script_name = os.path.join(filedir, "mock_app_redis.py")
37-
model_name = os.path.join(filedir, f"resnet50.{device.upper()}.pt")
36+
model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt")
3837

3938

40-
exp_path = os.path.join(filedir, "redis_ai")
39+
exp_path = os.path.join(filedir, "redis_ai_multi")
4140
os.makedirs(exp_path, exist_ok=True)
42-
exp = Experiment("redis_ai", launcher="slurm", exp_path=exp_path)
41+
exp = Experiment("redis_ai_multi", launcher="slurm", exp_path=exp_path)
4342

4443
db = exp.create_database(interface="hsn0")
4544

46-
app_rs = exp.create_run_settings(sys.executable, exe_args = [app_script_name, "--device", device])
45+
app_rs = exp.create_run_settings(
46+
sys.executable, exe_args = [app_script_name, "--device", DEVICE]
47+
)
4748
app_rs.set_nodes(1)
48-
app_rs.set_tasks(1)
49+
app_rs.set_tasks(4)
4950
app = exp.create_model("app", run_settings=app_rs)
5051
app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name])
5152

0 commit comments

Comments
 (0)