4141import os
4242import time
4343import torch
44- import numbers
4544
46- from collections import OrderedDict
45+ from mpi4py import MPI
4746from smartsim ._core .mli .infrastructure .storage .dragonfeaturestore import (
4847 DragonFeatureStore ,
4948)
5049from smartsim ._core .mli .message_handler import MessageHandler
5150from smartsim .log import get_logger
51+ from smartsim ._core .utils .timings import PerfTimer
52+
53+ torch .set_num_interop_threads (16 )
54+ torch .set_num_threads (1 )
5255
5356logger = get_logger ("App" )
57+ logger .info ("Started app" )
5458
59+ CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False
5560
5661class ProtoClient :
5762 def __init__ (self , timing_on : bool ):
63+ comm = MPI .COMM_WORLD
64+ rank = comm .Get_rank ()
5865 connect_to_infrastructure ()
5966 ddict_str = os .environ ["_SMARTSIM_INFRA_BACKBONE" ]
6067 self ._ddict = DDict .attach (ddict_str )
@@ -70,61 +77,15 @@ def __init__(self, timing_on: bool):
7077 self ._from_worker_ch_serialized = self ._from_worker_ch .serialize ()
7178 self ._to_worker_ch = Channel .make_process_local ()
7279
73- self ._start = None
74- self ._interm = None
75- self ._timings : OrderedDict [str , list [numbers .Number ]] = OrderedDict ()
76- self ._timing_on = timing_on
77-
78- def _add_label_to_timings (self , label : str ):
79- if label not in self ._timings :
80- self ._timings [label ] = []
81-
82- @staticmethod
83- def _format_number (number : numbers .Number ):
84- return f"{ number :0.4e} "
85-
86- def start_timings (self , batch_size : int ):
87- if self ._timing_on :
88- self ._add_label_to_timings ("batch_size" )
89- self ._timings ["batch_size" ].append (batch_size )
90- self ._start = time .perf_counter ()
91- self ._interm = time .perf_counter ()
92-
93- def end_timings (self ):
94- if self ._timing_on :
95- self ._add_label_to_timings ("total_time" )
96- self ._timings ["total_time" ].append (
97- self ._format_number (time .perf_counter () - self ._start )
98- )
99-
100- def measure_time (self , label : str ):
101- if self ._timing_on :
102- self ._add_label_to_timings (label )
103- self ._timings [label ].append (
104- self ._format_number (time .perf_counter () - self ._interm )
105- )
106- self ._interm = time .perf_counter ()
107-
108- def print_timings (self , to_file : bool = False ):
109- print (" " .join (self ._timings .keys ()))
110- value_array = numpy .array (
111- [value for value in self ._timings .values ()], dtype = float
112- )
113- value_array = numpy .transpose (value_array )
114- for i in range (value_array .shape [0 ]):
115- print (" " .join (self ._format_number (value ) for value in value_array [i ]))
116- if to_file :
117- numpy .save ("timings.npy" , value_array )
118- numpy .savetxt ("timings.txt" , value_array )
80+ self .perf_timer : PerfTimer = PerfTimer (debug = False , timing_on = timing_on , prefix = f"a{ rank } _" )
11981
12082 def run_model (self , model : bytes | str , batch : torch .Tensor ):
12183 tensors = [batch .numpy ()]
122- self .start_timings (batch .shape [0 ])
84+ self .perf_timer . start_timings ("batch_size" , batch .shape [0 ])
12385 built_tensor_desc = MessageHandler .build_tensor_descriptor (
12486 "c" , "float32" , list (batch .shape )
12587 )
126- self .measure_time ("build_tensor_descriptor" )
127- built_model = None
88+ self .perf_timer .measure_time ("build_tensor_descriptor" )
12889 if isinstance (model , str ):
12990 model_arg = MessageHandler .build_model_key (model , self ._backbone_descriptor )
13091 else :
@@ -137,39 +98,39 @@ def run_model(self, model: bytes | str, batch: torch.Tensor):
13798 output_descriptors = [],
13899 custom_attributes = None ,
139100 )
140- self .measure_time ("build_request" )
101+ self .perf_timer . measure_time ("build_request" )
141102 request_bytes = MessageHandler .serialize_request (request )
142- self .measure_time ("serialize_request" )
143- with self ._to_worker_fli .sendh (
144- timeout = None , stream_channel = self ._to_worker_ch
145- ) as to_sendh :
103+ self .perf_timer .measure_time ("serialize_request" )
104+ with self ._to_worker_fli .sendh (timeout = None , stream_channel = self ._to_worker_ch ) as to_sendh :
146105 to_sendh .send_bytes (request_bytes )
147- for t in tensors :
148- to_sendh .send_bytes (t .tobytes ()) # TODO NOT FAST ENOUGH!!!
149- # to_sendh.send_bytes(bytes(t.data))
150- logger .info (f"Message size: { len (request_bytes )} bytes" )
151-
152- self .measure_time ("send" )
106+ self .perf_timer .measure_time ("send_request" )
107+ for tensor in tensors :
108+ to_sendh .send_bytes (tensor .tobytes ()) #TODO NOT FAST ENOUGH!!!
109+ self .perf_timer .measure_time ("send_tensors" )
153110 with self ._from_worker_ch .recvh (timeout = None ) as from_recvh :
154111 resp = from_recvh .recv_bytes (timeout = None )
155- self .measure_time ("receive " )
112+ self .perf_timer . measure_time ("receive_response " )
156113 response = MessageHandler .deserialize_response (resp )
157- self .measure_time ("deserialize_response" )
114+ self .perf_timer . measure_time ("deserialize_response" )
158115 # list of data blobs? recv depending on the len(response.result.descriptors)?
159- data_blob = from_recvh .recv_bytes (timeout = None )
160- result = numpy .frombuffer (
161- data_blob ,
162- dtype = str (response .result .descriptors [0 ].dataType ),
116+ data_blob : bytes = from_recvh .recv_bytes (timeout = None )
117+ self .perf_timer .measure_time ("receive_tensor" )
118+ result = torch .from_numpy (
119+ numpy .frombuffer (
120+ data_blob ,
121+ dtype = str (response .result .descriptors [0 ].dataType ),
122+ )
163123 )
164- self .measure_time ("deserialize_tensor" )
124+ self .perf_timer . measure_time ("deserialize_tensor" )
165125
166- self .end_timings ()
126+ self .perf_timer . end_timings ()
167127 return result
168128
169129 def set_model (self , key : str , model : bytes ):
170130 self ._ddict [key ] = model
171131
172132
133+
173134class ResNetWrapper :
174135 def __init__ (self , name : str , model : str ):
175136 self ._model = torch .jit .load (model )
@@ -190,24 +151,39 @@ def model(self):
190151 def name (self ):
191152 return self ._name
192153
193-
194154if __name__ == "__main__" :
195155
196156 parser = argparse .ArgumentParser ("Mock application" )
197- parser .add_argument ("--device" , default = "cpu" )
157+ parser .add_argument ("--device" , default = "cpu" , type = str )
158+ parser .add_argument ("--log_max_batchsize" , default = 8 , type = int )
198159 args = parser .parse_args ()
199160
200- resnet = ResNetWrapper ("resnet50" , f"resnet50.{ args .device . upper () } .pt" )
161+ resnet = ResNetWrapper ("resnet50" , f"resnet50.{ args .device } .pt" )
201162
202163 client = ProtoClient (timing_on = True )
203164 client .set_model (resnet .name , resnet .model )
204165
205- total_iterations = 100
166+ if CHECK_RESULTS_AND_MAKE_ALL_SLOWER :
167+ # TODO: adapt to non-Nvidia devices
168+ torch_device = args .device .replace ("gpu" , "cuda" )
169+ pt_model = torch .jit .load (io .BytesIO (initial_bytes = (resnet .model ))).to (torch_device )
206170
207- for batch_size in [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 ]:
208- logger .info (f"Batch size: { batch_size } " )
209- for iteration_number in range (total_iterations + int (batch_size == 1 )):
210- logger .info (f"Iteration: { iteration_number } " )
211- client .run_model (resnet .name , resnet .get_batch (batch_size ))
171+ TOTAL_ITERATIONS = 100
212172
213- client .print_timings (to_file = True )
173+ for log2_bsize in range (args .log_max_batchsize + 1 ):
174+ b_size : int = 2 ** log2_bsize
175+ logger .info (f"Batch size: { b_size } " )
176+ for iteration_number in range (TOTAL_ITERATIONS + int (b_size == 1 )):
177+ logger .info (f"Iteration: { iteration_number } " )
178+ sample_batch = resnet .get_batch (b_size )
179+ remote_result = client .run_model (resnet .name , sample_batch )
180+ logger .info (client .perf_timer .get_last ("total_time" ))
181+ if CHECK_RESULTS_AND_MAKE_ALL_SLOWER :
182+ local_res = pt_model (sample_batch .to (torch_device ))
183+ err_norm = torch .linalg .vector_norm (torch .flatten (remote_result ).to (torch_device )- torch .flatten (local_res ), ord = 1 ).cpu ()
184+ res_norm = torch .linalg .vector_norm (remote_result , ord = 1 ).item ()
185+ local_res_norm = torch .linalg .vector_norm (local_res , ord = 1 ).item ()
186+ logger .info (f"Avg norm of error { err_norm .item ()/ b_size } compared to result norm of { res_norm / b_size } :{ local_res_norm / b_size } " )
187+ torch .cuda .synchronize ()
188+
189+ client .perf_timer .print_timings (to_file = True )
0 commit comments