2323
2424import dateutil .tz
2525
26+ from sagemaker .session import Session
27+
2628METRICS_DIR = os .environ .get ("SAGEMAKER_METRICS_DIRECTORY" , "." )
2729METRIC_TS_LOWER_BOUND_TO_NOW = 1209600 # on seconds
2830METRIC_TS_UPPER_BOUND_FROM_NOW = 7200 # on seconds
2931
32+ BATCH_SIZE = 10
33+
3034logging .basicConfig (level = logging .INFO )
3135logger = logging .getLogger (__name__ )
3236
@@ -171,7 +175,7 @@ def to_raw_metric_data(self):
171175 "Timestamp" : int (self .Timestamp ),
172176 }
173177 if self .Step is not None :
174- raw_metric_data ["IterationNumber " ] = int (self .Step )
178+ raw_metric_data ["Step " ] = int (self .Step )
175179 return raw_metric_data
176180
177181 def __str__ (self ):
@@ -185,36 +189,27 @@ def __repr__(self):
185189 "," .join (["{}={}" .format (k , repr (v )) for k , v in vars (self ).items ()]),
186190 )
187191
188- def to_request_item (self ):
189- """Transform a RawMetricData item to a list item for BatchPutMetrics request."""
190- item = {
191- "MetricName" : self .MetricName ,
192- "Timestamp" : int (self .Timestamp ),
193- "Value" : self .Value ,
194- }
195-
196- if self .Step is not None :
197- item ["IterationNumber" ] = self .Step
198-
199- return item
200-
201192
202193class _MetricsManager (object ):
203194 """Collects metrics and sends them directly to SageMaker Metrics data plane APIs."""
204195
205- _BATCH_SIZE = 10
206-
207- def __init__ (self , resource_arn , sagemaker_session ) -> None :
208- """Initiate a `_MetricsManager` instance
196+ def __init__ (self , trial_component_name : str , sagemaker_session : Session , sink = None ) -> None :
197+ """Initialize a `_MetricsManager` instance
209198
210199 Args:
211- resource_arn (str): The ARN of the resource to log metrics to
200+ trial_component_name (str): The Name of the Trial Component to log metrics to
212201 sagemaker_session (sagemaker.session.Session): Session object which
213202 manages interactions with Amazon SageMaker APIs and any other
214203 AWS services needed. If not specified, one is created using the
215204 default AWS configuration chain.
205+ sink (object): The metrics sink to use.
216206 """
217- self .sink = _SyncMetricsSink (resource_arn , sagemaker_session .sagemaker_metrics_client )
207+ if sink is None :
208+ self .sink = _SyncMetricsSink (
209+ trial_component_name , sagemaker_session .sagemaker_metrics_client
210+ )
211+ else :
212+ self .sink = sink
218213
219214 def log_metric (self , metric_name , value , timestamp = None , step = None ):
220215 """Sends a metric to metrics service."""
@@ -238,16 +233,14 @@ def close(self):
238233class _SyncMetricsSink (object ):
239234 """Collects metrics and sends them directly to metrics service."""
240235
241- _BATCH_SIZE = 10
242-
243- def __init__ (self , resource_arn , metrics_client ) -> None :
244- """Initiate a `_MetricsManager` instance
236+ def __init__ (self , trial_component_name , metrics_client ) -> None :
237+ """Initialize a `_SyncMetricsSink` instance
245238
246239 Args:
247- resource_arn (str): The ARN of a Trial Component to log metrics.
240+ trial_component_name (str): The Name of the Trial Component to log metrics.
248241 metrics_client (boto3.client): boto client for metrics service
249242 """
250- self ._resource_arn = resource_arn
243+ self ._trial_component_name = trial_component_name
251244 self ._metrics_client = metrics_client
252245 self ._buffer = []
253246
@@ -265,7 +258,7 @@ def _drain(self, close=False):
265258 if not self ._buffer :
266259 return
267260
268- if len (self ._buffer ) < self . _BATCH_SIZE and not close :
261+ if len (self ._buffer ) < BATCH_SIZE and not close :
269262 return
270263
271264 # pop all the available metrics
@@ -276,7 +269,10 @@ def _drain(self, close=False):
276269 def _send_metrics (self , metrics ):
277270 """Calls BatchPutMetrics directly on the metrics service."""
278271 while metrics :
279- batch , metrics = metrics [: self ._BATCH_SIZE ], metrics [self ._BATCH_SIZE :]
272+ batch , metrics = (
273+ metrics [:BATCH_SIZE ],
274+ metrics [BATCH_SIZE :],
275+ )
280276 request = self ._construct_batch_put_metrics_request (batch )
281277 response = self ._metrics_client .batch_put_metrics (** request )
282278 errors = response ["Errors" ] if "Errors" in response else None
@@ -287,7 +283,7 @@ def _send_metrics(self, metrics):
287283 def _construct_batch_put_metrics_request (self , batch ):
288284 """Creates dictionary object used as request to metrics service."""
289285 return {
290- "ResourceArn " : self ._resource_arn ,
286+ "TrialComponentName " : self ._trial_component_name ,
291287 "MetricData" : list (map (lambda x : x .to_raw_metric_data (), batch )),
292288 }
293289
@@ -300,23 +296,21 @@ class _MetricQueue(object):
300296 """A thread safe queue for sending metrics to SageMaker.
301297
302298 Args:
303- resource_arn (str): the ARN of the resource
299+ trial_component_name (str): the ARN of the resource
304300 metric_name (str): the name of the metric
305301 metrics_client (boto_client): the boto client for SageMaker Metrics service
306302 """
307303
308- _BATCH_SIZE = 10
309-
310304 _CONSUMER_SLEEP_SECONDS = 5
311305
312- def __init__ (self , resource_arn , metric_name , metrics_client ):
306+ def __init__ (self , trial_component_name , metric_name , metrics_client ):
313307 # infinite queue size
314308 self ._queue = queue .Queue ()
315309 self ._buffer = []
316310 self ._thread = threading .Thread (target = self ._run )
317311 self ._started = False
318312 self ._finished = False
319- self ._resource_arn = resource_arn
313+ self ._trial_component_name = trial_component_name
320314 self ._metrics_client = metrics_client
321315 self ._metric_name = metric_name
322316 self ._logged_metrics = 0
@@ -325,7 +319,7 @@ def log_metric(self, metric_data):
325319 """Adds a metric data point to the queue"""
326320 self ._buffer .append (metric_data )
327321
328- if len (self ._buffer ) < self . _BATCH_SIZE :
322+ if len (self ._buffer ) < BATCH_SIZE :
329323 return
330324
331325 self ._enqueue_all ()
@@ -354,7 +348,7 @@ def _construct_batch_put_metrics_request(self, batch):
354348 """Creates dictionary object used as request to metrics service."""
355349
356350 return {
357- "ResourceArn " : self ._resource_arn ,
351+ "TrialComponentName " : self ._trial_component_name ,
358352 "MetricData" : list (map (lambda x : x .to_raw_metric_data (), batch )),
359353 }
360354
@@ -382,14 +376,14 @@ class _AsyncMetricsSink(object):
382376
383377 _COMPLETE_SLEEP_SECONDS = 1.0
384378
385- def __init__ (self , resource_arn , metrics_client ) -> None :
386- """Initiate a `_MetricsManager ` instance
379+ def __init__ (self , trial_component_name , metrics_client ) -> None :
380+ """Initialize a `_AsyncMetricsSink ` instance
387381
388382 Args:
389- resource_arn (str): The ARN of a Trial Component to log metrics.
383+ trial_component_name (str): The Name of the Trial Component to log metrics to .
390384 metrics_client (boto3.client): boto client for metrics service
391385 """
392- self ._resource_arn = resource_arn
386+ self ._trial_component_name = trial_component_name
393387 self ._metrics_client = metrics_client
394388 self ._buffer = []
395389 self ._is_draining = False
@@ -402,7 +396,7 @@ def log_metric(self, metric_data):
402396 self ._metric_queues [metric_data .MetricName ].log_metric (metric_data )
403397 else :
404398 cur_metric_queue = _MetricQueue (
405- self ._resource_arn , metric_data .MetricName , self ._metrics_client
399+ self ._trial_component_name , metric_data .MetricName , self ._metrics_client
406400 )
407401 self ._metric_queues [metric_data .MetricName ] = cur_metric_queue
408402 cur_metric_queue .log_metric (metric_data )
0 commit comments