Skip to content

Commit 170c9ea

Browse files
committed
fix send-multiple items behavior with no sender supplied FLI factory
1 parent 608d6bd commit 170c9ea

File tree

5 files changed

+59
-35
lines changed

5 files changed

+59
-35
lines changed

smartsim/_core/mli/comm/channel/dragon_fli.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
# isort: off
2828
from dragon import fli
29-
import dragon.channels as dch
3029

3130
# isort: on
3231

@@ -59,9 +58,6 @@ def __init__(
5958

6059
self._fli = fli_
6160
"""The underlying dragon FLInterface used by this CommChannel for communications"""
62-
self._channel: t.Optional["dch.Channel"] = None
63-
"""The underlying dragon Channel used by a sender-side DragonFLIChannel
64-
to attach to the main FLI channel"""
6561
self._buffer_size: int = buffer_size
6662
"""Maximum number of messages that can be buffered before sending"""
6763

@@ -73,18 +69,36 @@ def send(self, value: bytes, timeout: float = 0.001) -> None:
7369
:raises SmartSimError: If sending message fails
7470
"""
7571
try:
76-
if self._channel is None:
77-
self._channel = drg_util.create_local(self._buffer_size)
72+
channel = drg_util.create_local(self._buffer_size)
7873

79-
with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh:
74+
with self._fli.sendh(timeout=None, stream_channel=channel) as sendh:
8075
sendh.send_bytes(value, timeout=timeout)
8176
logger.debug(f"DragonFLIChannel {self.descriptor} sent message")
8277
except Exception as e:
83-
self._channel = None
8478
raise SmartSimError(
8579
f"Error sending via DragonFLIChannel {self.descriptor}"
8680
) from e
8781

82+
def send_multiple(self, values: t.Sequence[bytes], timeout: float = 0.001) -> None:
83+
"""Send a message through the underlying communication channel.
84+
85+
:param values: The values to send
86+
:param timeout: Maximum time to wait (in seconds) for messages to send
87+
:raises SmartSimError: If sending message fails
88+
"""
89+
try:
90+
channel = drg_util.create_local(self._buffer_size)
91+
92+
with self._fli.sendh(timeout=None, stream_channel=channel) as sendh:
93+
for value in values:
94+
sendh.send_bytes(value)
95+
logger.debug(f"DragonFLIChannel {self.descriptor} sent message")
96+
except Exception as e:
97+
self._channel = None
98+
raise SmartSimError(
99+
f"Error sending via DragonFLIChannel {self.descriptor} {e}"
100+
) from e
101+
88102
def recv(self, timeout: float = 0.001) -> t.List[bytes]:
89103
"""Receives message(s) through the underlying communication channel.
90104

smartsim/_core/mli/infrastructure/control/request_dispatcher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ def _on_iteration(self) -> None:
371371
None,
372372
)
373373

374+
logger.debug(f"Dispatcher is processing {len(bytes_list)} messages")
374375
request_bytes = bytes_list[0]
375376
tensor_bytes_list = bytes_list[1:]
376377
self._perf_timer.start_timings()

tests/dragon/conftest.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@
5050
from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
5151
BackboneFeatureStore,
5252
)
53-
from smartsim._core.mli.infrastructure.storage.dragon_feature_store import (
54-
DragonFeatureStore,
55-
)
53+
from smartsim.log import get_logger
54+
55+
logger = get_logger(__name__)
56+
msg_pump_path = pathlib.Path(__file__).parent / "utils" / "msg_pump.py"
5657

