Skip to content

Commit 2d28f83

Browse files
committed
.
1 parent 3696825 commit 2d28f83

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

tests/helpers/pipelines.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,21 +105,20 @@ def run_model_test(
105105
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)
106106

107107

108+
@torch.no_grad()
108109
def run_prediction_eval_model_template(trained_model, dataloader, dp=False, min_acc=0.50):
109110
# run prediction on 1 batch
110111
batch = next(iter(dataloader))
111112
x, y = batch
112113
x = x.view(x.size(0), -1)
113114

114115
if dp:
115-
with torch.no_grad():
116-
output = trained_model(batch, 0)
117-
acc = output['val_acc']
116+
output = trained_model(batch, 0)
117+
acc = output['val_acc']
118118
acc = torch.mean(acc).item()
119119

120120
else:
121-
with torch.no_grad():
122-
y_hat = trained_model(x)
123-
acc = accuracy(y_hat.cpu(), y.cpu()).item()
121+
y_hat = trained_model(x)
122+
acc = accuracy(y_hat.cpu(), y.cpu(), top_k=2).item()
124123

125124
assert acc >= min_acc, f"This model is expected to get > {min_acc} in test set (it got {acc})"

0 commit comments

Comments
 (0)