-
Notifications
You must be signed in to change notification settings - Fork 1.2k
documentation: update PyTorch BYOM topic #1457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
9659e05
31c8ff1
7c4dfa2
db84ea7
92d4fec
cd80ca7
fa22a81
e2cf89c
8f7c92a
ad29c6c
596fc57
16a61fe
20b21bc
badad47
4fded1d
6a9d152
9695a03
8ea4170
009624f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,9 +4,13 @@ Using PyTorch with the SageMaker Python SDK | |||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| With PyTorch Estimators and Models, you can train and host PyTorch models on Amazon SageMaker. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| <<<<<<< HEAD | ||||||||||||||||||||||||
| * Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``, ``1.3.1``. | ||||||||||||||||||||||||
| ======= | ||||||||||||||||||||||||
| Supported versions of PyTorch: ``0.4.0``, ``1.0.0``, ``1.1.0``, ``1.2.0``, ``1.3.1``, ``1.4.0``. | ||||||||||||||||||||||||
| >>>>>>> 53fe1dc2025a1ba6e7fe4f16f120dfcc245ed465 | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Supported versions of PyTorch for Elastic Inference: ``1.3.1``. | ||||||||||||||||||||||||
| * Supported versions of PyTorch for Elastic Inference: ``1.3.1``. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| We recommend that you use the latest supported version because that's where we focus our development efforts. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
@@ -90,7 +94,7 @@ Note that SageMaker doesn't support argparse actions. If you want to use, for ex | |||||||||||||||||||||||
| you need to specify `type` as `bool` in your script and provide an explicit `True` or `False` value for this hyperparameter | ||||||||||||||||||||||||
| when instantiating PyTorch Estimator. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| For more on training environment variables, please visit `SageMaker Containers <https://github.com/aws/sagemaker-containers>`_. | ||||||||||||||||||||||||
| For more on training environment variables, see `SageMaker Containers <https://github.com/aws/sagemaker-containers>`_. | ||||||||||||||||||||||||
eslesar-aws marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Save the Model | ||||||||||||||||||||||||
| -------------- | ||||||||||||||||||||||||
|
|
@@ -115,7 +119,7 @@ to a certain filesystem path called ``model_dir``. This value is accessible thro | |||||||||||||||||||||||
| with open(os.path.join(args.model_dir, 'model.pth'), 'wb') as f: | ||||||||||||||||||||||||
| torch.save(model.state_dict(), f) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| After your training job is complete, SageMaker will compress and upload the serialized model to S3, and your model data | ||||||||||||||||||||||||
| After your training job is complete, SageMaker compresses and uploads the serialized model to S3, and your model data | ||||||||||||||||||||||||
| will be available in the S3 ``output_path`` you specified when you created the PyTorch Estimator. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| If you are using Elastic Inference, you must convert your models to the TorchScript format and use ``torch.jit.save`` to save the model. | ||||||||||||||||||||||||
|
|
@@ -566,11 +570,91 @@ The function should return a byte array of data serialized to content_type. | |||||||||||||||||||||||
| The default implementation expects ``prediction`` to be a torch.Tensor and can serialize the result to JSON, CSV, or NPY. | ||||||||||||||||||||||||
| It accepts response content types of "application/json", "text/csv", and "application/x-npy". | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Working with Existing Model Data and Training Jobs | ||||||||||||||||||||||||
| ================================================== | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Attach to existing training jobs | ||||||||||||||||||||||||
| -------------------------------- | ||||||||||||||||||||||||
| Bring your own model | ||||||||||||||||||||||||
| ==================== | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| You can deploy a PyTorch model that you trained outside of SageMaker by using the ``PyTorchModel`` class. | ||||||||||||||||||||||||
| Typically, you save a PyTorch model as a file with extension ``.pt`` or ``.pth``. | ||||||||||||||||||||||||
| To do this, you need to: | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| * Write an inference script. | ||||||||||||||||||||||||
| * Package the model artifacts into a tar.gz file. | ||||||||||||||||||||||||
eslesar-aws marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||
| * Upload the tar.gz file to an S3 bucket. | ||||||||||||||||||||||||
eslesar-aws marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||
| * Create the ``PyTorchModel`` object. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Write an inference script | ||||||||||||||||||||||||
| ------------------------- | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| You must create an inference script that implements (at least) the ``predict_fn`` function that calls the loaded model to get a prediction. | ||||||||||||||||||||||||
eslesar-aws marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||
| Optionally, you can also implement ``input_fn`` and ``output_fn`` to process input and output. | ||||||||||||||||||||||||
| For information about how to write an inference script, see `Serve a PyTorch Model <#serve-a-pytorch-model>`_. | ||||||||||||||||||||||||
| Save the inference script as ``inference.py`` in the same folder where you saved your PyTorch model. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Package model artifacts into a tar.gz file | ||||||||||||||||||||||||
| ------------------------------------------ | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| The directory structure where you saved your PyTorch model should look something like the following: | ||||||||||||||||||||||||
eslesar-aws marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| | my_model | ||||||||||||||||||||||||
| | |--model.pth | ||||||||||||||||||||||||
| | | ||||||||||||||||||||||||
| | code | ||||||||||||||||||||||||
| | |--inference.py | ||||||||||||||||||||||||
| | |--requirements.txt | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Where ``requirments.txt`` is an optional file that specifies dependencies on third-party libraries. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| With this file structure, run the following command to package your model as a ``tar.gz`` file: | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| ``tar -czf model.tar.gz my_model code`` | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Upload model.tar.gz to S3 | ||||||||||||||||||||||||
| ------------------------- | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| After you package your model into a ``tar.gz`` file, upload it to an S3 bucket by running the following python code: | ||||||||||||||||||||||||
eslesar-aws marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| .. code:: python | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| import boto3 | ||||||||||||||||||||||||
| import sagemaker | ||||||||||||||||||||||||
| s3 = boto3.client('s3') | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| from sagemaker import get_execution_role | ||||||||||||||||||||||||
| role = get_execution_role() | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| response = s3.upload_file('model.tar.gz', 'my-bucket', '%s/%s' %('my-path', 'model.tar.gz')) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| import boto3 | |
| import sagemaker | |
| s3 = boto3.client('s3') | |
| from sagemaker import get_execution_role | |
| role = get_execution_role() | |
| response = s3.upload_file('model.tar.gz', 'my-bucket', '%s/%s' %('my-path', 'model.tar.gz')) | |
| from sagemaker.s3 import S3Uploader | |
| S3Uploader.upload('model.tar.gz', 's3://my-bucket/my-path/model.tar.gz') |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uh oh!
There was an error while loading. Please reload this page.