diff --git a/use-cases/athena_ml_workflow_end_to_end/athena_ml_workflow_end_to_end.ipynb b/use-cases/athena_ml_workflow_end_to_end/athena_ml_workflow_end_to_end.ipynb new file mode 100644 index 0000000000..e723f14c7e --- /dev/null +++ b/use-cases/athena_ml_workflow_end_to_end/athena_ml_workflow_end_to_end.ipynb @@ -0,0 +1,1456 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9fbac6ee", + "metadata": {}, + "source": [ + "# Create an end to end machine learning workflow using Amazon Athena\n", + "---\n", + "\n", + "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \\n\",\n", + "\n", + "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "ece13bd7-19b2-47b3-976d-cf636fa68003", + "metadata": {}, + "source": [ + "Importing and transforming data can be one of the most challenging tasks in a machine learning workflow. We provide you with a Jupyter notebook that demonstrates a cost-effective strategy for an extract, transform, and load (ETL) workflow. Using Amazon Simple Storage Service (Amazon S3) and Amazon Athena, you learn how to query and transform data from a Jupyter notebook. Amazon S3 is an object storage service that allows you to store data and machine learning artifacts. Amazon Athena enables you to interactively query the data stored in those buckets, saving each query as a CSV file in an Amazon S3 location.\n", + "\n", + "The tutorial imports 16 CSV files for the 2019 NYC taxi dataset from multiple Amazon S3 locations. The goal is to predict the fare amount for each ride. From these 16 files, the notebook creates a single ride fare dataset and a single ride info dataset with deduplicated values. We join the deduplicated datasets into a single dataset.\n", + "\n", + "Amazon Athena stores the query results as a CSV file in the specified location. We provide the output to a SageMaker Processing Job to split the data into training, validation, and test sets. While data can be split using queries, a processing job ensures that the data is in a format that's parseable by the XGBoost algorithm.\n", + "\n", + "__Prerequisites:__\n", + "\n", + "The notebook must be run in the us-east-1 AWS Region. You also need your own Amazon S3 bucket and a database within Amazon Athena. You won't be able to access the data used in the tutorial otherwise.\n", + "\n", + "For information about creating a bucket, see [Creating a bucket](https://docs.aws.amazon.com/AmazonS3/latest/userguide/create-bucket-overview.html). For information about creating a database, see [Create a database](https://docs.aws.amazon.com/athena/latest/ug/getting-started.html#step-1-create-a-database).\n", + "\n", + "Amazon Athena uses the AWS Glue Data Catalog to read the data from Amazon S3 into a database. You must have permissions to use Glue. To clean up, you also need permissions to delete the bucket you've created. For a quick guide to providing permissions, see [Setting up\n", + "](http://parsash-clouddesk-2024.aka.corp.amazon.com/sagemaker-dg/src/AWSIronmanApiDoc/build/server-root/sagemaker/latest/dg/create-end-to-end-ml-workflow-athena.html#setting-up)." + ] + }, + { + "cell_type": "markdown", + "id": "0b11693f-7c35-41cf-8e4b-4f86eea8f3b0", + "metadata": {}, + "source": [ + "## Solution overview\n", + "\n", + "To create the end to end workflow, we do the following:\n", + "\n", + "1. Create an Amazon Athena client within the us-east-1 AWS Region.\n", + "2. Define the run_athena_query function that runs queries and prints out the status in the following cell.\n", + "3. Create the `ride_fare` table within your database using all ride fare tables for the year 2019.\n", + "4. Create the `ride_info` table using ride info table for the year 2019.\n", + "5. Create the `ride_info_deduped` and `ride_fare_deduped` tables that have all duplicate values removed from the original tables.\n", + "6. Run test queries to get the first ten rows of each table to see whether they have data.\n", + "7. Define the `get_query_results` function that takes the query ID and returns comma separated values that can be stored as a dataframe.\n", + "8. View the results of the test queries within pandas dataframes.\n", + "9. Join the `ride_info_deduped` and `ride_fare_deduped` tables into the `combined_ride_data_deduped` table.\n", + "10. Select all values in the combined table.\n", + "11. Define the `get_csv_file_location` function to get the Amazon S3 location of the query results.\n", + "12. Download the CSV file to our environment.\n", + "13. Perform Exploratory Data Analysis (EDA) on the data.\n", + "14. Use the results of the EDA to select the relevant features in query.\n", + "15. Use the `get_csv_file_location` function to get the location of those query results.\n", + "16. Split the data into training, validation, and test sets using a processing job.\n", + "17. Download the test dataset.\n", + "18. Take a 20 row sample from the test dataset.\n", + "20. Create a dataframe with 20 rows of actual and predicted values.\n", + "21. Calculate the RMSE of the data.\n", + "22. Clean up the resources created within the notebook." + ] + }, + { + "cell_type": "markdown", + "id": "54d7468c-c77b-4273-b02d-9e9c4e884d46", + "metadata": {}, + "source": [ + "### Define the run_athena_query function\n", + "\n", + "In the following cell, we define the `run_athena_query` function. It runs an Athena query and waits for its completion.\n", + "\n", + "It takes the following arguments:\n", + "\n", + "- query_string (str): The SQL query to be executed.\n", + "- database_name (str): The name of the Athena database.\n", + "- output_location (str): The S3 location where the query results are stored.\n", + "\n", + "\n", + "It returns the query execution ID string." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ab1ff0e-fcde-4976-a1cd-51e75c18deb2", + "metadata": {}, + "outputs": [], + "source": [ + "# Import required libraries\n", + "import time\n", + "import boto3\n", + "\n", + "\n", + "def run_athena_query(query_string, database_name, output_location):\n", + " # Create an Athena client\n", + " athena_client = boto3.client(\"athena\", region_name=\"us-east-1\")\n", + "\n", + " # Start the query execution\n", + " response = athena_client.start_query_execution(\n", + " QueryString=query_string,\n", + " QueryExecutionContext={\"Database\": database_name},\n", + " ResultConfiguration={\"OutputLocation\": output_location},\n", + " )\n", + "\n", + " query_execution_id = response[\"QueryExecutionId\"]\n", + " print(f\"Query execution ID: {query_execution_id}\")\n", + "\n", + " while True:\n", + " # Check the query execution status\n", + " query_status = athena_client.get_query_execution(QueryExecutionId=query_execution_id)\n", + " state = query_status[\"QueryExecution\"][\"Status\"][\"State\"]\n", + "\n", + " if state == \"SUCCEEDED\":\n", + " print(\"Query executed successfully.\")\n", + " break\n", + " elif state == \"FAILED\":\n", + " print(\n", + " f\"Query failed with error: {query_status['QueryExecution']['Status']['StateChangeReason']}\"\n", + " )\n", + " break\n", + " else:\n", + " print(f\"Query is currently in {state} state. Waiting for completion...\")\n", + " time.sleep(5) # Wait for 5 seconds before checking again\n", + "\n", + " return query_execution_id" + ] + }, + { + "cell_type": "markdown", + "id": "8df0da48-89b3-45c2-a479-af422a51b962", + "metadata": {}, + "source": [ + "### Create the ride_fare table\n", + "\n", + "We've provided you with the query. You most provide the name of the database you created within Amazon Athena and the Amazon S3 output location. If you're not sure about how to specify the output location, provide the name of the S3 bucket. After running the query, you should get a message that says \"Query executed successfully.\" and a 36 character string in single quotes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64131b68-de28-4060-bb75-8148902846f7", + "metadata": {}, + "outputs": [], + "source": [ + "# SQL query to create the 'ride_fare' table\n", + "create_ride_fare_table = \"\"\"\n", + "CREATE EXTERNAL TABLE `ride_fare` (\n", + " `ride_id` bigint, \n", + " `payment_type` smallint, \n", + " `fare_amount` float, \n", + " `extra` float, \n", + " `mta_tax` float, \n", + " `tip_amount` float, \n", + " `tolls_amount` float, \n", + " `total_amount` float\n", + ")\n", + "ROW FORMAT DELIMITED \n", + " FIELDS TERMINATED BY ',' \n", + " LINES TERMINATED BY '\\n' \n", + "STORED AS INPUTFORMAT \n", + " 'org.apache.hadoop.mapred.TextInputFormat' \n", + "OUTPUTFORMAT \n", + " 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'\n", + "LOCATION\n", + " 's3://dsoaws/nyc-taxi-orig-cleaned-split-csv-with-header-per-year-multiple-files/ride-fare/year=2019'\n", + "TBLPROPERTIES (\n", + " 'skip.header.line.count'='1', \n", + " 'transient_lastDdlTime'='1716908234'\n", + ");\n", + "\"\"\"\n", + "\n", + "# Athena database name\n", + "database = \"example-database-name\"\n", + "\n", + "# S3 location for query results\n", + "s3_output_location = \"s3://example-s3-bucket/example-s3-prefix\"\n", + "\n", + "# Execute the query to create the 'ride_fare' table\n", + "run_athena_query(create_ride_fare_table, database, s3_output_location)" + ] + }, + { + "cell_type": "markdown", + "id": "ebe5920a-4c36-48c0-9cb4-e418c738aa59", + "metadata": {}, + "source": [ + "### Create the ride fare table with the duplicates removed" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d249cc5-2d53-4274-8f5e-6ab09ccd3ea6", + "metadata": {}, + "outputs": [], + "source": [ + "# SQL query to create a new table with duplicates removed\n", + "remove_duplicates_from_ride_fare = \"\"\"\n", + "CREATE TABLE ride_fare_deduped\n", + "AS\n", + "SELECT DISTINCT *\n", + "FROM ride_fare\n", + "\"\"\"\n", + "\n", + "# Run the preceding query\n", + "run_athena_query(remove_duplicates_from_ride_fare, database, s3_output_location)" + ] + }, + { + "cell_type": "markdown", + "id": "2ac7fc34-37cb-4c46-993b-38f18576361c", + "metadata": {}, + "source": [ + "### Create the ride_info table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f9a68b9-bd11-49e9-ad72-b44b43d32e47", + "metadata": {}, + "outputs": [], + "source": [ + "# SQL query to create the ride_info table\n", + "create_ride_info_table_query = \"\"\"\n", + "CREATE EXTERNAL TABLE `ride_info` (\n", + " `ride_id` bigint, \n", + " `vendor_id` smallint, \n", + " `passenger_count` smallint, \n", + " `pickup_at` string, \n", + " `dropoff_at` string, \n", + " `trip_distance` float, \n", + " `rate_code_id` int, \n", + " `store_and_fwd_flag` string\n", + ")\n", + "ROW FORMAT DELIMITED \n", + " FIELDS TERMINATED BY ',' \n", + " LINES TERMINATED BY '\\n' \n", + "STORED AS INPUTFORMAT \n", + " 'org.apache.hadoop.mapred.TextInputFormat' \n", + "OUTPUTFORMAT \n", + " 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'\n", + "LOCATION\n", + " 's3://dsoaws/nyc-taxi-orig-cleaned-split-csv-with-header-per-year-multiple-files/ride-info/year=2019'\n", + "TBLPROPERTIES (\n", + " 'skip.header.line.count'='1', \n", + " 'transient_lastDdlTime'='1716907328'\n", + ");\n", + "\"\"\"\n", + "\n", + "# Run the query to create the ride_info table\n", + "run_athena_query(create_ride_info_table_query, database, s3_output_location)" + ] + }, + { + "cell_type": "markdown", + "id": "4c17ea01-2c1e-4c10-a539-0d00e6e4bb1d", + "metadata": {}, + "source": [ + "### Create the ride info table with the duplicates removed" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "263d883c-f189-43c0-9fbd-1a45093984e9", + "metadata": {}, + "outputs": [], + "source": [ + "# SQL query to create table with duplicates removed\n", + "remove_duplicates_from_ride_info = \"\"\"\n", + "CREATE TABLE ride_info_deduped\n", + "AS\n", + "SELECT DISTINCT *\n", + "FROM ride_info\n", + "\"\"\"\n", + "\n", + "# Run the query to create the table with the duplicates removed\n", + "run_athena_query(remove_duplicates_from_ride_info, database, s3_output_location)" + ] + }, + { + "cell_type": "markdown", + "id": "a19f8e17-42c5-4412-96a8-b7bc1a74c73c", + "metadata": {}, + "source": [ + "### Run a test query on ride_info_deduped" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6db6bb67-44a9-4ff4-b662-ad969a84d3d8", + "metadata": {}, + "outputs": [], + "source": [ + "test_ride_info_query = \"\"\"\n", + "SELECT * FROM ride_info_deduped limit 10\n", + "\"\"\"\n", + "\n", + "run_athena_query(test_ride_info_query, database, s3_output_location)" + ] + }, + { + "cell_type": "markdown", + "id": "b969d31f-e14a-473b-aefa-a1a19bc312f7", + "metadata": {}, + "source": [ + "### Run a test query on ride_fare_deduped" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92d8be21-3f20-453d-8b84-516571d9854d", + "metadata": {}, + "outputs": [], + "source": [ + "test_ride_fare_query = \"\"\"\n", + "SELECT * FROM ride_fare_deduped limit 10\n", + "\"\"\"\n", + "\n", + "run_athena_query(test_ride_fare_query, database, s3_output_location)" + ] + }, + { + "cell_type": "markdown", + "id": "c86acade-c4b9-4918-860e-11ee5e386a44", + "metadata": {}, + "source": [ + "### Define the `get_query_results` function\n", + "\n", + "In the following cell, we define the `get_query_results` function to get the query results in CSV format. The function gets the 36 character query execution ID string. The end of the output of the preceding cell is an example of a query execution ID string." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50e87ba6-42e9-4d99-862e-7eae16ad810e", + "metadata": {}, + "outputs": [], + "source": [ + "import io\n", + "\n", + "\n", + "def get_query_results(query_execution_id):\n", + " athena_client = boto3.client(\"athena\", region_name=\"us-east-1\")\n", + " s3 = boto3.client(\"s3\")\n", + "\n", + " # Get the query execution details\n", + " query_execution = athena_client.get_query_execution(QueryExecutionId=query_execution_id)\n", + " s3_location = query_execution[\"QueryExecution\"][\"ResultConfiguration\"][\"OutputLocation\"]\n", + "\n", + " # Extract bucket and key from S3 output location\n", + " bucket_name, key = s3_location.split(\"/\", 2)[2].split(\"/\", 1)\n", + "\n", + " # Get the CSV file location\n", + " obj = s3.get_object(Bucket=bucket_name, Key=key)\n", + " csv_data = obj[\"Body\"].read().decode(\"utf-8\")\n", + " csv_buffer = io.StringIO(csv_data)\n", + "\n", + " return csv_buffer" + ] + }, + { + "cell_type": "markdown", + "id": "d3d2ed4f-d7e6-49dc-9ea1-0dc66f252c76", + "metadata": {}, + "source": [ + "### Read `ride_info_deduped` test query into a dataframe\n", + "\n", + "Specify the query execution ID string in the `get_query_results` function. The output is the head of the dataframe. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b04abae5-936b-4d96-98e8-d2e2b6a17b9c", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "# Provide the query execution id of the test_ride_info query to get the query results\n", + "ride_info_sample = get_query_results(\"test_ride_info_query_execution_id\")\n", + "\n", + "df_ride_info_sample = pd.read_csv(ride_info_sample)\n", + "\n", + "df_ride_info_sample.head()" + ] + }, + { + "cell_type": "markdown", + "id": "6d10ebe2-8c17-4f2b-97fe-a5f339cd89d7", + "metadata": {}, + "source": [ + "### Read `ride_fare_deduped` test query into a dataframe\n", + "\n", + "Specify the query execution ID string in the `get_query_results` function. The output is the head of the resulting dataframe. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be89957f-31b1-4710-bfc2-178d6db18592", + "metadata": {}, + "outputs": [], + "source": [ + "# Provide the query execution id of the test_ride_fare query to get the query results\n", + "\n", + "ride_fare_sample = get_query_results(\"test_ride_fare_query_execution_id\")\n", + "\n", + "df_ride_fare_sample = pd.read_csv(ride_fare_sample)\n", + "\n", + "df_ride_fare_sample.head()" + ] + }, + { + "cell_type": "markdown", + "id": "3867e94a-7c89-48ed-86aa-92b09d47740d", + "metadata": {}, + "source": [ + "### Join the deduplicated tables together" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b8a76635-3c09-4cbc-b1b4-9318dc611250", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# SQL query to join the tables into a single table containing all the data.\n", + "create_ride_joined_deduped = \"\"\"\n", + "CREATE TABLE combined_ride_data_deduped AS\n", + "SELECT \n", + " rfs.ride_id, \n", + " rfs.payment_type, \n", + " rfs.fare_amount, \n", + " rfs.extra, \n", + " rfs.mta_tax, \n", + " rfs.tip_amount, \n", + " rfs.tolls_amount, \n", + " rfs.total_amount,\n", + " ris.vendor_id, \n", + " ris.passenger_count, \n", + " ris.pickup_at, \n", + " ris.dropoff_at, \n", + " ris.trip_distance, \n", + " ris.rate_code_id, \n", + " ris.store_and_fwd_flag\n", + "FROM \n", + " ride_fare_deduped rfs\n", + "JOIN \n", + " ride_info_deduped ris\n", + "ON \n", + " rfs.ride_id = ris.ride_id;\n", + ";\n", + "\"\"\"\n", + "\n", + "# Run the query to create the ride_data_deduped table\n", + "run_athena_query(create_ride_joined_deduped, database, s3_output_location)" + ] + }, + { + "cell_type": "markdown", + "id": "b2f9f6ca-f668-42ab-ac4a-371a82e1786d", + "metadata": {}, + "source": [ + "### Select all values from the deduplicated table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0791e57-4351-4f27-a8f9-ad741441d214", + "metadata": {}, + "outputs": [], + "source": [ + "# SQL query to select all values from the table and create the dataset that we're using for our analysis\n", + "ride_combined_full_table_query = \"\"\"\n", + "SELECT * FROM combined_ride_data_deduped\n", + "\"\"\"\n", + "\n", + "# Run the query to select all values from the combined_ride_data_deduped table\n", + "run_athena_query(ride_combined_full_table_query, database, s3_output_location)" + ] + }, + { + "cell_type": "markdown", + "id": "4492eaa8-b0cc-4a4d-9810-e9f1a39f21c7", + "metadata": {}, + "source": [ + "### Define get_csv_file_location function and get Amazon S3 location of query results\n", + "\n", + "Specify the query ID from the preceding cell in the function call. The output is the Amazon S3 URI of the dataset. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "97373c52-882b-4e44-8d75-a80d8d8c58df", + "metadata": {}, + "outputs": [], + "source": [ + "# Function to get the Amazon S3 URI location of Amazon Athena select statements\n", + "def get_csv_file_location(query_execution_id):\n", + " athena_client = boto3.client(\"athena\", region_name=\"us-east-1\")\n", + " query_execution = athena_client.get_query_execution(QueryExecutionId=query_execution_id)\n", + " s3_location = query_execution[\"QueryExecution\"][\"ResultConfiguration\"][\"OutputLocation\"]\n", + "\n", + " return s3_location\n", + "\n", + "\n", + "# Provide the 36 character string at the end of the output of the preceding cell as the query.\n", + "get_csv_file_location(\"ride_combined_full_table_query_execution_id\")" + ] + }, + { + "cell_type": "markdown", + "id": "c7bf4f25-dc86-4f1f-95de-967c20c5a7af", + "metadata": {}, + "source": [ + "### Download the dataset and rename it\n", + "\n", + "Replace the example S3 path in the following cell with the output of the preceding cell. The second command renames the CSV file it downloads to `nyc-taxi-whole-dataset.csv`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "954022d5-bdf9-4dbd-be2e-66d0009ce522", + "metadata": {}, + "outputs": [], + "source": [ + "# Use the S3 URI location returned from the preceding cell to download the dataset and rename it.\n", + "!aws s3 cp s3://example-s3-bucket/ride_combined_full_table_query_execution_id.csv .\n", + "!mv ride_combined_full_table_query_execution_id.csv nyc-taxi-whole-dataset.csv" + ] + }, + { + "cell_type": "markdown", + "id": "4d34ca22-8417-46f5-982f-dd22816f1d93", + "metadata": {}, + "source": [ + "### Get a 20,000 row sample and some information about it" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79d2f2a5-5111-4fb8-90f3-67474f1072c1", + "metadata": {}, + "outputs": [], + "source": [ + "sample_nyc_taxi_combined = pd.read_csv(\"nyc-taxi-whole-dataset.csv\", nrows=20000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9dececa-272d-458c-9f64-baa13eca0832", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Dataset shape: \", sample_nyc_taxi_combined.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c117a0f-429e-4913-aded-c839675f9e17", + "metadata": {}, + "outputs": [], + "source": [ + "df = sample_nyc_taxi_combined\n", + "\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3c56da9-0a1c-4c58-93e3-77260dfff40b", + "metadata": {}, + "outputs": [], + "source": [ + "df.info()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc25bcd9-a4b1-4491-867f-7534336d1ecd", + "metadata": {}, + "outputs": [], + "source": [ + "df.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18bd92b1-962a-40f2-b15f-7351d869f390", + "metadata": {}, + "outputs": [], + "source": [ + "df[\"vendor_id\"].value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4c4997f-85d8-4f57-a60c-51e3568cfe2e", + "metadata": {}, + "outputs": [], + "source": [ + "df[\"passenger_count\"].value_counts()" + ] + }, + { + "cell_type": "markdown", + "id": "ae527104-9312-498c-b0ee-d1e2303bf500", + "metadata": {}, + "source": [ + "### View the distribution of fare amount values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "641c278d-8fed-42b8-98d1-becba90d6259", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot to find the distribution of ride fare values\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.hist(df[\"fare_amount\"], edgecolor=\"black\", bins=30, range=(0, 100))\n", + "plt.xlabel(\"Fare Amount\")\n", + "plt.ylabel(\"Count\")\n", + "plt.show" + ] + }, + { + "cell_type": "markdown", + "id": "65d141c4-95ba-4176-8794-1475cb8f2a62", + "metadata": {}, + "source": [ + "### Make sure that all rows are unique" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d484f57-f150-45b5-9cc5-cc10a6e8e9f1", + "metadata": {}, + "outputs": [], + "source": [ + "df[\"ride_id\"].nunique()" + ] + }, + { + "cell_type": "markdown", + "id": "abc60782-4411-46e0-9d31-55adaa4dd1f5", + "metadata": {}, + "source": [ + "### Drop the store_and_fwd flag\n", + "\n", + "Determining its relevance isn't in scope for this tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f627790e-8aed-48e3-9c5d-52775bbb124d", + "metadata": {}, + "outputs": [], + "source": [ + "df.drop(\"store_and_fwd_flag\", axis=1, inplace=True)" + ] + }, + { + "cell_type": "markdown", + "id": "96fc51be-6a0f-44e6-abb8-2a6bf9188367", + "metadata": {}, + "source": [ + "### Drop the time series columns\n", + "\n", + "Analyzing the time series data also isn't in scope for this analysis." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c359f4db-b503-4d80-bb4c-55dc411f9b5e", + "metadata": {}, + "outputs": [], + "source": [ + "# We're dropping the time series columns to streamline the analysis.\n", + "time_series_columns_to_drop = [\"pickup_at\", \"dropoff_at\"]\n", + "df.drop(columns=time_series_columns_to_drop, inplace=True)" + ] + }, + { + "cell_type": "markdown", + "id": "ad5d1df6-d418-483a-b06d-848205f3f8ed", + "metadata": {}, + "source": [ + "### Install seaborn and create scatterplots" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05abe8af-bf44-471b-b130-19cee0dd822f", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install seaborn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6a10b9b-e916-48a9-88f5-ae94db2f6576", + "metadata": {}, + "outputs": [], + "source": [ + "# Create visualizations showing correlations between variables.\n", + "import seaborn as sns\n", + "\n", + "target = \"fare_amount\"\n", + "features = [col for col in df.columns if col != target]\n", + "\n", + "# Create a figure with subplots\n", + "fig, axes = plt.subplots(nrows=1, ncols=len(features), figsize=(50, 10))\n", + "\n", + "# Create scatter plots\n", + "for i, feature in enumerate(features):\n", + " sns.scatterplot(x=df[feature], y=df[target], ax=axes[i])\n", + " axes[i].set_title(f\"{feature} vs {target}\")\n", + " axes[i].set_xlabel(feature)\n", + " axes[i].set_ylabel(target)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "11c33316-1502-46b1-b265-6cf43d0d8f1d", + "metadata": {}, + "source": [ + "## Calculate the correlation coefficient between each feature and fare amount" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8dff114-adb5-4b34-a788-b93e42a2fee4", + "metadata": {}, + "outputs": [], + "source": [ + "# extra and mta_tax seem weakly correlated\n", + "# total_amount is almost perfectly correlated, indicating target leakage.\n", + "continuous_features = [\n", + " \"tip_amount\",\n", + " \"tolls_amount\",\n", + " \"extra\",\n", + " \"mta_tax\",\n", + " \"total_amount\",\n", + " \"trip_distance\",\n", + "]\n", + "\n", + "for i in continuous_features:\n", + " correlation = df[\"fare_amount\"].corr(df[i])\n", + " print(i, correlation)" + ] + }, + { + "cell_type": "markdown", + "id": "7ea2dc4f-c366-43f0-8a81-44ecd8289a3d", + "metadata": {}, + "source": [ + "### Calculate a one way ANOVA between the groups\n", + "\n", + "From running the ANOVA, `mta_tax` and `extra` have the most variance between the groups. We're using them as features to train our model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e083025-3312-4fd9-8cd2-4c8e37db5859", + "metadata": {}, + "outputs": [], + "source": [ + "# The mta tax and extra have the most variance between the groups\n", + "from scipy.stats import f_oneway\n", + "\n", + "# Separate features and target variable\n", + "X = df[[\"payment_type\", \"extra\", \"mta_tax\", \"vendor_id\", \"passenger_count\"]]\n", + "y = df[\"fare_amount\"]\n", + "\n", + "# Perform one-way ANOVA for each feature\n", + "for feature in X.columns:\n", + " groups = [y[X[feature] == group] for group in X[feature].unique()]\n", + " if len(groups) > 1:\n", + " f_statistic, p_value = f_oneway(*groups)\n", + " print(f\"Feature: {feature}, F-statistic: {f_statistic:.2f}, p-value: {p_value:.5f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "5b2f3d07-8010-43c4-873e-f462fd0bd94e", + "metadata": {}, + "source": [ + "### Run a query to get the dataset we're using for ML workflow\n", + "\n", + "The XGBoost algorithm on Amazon SageMaker uses the first column as the target column. `fare_amount` must be the first column in our query." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0dbcf599-076c-468e-9e9b-2e0bd53c3fa7", + "metadata": {}, + "outputs": [], + "source": [ + "# Final select statement has tip_amount, tolls_amount, extra, mta_tax, trip_distance\n", + "ride_combined_notebook_relevant_features_query = \"\"\"\n", + "SELECT fare_amount, tip_amount, tolls_amount, extra, mta_tax, trip_distance FROM combined_ride_data_deduped\n", + "\"\"\"\n", + "\n", + "run_athena_query(ride_combined_notebook_relevant_features_query, database, s3_output_location)" + ] + }, + { + "cell_type": "markdown", + "id": "4bbfeb06-e0e2-4ce0-9e73-98894053592d", + "metadata": {}, + "source": [ + "### Get the Amazon S3 URI of the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "624a7833-c815-480e-b1da-c29da3d02c76", + "metadata": {}, + "outputs": [], + "source": [ + "get_csv_file_location(\"ride_combined_notebook_relevant_features_query_execution_id\")" + ] + }, + { + "cell_type": "markdown", + "id": "4632047c-eabc-495a-9758-b55b78937f73", + "metadata": {}, + "source": [ + "### Run a SageMaker processing job to split the data\n", + "\n", + "The code in `processing_data_split.py` splits the dataset into training, validation, and test sets. We use a SageMaker processing job to provide the compute needed to transform large volumes of data. For more information about processing jobs, see [Use processing jobs to run data transformation workloads](https://docs.aws.amazon.com/sagemaker/latest/dg/processing-job.html). For more information about running sci-kit scripts, see [Data Processing with scikit-learn](https://docs.aws.amazon.com/sagemaker/latest/dg/use-scikit-learn-processing-container.html). \n", + "\n", + "For faster processing, we recommend using an `instance_count` of `2`, but you can use whatever value you prefer.\n", + "\n", + "For `source` within the `ProcessingInput` function, replace `'s3://example-s3-bucket/ride_combined_notebook_relevant_features_query_execution_id.csv'` with the output of the preceding cell. Within `processing_data_split.py`, you specify `/opt/ml/processing/input/query-id` as the `input_path`. The processing job is copying the query results to a location within its own container.\n", + "\n", + "For `Destination` under `ProcessingOutput`, replace `example-s3-bucket` with the Amazon S3 bucket that you've created." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "788cae3c-a34b-4ee0-899e-0a461e21b210", + "metadata": {}, + "outputs": [], + "source": [ + "import sagemaker\n", + "from sagemaker.sklearn.processing import SKLearnProcessor\n", + "from sagemaker.processing import ProcessingInput, ProcessingOutput\n", + "\n", + "\n", + "# Define the SageMaker execution role\n", + "role = sagemaker.get_execution_role()\n", + "\n", + "# Define the SKLearnProcessor\n", + "sklearn_processor = SKLearnProcessor(\n", + " framework_version=\"0.20.0\", role=role, instance_type=\"ml.m5.4xlarge\", instance_count=2\n", + ")\n", + "\n", + "# Run the processing job\n", + "sklearn_processor.run(\n", + " code=\"processing_data_split.py\",\n", + " inputs=[\n", + " ProcessingInput(\n", + " source=\"s3://example-s3-bucket/ride_combined_notebook_relevant_features_query_execution_id.csv\",\n", + " destination=\"/opt/ml/processing/input\",\n", + " )\n", + " ],\n", + " outputs=[\n", + " ProcessingOutput(\n", + " source=\"/opt/ml/processing/output/train\",\n", + " destination=\"s3://ux360-nyc-taxi-dogfooding/output/train\",\n", + " ),\n", + " ProcessingOutput(\n", + " source=\"/opt/ml/processing/output/validation\",\n", + " destination=\"s3://ux360-nyc-taxi-dogfooding/output/validation\",\n", + " ),\n", + " ProcessingOutput(\n", + " source=\"/opt/ml/processing/output/test\",\n", + " destination=\"s3://ux360-nyc-taxi-dogfooding/output/test\",\n", + " ),\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "bc164657-fd8f-4f96-89ff-23e991945ea4", + "metadata": {}, + "source": [ + "### Verify that train.csv is in the location that you've specified" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41cb0fb0-079d-421d-a4b8-005ee38fc472", + "metadata": {}, + "outputs": [], + "source": [ + "# Verify that train.csv is in the location that you've specified\n", + "!aws s3 ls s3://ux360-nyc-taxi-dogfooding/output/train/train.csv" + ] + }, + { + "cell_type": "markdown", + "id": "d0d2ba3c-fd6d-4aa0-b75b-92ba5a70ad00", + "metadata": {}, + "source": [ + "### Verify that val.csv is in the location that you've specified" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee3f29f1-a135-4bf6-bba5-595fb80c471d", + "metadata": {}, + "outputs": [], + "source": [ + "# Verify that val.csv is in the location that you've specified\n", + "!aws s3 ls s3://ux360-nyc-taxi-dogfooding/output/validation/val.csv" + ] + }, + { + "cell_type": "markdown", + "id": "c92d4b89-65a5-474b-aa22-dcb442c344b9", + "metadata": {}, + "source": [ + "### Specify `train.csv` and `val.csv` as the input for the training job" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e4e4113-b76c-49d5-a3b0-2327eb174fdf", + "metadata": {}, + "outputs": [], + "source": [ + "from sagemaker.session import TrainingInput\n", + "\n", + "bucket = \"example-s3-bucket\"\n", + "\n", + "train_input = TrainingInput(f\"s3://{bucket}/output/train/train.csv\", content_type=\"csv\")\n", + "validation_input = TrainingInput(f\"s3://{bucket}/output/validation/val.csv\", content_type=\"csv\")" + ] + }, + { + "cell_type": "markdown", + "id": "866262fe-5737-49af-9cde-af55575e07d1", + "metadata": {}, + "source": [ + "### Specify the model container and output location of the model artifact\n", + "\n", + "Specify the S3 location of the trained model artifact. You can access it later.\n", + "\n", + "It also gets the URI of the container image. We used version `1.2-2` of the XGBoost container image, but you can specify a different version. For more information about XGBoost container images, see [Use the XGBoost algorithm with Amazon SageMaker](https://docs.aws.amazon.com/sagemaker/latest/dg/xgboost.html). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5b6a9b2-54e5-4dfd-9a5e-3c7442f6d5af", + "metadata": {}, + "outputs": [], + "source": [ + "# Getting the XGBoost container that's in us-east-1\n", + "prefix = \"training-output-data\"\n", + "region = \"us-east-1\"\n", + "\n", + "from sagemaker.debugger import Rule, ProfilerRule, rule_configs\n", + "from sagemaker.session import TrainingInput\n", + "\n", + "s3_output_location = f\"s3://{bucket}/{prefix}/xgboost_model\"\n", + "\n", + "container = sagemaker.image_uris.retrieve(\"xgboost\", region, \"1.2-2\")\n", + "print(container)" + ] + }, + { + "cell_type": "markdown", + "id": "d04e189b-6f38-44cf-a046-6791abd32c00", + "metadata": {}, + "source": [ + "### Define the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44efb3a1-acf0-4193-987f-85025c7c3894", + "metadata": {}, + "outputs": [], + "source": [ + "xgb_model = sagemaker.estimator.Estimator(\n", + " image_uri=container,\n", + " role=role,\n", + " instance_count=2,\n", + " region=region,\n", + " instance_type=\"ml.m5.4xlarge\",\n", + " volume_size=5,\n", + " output_path=s3_output_location,\n", + " sagemaker_session=sagemaker.Session(),\n", + " rules=[\n", + " Rule.sagemaker(rule_configs.create_xgboost_report()),\n", + " ProfilerRule.sagemaker(rule_configs.ProfilerReport()),\n", + " ],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "44f1c8b1-7bf0-4381-9128-b00c2bfcf9f1", + "metadata": {}, + "source": [ + "### Set the model hyperparameters\n", + "\n", + "For the purposes of running the training job more quickly, we set the number of training rounds to 10." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e28512bf-d246-4a46-a0c8-24d1a8ad65a8", + "metadata": {}, + "outputs": [], + "source": [ + "xgb_model.set_hyperparameters(\n", + " max_depth=5,\n", + " eta=0.2,\n", + " gamma=4,\n", + " min_child_weight=6,\n", + " subsample=0.7,\n", + " objective=\"reg:squarederror\",\n", + " num_round=10,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e5b6ed18-990f-4ec7-9d42-6965ec67e2ce", + "metadata": {}, + "source": [ + "### Train the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58b77fc0-407d-4743-ae35-7bc7b04478e6", + "metadata": {}, + "outputs": [], + "source": [ + "xgb_model.fit({\"train\": train_input, \"validation\": validation_input}, wait=True)" + ] + }, + { + "cell_type": "markdown", + "id": "f0f8be08-10a5-4204-8f8b-60235d4b1f04", + "metadata": {}, + "source": [ + "### Deploy the model\n", + "\n", + "Copy the name of the model endpoint. We use it for our model evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1aa7bc3-feee-4602-a64c-8c1e08526d03", + "metadata": {}, + "outputs": [], + "source": [ + "xgb_predictor = xgb_model.deploy(initial_instance_count=1, instance_type=\"ml.m4.xlarge\")" + ] + }, + { + "cell_type": "markdown", + "id": "ddcf330c-8add-437d-af1f-687ed3ebc78d", + "metadata": {}, + "source": [ + "### Download the test.csv file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9cc4eea-a6d0-418f-ab35-db437ce2a99d", + "metadata": {}, + "outputs": [], + "source": [ + "!aws s3 cp s3://example-s3-bucket/output/test/test.csv ." + ] + }, + { + "cell_type": "markdown", + "id": "27b6cc9e-cb1c-43f6-99b8-fc26b38934c3", + "metadata": {}, + "source": [ + "### Create a 20 row test dataframe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "953f9d9b-04d0-4398-8620-8f9ab4eb407b", + "metadata": {}, + "outputs": [], + "source": [ + "import boto3\n", + "import json\n", + "\n", + "test_df = pd.read_csv(\"test.csv\", nrows=20)\n", + "test_df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "a27e6c58-1abb-41db-ab45-263b97ee01ed", + "metadata": {}, + "source": [ + "### Get predictions from the test dataframe\n", + "\n", + "Define the `get_predictions` function to convert the 20 row dataframe to a CSV string and get predictions from the model endpoint. Provide the `get_predictions` function with the name of the model and the model endpoint." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "218e7887-f37d-42e1-8f6a-9ee97d3c75c4", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import pandas as pd\n", + "\n", + "# Initialize the SageMaker runtime client\n", + "runtime = boto3.client(\"runtime.sagemaker\")\n", + "\n", + "# Define the endpoint name\n", + "endpoint_name = \"sagemaker-xgboost-timestamp\"\n", + "\n", + "\n", + "# Function to make predictions\n", + "def get_predictions(data, endpoint_name):\n", + " # Convert the DataFrame to a CSV string and encode it to bytes\n", + " csv_data = data.to_csv(header=False, index=False).encode(\"utf-8\")\n", + "\n", + " response = runtime.invoke_endpoint(\n", + " EndpointName=endpoint_name, ContentType=\"text/csv\", Body=csv_data\n", + " )\n", + "\n", + " # Read the response body\n", + " response_body = response[\"Body\"].read().decode(\"utf-8\")\n", + "\n", + " try:\n", + " # Try to parse the response as JSON\n", + " result = json.loads(response_body)\n", + " except json.JSONDecodeError:\n", + " # If response is not JSON, just return the raw response\n", + " result = response_body\n", + "\n", + " return result\n", + "\n", + "\n", + "# Drop the target column from the test dataframe\n", + "test_df = test_df.drop(test_df.columns[0], axis=1)\n", + "\n", + "# Get predictions\n", + "predictions = get_predictions(test_df, endpoint_name)\n", + "print(predictions)" + ] + }, + { + "cell_type": "markdown", + "id": "a136ae86-efd3-4d4f-9966-6610f445d84c", + "metadata": {}, + "source": [ + "### Create an array from the string of predictions\n", + "\n", + "The notebook uses the newline character as the separator, so we use the following code to create an array of predictions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58b45ac2-8a18-4d27-8aff-57370696d58f", + "metadata": {}, + "outputs": [], + "source": [ + "predictions_array = predictions.split(\"\\n\")\n", + "predictions_array = predictions_array[:-1]\n", + "predictions_array" + ] + }, + { + "cell_type": "markdown", + "id": "20097b4e-d515-45cf-9677-bd12953b6912", + "metadata": {}, + "source": [ + "### Get the 20 row sample of the test dataframe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5b69119-c58d-401d-a683-345a21451090", + "metadata": {}, + "outputs": [], + "source": [ + "df_with_target_column_values = pd.read_csv(\"test.csv\", nrows=20)\n", + "df_with_target_column_values.head()" + ] + }, + { + "cell_type": "markdown", + "id": "85cd39f3-5f12-4cb1-aab2-6ca658e9d16e", + "metadata": {}, + "source": [ + "### Convert the values of the predictions array from strings to floats" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75353856-df2f-4c45-9a9b-11e16a856aa6", + "metadata": {}, + "outputs": [], + "source": [ + "predictions_array = [float(x) for x in predictions_array]" + ] + }, + { + "cell_type": "markdown", + "id": "408a6da9-9a0c-4307-8966-acbcc11beacc", + "metadata": {}, + "source": [ + "### Create a dataframe to store the predicted versus actual values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9589000e-1ce0-4a08-9d9c-055d29e13639", + "metadata": {}, + "outputs": [], + "source": [ + "comparison_df = pd.DataFrame(predictions_array, columns=[\"predicted_values\"])\n", + "comparison_df" + ] + }, + { + "cell_type": "markdown", + "id": "e0652e07-1677-4fd4-b099-ccc2b1029cfd", + "metadata": {}, + "source": [ + "### Add the actual values to the comparison dataframe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "adf4f58c-f21c-4abf-b14c-2802cbd399b3", + "metadata": {}, + "outputs": [], + "source": [ + "column_to_add = df_with_target_column_values.iloc[:, 0]\n", + "\n", + "comparison_df[\"actual_values\"] = column_to_add\n", + "\n", + "comparison_df" + ] + }, + { + "cell_type": "markdown", + "id": "a1ee137e-2706-4972-b70a-4d908bb0cb0a", + "metadata": {}, + "source": [ + "### Verify that the datatypes of both columns are floats" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48f6f988-0de8-4c44-8c10-9845ef4d476d", + "metadata": {}, + "outputs": [], + "source": [ + "comparison_df.dtypes" + ] + }, + { + "cell_type": "markdown", + "id": "8c7cce0b-ce8b-4320-b9a4-9a50b2c732b3", + "metadata": {}, + "source": [ + "### Compute the RMSE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "781fe125-4a2e-4527-8c45-fcd20558f4bb", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "# Calculate the squared differences between the predicted and actual values\n", + "comparison_df[\"squared_diff\"] = (\n", + " comparison_df[\"actual_values\"] - comparison_df[\"predicted_values\"]\n", + ") ** 2\n", + "\n", + "# Calculate the mean of the squared differences\n", + "mean_squared_diff = comparison_df[\"squared_diff\"].mean()\n", + "\n", + "# Take the square root of the mean to get the RMSE\n", + "rmse = np.sqrt(mean_squared_diff)\n", + "\n", + "print(f\"RMSE: {rmse}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4a21cb4e-d9be-466c-869d-ac0be688700c", + "metadata": {}, + "source": [ + "### Clean up" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a6e651d-3e68-4c1b-8a28-3e15604b5ec1", + "metadata": {}, + "outputs": [], + "source": [ + "# Delete the S3 bucket\n", + "!aws s3 rb s3://example-s3-bucket --force" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c883864-e707-46d2-a183-76e5f2090368", + "metadata": {}, + "outputs": [], + "source": [ + "# Delete the endpoint\n", + "xgb_predictor.delete_endpoint()" + ] + }, + { + "cell_type": "markdown", + "id": "cd9140e5", + "metadata": {}, + "source": [ + "## Notebook CI Test Results\n", + " \n", + "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", + "\n", + "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)\n", + "\n", + "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)\n", + "\n", + "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)\n", + "\n", + "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)\n", + "\n", + "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)\n", + "\n", + "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)\n", + "\n", + "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)\n", + "\n", + "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)\n", + "\n", + "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)\n", + "\n", + "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)\n", + "\n", + "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)\n", + "\n", + "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)\n", + "\n", + "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)\n", + "\n", + "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)\n", + "\n", + "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/use-cases|athena_ml_workflow_end_to_end|athena_ml_workflow_end_to_end.ipynb)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/use-cases/athena_ml_workflow_end_to_end/processing_data_split.py b/use-cases/athena_ml_workflow_end_to_end/processing_data_split.py new file mode 100644 index 0000000000..fb8472d011 --- /dev/null +++ b/use-cases/athena_ml_workflow_end_to_end/processing_data_split.py @@ -0,0 +1,32 @@ +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split +import os + +# Define the input and output paths +input_path = '/opt/ml/processing/input/feature-selection-query-id.csv' +train_output_path = '/opt/ml/processing/output/train/train.csv' +val_output_path = '/opt/ml/processing/output/validation/val.csv' +test_output_path = '/opt/ml/processing/output/test/test.csv' + +# Read the input data +df = pd.read_csv(input_path, header=None) + +# Split the data into training, validation, and test sets +train, temp = train_test_split(df, test_size=0.3, random_state=42) +val, test = train_test_split(temp, test_size=0.5, random_state=42) + +# Save the splits to the output paths +os.makedirs(os.path.dirname(train_output_path), exist_ok=True) +train.to_csv(train_output_path, index=False) + +os.makedirs(os.path.dirname(val_output_path), exist_ok=True) +val.to_csv(val_output_path, index=False) + +os.makedirs(os.path.dirname(test_output_path), exist_ok=True) +test.to_csv(test_output_path, index=False) + +# Print the sizes of the splits +print(f"Training set: {len(train)} samples") +print(f"Validation set: {len(val)} samples") +print(f"Test set: {len(test)} samples") diff --git a/use-cases/pyspark_etl_and_training/pyspark-etl-training.ipynb b/use-cases/pyspark_etl_and_training/pyspark-etl-training.ipynb new file mode 100644 index 0000000000..d441ff4ac6 --- /dev/null +++ b/use-cases/pyspark_etl_and_training/pyspark-etl-training.ipynb @@ -0,0 +1,734 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "3ff2d442", + "metadata": {}, + "source": [ + "# Perform ETL and train a model using PySpark\n", + "---\n", + "\n", + "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.\n", + "\n", + "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "0a1828f9-efdc-4d12-a676-a2f3432e9ab0", + "metadata": {}, + "source": [ + "To perform extract transform load (ETL) operations on multiple files, we recommend opening a Jupyter notebook within Amazon SageMaker Studio and using the `Glue PySpark and Ray` kernel. The kernel is connected to an AWS Glue Interactive Session. The session connects your notebook to a cluster that automatically scales up the storage and compute to meet your data processing needs. When you shut down the kernel, the session stops and you're no longer charged for the compute on the cluster.\n", + "\n", + "Within the notebook you can use Spark commands to join and transform your data. Writing Spark commands is both faster and easier than writing SQL queries. For example, you can use the join command to join two tables. Instead of writing a query that can sometimes take minutes to complete, you can join a table within seconds.\n", + "\n", + "To show the utility of using the PySpark kernel for your ETL and model training worklows, we're predicting the fare amount of the NYC taxi dataset. It imports data from 47 files across 2 different Amazon Simple Storage Service (Amazon S3) locations. Amazon S3 is an object storage service that you can use to save and access data and machine learning artifacts for your models. For more information about Amazon S3, see [What is Amazon S3?](https://docs.aws.amazon.com/AmazonS3/latest/userguide/Welcome.html).\n", + "\n", + "The notebook is not meant to be a comprehensive analysis. Instead, it's meant to be a proof of concept to help you quickly get started.\n", + "\n", + "__Prerequisites:__\n", + "\n", + "This tutorial assumes that you've in the us-east-1 AWS Region. It also assumes that you've provided the IAM role you're using to run the notebook with permissions to use Glue. For more information, see [Providing AWS Glue permissions\n", + "](docs.aws.amazon.com/sagemaker/latest/dg/perform-etl-and-train-model-pyspark.html#providing-aws-glue-permissions)." + ] + }, + { + "cell_type": "markdown", + "id": "dffc1f72-88d2-442d-97ee-0d1c4e095ffb", + "metadata": {}, + "source": [ + "## Solution overview \n", + "\n", + "To perform ETL on the NYC taxi data and train a model, we do the following\n", + "\n", + "1. Start a Glue Session and load the SageMaker Python SDK\n", + "2. Set up the utilities needed to work with AWS Glue.\n", + "3. Load the data from the Amazon S3 into Spark dataframes.\n", + "4. Verify that we've loaded the data successfully.\n", + "5. Save a 20000 row sample of the Spark dataframe as a pandas dataframe.\n", + "6. Create a correlation matrix as an example of the types of analyses we can perform.\n", + "7. Split the Spark dataframe into training, validation, and test datasets.\n", + "8. Write the datasets to Amazon S3 locations that can be accessed by an Amazon SageMaker training job.\n", + "9. Use the training and validation datasets to train a model." + ] + }, + { + "cell_type": "markdown", + "id": "e472c953-1625-49df-8df9-9529344783ab", + "metadata": {}, + "source": [ + "### Start a Glue Session and load the SageMaker Python SDK" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94172c75-f8a9-4590-a443-c872fb5c5d6e", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "%additional_python_modules sagemaker" + ] + }, + { + "cell_type": "markdown", + "id": "725bd4b6-82a0-4f02-95b9-261ce62c71b0", + "metadata": {}, + "source": [ + "### Set up the utilities needed to work with AWS Glue\n", + "\n", + "We're importing `Join` to join our Spark dataframes. `GlueContext` provides methods for transforming our dataframes. In the context of the notebook, it reads the data from the Amazon S3 locations and uses the Spark cluster to transform the data. `SparkContext` represents the connection to the Spark cluster. `GlueContext` uses `SparkContext` to transform the data. `getResolvedOptions` lets you resolve configuration options within the Glue interactive session." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ea1c3a4-8881-48b0-8888-9319812750e7", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "import sys\n", + "from awsglue.transforms import Join\n", + "from awsglue.utils import getResolvedOptions\n", + "from pyspark.context import SparkContext\n", + "from awsglue.context import GlueContext\n", + "from awsglue.job import Job\n", + "\n", + "glueContext = GlueContext(SparkContext.getOrCreate())" + ] + }, + { + "cell_type": "markdown", + "id": "e03664e5-89a2-4296-ba83-3518df4a58f0", + "metadata": {}, + "source": [ + "### Create the `df_ride_info` dataframe\n", + "\n", + "Create a single dataframe from all the ride_info Parquet files for 2019." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba577de7-9ffe-4bae-b4c0-b225181306d9", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_ride_info = glueContext.create_dynamic_frame_from_options(\n", + " connection_type=\"s3\",\n", + " format=\"parquet\",\n", + " connection_options={\n", + " \"paths\": [\n", + " \"s3://dsoaws/nyc-taxi-orig-cleaned-split-parquet-per-year-multiple-files/ride-info/year=2019/\"\n", + " ],\n", + " \"recurse\": True,\n", + " },\n", + ").toDF()" + ] + }, + { + "cell_type": "markdown", + "id": "b04ce553-bf3d-4922-bbb1-4aa264447276", + "metadata": {}, + "source": [ + "### Create the `df_ride_info` dataframe\n", + "\n", + "Create a single dataframe from all the ride_fare Parquet files for 2019." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6efc3d4a-81d7-40f5-bb62-cd206924a0c9", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_ride_fare = glueContext.create_dynamic_frame_from_options(\n", + " connection_type=\"s3\",\n", + " format=\"parquet\",\n", + " connection_options={\n", + " \"paths\": [\n", + " \"s3://dsoaws/nyc-taxi-orig-cleaned-split-parquet-per-year-multiple-files/ride-fare/year=2019/\"\n", + " ],\n", + " \"recurse\": True,\n", + " },\n", + ").toDF()" + ] + }, + { + "cell_type": "markdown", + "id": "6c8664da-2105-4ada-b480-06d50c59e878", + "metadata": {}, + "source": [ + "### Show the first five rows of `dr_ride_fare`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d63af3a3-358f-4c6e-97d4-97a1f1a552de", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_ride_fare.show(5)" + ] + }, + { + "cell_type": "markdown", + "id": "688a17e8-0c83-485d-a328-e89344a0e8bf", + "metadata": {}, + "source": [ + "### Join df_ride_fare and df_ride_info on the `ride_id` column" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07a3baab-44b0-416a-b12e-049a270af8bd", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_joined = df_ride_info.join(df_ride_fare, [\"ride_id\"])" + ] + }, + { + "cell_type": "markdown", + "id": "236c2efc-85f8-43f8-b6d3-7f0e61ccefb0", + "metadata": {}, + "source": [ + "### Show the first five rows of the joined dataframe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a456733-4533-4688-8174-368e50f4dd66", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_joined.show(5)" + ] + }, + { + "cell_type": "markdown", + "id": "1396f6ee-c581-4274-baf8-243d38ec000b", + "metadata": {}, + "source": [ + "### Show the data types of the dataframe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a52a903-f394-4d00-a216-6af8c2132d83", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_joined.printSchema()" + ] + }, + { + "cell_type": "markdown", + "id": "18bb75a2-eba5-4d06-8a26-f30e31776a02", + "metadata": {}, + "source": [ + "### Count the number of rows" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6bcc15f-8d41-4def-ae49-edaef4105343", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_joined.count()" + ] + }, + { + "cell_type": "markdown", + "id": "d2daa67c-4b21-433a-b46e-eed518ba9ce7", + "metadata": {}, + "source": [ + "### Drop duplicates if there are any" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7d13d8d9-7eed-4efb-b972-601baf291842", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_no_dups = df_joined.dropDuplicates([\"ride_id\"])" + ] + }, + { + "cell_type": "markdown", + "id": "657e48dc-1f4a-4550-afe1-d9754e6d0e1e", + "metadata": {}, + "source": [ + "### Count the number of rows after dropping the duplicates\n", + "\n", + "In this case, there were no duplicates in the original dataframe." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e3e82a3-e3db-4752-8bab-f42cbbae4928", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_no_dups.count()" + ] + }, + { + "cell_type": "markdown", + "id": "ae4c0fc4-7cb5-4b70-8430-965b5fe4506e", + "metadata": {}, + "source": [ + "### Drop columns\n", + "Time series data and categorical data is outside of the scope of the notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9dc1d15f-53f6-404d-86fd-5a28f3792db8", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_cleaned = df_joined.drop(\n", + " \"pickup_at\", \"dropoff_at\", \"store_and_fwd_flag\", \"vendor_id\", \"payment_type\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "081c81f9-f052-4ddb-b769-4d41b6138f6a", + "metadata": {}, + "source": [ + "### Take a sample from the notebook and convert it to a pandas dataframe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48382726-c767-4b0e-9336-decbf8184938", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_sample = df_cleaned.sample(False, 0.1, seed=0).limit(20000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2bf2f181-0096-4044-8210-7d9de299d966", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_sample.count()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8b2f670-c5f9-4a01-8d9f-6a29a3dae660", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_pandas = df_sample.toPandas()\n", + "df_pandas.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "246c98e9-64bd-4644-a163-b86a943d6a09", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "print(\"Dataset shape: \", df_pandas.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5b2727c-de75-4cc0-94e9-d254e235d003", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_pandas.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d69b48b6-98c2-4851-9c7a-f24f092bae41", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_pandas.info()" + ] + }, + { + "cell_type": "markdown", + "id": "34222bea-8864-4934-8c93-a71a7e72325b", + "metadata": {}, + "source": [ + "### Create a correlation matrix of the features\n", + "\n", + "We're creating a correlation matrix to see which features are the most predictive. This is an example of an analysis that you can use for your own use case." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7f3e4f7-e04e-41e1-b94b-b32eb3bc3bbf", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "from pyspark.ml.stat import Correlation\n", + "from pyspark.ml.feature import VectorAssembler\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd # not sure how the kernel runs, but it looks like I have import pandas again after going back to the notebook after a while\n", + "\n", + "vector_col = \"corr_features\"\n", + "assembler = VectorAssembler(inputCols=df_sample.columns, outputCol=vector_col)\n", + "df_vector = assembler.transform(df_sample).select(vector_col)\n", + "\n", + "matrix = Correlation.corr(df_vector, vector_col).collect()[0][0]\n", + "corr_matrix = matrix.toArray().tolist()\n", + "corr_matrix_df = pd.DataFrame(data=corr_matrix, columns=df_sample.columns, index=df_sample.columns)\n", + "\n", + "plt.figure(figsize=(16, 10))\n", + "sns.heatmap(\n", + " corr_matrix_df,\n", + " xticklabels=corr_matrix_df.columns.values,\n", + " yticklabels=corr_matrix_df.columns.values,\n", + " cmap=\"Greens\",\n", + " annot=True,\n", + ")\n", + "\n", + "%matplot plt" + ] + }, + { + "cell_type": "markdown", + "id": "cbde3b29-d37d-485a-a114-5313c5a702c7", + "metadata": {}, + "source": [ + "### Split the dataset into train, validation, and test sets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e207c64-2e22-468f-a0c7-948090bcfce2", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_train, df_val, df_test = df_cleaned.randomSplit([0.7, 0.15, 0.15])" + ] + }, + { + "cell_type": "markdown", + "id": "01a4d181-e2f0-4743-ab35-dd1f68b0fd31", + "metadata": {}, + "source": [ + "### Define the Amazon S3 locations that store the datasets\n", + "\n", + "If you're getting a module not found error, restart the kernel and run all the cells again." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f16ea3a1-6d6d-4755-94ad-c743298bd130", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "# Define the S3 locations to store the datasets\n", + "import boto3\n", + "import sagemaker\n", + "\n", + "sagemaker_session = sagemaker.Session()\n", + "s3_bucket = sagemaker_session.default_bucket()\n", + "train_data_prefix = \"sandbox/glue-demo/train\"\n", + "validation_data_prefix = \"sandbox/glue-demo/validation\"\n", + "test_data_prefix = \"sandbox/glue-demo/test\"\n", + "region = boto3.Session().region_name" + ] + }, + { + "cell_type": "markdown", + "id": "8899a159-700c-403a-b4f5-a00c62b06e5a", + "metadata": {}, + "source": [ + "### Write the files to the locations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64d7ae48-6158-4273-8bb3-2f00abb1c20c", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_train.write.parquet(f\"s3://{s3_bucket}/{train_data_prefix}\", mode=\"overwrite\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de3d1190-4717-4944-846d-0169c093cb90", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_val.write.parquet(f\"s3://{s3_bucket}/{validation_data_prefix}\", mode=\"overwrite\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d18ef1c-fc2f-4e34-a692-4a6c48be7cba", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "df_test.write.parquet(f\"s3://{s3_bucket}/{test_data_prefix}\", mode=\"overwrite\")" + ] + }, + { + "cell_type": "markdown", + "id": "73c947e4-b4a9-4cc4-aefe-755aa0a713c8", + "metadata": {}, + "source": [ + "### Train a model\n", + "\n", + "The following code uses the `df_train` and `df_val` datasets to train an XGBoost model. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a31b7742-93df-44c5-8674-b6355032c508", + "metadata": { + "vscode": { + "languageId": "python_glue_session" + } + }, + "outputs": [], + "source": [ + "from sagemaker import image_uris\n", + "from sagemaker.inputs import TrainingInput\n", + "\n", + "hyperparameters = {\n", + " \"max_depth\": \"5\",\n", + " \"eta\": \"0.2\",\n", + " \"gamma\": \"4\",\n", + " \"min_child_weight\": \"6\",\n", + " \"subsample\": \"0.7\",\n", + " \"objective\": \"reg:squarederror\",\n", + " \"num_round\": \"50\",\n", + "}\n", + "\n", + "# Set an output path to save the trained model.\n", + "prefix = \"sandbox/glue-demo\"\n", + "output_path = f\"s3://{s3_bucket}/{prefix}/xgb-built-in-algo/output\"\n", + "\n", + "# The following line looks for the XGBoost image URI and builds an XGBoost container.\n", + "# We use version 1.7-1 of the image URI, you can specify a version that you prefer.\n", + "xgboost_container = sagemaker.image_uris.retrieve(\"xgboost\", region, \"1.7-1\")\n", + "\n", + "# Construct a SageMaker estimator that calls the xgboost-container\n", + "estimator = sagemaker.estimator.Estimator(\n", + " image_uri=xgboost_container,\n", + " hyperparameters=hyperparameters,\n", + " role=sagemaker.get_execution_role(),\n", + " instance_count=1,\n", + " instance_type=\"ml.m5.4xlarge\",\n", + " output_path=output_path,\n", + ")\n", + "\n", + "content_type = \"application/x-parquet\"\n", + "train_input = TrainingInput(f\"s3://{s3_bucket}/{prefix}/train/\", content_type=content_type)\n", + "validation_input = TrainingInput(\n", + " f\"s3://{s3_bucket}/{prefix}/validation/\", content_type=content_type\n", + ")\n", + "\n", + "# Run the XGBoost training job\n", + "estimator.fit({\"train\": train_input, \"validation\": validation_input})" + ] + }, + { + "cell_type": "markdown", + "id": "b1b1d546-1c7e-48f5-9262-939289ada936", + "metadata": {}, + "source": [ + "### Clean up\n", + "\n", + "To clean up, shut down the kernel. Shutting down the kernel, stops the Glue cluster. You won't be charged for any more compute other than what you used to run the tutorial." + ] + }, + { + "cell_type": "markdown", + "id": "99668011", + "metadata": {}, + "source": [ + "## Notebook CI Test Results\n", + " \n", + "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", + "\n", + "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)\n", + "\n", + "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)\n", + "\n", + "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)\n", + "\n", + "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)\n", + "\n", + "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)\n", + "\n", + "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)\n", + "\n", + "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)\n", + "\n", + "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)\n", + "\n", + "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)\n", + "\n", + "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)\n", + "\n", + "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)\n", + "\n", + "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)\n", + "\n", + "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)\n", + "\n", + "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)\n", + "\n", + "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/use-cases|pyspark_etl_and_training|pyspark-etl-training.ipynb)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Glue PySpark and Ray", + "language": "python", + "name": "glue_pyspark" + }, + "language_info": { + "codemirror_mode": { + "name": "python", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "Python_Glue_Session", + "pygments_lexer": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}