Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
e98e2fe
Initial FLI-based implementation
al-rigazzi Jun 25, 2024
043f0e7
Add inference example stub
al-rigazzi Jun 25, 2024
efc9e83
Lint, style, black magic
al-rigazzi Jun 25, 2024
35ec45e
Merge branch 'mli-feature' of https://github.com/CrayLabs/SmartSim in…
al-rigazzi Jun 25, 2024
ed3c42a
Bring up to feature branch
al-rigazzi Jun 25, 2024
e5be26b
Update example
al-rigazzi Jun 25, 2024
a23010f
Change the changelog
al-rigazzi Jun 25, 2024
3c20f46
Make style
al-rigazzi Jun 25, 2024
b9ed5ba
Attempt to mitigate import dragon error
al-rigazzi Jun 26, 2024
0de06f3
Import dragon optional
al-rigazzi Jun 26, 2024
d051385
isort
al-rigazzi Jun 26, 2024
e77b1cd
Fix imports in dragon backend tests
al-rigazzi Jun 26, 2024
a90888d
Style
al-rigazzi Jun 26, 2024
b431221
Fix type
al-rigazzi Jun 26, 2024
23efebc
Rename examples dir
al-rigazzi Jun 26, 2024
09b9d24
Remove old dir
al-rigazzi Jun 26, 2024
56d8e50
Add tests for torch worker
al-rigazzi Jun 26, 2024
6cec83e
Switch to sender-supplied channels in app example
al-rigazzi Jun 27, 2024
3ad6d44
Add prototype client for mock app
al-rigazzi Jun 27, 2024
bd5f133
Update mock app
al-rigazzi Jun 28, 2024
3e343ee
Changes to feature store
al-rigazzi Jul 4, 2024
a0525e5
Merge upstream
al-rigazzi Jul 5, 2024
a2bed26
Make style
al-rigazzi Jul 5, 2024
36e92d9
Fix typing
al-rigazzi Jul 5, 2024
59840a3
Fix lint
al-rigazzi Jul 5, 2024
b35b37d
Remove duplicated/useless comments
al-rigazzi Jul 5, 2024
51e0b17
Bring up to date with new schema
al-rigazzi Jul 9, 2024
1fcf17d
Add feature store prototype caching
al-rigazzi Jul 10, 2024
d76f880
Add redis driver, fix FLI
al-rigazzi Jul 10, 2024
0564d01
Merge branch 'mli-feature' of https://github.com/CrayLabs/SmartSim in…
al-rigazzi Jul 11, 2024
3938ec8
Update post-merge
al-rigazzi Jul 11, 2024
273a7d9
Fix typing
al-rigazzi Jul 11, 2024
a12d923
isort
al-rigazzi Jul 11, 2024
38b0de1
Update envloader test
al-rigazzi Jul 11, 2024
8223f96
Input not concatenated correctly
al-rigazzi Jul 15, 2024
4a83abe
Changes to entrypoint
al-rigazzi Jul 15, 2024
6ea0671
Use batch where needed
al-rigazzi Jul 16, 2024
d26e5f0
Adjustments, get back to one thread
al-rigazzi Jul 16, 2024
293e977
Move timing
al-rigazzi Jul 17, 2024
40c0471
multiprocess solution
al-rigazzi Jul 17, 2024
5893da5
Merge branch 'queue-wm' of https://github.com/al-rigazzi/SmartSim int…
al-rigazzi Jul 17, 2024
0bb1487
Constrain torch threads in worker
al-rigazzi Jul 17, 2024
7b9e00c
Affinity and correct process
al-rigazzi Jul 18, 2024
94a5263
Fixes to example
al-rigazzi Jul 18, 2024
f337de9
Merge branch 'develop' of https://github.com/CrayLabs/SmartSim into n…
al-rigazzi Jul 18, 2024
0819d2b
Merge branch 'new-merger' into queue-wm
al-rigazzi Jul 18, 2024
a7b5262
Add request dispatcher post-merge changes
al-rigazzi Jul 18, 2024
717ef88
Misc fixes
al-rigazzi Jul 19, 2024
05b49f3
Correct exception_handler behavior on batch
al-rigazzi Jul 19, 2024
14c3e9f
Style
al-rigazzi Jul 19, 2024
f93522f
Working post-merge version
al-rigazzi Jul 19, 2024
1bd7388
Fix indexing in multi-output
al-rigazzi Jul 20, 2024
d1e9639
Almost good results
al-rigazzi Jul 21, 2024
91ffaee
New timings API
al-rigazzi Jul 21, 2024
b9e9796
Pre-cleanup, best results so far
al-rigazzi Jul 23, 2024
8958aa1
Make dispatcher a service and refactor
al-rigazzi Jul 24, 2024
79eb936
Fixes for batched requests
al-rigazzi Jul 24, 2024
8759e9f
Pre-PR
al-rigazzi Jul 25, 2024
99d568d
Merge branch 'mli-feature' of https://github.com/CrayLabs/SmartSim in…
al-rigazzi Jul 25, 2024
63a0f31
Remove unused fake versioning function
al-rigazzi Jul 25, 2024
6fb3efd
Fix
al-rigazzi Jul 25, 2024
a0cd4ab
Address review
al-rigazzi Aug 13, 2024
4b66e4b
Merge branch 'mli-feature' of https://github.com/CrayLabs/SmartSim in…
al-rigazzi Aug 13, 2024
af8b639
Static checker passes
al-rigazzi Aug 14, 2024
e4a9db0
Working version, still slow
al-rigazzi Aug 14, 2024
0c0637c
Last fixes
al-rigazzi Aug 14, 2024
7dbeded
Fixing tests
al-rigazzi Aug 17, 2024
0eadc63
MLI driver multi-client
al-rigazzi Aug 17, 2024
8e178d9
Fixed broken test
al-rigazzi Aug 20, 2024
5fb8224
MyPy
al-rigazzi Aug 20, 2024
b6ea732
Fix WM test and add dispatcher error handling
al-rigazzi Aug 21, 2024
9e97e1c
Merge branch 'mli-feature' into queue-wm
al-rigazzi Aug 21, 2024
67242ec
Add RequestDispatcher tests
al-rigazzi Aug 26, 2024
42a00c1
Merge branch 'mli-feature' of https://github.com/CrayLabs/SmartSim in…
al-rigazzi Aug 26, 2024
4a5185b
Added tests for device manager
al-rigazzi Aug 26, 2024
9d0ba30
Fix tests
al-rigazzi Aug 26, 2024
99da355
Style and type
al-rigazzi Aug 26, 2024
c3646d7
Fix mock app
al-rigazzi Aug 26, 2024
c54e880
Small change to app
al-rigazzi Aug 26, 2024
01c6fa9
Merge branch 'mli-feature' into queue-wm
al-rigazzi Aug 26, 2024
093d706
Small change to app
al-rigazzi Aug 26, 2024
d9de5c1
Last fixes!
al-rigazzi Aug 27, 2024
eb03f08
Avoid using t.Self
al-rigazzi Aug 27, 2024
1e1b8c9
Remove unused timing
al-rigazzi Aug 27, 2024
be0b8e0
Split timing for request and tensors
al-rigazzi Aug 27, 2024
bc11d92
Pin watchdog to <5
al-rigazzi Aug 27, 2024
b04f4c1
Style
al-rigazzi Aug 27, 2024
47088f0
Other styling fixes
al-rigazzi Aug 27, 2024
0609eec
Move tests that require dragon.MemoryPool
al-rigazzi Aug 27, 2024
275e102
Update tests
al-rigazzi Aug 27, 2024
b220d99
Style
al-rigazzi Aug 27, 2024
d3ab796
Import or skip dragon
al-rigazzi Aug 27, 2024
14e627e
Isort
al-rigazzi Aug 27, 2024
bbe97ff
Fix pytest import
al-rigazzi Aug 27, 2024
eea793e
Adapt syntax for python 3.9
al-rigazzi Aug 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Jump to:

Description

- Add RequestDispatcher and the possibility of batching inference requests
- Enable hostname selection for dragon tasks
- Remove pydantic dependency from MLI code
- Update MLI environment variables using new naming convention
Expand Down
34 changes: 25 additions & 9 deletions ex/high_throughput_inference/mli_driver.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import argparse
import os
import base64
import cloudpickle
import sys
from smartsim import Experiment
from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker
from smartsim.status import TERMINAL_STATUSES
from smartsim.settings import DragonRunSettings
import time
import typing as t

device = "gpu"
DEVICE = "gpu"
NUM_RANKS = 4
NUM_WORKERS = 1
filedir = os.path.dirname(__file__)
worker_manager_script_name = os.path.join(filedir, "standalone_workermanager.py")
app_script_name = os.path.join(filedir, "mock_app.py")
model_name = os.path.join(filedir, f"resnet50.{device.upper()}.pt")
model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt")

transport: t.Literal["hsta", "tcp"] = "hsta"

Expand All @@ -25,37 +27,51 @@

torch_worker_str = base64.b64encode(cloudpickle.dumps(TorchWorker)).decode("ascii")

worker_manager_rs = exp.create_run_settings(
worker_manager_rs: DragonRunSettings = exp.create_run_settings(
sys.executable,
[
worker_manager_script_name,
"--device",
device,
DEVICE,
"--worker_class",
torch_worker_str,
"--batch_size",
str(NUM_RANKS//NUM_WORKERS),
"--batch_timeout",
str(0.00),
"--num_workers",
str(NUM_WORKERS)
],
)

aff = []

worker_manager_rs.set_cpu_affinity(aff)

worker_manager = exp.create_model("worker_manager", run_settings=worker_manager_rs)
worker_manager.attach_generator_files(to_copy=[worker_manager_script_name])

app_rs = exp.create_run_settings(
app_rs: DragonRunSettings = exp.create_run_settings(
sys.executable,
exe_args=[app_script_name, "--device", device],
exe_args=[app_script_name, "--device", DEVICE, "--log_max_batchsize", str(6)],
)
app_rs.set_tasks_per_node(NUM_RANKS)


app = exp.create_model("app", run_settings=app_rs)
app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name])


exp.generate(worker_manager, app, overwrite=True)
exp.start(worker_manager, app, block=False)

while True:
if exp.get_status(app)[0] in TERMINAL_STATUSES:
time.sleep(10)
exp.stop(worker_manager)
break
if exp.get_status(worker_manager)[0] in TERMINAL_STATUSES:
time.sleep(10)
exp.stop(app)
break
time.sleep(5)

print("Exiting.")
136 changes: 56 additions & 80 deletions ex/high_throughput_inference/mock_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,27 @@
import os
import time
import torch
import numbers

from collections import OrderedDict
from mpi4py import MPI
from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import (
DragonFeatureStore,
)
from smartsim._core.mli.message_handler import MessageHandler
from smartsim.log import get_logger
from smartsim._core.utils.timings import PerfTimer

torch.set_num_interop_threads(16)
torch.set_num_threads(1)

logger = get_logger("App")
logger.info("Started app")

CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False

class ProtoClient:
def __init__(self, timing_on: bool):
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
connect_to_infrastructure()
ddict_str = os.environ["_SMARTSIM_INFRA_BACKBONE"]
self._ddict = DDict.attach(ddict_str)
Expand All @@ -70,61 +77,15 @@ def __init__(self, timing_on: bool):
self._from_worker_ch_serialized = self._from_worker_ch.serialize()
self._to_worker_ch = Channel.make_process_local()

self._start = None
self._interm = None
self._timings: OrderedDict[str, list[numbers.Number]] = OrderedDict()
self._timing_on = timing_on

def _add_label_to_timings(self, label: str):
if label not in self._timings:
self._timings[label] = []

@staticmethod
def _format_number(number: numbers.Number):
return f"{number:0.4e}"

def start_timings(self, batch_size: int):
if self._timing_on:
self._add_label_to_timings("batch_size")
self._timings["batch_size"].append(batch_size)
self._start = time.perf_counter()
self._interm = time.perf_counter()

def end_timings(self):
if self._timing_on:
self._add_label_to_timings("total_time")
self._timings["total_time"].append(
self._format_number(time.perf_counter() - self._start)
)

def measure_time(self, label: str):
if self._timing_on:
self._add_label_to_timings(label)
self._timings[label].append(
self._format_number(time.perf_counter() - self._interm)
)
self._interm = time.perf_counter()

def print_timings(self, to_file: bool = False):
print(" ".join(self._timings.keys()))
value_array = numpy.array(
[value for value in self._timings.values()], dtype=float
)
value_array = numpy.transpose(value_array)
for i in range(value_array.shape[0]):
print(" ".join(self._format_number(value) for value in value_array[i]))
if to_file:
numpy.save("timings.npy", value_array)
numpy.savetxt("timings.txt", value_array)
self.perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"a{rank}_")

def run_model(self, model: bytes | str, batch: torch.Tensor):
tensors = [batch.numpy()]
self.start_timings(batch.shape[0])
self.perf_timer.start_timings("batch_size", batch.shape[0])
built_tensor_desc = MessageHandler.build_tensor_descriptor(
"c", "float32", list(batch.shape)
)
self.measure_time("build_tensor_descriptor")
built_model = None
self.perf_timer.measure_time("build_tensor_descriptor")
if isinstance(model, str):
model_arg = MessageHandler.build_model_key(model, self._backbone_descriptor)
else:
Expand All @@ -137,39 +98,39 @@ def run_model(self, model: bytes | str, batch: torch.Tensor):
output_descriptors=[],
custom_attributes=None,
)
self.measure_time("build_request")
self.perf_timer.measure_time("build_request")
request_bytes = MessageHandler.serialize_request(request)
self.measure_time("serialize_request")
with self._to_worker_fli.sendh(
timeout=None, stream_channel=self._to_worker_ch
) as to_sendh:
self.perf_timer.measure_time("serialize_request")
with self._to_worker_fli.sendh(timeout=None, stream_channel=self._to_worker_ch) as to_sendh:
to_sendh.send_bytes(request_bytes)
for t in tensors:
to_sendh.send_bytes(t.tobytes()) # TODO NOT FAST ENOUGH!!!
# to_sendh.send_bytes(bytes(t.data))
logger.info(f"Message size: {len(request_bytes)} bytes")

