2626from pytorch_lightning .utilities .exceptions import MisconfigurationException
2727from pytorch_lightning .utilities .warnings import WarningCache
2828
29+ warning_cache = WarningCache ()
30+
2931_WANDB_AVAILABLE = _module_available ("wandb" )
3032
3133try :
@@ -56,7 +58,6 @@ class WandbLogger(LightningLoggerBase):
5658 project: The name of the project to which this run will belong.
5759 log_model: Save checkpoints in wandb dir to upload on W&B servers.
5860 prefix: A string to put at the beginning of metric keys.
59- sync_step: Sync Trainer step with wandb step.
6061 experiment: WandB experiment object. Automatically set when creating a run.
6162 \**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
6263 :func:`wandb.init` can be passed as keyword arguments in this logger.
@@ -92,7 +93,7 @@ def __init__(
9293 log_model : Optional [bool ] = False ,
9394 experiment = None ,
9495 prefix : Optional [str ] = '' ,
95- sync_step : Optional [bool ] = True ,
96+ sync_step : Optional [bool ] = None ,
9697 ** kwargs
9798 ):
9899 if wandb is None :
@@ -108,6 +109,12 @@ def __init__(
108109 'Hint: Set `offline=False` to log your model.'
109110 )
110111
112+ if sync_step is not None :
113+ warning_cache .warn (
114+ "`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5."
115+ " Metrics are now logged separately and automatically synchronized." , DeprecationWarning
116+ )
117+
111118 super ().__init__ ()
112119 self ._name = name
113120 self ._save_dir = save_dir
@@ -117,12 +124,8 @@ def __init__(
117124 self ._project = project
118125 self ._log_model = log_model
119126 self ._prefix = prefix
120- self ._sync_step = sync_step
121127 self ._experiment = experiment
122128 self ._kwargs = kwargs
123- # logging multiple Trainer on a single W&B run (k-fold, resuming, etc)
124- self ._step_offset = 0
125- self .warning_cache = WarningCache ()
126129
127130 def __getstate__ (self ):
128131 state = self .__dict__ .copy ()
@@ -159,12 +162,15 @@ def experiment(self) -> Run:
159162 ** self ._kwargs
160163 ) if wandb .run is None else wandb .run
161164
162- # offset logging step when resuming a run
163- self ._step_offset = self ._experiment .step
164-
165165 # save checkpoints in wandb dir to upload on W&B servers
166166 if self ._save_dir is None :
167167 self ._save_dir = self ._experiment .dir
168+
169+ # define default x-axis (for latest wandb versions)
170+ if getattr (self ._experiment , "define_metric" , None ):
171+ self ._experiment .define_metric ("trainer/global_step" )
172+ self ._experiment .define_metric ("*" , step_metric = 'trainer/global_step' , step_sync = True )
173+
168174 return self ._experiment
169175
170176 def watch (self , model : nn .Module , log : str = 'gradients' , log_freq : int = 100 ):
@@ -182,15 +188,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
182188 assert rank_zero_only .rank == 0 , 'experiment tried to log from global_rank != 0'
183189
184190 metrics = self ._add_prefix (metrics )
185- if self ._sync_step and step is not None and step + self ._step_offset < self .experiment .step :
186- self .warning_cache .warn (
187- 'Trying to log at a previous step. Use `WandbLogger(sync_step=False)`'
188- ' or try logging with `commit=False` when calling manually `wandb.log`.'
189- )
190- if self ._sync_step :
191- self .experiment .log (metrics , step = (step + self ._step_offset ) if step is not None else None )
192- elif step is not None :
193- self .experiment .log ({** metrics , 'trainer_step' : (step + self ._step_offset )})
191+ if step is not None :
192+ self .experiment .log ({** metrics , 'trainer/global_step' : step })
194193 else :
195194 self .experiment .log (metrics )
196195
@@ -210,10 +209,6 @@ def version(self) -> Optional[str]:
210209
211210 @rank_zero_only
212211 def finalize (self , status : str ) -> None :
213- # offset future training logged on same W&B run
214- if self ._experiment is not None :
215- self ._step_offset = self ._experiment .step
216-
217212 # upload all checkpoints from saving dir
218213 if self ._log_model :
219214 wandb .save (os .path .join (self .save_dir , "*.ckpt" ))
0 commit comments