diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 1d476319dd..8d949fdf84 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -635,6 +635,28 @@ def test_quantized_model_to_device(self): cuda_res = m(*example_inputs_cuda) self.assertEqual(cuda_res.cpu(), ref) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_quantized_tensor_subclass_save_load_map_location(self): + m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device="cuda") + example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") + + quantize_(m, int8_weight_only()) + ref = m(*example_inputs) + with tempfile.NamedTemporaryFile() as f: + torch.save(m.state_dict(), f) + f.seek(0) + state_dict = torch.load(f.name, map_location="cpu", mmap=True) + + with torch.device('meta'): + m_copy = ToyLinearModel().eval() + + m_copy.load_state_dict(state_dict, assign=True) + m_copy.to(dtype=torch.bfloat16, device="cuda") + + res = m_copy(*example_inputs) + self.assertEqual(res, ref) + if __name__ == "__main__": unittest.main()