-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
When plt.Trainer(gpus>1, ...), the callback_metrics dictionary appears not to be populated. I've tried to both integrate Optuna and Ray Tune, and have failed with both as a result.
To ensure this was the issue, I printed the trainer.callback_metrics attribute:
(pid=144372) callback metrics are:
(pid=144372) {}
It does, however, work with 1 GPU. Unfortunately for me, my datasets are graphs, and they are so large that I can only fit one into memory at a time, so the number of GPUs = batch size, and tuning with a batch size of 1 might not be very indicative. Any help much appreciated!
Environment
Hardware
- 2x RTX 8000 GPUs
Software
- Python 3.8.10
- PyTorch 1.7.1
- PyTorch Lightning 1.3.1
- Ray[Tune] 1.1.0
- cudatoolkit 10.2
- OS: Linux version 3.10.0-693.11.6.el7.x86_64, Red Hat 4.8.5-16
- Installed via conda
- Any other relevant information: using SLURM
Code
I attach the LightningModule below, which uses TorchMetrics and the self.log features, as per the docs. I did (in desperation) try setting the callback_metrics dictionary myself in validation_epoch_end(), but that didn't work. Neither did setting sync_dist=True in the self.log() calls.
import torch
import torchmetrics
import torch_geometric as pyg
import pytorch_lightning as pl
class Interaction(pyg.nn.MessagePassing):
def __init__(self, in_edge, in_node, out_edge, out_node):
super(Interaction, self).__init__(
aggr='add',
flow="source_to_target")
self.in_edge = 2 * in_node + in_edge
self.in_node = in_node + out_edge
self.mlp_edge = torch.nn.Sequential(
torch.nn.Linear(self.in_edge, out_edge, bias=True),
torch.nn.ReLU(),
torch.nn.Linear(out_edge, out_edge, bias=True)
)
self.mlp_node = torch.nn.Sequential(
torch.nn.Linear(self.in_node, out_node, bias=True),
torch.nn.ReLU(),
torch.nn.Linear(out_node, out_node, bias=True)
)
def forward(self, x, edge_index, edge_attrs):
return self.propagate(
x=x,
edge_index=edge_index,
edge_attrs=edge_attrs
)
def message(self, x_i, x_j, edge_index, edge_attrs):
recv_send = [x_i, x_j]
if edge_attrs is not None:
recv_send.append(edge_attrs)
recv_send = torch.cat(recv_send, dim=1)
self.edge_embed = self.mlp_edge(recv_send)
return self.edge_embed
def update(self, aggr_out, x):
node_embed = self.mlp_node(torch.cat([x, aggr_out], dim=1))
return (self.edge_embed, node_embed)
class Net(pl.LightningModule):
def __init__(self, dim_node: int = 4, dim_edge: int = 0,
dim_embed_edge: int = 64, dim_embed_node: int = 32,
num_hidden: int = 3, final_bias: bool = False,
pos_weight: float = 80.0,
learn_rate: float = 1e-4, weight_decay: float = 5e-4,
infer_thresh: float = 0.5):
super(Net, self).__init__()
# define the architecture
self.encode = Interaction(dim_edge, dim_node,
dim_embed_edge, dim_embed_node)
self.message = pyg.nn.Sequential('x, edge_index, edge_attrs', [
(Interaction(dim_embed_edge, dim_embed_node,
dim_embed_edge, dim_embed_node),
'x, edge_index, edge_attrs -> edge_attrs, x')
for i in range(num_hidden)
])
self.classify = torch.nn.Linear(dim_embed_edge, 1, bias=final_bias)
# optimiser args
self.lr = learn_rate
self.decay = weight_decay
# configure the loss
self.criterion = torch.nn.BCEWithLogitsLoss(
pos_weight=torch.tensor(pos_weight, device=self.device),
reduction='mean')
# add metrics
self.train_ACC = torchmetrics.Accuracy(threshold=infer_thresh)
self.train_F1 = torchmetrics.F1(
num_classes=1, threshold=infer_thresh)
self.val_ACC = torchmetrics.Accuracy(threshold=infer_thresh)
self.val_F1 = torchmetrics.F1(
num_classes=1, threshold=infer_thresh)
self.val_PR = torchmetrics.BinnedPrecisionRecallCurve(
num_classes=1, num_thresholds=5)
def forward(self, data, sigmoid=True):
node_attrs, edge_attrs = data.x, data.edge_attr
edge_attrs, node_attrs = self.encode(node_attrs, data.edge_index,
edge_attrs)
edge_attrs, node_attrs = self.message(node_attrs, data.edge_index,
edge_attrs)
pred = self.classify(edge_attrs)
if sigmoid:
pred = torch.sigmoid(pred)
return pred
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.lr,
weight_decay=self.decay
)
return optimizer
def _train_av_loss(self, outputs):
return torch.stack([x['loss'] for x in outputs]).mean()
def _val_av_loss(self, losses):
return torch.stack(losses).mean()
def training_step(self, batch, batch_idx):
edge_pred = self(batch, sigmoid=False)
loss = self.criterion(edge_pred, batch.y.view(-1, 1))
return {'loss': loss,
'preds': torch.sigmoid(edge_pred),
'target': batch.y.view(-1, 1).int()}
def training_step_end(self, outputs):
self.train_ACC(outputs['preds'], outputs['target'])
self.train_F1(outputs['preds'], outputs['target'])
self.log('ptl/train_loss', outputs['loss'], on_step=True)
return outputs['loss']
def training_epoch_end(self, outputs):
self.log('ptl/train_loss', self._train_av_loss(outputs))
self.log('ptl/train_accuracy', self.train_ACC.compute())
self.log('ptl/train_f', self.train_F1.compute())
def validation_step(self, batch, batch_idx):
edge_pred = self(batch, sigmoid=False)
loss = self.criterion(edge_pred, batch.y.view(-1, 1))
return {'loss': loss,
'preds': torch.sigmoid(edge_pred),
'target': batch.y.view(-1, 1).int()}
def validation_step_end(self, outputs):
self.val_ACC(outputs['preds'], outputs['target'])
self.val_F1(outputs['preds'], outputs['target'])
self.val_PR(outputs['preds'], outputs['target'])
self.log('ptl/val_loss', outputs['loss'], on_step=True)
return outputs['loss']
def validation_epoch_end(self, outputs):
metrics = {
'ptl/val_loss': self._val_av_loss(outputs),
'ptl/val_accuracy': self.val_ACC.compute(),
'ptl/val_f': self.val_F1.compute(),
}
self.log_dict(metrics, sync_dist=True)
prec, recall, thresh = self.val_PR.compute()
for i, t in enumerate(thresh):
self.log(f'ptl/val_prec_thresh_{t:.3f}', prec[i])
self.log(f'ptl/val_recall_thresh_{t:.3f}', recall[i])
self.trainer.callback_metrics = metrics
return metricsHere I attach the tuning script using Ray[Tune], as per their docs https://docs.ray.io/en/master/tune/tutorials/tune-pytorch-lightning.html.
import os
import pytorch_lightning as pl
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback
from cluster_gnn.models import gnn
from cluster_gnn.data import loader
# slurm hack
os.environ["SLURM_JOB_NAME"] = "bash"
ROOT_DIR = '/home/jlc1n20/projects/cluster_gnn/'
MODEL_DIR = ROOT_DIR + '/models/tune/'
def train_gnn(config, data_module, num_epochs=10, num_gpus=4, callbacks=None,
checkpoint_dir=None):
logger = pl.loggers.TensorBoardLogger(
save_dir=tune.get_trial_dir(), name="", version=".")
if checkpoint_dir:
ckpt = pl.utilities.cloud_io.pl_load(
os.path.join(checkpoint_dir, 'checkpoint'),
map_location=lambda storage, loc: storage)
model = gnn.Net._load_model_state(
checkpoint=ckpt,
num_hidden=6, dim_embed_edge=64, dim_embed_node=32,
learn_rate=config['learn_rate'],
pos_weight=config['pos_weight'])
else:
model = gnn.Net(num_hidden=6, dim_embed_edge=64, dim_embed_node=32,
learn_rate=config['learn_rate'],
pos_weight=config['pos_weight'])
trainer = pl.Trainer(gpus=num_gpus, num_nodes=1, max_epochs=num_epochs,
progress_bar_refresh_rate=0,
limit_train_batches=0.1,
logger=logger,
callbacks=callbacks)
trainer.fit(model, data_module)
print('callback metrics are:\n {}'.format(trainer.callback_metrics))
def tune_gnn(data_module, num_samples=10, num_epochs=10, gpus_per_trial=2,
init_params=None, checkpoint_dir=None):
config = {
'learn_rate': tune.loguniform(1e-6, 1e-1),
'pos_weight': tune.uniform(1.0, 100.0),
}
metrics = ['ptl/val_loss', 'ptl/val_accuracy', 'ptl/val_f']
callbacks = [
TuneReportCheckpointCallback(
metrics,
filename='checkpoint',
on='validation_end')
]
scheduler = ASHAScheduler(
time_attr='epoch',
max_t=num_epochs,
)
search_alg = HyperOptSearch(points_to_evaluate=init_params)
reporter = CLIReporter(
parameter_columns=[
'learn_rate',
'pos_weight',
],
)
trainable = tune.with_parameters(
train_gnn,
data_module=data_module,
num_epochs=num_epochs,
num_gpus=gpus_per_trial,
callbacks=callbacks,
checkpoint_dir=checkpoint_dir,
)
analysis = tune.run(
trainable,
resources_per_trial={
'cpu': 1,
'gpu': gpus_per_trial,
},
metric='ptl/val_f',
mode='max',
config=config,
num_samples=num_samples,
search_alg=search_alg,
scheduler=scheduler,
progress_reporter=reporter,
local_dir=MODEL_DIR,
verbose=3,
name='tune_gnn')
print('Best hp found: ', analysis.best_config)
if __name__ == '__main__':
num_gpus = 2
cur_best_params = [{
'learn_rate': 3.75e-5,
'pos_weight': 21.5,
}]
graph_data = loader.GraphDataModule(
'/home/jlc1n20/projects/cluster_gnn/data/', num_workers=num_gpus)
tune_gnn(data_module=graph_data, num_samples=1, num_epochs=1,
gpus_per_trial=num_gpus, init_params=cur_best_params)