Skip to content

Commit 9aefadf

Browse files
committed
Add a test for map_location="cpu"
Summary: torchtune is using torch.load(file_name, map_location="cpu", mmap=True), so we add a test to make sure this works with tensor subclass API Test Plan: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_save_load_map_location Reviewers: Subscribers: Tasks: Tags:
1 parent 05038a1 commit 9aefadf

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

test/quantization/test_quant_api.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,26 @@ def test_quantized_model_to_device(self):
635635
cuda_res = m(*example_inputs_cuda)
636636
self.assertEqual(cuda_res.cpu(), ref)
637637

638+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
639+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
640+
def test_quantized_tensor_subclass_save_load_map_location(self):
641+
m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device="cuda")
642+
m_copy = copy.deepcopy(m)
643+
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")
644+
645+
quantize_(m, int8_weight_only())
646+
ref = m(*example_inputs)
647+
with tempfile.NamedTemporaryFile() as f:
648+
torch.save(m.state_dict(), f)
649+
f.seek(0)
650+
state_dict = torch.load(f.name, map_location="cpu", mmap=True)
651+
652+
m_copy.load_state_dict(state_dict, assign=True)
653+
m_copy.to(dtype=torch.bfloat16, device="cuda")
654+
655+
res = m_copy(*example_inputs)
656+
self.assertEqual(res, ref)
657+
638658

639659
if __name__ == "__main__":
640660
unittest.main()

0 commit comments

Comments
 (0)