|
145 | 145 | ], |
146 | 146 | } |
147 | 147 |
|
148 | | -PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [ |
149 | | - "1.10", |
150 | | - "1.10.0", |
151 | | - "1.10.2", |
152 | | - "1.11", |
153 | | - "1.11.0", |
154 | | - "1.12", |
155 | | - "1.12.0", |
156 | | - "1.12.1", |
157 | | - "1.13.1", |
158 | | - "2.0.0", |
159 | | - "2.0.1", |
160 | | - "2.1.0", |
161 | | - "2.2.0", |
162 | | -] |
163 | | - |
164 | 148 | TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [ |
165 | 149 | "1.13.1", |
166 | 150 | "2.0.0", |
@@ -915,13 +899,6 @@ def validate_distribution( |
915 | 899 | ) |
916 | 900 | if framework_name and framework_name == "pytorch": |
917 | 901 | # We need to validate only for PyTorch framework |
918 | | - validate_pytorch_distribution( |
919 | | - distribution=validated_distribution, |
920 | | - framework_name=framework_name, |
921 | | - framework_version=framework_version, |
922 | | - py_version=py_version, |
923 | | - image_uri=image_uri, |
924 | | - ) |
925 | 902 | validate_torch_distributed_distribution( |
926 | 903 | instance_type=instance_type, |
927 | 904 | distribution=validated_distribution, |
@@ -955,13 +932,6 @@ def validate_distribution( |
955 | 932 | ) |
956 | 933 | if framework_name and framework_name == "pytorch": |
957 | 934 | # We need to validate only for PyTorch framework |
958 | | - validate_pytorch_distribution( |
959 | | - distribution=validated_distribution, |
960 | | - framework_name=framework_name, |
961 | | - framework_version=framework_version, |
962 | | - py_version=py_version, |
963 | | - image_uri=image_uri, |
964 | | - ) |
965 | 935 | validate_torch_distributed_distribution( |
966 | 936 | instance_type=instance_type, |
967 | 937 | distribution=validated_distribution, |
@@ -1010,63 +980,6 @@ def validate_distribution_for_instance_type(instance_type, distribution): |
1010 | 980 | raise ValueError(err_msg) |
1011 | 981 |
|
1012 | 982 |
|
1013 | | -def validate_pytorch_distribution( |
1014 | | - distribution, framework_name, framework_version, py_version, image_uri |
1015 | | -): |
1016 | | - """Check if pytorch distribution strategy is correctly invoked by the user. |
1017 | | -
|
1018 | | - Args: |
1019 | | - distribution (dict): A dictionary with information to enable distributed training. |
1020 | | - (Defaults to None if distributed training is not enabled.) For example: |
1021 | | -
|
1022 | | - .. code:: python |
1023 | | -
|
1024 | | - { |
1025 | | - "pytorchddp": { |
1026 | | - "enabled": True |
1027 | | - } |
1028 | | - } |
1029 | | - framework_name (str): A string representing the name of framework selected. |
1030 | | - framework_version (str): A string representing the framework version selected. |
1031 | | - py_version (str): A string representing the python version selected. |
1032 | | - image_uri (str): A string representing a Docker image URI. |
1033 | | -
|
1034 | | - Raises: |
1035 | | - ValueError: if |
1036 | | - `py_version` is not python3 or |
1037 | | - `framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS |
1038 | | - """ |
1039 | | - if framework_name and framework_name != "pytorch": |
1040 | | - # We need to validate only for PyTorch framework |
1041 | | - return |
1042 | | - |
1043 | | - pytorch_ddp_enabled = False |
1044 | | - if "pytorchddp" in distribution: |
1045 | | - pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False) |
1046 | | - if not pytorch_ddp_enabled: |
1047 | | - # Distribution strategy other than pytorchddp is selected |
1048 | | - return |
1049 | | - |
1050 | | - err_msg = "" |
1051 | | - if not image_uri: |
1052 | | - # ignore framework_version and py_version if image_uri is set |
1053 | | - # in case image_uri is not set, then both are mandatory |
1054 | | - if framework_version not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS: |
1055 | | - err_msg += ( |
1056 | | - f"Provided framework_version {framework_version} is not supported by" |
1057 | | - " pytorchddp.\n" |
1058 | | - "Please specify one of the supported framework versions:" |
1059 | | - f" {PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS} \n" |
1060 | | - ) |
1061 | | - if "py3" not in py_version: |
1062 | | - err_msg += ( |
1063 | | - f"Provided py_version {py_version} is not supported by pytorchddp.\n" |
1064 | | - "Please specify py_version>=py3" |
1065 | | - ) |
1066 | | - if err_msg: |
1067 | | - raise ValueError(err_msg) |
1068 | | - |
1069 | | - |
1070 | 983 | def validate_torch_distributed_distribution( |
1071 | 984 | instance_type, |
1072 | 985 | distribution, |
|
0 commit comments