Skip to content

Callback metrics not being populated during multi-gpu training #7671

@jacanchaplais

Description

@jacanchaplais

🐛 Bug

When plt.Trainer(gpus>1, ...), the callback_metrics dictionary appears not to be populated. I've tried to both integrate Optuna and Ray Tune, and have failed with both as a result.

To ensure this was the issue, I printed the trainer.callback_metrics attribute:

(pid=144372) callback metrics are:
(pid=144372)  {}

It does, however, work with 1 GPU. Unfortunately for me, my datasets are graphs, and they are so large that I can only fit one into memory at a time, so the number of GPUs = batch size, and tuning with a batch size of 1 might not be very indicative. Any help much appreciated!

Environment

Hardware

  • 2x RTX 8000 GPUs

Software

  • Python 3.8.10
  • PyTorch 1.7.1
  • PyTorch Lightning 1.3.1
  • Ray[Tune] 1.1.0
  • cudatoolkit 10.2
  • OS: Linux version 3.10.0-693.11.6.el7.x86_64, Red Hat 4.8.5-16
  • Installed via conda
  • Any other relevant information: using SLURM

Code

I attach the LightningModule below, which uses TorchMetrics and the self.log features, as per the docs. I did (in desperation) try setting the callback_metrics dictionary myself in validation_epoch_end(), but that didn't work. Neither did setting sync_dist=True in the self.log() calls.

import torch
import torchmetrics
import torch_geometric as pyg
import pytorch_lightning as pl

class Interaction(pyg.nn.MessagePassing):
    def __init__(self, in_edge, in_node, out_edge, out_node):
        super(Interaction, self).__init__(
            aggr='add',
            flow="source_to_target")
        self.in_edge = 2 * in_node + in_edge
        self.in_node = in_node + out_edge
        self.mlp_edge = torch.nn.Sequential(
            torch.nn.Linear(self.in_edge, out_edge, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(out_edge, out_edge, bias=True)
        )
        self.mlp_node = torch.nn.Sequential(
            torch.nn.Linear(self.in_node, out_node, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(out_node, out_node, bias=True)
        )

    def forward(self, x, edge_index, edge_attrs):
        return self.propagate(
            x=x,
            edge_index=edge_index,
            edge_attrs=edge_attrs
        )

    def message(self, x_i, x_j, edge_index, edge_attrs):
        recv_send = [x_i, x_j]
        if edge_attrs is not None:
            recv_send.append(edge_attrs)
        recv_send = torch.cat(recv_send, dim=1)
        self.edge_embed = self.mlp_edge(recv_send)
        return self.edge_embed

    def update(self, aggr_out, x):
        node_embed = self.mlp_node(torch.cat([x, aggr_out], dim=1))
        return (self.edge_embed, node_embed)


class Net(pl.LightningModule):
    def __init__(self, dim_node: int = 4, dim_edge: int = 0,
                 dim_embed_edge: int = 64, dim_embed_node: int = 32,
                 num_hidden: int = 3, final_bias: bool = False,
                 pos_weight: float = 80.0,
                 learn_rate: float = 1e-4, weight_decay: float = 5e-4,
                 infer_thresh: float = 0.5):
        super(Net, self).__init__()
        # define the architecture
        self.encode = Interaction(dim_edge, dim_node,
                                  dim_embed_edge, dim_embed_node)
        self.message = pyg.nn.Sequential('x, edge_index, edge_attrs', [
            (Interaction(dim_embed_edge, dim_embed_node,
                         dim_embed_edge, dim_embed_node),
             'x, edge_index, edge_attrs -> edge_attrs, x')
             for i in range(num_hidden)
             ])
        self.classify = torch.nn.Linear(dim_embed_edge, 1, bias=final_bias)
        # optimiser args
        self.lr = learn_rate
        self.decay = weight_decay
        # configure the loss
        self.criterion = torch.nn.BCEWithLogitsLoss(
                pos_weight=torch.tensor(pos_weight, device=self.device),
                reduction='mean')
        # add metrics
        self.train_ACC = torchmetrics.Accuracy(threshold=infer_thresh)
        self.train_F1 = torchmetrics.F1(
                num_classes=1, threshold=infer_thresh)
        self.val_ACC = torchmetrics.Accuracy(threshold=infer_thresh)
        self.val_F1 = torchmetrics.F1(
                num_classes=1, threshold=infer_thresh)
        self.val_PR = torchmetrics.BinnedPrecisionRecallCurve(
                num_classes=1, num_thresholds=5)

    def forward(self, data, sigmoid=True):
        node_attrs, edge_attrs = data.x, data.edge_attr
        edge_attrs, node_attrs = self.encode(node_attrs, data.edge_index,
                                             edge_attrs)
        edge_attrs, node_attrs = self.message(node_attrs, data.edge_index,
                                              edge_attrs)
        pred = self.classify(edge_attrs)
        if sigmoid:
            pred = torch.sigmoid(pred)
        return pred

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
                self.parameters(),
                lr=self.lr,
                weight_decay=self.decay
            )
        return optimizer

    def _train_av_loss(self, outputs):
        return torch.stack([x['loss'] for x in outputs]).mean()

    def _val_av_loss(self, losses):
        return torch.stack(losses).mean()

    def training_step(self, batch, batch_idx):
        edge_pred = self(batch, sigmoid=False)
        loss = self.criterion(edge_pred, batch.y.view(-1, 1))
        return {'loss': loss,
                'preds': torch.sigmoid(edge_pred),
                'target': batch.y.view(-1, 1).int()}

    def training_step_end(self, outputs):
        self.train_ACC(outputs['preds'], outputs['target'])
        self.train_F1(outputs['preds'], outputs['target'])
        self.log('ptl/train_loss', outputs['loss'], on_step=True)
        return outputs['loss']

    def training_epoch_end(self, outputs):
        self.log('ptl/train_loss', self._train_av_loss(outputs))
        self.log('ptl/train_accuracy', self.train_ACC.compute())
        self.log('ptl/train_f', self.train_F1.compute())

    def validation_step(self, batch, batch_idx):
        edge_pred = self(batch, sigmoid=False)
        loss = self.criterion(edge_pred, batch.y.view(-1, 1))
        return {'loss': loss,
                'preds': torch.sigmoid(edge_pred),
                'target': batch.y.view(-1, 1).int()}

    def validation_step_end(self, outputs):
        self.val_ACC(outputs['preds'], outputs['target'])
        self.val_F1(outputs['preds'], outputs['target'])
        self.val_PR(outputs['preds'], outputs['target'])
        self.log('ptl/val_loss', outputs['loss'], on_step=True)
        return outputs['loss']

    def validation_epoch_end(self, outputs):
        metrics = {
            'ptl/val_loss': self._val_av_loss(outputs),
            'ptl/val_accuracy': self.val_ACC.compute(),
            'ptl/val_f': self.val_F1.compute(),
            }
        self.log_dict(metrics, sync_dist=True)
        prec, recall, thresh = self.val_PR.compute()
        for i, t in enumerate(thresh):
            self.log(f'ptl/val_prec_thresh_{t:.3f}', prec[i])
            self.log(f'ptl/val_recall_thresh_{t:.3f}', recall[i])
        self.trainer.callback_metrics = metrics
        return metrics

Here I attach the tuning script using Ray[Tune], as per their docs https://docs.ray.io/en/master/tune/tutorials/tune-pytorch-lightning.html.

import os

import pytorch_lightning as pl
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback

from cluster_gnn.models import gnn
from cluster_gnn.data import loader

# slurm hack
os.environ["SLURM_JOB_NAME"] = "bash"

ROOT_DIR = '/home/jlc1n20/projects/cluster_gnn/'
MODEL_DIR = ROOT_DIR + '/models/tune/'

def train_gnn(config, data_module, num_epochs=10, num_gpus=4, callbacks=None,
              checkpoint_dir=None):
    logger = pl.loggers.TensorBoardLogger(
        save_dir=tune.get_trial_dir(), name="", version=".")
    if checkpoint_dir:
        ckpt = pl.utilities.cloud_io.pl_load(
            os.path.join(checkpoint_dir, 'checkpoint'),
            map_location=lambda storage, loc: storage)
        model = gnn.Net._load_model_state(
            checkpoint=ckpt,
            num_hidden=6, dim_embed_edge=64, dim_embed_node=32,
            learn_rate=config['learn_rate'],
            pos_weight=config['pos_weight'])
    else:
        model = gnn.Net(num_hidden=6, dim_embed_edge=64, dim_embed_node=32,
                        learn_rate=config['learn_rate'],
                        pos_weight=config['pos_weight'])
    trainer = pl.Trainer(gpus=num_gpus, num_nodes=1, max_epochs=num_epochs,
                         progress_bar_refresh_rate=0,
                         limit_train_batches=0.1,
                         logger=logger,
                         callbacks=callbacks)
    trainer.fit(model, data_module)
    print('callback metrics are:\n {}'.format(trainer.callback_metrics))

def tune_gnn(data_module, num_samples=10, num_epochs=10, gpus_per_trial=2,
             init_params=None, checkpoint_dir=None):
    config = {
        'learn_rate': tune.loguniform(1e-6, 1e-1),
        'pos_weight': tune.uniform(1.0, 100.0),
        }
    metrics = ['ptl/val_loss', 'ptl/val_accuracy', 'ptl/val_f']
    callbacks = [
        TuneReportCheckpointCallback(
            metrics,
            filename='checkpoint',
            on='validation_end')
        ]
    scheduler = ASHAScheduler(
        time_attr='epoch',
        max_t=num_epochs,
        )
    search_alg = HyperOptSearch(points_to_evaluate=init_params)
    reporter = CLIReporter(
        parameter_columns=[
            'learn_rate',
            'pos_weight',
            ],
        )
    trainable = tune.with_parameters(
        train_gnn,
        data_module=data_module,
        num_epochs=num_epochs,
        num_gpus=gpus_per_trial,
        callbacks=callbacks,
        checkpoint_dir=checkpoint_dir,
        )
    analysis = tune.run(
        trainable,
        resources_per_trial={
            'cpu': 1,
            'gpu': gpus_per_trial,
            },
        metric='ptl/val_f',
        mode='max',
        config=config,
        num_samples=num_samples,
        search_alg=search_alg,
        scheduler=scheduler,
        progress_reporter=reporter,
        local_dir=MODEL_DIR,
        verbose=3,
        name='tune_gnn')
    print('Best hp found: ', analysis.best_config)


if __name__ == '__main__':
    num_gpus = 2
    cur_best_params = [{
        'learn_rate': 3.75e-5,
        'pos_weight': 21.5,
        }]
    graph_data = loader.GraphDataModule(
        '/home/jlc1n20/projects/cluster_gnn/data/', num_workers=num_gpus)

    tune_gnn(data_module=graph_data, num_samples=1, num_epochs=1,
             gpus_per_trial=num_gpus, init_params=cur_best_params)

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdistributedGeneric distributed-related topichelp wantedOpen to be worked onloggingRelated to the `LoggerConnector` and `log()`priority: 1Medium priority taskwaiting on authorWaiting on user action, correction, or update

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions