From 69155ec263810e70554b0a48f053770d2d091c89 Mon Sep 17 00:00:00 2001 From: Patrick Leask Date: Wed, 18 Sep 2024 16:00:49 +0100 Subject: [PATCH] add wandb artifact logging --- training.py | 57 +++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/training.py b/training.py index f100fee..79de6a0 100644 --- a/training.py +++ b/training.py @@ -1,7 +1,3 @@ -""" -Training dictionaries -""" - import json import multiprocessing as mp import os @@ -17,6 +13,20 @@ from .trainers.standard import StandardTrainer +def save_checkpoint(wandb_run, model_path, config_path, name, step): + # Create and log artifact + artifact = wandb.Artifact( + name=name, + type="model", + description=f"Model checkpoint at step {step}", + ) + artifact.add_file(model_path) + artifact.add_file(config_path) + wandb_run.log_artifact(artifact) + + print(f"Model and config saved as artifact at step {step}") + + def new_wandb_process(config, log_queue, entity, project): wandb.init(entity=entity, project=project, config=config, name=config["wandb_name"]) while True: @@ -24,7 +34,12 @@ def new_wandb_process(config, log_queue, entity, project): log = log_queue.get(timeout=1) if log == "DONE": break - wandb.log(log) + if isinstance(log, dict) and log.get("artifact", False): + # Handle artifact saving + artifact_data = log["artifact_data"] + save_checkpoint(wandb_run=wandb.run, **artifact_data) + else: + wandb.log(log) except Empty: continue wandb.finish() @@ -88,7 +103,7 @@ def trainSAE( run_cfg={}, ): """ - Train SAEs using the given trainers + Train SAEs using the given trainers and save them as wandb artifacts """ trainers = [] for config in trainer_configs: @@ -141,23 +156,45 @@ def trainSAE( # saving if save_steps is not None and step % save_steps == 0: - for dir, trainer in zip(save_dirs, trainers): + for dir, trainer, log_queue in zip(save_dirs, trainers, log_queues): if dir is not None: if not os.path.exists(os.path.join(dir, "checkpoints")): os.mkdir(os.path.join(dir, "checkpoints")) + save_path = os.path.join(dir, "checkpoints", f"ae_{step}.pt") t.save( trainer.ae.state_dict(), - os.path.join(dir, "checkpoints", f"ae_{step}.pt"), + save_path, ) + config_path = os.path.join(dir, "config.json") + # Send message to wandb process to save artifact + if use_wandb: + # Prepare artifact data + artifact_data = { + "model_path": save_path, + "config_path": config_path, + "name": f"{trainer.config.get('wandb_name', 'trainer')}_{step}", + "step": step, + } + log_queue.put({"artifact": True, "artifact_data": artifact_data}) # training for trainer in trainers: trainer.update(step, act) # save final SAEs - for save_dir, trainer in zip(save_dirs, trainers): + for save_dir, trainer, log_queue in zip(save_dirs, trainers, log_queues): if save_dir is not None: - t.save(trainer.ae.state_dict(), os.path.join(save_dir, "ae.pt")) + save_path = os.path.join(save_dir, "ae.pt") + t.save(trainer.ae.state_dict(), save_path) + config_path = os.path.join(save_dir, "config.json") + if use_wandb: + artifact_data = { + "model_path": save_path, + "config_path": config_path, + "name": f"{trainer.config.get('wandb_name', 'trainer')}_final", + "step": 'final', + } + log_queue.put({"artifact": True, "artifact_data": artifact_data}) # Signal wandb processes to finish if use_wandb: