Skip to content

WEIGHT_FORMAT_WARN in RNN.cpp does not get set on rocm  #1077

@bmedishe

Description

@bmedishe

🚀 The feature, motivation and pitch

I am working on enabling test_nn.py test_cudnn_weight_format on rocm, an observed that the test works if

diff --git a/test/test_nn.py b/test/test_nn.py
index c8311c91d7..85b391e880 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -8460,8 +8460,9 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
                 with warnings.catch_warnings(record=True) as w:
                     output_noncontig = rnn(input, hx)
                 if first_warn:
-                    self.assertEqual(len(w), 1)
-                    self.assertIn('weights are not part of single contiguous chunk of memory', w[0].message.args[0])
+                    if (torch.version.hip is None):
+                        self.assertEqual(len(w), 1)
+                        self.assertIn('weights are not part of single contiguous chunk of memory', w[0].message.args[0])
                     first_warn = False
                     warnings.resetwarnings()
                 output_noncontig[0].sum().backward()

This warning is generated from aten/src/ATen/native/cudnn/RNN.cpp

How can this test pass without bypassing the above checks ??

Alternatives

No response

Additional context

No response

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions