2626from pytorch_lightning .utilities .exceptions import MisconfigurationException
2727from pytorch_lightning .utilities .warnings import WarningCache
2828
29+ warning_cache = WarningCache ()
30+
2931_WANDB_AVAILABLE = _module_available ("wandb" )
3032
3133try :
@@ -56,7 +58,6 @@ class WandbLogger(LightningLoggerBase):
5658 project: The name of the project to which this run will belong.
5759 log_model: Save checkpoints in wandb dir to upload on W&B servers.
5860 prefix: A string to put at the beginning of metric keys.
59- sync_step: Sync Trainer step with wandb step.
6061 experiment: WandB experiment object. Automatically set when creating a run.
6162 \**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
6263 :func:`wandb.init` can be passed as keyword arguments in this logger.
@@ -98,7 +99,7 @@ def __init__(
9899 log_model : Optional [bool ] = False ,
99100 experiment = None ,
100101 prefix : Optional [str ] = '' ,
101- sync_step : Optional [bool ] = True ,
102+ sync_step : Optional [bool ] = None ,
102103 ** kwargs
103104 ):
104105 if wandb is None :
@@ -114,6 +115,12 @@ def __init__(
114115 'Hint: Set `offline=False` to log your model.'
115116 )
116117
118+ if sync_step is not None :
119+ warning_cache .warn (
120+ "`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5."
121+ " Metrics are now logged separately and automatically synchronized." , DeprecationWarning
122+ )
123+
117124 super ().__init__ ()
118125 self ._name = name
119126 self ._save_dir = save_dir
@@ -123,12 +130,8 @@ def __init__(
123130 self ._project = project
124131 self ._log_model = log_model
125132 self ._prefix = prefix
126- self ._sync_step = sync_step
127133 self ._experiment = experiment
128134 self ._kwargs = kwargs
129- # logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
130- self ._step_offset = 0
131- self .warning_cache = WarningCache ()
132135
133136 def __getstate__ (self ):
134137 state = self .__dict__ .copy ()
@@ -165,12 +168,15 @@ def experiment(self) -> Run:
165168 ** self ._kwargs
166169 ) if wandb .run is None else wandb .run
167170
168- # offset logging step when resuming a run
169- self ._step_offset = self ._experiment .step
170-
171171 # save checkpoints in wandb dir to upload on W&B servers
172172 if self ._save_dir is None :
173173 self ._save_dir = self ._experiment .dir
174+
175+ # define default x-axis (for latest wandb versions)
176+ if getattr (self ._experiment , "define_metric" , None ):
177+ self ._experiment .define_metric ("trainer/global_step" )
178+ self ._experiment .define_metric ("*" , step_metric = 'trainer/global_step' , step_sync = True )
179+
174180 return self ._experiment
175181
176182 def watch (self , model : nn .Module , log : str = 'gradients' , log_freq : int = 100 ):
@@ -188,15 +194,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
188194 assert rank_zero_only .rank == 0 , 'experiment tried to log from global_rank != 0'
189195
190196 metrics = self ._add_prefix (metrics )
191- if self ._sync_step and step is not None and step + self ._step_offset < self .experiment .step :
192- self .warning_cache .warn (
193- 'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`'
194- ' or try logging with `commit=False` when calling manually `wandb.log`.'
195- )
196- if self ._sync_step :
197- self .experiment .log (metrics , step = (step + self ._step_offset ) if step is not None else None )
198- elif step is not None :
199- self .experiment .log ({** metrics , 'trainer_step' : (step + self ._step_offset )})
197+ if step is not None :
198+ self .experiment .log ({** metrics , 'trainer/global_step' : step })
200199 else :
201200 self .experiment .log (metrics )
202201
@@ -216,10 +215,6 @@ def version(self) -> Optional[str]:
216215
217216 @rank_zero_only
218217 def finalize (self , status : str ) -> None :
219- # offset future training logged on same W&B run
220- if self ._experiment is not None :
221- self ._step_offset = self ._experiment .step
222-
223218 # upload all checkpoints from saving dir
224219 if self ._log_model :
225220 wandb .save (os .path .join (self .save_dir , "*.ckpt" ))
0 commit comments