1+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4+ # may not use this file except in compliance with the License. A copy of
5+ # the License is located at
6+ #
7+ # http://aws.amazon.com/apache2.0/
8+ #
9+ # or in the "license" file accompanying this file. This file is
10+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+ # ANY KIND, either express or implied. See the License for the specific
12+ # language governing permissions and limitations under the License.
13+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
14+ #
15+ # Licensed under the Apache License, Version 2.0 (the "License"). You
16+ # may not use this file except in compliance with the License. A copy of
17+ # the License is located at
18+ #
19+ # http://aws.amazon.com/apache2.0/
20+ #
21+ # or in the "license" file accompanying this file. This file is
22+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
23+ # ANY KIND, either express or implied. See the License for the specific
24+ # language governing permissions and limitations under the License.
25+ from __future__ import absolute_import
26+
27+ from sagemaker import (
28+ AutoMLTabularConfig ,
29+ AutoMLImageClassificationConfig ,
30+ AutoMLTextGenerationConfig ,
31+ AutoMLTextClassificationConfig ,
32+ AutoMLTimeSeriesForecastingConfig ,
33+ )
34+
35+ # Common params
36+ MAX_CANDIDATES = 10
37+ MAX_RUNTIME_PER_TRAINING_JOB = 3600
38+ TOTAL_JOB_RUNTIME = 36000
39+ BUCKET_NAME = "mybucket"
40+ FEATURE_SPECIFICATION_S3_URI = "s3://{}/features.json" .format (BUCKET_NAME )
41+
42+ # Tabular params
43+ AUTO_ML_TABULAR_ALGORITHMS = "xgboost"
44+ MODE = "ENSEMBLING"
45+ GENERATE_CANDIDATE_DEFINITIONS_ONLY = True
46+ PROBLEM_TYPE = "BinaryClassification"
47+ TARGET_ATTRIBUTE_NAME = "target"
48+ SAMPLE_WEIGHT_ATTRIBUTE_NAME = "sampleWeight"
49+
50+ TABULAR_PROBLEM_CONFIG = {
51+ "CompletionCriteria" : {
52+ "MaxCandidates" : MAX_CANDIDATES ,
53+ "MaxRuntimePerTrainingJobInSeconds" : MAX_RUNTIME_PER_TRAINING_JOB ,
54+ "MaxAutoMLJobRuntimeInSeconds" : TOTAL_JOB_RUNTIME ,
55+ },
56+ "CandidateGenerationConfig" : {
57+ "AlgorithmsConfig" : [{"AutoMLAlgorithms" : AUTO_ML_TABULAR_ALGORITHMS }],
58+ },
59+ "FeatureSpecificationS3Uri" : FEATURE_SPECIFICATION_S3_URI ,
60+ "Mode" : MODE ,
61+ "GenerateCandidateDefinitionsOnly" : GENERATE_CANDIDATE_DEFINITIONS_ONLY ,
62+ "ProblemType" : PROBLEM_TYPE ,
63+ "TargetAttributeName" : TARGET_ATTRIBUTE_NAME ,
64+ "SampleWeightAttributeName" : SAMPLE_WEIGHT_ATTRIBUTE_NAME ,
65+ }
66+
67+ # Image classification params
68+
69+ IMAGE_CLASSIFICATION_PROBLEM_CONFIG = {
70+ "CompletionCriteria" : {
71+ "MaxCandidates" : MAX_CANDIDATES ,
72+ "MaxRuntimePerTrainingJobInSeconds" : MAX_RUNTIME_PER_TRAINING_JOB ,
73+ "MaxAutoMLJobRuntimeInSeconds" : TOTAL_JOB_RUNTIME ,
74+ },
75+ }
76+
77+ # Text classification
78+ CONTEXT_COLUMN = "text"
79+ TARGET_LABEL_COLUMN = "class"
80+
81+ TEXT_CLASSIFICATION_PROBLEM_CONFIG = {
82+ "CompletionCriteria" : {
83+ "MaxCandidates" : MAX_CANDIDATES ,
84+ "MaxRuntimePerTrainingJobInSeconds" : MAX_RUNTIME_PER_TRAINING_JOB ,
85+ "MaxAutoMLJobRuntimeInSeconds" : TOTAL_JOB_RUNTIME ,
86+ },
87+ "ContentColumn" : CONTEXT_COLUMN ,
88+ "TargetLabelColumn" : TARGET_LABEL_COLUMN ,
89+ }
90+
91+ # Text generation params
92+ BASE_MODEL_NAME = "base_model"
93+ TEXT_GENERATION_HYPER_PARAMS = {"test" : 1 }
94+ ACCEPT_EULA = True
95+
96+ TEXT_GENERATION_PROBLEM_CONFIG = {
97+ "CompletionCriteria" : {
98+ "MaxCandidates" : MAX_CANDIDATES ,
99+ "MaxRuntimePerTrainingJobInSeconds" : MAX_RUNTIME_PER_TRAINING_JOB ,
100+ "MaxAutoMLJobRuntimeInSeconds" : TOTAL_JOB_RUNTIME ,
101+ },
102+ "BaseModelName" : BASE_MODEL_NAME ,
103+ "TextGenerationHyperParameters" : TEXT_GENERATION_HYPER_PARAMS ,
104+ "ModelAccessConfig" : {
105+ "AcceptEula" : ACCEPT_EULA ,
106+ }
107+ }
108+
109+ # Time series forecasting params
110+ FORECAST_FREQUENCY = "1D"
111+ FORECAST_HORIZON = 5
112+ ITEM_IDENTIFIER_ATTRIBUTE_NAME = "identifier_attribute"
113+ TIMESTAMP_ATTRIBUTE_NAME = "timestamp_attribute"
114+ FORECAST_QUANTILES = ["p1" ]
115+ HOLIDAY_CONFIG = "DE"
116+
117+
118+ TIME_SERIES_FORECASTING_PROBLEM_CONFIG = {
119+ "CompletionCriteria" : {
120+ "MaxCandidates" : MAX_CANDIDATES ,
121+ "MaxRuntimePerTrainingJobInSeconds" : MAX_RUNTIME_PER_TRAINING_JOB ,
122+ "MaxAutoMLJobRuntimeInSeconds" : TOTAL_JOB_RUNTIME ,
123+ },
124+ "FeatureSpecificationS3Uri" : FEATURE_SPECIFICATION_S3_URI ,
125+ "ForecastFrequency" : FORECAST_FREQUENCY ,
126+ "ForecastHorizon" : FORECAST_HORIZON ,
127+ "TimeSeriesConfig" : {
128+ "ItemIdentifierAttributeName" : ITEM_IDENTIFIER_ATTRIBUTE_NAME ,
129+ "TargetAttributeName" : TARGET_ATTRIBUTE_NAME ,
130+ "TimestampAttributeName" : TIMESTAMP_ATTRIBUTE_NAME ,
131+ },
132+ "ForecastQuantiles" : FORECAST_QUANTILES ,
133+ "HolidayConfig" : [{
134+ "CountryCode" : HOLIDAY_CONFIG ,
135+ }],
136+ }
137+
138+ def test_tabular_problem_config_from_response ():
139+ problem_config = AutoMLTabularConfig .from_response_dict (TABULAR_PROBLEM_CONFIG )
140+ assert problem_config .algorithms_config == AUTO_ML_TABULAR_ALGORITHMS
141+ assert problem_config .feature_specification_s3_uri == FEATURE_SPECIFICATION_S3_URI
142+ assert problem_config .generate_candidate_definitions_only == GENERATE_CANDIDATE_DEFINITIONS_ONLY
143+ assert problem_config .max_candidates == MAX_CANDIDATES
144+ assert problem_config .max_runtime_per_training_job_in_seconds == MAX_RUNTIME_PER_TRAINING_JOB
145+ assert problem_config .max_total_job_runtime_in_seconds == TOTAL_JOB_RUNTIME
146+ assert problem_config .mode == MODE
147+ assert problem_config .problem_type == PROBLEM_TYPE
148+ assert problem_config .sample_weight_attribute_name == SAMPLE_WEIGHT_ATTRIBUTE_NAME
149+ assert problem_config .target_attribute_name == TARGET_ATTRIBUTE_NAME
150+
151+ def test_tabular_problem_config_to_request ():
152+ problem_config = AutoMLTabularConfig (
153+ target_attribute_name = TARGET_ATTRIBUTE_NAME ,
154+ algorithms_config = AUTO_ML_TABULAR_ALGORITHMS ,
155+ feature_specification_s3_uri = FEATURE_SPECIFICATION_S3_URI ,
156+ generate_candidate_definitions_only = GENERATE_CANDIDATE_DEFINITIONS_ONLY ,
157+ mode = MODE ,
158+ problem_type = PROBLEM_TYPE ,
159+ sample_weight_attribute_name = SAMPLE_WEIGHT_ATTRIBUTE_NAME ,
160+ max_candidates = MAX_CANDIDATES ,
161+ max_total_job_runtime_in_seconds = TOTAL_JOB_RUNTIME ,
162+ max_runtime_per_training_job_in_seconds = MAX_RUNTIME_PER_TRAINING_JOB ,
163+ )
164+
165+ assert problem_config .to_request_dict ()["TabularJobConfig" ] == TABULAR_PROBLEM_CONFIG
166+
167+ def test_image_classification_problem_config_from_response ():
168+ problem_config = AutoMLImageClassificationConfig .from_response_dict (IMAGE_CLASSIFICATION_PROBLEM_CONFIG )
169+ assert problem_config .max_candidates == MAX_CANDIDATES
170+ assert problem_config .max_runtime_per_training_job_in_seconds == MAX_RUNTIME_PER_TRAINING_JOB
171+ assert problem_config .max_total_job_runtime_in_seconds == TOTAL_JOB_RUNTIME
172+
173+ def test_image_classification_problem_config_to_request ():
174+ problem_config = AutoMLImageClassificationConfig (
175+ max_candidates = MAX_CANDIDATES ,
176+ max_total_job_runtime_in_seconds = TOTAL_JOB_RUNTIME ,
177+ max_runtime_per_training_job_in_seconds = MAX_RUNTIME_PER_TRAINING_JOB ,
178+ )
179+
180+ assert problem_config .to_request_dict ()["ImageClassificationJobConfig" ] == IMAGE_CLASSIFICATION_PROBLEM_CONFIG
181+
182+ def test_text_classification_problem_config_from_response ():
183+ problem_config = AutoMLTextClassificationConfig .from_response_dict (TEXT_CLASSIFICATION_PROBLEM_CONFIG )
184+ assert problem_config .content_column == CONTEXT_COLUMN
185+ assert problem_config .target_label_column == TARGET_LABEL_COLUMN
186+ assert problem_config .max_candidates == MAX_CANDIDATES
187+ assert problem_config .max_runtime_per_training_job_in_seconds == MAX_RUNTIME_PER_TRAINING_JOB
188+ assert problem_config .max_total_job_runtime_in_seconds == TOTAL_JOB_RUNTIME
189+
190+ def test_text_classification_to_request ():
191+ problem_config = AutoMLTextClassificationConfig (
192+ content_column = CONTEXT_COLUMN ,
193+ target_label_column = TARGET_LABEL_COLUMN ,
194+ max_candidates = MAX_CANDIDATES ,
195+ max_total_job_runtime_in_seconds = TOTAL_JOB_RUNTIME ,
196+ max_runtime_per_training_job_in_seconds = MAX_RUNTIME_PER_TRAINING_JOB ,
197+ )
198+
199+ assert problem_config .to_request_dict ()["TextClassificationJobConfig" ] == TEXT_CLASSIFICATION_PROBLEM_CONFIG
200+
201+ def test_text_generation_problem_config_from_response ():
202+ problem_config = AutoMLTextGenerationConfig .from_response_dict (TEXT_GENERATION_PROBLEM_CONFIG )
203+ assert problem_config .accept_eula == ACCEPT_EULA
204+ assert problem_config .base_model_name == BASE_MODEL_NAME
205+ assert problem_config .max_candidates == MAX_CANDIDATES
206+ assert problem_config .max_runtime_per_training_job_in_seconds == MAX_RUNTIME_PER_TRAINING_JOB
207+ assert problem_config .max_total_job_runtime_in_seconds == TOTAL_JOB_RUNTIME
208+ assert problem_config .text_generation_hyper_params == TEXT_GENERATION_HYPER_PARAMS
209+
210+ def test_text_generation_problem_config_to_request ():
211+ problem_config = AutoMLTextGenerationConfig (
212+ accept_eula = ACCEPT_EULA ,
213+ base_model_name = BASE_MODEL_NAME ,
214+ text_generation_hyper_params = TEXT_GENERATION_HYPER_PARAMS ,
215+ max_candidates = MAX_CANDIDATES ,
216+ max_total_job_runtime_in_seconds = TOTAL_JOB_RUNTIME ,
217+ max_runtime_per_training_job_in_seconds = MAX_RUNTIME_PER_TRAINING_JOB ,
218+ )
219+
220+ assert problem_config .to_request_dict ()["TextGenerationJobConfig" ] == TEXT_GENERATION_PROBLEM_CONFIG
221+
222+ def test_time_series_forecasting_problem_config_from_response ():
223+ problem_config = AutoMLTimeSeriesForecastingConfig .from_response_dict (TIME_SERIES_FORECASTING_PROBLEM_CONFIG )
224+ assert problem_config .forecast_frequency == FORECAST_FREQUENCY
225+ assert problem_config .forecast_horizon == FORECAST_HORIZON
226+ assert problem_config .item_identifier_attribute_name == ITEM_IDENTIFIER_ATTRIBUTE_NAME
227+ assert problem_config .target_attribute_name == TARGET_ATTRIBUTE_NAME
228+ assert problem_config .timestamp_attribute_name == TIMESTAMP_ATTRIBUTE_NAME
229+ assert problem_config .max_candidates == MAX_CANDIDATES
230+ assert problem_config .max_runtime_per_training_job_in_seconds == MAX_RUNTIME_PER_TRAINING_JOB
231+ assert problem_config .max_total_job_runtime_in_seconds == TOTAL_JOB_RUNTIME
232+ assert problem_config .forecast_quantiles == FORECAST_QUANTILES
233+ assert problem_config .holiday_config == HOLIDAY_CONFIG
234+ assert problem_config .feature_specification_s3_uri == FEATURE_SPECIFICATION_S3_URI
235+
236+ def test_time_series_forecasting_problem_config_to_request ():
237+ problem_config = AutoMLTimeSeriesForecastingConfig (
238+ forecast_frequency = FORECAST_FREQUENCY ,
239+ forecast_horizon = FORECAST_HORIZON ,
240+ item_identifier_attribute_name = ITEM_IDENTIFIER_ATTRIBUTE_NAME ,
241+ target_attribute_name = TARGET_ATTRIBUTE_NAME ,
242+ timestamp_attribute_name = TIMESTAMP_ATTRIBUTE_NAME ,
243+ forecast_quantiles = FORECAST_QUANTILES ,
244+ holiday_config = HOLIDAY_CONFIG ,
245+ feature_specification_s3_uri = FEATURE_SPECIFICATION_S3_URI ,
246+ max_candidates = MAX_CANDIDATES ,
247+ max_total_job_runtime_in_seconds = TOTAL_JOB_RUNTIME ,
248+ max_runtime_per_training_job_in_seconds = MAX_RUNTIME_PER_TRAINING_JOB ,
249+ )
250+
251+ assert problem_config .to_request_dict ()["TimeSeriesForecastingJobConfig" ] == TIME_SERIES_FORECASTING_PROBLEM_CONFIG
0 commit comments