22import logging
33from os import path
44from abc import ABCMeta
5- from typing import List , Mapping , Optional
5+ from typing import List , Mapping , Optional , Dict
66from glob import glob
77
88import numpy as np
1111
1212from .third_party .tensorflow .freeze_graph import freeze_graph
1313from .third_party .tensorflow .average_gradients import average_gradients
14- from .utils import create_optimizer
14+ from .utils import create_optimizer , Profiler
1515from .graph_tower import GraphTower
1616
1717DEFAULT_LOSS_NAME = 'loss'
@@ -44,8 +44,8 @@ def __init__(self, # pylint: disable=too-many-arguments
4444 dataset : Optional [el .AbstractDataset ], log_dir : Optional [str ], inputs : List [str ], outputs : List [str ],
4545 session_config : Optional [dict ]= None , n_gpus : int = 0 , restore_from : Optional [str ]= None ,
4646 optimizer = None , freeze = False , loss_name : str = DEFAULT_LOSS_NAME , monitor : Optional [str ]= None ,
47- restore_fallback : Optional [str ]= None , clip_gradient : Optional [float ]= None ,
48- ** kwargs ):
47+ restore_fallback : Optional [str ]= None , clip_gradient : Optional [float ]= None , profile : bool = False ,
48+ keep_profiles : int = 5 , ** kwargs ):
4949 """
5050 Create new emloop trainable TensorFlow model.
5151
@@ -82,6 +82,8 @@ def __init__(self, # pylint: disable=too-many-arguments
8282 :param monitor: monitor signal mean and variance of the tensors which names contain the specified value
8383 :param restore_fallback: ignored arg. (allows training from configs saved by emloop where it is added)
8484 :param clip_gradient: limit the absolute value of the gradient; set to None for no clipping
85+ :param profile: if true, profile the speed of model inference and save profiles to the specified log_dir
86+ :param keep_profiles: if true, profile the speed of model inference and save profiles to the specified log_dir
8587 :param kwargs: additional kwargs forwarded to :py:meth:`_create_model`
8688 """
8789 super ().__init__ (dataset = dataset , log_dir = log_dir , restore_from = restore_from )
@@ -97,10 +99,17 @@ def __init__(self, # pylint: disable=too-many-arguments
9799 self ._towers = [GraphTower (i , inputs , outputs , loss_name ) for i in range (n_gpus )]
98100 if n_gpus == 0 :
99101 self ._towers .append (GraphTower (- 1 , inputs , outputs , loss_name ))
100-
101102 logging .info ('\t Creating TF model on %s GPU devices' , n_gpus )
102103 self ._graph = tf .Graph ()
103104 self ._session = self ._create_session (session_config )
105+
106+ if profile and not log_dir :
107+ raise ValueError ('log_dir has to be specified with profile set to True' )
108+
109+ self ._profile = profile
110+ if profile :
111+ self ._profiler = Profiler (log_dir , keep_profiles , self ._session )
112+
104113 dependencies = []
105114 with self ._graph .as_default ():
106115 if restore_from is None :
@@ -223,12 +232,14 @@ def run(self, batch: el.Batch, train: bool=False, stream: el.datasets.StreamWrap
223232 for output_name in self .output_names :
224233 fetches .append (tower [output_name ])
225234
235+ run_fn = self ._profiler .run if self ._profile else self ._session .run
236+
226237 # run the computational graph for one batch and allow buffering in the meanwhile
227238 if stream is not None :
228239 with stream .allow_buffering :
229- outputs = self . _session . run (fetches = fetches , feed_dict = feed_dict )
240+ outputs = run_fn (fetches , feed_dict )
230241 else :
231- outputs = self . _session . run (fetches = fetches , feed_dict = feed_dict )
242+ outputs = run_fn (fetches , feed_dict )
232243
233244 if train :
234245 outputs = outputs [1 :]
0 commit comments