Skip to content

Commit 4ad7a1f

Browse files
authored
[Chore] create a utility for calculating the expected number of shards. (#8692)
create a utility for calculating the expected number of shards.
1 parent 1f81fbe commit 4ad7a1f

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

tests/models/test_modeling_common.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@
5555
from ..others.test_utils import TOKEN, USER, is_staging_test
5656

5757

58+
def caculate_expected_num_shards(index_map_path):
59+
with open(index_map_path) as f:
60+
weight_map_dict = json.load(f)["weight_map"]
61+
first_key = list(weight_map_dict.keys())[0]
62+
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
63+
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
64+
return expected_num_shards
65+
66+
5867
# Will be run via run_test_in_subprocess
5968
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
6069
error = None
@@ -888,12 +897,7 @@ def test_sharded_checkpoints(self):
888897
# Now check if the right number of shards exists. First, let's get the number of shards.
889898
# Since this number can be dependent on the model being tested, it's important that we calculate it
890899
# instead of hardcoding it.
891-
with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f:
892-
weight_map_dict = json.load(f)["weight_map"]
893-
first_key = list(weight_map_dict.keys())[0]
894-
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
895-
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
896-
900+
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
897901
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
898902
self.assertTrue(actual_num_shards == expected_num_shards)
899903

@@ -924,12 +928,7 @@ def test_sharded_checkpoints_device_map(self):
924928
# Now check if the right number of shards exists. First, let's get the number of shards.
925929
# Since this number can be dependent on the model being tested, it's important that we calculate it
926930
# instead of hardcoding it.
927-
with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f:
928-
weight_map_dict = json.load(f)["weight_map"]
929-
first_key = list(weight_map_dict.keys())[0]
930-
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
931-
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
932-
931+
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
933932
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
934933
self.assertTrue(actual_num_shards == expected_num_shards)
935934

0 commit comments

Comments
 (0)