self.measure_time("send")
self.perf_timer.measure_time("send_request")
for tensor in tensors:
to_sendh.send_bytes(tensor.tobytes()) #TODO NOT FAST ENOUGH!!!
self.perf_timer.measure_time("send_tensors")
with self._from_worker_ch.recvh(timeout=None) as from_recvh:
resp = from_recvh.recv_bytes(timeout=None)
self.measure_time("receive")
self.perf_timer.measure_time("receive_response")
response = MessageHandler.deserialize_response(resp)
self.measure_time("deserialize_response")
self.perf_timer.measure_time("deserialize_response")
# list of data blobs? recv depending on the len(response.result.descriptors)?
data_blob = from_recvh.recv_bytes(timeout=None)
result = numpy.frombuffer(
data_blob,
dtype=str(response.result.descriptors[0].dataType),
data_blob: bytes = from_recvh.recv_bytes(timeout=None)
self.perf_timer.measure_time("receive_tensor")
result = torch.from_numpy(
numpy.frombuffer(
data_blob,
dtype=str(response.result.descriptors[0].dataType),
)
)
self.measure_time("deserialize_tensor")
self.perf_timer.measure_time("deserialize_tensor")

self.end_timings()
self.perf_timer.end_timings()
return result

def set_model(self, key: str, model: bytes):
self._ddict[key] = model



class ResNetWrapper:
def __init__(self, name: str, model: str):
self._model = torch.jit.load(model)
Expand All @@ -190,24 +151,39 @@ def model(self):
def name(self):
return self._name


if __name__ == "__main__":

parser = argparse.ArgumentParser("Mock application")
parser.add_argument("--device", default="cpu")
parser.add_argument("--device", default="cpu", type=str)
parser.add_argument("--log_max_batchsize", default=8, type=int)
args = parser.parse_args()

resnet = ResNetWrapper("resnet50", f"resnet50.{args.device.upper()}.pt")
resnet = ResNetWrapper("resnet50", f"resnet50.{args.device}.pt")

client = ProtoClient(timing_on=True)
client.set_model(resnet.name, resnet.model)

total_iterations = 100
if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
# TODO: adapt to non-Nvidia devices
torch_device = args.device.replace("gpu", "cuda")
pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to(torch_device)

for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:
logger.info(f"Batch size: {batch_size}")
for iteration_number in range(total_iterations + int(batch_size == 1)):
logger.info(f"Iteration: {iteration_number}")
client.run_model(resnet.name, resnet.get_batch(batch_size))
TOTAL_ITERATIONS = 100

client.print_timings(to_file=True)
for log2_bsize in range(args.log_max_batchsize+1):
b_size: int = 2**log2_bsize
logger.info(f"Batch size: {b_size}")
for iteration_number in range(TOTAL_ITERATIONS + int(b_size==1)):
logger.info(f"Iteration: {iteration_number}")
sample_batch = resnet.get_batch(b_size)
remote_result = client.run_model(resnet.name, sample_batch)
logger.info(client.perf_timer.get_last("total_time"))
if CHECK_RESULTS_AND_MAKE_ALL_SLOWER:
local_res = pt_model(sample_batch.to(torch_device))
err_norm = torch.linalg.vector_norm(torch.flatten(remote_result).to(torch_device)-torch.flatten(local_res), ord=1).cpu()
res_norm = torch.linalg.vector_norm(remote_result, ord=1).item()
local_res_norm = torch.linalg.vector_norm(local_res, ord=1).item()
logger.info(f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}")
torch.cuda.synchronize()

client.perf_timer.print_timings(to_file=True)
28 changes: 15 additions & 13 deletions ex/high_throughput_inference/mock_app_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
import numpy
import time
import torch
from mpi4py import MPI
from smartsim.log import get_logger
from smartsim._core.utils.timings import PerfTimer
from smartredis import Client

logger = get_logger("App")
Expand All @@ -56,6 +58,9 @@ def name(self):

if __name__ == "__main__":

comm = MPI.COMM_WORLD
rank = comm.Get_rank()

parser = argparse.ArgumentParser("Mock application")
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
Expand All @@ -65,24 +70,21 @@ def name(self):
client = Client(cluster=False, address=None)
client.set_model(resnet.name, resnet.model, backend='TORCH', device=args.device.upper())

perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"redis{rank}_")

