Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 21 additions & 76 deletions tests/models/cwm/test_modeling_cwm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from transformers import is_torch_available
from transformers.testing_utils import (
cleanup,
require_read_token,
require_torch,
require_torch_accelerator,
slow,
Expand Down Expand Up @@ -85,6 +86,7 @@ class CwmModelTest(CausalLMModelTest, unittest.TestCase):

@require_torch_accelerator
@slow
@require_read_token
class CwmIntegrationTest(unittest.TestCase):
def setUp(self):
cleanup(torch_device, gc_collect=True)
Expand Down Expand Up @@ -116,45 +118,14 @@ def test_cwm_integration(self):
with torch.no_grad():
out = model(**inputs)

# fmt: off
expected_logits = torch.tensor(
[
0.5625,
2.9531,
9.1875,
0.4746,
-0.3613,
2.2031,
2.9844,
1.5312,
0.5859,
1.5391,
2.7500,
3.4375,
2.0156,
2.1719,
1.5469,
2.5469,
2.8438,
1.8203,
1.7188,
1.3984,
1.0469,
0.1748,
0.4453,
0.1533,
-0.1157,
0.8516,
2.2344,
5.2188,
1.2891,
1.5234,
0.8555,
0.6992,
],
[0.5625, 2.9531, 9.1875, 0.5039, -0.3262, 2.2344, 3.0312, 1.5312, 0.5664, 1.5625, 2.7656, 3.4219, 2.0312, 2.1719, 1.5391, 2.5469, 2.8281, 1.8125, 1.7109, 1.3906, 1.0391, 0.1621, 0.4277, 0.1455, -0.1230, 0.8477, 2.2344, 5.2188, 1.2969, 1.5547, 0.8516, 0.7148],
dtype=torch.bfloat16,
).to(model.device)
# fmt: on

self.assertTrue(torch.allclose(out.logits[0, -1, :32], expected_logits, atol=1e-2, rtol=1e-2))
torch.testing.assert_close(out.logits[0, -1, :32], expected_logits, atol=1e-2, rtol=1e-2)

self.assertEqual(out.logits.shape[1], inputs.input_ids.shape[1])
self.assertEqual(out.logits.shape[2], model.config.vocab_size)
Expand All @@ -166,10 +137,13 @@ def test_cwm_sliding_window_long_sequence(self):
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("facebook/cwm")
model = CwmForCausalLM.from_pretrained("facebook/cwm", device_map="auto", dtype=torch.bfloat16)
# original `sliding_window` is `8192`, but it causes GPU OOM on A10
model = CwmForCausalLM.from_pretrained(
"facebook/cwm", device_map="auto", dtype=torch.bfloat16, sliding_window=4096
)

sliding_window = model.config.sliding_window
long_text = "for i in range(1000):\n print(f'iteration {i}')\n" * 600
long_text = "for i in range(1000):\n print(f'iteration {i}')\n" * 270

inputs = tokenizer(long_text, return_tensors="pt").to(model.device)
seq_len = inputs.input_ids.shape[1]
Expand All @@ -182,50 +156,21 @@ def test_cwm_sliding_window_long_sequence(self):
with torch.no_grad():
out = model(**inputs)

# fmt: off
expected_logits = torch.tensor(
[
4.7812,
6.1875,
13.1875,
4.4062,
5.0312,
3.9844,
6.6875,
4.8438,
2.3125,
6.5000,
4.4688,
0.5195,
5.6562,
3.3125,
2.7500,
4.9062,
5.5938,
4.1562,
3.9531,
2.4062,
3.2812,
2.8594,
3.4688,
2.9688,
2.6875,
3.4531,
2.7344,
7.2812,
4.5000,
5.7500,
2.3438,
5.9688,
],
[5.2812, 6.4688, 12.8125, 4.6875, 5.2500, 4.2500, 6.9688, 4.9375, 2.7656, 6.5938, 4.9688, 1.1016, 5.9375, 3.7500, 3.1094, 5.5312, 6.1250, 4.7500, 4.5312, 2.8281, 4.0625, 3.3125, 3.9219, 3.3906, 3.1406, 3.6719, 3.2031, 7.0938, 4.8750, 6.0000, 2.7188, 6.2500],
dtype=torch.bfloat16,
).to(model.device)
# fmt: on

self.assertTrue(torch.allclose(out.logits[0, -1, :32], expected_logits, atol=1e-2, rtol=1e-2))
torch.testing.assert_close(out.logits[0, -1, :32], expected_logits, atol=1e-2, rtol=1e-2)

self.assertEqual(out.logits.shape[1], seq_len)
self.assertEqual(out.logits.shape[2], model.config.vocab_size)
self.assertFalse(torch.isnan(out.logits).any())
self.assertFalse(torch.isinf(out.logits).any())
logits = out.logits.to("cpu")

self.assertEqual(logits.shape[1], seq_len)
self.assertEqual(logits.shape[2], model.config.vocab_size)
self.assertFalse(torch.isnan(logits).any())
self.assertFalse(torch.isinf(logits).any())

for i, layer in enumerate(model.model.layers):
if model.config.layer_types[i] == "sliding_attention":
Expand Down