|
55 | 55 | from ..others.test_utils import TOKEN, USER, is_staging_test |
56 | 56 |
|
57 | 57 |
|
| 58 | +def caculate_expected_num_shards(index_map_path): |
| 59 | + with open(index_map_path) as f: |
| 60 | + weight_map_dict = json.load(f)["weight_map"] |
| 61 | + first_key = list(weight_map_dict.keys())[0] |
| 62 | + weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors |
| 63 | + expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0]) |
| 64 | + return expected_num_shards |
| 65 | + |
| 66 | + |
58 | 67 | # Will be run via run_test_in_subprocess |
59 | 68 | def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): |
60 | 69 | error = None |
@@ -888,12 +897,7 @@ def test_sharded_checkpoints(self): |
888 | 897 | # Now check if the right number of shards exists. First, let's get the number of shards. |
889 | 898 | # Since this number can be dependent on the model being tested, it's important that we calculate it |
890 | 899 | # instead of hardcoding it. |
891 | | - with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f: |
892 | | - weight_map_dict = json.load(f)["weight_map"] |
893 | | - first_key = list(weight_map_dict.keys())[0] |
894 | | - weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors |
895 | | - expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0]) |
896 | | - |
| 900 | + expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) |
897 | 901 | actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) |
898 | 902 | self.assertTrue(actual_num_shards == expected_num_shards) |
899 | 903 |
|
@@ -924,12 +928,7 @@ def test_sharded_checkpoints_device_map(self): |
924 | 928 | # Now check if the right number of shards exists. First, let's get the number of shards. |
925 | 929 | # Since this number can be dependent on the model being tested, it's important that we calculate it |
926 | 930 | # instead of hardcoding it. |
927 | | - with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f: |
928 | | - weight_map_dict = json.load(f)["weight_map"] |
929 | | - first_key = list(weight_map_dict.keys())[0] |
930 | | - weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors |
931 | | - expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0]) |
932 | | - |
| 931 | + expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) |
933 | 932 | actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) |
934 | 933 | self.assertTrue(actual_num_shards == expected_num_shards) |
935 | 934 |
|
|
0 commit comments