@@ -96,7 +96,7 @@ def __init__(
9696 sub_dir : Optional [str ] = None ,
9797 agg_key_funcs : Optional [Mapping [str , Callable [[Sequence [float ]], float ]]] = None ,
9898 agg_default_func : Optional [Callable [[Sequence [float ]], float ]] = None ,
99- ** kwargs ,
99+ ** kwargs : Any ,
100100 ):
101101 super ().__init__ (agg_key_funcs = agg_key_funcs , agg_default_func = agg_default_func )
102102 self ._save_dir = save_dir
@@ -108,8 +108,8 @@ def __init__(
108108 self ._prefix = prefix
109109 self ._fs = get_filesystem (save_dir )
110110
111- self ._experiment = None
112- self .hparams = {}
111+ self ._experiment : Optional [ "SummaryWriter" ] = None
112+ self .hparams : Union [ Dict [ str , Any ], Namespace ] = {}
113113 self ._kwargs = kwargs
114114
115115 @property
@@ -138,7 +138,7 @@ def log_dir(self) -> str:
138138 return log_dir
139139
140140 @property
141- def save_dir (self ) -> Optional [ str ] :
141+ def save_dir (self ) -> str :
142142 """Gets the save directory where the TensorBoard experiments are saved.
143143
144144 Returns:
@@ -155,7 +155,7 @@ def sub_dir(self) -> Optional[str]:
155155 """
156156 return self ._sub_dir
157157
158- @property
158+ @property # type: ignore[misc]
159159 @rank_zero_experiment
160160 def experiment (self ) -> SummaryWriter :
161161 r"""
@@ -236,7 +236,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
236236 raise ValueError (m ) from ex
237237
238238 @rank_zero_only
239- def log_graph (self , model : "pl.LightningModule" , input_array = None ):
239+ def log_graph (self , model : "pl.LightningModule" , input_array : Optional [ Tensor ] = None ) -> None :
240240 if self ._log_graph :
241241 if input_array is None :
242242 input_array = model .example_input_array
@@ -281,7 +281,7 @@ def name(self) -> str:
281281 return self ._name
282282
283283 @property
284- def version (self ) -> int :
284+ def version (self ) -> Union [ int , str ] :
285285 """Get the experiment version.
286286
287287 Returns:
@@ -291,7 +291,7 @@ def version(self) -> int:
291291 self ._version = self ._get_next_version ()
292292 return self ._version
293293
294- def _get_next_version (self ):
294+ def _get_next_version (self ) -> int :
295295 root_dir = self .root_dir
296296
297297 try :
@@ -318,7 +318,7 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
318318 # logging of arrays with dimension > 1 is not supported, sanitize as string
319319 return {k : str (v ) if isinstance (v , (Tensor , np .ndarray )) and v .ndim > 1 else v for k , v in params .items ()}
320320
321- def __getstate__ (self ):
321+ def __getstate__ (self ) -> Dict [ str , Any ] :
322322 state = self .__dict__ .copy ()
323323 state ["_experiment" ] = None
324324 return state
0 commit comments