-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Does PyTorch Lightning provide abstractions for inference? In particular, does it provide ways of automatically handling the transfer to/from GPU when I call model(x), or do I need to roll my own code for that?
Example Use Case
I have a use case where I train a model on slices of a sliding window of an audio spectrogram (i.e., let's say 1 second chunks). When training is finished, I'd like to see the performance of the model on an entire file. Pseudocode:
# generate training data
X, Y = [], []
for audio_file in audio_files:
for x, y in sliding_window(audio_file):
X.append(x); Y.append(y)
X, Y = shuffle(X, Y) # shuffle the slices of all files
# Train model on slices
model = ExampleModel(X, Y)
trainer = Trainer(gpus=1)
trainer.fit(model)
# Plot the performance on a whole test file:
test_Y = []
for x, _ in sliding_window(test_file)
test_Y.append(model(x))
plt.plot(test_Y)Notice that during training, the notion of a file is entirely gone, but when I plot my test file, I reintroduce that. Of course, in my real code, my training data X, Y is split into training, validation and test, as usual. The plotting step is an additional verification; sort of like putting the pieces together.
Problem
When the model runs on the GPU, The last part of the code becomes:
# Plot the performance on a whole test file:
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_Y = []
for x, _ in sliding_window(test_file)
y = model(x.to(device)).cpu()
test_Y.append(y)
plt.plot(test_Y)This isn't the end of the world, but it's not as nice as the other code that PyTorch Lightning helped me refactor. I also can't call x.type_as(...) since in that loop, I have no reference type that lives on the CPU/GPU that I could refer to (or maybe I can, but I haven't figured it out).
A workaround to this is to save the model and load it again, on a CPU.
# Train model on slices
# ...
trainer.fit(model)
trainer.save_checkpoint("model.ckpt")
model = ExampleModel.load_from_checkpoint("model.ckpt")
# Plot the performance on a whole test file:
model.eval()
test_Y = []
for x, _ in sliding_window(test_file)
test_Y.append(model(x))
plt.plot(test_Y)While this removes the noise of the .to(device) and .cpu() calls, it adds the overhead of having to save the model every time. I also still have to manually call model.eval(). The use case of running my model on an entire audio file is not for metrics but for visual inspection; as such I always only sample a few audio files. Running the model on a CPU instead of a GPU for inference thus isn't a problem.
Question
Is there a more elegant way to achieve the above?