File tree Expand file tree Collapse file tree 2 files changed +3
-27
lines changed Expand file tree Collapse file tree 2 files changed +3
-27
lines changed Original file line number Diff line number Diff line change 2222import tempfile
2323from collections import namedtuple
2424from typing import Optional , Union , Dict
25- import yaml
2625
2726import sagemaker .image_uris
2827from sagemaker .session_settings import SessionSettings
@@ -248,7 +247,7 @@ def parse_mp_parameters(params):
248247
249248 Raises:
250249 ValueError: if params is not a string or a dict, or
251- the config file cannot be parsed as json or yaml .
250+ the config file cannot be parsed as json.
252251 """
253252 parsed = None
254253 if isinstance (params , dict ):
@@ -258,19 +257,15 @@ def parse_mp_parameters(params):
258257 with open (params , "r" ) as fp :
259258 parsed = json .load (fp )
260259 except json .decoder .JSONDecodeError :
261- try :
262- with open (params , "r" ) as fp :
263- parsed = yaml .load (fp )
264- except yaml .YAMLError :
265- pass
260+ pass
266261 else :
267262 raise ValueError (
268263 f"Expected a string path to an existing modelparallel config, or a dictionary. "
269264 f"Received: { params } ."
270265 )
271266
272267 if parsed is None :
273- raise ValueError (f"Cannot parse { params } as a json or yaml file." )
268+ raise ValueError (f"Cannot parse { params } as a json file." )
274269
275270 return parsed
276271
Original file line number Diff line number Diff line change 1818import tarfile
1919from contextlib import contextmanager
2020from itertools import product
21- import yaml
2221
2322import pytest
2423
@@ -226,24 +225,6 @@ def test_parse_mp_parameters_input_str_json():
226225 os .remove (json_file_path )
227226
228227
229- def test_parse_mp_parameters_input_str_yaml ():
230- mp_parameters = {
231- "partitions" : 1 ,
232- "tensor_parallel_degree" : 2 ,
233- "microbatches" : 1 ,
234- "optimize" : "speed" ,
235- "pipeline" : "interleaved" ,
236- "ddp" : 1 ,
237- "auto_partition" : False ,
238- "default_partition" : 0 ,
239- }
240- yaml_file_path = "./params.yaml"
241- with open (yaml_file_path , "x" ) as fp :
242- yaml .dump (mp_parameters , fp )
243- assert mp_parameters == fw_utils .parse_mp_parameters (yaml_file_path )
244- os .remove (yaml_file_path )
245-
246-
247228def test_parse_mp_parameters_input_not_exit ():
248229 with pytest .raises (ValueError ):
249230 fw_utils .parse_mp_parameters (" !@#$%^&*()path probably in not there.!@#$%^&*()" )
You can’t perform that action at this time.
0 commit comments