5758
class MsgPumpRequest(t.NamedTuple):
5859
"""Fields required for starting a simulated inference request producer."""
@@ -116,17 +117,22 @@ def run_message_pump(request: MsgPumpRequest) -> subprocess.Popen:
116117
:param request: A request containing all parameters required to
117118
invoke the message pump entrypoint
118119
:returns: The Popen object for the subprocess that was started"""
119-
# <smartsim_dir>/tests/dragon/utils/msg_pump.py
120-
msg_pump_script = "tests/dragon/utils/msg_pump.py"
121-
msg_pump_path = pathlib.Path(__file__).parent / msg_pump_script
120+
assert request.backbone_descriptor
121+
assert request.callback_descriptor
122+
assert request.work_queue_descriptor
122123

124+
# <smartsim_dir>/tests/dragon/utils/msg_pump.py
123125
cmd = [sys.executable, str(msg_pump_path.absolute()), *request.as_command()]
126+
logger.info(f"Executing msg_pump with command: {cmd}")
124127

125128
popen = subprocess.Popen(
126129
args=cmd,
127130
stdout=subprocess.PIPE,
128131
stderr=subprocess.PIPE,
129132
)
133+
134+
assert popen is not None
135+
assert popen.returncode is None
130136
return popen
131137

132138
return run_message_pump

tests/dragon/test_request_dispatcher.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
from smartsim.log import get_logger
7474

7575
logger = get_logger(__name__)
76-
mock_msg_pump_path = pathlib.Path(__file__).parent / "utils" / "msg_pump.py"
7776
_MsgPumpFactory = t.Callable[[conftest.MsgPumpRequest], sp.Popen]
7877

7978
# The tests in this file belong to the dragon group
@@ -129,8 +128,8 @@ def test_request_dispatcher(
129128
)
130129

131130
request_dispatcher._on_start()
132-
pump_processes: t.List[sp.Popen] = []
133131

132+
# put some messages into the back queue for the dispatcher to pickup
134133
for i in range(num_iterations):
135134
batch: t.Optional[RequestBatch] = None
136135
mem_allocs = []
@@ -149,18 +148,22 @@ def test_request_dispatcher(
149148
)
150149

151150
msg_pump = msg_pump_factory(request)
152-
pump_processes.append(msg_pump)
151+
152+
assert msg_pump is not None, "Msg Pump Process Creation Failed"
153+
assert msg_pump.wait() == 0
153154

154155
time.sleep(1)
155156

156-
for _ in range(200):
157+
for i in range(15):
157158
try:
158159
request_dispatcher._on_iteration()
159160
batch = request_dispatcher.task_queue.get(timeout=0.1)
160161
break
161162
except Empty:
163+
logger.warning(f"Task queue is empty on iteration {i}")
162164
continue
163165
except Exception as exc:
166+
logger.error(f"Task queue exception on iteration {i}")
164167
raise exc
165168

166169
assert batch is not None
@@ -219,13 +222,6 @@ def test_request_dispatcher(
219222
assert model_key not in request_dispatcher._active_queues
220223
assert model_key not in request_dispatcher._queues
221224

222-
msg_pump.wait()
223-
224-
for msg_pump in pump_processes:
225-
if msg_pump.returncode is not None:
226-
continue
227-
msg_pump.terminate()
228-
229225
# Try to remove the dispatcher and free the memory
230226
del request_dispatcher
231227
gc.collect()

tests/dragon/utils/msg_pump.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import io
2828
import logging
2929
import pathlib
30-
import time
30+
import sys
3131
import typing as t
3232

3333
import pytest
@@ -44,7 +44,6 @@
4444

4545
# isort: on
4646

47-
from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel
4847
from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel
4948
from smartsim._core.mli.infrastructure.storage.backbone_feature_store import (
5049
BackboneFeatureStore,
@@ -124,6 +123,8 @@ def mock_messages(
124123
feature_store = BackboneFeatureStore.from_descriptor(fs_descriptor)
125124
request_dispatcher_queue = DragonFLIChannel.from_descriptor(dispatch_fli_descriptor)
126125

126+
feature_store[model_key] = load_model()
127+
127128
for iteration_number in range(2):
128129
logged_iteration = offset + iteration_number
129130
logger.debug(f"Sending mock message {logged_iteration}")
@@ -163,9 +164,9 @@ def mock_messages(
163164

164165
logger.info(
165166
f"Retrieving {iteration_number} from callback channel: {callback_descriptor}"
166-
)
167-
callback_channel = DragonCommChannel.from_descriptor(callback_descriptor)
168167

168+
# send the header & body together so they arrive together
169+
request_dispatcher_queue.send_multiple([request_bytes, tensor.tobytes()])
169170
# Results will be empty. The test pulls messages off the queue before they
170171
# can be serviced by a worker. Just ensure the callback channel works.
171172
results = callback_channel.recv(timeout=0.1)
@@ -185,9 +186,15 @@ def mock_messages(
185186

186187
args = args.parse_args()
187188

188-
mock_messages(
189-
args.dispatch_fli_descriptor,
190-
args.fs_descriptor,
191-
args.parent_iteration,
192-
args.callback_descriptor,
193-
)
189+
try:
190+
mock_messages(
191+
args.dispatch_fli_descriptor,
192+
args.fs_descriptor,
193+
args.parent_iteration,
194+
args.callback_descriptor,
195+
)
196+
except Exception as ex:
197+
logger.exception("The message pump did not execute properly")
198+
sys.exit(100)
199+
200+
sys.exit(0)

0 commit comments

Comments
 (0)