Skip to content

Auto move input to proper device for inference #1412

@tcwalther

Description

@tcwalther

Does PyTorch Lightning provide abstractions for inference? In particular, does it provide ways of automatically handling the transfer to/from GPU when I call model(x), or do I need to roll my own code for that?

Example Use Case

I have a use case where I train a model on slices of a sliding window of an audio spectrogram (i.e., let's say 1 second chunks). When training is finished, I'd like to see the performance of the model on an entire file. Pseudocode:

# generate training data
X, Y = [], []
for audio_file in audio_files:
    for x, y in sliding_window(audio_file):
        X.append(x); Y.append(y)
X, Y = shuffle(X, Y)  # shuffle the slices of all files

# Train model on slices
model = ExampleModel(X, Y)
trainer = Trainer(gpus=1)
trainer.fit(model)

# Plot the performance on a whole test file:
test_Y = []
for x, _ in sliding_window(test_file)
    test_Y.append(model(x))
plt.plot(test_Y)

Notice that during training, the notion of a file is entirely gone, but when I plot my test file, I reintroduce that. Of course, in my real code, my training data X, Y is split into training, validation and test, as usual. The plotting step is an additional verification; sort of like putting the pieces together.

Problem

When the model runs on the GPU, The last part of the code becomes:

# Plot the performance on a whole test file:
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_Y = []
for x, _ in sliding_window(test_file)
    y = model(x.to(device)).cpu()
    test_Y.append(y)
plt.plot(test_Y)

This isn't the end of the world, but it's not as nice as the other code that PyTorch Lightning helped me refactor. I also can't call x.type_as(...) since in that loop, I have no reference type that lives on the CPU/GPU that I could refer to (or maybe I can, but I haven't figured it out).

A workaround to this is to save the model and load it again, on a CPU.

# Train model on slices
# ...
trainer.fit(model)
trainer.save_checkpoint("model.ckpt")
model = ExampleModel.load_from_checkpoint("model.ckpt")

# Plot the performance on a whole test file:
model.eval()
test_Y = []
for x, _ in sliding_window(test_file)
    test_Y.append(model(x))
plt.plot(test_Y)

While this removes the noise of the .to(device) and .cpu() calls, it adds the overhead of having to save the model every time. I also still have to manually call model.eval(). The use case of running my model on an entire audio file is not for metrics but for visual inspection; as such I always only sample a few audio files. Running the model on a CPU instead of a GPU for inference thus isn't a problem.

Question

Is there a more elegant way to achieve the above?

Metadata

Metadata

Assignees

Labels

discussionIn a discussion stagefeatureIs an improvement or enhancementhelp wantedOpen to be worked onlet's do it!approved to implement

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions