3232
3333try :
3434 import wandb
35+ from wandb .sdk .lib import RunDisabled
3536 from wandb .wandb_run import Run
3637except ModuleNotFoundError :
3738 # needed for test mocks, these tests shall be updated
38- wandb , Run = None , None
39+ wandb , Run , RunDisabled = None , None , None # type: ignore
3940
4041
4142class WandbLogger (Logger ):
@@ -251,18 +252,18 @@ def __init__(
251252 self ,
252253 name : Optional [str ] = None ,
253254 save_dir : Optional [str ] = None ,
254- offline : Optional [ bool ] = False ,
255+ offline : bool = False ,
255256 id : Optional [str ] = None ,
256257 anonymous : Optional [bool ] = None ,
257258 version : Optional [str ] = None ,
258259 project : Optional [str ] = None ,
259260 log_model : Union [str , bool ] = False ,
260- experiment = None ,
261- prefix : Optional [ str ] = "" ,
261+ experiment : Union [ Run , RunDisabled , None ] = None ,
262+ prefix : str = "" ,
262263 agg_key_funcs : Optional [Mapping [str , Callable [[Sequence [float ]], float ]]] = None ,
263264 agg_default_func : Optional [Callable [[Sequence [float ]], float ]] = None ,
264- ** kwargs ,
265- ):
265+ ** kwargs : Any ,
266+ ) -> None :
266267 if wandb is None :
267268 raise ModuleNotFoundError (
268269 "You want to use `wandb` logger which is not installed yet,"
@@ -288,17 +289,16 @@ def __init__(
288289 self ._log_model = log_model
289290 self ._prefix = prefix
290291 self ._experiment = experiment
291- self ._logged_model_time = {}
292- self ._checkpoint_callback = None
292+ self ._logged_model_time : Dict [ str , float ] = {}
293+ self ._checkpoint_callback : Optional [ "ReferenceType[Checkpoint]" ] = None
293294 # set wandb init arguments
294- anonymous_lut = {True : "allow" , False : None }
295- self ._wandb_init = dict (
295+ self ._wandb_init : Dict [str , Any ] = dict (
296296 name = name or project ,
297297 project = project ,
298298 id = version or id ,
299299 dir = save_dir ,
300300 resume = "allow" ,
301- anonymous = anonymous_lut . get ( anonymous , anonymous ),
301+ anonymous = ( "allow" if anonymous else None ),
302302 )
303303 self ._wandb_init .update (** kwargs )
304304 # extract parameters
@@ -310,7 +310,7 @@ def __init__(
310310 wandb .require ("service" )
311311 _ = self .experiment
312312
313- def __getstate__ (self ):
313+ def __getstate__ (self ) -> Dict [ str , Any ] :
314314 state = self .__dict__ .copy ()
315315 # args needed to reload correct experiment
316316 if self ._experiment is not None :
@@ -322,7 +322,7 @@ def __getstate__(self):
322322 state ["_experiment" ] = None
323323 return state
324324
325- @property
325+ @property # type: ignore[misc]
326326 @rank_zero_experiment
327327 def experiment (self ) -> Run :
328328 r"""
@@ -357,13 +357,14 @@ def experiment(self) -> Run:
357357 self ._experiment = wandb .init (** self ._wandb_init )
358358
359359 # define default x-axis
360- if getattr (self ._experiment , "define_metric" , None ):
360+ if isinstance ( self . _experiment , Run ) and getattr (self ._experiment , "define_metric" , None ):
361361 self ._experiment .define_metric ("trainer/global_step" )
362362 self ._experiment .define_metric ("*" , step_metric = "trainer/global_step" , step_sync = True )
363363
364+ assert isinstance (self ._experiment , Run )
364365 return self ._experiment
365366
366- def watch (self , model : nn .Module , log : str = "gradients" , log_freq : int = 100 , log_graph : bool = True ):
367+ def watch (self , model : nn .Module , log : str = "gradients" , log_freq : int = 100 , log_graph : bool = True ) -> None :
367368 self .experiment .watch (model , log = log , log_freq = log_freq , log_graph = log_graph )
368369
369370 @rank_zero_only
@@ -379,7 +380,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
379380
380381 metrics = _add_prefix (metrics , self ._prefix , self .LOGGER_JOIN_CHAR )
381382 if step is not None :
382- self .experiment .log ({ ** metrics , "trainer/global_step" : step })
383+ self .experiment .log (dict ( metrics , ** { "trainer/global_step" : step }) )
383384 else :
384385 self .experiment .log (metrics )
385386
@@ -417,7 +418,7 @@ def log_text(
417418 self .log_table (key , columns , data , dataframe , step )
418419
419420 @rank_zero_only
420- def log_image (self , key : str , images : List [Any ], step : Optional [int ] = None , ** kwargs : str ) -> None :
421+ def log_image (self , key : str , images : List [Any ], step : Optional [int ] = None , ** kwargs : Any ) -> None :
421422 """Log images (tensors, numpy arrays, PIL Images or file paths).
422423
423424 Optional kwargs are lists passed to each image (ex: caption, masks, boxes).
0 commit comments