Skip to content

Commit c3fd9c2

Browse files
committed
.
1 parent 5a17564 commit c3fd9c2

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

tests/helpers/pipelines.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,21 +101,20 @@ def run_model_test(
101101
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)
102102

103103

104+
@torch.no_grad()
104105
def run_prediction_eval_model_template(trained_model, dataloader, dp=False, min_acc=0.50):
105106
# run prediction on 1 batch
106107
batch = next(iter(dataloader))
107108
x, y = batch
108109
x = x.view(x.size(0), -1)
109110

110111
if dp:
111-
with torch.no_grad():
112-
output = trained_model(batch, 0)
113-
acc = output['val_acc']
112+
output = trained_model(batch, 0)
113+
acc = output['val_acc']
114114
acc = torch.mean(acc).item()
115115

116116
else:
117-
with torch.no_grad():
118-
y_hat = trained_model(x)
119-
acc = accuracy(y_hat.cpu(), y.cpu()).item()
117+
y_hat = trained_model(x)
118+
acc = accuracy(y_hat.cpu(), y.cpu(), top_k=2).item()
120119

121120
assert acc >= min_acc, f"This model is expected to get > {min_acc} in test set (it got {acc})"

0 commit comments

Comments
 (0)