3737
3838import argparse
3939import io
40- import numpy
41- import os
42- import time
40+
4341import torch
4442
45- from mpi4py import MPI
46- from smartsim ._core .mli .infrastructure .storage .dragon_feature_store import (
47- DragonFeatureStore ,
48- )
49- from smartsim ._core .mli .message_handler import MessageHandler
5043from smartsim .log import get_logger
51- from smartsim ._core .utils .timings import PerfTimer
5244
5345torch .set_num_interop_threads (16 )
5446torch .set_num_threads (1 )
5547
5648logger = get_logger ("App" )
5749logger .info ("Started app" )
5850
59- CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False
51+ from collections import OrderedDict
6052
61- class ProtoClient :
62- def __init__ (self , timing_on : bool ):
63- comm = MPI .COMM_WORLD
64- rank = comm .Get_rank ()
65- connect_to_infrastructure ()
66- ddict_str = os .environ ["_SMARTSIM_INFRA_BACKBONE" ]
67- self ._ddict = DDict .attach (ddict_str )
68- self ._backbone_descriptor = DragonFeatureStore (self ._ddict ).descriptor
69- to_worker_fli_str = None
70- while to_worker_fli_str is None :
71- try :
72- to_worker_fli_str = self ._ddict ["to_worker_fli" ]
73- self ._to_worker_fli = fli .FLInterface .attach (to_worker_fli_str )
74- except KeyError :
75- time .sleep (1 )
76- self ._from_worker_ch = Channel .make_process_local ()
77- self ._from_worker_ch_serialized = self ._from_worker_ch .serialize ()
78- self ._to_worker_ch = Channel .make_process_local ()
79-
80- self .perf_timer : PerfTimer = PerfTimer (debug = False , timing_on = timing_on , prefix = f"a{ rank } _" )
81-
82- def run_model (self , model : bytes | str , batch : torch .Tensor ):
83- tensors = [batch .numpy ()]
84- self .perf_timer .start_timings ("batch_size" , batch .shape [0 ])
85- built_tensor_desc = MessageHandler .build_tensor_descriptor (
86- "c" , "float32" , list (batch .shape )
87- )
88- self .perf_timer .measure_time ("build_tensor_descriptor" )
89- if isinstance (model , str ):
90- model_arg = MessageHandler .build_model_key (model , self ._backbone_descriptor )
91- else :
92- model_arg = MessageHandler .build_model (model , "resnet-50" , "1.0" )
93- request = MessageHandler .build_request (
94- reply_channel = self ._from_worker_ch_serialized ,
95- model = model_arg ,
96- inputs = [built_tensor_desc ],
97- outputs = [],
98- output_descriptors = [],
99- custom_attributes = None ,
100- )
101- self .perf_timer .measure_time ("build_request" )
102- request_bytes = MessageHandler .serialize_request (request )
103- self .perf_timer .measure_time ("serialize_request" )
104- with self ._to_worker_fli .sendh (timeout = None , stream_channel = self ._to_worker_ch ) as to_sendh :
105- to_sendh .send_bytes (request_bytes )
106- self .perf_timer .measure_time ("send_request" )
107- for tensor in tensors :
108- to_sendh .send_bytes (tensor .tobytes ()) #TODO NOT FAST ENOUGH!!!
109- self .perf_timer .measure_time ("send_tensors" )
110- with self ._from_worker_ch .recvh (timeout = None ) as from_recvh :
111- resp = from_recvh .recv_bytes (timeout = None )
112- self .perf_timer .measure_time ("receive_response" )
113- response = MessageHandler .deserialize_response (resp )
114- self .perf_timer .measure_time ("deserialize_response" )
115- # list of data blobs? recv depending on the len(response.result.descriptors)?
116- data_blob : bytes = from_recvh .recv_bytes (timeout = None )
117- self .perf_timer .measure_time ("receive_tensor" )
118- result = torch .from_numpy (
119- numpy .frombuffer (
120- data_blob ,
121- dtype = str (response .result .descriptors [0 ].dataType ),
122- )
123- )
124- self .perf_timer .measure_time ("deserialize_tensor" )
53+ from smartsim .log import get_logger , log_to_file
54+ from smartsim .protoclient import ProtoClient
12555
126- self .perf_timer .end_timings ()
127- return result
56+ logger = get_logger ("App" , "DEBUG" )
12857
129- def set_model (self , key : str , model : bytes ):
130- self ._ddict [key ] = model
13158
59+ CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False
13260
13361
13462class ResNetWrapper :
@@ -151,6 +79,7 @@ def model(self):
15179 def name (self ):
15280 return self ._name
15381
82+
15483if __name__ == "__main__" :
15584
15685 parser = argparse .ArgumentParser ("Mock application" )
@@ -160,30 +89,38 @@ def name(self):
16089
16190 resnet = ResNetWrapper ("resnet50" , f"resnet50.{ args .device } .pt" )
16291
163- client = ProtoClient (timing_on = True )
164- client .set_model (resnet .name , resnet .model )
92+ client = ProtoClient (timing_on = True , wait_timeout = 0 )
93+ # client.set_model(resnet.name, resnet.model)
16594
16695 if CHECK_RESULTS_AND_MAKE_ALL_SLOWER :
16796 # TODO: adapt to non-Nvidia devices
16897 torch_device = args .device .replace ("gpu" , "cuda" )
169- pt_model = torch .jit .load (io .BytesIO (initial_bytes = (resnet .model ))).to (torch_device )
98+ pt_model = torch .jit .load (io .BytesIO (initial_bytes = (resnet .model ))).to (
99+ torch_device
100+ )
170101
171102 TOTAL_ITERATIONS = 100
172103
173- for log2_bsize in range (args .log_max_batchsize + 1 ):
104+ for log2_bsize in range (args .log_max_batchsize + 1 ):
174105 b_size : int = 2 ** log2_bsize
175106 logger .info (f"Batch size: { b_size } " )
176- for iteration_number in range (TOTAL_ITERATIONS + int (b_size == 1 )):
107+ for iteration_number in range (TOTAL_ITERATIONS + int (b_size == 1 )):
177108 logger .info (f"Iteration: { iteration_number } " )
178109 sample_batch = resnet .get_batch (b_size )
179110 remote_result = client .run_model (resnet .name , sample_batch )
180111 logger .info (client .perf_timer .get_last ("total_time" ))
181112 if CHECK_RESULTS_AND_MAKE_ALL_SLOWER :
182113 local_res = pt_model (sample_batch .to (torch_device ))
183- err_norm = torch .linalg .vector_norm (torch .flatten (remote_result ).to (torch_device )- torch .flatten (local_res ), ord = 1 ).cpu ()
114+ err_norm = torch .linalg .vector_norm (
115+ torch .flatten (remote_result ).to (torch_device )
116+ - torch .flatten (local_res ),
117+ ord = 1 ,
118+ ).cpu ()
184119 res_norm = torch .linalg .vector_norm (remote_result , ord = 1 ).item ()
185120 local_res_norm = torch .linalg .vector_norm (local_res , ord = 1 ).item ()
186- logger .info (f"Avg norm of error { err_norm .item ()/ b_size } compared to result norm of { res_norm / b_size } :{ local_res_norm / b_size } " )
121+ logger .info (
122+ f"Avg norm of error { err_norm .item ()/ b_size } compared to result norm of { res_norm / b_size } :{ local_res_norm / b_size } "
123+ )
187124 torch .cuda .synchronize ()
188125
189- client .perf_timer .print_timings (to_file = True )
126+ client .perf_timer .print_timings (to_file = True )
0 commit comments