diff --git a/doc/changelog.md b/doc/changelog.md index 964e62b49d..ac09ecf604 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -13,6 +13,7 @@ Jump to: Description +- Add RequestDispatcher and the possibility of batching inference requests - Enable hostname selection for dragon tasks - Remove pydantic dependency from MLI code - Update MLI environment variables using new naming convention diff --git a/ex/high_throughput_inference/mli_driver.py b/ex/high_throughput_inference/mli_driver.py index 0cf87ef2e2..807a70b219 100644 --- a/ex/high_throughput_inference/mli_driver.py +++ b/ex/high_throughput_inference/mli_driver.py @@ -1,4 +1,3 @@ -import argparse import os import base64 import cloudpickle @@ -6,14 +5,17 @@ from smartsim import Experiment from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker from smartsim.status import TERMINAL_STATUSES +from smartsim.settings import DragonRunSettings import time import typing as t -device = "gpu" +DEVICE = "gpu" +NUM_RANKS = 4 +NUM_WORKERS = 1 filedir = os.path.dirname(__file__) worker_manager_script_name = os.path.join(filedir, "standalone_workermanager.py") app_script_name = os.path.join(filedir, "mock_app.py") -model_name = os.path.join(filedir, f"resnet50.{device.upper()}.pt") +model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt") transport: t.Literal["hsta", "tcp"] = "hsta" @@ -25,37 +27,51 @@ torch_worker_str = base64.b64encode(cloudpickle.dumps(TorchWorker)).decode("ascii") -worker_manager_rs = exp.create_run_settings( +worker_manager_rs: DragonRunSettings = exp.create_run_settings( sys.executable, [ worker_manager_script_name, "--device", - device, + DEVICE, "--worker_class", torch_worker_str, + "--batch_size", + str(NUM_RANKS//NUM_WORKERS), + "--batch_timeout", + str(0.00), + "--num_workers", + str(NUM_WORKERS) ], ) + +aff = [] + +worker_manager_rs.set_cpu_affinity(aff) + worker_manager = exp.create_model("worker_manager", run_settings=worker_manager_rs) worker_manager.attach_generator_files(to_copy=[worker_manager_script_name]) -app_rs = exp.create_run_settings( +app_rs: DragonRunSettings = exp.create_run_settings( sys.executable, - exe_args=[app_script_name, "--device", device], + exe_args=[app_script_name, "--device", DEVICE, "--log_max_batchsize", str(6)], ) +app_rs.set_tasks_per_node(NUM_RANKS) + + app = exp.create_model("app", run_settings=app_rs) app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name]) - exp.generate(worker_manager, app, overwrite=True) exp.start(worker_manager, app, block=False) while True: if exp.get_status(app)[0] in TERMINAL_STATUSES: + time.sleep(10) exp.stop(worker_manager) break if exp.get_status(worker_manager)[0] in TERMINAL_STATUSES: + time.sleep(10) exp.stop(app) break - time.sleep(5) print("Exiting.") diff --git a/ex/high_throughput_inference/mock_app.py b/ex/high_throughput_inference/mock_app.py index 44db70b71d..517d18fb2f 100644 --- a/ex/high_throughput_inference/mock_app.py +++ b/ex/high_throughput_inference/mock_app.py @@ -41,20 +41,27 @@ import os import time import torch -import numbers -from collections import OrderedDict +from mpi4py import MPI from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import ( DragonFeatureStore, ) from smartsim._core.mli.message_handler import MessageHandler from smartsim.log import get_logger +from smartsim._core.utils.timings import PerfTimer + +torch.set_num_interop_threads(16) +torch.set_num_threads(1) logger = get_logger("App") +logger.info("Started app") +CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False class ProtoClient: def __init__(self, timing_on: bool): + comm = MPI.COMM_WORLD + rank = comm.Get_rank() connect_to_infrastructure() ddict_str = os.environ["_SMARTSIM_INFRA_BACKBONE"] self._ddict = DDict.attach(ddict_str) @@ -70,61 +77,15 @@ def __init__(self, timing_on: bool): self._from_worker_ch_serialized = self._from_worker_ch.serialize() self._to_worker_ch = Channel.make_process_local() - self._start = None - self._interm = None - self._timings: OrderedDict[str, list[numbers.Number]] = OrderedDict() - self._timing_on = timing_on - - def _add_label_to_timings(self, label: str): - if label not in self._timings: - self._timings[label] = [] - - @staticmethod - def _format_number(number: numbers.Number): - return f"{number:0.4e}" - - def start_timings(self, batch_size: int): - if self._timing_on: - self._add_label_to_timings("batch_size") - self._timings["batch_size"].append(batch_size) - self._start = time.perf_counter() - self._interm = time.perf_counter() - - def end_timings(self): - if self._timing_on: - self._add_label_to_timings("total_time") - self._timings["total_time"].append( - self._format_number(time.perf_counter() - self._start) - ) - - def measure_time(self, label: str): - if self._timing_on: - self._add_label_to_timings(label) - self._timings[label].append( - self._format_number(time.perf_counter() - self._interm) - ) - self._interm = time.perf_counter() - - def print_timings(self, to_file: bool = False): - print(" ".join(self._timings.keys())) - value_array = numpy.array( - [value for value in self._timings.values()], dtype=float - ) - value_array = numpy.transpose(value_array) - for i in range(value_array.shape[0]): - print(" ".join(self._format_number(value) for value in value_array[i])) - if to_file: - numpy.save("timings.npy", value_array) - numpy.savetxt("timings.txt", value_array) + self.perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"a{rank}_") def run_model(self, model: bytes | str, batch: torch.Tensor): tensors = [batch.numpy()] - self.start_timings(batch.shape[0]) + self.perf_timer.start_timings("batch_size", batch.shape[0]) built_tensor_desc = MessageHandler.build_tensor_descriptor( "c", "float32", list(batch.shape) ) - self.measure_time("build_tensor_descriptor") - built_model = None + self.perf_timer.measure_time("build_tensor_descriptor") if isinstance(model, str): model_arg = MessageHandler.build_model_key(model, self._backbone_descriptor) else: @@ -137,39 +98,39 @@ def run_model(self, model: bytes | str, batch: torch.Tensor): output_descriptors=[], custom_attributes=None, ) - self.measure_time("build_request") + self.perf_timer.measure_time("build_request") request_bytes = MessageHandler.serialize_request(request) - self.measure_time("serialize_request") - with self._to_worker_fli.sendh( - timeout=None, stream_channel=self._to_worker_ch - ) as to_sendh: + self.perf_timer.measure_time("serialize_request") + with self._to_worker_fli.sendh(timeout=None, stream_channel=self._to_worker_ch) as to_sendh: to_sendh.send_bytes(request_bytes) - for t in tensors: - to_sendh.send_bytes(t.tobytes()) # TODO NOT FAST ENOUGH!!! - # to_sendh.send_bytes(bytes(t.data)) - logger.info(f"Message size: {len(request_bytes)} bytes") - - self.measure_time("send") + self.perf_timer.measure_time("send_request") + for tensor in tensors: + to_sendh.send_bytes(tensor.tobytes()) #TODO NOT FAST ENOUGH!!! + self.perf_timer.measure_time("send_tensors") with self._from_worker_ch.recvh(timeout=None) as from_recvh: resp = from_recvh.recv_bytes(timeout=None) - self.measure_time("receive") + self.perf_timer.measure_time("receive_response") response = MessageHandler.deserialize_response(resp) - self.measure_time("deserialize_response") + self.perf_timer.measure_time("deserialize_response") # list of data blobs? recv depending on the len(response.result.descriptors)? - data_blob = from_recvh.recv_bytes(timeout=None) - result = numpy.frombuffer( - data_blob, - dtype=str(response.result.descriptors[0].dataType), + data_blob: bytes = from_recvh.recv_bytes(timeout=None) + self.perf_timer.measure_time("receive_tensor") + result = torch.from_numpy( + numpy.frombuffer( + data_blob, + dtype=str(response.result.descriptors[0].dataType), + ) ) - self.measure_time("deserialize_tensor") + self.perf_timer.measure_time("deserialize_tensor") - self.end_timings() + self.perf_timer.end_timings() return result def set_model(self, key: str, model: bytes): self._ddict[key] = model + class ResNetWrapper: def __init__(self, name: str, model: str): self._model = torch.jit.load(model) @@ -190,24 +151,39 @@ def model(self): def name(self): return self._name - if __name__ == "__main__": parser = argparse.ArgumentParser("Mock application") - parser.add_argument("--device", default="cpu") + parser.add_argument("--device", default="cpu", type=str) + parser.add_argument("--log_max_batchsize", default=8, type=int) args = parser.parse_args() - resnet = ResNetWrapper("resnet50", f"resnet50.{args.device.upper()}.pt") + resnet = ResNetWrapper("resnet50", f"resnet50.{args.device}.pt") client = ProtoClient(timing_on=True) client.set_model(resnet.name, resnet.model) - total_iterations = 100 + if CHECK_RESULTS_AND_MAKE_ALL_SLOWER: + # TODO: adapt to non-Nvidia devices + torch_device = args.device.replace("gpu", "cuda") + pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to(torch_device) - for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: - logger.info(f"Batch size: {batch_size}") - for iteration_number in range(total_iterations + int(batch_size == 1)): - logger.info(f"Iteration: {iteration_number}") - client.run_model(resnet.name, resnet.get_batch(batch_size)) + TOTAL_ITERATIONS = 100 - client.print_timings(to_file=True) + for log2_bsize in range(args.log_max_batchsize+1): + b_size: int = 2**log2_bsize + logger.info(f"Batch size: {b_size}") + for iteration_number in range(TOTAL_ITERATIONS + int(b_size==1)): + logger.info(f"Iteration: {iteration_number}") + sample_batch = resnet.get_batch(b_size) + remote_result = client.run_model(resnet.name, sample_batch) + logger.info(client.perf_timer.get_last("total_time")) + if CHECK_RESULTS_AND_MAKE_ALL_SLOWER: + local_res = pt_model(sample_batch.to(torch_device)) + err_norm = torch.linalg.vector_norm(torch.flatten(remote_result).to(torch_device)-torch.flatten(local_res), ord=1).cpu() + res_norm = torch.linalg.vector_norm(remote_result, ord=1).item() + local_res_norm = torch.linalg.vector_norm(local_res, ord=1).item() + logger.info(f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}") + torch.cuda.synchronize() + + client.perf_timer.print_timings(to_file=True) \ No newline at end of file diff --git a/ex/high_throughput_inference/mock_app_redis.py b/ex/high_throughput_inference/mock_app_redis.py index c56b4fb8b4..8978bcea23 100644 --- a/ex/high_throughput_inference/mock_app_redis.py +++ b/ex/high_throughput_inference/mock_app_redis.py @@ -29,7 +29,9 @@ import numpy import time import torch +from mpi4py import MPI from smartsim.log import get_logger +from smartsim._core.utils.timings import PerfTimer from smartredis import Client logger = get_logger("App") @@ -56,6 +58,9 @@ def name(self): if __name__ == "__main__": + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + parser = argparse.ArgumentParser("Mock application") parser.add_argument("--device", default="cpu") args = parser.parse_args() @@ -65,24 +70,21 @@ def name(self): client = Client(cluster=False, address=None) client.set_model(resnet.name, resnet.model, backend='TORCH', device=args.device.upper()) + perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"redis{rank}_") + total_iterations = 100 timings=[] for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: logger.info(f"Batch size: {batch_size}") for iteration_number in range(total_iterations + int(batch_size==1)): - timing = [batch_size] + perf_timer.start_timings("batch_size", batch_size) logger.info(f"Iteration: {iteration_number}") - start = time.perf_counter() - client.put_tensor(name="batch", data=resnet.get_batch(batch_size).numpy()) - client.run_model(name=resnet.name, inputs=["batch"], outputs=["result"]) - result = client.get_tensor(name="result") - end = time.perf_counter() - timing.append(end-start) - timings.append(timing) - + input_name = f"batch_{rank}" + output_name = f"result_{rank}" + client.put_tensor(name=input_name, data=resnet.get_batch(batch_size).numpy()) + client.run_model(name=resnet.name, inputs=[input_name], outputs=[output_name]) + result = client.get_tensor(name=output_name) + perf_timer.end_timings() - timings_np = numpy.asarray(timings) - numpy.save("timings.npy", timings_np) - for timing in timings: - print(" ".join(str(t) for t in timing)) + perf_timer.print_timings(True) diff --git a/ex/high_throughput_inference/redis_driver.py b/ex/high_throughput_inference/redis_driver.py index ceddba4ef7..ff57725d40 100644 --- a/ex/high_throughput_inference/redis_driver.py +++ b/ex/high_throughput_inference/redis_driver.py @@ -29,23 +29,24 @@ from smartsim import Experiment from smartsim.status import TERMINAL_STATUSES import time -import typing as t -device = "gpu" +DEVICE = "gpu" filedir = os.path.dirname(__file__) app_script_name = os.path.join(filedir, "mock_app_redis.py") -model_name = os.path.join(filedir, f"resnet50.{device.upper()}.pt") +model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt") -exp_path = os.path.join(filedir, "redis_ai") +exp_path = os.path.join(filedir, "redis_ai_multi") os.makedirs(exp_path, exist_ok=True) -exp = Experiment("redis_ai", launcher="slurm", exp_path=exp_path) +exp = Experiment("redis_ai_multi", launcher="slurm", exp_path=exp_path) db = exp.create_database(interface="hsn0") -app_rs = exp.create_run_settings(sys.executable, exe_args = [app_script_name, "--device", device]) +app_rs = exp.create_run_settings( + sys.executable, exe_args = [app_script_name, "--device", DEVICE] + ) app_rs.set_nodes(1) -app_rs.set_tasks(1) +app_rs.set_tasks(4) app = exp.create_model("app", run_settings=app_rs) app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name]) diff --git a/ex/high_throughput_inference/standalone_workermanager.py b/ex/high_throughput_inference/standalone_workermanager.py index 982cb6cc38..0b8c61251b 100644 --- a/ex/high_throughput_inference/standalone_workermanager.py +++ b/ex/high_throughput_inference/standalone_workermanager.py @@ -24,28 +24,90 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# isort: off + import dragon + +# pylint disable=import-error +import dragon.infrastructure.policy as dragon_policy +import dragon.infrastructure.process_desc as dragon_process_desc +import dragon.native.process as dragon_process from dragon import fli from dragon.channels import Channel from dragon.data.ddict.ddict import DDict -from dragon.utils import b64decode, b64encode from dragon.globalservices.api_setup import connect_to_infrastructure +from dragon.managed_memory import MemoryPool +from dragon.utils import b64decode, b64encode +# pylint enable=import-error +# isort: off # isort: on + import argparse import base64 +import multiprocessing as mp +import os +import pickle +import socket +import sys +import time +import typing as t + import cloudpickle import optparse import os +from smartsim._core.entrypoints.service import Service +from smartsim._core.mli.comm.channel.channel import CommChannelBase from smartsim._core.mli.comm.channel.dragonchannel import DragonCommChannel from smartsim._core.mli.comm.channel.dragonfli import DragonFLIChannel from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import ( DragonFeatureStore, ) +from smartsim._core.mli.infrastructure.control.requestdispatcher import ( + RequestDispatcher, +) from smartsim._core.mli.infrastructure.control.workermanager import WorkerManager from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader +from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.worker.worker import MachineLearningWorkerBase + +from smartsim.log import get_logger + +logger = get_logger("Worker Manager Entry Point") + +mp.set_start_method("dragon") + +pid = os.getpid() +affinity = os.sched_getaffinity(pid) +logger.info(f"Entry point: {socket.gethostname()}, {affinity}") +logger.info(f"CPUS: {os.cpu_count()}") + + + +def service_as_dragon_proc( + service: Service, cpu_affinity: list[int], gpu_affinity: list[int] +) -> dragon_process.Process: + + options = dragon_process_desc.ProcessOptions(make_inf_channels=True) + local_policy = dragon_policy.Policy( + placement=dragon_policy.Policy.Placement.HOST_NAME, + host_name=socket.gethostname(), + cpu_affinity=cpu_affinity, + gpu_affinity=gpu_affinity, + ) + return dragon_process.Process( + target=service.execute, + args=[], + cwd=os.getcwd(), + policy=local_policy, + options=options, + stderr=dragon_process.Popen.STDOUT, + stdout=dragon_process.Popen.STDOUT, + ) + + if __name__ == "__main__": @@ -66,8 +128,20 @@ parser.add_argument( "--num_workers", type=int, default=1, help="Number of workers to run" ) - + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="How many requests the workers will try to aggregate before processing them", + ) + parser.add_argument( + "--batch_timeout", + type=float, + default=0.001, + help="How much time (in seconds) should be waited before processing an incomplete aggregated request", + ) args = parser.parse_args() + connect_to_infrastructure() ddict_str = os.environ["_SMARTSIM_INFRA_BACKBONE"] ddict = DDict.attach(ddict_str) @@ -77,8 +151,12 @@ to_worker_fli_serialized = to_worker_fli.serialize() ddict["to_worker_fli"] = to_worker_fli_serialized - worker_type_name = base64.b64decode(args.worker_class.encode("ascii")) - torch_worker = cloudpickle.loads(worker_type_name)() + arg_worker_type = cloudpickle.loads( + base64.b64decode(args.worker_class.encode("ascii")) + ) + + dfs = DragonFeatureStore(ddict) + comm_channel = DragonFLIChannel(to_worker_fli_serialized) descriptor = base64.b64encode(to_worker_fli_serialized).decode("utf-8") os.environ["_SMARTSIM_REQUEST_QUEUE"] = descriptor @@ -89,11 +167,57 @@ queue_factory=DragonFLIChannel.from_descriptor, ) - worker_manager = WorkerManager( + dispatcher = RequestDispatcher( + batch_timeout=args.batch_timeout, + batch_size=args.batch_size, config_loader=config_loader, - worker=torch_worker, - as_service=True, - cooldown=10, - device=args.device, + worker_type=arg_worker_type, ) - worker_manager.execute() + + wms = [] + worker_device = args.device + for wm_idx in range(args.num_workers): + + worker_manager = WorkerManager( + config_loader=config_loader, + worker_type=arg_worker_type, + as_service=True, + cooldown=10, + device=worker_device, + dispatcher_queue=dispatcher.task_queue, + ) + + wms.append(worker_manager) + + wm_affinity: list[int] = [] + disp_affinity: list[int] = [] + + # This is hardcoded for a specific type of node: + # the GPU-to-CPU mapping is taken from the nvidia-smi tool + # TODO can this be computed on the fly? + gpu_to_cpu_aff: dict[int, list[int]] = {} + gpu_to_cpu_aff[0] = list(range(48,64)) + list(range(112,128)) + gpu_to_cpu_aff[1] = list(range(32,48)) + list(range(96,112)) + gpu_to_cpu_aff[2] = list(range(16,32)) + list(range(80,96)) + gpu_to_cpu_aff[3] = list(range(0,16)) + list(range(64,80)) + + worker_manager_procs = [] + for worker_idx in range(args.num_workers): + wm_cpus = len(gpu_to_cpu_aff[worker_idx]) - 4 + wm_affinity = gpu_to_cpu_aff[worker_idx][:wm_cpus] + disp_affinity.extend(gpu_to_cpu_aff[worker_idx][wm_cpus:]) + worker_manager_procs.append(service_as_dragon_proc( + worker_manager, cpu_affinity=wm_affinity, gpu_affinity=[worker_idx] + )) + + dispatcher_proc = service_as_dragon_proc(dispatcher, cpu_affinity=disp_affinity, gpu_affinity=[]) + + # TODO: use ProcessGroup and restart=True? + all_procs = [dispatcher_proc, *worker_manager_procs] + + print(f"Dispatcher proc: {dispatcher_proc}") + for proc in all_procs: + proc.start() + + while all(proc.is_alive for proc in all_procs): + time.sleep(1) diff --git a/setup.py b/setup.py index 512da78de9..709913eda8 100644 --- a/setup.py +++ b/setup.py @@ -177,7 +177,7 @@ class BuildError(Exception): "filelock>=3.4.2", "protobuf~=3.20", "jinja2>=3.1.2", - "watchdog>=4.0.0", + "watchdog>=4.0.0,<5", "pycapnp==2.0.0", "pydantic==1.10.14", "pyzmq>=25.1.2", diff --git a/smartsim/_core/entrypoints/service.py b/smartsim/_core/entrypoints/service.py index df9c2bbef6..6b4ef74b67 100644 --- a/smartsim/_core/entrypoints/service.py +++ b/smartsim/_core/entrypoints/service.py @@ -103,23 +103,6 @@ def execute(self) -> None: running = True cooldown_start: t.Optional[datetime.datetime] = None - headers = [ - "batch_size", - "w_deserialize", - "w_fetch_model", - "w_load_model", - "w_fetch_input", - "w_transform_input", - "w_execute", - "w_transform_output", - "w_assign_output", - "w_build_reply", - "w_serialize_resp", - "w_send", - ] - - print(",".join(headers)) - while running: self._on_iteration() diff --git a/smartsim/_core/launcher/dragon/dragonBackend.py b/smartsim/_core/launcher/dragon/dragonBackend.py index 6cf39be0fb..7526af14ad 100644 --- a/smartsim/_core/launcher/dragon/dragonBackend.py +++ b/smartsim/_core/launcher/dragon/dragonBackend.py @@ -605,10 +605,7 @@ def _start_steps(self) -> None: logger.debug(f"Step id {step_id} allocated on {hosts}") - global_policy = dragon_policy.Policy( - placement=dragon_policy.Policy.Placement.HOST_NAME, - host_name=hosts[0], - ) + global_policy = self.create_run_policy(request, hosts[0]) options = dragon_process_desc.ProcessOptions(make_inf_channels=True) grp = dragon_process_group.ProcessGroup( restart=False, pmi_enabled=request.pmi_enabled, policy=global_policy diff --git a/smartsim/_core/mli/comm/channel/dragonchannel.py b/smartsim/_core/mli/comm/channel/dragonchannel.py index 80fdd9cdc6..89b90f2e62 100644 --- a/smartsim/_core/mli/comm/channel/dragonchannel.py +++ b/smartsim/_core/mli/comm/channel/dragonchannel.py @@ -33,11 +33,7 @@ logger = get_logger(__name__) -try: - import dragon.channels as dch -except ImportError as exc: - if not "pytest" in sys.modules: - raise exc from None +import dragon.channels as dch class DragonCommChannel(cch.CommChannelBase): diff --git a/smartsim/_core/mli/comm/channel/dragonfli.py b/smartsim/_core/mli/comm/channel/dragonfli.py index 4636894bdd..130c5cf5eb 100644 --- a/smartsim/_core/mli/comm/channel/dragonfli.py +++ b/smartsim/_core/mli/comm/channel/dragonfli.py @@ -68,12 +68,12 @@ def recv(self) -> t.List[bytes]: :returns: the received message""" messages = [] eot = False - with self._fli.recvh(timeout=None) as recvh: + with self._fli.recvh(timeout=0.001) as recvh: while not eot: try: message, _ = recvh.recv_bytes(timeout=None) messages.append(message) - except fli.FLIEOT as exc: + except fli.FLIEOT: eot = True return messages diff --git a/smartsim/_core/mli/infrastructure/control/__init__.py b/smartsim/_core/mli/infrastructure/control/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/infrastructure/control/devicemanager.py b/smartsim/_core/mli/infrastructure/control/devicemanager.py new file mode 100644 index 0000000000..3570bd51ed --- /dev/null +++ b/smartsim/_core/mli/infrastructure/control/devicemanager.py @@ -0,0 +1,146 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t +from contextlib import _GeneratorContextManager, contextmanager + +from .....log import get_logger +from ...infrastructure.storage.featurestore import FeatureStore +from ..worker.worker import MachineLearningWorkerBase, RequestBatch + +logger = get_logger(__name__) + + +class WorkerDevice: + def __init__(self, name: str) -> None: + """Wrapper around a device to keep track of loaded Models and availability + :param name: name used by the toolkit to identify this device, e.g. ``cuda:0`` + """ + self._name = name + """The name used by the toolkit to identify this device""" + self._models: dict[str, t.Any] = {} + """Dict of keys to models which are loaded on this device""" + + @property + def name(self) -> str: + """The identifier of the device represented by this object""" + return self._name + + def add_model(self, key: str, model: t.Any) -> None: + """Add a reference to a model loaded on this device and assign it a key + + :param key: The key under which the model is saved + :param model: The model which is added + """ + self._models[key] = model + + def remove_model(self, key: str) -> None: + """Remove the reference to a model loaded on this device + + :param key: The key of the model to remove + """ + self._models.pop(key) + + def get_model(self, key: str) -> t.Any: + """Get the model corresponding to a given key + + :param key: the model key + :returns: the model for the given key + """ + return self._models[key] + + def __contains__(self, key: str) -> bool: + """Check if model with a given key is available on the device + + :param key: the key of the model to check for existence + :returns: whether the model is available on the device + """ + return key in self._models + + @contextmanager + def get(self, key_to_remove: t.Optional[str]) -> t.Iterator["WorkerDevice"]: + yield self + if key_to_remove is not None: + self.remove_model(key_to_remove) + + +class DeviceManager: + def __init__(self, device: WorkerDevice): + """An object to manage devices such as GPUs and CPUs. + + The main goal of the ``DeviceManager`` is to ensure that + the managed device is ready to be used by a worker to + run a given model + :param device: The managed device + """ + self._device = device + """Device managed by this object""" + + def _load_model_on_device( + self, + worker: MachineLearningWorkerBase, + batch: RequestBatch, + feature_stores: dict[str, FeatureStore], + ) -> None: + """Load the model needed to execute on a batch on the managed device. + + The model is loaded by the worker. + + :param worker: the worker that loads the model + :param batch: the batch for which the model is needed + :param feature_stores: feature stores where the model could be stored + """ + + model_bytes = worker.fetch_model(batch, feature_stores) + loaded_model = worker.load_model(batch, model_bytes, self._device.name) + self._device.add_model(batch.model_id.key, loaded_model.model) + + def get_device( + self, + worker: MachineLearningWorkerBase, + batch: RequestBatch, + feature_stores: dict[str, FeatureStore], + ) -> _GeneratorContextManager[WorkerDevice]: + """Get the device managed by this object + + the model needed to run the batch of requests is + guaranteed to be available on the model + + :param worker: The worker that wants to access the device + :param batch: The batch of requests + :param feature_store: The feature store on which part of the + data needed by the request may be stored + :return: A generator yielding the device + """ + model_in_request = batch.has_raw_model + + # Load model if not already loaded, or + # because it is sent with the request + if model_in_request or not batch.model_id.key in self._device: + self._load_model_on_device(worker, batch, feature_stores) + + key_to_remove = batch.model_id.key if model_in_request else None + return self._device.get(key_to_remove) diff --git a/smartsim/_core/mli/infrastructure/control/error_handling.py b/smartsim/_core/mli/infrastructure/control/error_handling.py new file mode 100644 index 0000000000..e2c5bcd9e1 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/control/error_handling.py @@ -0,0 +1,70 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +from .....log import get_logger +from ...comm.channel.channel import CommChannelBase +from ...message_handler import MessageHandler +from ...mli_schemas.response.response_capnp import ResponseBuilder + +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import Status + +logger = get_logger(__file__) + + +def build_failure_reply(status: "Status", message: str) -> ResponseBuilder: + return MessageHandler.build_response( + status=status, + message=message, + result=[], + custom_attributes=None, + ) + + +def exception_handler( + exc: Exception, reply_channel: t.Optional[CommChannelBase], failure_message: str +) -> None: + """ + Logs exceptions and sends a failure response. + + :param exc: The exception to be logged + :param reply_channel: The channel used to send replies + :param failure_message: Failure message to log and send back + """ + logger.exception( + f"{failure_message}\n" + f"Exception type: {type(exc).__name__}\n" + f"Exception message: {str(exc)}" + ) + serialized_resp = MessageHandler.serialize_response( + build_failure_reply("fail", failure_message) + ) + if reply_channel: + reply_channel.send(serialized_resp) + else: + logger.warning("Unable to notify client of error without reply_channel") diff --git a/smartsim/_core/mli/infrastructure/control/requestdispatcher.py b/smartsim/_core/mli/infrastructure/control/requestdispatcher.py new file mode 100644 index 0000000000..d56912a8f0 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/control/requestdispatcher.py @@ -0,0 +1,504 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# pylint: disable=import-error +# pylint: disable-next=unused-import +import dragon +import dragon.globalservices.pool as dragon_gs_pool +from dragon.managed_memory import MemoryPool +from dragon.mpbridge.queues import DragonQueue + +# pylint: enable=import-error + +# isort: off +# isort: on + +import multiprocessing as mp +import time +import typing as t +import uuid +from queue import Empty, Full, Queue + +from smartsim._core.entrypoints.service import Service + +from .....error import SmartSimError +from .....log import get_logger +from ....utils.timings import PerfTimer +from ...infrastructure.environmentloader import EnvironmentConfigLoader +from ...infrastructure.storage.featurestore import FeatureStore +from ...infrastructure.worker.worker import ( + InferenceRequest, + MachineLearningWorkerBase, + ModelIdentifier, + RequestBatch, +) +from .error_handling import exception_handler + +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import Status + +logger = get_logger("Request Dispatcher") + + +class BatchQueue(Queue[InferenceRequest]): + def __init__( + self, batch_timeout: float, batch_size: int, model_id: ModelIdentifier + ) -> None: + """Queue used to store inference requests waiting to be batched and + sent to Worker Managers. + :param batch_timeout: Time in seconds that has to be waited before flushing a + non-full queue. The time of the first item put is 0 seconds. + :param batch_size: Total capacity of the queue. + :param model_id: Key of the model which needs to be executed on the queued + requests + """ + super().__init__(maxsize=batch_size) + self._batch_timeout = batch_timeout + """Time in seconds that has to be waited before flushing a non-full queue. + The time of the first item put is 0 seconds.""" + self._batch_size = batch_size + """Total capacity of the queue.""" + self._first_put: t.Optional[float] = None + """Time at which the first item was put on the queue""" + self._disposable = False + """Whether the queue will not be used again and can be deleted. + A disposable queue is always full.""" + self._model_id: ModelIdentifier = model_id + """Key of the model which needs to be executed on the queued requests""" + self._uid = str(uuid.uuid4()) + """Unique ID of queue""" + + @property + def uid(self) -> str: + """ID of this queue""" + return self._uid + + @property + def model_id(self) -> ModelIdentifier: + """Key of the model which needs to be run on the queued requests""" + return self._model_id + + def put( + self, + item: InferenceRequest, + block: bool = False, + timeout: t.Optional[float] = 0.0, + ) -> None: + """Put an inference request in the queue + :param item: The request + :param block: Whether to block when trying to put the item + :param timeout: Time (in seconds) to wait if block==True + :raises Full: If an item cannot be put on the queue + """ + super().put(item, block=block, timeout=timeout) + if self._first_put is None: + self._first_put = time.time() + + @property + def _elapsed_time(self) -> float: + """Time elapsed since the first item was put on this queue""" + if self.empty() or self._first_put is None: + return 0 + return time.time() - self._first_put + + @property + def ready(self) -> bool: + """True if the queue can be flushed""" + if self.empty(): + return False + + timed_out = ( + self._batch_timeout > 0 and self._elapsed_time >= self._batch_timeout + ) + logger.debug(f"Is full: {self.full()} or has timed out: {timed_out}") + return self.full() or timed_out + + def make_disposable(self) -> None: + """Set this queue as disposable, and never use it again after it gets flushed""" + self._disposable = True + + @property + def can_be_removed(self) -> bool: + """Whether this queue can be deleted and garbage collected""" + return self.empty() and self._disposable + + def flush(self) -> list[t.Any]: + """Get all requests from queue + :return: Requests waiting to be executed + """ + num_items = self.qsize() + self._first_put = None + items = [] + for _ in range(num_items): + try: + items.append(self.get()) + except Empty: + break + + return items + + def full(self) -> bool: + """Return True if the queue has reached its maximum capacity""" + if self._disposable: + return True + return self.qsize() >= self._batch_size + + def empty(self) -> bool: + """Return True if the queue has 0 elements""" + return self.qsize() == 0 + + +class RequestDispatcher(Service): + def __init__( + self, + batch_timeout: float, + batch_size: int, + config_loader: EnvironmentConfigLoader, + worker_type: t.Type[MachineLearningWorkerBase], + mem_pool_size: int = 2 * 1024**3, + ) -> None: + """The RequestDispatcher intercepts inference requests, stages them in + queues and batches them together before making them available to Worker + Managers. + :param batch_timeout: Maximum elapsed time before flushing a complete or + incomplete batch + :param batch_size: Total capacity of each batch queue. + :param mem_pool: Memory pool used to share batched input tensors with worker + managers + :param config_loader: Object to load configuration from environment + :param worker_type: Type of worker to instantiate to batch inputs + :param mem_pool_size: Size of the memory pool used to allocate tensors + :raises SmartSimError: If config_loaded.get_queue() does not return a channel + """ + super().__init__(as_service=True, cooldown=1) + self._queues: dict[str, list[BatchQueue]] = {} + """Dict of all batch queues available for a given model id""" + self._active_queues: dict[str, BatchQueue] = {} + """Mapping telling which queue is the recipient of requests for a given model + key""" + self._batch_timeout = batch_timeout + """Time in seconds that has to be waited before flushing a non-full queue""" + self._batch_size = batch_size + """Total capacity of each batch queue.""" + incoming_channel = config_loader.get_queue() + if incoming_channel is None: + raise SmartSimError("No incoming channel for dispatcher") + self._incoming_channel = incoming_channel + """The channel the dispatcher monitors for new tasks""" + self._outgoing_queue: DragonQueue = mp.Queue(maxsize=0) + """The queue on which batched inference requests are placed""" + self._feature_stores: t.Dict[str, FeatureStore] = {} + """A collection of attached feature stores""" + self._featurestore_factory = config_loader._featurestore_factory + """A factory method to create a desired feature store client type""" + self._backbone: t.Optional[FeatureStore] = config_loader.get_backbone() + """A standalone, system-created feature store used to share internal + information among MLI components""" + self._callback_factory = config_loader._callback_factory + """The type of communication channel to construct for callbacks""" + self._worker = worker_type() + """The worker used to batch inputs""" + self._mem_pool = MemoryPool.attach(dragon_gs_pool.create(mem_pool_size).sdesc) + """Memory pool used to share batched input tensors with the Worker Managers""" + self._perf_timer = PerfTimer(prefix="r_", debug=False, timing_on=True) + """Performance timer""" + + def _check_feature_stores(self, request: InferenceRequest) -> bool: + """Ensures that all feature stores required by the request are available + + :param request: The request to validate + :returns: False if feature store validation fails for the request, True + otherwise + """ + # collect all feature stores required by the request + fs_model: t.Set[str] = set() + if request.model_key: + fs_model = {request.model_key.descriptor} + fs_inputs = {key.descriptor for key in request.input_keys} + fs_outputs = {key.descriptor for key in request.output_keys} + + # identify which feature stores are requested and unknown + fs_desired = fs_model.union(fs_inputs).union(fs_outputs) + fs_actual = {item.descriptor for item in self._feature_stores.values()} + fs_missing = fs_desired - fs_actual + + if self._featurestore_factory is None: + logger.error("No feature store factory configured") + return False + + # create the feature stores we need to service request + if fs_missing: + logger.debug(f"Adding feature store(s): {fs_missing}") + for descriptor in fs_missing: + feature_store = self._featurestore_factory(descriptor) + self._feature_stores[descriptor] = feature_store + + return True + + # pylint: disable-next=no-self-use + def _check_model(self, request: InferenceRequest) -> bool: + """Ensure that a model is available for the request + + :param request: The request to validate + :returns: False if model validation fails for the request, True otherwise + """ + if request.model_key or request.raw_model: + return True + + logger.error("Unable to continue without model bytes or feature store key") + return False + + # pylint: disable-next=no-self-use + def _check_inputs(self, request: InferenceRequest) -> bool: + """Ensure that inputs are available for the request + + :param request: The request to validate + :returns: False if input validation fails for the request, True otherwise + """ + if request.input_keys or request.raw_inputs: + return True + + logger.error("Unable to continue without input bytes or feature store keys") + return False + + # pylint: disable-next=no-self-use + def _check_callback(self, request: InferenceRequest) -> bool: + """Ensure that a callback channel is available for the request + + :param request: The request to validate + :returns: False if callback validation fails for the request, True otherwise + """ + if request.callback is not None: + return True + + logger.error("No callback channel provided in request") + return False + + def _validate_request(self, request: InferenceRequest) -> bool: + """Ensure the request can be processed + + :param request: The request to validate + :return: False if the request fails any validation checks, True otherwise""" + checks = [ + self._check_feature_stores(request), + self._check_model(request), + self._check_inputs(request), + self._check_callback(request), + ] + + return all(checks) + + def _on_iteration(self) -> None: + """This method is executed repeatedly until ``Service`` shutdown + conditions are satisfied and cooldown is elapsed. + """ + try: + self._perf_timer.set_active(True) + bytes_list: t.List[bytes] = self._incoming_channel.recv() + except Exception: + self._perf_timer.set_active(False) + else: + if not bytes_list: + exception_handler( + ValueError("No request data found"), + None, + "No request data found.", + ) + + request_bytes = bytes_list[0] + tensor_bytes_list = bytes_list[1:] + self._perf_timer.start_timings() + + request = self._worker.deserialize_message( + request_bytes, self._callback_factory + ) + if request.input_meta and tensor_bytes_list: + request.raw_inputs = tensor_bytes_list + + self._perf_timer.measure_time("deserialize_message") + + if not self._validate_request(request): + exception_handler( + ValueError("Error validating the request"), + request.callback, + "Error validating the request.", + ) + self._perf_timer.measure_time("validate_request") + else: + self._perf_timer.measure_time("validate_request") + self.dispatch(request) + self._perf_timer.measure_time("dispatch") + finally: + self.flush_requests() + self.remove_queues() + + self._perf_timer.end_timings() + + if self._perf_timer.max_length == 801 and self._perf_timer.is_active: + self._perf_timer.print_timings(True) + + def remove_queues(self) -> None: + """Remove references to queues that can be removed + and allow them to be garbage collected""" + queue_lists_to_remove = [] + for key, queues in self._queues.items(): + queues_to_remove = [] + for queue in queues: + if queue.can_be_removed: + queues_to_remove.append(queue) + + for queue_to_remove in queues_to_remove: + queues.remove(queue_to_remove) + if ( + key in self._active_queues + and self._active_queues[key] == queue_to_remove + ): + del self._active_queues[key] + + if len(queues) == 0: + queue_lists_to_remove.append(key) + + for key in queue_lists_to_remove: + del self._queues[key] + + @property + def task_queue(self) -> DragonQueue: + """The queue on which batched requests are placed""" + return self._outgoing_queue + + def _swap_queue(self, model_id: ModelIdentifier) -> None: + """Get an empty queue or create a new one + + and make it the active one for a given model. + :param model_id: The id of the model for which the + queue has to be swapped + """ + if model_id.key in self._queues: + for queue in self._queues[model_id.key]: + if not queue.full(): + self._active_queues[model_id.key] = queue + return + + new_queue = BatchQueue(self._batch_timeout, self._batch_size, model_id) + if model_id.key in self._queues: + self._queues[model_id.key].append(new_queue) + else: + self._queues[model_id.key] = [new_queue] + self._active_queues[model_id.key] = new_queue + return + + def dispatch(self, request: InferenceRequest) -> None: + """Assign a request to a batch queue + :param request: the request to place + """ + if request.raw_model is not None: + logger.debug("Direct inference requested, creating tmp queue") + tmp_id = f"_tmp_{str(uuid.uuid4())}" + tmp_queue: BatchQueue = BatchQueue( + batch_timeout=0, + batch_size=1, + model_id=ModelIdentifier(key=tmp_id, descriptor="TMP"), + ) + self._active_queues[tmp_id] = tmp_queue + self._queues[tmp_id] = [tmp_queue] + tmp_queue.put_nowait(request) + tmp_queue.make_disposable() + return + + if request.model_key: + success = False + while not success: + try: + self._active_queues[request.model_key.key].put_nowait(request) + success = True + except (Full, KeyError): + self._swap_queue(request.model_key) + + def flush_requests(self) -> None: + """Get all requests from queues which are ready to be flushed. Place all + avaliable request batches in the outgoing queue. + """ + for queue_list in self._queues.values(): + for queue in queue_list: + if queue.ready: + self._perf_timer.measure_time("find_queue") + try: + batch = RequestBatch( + requests=queue.flush(), + inputs=None, + model_id=queue.model_id, + ) + finally: + self._perf_timer.measure_time("flush_requests") + try: + fetch_results = self._worker.fetch_inputs( + batch=batch, feature_stores=self._feature_stores + ) + except Exception as exc: + exception_handler( + exc, + None, + "Error fetching input.", + ) + continue + self._perf_timer.measure_time("fetch_input") + try: + transformed_inputs = self._worker.transform_input( + batch=batch, + fetch_results=fetch_results, + mem_pool=self._mem_pool, + ) + except Exception as exc: + exception_handler( + exc, + None, + "Error Transforming input.", + ) + continue + + self._perf_timer.measure_time("transform_input") + batch.inputs = transformed_inputs + for request in batch.requests: + request.raw_inputs = [] + request.input_meta = [] + + try: + self._outgoing_queue.put(batch) + except Exception as exc: + exception_handler( + exc, + None, + "Error placing batch on task queue.", + ) + continue + self._perf_timer.measure_time("put") + + def _can_shutdown(self) -> bool: + """Whether the Service can be shut down""" + return False + + def __del__(self) -> None: + self._mem_pool.destroy() diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index dcc35ae831..54a245b813 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -24,67 +24,42 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# pylint: disable=import-error +# pylint: disable-next=unused-import +import dragon + +# pylint: enable=import-error + +# isort: off +# isort: on + +import multiprocessing as mp import time import typing as t +from queue import Empty from smartsim._core.mli.infrastructure.storage.featurestore import FeatureStore from .....log import get_logger from ....entrypoints.service import Service -from ...comm.channel.channel import CommChannelBase -from ...comm.channel.dragonchannel import DragonCommChannel +from ....utils.timings import PerfTimer from ...infrastructure.environmentloader import EnvironmentConfigLoader from ...infrastructure.worker.worker import ( InferenceReply, - InferenceRequest, LoadModelResult, MachineLearningWorkerBase, + RequestBatch, ) from ...message_handler import MessageHandler -from ...mli_schemas.response.response_capnp import ResponseBuilder +from .devicemanager import DeviceManager, WorkerDevice +from .error_handling import build_failure_reply, exception_handler if t.TYPE_CHECKING: - from dragon.fli import FLInterface - from smartsim._core.mli.mli_schemas.response.response_capnp import Status logger = get_logger(__name__) -def build_failure_reply(status: "Status", message: str) -> ResponseBuilder: - """Build a response indicating a failure occurred - :param status: The status of the response - :param message: The error message to include in the response""" - return MessageHandler.build_response( - status=status, - message=message, - result=None, - custom_attributes=None, - ) - - -def exception_handler( - exc: Exception, reply_channel: t.Optional[CommChannelBase], failure_message: str -) -> None: - """ - Logs exceptions and sends a failure response. - - :param exc: The exception to be logged - :param reply_channel: The channel used to send replies - :param failure_message: Failure message to log and send back - """ - logger.exception( - f"{failure_message}\n" - f"Exception type: {type(exc).__name__}\n" - f"Exception message: {str(exc)}" - ) - serialized_resp = MessageHandler.serialize_response( - build_failure_reply("fail", failure_message) - ) - if reply_channel: - reply_channel.send(serialized_resp) - - class WorkerManager(Service): """An implementation of a service managing distribution of tasks to machine learning workers""" @@ -92,26 +67,29 @@ class WorkerManager(Service): def __init__( self, config_loader: EnvironmentConfigLoader, - worker: MachineLearningWorkerBase, + worker_type: t.Type[MachineLearningWorkerBase], + dispatcher_queue: "mp.Queue[RequestBatch]", as_service: bool = False, cooldown: int = 0, device: t.Literal["cpu", "gpu"] = "cpu", ) -> None: """Initialize the WorkerManager - :param config_loader: Environment config loader that loads the task queue and - feature store - :param workers: A worker to manage + :param config_loader: Environment config loader for loading queues + and feature stores + :param worker_type: The type of worker to manage + :param dispatcher_queue: Queue from which the batched requests are pulled :param as_service: Specifies run-once or run-until-complete behavior of service :param cooldown: Number of seconds to wait before shutting down after shutdown criteria are met - :param device: The type of hardware the workers must be executed on + :param device: The device on which the Worker should run. Every worker manager + is assigned one single GPU (if available), thus the device should have no index. """ super().__init__(as_service, cooldown) - self._task_queue: t.Optional[CommChannelBase] = config_loader.get_queue() - """the queue the manager monitors for new tasks""" - self._worker = worker + self._dispatcher_queue = dispatcher_queue + """The Dispatcher queue that the WorkerManager monitors for new batches""" + self._worker = worker_type() """The ML Worker implementation""" self._callback_factory = config_loader._callback_factory """The type of communication channel to construct for callbacks""" @@ -126,19 +104,28 @@ def __init__( self._backbone: t.Optional[FeatureStore] = config_loader.get_backbone() """A standalone, system-created feature store used to share internal information among MLI components""" + self._device_manager: t.Optional[DeviceManager] = None + """Object responsible for model caching and device access""" + self._perf_timer = PerfTimer(prefix="w_", debug=False, timing_on=True) + """Performance timer""" - def _check_feature_stores(self, request: InferenceRequest) -> bool: + def _on_start(self) -> None: + """Called on initial entry into Service `execute` event loop before + `_on_iteration` is invoked.""" + self._device_manager = DeviceManager(WorkerDevice(self._device)) + + def _check_feature_stores(self, batch: RequestBatch) -> bool: """Ensures that all feature stores required by the request are available - :param request: The request to validate - :returns: False if feature store validation fails for the request, True otherwise + :param batch: The batch of requests to validate + :returns: False if feature store validation fails for the batch, True otherwise """ # collect all feature stores required by the request fs_model: t.Set[str] = set() - if request.model_key: - fs_model = {request.model_key.descriptor} - fs_inputs = {key.descriptor for key in request.input_keys} - fs_outputs = {key.descriptor for key in request.output_keys} + if batch.model_id.key: + fs_model = {batch.model_id.descriptor} + fs_inputs = {key.descriptor for key in batch.input_keys} + fs_outputs = {key.descriptor for key in batch.output_keys} # identify which feature stores are requested and unknown fs_desired = fs_model.union(fs_inputs).union(fs_outputs) @@ -158,269 +145,169 @@ def _check_feature_stores(self, request: InferenceRequest) -> bool: return True - def _check_model(self, request: InferenceRequest) -> bool: - """Ensure that a model is available for the request - - :param request: The request to validate - :returns: False if model validation fails for the request, True otherwise - """ - if request.model_key or request.raw_model: - return True - - logger.error("Unable to continue without model bytes or feature store key") - return False - - def _check_inputs(self, request: InferenceRequest) -> bool: - """Ensure that inputs are available for the request - - :param request: The request to validate - :returns: False if input validation fails for the request, True otherwise - """ - if request.input_keys or request.raw_inputs: - return True - - logger.error("Unable to continue without input bytes or feature store keys") - return False - - def _check_callback(self, request: InferenceRequest) -> bool: - """Ensure that a callback channel is available for the request - - :param request: The request to validate - :returns: False if callback validation fails for the request, True otherwise - """ - if request.callback is not None: - return True - - logger.error("No callback channel provided in request") - return False - - def _validate_request(self, request: InferenceRequest) -> bool: + def _validate_batch(self, batch: RequestBatch) -> bool: """Ensure the request can be processed - :param request: The request to validate + :param batch: The batch of requests to validate :return: False if the request fails any validation checks, True otherwise""" - checks = [ - self._check_feature_stores(request), - self._check_model(request), - self._check_inputs(request), - self._check_callback(request), - ] - return all(checks) + if batch is None or len(batch.requests) == 0: + return False + + return self._check_feature_stores(batch) + # remove this when we are done with time measurements + # pylint: disable-next=too-many-statements def _on_iteration(self) -> None: """Executes calls to the machine learning worker implementation to complete the inference pipeline""" - logger.debug("executing worker manager pipeline") - if self._task_queue is None: - logger.error("No queue to check for tasks") + pre_batch_time = time.perf_counter() + try: + batch: RequestBatch = self._dispatcher_queue.get(timeout=0.0001) + except Empty: return - timings = [] # timing - - bytes_list: t.List[bytes] = self._task_queue.recv() + self._perf_timer.start_timings( + "flush_requests", time.perf_counter() - pre_batch_time + ) - if not bytes_list: + if not self._validate_batch(batch): exception_handler( - ValueError("No request data found"), + ValueError("An invalid batch was received"), None, - "No request data found.", + "Error batching inputs, the batch was invalid.", ) return - request_bytes = bytes_list[0] - tensor_bytes_list = bytes_list[1:] - - interm = time.perf_counter() # timing - request = self._worker.deserialize_message( - request_bytes, self._callback_factory - ) - - if request.input_meta and tensor_bytes_list: - request.raw_inputs = tensor_bytes_list + if self._device_manager is None: + for request in batch.requests: + msg = "No Device Manager found. WorkerManager._on_start() " + "must be called after initialization. If possible, " + "you should use `WorkerManager.execute()` instead of " + "directly calling `_on_iteration()`." + try: + self._dispatcher_queue.put(batch) + except Exception: + msg += "\nThe batch could not be put back in the queue " + "and will not be processed." + exception_handler( + RuntimeError(msg), + request.callback, + "Error acquiring device manager", + ) + return - if not self._validate_request(request): - exception_handler( - ValueError("Error validating the request"), - request.callback, - "Error validating the request.", + try: + device_cm = self._device_manager.get_device( + worker=self._worker, + batch=batch, + feature_stores=self._feature_stores, ) - - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - - reply = InferenceReply() - - if not request.raw_model: - if request.model_key is None: + except Exception as exc: + for request in batch.requests: exception_handler( - ValueError("Could not find model key or model"), + exc, request.callback, - "Could not find model key or model.", + "Error loading model on device or getting device.", ) - return + return + self._perf_timer.measure_time("fetch_model") - if request.model_key.key in self._cached_models: - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - model_result = LoadModelResult( - self._cached_models[request.model_key.key] - ) + with device_cm as device: - else: - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - try: - fetch_model_result = self._worker.fetch_model( - request, self._feature_stores - ) - except Exception as e: + try: + model_result = LoadModelResult(device.get_model(batch.model_id.key)) + except Exception as exc: + for request in batch.requests: exception_handler( - e, request.callback, "Failed while fetching the model." + exc, request.callback, "Error getting model from device." ) - return + return + self._perf_timer.measure_time("load_model") - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - try: - model_result = self._worker.load_model( - request, - fetch_result=fetch_model_result, - device=self._device, - ) - self._cached_models[request.model_key.key] = model_result.model - except Exception as e: + if batch.inputs is None: + for request in batch.requests: exception_handler( - e, + ValueError("Error batching inputs"), request.callback, - "Failed while loading model from feature store.", + "Error batching inputs.", ) - return + return + transformed_input = batch.inputs - else: - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing try: - fetch_model_result = self._worker.fetch_model( - request, self._feature_stores + execute_result = self._worker.execute( + batch, model_result, transformed_input, device.name ) except Exception as e: - exception_handler( - e, request.callback, "Failed while fetching the model." - ) + for request in batch.requests: + exception_handler(e, request.callback, "Failed while executing.") return + self._perf_timer.measure_time("execute") - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing try: - model_result = self._worker.load_model( - request, fetch_result=fetch_model_result, device=self._device + transformed_outputs = self._worker.transform_output( + batch, execute_result ) except Exception as e: - exception_handler( - e, - request.callback, - "Failed while loading model from feature store.", - ) + for request in batch.requests: + exception_handler( + e, request.callback, "Failed while transforming the output." + ) return - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - try: - fetch_input_result = self._worker.fetch_inputs( - request, self._feature_stores - ) - except Exception as e: - exception_handler(e, request.callback, "Failed while fetching the inputs.") - return - - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - try: - transformed_input = self._worker.transform_input( - request, fetch_input_result, self._device - ) - except Exception as e: - exception_handler( - e, request.callback, "Failed while transforming the input." - ) - return - - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - try: - execute_result = self._worker.execute( - request, model_result, transformed_input - ) - except Exception as e: - exception_handler(e, request.callback, "Failed while executing.") - return - - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - try: - transformed_output = self._worker.transform_output( - request, execute_result, self._device - ) - except Exception as e: - exception_handler( - e, request.callback, "Failed while transforming the output." - ) - return + for request, transformed_output in zip(batch.requests, transformed_outputs): + reply = InferenceReply() + if request.output_keys: + try: + reply.output_keys = self._worker.place_output( + request, + transformed_output, + self._feature_stores, + ) + except Exception as e: + exception_handler( + e, request.callback, "Failed while placing the output." + ) + continue + else: + reply.outputs = transformed_output.outputs + self._perf_timer.measure_time("assign_output") + + if reply.outputs is None or not reply.outputs: + response = build_failure_reply("fail", "Outputs not found.") + else: + reply.status_enum = "complete" + reply.message = "Success" + + results = self._worker.prepare_outputs(reply) + response = MessageHandler.build_response( + status=reply.status_enum, + message=reply.message, + result=results, + custom_attributes=None, + ) - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - if request.output_keys: - try: - reply.output_keys = self._worker.place_output( - request, transformed_output, self._feature_stores - ) - except Exception as e: - exception_handler( - e, request.callback, "Failed while placing the output." - ) - return - else: - reply.outputs = transformed_output.outputs - - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - - if reply.outputs is None or not reply.outputs: - response = build_failure_reply("fail", "Outputs not found.") - else: - reply.status_enum = "complete" - reply.message = "Success" - - results = self._worker.prepare_outputs(reply) - response = MessageHandler.build_response( - status=reply.status_enum, - message=reply.message, - result=results, - custom_attributes=None, - ) + self._perf_timer.measure_time("build_reply") - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing + serialized_resp = MessageHandler.serialize_response(response) - serialized_resp = MessageHandler.serialize_response(response) + self._perf_timer.measure_time("serialize_resp") - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - if request.callback: - # send serialized response - request.callback.send(serialized_resp) - if reply.outputs: - # send tensor data after response - for output in reply.outputs: - request.callback.send(output) + if request.callback: + request.callback.send(serialized_resp) + if reply.outputs: + # send tensor data after response + for output in reply.outputs: + request.callback.send(output) + self._perf_timer.measure_time("send") - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing + self._perf_timer.end_timings() - print(" ".join(str(time) for time in timings)) # timing + if self._perf_timer.max_length == 801: + self._perf_timer.print_timings(True) def _can_shutdown(self) -> bool: """Return true when the criteria to shut down the service are met.""" diff --git a/smartsim/_core/mli/infrastructure/worker/torch_worker.py b/smartsim/_core/mli/infrastructure/worker/torch_worker.py index eea349894c..0639d59696 100644 --- a/smartsim/_core/mli/infrastructure/worker/torch_worker.py +++ b/smartsim/_core/mli/infrastructure/worker/torch_worker.py @@ -29,6 +29,9 @@ import numpy as np import torch +# pylint: disable=import-error +from dragon.managed_memory import MemoryAlloc, MemoryPool + from .....error import SmartSimError from .....log import get_logger from ...mli_schemas.tensor import tensor_capnp @@ -36,13 +39,18 @@ ExecuteResult, FetchInputResult, FetchModelResult, - InferenceRequest, LoadModelResult, MachineLearningWorkerBase, + RequestBatch, TransformInputResult, TransformOutputResult, ) +# pylint: enable=import-error + + +torch.set_num_threads(1) +torch.set_num_interop_threads(4) logger = get_logger(__name__) @@ -51,75 +59,150 @@ class TorchWorker(MachineLearningWorkerBase): @staticmethod def load_model( - request: InferenceRequest, fetch_result: FetchModelResult, device: str + batch: RequestBatch, fetch_result: FetchModelResult, device: str ) -> LoadModelResult: if fetch_result.model_bytes: model_bytes = fetch_result.model_bytes - elif request.raw_model and request.raw_model.data: - model_bytes = request.raw_model.data + elif batch.raw_model and batch.raw_model.data: + model_bytes = batch.raw_model.data else: raise ValueError("Unable to load model without reference object") device_to_torch = {"cpu": "cpu", "gpu": "cuda"} - device = device_to_torch[device] + for old, new in device_to_torch.items(): + device = device.replace(old, new) + buffer = io.BytesIO(initial_bytes=model_bytes) - model = torch.jit.load(buffer, map_location=device) # type: ignore + with torch.no_grad(): + model = torch.jit.load(buffer, map_location=device) # type: ignore + model.eval() result = LoadModelResult(model) return result @staticmethod def transform_input( - request: InferenceRequest, fetch_result: FetchInputResult, device: str + batch: RequestBatch, + fetch_results: list[FetchInputResult], + mem_pool: MemoryPool, ) -> TransformInputResult: - result = [] + results: list[torch.Tensor] = [] + total_samples = 0 + slices: list[slice] = [] - device_to_torch = {"cpu": "cpu", "gpu": "cuda"} - device = device_to_torch[device] - if fetch_result.meta is None: + all_dims: list[list[int]] = [] + all_dtypes: list[str] = [] + if fetch_results[0].meta is None: raise ValueError("Cannot reconstruct tensor without meta information") - for item, item_meta in zip(fetch_result.inputs, fetch_result.meta): - tensor_desc: tensor_capnp.TensorDescriptor = item_meta - result.append( - torch.from_numpy(np.frombuffer(item, dtype=str(tensor_desc.dataType))) - .to(device) - .reshape(tuple(dim for dim in tensor_desc.dimensions)) + # Traverse inputs to get total number of samples and compute slices + # Assumption: first dimension is samples, all tensors in the same input + # have same number of samples + # thus we only look at the first tensor for each input + for res_idx, fetch_result in enumerate(fetch_results): + if fetch_result.meta is None or any( + item_meta is None for item_meta in fetch_result.meta + ): + raise ValueError("Cannot reconstruct tensor without meta information") + first_tensor_desc: tensor_capnp.TensorDescriptor = fetch_result.meta[0] + num_samples = first_tensor_desc.dimensions[0] + slices.append(slice(total_samples, total_samples + num_samples)) + total_samples = total_samples + num_samples + + if res_idx == len(fetch_results) - 1: + # For each tensor in the last input, get remaining dimensions + # Assumptions: all inputs have the same number of tensors and + # last N-1 dimensions match across inputs for corresponding tensors + # thus: resulting array will be of size (num_samples, all_other_dims) + for item_meta in fetch_result.meta: + tensor_desc: tensor_capnp.TensorDescriptor = item_meta + tensor_dims = list(tensor_desc.dimensions) + all_dims.append([total_samples, *tensor_dims[1:]]) + all_dtypes.append(str(tensor_desc.dataType)) + + for result_tensor_idx, (dims, dtype) in enumerate(zip(all_dims, all_dtypes)): + itemsize = np.empty((1), dtype=dtype).itemsize + alloc_size = int(np.prod(dims) * itemsize) + mem_alloc = mem_pool.alloc(alloc_size) + mem_view = mem_alloc.get_memview() + mem_view[:alloc_size] = b"".join( + [ + fetch_result.inputs[result_tensor_idx] + for fetch_result in fetch_results + ] ) - return TransformInputResult(result) - # return data # note: this fails copy test! + results.append(mem_alloc.serialize()) + + return TransformInputResult(results, slices, all_dims, all_dtypes) + + # pylint: disable-next=unused-argument @staticmethod def execute( - request: InferenceRequest, + batch: RequestBatch, load_result: LoadModelResult, transform_result: TransformInputResult, + device: str, ) -> ExecuteResult: if not load_result.model: raise SmartSimError("Model must be loaded to execute") + device_to_torch = {"cpu": "cpu", "gpu": "cuda"} + for old, new in device_to_torch.items(): + device = device.replace(old, new) + + tensors = [] + mem_allocs = [] + for transformed, dims, dtype in zip( + transform_result.transformed, transform_result.dims, transform_result.dtypes + ): + mem_alloc = MemoryAlloc.attach(transformed) + mem_allocs.append(mem_alloc) + itemsize = np.empty((1), dtype=dtype).itemsize + tensors.append( + torch.from_numpy( + np.frombuffer( + mem_alloc.get_memview()[0 : np.prod(dims) * itemsize], + dtype=dtype, + ).reshape(dims) + ) + ) model: torch.nn.Module = load_result.model - model.eval() - results = [model(tensor).detach() for tensor in transform_result.transformed] + with torch.no_grad(): + model.eval() + results = [ + model( + *[ + tensor.to(device, non_blocking=True).detach() + for tensor in tensors + ] + ) + ] + + transform_result.transformed = [] - execute_result = ExecuteResult(results) + execute_result = ExecuteResult(results, transform_result.slices) + for mem_alloc in mem_allocs: + mem_alloc.free() return execute_result @staticmethod def transform_output( - request: InferenceRequest, + batch: RequestBatch, execute_result: ExecuteResult, - result_device: str, - ) -> TransformOutputResult: - if result_device != "cpu": - transformed = [ - item.to("cpu").numpy().tobytes() for item in execute_result.predictions - ] - - # todo: need the shape from latest schemas added here. - return TransformOutputResult(transformed, None, "c", "float32") # fixme - - return TransformOutputResult( - [item.numpy().tobytes() for item in execute_result.predictions], - None, - "c", - "float32", - ) # fixme + ) -> list[TransformOutputResult]: + transformed_list: list[TransformOutputResult] = [] + cpu_predictions = [ + prediction.cpu() for prediction in execute_result.predictions + ] + for result_slice in execute_result.slices: + transformed = [] + for cpu_item in cpu_predictions: + transformed.append(cpu_item[result_slice].numpy().tobytes()) + + # todo: need the shape from latest schemas added here. + transformed_list.append( + TransformOutputResult(transformed, None, "c", "float32") + ) # fixme + + execute_result.predictions = [] + + return transformed_list diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py index 89fb635247..25e4dc49f7 100644 --- a/smartsim/_core/mli/infrastructure/worker/worker.py +++ b/smartsim/_core/mli/infrastructure/worker/worker.py @@ -24,8 +24,15 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# pylint: disable=import-error +from dragon.managed_memory import MemoryPool + +# isort: off +# isort: on + import typing as t from abc import ABC, abstractmethod +from dataclasses import dataclass from .....error import SmartSimError from .....log import get_logger @@ -40,6 +47,9 @@ logger = get_logger(__name__) +# Placeholder +ModelIdentifier = FeatureStoreKey + class InferenceRequest: """Internal representation of an inference request from a client""" @@ -100,19 +110,34 @@ def __init__(self, model: t.Any) -> None: class TransformInputResult: - """A wrapper around a transformed input""" + """A wrapper around a transformed batch of input tensors""" - def __init__(self, result: t.Any) -> None: + def __init__( + self, + result: t.Any, + slices: list[slice], + dims: list[list[int]], + dtypes: list[str], + ) -> None: """Initialize the object""" self.transformed = result + """List of Dragon MemoryAlloc objects on which the tensors are stored""" + self.slices = slices + """Each slice represents which portion of the input tensors belongs to + which request""" + self.dims = dims + """Dimension of the transformed tensors""" + self.dtypes = dtypes + """Data type of transformed tensors""" class ExecuteResult: """A wrapper around inference results""" - def __init__(self, result: t.Any) -> None: + def __init__(self, result: t.Any, slices: list[slice]) -> None: """Initialize the object""" self.predictions = result + self.slices = slices class FetchInputResult: @@ -153,6 +178,62 @@ def __init__(self, result: bytes) -> None: self.model_bytes: bytes = result +@dataclass +class RequestBatch: + """A batch of aggregated inference requests""" + + requests: list[InferenceRequest] + inputs: t.Optional[TransformInputResult] + model_id: ModelIdentifier + + @property + def has_valid_requests(self) -> bool: + """Returns whether the batch contains at least one request. + + :return: True if at least one request is available + """ + return len(self.requests) > 0 + + @property + def has_raw_model(self) -> bool: + """Returns whether the batch has a raw model + + :return: True if the batch has a raw model + """ + return self.raw_model is not None + + @property + def raw_model(self) -> t.Optional[t.Any]: + """Returns the raw model to use to execute for this batch + if it is available. + :return: A model if available, otherwise None""" + if self.has_valid_requests: + return self.requests[0].raw_model + return None + + @property + def input_keys(self) -> t.List[FeatureStoreKey]: + """All input keys available in this batch's requests + + :return: All input keys belonging to requests in this batch""" + keys = [] + for request in self.requests: + keys.extend(request.input_keys) + + return keys + + @property + def output_keys(self) -> t.List[FeatureStoreKey]: + """All output keys available in this batch's requests + + :return: All output keys belonging to requests in this batch""" + keys = [] + for request in self.requests: + keys.extend(request.output_keys) + + return keys + + class MachineLearningWorkerCore: """Basic functionality of ML worker that is shared across all worker types""" @@ -233,29 +314,30 @@ def prepare_outputs(reply: InferenceReply) -> t.List[t.Any]: @staticmethod def fetch_model( - request: InferenceRequest, feature_stores: t.Dict[str, FeatureStore] + batch: RequestBatch, feature_stores: t.Dict[str, FeatureStore] ) -> FetchModelResult: """Given a resource key, retrieve the raw model from a feature store - :param request: The request that triggered the pipeline + :param batch: The batch of requests that triggered the pipeline :param feature_stores: Available feature stores used for persistence - :return: Raw bytes of the model""" + :return: Raw bytes of the model + :raises SmartSimError: if neither a key or a model are provided or the + model cannot be retrieved from the feature store + :raises ValueError: if a feature store is not available and a raw + model is not provided""" - if request.raw_model: - # Should we cache model in the feature store? - # model_key = hash(request.raw_model) - # feature_store[model_key] = request.raw_model - # short-circuit and return the directly supplied model - return FetchModelResult(request.raw_model.data) + # All requests in the same batch share the model + if batch.raw_model: + return FetchModelResult(batch.raw_model.data) if not feature_stores: raise ValueError("Feature store is required for model retrieval") - if not request.model_key: + if batch.model_id is None: raise SmartSimError( "Key must be provided to retrieve model from feature store" ) - key, fsd = request.model_key.key, request.model_key.descriptor + key, fsd = batch.model_id.key, batch.model_id.descriptor try: feature_store = feature_stores[fsd] @@ -267,51 +349,47 @@ def fetch_model( @staticmethod def fetch_inputs( - request: InferenceRequest, feature_stores: t.Dict[str, FeatureStore] - ) -> FetchInputResult: + batch: RequestBatch, feature_stores: t.Dict[str, FeatureStore] + ) -> t.List[FetchInputResult]: """Given a collection of ResourceKeys, identify the physical location and input metadata - :param request: The request that triggered the pipeline + :param batch: The batch of requests that triggered the pipeline :param feature_stores: Available feature stores used for persistence - :return: the fetched input""" + :return: the fetched input + :raises ValueError: If neither an input key or an input tensor are provided + :raises SmartSimError: If a tensor for a given key cannot be retrieved""" + fetch_results = [] + for request in batch.requests: + if request.raw_inputs: + fetch_results.append( + FetchInputResult(request.raw_inputs, request.input_meta) + ) + continue - if request.raw_inputs: - return FetchInputResult(request.raw_inputs, request.input_meta) + if not feature_stores: + raise ValueError("No input and no feature store provided") - if not feature_stores: - raise ValueError("No input and no feature store provided") - - if request.input_keys: - data: t.List[bytes] = [] - - for fs_key in request.input_keys: - try: - feature_store = feature_stores[fs_key.descriptor] - tensor_bytes = t.cast(bytes, feature_store[fs_key.key]) - data.append(tensor_bytes) - except KeyError as ex: - logger.exception(ex) - raise SmartSimError( - f"Model could not be retrieved with key {fs_key.key}" - ) from ex - return FetchInputResult( - data, meta=None - ) # fixme: need to get both tensor and descriptor - - raise ValueError("No input source") + if request.input_keys: + data: t.List[bytes] = [] - @staticmethod - def batch_requests( - request: InferenceRequest, transform_result: TransformInputResult - ) -> CreateInputBatchResult: - """Create a batch of requests. Return the batch when batch_size datum have been - collected or a configured batch duration has elapsed. - :param request: The request that triggered the pipeline - :param transform_result: Transformed inputs ready for batching - :return: `None` if batch size has not been reached and timeout not exceeded.""" - if transform_result is not None or request.batch_size: - raise NotImplementedError("Batching is not yet supported") - return CreateInputBatchResult(None) + for fs_key in request.input_keys: + try: + feature_store = feature_stores[fs_key.descriptor] + tensor_bytes = t.cast(bytes, feature_store[fs_key.key]) + data.append(tensor_bytes) + except KeyError as ex: + logger.exception(ex) + raise SmartSimError( + f"Tensor could not be retrieved with key {fs_key.key}" + ) from ex + fetch_results.append( + FetchInputResult(data, meta=None) + ) # fixme: need to get both tensor and descriptor + continue + + raise ValueError("No input source") + + return fetch_results @staticmethod def place_output( @@ -324,7 +402,9 @@ def place_output( :param request: The request that triggered the pipeline :param execute_result: Results from inference :param feature_stores: Available feature stores used for persistence - :return: A collection of keys that were placed in the feature store""" + :return: A collection of keys that were placed in the feature store + :raises ValueError: If a feature store is not provided + """ if not feature_stores: raise ValueError("Feature store is required for output persistence") @@ -342,13 +422,13 @@ def place_output( class MachineLearningWorkerBase(MachineLearningWorkerCore, ABC): - """Abstrct base class providing contract for a machine learning + """Abstract base class providing contract for a machine learning worker implementation.""" @staticmethod @abstractmethod def load_model( - request: InferenceRequest, fetch_result: FetchModelResult, device: str + batch: RequestBatch, fetch_result: FetchModelResult, device: str ) -> LoadModelResult: """Given a loaded MachineLearningModel, ensure it is loaded into device memory @@ -359,35 +439,39 @@ def load_model( @staticmethod @abstractmethod def transform_input( - request: InferenceRequest, fetch_result: FetchInputResult, device: str + batch: RequestBatch, + fetch_results: list[FetchInputResult], + mem_pool: MemoryPool, ) -> TransformInputResult: - """Given a collection of data, perform a transformation on the data + """Given a collection of data, perform a transformation on the data and put + the raw tensor data on a MemoryPool allocation. :param request: The request that triggered the pipeline - :param fetch_result: Raw output from fetching inputs out of a feature store - :param device: The device on which the transformed input must be placed + :param fetch_result: Raw outputs from fetching inputs out of a feature store + :param mem_pool: The memory pool used to access batched input tensors :return: The transformed inputs wrapped in a InputTransformResult""" @staticmethod @abstractmethod def execute( - request: InferenceRequest, + batch: RequestBatch, load_result: LoadModelResult, transform_result: TransformInputResult, + device: str, ) -> ExecuteResult: """Execute an ML model on inputs transformed for use by the model - :param request: The request that triggered the pipeline + :param batch: The batch of requests that triggered the pipeline :param load_result: The result of loading the model onto device memory :param transform_result: The result of transforming inputs for model consumption + :param device: The device on which the model will be executed :return: The result of inference wrapped in an ExecuteResult""" @staticmethod @abstractmethod def transform_output( - request: InferenceRequest, execute_result: ExecuteResult, result_device: str - ) -> TransformOutputResult: + batch: RequestBatch, execute_result: ExecuteResult + ) -> t.List[TransformOutputResult]: """Given inference results, perform transformations required to transmit results to the requestor. - :param request: The request that triggered the pipeline + :param batch: The batch of requests that triggered the pipeline :param execute_result: The result of inference wrapped in an ExecuteResult - :param result_device: The device on which the result of inference is placed - :return:""" + :return: A list of transformed outputs""" diff --git a/smartsim/_core/mli/mli_schemas/model/__init__.py b/smartsim/_core/mli/mli_schemas/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/utils/timings.py b/smartsim/_core/utils/timings.py new file mode 100644 index 0000000000..a61a243220 --- /dev/null +++ b/smartsim/_core/utils/timings.py @@ -0,0 +1,143 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import time +import typing as t +from collections import OrderedDict + +import numpy as np + +from ...log import get_logger + +logger = get_logger("PerfTimer") + + +class PerfTimer: + def __init__( + self, + filename: str = "timings", + prefix: str = "", + timing_on: bool = True, + debug: bool = False, + ): + self._start: t.Optional[float] = None + self._interm: t.Optional[float] = None + self._timings: OrderedDict[str, list[t.Union[float, int, str]]] = OrderedDict() + self._timing_on = timing_on + self._filename = filename + self._prefix = prefix + self._debug = debug + + def _add_label_to_timings(self, label: str) -> None: + if label not in self._timings: + self._timings[label] = [] + + @staticmethod + def _format_number(number: t.Union[float, int]) -> str: + return f"{number:0.4e}" + + def start_timings( + self, + first_label: t.Optional[str] = None, + first_value: t.Optional[t.Union[float, int]] = None, + ) -> None: + if self._timing_on: + if first_label is not None and first_value is not None: + mod_label = self._make_label(first_label) + value = self._format_number(first_value) + self._log(f"Started timing: {first_label}: {value}") + self._add_label_to_timings(mod_label) + self._timings[mod_label].append(value) + self._start = time.perf_counter() + self._interm = time.perf_counter() + + def end_timings(self) -> None: + if self._timing_on and self._start is not None: + mod_label = self._make_label("total_time") + self._add_label_to_timings(mod_label) + delta = self._format_number(time.perf_counter() - self._start) + self._timings[self._make_label("total_time")].append(delta) + self._log(f"Finished timing: {mod_label}: {delta}") + self._interm = None + + def _make_label(self, label: str) -> str: + return self._prefix + label + + def _get_delta(self) -> t.Union[float, int]: + if self._interm is None: + return 0 + return time.perf_counter() - self._interm + + def get_last(self, label: str) -> str: + mod_label = self._make_label(label) + if mod_label in self._timings: + value = self._timings[mod_label][-1] + if value: + return f"{label}: {value}" + + return "Not measured yet" + + def measure_time(self, label: str) -> None: + if self._timing_on and self._interm is not None: + mod_label = self._make_label(label) + self._add_label_to_timings(mod_label) + delta = self._format_number(self._get_delta()) + self._timings[mod_label].append(delta) + self._log(f"{mod_label}: {delta}") + self._interm = time.perf_counter() + + def _log(self, msg: str) -> None: + if self._debug: + logger.info(msg) + + @property + def max_length(self) -> int: + if len(self._timings) == 0: + return 0 + return max(len(value) for value in self._timings.values()) + + def print_timings(self, to_file: bool = False) -> None: + print(" ".join(self._timings.keys())) + try: + value_array = np.array(list(self._timings.values()), dtype=float) + except Exception as e: + logger.exception(e) + return + value_array = np.transpose(value_array) + if self._debug: + for i in range(value_array.shape[0]): + print(" ".join(self._format_number(value) for value in value_array[i])) + if to_file: + np.save(self._prefix + self._filename + ".npy", value_array) + + def set_active(self, active: bool = True) -> None: + """Set whether the timer will record time""" + self._timing_on = active + + @property + def is_active(self) -> bool: + """Returns true if the timer will record time""" + return self._timing_on diff --git a/tests/mli/test_core_machine_learning_worker.py b/tests/dragon/test_core_machine_learning_worker.py similarity index 80% rename from tests/mli/test_core_machine_learning_worker.py rename to tests/dragon/test_core_machine_learning_worker.py index 7ef4ab259b..231a971241 100644 --- a/tests/mli/test_core_machine_learning_worker.py +++ b/tests/dragon/test_core_machine_learning_worker.py @@ -28,6 +28,9 @@ import time import pytest + +dragon = pytest.importorskip("dragon") + import torch import smartsim.error as sse @@ -35,6 +38,7 @@ from smartsim._core.mli.infrastructure.worker.worker import ( InferenceRequest, MachineLearningWorkerCore, + RequestBatch, TransformInputResult, TransformOutputResult, ) @@ -42,8 +46,8 @@ from .featurestore import FileSystemFeatureStore, MemoryFeatureStore -# The tests in this file belong to the group_a group -pytestmark = pytest.mark.group_b +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon # retrieved from pytest fixtures is_dragon = ( @@ -94,9 +98,11 @@ def test_fetch_model_disk(persist_torch_model: pathlib.Path, test_dir: str) -> N fsd = feature_store.descriptor feature_store[str(persist_torch_model)] = persist_torch_model.read_bytes() - request = InferenceRequest(model_key=FeatureStoreKey(key=key, descriptor=fsd)) + model_key = FeatureStoreKey(key=key, descriptor=fsd) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) - fetch_result = worker.fetch_model(request, {fsd: feature_store}) + fetch_result = worker.fetch_model(batch, {fsd: feature_store}) assert fetch_result.model_bytes assert fetch_result.model_bytes == persist_torch_model.read_bytes() @@ -110,10 +116,12 @@ def test_fetch_model_disk_missing() -> None: key = "/path/that/doesnt/exist" - request = InferenceRequest(model_key=FeatureStoreKey(key=key, descriptor=fsd)) + model_key = FeatureStoreKey(key=key, descriptor=fsd) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) with pytest.raises(sse.SmartSimError) as ex: - worker.fetch_model(request, {fsd: feature_store}) + worker.fetch_model(batch, {fsd: feature_store}) # ensure the error message includes key-identifying information assert key in ex.value.args[0] @@ -133,10 +141,11 @@ def test_fetch_model_feature_store(persist_torch_model: pathlib.Path) -> None: fsd = feature_store.descriptor feature_store[key] = persist_torch_model.read_bytes() - request = InferenceRequest( - model_key=FeatureStoreKey(key=key, descriptor=feature_store.descriptor) - ) - fetch_result = worker.fetch_model(request, {fsd: feature_store}) + model_key = FeatureStoreKey(key=key, descriptor=feature_store.descriptor) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_model(batch, {fsd: feature_store}) assert fetch_result.model_bytes assert fetch_result.model_bytes == persist_torch_model.read_bytes() @@ -150,13 +159,13 @@ def test_fetch_model_feature_store_missing() -> None: feature_store = MemoryFeatureStore() fsd = feature_store.descriptor - request = InferenceRequest( - model_key=FeatureStoreKey(key=key, descriptor=feature_store.descriptor) - ) + model_key = FeatureStoreKey(key=key, descriptor=feature_store.descriptor) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) # todo: consider that raising this exception shows impl. replace... with pytest.raises(sse.SmartSimError) as ex: - worker.fetch_model(request, {fsd: feature_store}) + worker.fetch_model(batch, {fsd: feature_store}) # ensure the error message includes key-identifying information assert key in ex.value.args[0] @@ -173,11 +182,11 @@ def test_fetch_model_memory(persist_torch_model: pathlib.Path) -> None: fsd = feature_store.descriptor feature_store[key] = persist_torch_model.read_bytes() - request = InferenceRequest( - model_key=FeatureStoreKey(key=key, descriptor=feature_store.descriptor) - ) + model_key = FeatureStoreKey(key=key, descriptor=feature_store.descriptor) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) - fetch_result = worker.fetch_model(request, {fsd: feature_store}) + fetch_result = worker.fetch_model(batch, {fsd: feature_store}) assert fetch_result.model_bytes assert fetch_result.model_bytes == persist_torch_model.read_bytes() @@ -193,12 +202,16 @@ def test_fetch_input_disk(persist_torch_tensor: pathlib.Path) -> None: request = InferenceRequest( input_keys=[FeatureStoreKey(key=tensor_name, descriptor=fsd)] ) + + model_key = FeatureStoreKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + worker = MachineLearningWorkerCore feature_store[tensor_name] = persist_torch_tensor.read_bytes() - fetch_result = worker.fetch_inputs(request, {fsd: feature_store}) - assert fetch_result.inputs is not None + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) + assert fetch_result[0].inputs is not None def test_fetch_input_disk_missing() -> None: @@ -212,8 +225,11 @@ def test_fetch_input_disk_missing() -> None: request = InferenceRequest(input_keys=[FeatureStoreKey(key=key, descriptor=fsd)]) + model_key = FeatureStoreKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + with pytest.raises(sse.SmartSimError) as ex: - worker.fetch_inputs(request, {fsd: feature_store}) + worker.fetch_inputs(batch, {fsd: feature_store}) # ensure the error message includes key-identifying information assert key[0] in ex.value.args[0] @@ -236,9 +252,14 @@ def test_fetch_input_feature_store(persist_torch_tensor: pathlib.Path) -> None: # put model bytes into the feature store feature_store[tensor_name] = persist_torch_tensor.read_bytes() - fetch_result = worker.fetch_inputs(request, {fsd: feature_store}) - assert fetch_result.inputs - assert list(fetch_result.inputs)[0][:10] == persist_torch_tensor.read_bytes()[:10] + model_key = FeatureStoreKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) + assert fetch_result[0].inputs + assert ( + list(fetch_result[0].inputs)[0][:10] == persist_torch_tensor.read_bytes()[:10] + ) @pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") @@ -269,9 +290,12 @@ def test_fetch_multi_input_feature_store(persist_torch_tensor: pathlib.Path) -> ] ) - fetch_result = worker.fetch_inputs(request, {fsd: feature_store}) + model_key = FeatureStoreKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) - raw_bytes = list(fetch_result.inputs) + raw_bytes = list(fetch_result[0].inputs) assert raw_bytes assert raw_bytes[0][:10] == persist_torch_tensor.read_bytes()[:10] assert raw_bytes[1][:10] == body2[:10] @@ -288,8 +312,11 @@ def test_fetch_input_feature_store_missing() -> None: fsd = feature_store.descriptor request = InferenceRequest(input_keys=[FeatureStoreKey(key=key, descriptor=fsd)]) + model_key = FeatureStoreKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + with pytest.raises(sse.SmartSimError) as ex: - worker.fetch_inputs(request, {fsd: feature_store}) + worker.fetch_inputs(batch, {fsd: feature_store}) # ensure the error message includes key-identifying information assert key in ex.value.args[0] @@ -307,21 +334,11 @@ def test_fetch_input_memory(persist_torch_tensor: pathlib.Path) -> None: feature_store[key] = persist_torch_tensor.read_bytes() request = InferenceRequest(input_keys=[FeatureStoreKey(key=key, descriptor=fsd)]) - fetch_result = worker.fetch_inputs(request, {fsd: feature_store}) - assert fetch_result.inputs is not None - - -def test_batch_requests() -> None: - """Verify batch requests handles an empty data set gracefully""" - worker = MachineLearningWorkerCore - result = TransformInputResult([]) - - request = InferenceRequest(batch_size=10) + model_key = FeatureStoreKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) - with pytest.raises(NotImplementedError): - # NOTE: we expect this to fail since it's not yet implemented. - # TODO: once implemented, replace this expectation of failure... - worker.batch_requests(request, result) + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) + assert fetch_result[0].inputs is not None def test_place_outputs() -> None: diff --git a/tests/dragon/test_device_manager.py b/tests/dragon/test_device_manager.py new file mode 100644 index 0000000000..8edeb60fbb --- /dev/null +++ b/tests/dragon/test_device_manager.py @@ -0,0 +1,185 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.control.devicemanager import ( + DeviceManager, + WorkerDevice, +) +from smartsim._core.mli.infrastructure.storage.featurestore import ( + FeatureStore, + FeatureStoreKey, +) +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceRequest, + LoadModelResult, + MachineLearningWorkerBase, + RequestBatch, + TransformInputResult, + TransformOutputResult, +) + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +class MockWorker(MachineLearningWorkerBase): + @staticmethod + def fetch_model( + batch: RequestBatch, feature_stores: t.Dict[str, FeatureStore] + ) -> FetchModelResult: + if batch.has_raw_model: + return FetchModelResult(batch.raw_model) + return FetchModelResult(b"fetched_model") + + @staticmethod + def load_model( + batch: RequestBatch, fetch_result: FetchModelResult, device: str + ) -> LoadModelResult: + return LoadModelResult(fetch_result.model_bytes) + + @staticmethod + def transform_input( + batch: RequestBatch, + fetch_results: list[FetchInputResult], + mem_pool: "MemoryPool", + ) -> TransformInputResult: + return TransformInputResult(b"result", [slice(0, 1)], [[1, 2]], ["float32"]) + + @staticmethod + def execute( + batch: RequestBatch, + load_result: LoadModelResult, + transform_result: TransformInputResult, + device: str, + ) -> ExecuteResult: + return ExecuteResult(b"result", [slice(0, 1)]) + + @staticmethod + def transform_output( + batch: RequestBatch, execute_result: ExecuteResult + ) -> t.List[TransformOutputResult]: + return [TransformOutputResult(b"result", None, "c", "float32")] + + +def test_worker_device(): + worker_device = WorkerDevice("gpu:0") + assert worker_device.name == "gpu:0" + + model_key = "my_model_key" + model = b"the model" + + worker_device.add_model(model_key, model) + + assert model_key in worker_device + assert worker_device.get_model(model_key) == model + worker_device.remove_model(model_key) + + assert model_key not in worker_device + + +def test_device_manager_model_in_request(): + + worker_device = WorkerDevice("gpu:0") + device_manager = DeviceManager(worker_device) + + worker = MockWorker() + + tensor_key = FeatureStoreKey(key="key", descriptor="desc") + output_key = FeatureStoreKey(key="key", descriptor="desc") + model_key = FeatureStoreKey(key="model key", descriptor="desc") + + request = InferenceRequest( + model_key=model_key, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=b"raw model", + batch_size=0, + ) + + request_batch = RequestBatch( + [request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_key, + ) + + with device_manager.get_device( + worker=worker, batch=request_batch, feature_stores={} + ) as returned_device: + + assert returned_device == worker_device + assert worker_device.get_model(model_key.key) == b"raw model" + + assert model_key.key not in worker_device + + +def test_device_manager_model_key(): + + worker_device = WorkerDevice("gpu:0") + device_manager = DeviceManager(worker_device) + + worker = MockWorker() + + tensor_key = FeatureStoreKey(key="key", descriptor="desc") + output_key = FeatureStoreKey(key="key", descriptor="desc") + model_key = FeatureStoreKey(key="model key", descriptor="desc") + + request = InferenceRequest( + model_key=model_key, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=None, + batch_size=0, + ) + + request_batch = RequestBatch( + [request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_key, + ) + + with device_manager.get_device( + worker=worker, batch=request_batch, feature_stores={} + ) as returned_device: + + assert returned_device == worker_device + assert worker_device.get_model(model_key.key) == b"fetched_model" + + assert model_key.key in worker_device diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py index 5603269b2f..b20424866a 100644 --- a/tests/dragon/test_error_handling.py +++ b/tests/dragon/test_error_handling.py @@ -30,12 +30,19 @@ dragon = pytest.importorskip("dragon") +import multiprocessing as mp + import dragon.utils as du from dragon.channels import Channel from dragon.data.ddict.ddict import DDict from dragon.fli import FLInterface +from dragon.mpbridge.queues import DragonQueue from smartsim._core.mli.comm.channel.dragonfli import DragonFLIChannel +from smartsim._core.mli.infrastructure.control.devicemanager import WorkerDevice +from smartsim._core.mli.infrastructure.control.requestdispatcher import ( + RequestDispatcher, +) from smartsim._core.mli.infrastructure.control.workermanager import ( WorkerManager, exception_handler, @@ -44,13 +51,18 @@ from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import ( DragonFeatureStore, ) -from smartsim._core.mli.infrastructure.storage.featurestore import FeatureStore +from smartsim._core.mli.infrastructure.storage.featurestore import ( + FeatureStore, + FeatureStoreKey, +) from smartsim._core.mli.infrastructure.worker.worker import ( ExecuteResult, FetchInputResult, FetchModelResult, InferenceReply, + InferenceRequest, LoadModelResult, + RequestBatch, TransformInputResult, TransformOutputResult, ) @@ -85,7 +97,7 @@ def setup_worker_manager_model_bytes( backbone_descriptor: str, app_feature_store: FeatureStore, ): - integrated_worker = IntegratedTorchWorker() + integrated_worker_type = IntegratedTorchWorker chan = Channel.make_process_local() queue = FLInterface(main_ch=chan) @@ -95,17 +107,136 @@ def setup_worker_manager_model_bytes( # Put backbone descriptor into env var for the `EnvironmentConfigLoader` monkeypatch.setenv("_SMARTSIM_INFRA_BACKBONE", backbone_descriptor) + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + dispatcher_task_queue = mp.Queue(maxsize=0) + worker_manager = WorkerManager( - EnvironmentConfigLoader( - featurestore_factory=DragonFeatureStore.from_descriptor, - callback_factory=FileSystemCommChannel.from_descriptor, - queue_factory=DragonFLIChannel.from_descriptor, - ), - integrated_worker, + config_loader=config_loader, + worker_type=integrated_worker_type, + dispatcher_queue=dispatcher_task_queue, as_service=False, cooldown=3, ) + tensor_key = FeatureStoreKey(key="key", descriptor=app_feature_store.descriptor) + output_key = FeatureStoreKey(key="key", descriptor=app_feature_store.descriptor) + + request = InferenceRequest( + model_key=None, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=b"model", + batch_size=0, + ) + + model_id = FeatureStoreKey(key="key", descriptor=app_feature_store.descriptor) + + request_batch = RequestBatch( + [request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_id, + ) + + dispatcher_task_queue.put(request_batch) + return worker_manager, integrated_worker_type + + +@pytest.fixture +def setup_worker_manager_model_key( + test_dir: str, + monkeypatch: pytest.MonkeyPatch, + backbone_descriptor: str, + app_feature_store: FeatureStore, +): + integrated_worker_type = IntegratedTorchWorker + + chan = Channel.make_process_local() + queue = FLInterface(main_ch=chan) + monkeypatch.setenv( + "_SMARTSIM_REQUEST_QUEUE", du.B64.bytes_to_str(queue.serialize()) + ) + # Put backbone descriptor into env var for the `EnvironmentConfigLoader` + monkeypatch.setenv("_SMARTSIM_INFRA_BACKBONE", backbone_descriptor) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + dispatcher_task_queue = mp.Queue(maxsize=0) + + worker_manager = WorkerManager( + config_loader=config_loader, + worker_type=integrated_worker_type, + dispatcher_queue=dispatcher_task_queue, + as_service=False, + cooldown=3, + ) + + tensor_key = FeatureStoreKey(key="key", descriptor=app_feature_store.descriptor) + output_key = FeatureStoreKey(key="key", descriptor=app_feature_store.descriptor) + model_id = FeatureStoreKey(key="model key", descriptor=app_feature_store.descriptor) + + request = InferenceRequest( + model_key=model_id, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=b"model", + batch_size=0, + ) + request_batch = RequestBatch( + [request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_id, + ) + + dispatcher_task_queue.put(request_batch) + return worker_manager, integrated_worker_type + + +@pytest.fixture +def setup_request_dispatcher_model_bytes( + test_dir, + monkeypatch: pytest.MonkeyPatch, + backbone_descriptor: str, + app_feature_store: FeatureStore, +): + integrated_worker_type = IntegratedTorchWorker + + chan = Channel.make_process_local() + queue = FLInterface(main_ch=chan) + monkeypatch.setenv( + "_SMARTSIM_REQUEST_QUEUE", du.B64.bytes_to_str(queue.serialize()) + ) + # Put backbone descriptor into env var for the `EnvironmentConfigLoader` + monkeypatch.setenv("_SMARTSIM_INFRA_BACKBONE", backbone_descriptor) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + request_dispatcher = RequestDispatcher( + batch_timeout=0, + batch_size=0, + config_loader=config_loader, + worker_type=integrated_worker_type, + ) + request_dispatcher._on_start() + tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) model = MessageHandler.build_model(b"model", "model name", "v 0.0.1") @@ -113,19 +244,20 @@ def setup_worker_manager_model_bytes( test_dir, model, [tensor_key], [output_key], [], None ) ser_request = MessageHandler.serialize_request(request) - worker_manager._task_queue.send(ser_request) - return worker_manager, integrated_worker + request_dispatcher._incoming_channel.send(ser_request) + + return request_dispatcher, integrated_worker_type @pytest.fixture -def setup_worker_manager_model_key( - test_dir: str, +def setup_request_dispatcher_model_key( + test_dir, monkeypatch: pytest.MonkeyPatch, backbone_descriptor: str, app_feature_store: FeatureStore, ): - integrated_worker = IntegratedTorchWorker() + integrated_worker_type = IntegratedTorchWorker chan = Channel.make_process_local() queue = FLInterface(main_ch=chan) @@ -135,29 +267,33 @@ def setup_worker_manager_model_key( # Put backbone descriptor into env var for the `EnvironmentConfigLoader` monkeypatch.setenv("_SMARTSIM_INFRA_BACKBONE", backbone_descriptor) - worker_manager = WorkerManager( - EnvironmentConfigLoader( - featurestore_factory=DragonFeatureStore.from_descriptor, - callback_factory=FileSystemCommChannel.from_descriptor, - queue_factory=DragonFLIChannel.from_descriptor, - ), - integrated_worker, - as_service=False, - cooldown=3, + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + request_dispatcher = RequestDispatcher( + batch_timeout=0, + batch_size=0, + config_loader=config_loader, + worker_type=integrated_worker_type, ) + request_dispatcher._on_start() tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) model_key = MessageHandler.build_model_key( - "model key", app_feature_store.descriptor + key="model key", feature_store_descriptor=app_feature_store.descriptor ) request = MessageHandler.build_request( test_dir, model_key, [tensor_key], [output_key], [], None ) ser_request = MessageHandler.serialize_request(request) - worker_manager._task_queue.send(ser_request) - return worker_manager, integrated_worker + request_dispatcher._incoming_channel.send(ser_request) + + return request_dispatcher, integrated_worker_type def mock_pipeline_stage(monkeypatch: pytest.MonkeyPatch, integrated_worker, stage): @@ -167,7 +303,7 @@ def mock_stage(*args, **kwargs): monkeypatch.setattr(integrated_worker, stage, mock_stage) mock_reply_fn = MagicMock() monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + "smartsim._core.mli.infrastructure.control.error_handling.build_failure_reply", mock_reply_fn, ) @@ -193,21 +329,15 @@ def mock_exception_handler(exc, reply_channel, failure_message): "stage, error_message", [ pytest.param( - "fetch_model", "Failed while fetching the model.", id="fetch model" + "fetch_model", + "Error loading model on device or getting device.", + id="fetch model", ), pytest.param( "load_model", - "Failed while loading model from feature store.", + "Error loading model on device or getting device.", id="load model", ), - pytest.param( - "fetch_inputs", "Failed while fetching the inputs.", id="fetch inputs" - ), - pytest.param( - "transform_input", - "Failed while transforming the input.", - id="transform inputs", - ), pytest.param("execute", "Failed while executing.", id="execute"), pytest.param( "transform_output", @@ -219,7 +349,7 @@ def mock_exception_handler(exc, reply_channel, failure_message): ), ], ) -def test_pipeline_stage_errors_handled( +def test_wm_pipeline_stage_errors_handled( request, setup_worker_manager, monkeypatch: pytest.MonkeyPatch, @@ -227,7 +357,13 @@ def test_pipeline_stage_errors_handled( error_message: str, ): """Ensures that the worker manager does not crash after a failure in various pipeline stages""" - worker_manager, integrated_worker = request.getfixturevalue(setup_worker_manager) + worker_manager, integrated_worker_type = request.getfixturevalue( + setup_worker_manager + ) + integrated_worker = worker_manager._worker + + worker_manager._on_start() + device = worker_manager._device_manager._device mock_reply_fn = mock_pipeline_stage(monkeypatch, integrated_worker, stage) if stage not in ["fetch_model"]: @@ -236,42 +372,28 @@ def test_pipeline_stage_errors_handled( "fetch_model", MagicMock(return_value=FetchModelResult(b"result_bytes")), ) - if stage not in ["fetch_model", "load_model"]: monkeypatch.setattr( integrated_worker, "load_model", MagicMock(return_value=LoadModelResult(b"result_bytes")), ) - if stage not in ["fetch_model", "load_model", "fetch_inputs"]: monkeypatch.setattr( - integrated_worker, - "fetch_inputs", - MagicMock(return_value=FetchInputResult([b"result_bytes"], None)), - ) - if stage not in ["fetch_model", "load_model", "fetch_inputs", "transform_input"]: - monkeypatch.setattr( - integrated_worker, - "transform_input", - MagicMock(return_value=TransformInputResult(b"result_bytes")), + device, + "get_model", + MagicMock(return_value=b"result_bytes"), ) if stage not in [ "fetch_model", - "load_model", - "fetch_inputs", - "transform_input", "execute", ]: monkeypatch.setattr( integrated_worker, "execute", - MagicMock(return_value=ExecuteResult(b"result_bytes")), + MagicMock(return_value=ExecuteResult(b"result_bytes", [slice(0, 1)])), ) if stage not in [ "fetch_model", - "load_model", - "fetch_inputs", - "transform_input", "execute", "transform_output", ]: @@ -279,7 +401,7 @@ def test_pipeline_stage_errors_handled( integrated_worker, "transform_output", MagicMock( - return_value=TransformOutputResult(b"result", [], "c", "float32") + return_value=[TransformOutputResult(b"result", [], "c", "float32")] ), ) @@ -289,6 +411,56 @@ def test_pipeline_stage_errors_handled( mock_reply_fn.assert_called_with("fail", error_message) +@pytest.mark.parametrize( + "setup_request_dispatcher", + [ + pytest.param("setup_request_dispatcher_model_bytes"), + pytest.param("setup_request_dispatcher_model_key"), + ], +) +@pytest.mark.parametrize( + "stage, error_message", + [ + pytest.param( + "fetch_inputs", + "Error fetching input.", + id="fetch input", + ), + pytest.param( + "transform_input", + "Error Transforming input.", + id="transform input", + ), + ], +) +def test_dispatcher_pipeline_stage_errors_handled( + request, + setup_request_dispatcher, + monkeypatch: pytest.MonkeyPatch, + stage: str, + error_message: str, +): + """Ensures that the request dispatcher does not crash after a failure in various pipeline stages""" + request_dispatcher, integrated_worker_type = request.getfixturevalue( + setup_request_dispatcher + ) + integrated_worker = request_dispatcher._worker + + mock_reply_fn = mock_pipeline_stage(monkeypatch, integrated_worker, stage) + + if stage not in ["fetch_inputs"]: + monkeypatch.setattr( + integrated_worker, + "fetch_inputs", + MagicMock(return_value=[FetchInputResult(result=[b"result"], meta=None)]), + ) + + request_dispatcher._on_iteration() + + mock_reply_fn.assert_called_once() + mock_reply_fn.assert_called_with("fail", error_message) + + def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch): """Ensures that the worker manager does not crash after a failure in the execute pipeline stage""" @@ -296,7 +468,7 @@ def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch): mock_reply_fn = MagicMock() monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + "smartsim._core.mli.infrastructure.control.error_handling.build_failure_reply", mock_reply_fn, ) diff --git a/tests/dragon/test_request_dispatcher.py b/tests/dragon/test_request_dispatcher.py new file mode 100644 index 0000000000..c8d97dd7ed --- /dev/null +++ b/tests/dragon/test_request_dispatcher.py @@ -0,0 +1,331 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import gc +import io +import logging +import pathlib +import socket +import time +import typing as t +from queue import Empty + +import numpy as np +import pytest + +torch = pytest.importorskip("torch") +dragon = pytest.importorskip("dragon") + +import base64 +import multiprocessing as mp + +try: + mp.set_start_method("dragon") +except Exception: + pass + +import os + +import dragon.channels as dch +import dragon.infrastructure.policy as dragon_policy +import dragon.infrastructure.process_desc as dragon_process_desc +import dragon.native.process as dragon_process +from dragon import fli +from dragon.channels import Channel +from dragon.data.ddict.ddict import DDict +from dragon.managed_memory import MemoryAlloc, MemoryPool +from dragon.mpbridge.queues import DragonQueue + +from smartsim._core.entrypoints.service import Service +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim._core.mli.comm.channel.dragonchannel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragonfli import DragonFLIChannel +from smartsim._core.mli.infrastructure.control.requestdispatcher import ( + RequestBatch, + RequestDispatcher, +) +from smartsim._core.mli.infrastructure.control.workermanager import ( + EnvironmentConfigLoader, +) +from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.featurestore import FeatureStore +from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +from .featurestore import FileSystemFeatureStore +from .utils.channel import FileSystemCommChannel + +logger = get_logger(__name__) +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +def persist_model_file(model_path: pathlib.Path) -> pathlib.Path: + """Create a simple torch model and persist to disk for + testing purposes. + + TODO: remove once unit tests are in place""" + # test_path = pathlib.Path(work_dir) + if not model_path.parent.exists(): + model_path.parent.mkdir(parents=True, exist_ok=True) + + model_path.unlink(missing_ok=True) + + model = torch.nn.Linear(2, 1) + torch.save(model, model_path) + + return model_path + + +def mock_messages( + request_dispatcher_queue: DragonFLIChannel, + feature_store: FeatureStore, + feature_store_root_dir: pathlib.Path, + comm_channel_root_dir: pathlib.Path, +) -> None: + """Mock event producer for triggering the inference pipeline""" + feature_store_root_dir.mkdir(parents=True, exist_ok=True) + comm_channel_root_dir.mkdir(parents=True, exist_ok=True) + + model_path = persist_model_file(feature_store_root_dir.parent / "model_original.pt") + model_bytes = model_path.read_bytes() + model_key = str(feature_store_root_dir / "model_fs.pt") + + feature_store[model_key] = model_bytes + + for iteration_number in range(2): + + channel_key = Channel.make_process_local().serialize() + callback_channel = DragonCommChannel(channel_key) + + input_path = feature_store_root_dir / f"{iteration_number}/input.pt" + output_path = feature_store_root_dir / f"{iteration_number}/output.pt" + + input_key = str(input_path) + output_key = str(output_path) + + tensor = ( + (iteration_number + 1) * torch.ones((1, 2), dtype=torch.float32) + ).numpy() + fsd = feature_store.descriptor + + tensor_desc = MessageHandler.build_tensor_descriptor( + "c", "float32", list(tensor.shape) + ) + + message_tensor_output_key = MessageHandler.build_tensor_key(output_key, fsd) + message_tensor_input_key = MessageHandler.build_tensor_key(input_key, fsd) + message_model_key = MessageHandler.build_model_key(model_key, fsd) + + request = MessageHandler.build_request( + reply_channel=base64.b64encode(callback_channel.descriptor).decode("utf-8"), + model=message_model_key, + inputs=[tensor_desc], + outputs=[message_tensor_output_key], + output_descriptors=[], + custom_attributes=None, + ) + request_bytes = MessageHandler.serialize_request(request) + with request_dispatcher_queue._fli.sendh( + timeout=None, stream_channel=request_dispatcher_queue._channel + ) as sendh: + sendh.send_bytes(request_bytes) + sendh.send_bytes(tensor.tobytes()) + time.sleep(1) + + +@pytest.fixture +def prepare_environment(test_dir: str) -> pathlib.Path: + """Cleanup prior outputs to run demo repeatedly""" + path = pathlib.Path(f"{test_dir}/workermanager.log") + logging.basicConfig(filename=path.absolute(), level=logging.DEBUG) + return path + + +def service_as_dragon_proc( + service: Service, cpu_affinity: list[int], gpu_affinity: list[int] +) -> dragon_process.Process: + + options = dragon_process_desc.ProcessOptions(make_inf_channels=True) + local_policy = dragon_policy.Policy( + placement=dragon_policy.Policy.Placement.HOST_NAME, + host_name=socket.gethostname(), + cpu_affinity=cpu_affinity, + gpu_affinity=gpu_affinity, + ) + return dragon_process.Process( + target=service.execute, + args=[], + cwd=os.getcwd(), + policy=local_policy, + options=options, + stderr=dragon_process.Popen.STDOUT, + stdout=dragon_process.Popen.STDOUT, + ) + + +def test_request_dispatcher(prepare_environment: pathlib.Path) -> None: + """Test the request dispatcher batching and queueing system + + This also includes setting a queue to disposable, checking that it is no + longer referenced by the dispatcher. + """ + + test_path = prepare_environment + fs_path = test_path / "feature_store" + comm_path = test_path / "comm_store" + + to_worker_channel = dch.Channel.make_process_local() + to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) + to_worker_fli_serialized = to_worker_fli.serialize() + + # NOTE: env vars should be set prior to instantiating EnvironmentConfigLoader + # or test environment may be unable to send messages w/queue + descriptor = base64.b64encode(to_worker_fli_serialized).decode("utf-8") + os.environ["_SMARTSIM_REQUEST_QUEUE"] = descriptor + + ddict = DDict(1, 2, 4 * 1024**2) + dragon_fs = DragonFeatureStore(ddict) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + integrated_worker_type = TorchWorker + + request_dispatcher = RequestDispatcher( + batch_timeout=0, + batch_size=2, + config_loader=config_loader, + worker_type=integrated_worker_type, + mem_pool_size=2 * 1024**2, + ) + + worker_queue = config_loader.get_queue() + if worker_queue is None: + logger.warn( + "FLI input queue not loaded correctly from config_loader: " + f"{config_loader._queue_descriptor}" + ) + + request_dispatcher._on_start() + + for _ in range(2): + batch: t.Optional[RequestBatch] = None + mem_allocs = [] + tensors = [] + fs_path = test_path / f"feature_store" + comm_path = test_path / f"comm_store" + model_key = str(fs_path / "model_fs.pt") + + # create a mock client application to populate the request queue + msg_pump = mp.Process( + target=mock_messages, + args=( + worker_queue, + dragon_fs, + fs_path, + comm_path, + ), + ) + + msg_pump.start() + + time.sleep(1) + + for attempts in range(15): + try: + request_dispatcher._on_iteration() + batch = request_dispatcher.task_queue.get(timeout=1) + break + except Empty: + continue + except Exception as exc: + raise exc + + try: + assert batch is not None + assert batch.has_valid_requests + + transform_result = batch.inputs + for transformed, dims, dtype in zip( + transform_result.transformed, + transform_result.dims, + transform_result.dtypes, + ): + mem_alloc = MemoryAlloc.attach(transformed) + mem_allocs.append(mem_alloc) + itemsize = np.empty((1), dtype=dtype).itemsize + tensors.append( + torch.from_numpy( + np.frombuffer( + mem_alloc.get_memview()[0 : np.prod(dims) * itemsize], + dtype=dtype, + ).reshape(dims) + ) + ) + + assert len(batch.requests) == 2 + assert batch.model_id.key == model_key + assert model_key in request_dispatcher._queues + assert model_key in request_dispatcher._active_queues + assert len(request_dispatcher._queues[model_key]) == 1 + assert request_dispatcher._queues[model_key][0].empty() + assert request_dispatcher._queues[model_key][0].model_id.key == model_key + assert len(tensors) == 1 + assert tensors[0].shape == torch.Size([2, 2]) + + for tensor in tensors: + for sample_idx in range(tensor.shape[0]): + tensor_in = tensor[sample_idx] + tensor_out = (sample_idx + 1) * torch.ones( + (2,), dtype=torch.float32 + ) + assert torch.equal(tensor_in, tensor_out) + + except Exception as exc: + raise exc + finally: + for mem_alloc in mem_allocs: + mem_alloc.free() + + msg_pump.kill() + + request_dispatcher._active_queues[model_key].make_disposable() + assert request_dispatcher._active_queues[model_key].can_be_removed + + request_dispatcher._on_iteration() + + assert model_key not in request_dispatcher._active_queues + assert model_key not in request_dispatcher._queues + + # Try to remove the dispatcher and free the memory + del request_dispatcher + gc.collect() diff --git a/tests/mli/test_torch_worker.py b/tests/dragon/test_torch_worker.py similarity index 61% rename from tests/mli/test_torch_worker.py rename to tests/dragon/test_torch_worker.py index 1e8bba7e33..88e800240f 100644 --- a/tests/mli/test_torch_worker.py +++ b/tests/dragon/test_torch_worker.py @@ -25,9 +25,15 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import io +import typing as t +import numpy as np import pytest import torch + +dragon = pytest.importorskip("dragon") +import dragon.globalservices.pool as dragon_gs_pool +from dragon.managed_memory import MemoryAlloc, MemoryPool from torch import nn from torch.nn import functional as F @@ -39,14 +45,15 @@ FetchModelResult, InferenceRequest, LoadModelResult, + RequestBatch, TransformInputResult, ) from smartsim._core.mli.message_handler import MessageHandler from smartsim.log import get_logger logger = get_logger(__name__) -# The tests in this file belong to the group_a group -pytestmark = pytest.mark.group_a +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon # simple MNIST in PyTorch @@ -60,7 +67,7 @@ def __init__(self): self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) - def forward(self, x): + def forward(self, x, y): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) @@ -86,7 +93,7 @@ def get_batch() -> torch.Tensor: def create_torch_model(): n = Net() example_forward_input = get_batch() - module = torch.jit.trace(n, example_forward_input) + module = torch.jit.trace(n, [example_forward_input, example_forward_input]) model_buffer = io.BytesIO() torch.jit.save(module, model_buffer) return model_buffer.getvalue() @@ -113,18 +120,27 @@ def get_request() -> InferenceRequest: ) +def get_request_batch_from_request( + request: InferenceRequest, inputs: t.Optional[TransformInputResult] = None +) -> RequestBatch: + + return RequestBatch([request], inputs, request.model_key) + + sample_request: InferenceRequest = get_request() +sample_request_batch: RequestBatch = get_request_batch_from_request(sample_request) worker = TorchWorker() def test_load_model(mlutils) -> None: fetch_model_result = FetchModelResult(sample_request.raw_model) load_model_result = worker.load_model( - sample_request, fetch_model_result, mlutils.get_test_device().lower() + sample_request_batch, fetch_model_result, mlutils.get_test_device().lower() ) assert load_model_result.model( - get_batch().to(torch_device[mlutils.get_test_device().lower()]) + get_batch().to(torch_device[mlutils.get_test_device().lower()]), + get_batch().to(torch_device[mlutils.get_test_device().lower()]), ).shape == torch.Size((20, 10)) @@ -133,44 +149,73 @@ def test_transform_input(mlutils) -> None: sample_request.raw_inputs, sample_request.input_meta ) + mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) + transform_input_result = worker.transform_input( - sample_request, fetch_input_result, mlutils.get_test_device().lower() + sample_request_batch, [fetch_input_result], mem_pool ) - assert all( - transformed.shape == get_batch().shape - for transformed in transform_input_result.transformed - ) + batch = get_batch().numpy() + assert transform_input_result.slices[0] == slice(0, batch.shape[0]) + + for tensor_index in range(2): + assert torch.Size(transform_input_result.dims[tensor_index]) == batch.shape + assert transform_input_result.dtypes[tensor_index] == str(batch.dtype) + mem_alloc = MemoryAlloc.attach(transform_input_result.transformed[tensor_index]) + itemsize = batch.itemsize + tensor = torch.from_numpy( + np.frombuffer( + mem_alloc.get_memview()[ + 0 : np.prod(transform_input_result.dims[tensor_index]) * itemsize + ], + dtype=transform_input_result.dtypes[tensor_index], + ).reshape(transform_input_result.dims[tensor_index]) + ) + + assert torch.equal( + tensor, torch.from_numpy(sample_request.raw_inputs[tensor_index]) + ) + + mem_pool.destroy() def test_execute(mlutils) -> None: load_model_result = LoadModelResult( Net().to(torch_device[mlutils.get_test_device().lower()]) ) - transform_result = TransformInputResult( - [ - get_batch().to(torch_device[mlutils.get_test_device().lower()]) - for _ in range(2) - ] + fetch_input_result = FetchInputResult( + sample_request.raw_inputs, sample_request.input_meta + ) + + request_batch = get_request_batch_from_request(sample_request, fetch_input_result) + + mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) + + transform_result = worker.transform_input( + request_batch, [fetch_input_result], mem_pool ) - execute_result = worker.execute(sample_request, load_model_result, transform_result) + execute_result = worker.execute( + request_batch, + load_model_result, + transform_result, + mlutils.get_test_device().lower(), + ) assert all( result.shape == torch.Size((20, 10)) for result in execute_result.predictions ) + mem_pool.destroy() + def test_transform_output(mlutils): - execute_result = ExecuteResult([torch.rand((20, 10)) for _ in range(2)]) + tensors = [torch.rand((20, 10)) for _ in range(2)] + execute_result = ExecuteResult(tensors, [slice(0, 20)]) - transformed_output = worker.transform_output( - sample_request, execute_result, torch_device[mlutils.get_test_device().lower()] - ) + transformed_output = worker.transform_output(sample_request_batch, execute_result) - assert transformed_output.outputs == [ - item.numpy().tobytes() for item in execute_result.predictions - ] - assert transformed_output.shape == None - assert transformed_output.order == "c" - assert transformed_output.dtype == "float32" + assert transformed_output[0].outputs == [item.numpy().tobytes() for item in tensors] + assert transformed_output[0].shape == None + assert transformed_output[0].order == "c" + assert transformed_output[0].dtype == "float32" diff --git a/tests/dragon/test_worker_manager.py b/tests/dragon/test_worker_manager.py index c8332c260f..a334164257 100644 --- a/tests/dragon/test_worker_manager.py +++ b/tests/dragon/test_worker_manager.py @@ -26,7 +26,6 @@ import io import logging -import multiprocessing as mp import pathlib import time @@ -36,10 +35,18 @@ dragon = pytest.importorskip("dragon") import base64 +import multiprocessing as mp + +try: + mp.set_start_method("dragon") +except Exception: + pass + import os import dragon.channels as dch from dragon import fli +from dragon.mpbridge.queues import DragonQueue from smartsim._core.mli.comm.channel.channel import CommChannelBase from smartsim._core.mli.comm.channel.dragonfli import DragonFLIChannel @@ -174,14 +181,15 @@ def test_worker_manager(prepare_environment: pathlib.Path) -> None: callback_factory=FileSystemCommChannel.from_descriptor, queue_factory=DragonFLIChannel.from_descriptor, ) - integrated_worker = TorchWorker() + integrated_worker_type = TorchWorker worker_manager = WorkerManager( config_loader, - integrated_worker, + integrated_worker_type, as_service=True, cooldown=5, device="cpu", + dispatcher_queue=mp.Queue(maxsize=0), ) worker_queue = config_loader.get_queue()