total_iterations = 100
timings=[]
for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:
logger.info(f"Batch size: {batch_size}")
for iteration_number in range(total_iterations + int(batch_size==1)):
timing = [batch_size]
perf_timer.start_timings("batch_size", batch_size)
logger.info(f"Iteration: {iteration_number}")
start = time.perf_counter()
client.put_tensor(name="batch", data=resnet.get_batch(batch_size).numpy())
client.run_model(name=resnet.name, inputs=["batch"], outputs=["result"])
result = client.get_tensor(name="result")
end = time.perf_counter()
timing.append(end-start)
timings.append(timing)

input_name = f"batch_{rank}"
output_name = f"result_{rank}"
client.put_tensor(name=input_name, data=resnet.get_batch(batch_size).numpy())
client.run_model(name=resnet.name, inputs=[input_name], outputs=[output_name])
result = client.get_tensor(name=output_name)
perf_timer.end_timings()


timings_np = numpy.asarray(timings)
numpy.save("timings.npy", timings_np)
for timing in timings:
print(" ".join(str(t) for t in timing))
perf_timer.print_timings(True)
15 changes: 8 additions & 7 deletions ex/high_throughput_inference/redis_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,24 @@
from smartsim import Experiment
from smartsim.status import TERMINAL_STATUSES
import time
import typing as t

device = "gpu"
DEVICE = "gpu"
filedir = os.path.dirname(__file__)
app_script_name = os.path.join(filedir, "mock_app_redis.py")
model_name = os.path.join(filedir, f"resnet50.{device.upper()}.pt")
model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt")


exp_path = os.path.join(filedir, "redis_ai")
exp_path = os.path.join(filedir, "redis_ai_multi")
os.makedirs(exp_path, exist_ok=True)
exp = Experiment("redis_ai", launcher="slurm", exp_path=exp_path)
exp = Experiment("redis_ai_multi", launcher="slurm", exp_path=exp_path)

db = exp.create_database(interface="hsn0")

app_rs = exp.create_run_settings(sys.executable, exe_args = [app_script_name, "--device", device])
app_rs = exp.create_run_settings(
sys.executable, exe_args = [app_script_name, "--device", DEVICE]
)
app_rs.set_nodes(1)
app_rs.set_tasks(1)
app_rs.set_tasks(4)
app = exp.create_model("app", run_settings=app_rs)
app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name])

Expand Down
Loading