forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 69
Open
Description
🚀 The feature, motivation and pitch
I am working on enabling test_nn.py test_cudnn_weight_format on rocm, an observed that the test works if
diff --git a/test/test_nn.py b/test/test_nn.py
index c8311c91d7..85b391e880 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -8460,8 +8460,9 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
with warnings.catch_warnings(record=True) as w:
output_noncontig = rnn(input, hx)
if first_warn:
- self.assertEqual(len(w), 1)
- self.assertIn('weights are not part of single contiguous chunk of memory', w[0].message.args[0])
+ if (torch.version.hip is None):
+ self.assertEqual(len(w), 1)
+ self.assertIn('weights are not part of single contiguous chunk of memory', w[0].message.args[0])
first_warn = False
warnings.resetwarnings()
output_noncontig[0].sum().backward()
This warning is generated from aten/src/ATen/native/cudnn/RNN.cpp
How can this test pass without bypassing the above checks ??
Alternatives
No response
Additional context
No response
Metadata
Metadata
Assignees
Labels
No labels