File tree Expand file tree Collapse file tree 2 files changed +3
-27
lines changed Expand file tree Collapse file tree 2 files changed +3
-27
lines changed Original file line number Diff line number Diff line change 2222import tempfile
2323from collections import namedtuple
2424from typing import Optional , Union , Dict
25- import yaml
2625
2726import sagemaker .image_uris
2827from sagemaker .session_settings import SessionSettings
@@ -202,7 +201,7 @@ def parse_mp_parameters(params):
202201
203202 Raises:
204203 ValueError: if params is not a string or a dict, or
205- the config file cannot be parsed as json or yaml .
204+ the config file cannot be parsed as json.
206205 """
207206 parsed = None
208207 if isinstance (params , dict ):
@@ -212,19 +211,15 @@ def parse_mp_parameters(params):
212211 with open (params , "r" ) as fp :
213212 parsed = json .load (fp )
214213 except json .decoder .JSONDecodeError :
215- try :
216- with open (params , "r" ) as fp :
217- parsed = yaml .load (fp )
218- except yaml .YAMLError :
219- pass
214+ pass
220215 else :
221216 raise ValueError (
222217 f"Expected a string path to an existing modelparallel config, or a dictionary. "
223218 f"Received: { params } ."
224219 )
225220
226221 if parsed is None :
227- raise ValueError (f"Cannot parse { params } as a json or yaml file." )
222+ raise ValueError (f"Cannot parse { params } as a json file." )
228223
229224 return parsed
230225
Original file line number Diff line number Diff line change 1818import tarfile
1919from contextlib import contextmanager
2020from itertools import product
21- import yaml
2221
2322import pytest
2423
@@ -235,24 +234,6 @@ def test_parse_mp_parameters_input_str_json():
235234 os .remove (json_file_path )
236235
237236
238- def test_parse_mp_parameters_input_str_yaml ():
239- mp_parameters = {
240- "partitions" : 1 ,
241- "tensor_parallel_degree" : 2 ,
242- "microbatches" : 1 ,
243- "optimize" : "speed" ,
244- "pipeline" : "interleaved" ,
245- "ddp" : 1 ,
246- "auto_partition" : False ,
247- "default_partition" : 0 ,
248- }
249- yaml_file_path = "./params.yaml"
250- with open (yaml_file_path , "x" ) as fp :
251- yaml .dump (mp_parameters , fp )
252- assert mp_parameters == fw_utils .parse_mp_parameters (yaml_file_path )
253- os .remove (yaml_file_path )
254-
255-
256237def test_parse_mp_parameters_input_not_exit ():
257238 with pytest .raises (ValueError ):
258239 fw_utils .parse_mp_parameters (" !@#$%^&*()path probably in not there.!@#$%^&*()" )
You can’t perform that action at this time.
0 commit comments