3737
3838import argparse
3939import io
40- import numpy
41- import os
42- import time
40+
4341import torch
4442
45- from mpi4py import MPI
46- from smartsim ._core .mli .infrastructure .storage .dragon_feature_store import (
47- DragonFeatureStore ,
48- )
49- from smartsim ._core .mli .message_handler import MessageHandler
5043from smartsim .log import get_logger
51- from smartsim ._core .utils .timings import PerfTimer
5244
5345torch .set_num_interop_threads (16 )
5446torch .set_num_threads (1 )
5547
5648logger = get_logger ("App" )
5749logger .info ("Started app" )
5850
59- CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False
51+ from collections import OrderedDict
6052
61- class ProtoClient :
62- def __init__ (self , timing_on : bool ):
63- comm = MPI .COMM_WORLD
64- rank = comm .Get_rank ()
65- connect_to_infrastructure ()
66- ddict_str = os .environ ["_SMARTSIM_INFRA_BACKBONE" ]
67- self ._ddict = DDict .attach (ddict_str )
68- self ._backbone_descriptor = DragonFeatureStore (self ._ddict ).descriptor
69- to_worker_fli_str = None
70- while to_worker_fli_str is None :
71- try :
72- to_worker_fli_str = self ._ddict ["to_worker_fli" ]
73- self ._to_worker_fli = fli .FLInterface .attach (to_worker_fli_str )
74- except KeyError :
75- time .sleep (1 )
76- self ._from_worker_ch = Channel .make_process_local ()
77- self ._from_worker_ch_serialized = self ._from_worker_ch .serialize ()
78- self ._to_worker_ch = Channel .make_process_local ()
79-
80- self .perf_timer : PerfTimer = PerfTimer (debug = False , timing_on = timing_on , prefix = f"a{ rank } _" )
81-
82- def run_model (self , model : bytes | str , batch : torch .Tensor ):
83- tensors = [batch .numpy ()]
84- self .perf_timer .start_timings ("batch_size" , batch .shape [0 ])
85- built_tensor_desc = MessageHandler .build_tensor_descriptor (
86- "c" , "float32" , list (batch .shape )
87- )
88- self .perf_timer .measure_time ("build_tensor_descriptor" )
89- if isinstance (model , str ):
90- model_arg = MessageHandler .build_model_key (model , self ._backbone_descriptor )
91- else :
92- model_arg = MessageHandler .build_model (model , "resnet-50" , "1.0" )
93- request = MessageHandler .build_request (
94- reply_channel = self ._from_worker_ch_serialized ,
95- model = model_arg ,
96- inputs = [built_tensor_desc ],
97- outputs = [],
98- output_descriptors = [],
99- custom_attributes = None ,
100- )
101- self .perf_timer .measure_time ("build_request" )
102- request_bytes = MessageHandler .serialize_request (request )
103- self .perf_timer .measure_time ("serialize_request" )
104- with self ._to_worker_fli .sendh (timeout = None , stream_channel = self ._to_worker_ch ) as to_sendh :
105- to_sendh .send_bytes (request_bytes )
106- self .perf_timer .measure_time ("send_request" )
107- for tensor in tensors :
108- to_sendh .send_bytes (tensor .tobytes ()) #TODO NOT FAST ENOUGH!!!
109- self .perf_timer .measure_time ("send_tensors" )
110- with self ._from_worker_ch .recvh (timeout = None ) as from_recvh :
111- resp = from_recvh .recv_bytes (timeout = None )
112- self .perf_timer .measure_time ("receive_response" )
113- response = MessageHandler .deserialize_response (resp )
114- self .perf_timer .measure_time ("deserialize_response" )
115- # list of data blobs? recv depending on the len(response.result.descriptors)?
116- data_blob : bytes = from_recvh .recv_bytes (timeout = None )
117- self .perf_timer .measure_time ("receive_tensor" )
118- result = torch .from_numpy (
119- numpy .frombuffer (
120- data_blob ,
121- dtype = str (response .result .descriptors [0 ].dataType ),
122- )
123- )
124- self .perf_timer .measure_time ("deserialize_tensor" )
53+ from smartsim .log import get_logger , log_to_file
54+ from smartsim ._core .mli .client .protoclient import ProtoClient
12555
126- self .perf_timer .end_timings ()
127- return result
56+ logger = get_logger ("App" )
12857
129- def set_model (self , key : str , model : bytes ):
130- self ._ddict [key ] = model
13158
59+ CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False
13260
13361
13462class ResNetWrapper :
63+ """Wrapper around a pre-rained ResNet model."""
13564 def __init__ (self , name : str , model : str ):
65+ """Initialize the instance.
66+
67+ :param name: The name to use for the model
68+ :param model: The path to the pre-trained PyTorch model"""
13669 self ._model = torch .jit .load (model )
13770 self ._name = name
13871 buffer = io .BytesIO ()
@@ -141,16 +74,28 @@ def __init__(self, name: str, model: str):
14174 self ._serialized_model = buffer .getvalue ()
14275
14376 def get_batch (self , batch_size : int = 32 ):
77+ """Create a random batch of data with the correct dimensions to
78+ invoke a ResNet model.
79+
80+ :param batch_size: The desired number of samples to produce
81+ :returns: A PyTorch tensor"""
14482 return torch .randn ((batch_size , 3 , 224 , 224 ), dtype = torch .float32 )
14583
14684 @property
147- def model (self ):
85+ def model (self ) -> bytes :
86+ """The content of a model file.
87+
88+ :returns: The model bytes"""
14889 return self ._serialized_model
14990
15091 @property
151- def name (self ):
92+ def name (self ) -> str :
93+ """The name applied to the model.
94+
95+ :returns: The name"""
15296 return self ._name
15397
98+
15499if __name__ == "__main__" :
155100
156101 parser = argparse .ArgumentParser ("Mock application" )
@@ -166,24 +111,32 @@ def name(self):
166111 if CHECK_RESULTS_AND_MAKE_ALL_SLOWER :
167112 # TODO: adapt to non-Nvidia devices
168113 torch_device = args .device .replace ("gpu" , "cuda" )
169- pt_model = torch .jit .load (io .BytesIO (initial_bytes = (resnet .model ))).to (torch_device )
114+ pt_model = torch .jit .load (io .BytesIO (initial_bytes = (resnet .model ))).to (
115+ torch_device
116+ )
170117
171118 TOTAL_ITERATIONS = 100
172119
173- for log2_bsize in range (args .log_max_batchsize + 1 ):
120+ for log2_bsize in range (args .log_max_batchsize + 1 ):
174121 b_size : int = 2 ** log2_bsize
175122 logger .info (f"Batch size: { b_size } " )
176- for iteration_number in range (TOTAL_ITERATIONS + int (b_size == 1 )):
123+ for iteration_number in range (TOTAL_ITERATIONS + int (b_size == 1 )):
177124 logger .info (f"Iteration: { iteration_number } " )
178125 sample_batch = resnet .get_batch (b_size )
179126 remote_result = client .run_model (resnet .name , sample_batch )
180127 logger .info (client .perf_timer .get_last ("total_time" ))
181128 if CHECK_RESULTS_AND_MAKE_ALL_SLOWER :
182129 local_res = pt_model (sample_batch .to (torch_device ))
183- err_norm = torch .linalg .vector_norm (torch .flatten (remote_result ).to (torch_device )- torch .flatten (local_res ), ord = 1 ).cpu ()
130+ err_norm = torch .linalg .vector_norm (
131+ torch .flatten (remote_result ).to (torch_device )
132+ - torch .flatten (local_res ),
133+ ord = 1 ,
134+ ).cpu ()
184135 res_norm = torch .linalg .vector_norm (remote_result , ord = 1 ).item ()
185136 local_res_norm = torch .linalg .vector_norm (local_res , ord = 1 ).item ()
186- logger .info (f"Avg norm of error { err_norm .item ()/ b_size } compared to result norm of { res_norm / b_size } :{ local_res_norm / b_size } " )
137+ logger .info (
138+ f"Avg norm of error { err_norm .item ()/ b_size } compared to result norm of { res_norm / b_size } :{ local_res_norm / b_size } "
139+ )
187140 torch .cuda .synchronize ()
188141
189- client .perf_timer .print_timings (to_file = True )
142+ client .perf_timer .print_timings (to_file = True )
0 commit comments