From 91e93539f234fba0d6de544b78d881f5504e6508 Mon Sep 17 00:00:00 2001 From: Ben Firshman Date: Tue, 4 Jan 2022 15:51:46 -0800 Subject: [PATCH 01/14] Document new Cog Signed-off-by: Ben Firshman --- README.md | 86 +++++++++++++++---------------- docs/getting-started-own-model.md | 35 ++++++++----- docs/getting-started.md | 14 ++--- docs/python.md | 13 +++-- 4 files changed, 80 insertions(+), 68 deletions(-) diff --git a/README.md b/README.md index e62ca0f692..e7cb437f1b 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,14 @@ # Cog: Containers for machine learning -Use Docker for machine learning, without all the pain. +Put your machine learning model in a standard, production-ready Docker container without having to know how Docker works. -Cog gives you a consistent environment to run your model in – for developing on your laptop, training on GPU machines, and for other people working on the model. Then, when the model is trained and you want to share or deploy it, you can bake the model into a Docker image that serves a standard HTTP API. +Cog does a few things for you: -Cog does a few handy things beyond normal Docker: +- **Automatic Docker image.** Define your environment with a simple configuration file, and Cog generates a `Dockerfile` with all the best practices. +- **Standard, production-ready HTTP and AMQP interface.** Automatically generate APIs for integrating with production systems, battle hardened on Replicate. +- **No more CUDA hell.** Cog knows which CUDA/cuDNN/PyTorch/Tensorflow/Python combos are compatible and will set it all up correctly for you. -- **Automatic Docker image.** Define your environment with a simple configuration file, then Cog will generate Dockerfiles with best practices and do all the GPU configuration for you. -- **Automatic HTTP service.** Cog will generate an HTTP service from the definition of your model, so you don't need to write a Flask server in the right way. -- **No more CUDA hell.** Cog knows which CUDA/cuDNN/PyTorch/Tensorflow/Python combos are compatible and will pick the right versions for you. - -## Develop and train in a consistent environment +## How it works Define the Docker environment your model runs in with `cog.yaml`: @@ -23,46 +21,24 @@ build: python_version: "3.8" python_packages: - "torch==1.8.1" +predict: "predict.py:Predictor" ``` -Now, you can run commands inside this environment: - -``` -$ cog run python train.py -... -``` - -This will: - -- Generate a `Dockerfile` with best practices -- Pick the right CUDA version -- Build an image -- Run `python train.py` in the image with the current directory mounted as a volume and GPUs hooked up correctly - - - -## Put a trained model in a Docker image - -First, you define how predictions are run on your model: +And define how predictions are run on your model with `predict.py`: ```python -import cog +from cog import Predictor, Input, Path import torch -class ColorizationPredictor(cog.Predictor): +class ColorizationPredictor(Predictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.model = torch.load("./weights.pth") - + # The arguments and types the model takes as input - @cog.input("input", type=cog.Path, help="Grayscale input image") - def predict(self, input): + def predict(self, + input: Path = Input(title="Grayscale input image") + ) -> Path: """Run a single prediction on the model""" processed_input = preprocess(input) output = self.model(processed_input) @@ -87,10 +63,34 @@ $ cog build -t my-colorization-model $ docker run -d -p 5000:5000 --gpus all my-colorization-model -$ curl http://localhost:5000/predict -X POST -F input=@image.png +$ curl http://localhost:5000/predictions -X POST \ + --data '{"input": "https://.../input.jpg"}' ``` -That's it! Your model will now run forever in this reproducible Docker environment. + + + + +## Deploying models to production + +Cog does a number of things out of the box to help you deploy models to production: + +- **Standard interface.** Put models inside Cog containers, and they'll run anywhere that runs Docker containers. +- **HTTP prediction server, based on FastAPI.** +- **Type checking, based on Pydantic.** Cog models define their input and output with JSON Schema, and the HTTP server is defined with OpenAPI. +- **AMQP RPC interface.** Long-running deep learning models or batch processing is best architected with a queue. Cog models can do this out of the box. +- **Read/write files from cloud storage.** Files can be read and written directly on Amazon S3 and Google Cloud Storage for efficiency. ## Why are we building this? @@ -129,10 +129,10 @@ sudo chmod +x /usr/local/bin/cog - [Get started with your own model](docs/getting-started-own-model.md) - [Take a look at some examples of using Cog](https://github.com/replicate/cog-examples) - [`cog.yaml` reference](docs/yaml.md) to learn how to define your model's environment -- [Prediction interface reference](docs/python.md) to learn how the `cog.Predictor` interface works +- [Prediction interface reference](docs/python.md) to learn how the `Predictor` interface works ## Need help? - + [Join us in #cog on Discord.](https://discord.gg/QmzJApGjyE) ## Contributors ✨ @@ -166,4 +166,4 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d -This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! \ No newline at end of file +This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! diff --git a/docs/getting-started-own-model.md b/docs/getting-started-own-model.md index a7e43be394..2d8dfca4da 100644 --- a/docs/getting-started-own-model.md +++ b/docs/getting-started-own-model.md @@ -67,7 +67,7 @@ The next step is to update `predict.py` to define the interface for running pred ```python import cog -from pathlib import Path +from cog import Path, Input import torch class Predictor(cog.Predictor): @@ -75,10 +75,10 @@ class Predictor(cog.Predictor): """Load the model into memory to make running multiple predictions efficient""" self.net = torch.load("weights.pth") - # Define the input types for a prediction - @cog.input("input", type=Path, help="Image to enlarge") - @cog.input("scale", type=float, default=1.5, help="Factor to scale image by") - def predict(self, input, scale): + def predict(self, + image: Path = Input(description="Image to enlarge"), + scale: float = Input(description="Factor to scale image by", default=1.5) + ) -> Path: """Run a single prediction on the model""" # ... pre-processing ... output = self.net(input) @@ -88,16 +88,25 @@ class Predictor(cog.Predictor): Edit your `predict.py` file and fill in the functions with your own model's setup and prediction code. You might need to import parts of your model from another file. -You also need to define the inputs to your model using the `@cog.input()` decorator, as demonstrated above. The first argument maps to the name of the argument in the `predict()` function, and it also takes these other arguments: +You also need to define the inputs to your model as arguments to the `predict()` function, as demonstrated above. For each argument, you need to annotate with a type. The supported types are: -- `type`: Either `str`, `int`, `float`, `bool`, or `Path` (be sure to add the import, as in the example above). `Path` is used for files. For more complex inputs, save it to a file and use `Path`. -- `help`: A description of what to pass to this input for users of the model +- `str`: a string +- `int`: an integer +- `float`: a floating point number +- `bool`: a boolean +- `cog.File`: a file-like object representing a file +- `cog.Path`: a path to a file on disk + +You can provide more information about the input with the `Input()` function, as shown above. It takes these basic arguments: + +- `description`: A description of what to pass to this input for users of the model - `default`: A default value to set the input to. If this argument is not passed, the input is required. If it is explicitly set to `None`, the input is optional. -- `min`: A minimum value for `int` or `float` types. -- `max`: A maximum value for `int` or `float` types. -- `options`: A list of values to limit the input to. It can be used with `str`, `int`, and `float` inputs. +- `gt`: For `int` or `float` types, the value should be greater than this number. +- `ge`: For `int` or `float` types, the value should be greater than or equal to this number. +- `lt`: For `int` or `float` types, the value should be less than this number. +- `le`: For `int` or `float` types, the value should be less than or equal to this number. -For more details about writing your model interface, [take a look at the prediction interface documentation](python.md). +There are some more advanced options you can pass, too. For more details, [take a look at the prediction interface documentation](python.md). Next, add the line `predict: "predict.py:Predictor"` to your `cog.yaml`, so it looks something like this: @@ -122,7 +131,7 @@ Written output to output.png To pass more inputs to the model, you can add more `-i` options: ``` -$ cog predict -i input=@input.jpg -i scale=2.0 +$ cog predict -i image=@image.jpg -i scale=2.0 ``` In this case it is just a number, not a file, so you don't need the `@` prefix. diff --git a/docs/getting-started.md b/docs/getting-started.md index aacca905e2..2f61c341af 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -59,10 +59,11 @@ First, run this to get some pre-trained model weights: Then, we need to write some code to describe how predictions are run on the model. Save this to `predict.py`: ```python +from typing import Any import cog -from pathlib import Path +from cog import Input, Path from tensorflow.keras.applications.resnet50 import ResNet50 -from tensorflow.keras.preprocessing import image +from tensorflow.keras.preprocessing import image as keras_image from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions import numpy as np @@ -73,18 +74,17 @@ class ResNetPredictor(cog.Predictor): self.model = ResNet50(weights='resnet50_weights_tf_dim_ordering_tf_kernels.h5') # Define the arguments and types the model takes as input - @cog.input("input", type=Path, help="Image to classify") - def predict(self, input): + def predict(self, image: Path = Input(description="Image to classify")) -> Any: """Run a single prediction on the model""" # Preprocess the image - img = image.load_img(input, target_size=(224, 224)) - x = image.img_to_array(img) + img = keras_image.load_img(image, target_size=(224, 224)) + x = keras_image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x) # Run the prediction preds = self.model.predict(x) # Return the top 3 predictions - return str(decode_predictions(preds, top=3)[0]) + return decode_predictions(preds, top=3)[0] ``` We also need to point Cog at this, and tell it what Python dependencies to install. Update `cog.yaml` to look like this: diff --git a/docs/python.md b/docs/python.md index 0a1d9488fa..eb52458c41 100644 --- a/docs/python.md +++ b/docs/python.md @@ -4,18 +4,21 @@ You define how Cog runs predictions on your model by defining a class that inher ```python import cog -from pathlib import Path +from cog import Path, Input import torch class ImageScalingPredictor(cog.Predictor): def setup(self): + """Load the model into memory to make running multiple predictions efficient""" self.model = torch.load("weights.pth") - @cog.input("input", type=Path, help="Image to enlarge") - @cog.input("scale", type=float, default=1.5, help="Factor to scale image by") - def predict(self, input): + def predict(self, + image: Path = Input(description="Image to enlarge"), + scale: float = Input(description="Factor to scale image by", default=1.5) + ) -> Path: + """Run a single prediction on the model""" # ... pre-processing ... - output = self.model(input) + output = self.model(image) # ... post-processing ... return output ``` From 782d2f4e6e2c56672c221ba93e0583d35e79085e Mon Sep 17 00:00:00 2001 From: Ben Firshman Date: Tue, 11 Jan 2022 12:48:16 -0800 Subject: [PATCH 02/14] New Cog Signed-off-by: Ben Firshman --- go.mod | 4 +- go.sum | 16 +- pkg/cli/predict.go | 81 +++-- pkg/image/build.go | 21 +- pkg/image/openapi_schema.go | 66 ++++ pkg/image/type_signature.go | 74 ---- pkg/predict/input.go | 20 + pkg/predict/output.go | 14 - pkg/predict/predictor.go | 135 +++---- python/cog/__init__.py | 12 +- .../{type_signature.py => openapi_schema.py} | 9 +- python/cog/files.py | 29 ++ python/cog/input.py | 202 ---------- python/cog/json.py | 76 ++-- python/cog/predictor.py | 88 +++-- python/cog/response.py | 24 ++ python/cog/server/http.py | 182 +++++---- python/cog/types.py | 128 +++++++ python/setup.py | 9 +- python/tests/server/test_http.py | 242 ++++++++---- python/tests/server/test_http_input.py | 344 +++++++++--------- python/tests/server/test_http_output.py | 119 ++++-- python/tests/test_json.py | 59 +++ requirements-dev.txt | 7 +- .../fixtures/failing-project/predict.py | 6 +- .../fixtures/file-input-project/cog.yaml | 3 + .../fixtures/file-input-project/predict.py | 8 + .../fixtures/file-output-project/cog.yaml | 5 + .../fixtures/file-output-project/predict.py | 13 + .../fixtures/file-project/predict.py | 6 +- .../fixtures/int-project/cog.yaml | 3 + .../fixtures/int-project/predict.py | 6 + .../fixtures/logging-project/predict.py | 3 +- .../fixtures/string-project/predict.py | 3 +- .../subdirectory-project/my-subdir/predict.py | 3 +- .../fixtures/timeout-project/predict.py | 6 +- .../fixtures/yielding-project/predict.py | 8 +- .../yielding-timeout-project/predict.py | 10 +- .../test_integration/test_build.py | 24 +- .../test_integration/test_predict.py | 58 +++ 40 files changed, 1206 insertions(+), 920 deletions(-) create mode 100644 pkg/image/openapi_schema.go delete mode 100644 pkg/image/type_signature.go delete mode 100644 pkg/predict/output.go rename python/cog/command/{type_signature.py => openapi_schema.py} (78%) create mode 100644 python/cog/files.py delete mode 100644 python/cog/input.py create mode 100644 python/cog/response.py create mode 100644 python/cog/types.py create mode 100644 python/tests/test_json.py create mode 100644 test-integration/test_integration/fixtures/file-input-project/cog.yaml create mode 100644 test-integration/test_integration/fixtures/file-input-project/predict.py create mode 100644 test-integration/test_integration/fixtures/file-output-project/cog.yaml create mode 100644 test-integration/test_integration/fixtures/file-output-project/predict.py create mode 100644 test-integration/test_integration/fixtures/int-project/cog.yaml create mode 100644 test-integration/test_integration/fixtures/int-project/predict.py diff --git a/go.mod b/go.mod index 90092083d6..1b7c390caf 100644 --- a/go.mod +++ b/go.mod @@ -3,15 +3,14 @@ module github.com/replicate/cog go 1.16 require ( - github.com/TylerBrock/colorjson v0.0.0-20200706003622-8a50f05110d2 github.com/anaskhan96/soup v1.2.5 github.com/docker/cli v20.10.12+incompatible github.com/docker/docker v20.10.12+incompatible github.com/docker/docker-credential-helpers v0.6.4 // indirect github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.4.0 // indirect + github.com/getkin/kin-openapi v0.89.0 github.com/golangci/golangci-lint v1.44.0 - github.com/hokaccha/go-prettyjson v0.0.0-20210113012101-fb4e108d2519 // indirect github.com/logrusorgru/aurora v2.0.3+incompatible github.com/mattn/go-isatty v0.0.14 github.com/mattn/go-runewidth v0.0.13 // indirect @@ -22,6 +21,7 @@ require ( github.com/spf13/cobra v1.3.0 github.com/stretchr/objx v0.2.0 // indirect github.com/stretchr/testify v1.7.0 + github.com/vincent-petithory/dataurl v1.0.0 github.com/xeipuuv/gojsonschema v1.2.0 github.com/xeonx/timeago v1.0.0-rc4 golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e diff --git a/go.sum b/go.sum index 0e889416dc..ec56cd56bf 100644 --- a/go.sum +++ b/go.sum @@ -76,8 +76,6 @@ github.com/Masterminds/sprig v2.22.0+incompatible/go.mod h1:y6hNFY5UBTIWBxnzTeuN github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/OpenPeeDeeP/depguard v1.1.0 h1:pjK9nLPS1FwQYGGpPxoMYpe7qACHOhAWQMQzV71i49o= github.com/OpenPeeDeeP/depguard v1.1.0/go.mod h1:JtAMzWkmFEzDPyAd+W0NHl1lvpQKTvT9jnRVsohBKpc= -github.com/TylerBrock/colorjson v0.0.0-20200706003622-8a50f05110d2 h1:ZBbLwSJqkHBuFDA6DUhhse0IGJ7T5bemHyNILUjvOq4= -github.com/TylerBrock/colorjson v0.0.0-20200706003622-8a50f05110d2/go.mod h1:VSw57q4QFiWDbRnjdX8Cb3Ow0SFncRw+bA/ofY6Q83w= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= @@ -222,6 +220,9 @@ github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5 github.com/fullstorydev/grpcurl v1.6.0/go.mod h1:ZQ+ayqbKMJNhzLmbpCiurTVlaK2M/3nqZCxaQ2Ze/sM= github.com/fzipp/gocyclo v0.4.0 h1:IykTnjwh2YLyYkGa0y92iTTEQcnyAz0r9zOo15EbJ7k= github.com/fzipp/gocyclo v0.4.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA= +github.com/getkin/kin-openapi v0.89.0 h1:p4nagHchUKGn85z/f+pse4aSh50nIBOYjOhMIku2hiA= +github.com/getkin/kin-openapi v0.89.0/go.mod h1:660oXbgy5JFMKreazJaQTw7o+X00qeSyhcnluiMv+Xg= +github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-critic/go-critic v0.6.2 h1:L5SDut1N4ZfsWZY0sH4DCrsHLHnhuuWak2wa165t9gs= github.com/go-critic/go-critic v0.6.2/go.mod h1:td1s27kfmLpe5G/DPjlnFI7o1UCzePptwU7Az0V5iCM= @@ -233,6 +234,10 @@ github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2 github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/swag v0.19.5 h1:lTz6Ys4CmqqCQmZPBlbQENR1/GucA2bzYTE12Pw4tFY= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-redis/redis v6.15.8+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= @@ -444,8 +449,6 @@ github.com/hashicorp/memberlist v0.2.2/go.mod h1:MS2lj3INKhZjWNqd3N0m3J+Jxf3DAOn github.com/hashicorp/memberlist v0.3.0/go.mod h1:MS2lj3INKhZjWNqd3N0m3J+Jxf3DAOnAH9VT3Sh9MUE= github.com/hashicorp/serf v0.9.5/go.mod h1:UWDWwZeL5cuWDJdl0C6wrvrUwEqtQ4ZKBKKENpqIUyk= github.com/hashicorp/serf v0.9.6/go.mod h1:TXZNMjZQijwlDvp+r0b63xZ45H7JmCmgg4gpTwn9UV4= -github.com/hokaccha/go-prettyjson v0.0.0-20210113012101-fb4e108d2519 h1:nqAlWFEdqI0ClbTDrhDvE/8LeQ4pftrqKUX9w5k0j3s= -github.com/hokaccha/go-prettyjson v0.0.0-20210113012101-fb4e108d2519/go.mod h1:pFlLw2CfqZiIBOx6BuCeRLCrfxBJipTY0nIOF/VbGcI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.0.0/go.mod h1:4qWG/gcEcfX4z/mBDHJ++3ReCw9ibxbsNJbcucJdbSo= github.com/huandu/xstrings v1.2.0/go.mod h1:DvyZB1rfVYsBIigL8HwpZgxHwXozlTgGqn63UyNX5k4= @@ -534,6 +537,9 @@ github.com/lyft/protoc-gen-star v0.5.3/go.mod h1:V0xaHgaf5oCCqmcxYcWiDfTiKsZsRc8 github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls= github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e h1:hB2xlXdHp/pmPZq0y3QnmWAArdw9PqbmotexnWx/FU8= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/maratori/testpackage v1.0.1 h1:QtJ5ZjqapShm0w5DosRjg0PRlSdAdlx+W6cCKoALdbQ= github.com/maratori/testpackage v1.0.1/go.mod h1:ddKdw+XG0Phzhx8BFDTKgpWP4i7MpApTE5fXSKAqwDU= github.com/matoous/godox v0.0.0-20210227103229-6504466cf951 h1:pWxk9e//NbPwfxat7RXkts09K+dEBJWakUWwICVqYbA= @@ -826,6 +832,8 @@ github.com/valyala/fasthttp v1.30.0/go.mod h1:2rsYD01CKFrjjsvFxx75KlEUNpWNBY9JWD github.com/valyala/quicktemplate v1.7.0/go.mod h1:sqKJnoaOF88V07vkO+9FL8fb9uZg/VPSJnLYn+LmLk8= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8/go.mod h1:dniwbG03GafCjFohMDmz6Zc6oCuiqgH6tGNyXTkHzXE= +github.com/vincent-petithory/dataurl v1.0.0 h1:cXw+kPto8NLuJtlMsI152irrVw9fRDX8AbShPRpg2CI= +github.com/vincent-petithory/dataurl v1.0.0/go.mod h1:FHafX5vmDzyP+1CQATJn7WFKc9CvnvxyvZy6I1MrG/U= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index 31c3a7021a..503dacf76b 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -1,15 +1,15 @@ package cli import ( + "bytes" "encoding/json" "fmt" - "io" "os" "strings" - "github.com/TylerBrock/colorjson" "github.com/mitchellh/go-homedir" "github.com/spf13/cobra" + "github.com/vincent-petithory/dataurl" "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" @@ -110,7 +110,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error { // FIXME: will not run on signal defer func() { - console.Infof("Stopping container...") + console.Debugf("Stopping container...") if err := predictor.Stop(); err != nil { console.Warnf("Failed to stop container: %s", err) } @@ -122,44 +122,60 @@ func cmdPredict(cmd *cobra.Command, args []string) error { func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, outputPath string) error { console.Info("Running prediction...") inputs := parseInputFlags(inputFlags) - result, err := predictor.Predict(inputs) + prediction, err := predictor.Predict(inputs) + if err != nil { + return err + } + schema, err := predictor.GetSchema() if err != nil { return err } - // TODO(andreas): support multiple outputs? - output := result.Values["output"] - - // Write to stdout - if outputPath == "" { - // Is it something we can sensibly write to stdout? - if output.MimeType == "text/plain" { - output, err := io.ReadAll(output.Buffer) - if err != nil { - return err - } - console.Output(string(output)) - return nil - } else if output.MimeType == "application/json" { - var obj interface{} - dec := json.NewDecoder(output.Buffer) - if err := dec.Decode(&obj); err != nil { - return err - } - f := colorjson.NewFormatter() - f.Indent = 2 - s, _ := f.Marshal(obj) - console.Output(string(s)) - return nil + // Generate output depending on type in schema + var out []byte + outputSchema := schema.Components.Schemas["Response"].Value.Properties["output"].Value + if outputSchema.Type == "string" && outputSchema.Format == "uri" { + dataurlObj, err := dataurl.DecodeString((*prediction.Output).(string)) + if err != nil { + return fmt.Errorf("Failed to decode dataurl: %w", err) } - // Otherwise, fall back to writing file + out = dataurlObj.Data outputPath = "output" - extension := mime.ExtensionByType(output.MimeType) + extension := mime.ExtensionByType(dataurlObj.ContentType()) if extension != "" { outputPath += extension } + } else if outputSchema.Type == "string" { + // Handle strings separately because if we encode it to JSON it will be surrounded by quotes. + s := (*prediction.Output).(string) + out = []byte(s) + } else { + // Treat everything else as JSON -- ints, floats, bools will all convert correctly. + rawJSON, err := json.Marshal(prediction.Output) + if err != nil { + return fmt.Errorf("Failed to encode prediction output as JSON: %w", err) + } + var indentedJSON bytes.Buffer + if err := json.Indent(&indentedJSON, rawJSON, "", " "); err != nil { + return err + } + out = indentedJSON.Bytes() + + // FIXME: this stopped working + // f := colorjson.NewFormatter() + // f.Indent = 2 + // s, _ := f.Marshal(obj) + } + // Write to stdout + if outputPath == "" { + console.Output(string(out)) + return nil + } + + // Fall back to writing file + // Ignore @, to make it behave the same as -i outputPath = strings.TrimPrefix(outputPath, "@") @@ -174,7 +190,10 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o return err } - if _, err := io.Copy(outFile, output.Buffer); err != nil { + if _, err := outFile.Write(out); err != nil { + return err + } + if err := outFile.Close(); err != nil { return err } diff --git a/pkg/image/build.go b/pkg/image/build.go index 3ceaa4ff52..3ff80cf4db 100644 --- a/pkg/image/build.go +++ b/pkg/image/build.go @@ -38,14 +38,10 @@ func Build(cfg *config.Config, dir, imageName string, progressOutput string) err } console.Info("Adding labels to image...") - signature, err := GetTypeSignature(imageName, cfg.Build.GPU) + schema, err := GenerateOpenAPISchema(imageName, cfg.Build.GPU) if err != nil { return fmt.Errorf("Failed to get type signature: %w", err) } - signatureJSON, err := json.Marshal(signature) - if err != nil { - return fmt.Errorf("Failed to convert type signature to JSON: %w", err) - } configJSON, err := json.Marshal(cfg) if err != nil { return fmt.Errorf("Failed to convert config to JSON: %w", err) @@ -54,10 +50,19 @@ func Build(cfg *config.Config, dir, imageName string, progressOutput string) err // built image to get those. But, the escaping of JSON inside a label inside a Dockerfile was gnarly, and // doesn't seem to be a problem here, so do it here instead. labels := map[string]string{ - global.LabelNamespace + "cog_version": global.Version, - global.LabelNamespace + "config": string(bytes.TrimSpace(configJSON)), - global.LabelNamespace + "type_signature": string(signatureJSON), + global.LabelNamespace + "cog_version": global.Version, + global.LabelNamespace + "config": string(bytes.TrimSpace(configJSON)), + } + + // OpenAPI schema is not set if there is no predictor. + if len((*schema).(map[string]interface{})) != 0 { + schemaJSON, err := json.Marshal(schema) + if err != nil { + return fmt.Errorf("Failed to convert type signature to JSON: %w", err) + } + labels[global.LabelNamespace+"openapi_schema"] = string(schemaJSON) } + if err := docker.BuildAddLabelsToImage(imageName, labels); err != nil { return fmt.Errorf("Failed to add labels to image: %w", err) } diff --git a/pkg/image/openapi_schema.go b/pkg/image/openapi_schema.go new file mode 100644 index 0000000000..d8d6e7cd55 --- /dev/null +++ b/pkg/image/openapi_schema.go @@ -0,0 +1,66 @@ +package image + +import ( + "bytes" + "encoding/json" + "fmt" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/replicate/cog/pkg/docker" + "github.com/replicate/cog/pkg/util/console" +) + +// GenerateOpenAPISchema by running the image and executing Cog +// This will be run as part of the build process then added as a label to the image. It can be retrieved more efficiently with the label by using GetOpenAPISchema +func GenerateOpenAPISchema(imageName string, enableGPU bool) (*interface{}, error) { + var stdout bytes.Buffer + var stderr bytes.Buffer + + // FIXME(bfirsh): we could detect this by reading the config label on the image + gpus := "" + if enableGPU { + gpus = "all" + } + + err := docker.RunWithIO(docker.RunOptions{ + Image: imageName, + Args: []string{ + "python", "-m", "cog.command.openapi_schema", + }, + GPUs: gpus, + }, nil, &stdout, &stderr) + + if enableGPU && err == docker.ErrMissingDeviceDriver { + console.Debug(stdout.String()) + console.Debug(stderr.String()) + console.Debug("Missing device driver, re-trying without GPU") + return GenerateOpenAPISchema(imageName, false) + } + + if err != nil { + console.Info(stdout.String()) + console.Info(stderr.String()) + return nil, err + } + var schema *interface{} + if err := json.Unmarshal(stdout.Bytes(), &schema); err != nil { + // Exit code was 0, but JSON was not returned. + // This is verbose, but print so anything that gets printed in Python bubbles up here. + console.Info(stdout.String()) + console.Info(stderr.String()) + return nil, err + } + return schema, nil +} + +func GetOpenAPISchema(imageName string) (*openapi3.T, error) { + image, err := docker.ImageInspect(imageName) + if err != nil { + return nil, fmt.Errorf("Failed to inspect %s: %w", imageName, err) + } + schemaString := image.Config.Labels["org.cogmodel.openapi_schema"] + if schemaString == "" { + return nil, fmt.Errorf("Image %s does not appear to be a Cog model", imageName) + } + return openapi3.NewLoader().LoadFromData([]byte(schemaString)) +} diff --git a/pkg/image/type_signature.go b/pkg/image/type_signature.go deleted file mode 100644 index 9e6867b24f..0000000000 --- a/pkg/image/type_signature.go +++ /dev/null @@ -1,74 +0,0 @@ -package image - -import ( - "bytes" - "encoding/json" - - "github.com/replicate/cog/pkg/docker" - "github.com/replicate/cog/pkg/util/console" -) - -type InputType string - -const ( - InputTypeString InputType = "str" - InputTypeInt InputType = "int" - InputTypeFloat InputType = "float" - InputTypeBool InputType = "bool" - InputTypePath InputType = "Path" -) - -type Input struct { - Name string `json:"name"` - Type InputType `json:"type,omitempty"` - Default *string `json:"default,omitempty"` - Min *string `json:"min,omitempty"` - Max *string `json:"max,omitempty"` - Options *[]string `json:"options,omitempty"` - Help *string `json:"help,omitempty"` -} - -type TypeSignature struct { - Inputs []Input `json:"inputs,omitempty"` -} - -func GetTypeSignature(imageName string, enableGPU bool) (*TypeSignature, error) { - var stdout bytes.Buffer - var stderr bytes.Buffer - - // FIXME(bfirsh): we could detect this by reading the config label on the image - gpus := "" - if enableGPU { - gpus = "all" - } - - err := docker.RunWithIO(docker.RunOptions{ - Image: imageName, - Args: []string{ - "python", "-m", "cog.command.type_signature", - }, - GPUs: gpus, - }, nil, &stdout, &stderr) - - if enableGPU && err == docker.ErrMissingDeviceDriver { - console.Debug(stdout.String()) - console.Debug(stderr.String()) - console.Debug("Missing device driver, re-trying without GPU") - return GetTypeSignature(imageName, false) - } - - if err != nil { - console.Info(stdout.String()) - console.Info(stderr.String()) - return nil, err - } - var signature *TypeSignature - if err := json.Unmarshal(stdout.Bytes(), &signature); err != nil { - // Exit code was 0, but JSON was not returned. - // This is verbose, but print so anything that gets printed in Python bubbles up here. - console.Info(stdout.String()) - console.Info(stderr.String()) - return nil, err - } - return signature, nil -} diff --git a/pkg/predict/input.go b/pkg/predict/input.go index f2b0f0c83f..6d1bda6a1c 100644 --- a/pkg/predict/input.go +++ b/pkg/predict/input.go @@ -1,11 +1,14 @@ package predict import ( + "io/ioutil" + "mime" "path/filepath" "strings" "github.com/mitchellh/go-homedir" "github.com/replicate/cog/pkg/util/console" + "github.com/vincent-petithory/dataurl" ) type Input struct { @@ -50,3 +53,20 @@ func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs { } return input } + +func (inputs *Inputs) toMap() (map[string]string, error) { + keyVals := map[string]string{} + for key, input := range *inputs { + if input.String != nil { + keyVals[key] = *input.String + } else if input.File != nil { + content, err := ioutil.ReadFile(*input.File) + if err != nil { + return keyVals, err + } + mimeType := mime.TypeByExtension(filepath.Ext(*input.File)) + keyVals[key] = dataurl.New(content, mimeType).String() + } + } + return keyVals, nil +} diff --git a/pkg/predict/output.go b/pkg/predict/output.go deleted file mode 100644 index 3c630b4950..0000000000 --- a/pkg/predict/output.go +++ /dev/null @@ -1,14 +0,0 @@ -package predict - -import "io" - -type OutputValue struct { - Buffer io.Reader - MimeType string -} - -type Output struct { - Values map[string]OutputValue - SetupTime float64 - RunTime float64 -} diff --git a/pkg/predict/predictor.go b/pkg/predict/predictor.go index bb4494588b..d6d3b17867 100644 --- a/pkg/predict/predictor.go +++ b/pkg/predict/predictor.go @@ -6,20 +6,29 @@ import ( "fmt" "io" "math/rand" - "mime/multipart" "net/http" - "os" - "path/filepath" - "strconv" - "strings" "time" + "github.com/getkin/kin-openapi/openapi3" "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/util/console" "github.com/replicate/cog/pkg/util/shell" ) +type status string + +type Request struct { + // TODO: could this be Inputs? + Input map[string]string `json:"input"` +} + +type Response struct { + Status status `json:"status"` + Output *interface{} `json:"output"` + Error string `json:"error"` +} + type Predictor struct { runOptions docker.RunOptions @@ -29,6 +38,9 @@ type Predictor struct { } func NewPredictor(runOptions docker.RunOptions) Predictor { + if global.Debug { + runOptions.Env = append(runOptions.Env, "COG_DEBUG=1") + } return Predictor{runOptions: runOptions} } @@ -57,7 +69,7 @@ func (p *Predictor) Start(logsWriter io.Writer) error { } func (p *Predictor) waitForContainerReady() error { - url := fmt.Sprintf("http://localhost:%d/ping", p.port) + url := fmt.Sprintf("http://localhost:%d/", p.port) start := time.Now() for { @@ -91,46 +103,23 @@ func (p *Predictor) Stop() error { return docker.Stop(p.containerID) } -func (p *Predictor) Predict(inputs Inputs) (*Output, error) { - bodyBuffer := new(bytes.Buffer) - - mwriter := multipart.NewWriter(bodyBuffer) - for key, val := range inputs { - if val.File != nil { - w, err := mwriter.CreateFormFile(key, filepath.Base(*val.File)) - if err != nil { - return nil, err - } - file, err := os.Open(*val.File) - if err != nil { - return nil, err - } - if _, err := io.Copy(w, file); err != nil { - return nil, err - } - if err := file.Close(); err != nil { - return nil, err - } - } else { - w, err := mwriter.CreateFormField(key) - if err != nil { - return nil, err - } - if _, err = w.Write([]byte(*val.String)); err != nil { - return nil, err - } - } +func (p *Predictor) Predict(inputs Inputs) (*Response, error) { + inputMap, err := inputs.toMap() + if err != nil { + return nil, err } - if err := mwriter.Close(); err != nil { - return nil, fmt.Errorf("Failed to close form mime writer: %w", err) + request := Request{Input: inputMap} + requestBody, err := json.Marshal(request) + if err != nil { + return nil, err } - url := fmt.Sprintf("http://localhost:%d/predict", p.port) - req, err := http.NewRequest(http.MethodPost, url, bodyBuffer) + url := fmt.Sprintf("http://localhost:%d/predictions", p.port) + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(requestBody)) if err != nil { return nil, fmt.Errorf("Failed to create HTTP request to %s: %w", url, err) } - req.Header.Set("Content-Type", mwriter.FormDataContentType()) + req.Header.Set("Content-Type", "application/json") req.Close = true httpClient := &http.Client{} @@ -140,6 +129,7 @@ func (p *Predictor) Predict(inputs Inputs) (*Output, error) { } defer resp.Body.Close() + // TODO if resp.StatusCode == http.StatusBadRequest { body := struct { Message string `json:"message"` @@ -154,67 +144,28 @@ func (p *Predictor) Predict(inputs Inputs) (*Output, error) { } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("/predict call returned status %d", resp.StatusCode) + return nil, fmt.Errorf("/predictions call returned status %d", resp.StatusCode) } - contentType := resp.Header.Get("Content-Type") - mimeType := strings.Split(contentType, ";")[0] - - buf := new(bytes.Buffer) - if _, err := io.Copy(buf, resp.Body); err != nil { - return nil, fmt.Errorf("Failed to read response: %w", err) + prediction := &Response{} + if err = json.NewDecoder(resp.Body).Decode(prediction); err != nil { + return nil, fmt.Errorf("Failed to decode prediction response: %w", err) } - - setupTime := -1.0 - runTime := -1.0 - setupTimeStr := resp.Header.Get("X-Setup-Time") - if setupTimeStr != "" { - setupTime, err = strconv.ParseFloat(setupTimeStr, 64) - if err != nil { - console.Errorf("Failed to parse setup time '%s' as float: %s", setupTimeStr, err) - } - } - runTimeStr := resp.Header.Get("X-Run-Time") - if runTimeStr != "" { - runTime, err = strconv.ParseFloat(runTimeStr, 64) - if err != nil { - console.Errorf("Failed to parse run time '%s' as float: %s", runTimeStr, err) - } - } - - output := &Output{ - Values: map[string]OutputValue{ - // TODO(andreas): support multiple outputs? - "output": { - Buffer: buf, - MimeType: mimeType, - }, - }, - SetupTime: setupTime, - RunTime: runTime, - } - return output, nil + return prediction, nil } -func (p *Predictor) Help() (*HelpResponse, error) { - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://localhost:%d/help", p.port), nil) +func (p *Predictor) GetSchema() (*openapi3.T, error) { + resp, err := http.Get(fmt.Sprintf("http://localhost:%d/openapi.json", p.port)) if err != nil { - return nil, fmt.Errorf("Failed to create GET request: %w", err) + return nil, err } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("Failed to GET /help: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("/help call returned status %d", resp.StatusCode) + return nil, fmt.Errorf("Failed to get OpenAPI schema: %d", resp.StatusCode) } - help := new(HelpResponse) - if err := json.NewDecoder(resp.Body).Decode(help); err != nil { - return nil, fmt.Errorf("Failed to parse /help body: %w", err) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err } - - return help, nil + return openapi3.NewLoader().LoadFromData(body) } diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 0aedf8697b..8ececd7369 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -1,14 +1,10 @@ -from pathlib import Path - - -from .server.redis_queue import RedisQueueWorker -from .input import input from .predictor import Predictor +from .types import File, Input, Path __all__ = [ - "Predictor", - "input", + "File", + "Input", "Path", - "RedisQueueWorker", + "Predictor", ] diff --git a/python/cog/command/type_signature.py b/python/cog/command/openapi_schema.py similarity index 78% rename from python/cog/command/type_signature.py rename to python/cog/command/openapi_schema.py index 36448af69c..3262b7fe66 100644 --- a/python/cog/command/type_signature.py +++ b/python/cog/command/openapi_schema.py @@ -4,13 +4,15 @@ This prints a JSON object describing the inputs of the model. """ import json + from ..errors import ConfigDoesNotExist, PredictorNotSet from ..suppress_output import suppress_output from ..predictor import load_predictor +from ..server.http import create_app if __name__ == "__main__": - obj = {} + schema = {} try: with suppress_output(): predictor = load_predictor() @@ -19,5 +21,6 @@ # Not an error, there just isn't anything. pass else: - obj = predictor.get_type_signature() - print(json.dumps(obj, indent=2)) + app = create_app(predictor) + schema = app.openapi() + print(json.dumps(schema, indent=2)) diff --git a/python/cog/files.py b/python/cog/files.py new file mode 100644 index 0000000000..03cd9b31b9 --- /dev/null +++ b/python/cog/files.py @@ -0,0 +1,29 @@ +import base64 +import io +import mimetypes +import os + +import requests + + +def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str: + fh.seek(0) + + if output_file_prefix is not None: + name = getattr(fh, "name", "output") + url = output_file_prefix + os.path.basename(name) + resp = requests.put(url, files={"file": fh}) + resp.raise_for_status() + return url + + b = fh.read() + # The file handle is strings, not bytes + if isinstance(b, str): + b = b.encode("utf-8") + encoded_body = base64.b64encode(b) + if getattr(fh, "name", None): + mime_type = mimetypes.guess_type(fh.name)[0] + else: + mime_type = "application/octet-stream" + s = encoded_body.decode("utf-8") + return f"data:{mime_type};base64,{s}" diff --git a/python/cog/input.py b/python/cog/input.py deleted file mode 100644 index b0063c2124..0000000000 --- a/python/cog/input.py +++ /dev/null @@ -1,202 +0,0 @@ -from dataclasses import dataclass -import functools -from numbers import Number -import os -from pathlib import Path -import shutil -import tempfile -from typing import Any, Optional, List, Callable, Dict, Type - -from werkzeug.datastructures import FileStorage - -from .predictor import Predictor - -_VALID_INPUT_TYPES = frozenset([str, int, float, bool, Path]) -UNSPECIFIED = object() - - -@dataclass -class InputSpec: - name: str - type: Type - default: Any = UNSPECIFIED - min: Optional[Number] = None - max: Optional[Number] = None - options: Optional[List[Any]] = None - help: Optional[str] = None - - -class InputValidationError(Exception): - pass - - -def input(name, type, default=UNSPECIFIED, min=None, max=None, options=None, help=None): - """ - A decorator that defines an input for a predict() method. - """ - type_name = get_type_name(type) - if type not in _VALID_INPUT_TYPES: - type_list = ", ".join([type_name(t) for t in _VALID_INPUT_TYPES]) - raise ValueError( - f"{type_name} is not a valid input type. Valid types are: {type_list}" - ) - if (min is not None or max is not None) and not _is_numeric_type(type): - raise ValueError(f"Non-numeric type {type_name} cannot have min and max values") - - if options is not None and type == Path: - raise ValueError(f"File type cannot have options") - - if options is not None and len(options) < 2: - raise ValueError(f"Options list must have at least two items") - - def wrapper(f): - if not hasattr(f, "_inputs"): - f._inputs = [] - - if name in (i.name for i in f._inputs): - raise ValueError(f"{name} is already defined as an argument") - - if type == Path and default is not UNSPECIFIED and default is not None: - raise TypeError("Cannot use default with Path type") - - # Insert at start of list because decorators are run bottom up - f._inputs.insert( - 0, - InputSpec( - name=name, - type=type, - default=default, - min=min, - max=max, - options=options, - help=help, - ), - ) - - @functools.wraps(f) - def wraps(self, **kwargs): - if not isinstance(self, Predictor): - raise TypeError("{self} is not an instance of cog.Predictor") - return f(self, **kwargs) - - return wraps - - return wrapper - - -def get_type_name(typ: Type) -> str: - if typ == str: - return "str" - if typ == int: - return "int" - if typ == float: - return "float" - if typ == bool: - return "bool" - if typ == Path: - return "Path" - return str(typ) - - -def _is_numeric_type(typ: Type) -> bool: - return typ in (int, float) - - -def validate_and_convert_inputs( - predictor: Predictor, raw_inputs: Dict[str, Any], cleanup_functions: List[Callable] -) -> Dict[str, Any]: - input_specs = predictor.predict._inputs - inputs = {} - - for input_spec in input_specs: - if input_spec.name in raw_inputs: - val = raw_inputs[input_spec.name] - - if input_spec.type == Path: - if not isinstance(val, FileStorage): - raise InputValidationError( - f"Could not convert file input {input_spec.name} to {get_type_name(input_spec.type)}", - ) - if val.filename is None: - raise InputValidationError( - f"No filename is provided for file input {input_spec.name}" - ) - - temp_dir = tempfile.mkdtemp() - cleanup_functions.append(lambda: shutil.rmtree(temp_dir)) - - temp_path = os.path.join(temp_dir, val.filename) - with open(temp_path, "wb") as f: - f.write(val.stream.read()) - converted = Path(temp_path) - - elif input_spec.type == int: - try: - converted = int(val) - except ValueError: - raise InputValidationError( - f"Could not convert {input_spec.name}={val} to int" - ) - - elif input_spec.type == float: - try: - converted = float(val) - except ValueError: - raise InputValidationError( - f"Could not convert {input_spec.name}={val} to float" - ) - - elif input_spec.type == bool: - if val.lower() not in ["true", "false"]: - raise InputValidationError( - f"{input_spec.name}={val} is not a boolean" - ) - converted = val.lower() == "true" - - elif input_spec.type == str: - if isinstance(val, FileStorage): - raise InputValidationError( - f"Could not convert file input {input_spec.name} to str" - ) - converted = val - - else: - raise TypeError( - f"Internal error: Input type {input_spec} is not a valid input type" - ) - - if _is_numeric_type(input_spec.type): - if input_spec.max is not None and converted > input_spec.max: - raise InputValidationError( - f"Value {converted} is greater than the max value {input_spec.max}" - ) - if input_spec.min is not None and converted < input_spec.min: - raise InputValidationError( - f"Value {converted} is less than the min value {input_spec.min}" - ) - - if input_spec.options is not None: - if converted not in input_spec.options: - valid_str = ", ".join([str(o) for o in input_spec.options]) - raise InputValidationError( - f"Value {converted} is not an option. Valid options are: {valid_str}" - ) - - else: - if input_spec.default is not UNSPECIFIED: - converted = input_spec.default - else: - raise InputValidationError( - f"Missing expected argument: {input_spec.name}" - ) - inputs[input_spec.name] = converted - - expected_names = set(s.name for s in input_specs) - raw_keys = set(raw_inputs.keys()) - extraneous_keys = raw_keys - expected_names - if extraneous_keys: - raise InputValidationError( - f"Extraneous input keys: {', '.join(extraneous_keys)}" - ) - - return inputs diff --git a/python/cog/json.py b/python/cog/json.py index 9a48cae7db..e179800ae4 100644 --- a/python/cog/json.py +++ b/python/cog/json.py @@ -1,46 +1,44 @@ -import json +from enum import Enum +import io -# Based on keepsake.json +from pydantic import BaseModel + +from .types import Path -# We load numpy but not torch or tensorflow because numpy loads very fast and -# they're probably using it anyway -# fmt: off try: import numpy as np # type: ignore + has_numpy = True except ImportError: has_numpy = False -# fmt: on - -# Tensorflow takes a solid 10 seconds to import on a modern Macbook Pro, so instead of importing, -# do this instead -def _is_tensorflow_tensor(obj): - # e.g. __module__='tensorflow.python.framework.ops', __name__='EagerTensor' - return ( - obj.__class__.__module__.split(".")[0] == "tensorflow" - and "Tensor" in obj.__class__.__name__ - ) - - -def _is_torch_tensor(obj): - return (obj.__class__.__module__, obj.__class__.__name__) == ("torch", "Tensor") - - -class CustomJSONEncoder(json.JSONEncoder): - def default(self, o): - if has_numpy: - if isinstance(o, np.integer): - return int(o) - elif isinstance(o, np.floating): - return float(o) - elif isinstance(o, np.ndarray): - return o.tolist() - if _is_torch_tensor(o): - return o.detach().tolist() - if _is_tensorflow_tensor(o): - return o.numpy().tolist() - return json.JSONEncoder.default(self, o) - - -def to_json(obj): - return json.dumps(obj, cls=CustomJSONEncoder) + + +def encode_json(obj, upload_file): + """ + Returns a JSON-compatible version of the object. It will encode any Pydantic models and custom types. + + When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files. + + Somewhat based on FastAPI's jsonable_encoder(). + """ + if isinstance(obj, BaseModel): + return encode_json(obj.dict(exclude_unset=True), upload_file) + if isinstance(obj, dict): + return {key: encode_json(value, upload_file) for key, value in obj.items()} + if isinstance(obj, list): + return [encode_json(value, upload_file) for value in obj] + if isinstance(obj, Enum): + return obj.value + if isinstance(obj, Path): + with obj.open("rb") as f: + return upload_file(f) + if isinstance(obj, io.IOBase): + return upload_file(obj) + if has_numpy: + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + return obj diff --git a/python/cog/predictor.py b/python/cog/predictor.py index a7abd3c1cb..c7130fdbed 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -1,17 +1,19 @@ from abc import ABC, abstractmethod +from collections.abc import Generator import importlib +import inspect import os.path from pathlib import Path -from typing import Dict, Any +import typing +from pydantic import create_model +from pydantic.fields import FieldInfo +# Added in Python 3.8. Can be from typing if we drop support for <3.8. +from typing_extensions import Literal, get_origin, get_args import yaml from .errors import ConfigDoesNotExist, PredictorNotSet - - -# TODO(andreas): handle directory input -# TODO(andreas): handle List[Dict[str, int]], etc. -# TODO(andreas): model-level documentation +from .types import Input class Predictor(ABC): @@ -22,36 +24,6 @@ def setup(self): def predict(self, **kwargs): pass - def get_type_signature(self): - """ - Returns a dict describing the inputs of the model. - """ - from .input import ( - get_type_name, - UNSPECIFIED, - ) - - inputs = [] - if hasattr(self.predict, "_inputs"): - input_specs = self.predict._inputs - for spec in input_specs: - arg: Dict[str, Any] = { - "name": spec.name, - "type": get_type_name(spec.type), - } - if spec.help: - arg["help"] = spec.help - if spec.default is not UNSPECIFIED: - arg["default"] = str(spec.default) # TODO: don't string this - if spec.min is not None: - arg["min"] = str(spec.min) # TODO: don't string this - if spec.max is not None: - arg["max"] = str(spec.max) # TODO: don't string this - if spec.options is not None: - arg["options"] = [str(o) for o in spec.options] - inputs.append(arg) - return {"inputs": inputs} - def run_prediction(predictor, inputs, cleanup_functions): """ @@ -88,3 +60,47 @@ def load_predictor(): spec.loader.exec_module(module) predictor_class = getattr(module, class_name) return predictor_class() + + +def get_input_type(predictor: Predictor): + signature = inspect.signature(predictor.predict) + create_model_kwargs = {} + + order = 0 + + for name, parameter in signature.parameters.items(): + if not parameter.annotation: + # TODO: perhaps should throw error if there are arguments not annotated? + continue + + # if no default is specified, create an empty, required input + if parameter.default is inspect.Signature.empty: + default = Input() + else: + default = parameter.default + # If user hasn't used `Input`, then wrap it in that + if not isinstance(default, FieldInfo): + default = Input(default=default) + + # Fields aren't ordered, so use this pattern to ensure defined order + # https://github.com/go-openapi/spec/pull/116 + default.extra["x-order"] = order + order += 1 + + create_model_kwargs[name] = (parameter.annotation, default) + + return create_model("Input", **create_model_kwargs) + + +def get_output_type(predictor: Predictor): + signature = inspect.signature(predictor.predict) + if signature.return_annotation is inspect.Signature.empty: + OutputType = Literal[None] + else: + OutputType = signature.return_annotation + + # The type that goes in the response is the type that is yielded + if get_origin(OutputType) is Generator: + OutputType = get_args(OutputType)[0] + + return OutputType diff --git a/python/cog/response.py b/python/cog/response.py new file mode 100644 index 0000000000..982caf335d --- /dev/null +++ b/python/cog/response.py @@ -0,0 +1,24 @@ +import enum +from typing import Any + +from pydantic import BaseModel, Field + + +class Status(enum.Enum): + PROCESSING = "processing" + SUCCESS = "success" + FAILED = "failed" # FIXME: "failure"? + + +def get_response_type(OutputType: Any): + class Response(BaseModel): + """The status of a prediction.""" + + status: Status = Field(...) + output: OutputType = None + error: str = None + + class Config: + arbitrary_types_allowed = True + + return Response diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 0b14c60a2c..4277266c81 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -1,106 +1,102 @@ -from pathlib import Path -import sys -import time +import logging +import os import types -from flask import Flask, send_file, request, jsonify, Response - -from ..input import ( - validate_and_convert_inputs, - InputValidationError, -) -from ..json import to_json -from ..predictor import Predictor, run_prediction, load_predictor - - -class HTTPServer: - def __init__(self, predictor: Predictor): - self.predictor = predictor - - def make_app(self) -> Flask: - start_time = time.time() - self.predictor.setup() - app = Flask(__name__) - setup_time = time.time() - start_time - - @app.route("/predict", methods=["POST"]) - @app.route("/infer", methods=["POST"]) # deprecated - def handle_request(): - start_time = time.time() - - cleanup_functions = [] - try: - raw_inputs = {} - for key, val in request.form.items(): - raw_inputs[key] = val - for key, val in request.files.items(): - if key in raw_inputs: - return _abort400( - f"Duplicated argument name in form and files: {key}" - ) - raw_inputs[key] = val - - if hasattr(self.predictor.predict, "_inputs"): - try: - inputs = validate_and_convert_inputs( - self.predictor, raw_inputs, cleanup_functions - ) - except InputValidationError as e: - return _abort400(str(e)) - else: - inputs = raw_inputs - - result = run_prediction(self.predictor, inputs, cleanup_functions) - run_time = time.time() - start_time - return self.create_response(result, setup_time, run_time) - finally: - for cleanup_function in cleanup_functions: - try: - cleanup_function() - except Exception as e: - sys.stderr.write(f"Cleanup function caught error: {e}") - - @app.route("/ping") - def ping(): - return "PONG" - - @app.route("/type-signature") - def type_signature(): - return jsonify(self.predictor.get_type_signature()) - - return app - - def start_server(self): - app = self.make_app() - app.run(host="0.0.0.0", port=5000, threaded=False, processes=1) - - def create_response(self, result, setup_time, run_time): +from fastapi import Body, FastAPI, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel, ValidationError +import uvicorn + + +from ..files import upload_file +from ..json import encode_json +from ..predictor import Predictor, get_input_type, get_output_type, load_predictor +from ..response import Status, get_response_type + +logger = logging.getLogger("cog") + + +def create_app(predictor: Predictor) -> FastAPI: + app = FastAPI( + title="Cog", # TODO: mention model name? + # version=None # TODO + ) + app.on_event("startup")(predictor.setup) + + @app.get("/") + def root(): + return { + # "cog_version": "", # TODO + "docs_url": "/docs", + "openapi_url": "/openapi.json", + } + + InputType = get_input_type(predictor) + + class Request(BaseModel): + input: InputType = None + output_file_prefix: str = None + + def predict(request: Request = Body(default=None)): + if request is None or request.input is None: + output = predictor.predict() + else: + output = predictor.predict(**request.input.dict()) + output_file_prefix = None + if request: + output_file_prefix = request.output_file_prefix + # loop over generator function to get the last result - if isinstance(result, types.GeneratorType): + if isinstance(output, types.GeneratorType): last_result = None - for iteration in enumerate(result): + for iteration in enumerate(output): last_result = iteration # last result is a tuple with (index, value) - result = last_result[1] + output = last_result[1] - if isinstance(result, Path): - resp = send_file(str(result)) - elif isinstance(result, str): - resp = Response(result, mimetype="text/plain") - else: - resp = Response(to_json(result), mimetype="application/json") - resp.headers["X-Setup-Time"] = setup_time - resp.headers["X-Run-Time"] = run_time - return resp + OutputType = get_output_type(predictor) + Response = get_response_type(OutputType) + + try: + response = Response(status=Status.SUCCESS, output=output) + except ValidationError as e: + logger.error( + f"""The return value of predict() was not valid: + +{e} + +Check that your predict function is in this form, where `output_type` is the same as the type you are returning (e.g. `str`): + + def predict(...) -> output_type: + ... +""" + ) + raise HTTPException(status_code=500) + encoded_response = encode_json( + response, upload_file=lambda fh: upload_file(fh, output_file_prefix) + ) + return JSONResponse(content=encoded_response) + # response_model is purely for generating schema. + # We generate Response again in the request so we can set file output paths correctly, etc. + OutputType = get_output_type(predictor) + app.post( + "/predictions", + response_model=get_response_type(OutputType), + response_model_exclude_unset=True, + )(predict) -def _abort400(message): - resp = jsonify({"message": message}) - resp.status_code = 400 - return resp + return app if __name__ == "__main__": predictor = load_predictor() - server = HTTPServer(predictor) - server.start_server() + app = create_app(predictor) + uvicorn.run( + app, + host="0.0.0.0", + port=5000, + log_level="debug" if os.environ.get("COG_DEBUG") else "warning", + # Single worker to safely run on GPUs. + workers=1, + ) diff --git a/python/cog/types.py b/python/cog/types.py new file mode 100644 index 0000000000..e502f1f4f4 --- /dev/null +++ b/python/cog/types.py @@ -0,0 +1,128 @@ +import io +import mimetypes +import os +import base64 +import pathlib +import requests +import shutil +import tempfile +from typing import Any, Optional +from urllib.parse import urlparse + +from pydantic import Field +from pydantic.typing import NoArgAnyCallable + + +def Input( + default=..., + default_factory: Optional[NoArgAnyCallable] = None, + alias: str = None, + title: str = None, + description: str = None, + const: bool = None, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + multiple_of: float = None, + min_items: int = None, + max_items: int = None, + min_length: int = None, + max_length: int = None, + allow_mutation: bool = True, + regex: str = None, + **kwargs: Any, +): + """Input is similar to pydantic.Field, but doesn't require a default value to be the first argument.""" + return Field( + default, + default_factory=default_factory, + alias=alias, + title=title, + description=description, + const=const, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + min_items=min_items, + max_items=max_items, + min_length=min_length, + max_length=max_length, + allow_mutation=allow_mutation, + regex=regex, + **kwargs, + ) + + +def get_filename(url): + parsed_url = urlparse(url) + if parsed_url.scheme == "data": + header, _ = parsed_url.path.split(",", 1) + mime_type, _ = header.split(";", 1) + return "file" + mimetypes.guess_extension(mime_type) + return os.path.basename(parsed_url.path) + + +class File(io.IOBase): + validate_always = True + + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, value: Any) -> io.IOBase: + if isinstance(value, io.IOBase): + return value + + parsed_url = urlparse(value) + if parsed_url.scheme == "data": + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URIs + # TODO: decode properly. this maybe? https://github.com/fcurella/python-datauri/ + header, encoded = parsed_url.path.split(",", 1) + return io.BytesIO(base64.b64decode(encoded)) + elif parsed_url.scheme == "http" or parsed_url.scheme == "https": + resp = requests.get(value, stream=True) + resp.raise_for_status() + resp.raw.decode_content = True + return resp.raw + else: + raise ValueError( + f"'{parsed_url.scheme}' is not a valid URL scheme. 'data', 'http', or 'https' is supported." + ) + + @classmethod + def __modify_schema__(cls, field_schema): + """Defines what this type should be in openapi.json""" + # https://json-schema.org/understanding-json-schema/reference/string.html#uri-template + field_schema.update(type="string", format="uri") + + +class Path(pathlib.PosixPath): + validate_always = True + + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, value: Any) -> pathlib.Path: + if isinstance(value, pathlib.Path): + return value + + src = File.validate(value) + # TODO: cleanup! + temp_dir = tempfile.mkdtemp() + temp_path = os.path.join(temp_dir, get_filename(value)) + with open(temp_path, "wb") as dest: + shutil.copyfileobj(src, dest) + + return cls(dest.name) + + @classmethod + def __modify_schema__(cls, field_schema): + """Defines what this type should be in openapi.json""" + # https://json-schema.org/understanding-json-schema/reference/string.html#uri-template + field_schema.update(type="string", format="uri") diff --git a/python/setup.py b/python/setup.py index 8f9f54a417..d60fccb9e3 100644 --- a/python/setup.py +++ b/python/setup.py @@ -15,11 +15,14 @@ license="Apache License 2.0", python_requires=">=3.6.0", install_requires=[ - # intionally loose. perhaps these should be vendored to not collide with user code? - "flask>=2,<3", + # intentionally loose. perhaps these should be vendored to not collide with user code? + "fastapi>=0.6,<1", + "pydantic>=1,<2", + "PyYAML", "redis>=4,<5", "requests>=2,<3", - "PyYAML", + "typing_extensions>=3.7.4", + "uvicorn[standard]>=0.12,<1", ], packages=setuptools.find_packages(), ) diff --git a/python/tests/server/test_http.py b/python/tests/server/test_http.py index 618f36f669..043fa1e08c 100644 --- a/python/tests/server/test_http.py +++ b/python/tests/server/test_http.py @@ -1,19 +1,23 @@ -from flask.testing import FlaskClient +import base64 import io import os -from pathlib import Path import tempfile +from typing import Generator from unittest import mock +from fastapi.testclient import TestClient from PIL import Image +import pytest + import cog -from cog.server.http import HTTPServer +from cog import Input, File, Path + +from cog.server.http import create_app -def make_client(version) -> FlaskClient: - app = HTTPServer(version).make_app() - app.config["TESTING"] = True - with app.test_client() as client: +def make_client(predictor: cog.Predictor, **kwargs) -> TestClient: + app = create_app(predictor) + with TestClient(app, **kwargs) as client: return client @@ -22,89 +26,192 @@ class Predictor(cog.Predictor): def setup(self): self.foo = "bar" - def predict(self): + def predict(self) -> str: return self.foo client = make_client(Predictor()) - resp = client.post("/predict") + resp = client.post("/predictions") assert resp.status_code == 200 - assert resp.data == b"bar" + assert resp.json() == {"status": "success", "output": "bar"} -def test_type_signature(): +def test_openapi_specification(): class Predictor(cog.Predictor): - @cog.input("text", type=str, help="Some text") - @cog.input("num1", type=int, help="First number") - @cog.input("num2", type=int, default=10, help="Second number") - @cog.input("path", type=Path, help="A file path") - def predict(self, text, num1, num2, path): + def predict( + self, + no_default: str, + default_without_input: str = "default", + input_with_default: int = Input(title="Some number", default=10), + path: Path = Input(title="Some path"), + image: File = Input(title="Some path"), + ) -> str: pass client = make_client(Predictor()) - resp = client.get("/type-signature") + resp = client.get("/openapi.json") assert resp.status_code == 200 - assert resp.json == { - "inputs": [ - { - "name": "text", - "type": "str", - "help": "Some text", - }, - { - "name": "num1", - "type": "int", - "help": "First number", + print(resp.json()) + assert resp.json() == { + "openapi": "3.0.2", + "info": {"title": "Cog", "version": "0.1.0"}, + "paths": { + "/": { + "get": { + "summary": "Root", + "operationId": "root__get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + } }, - { - "name": "num2", - "type": "int", - "help": "Second number", - "default": "10", + "/predictions": { + "post": { + "summary": "Predict", + "operationId": "predict_predictions_post", + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Request"} + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Response"} + } + }, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } }, - { - "name": "path", - "type": "Path", - "help": "A file path", - }, - ] + }, + "components": { + "schemas": { + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + "Input": { + "title": "Input", + "required": ["no_default", "path", "image"], + "type": "object", + "properties": { + "no_default": { + "title": "No Default", + "type": "string", + "x-order": 0, + }, + "default_without_input": { + "title": "Default Without Input", + "type": "string", + "default": "default", + "x-order": 1, + }, + "input_with_default": { + "title": "Some number", + "type": "integer", + "default": 10, + "x-order": 2, + }, + "path": { + "title": "Some path", + "type": "string", + "format": "uri", + "x-order": 3, + }, + "image": { + "title": "Some path", + "type": "string", + "format": "uri", + "x-order": 4, + }, + }, + }, + "Request": { + "title": "Request", + "type": "object", + "properties": { + "input": {"$ref": "#/components/schemas/Input"}, + "output_file_prefix": { + "title": "Output File Prefix", + "type": "string", + }, + }, + }, + "Response": { + "title": "Response", + "required": ["status"], + "type": "object", + "properties": { + "status": {"$ref": "#/components/schemas/Status"}, + "output": {"title": "Output", "type": "string"}, + "error": {"title": "Error", "type": "string"}, + }, + "description": "The status of a prediction.", + }, + "Status": { + "title": "Status", + "enum": ["processing", "success", "failed"], + "description": "An enumeration.", + }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + } + }, } def test_yielding_strings_from_generator_predictors(): class Predictor(cog.Predictor): - def predict(self): + def predict(self) -> Generator[str, None, None]: predictions = ["foo", "bar", "baz"] for prediction in predictions: yield prediction client = make_client(Predictor()) - resp = client.post("/predict") - assert resp.status_code == 200 - assert resp.content_type == "text/plain; charset=utf-8" - assert resp.data == b"baz" - - -def test_yielding_json_from_generator_predictors(): - class Predictor(cog.Predictor): - def predict(self): - predictions = [ - {"meaning_of_life": 40}, - {"meaning_of_life": 41}, - {"meaning_of_life": 42}, - ] - for prediction in predictions: - yield prediction - - client = make_client(Predictor()) - resp = client.post("/predict") + resp = client.post("/predictions") assert resp.status_code == 200 - assert resp.content_type == "application/json" - assert resp.data == b'{"meaning_of_life": 42}' + assert resp.json() == {"status": "success", "output": "baz"} def test_yielding_files_from_generator_predictors(): class Predictor(cog.Predictor): - def predict(self): + def predict(self) -> Generator[cog.Path, None, None]: colors = ["red", "blue", "yellow"] for i, color in enumerate(colors): temp_dir = tempfile.mkdtemp() @@ -114,16 +221,17 @@ def predict(self): yield Path(temp_path) client = make_client(Predictor()) - resp = client.post("/predict") + resp = client.post("/predictions") assert resp.status_code == 200 - # need both image/bmp and image/x-ms-bmp until https://bugs.python.org/issue44211 is fixed - assert resp.content_type in ["image/bmp", "image/x-ms-bmp"] - image = Image.open(io.BytesIO(resp.data)) + header, b64data = resp.json()["output"].split(",", 1) + image = Image.open(io.BytesIO(base64.b64decode(b64data))) image_color = Image.Image.getcolors(image)[0][1] assert image_color == (255, 255, 0) # yellow +# TODO: timing +@pytest.mark.skip @mock.patch("time.time", return_value=0.0) def test_timing(time_mock): class Predictor(cog.Predictor): @@ -135,7 +243,7 @@ def predict(self): return "" client = make_client(Predictor()) - resp = client.post("/predict") + resp = client.post("/predictions") assert resp.status_code == 200 assert float(resp.headers["X-Setup-Time"]) == 1.0 assert float(resp.headers["X-Run-Time"]) == 2.0 diff --git a/python/tests/server/test_http_input.py b/python/tests/server/test_http_input.py index 9080835707..bb1db2ada6 100644 --- a/python/tests/server/test_http_input.py +++ b/python/tests/server/test_http_input.py @@ -1,281 +1,269 @@ -import io +import base64 +from enum import Enum import os -from pathlib import Path +import tempfile -import pytest +from PIL import Image +from pydantic import BaseModel +import responses import cog +from cog import Input, Path, File from .test_http import make_client def test_no_input(): class Predictor(cog.Predictor): - def predict(self): + def predict(self) -> str: return "foobar" client = make_client(Predictor()) - resp = client.post("/predict") + resp = client.post("/predictions") assert resp.status_code == 200 - assert resp.data == b"foobar" + assert resp.json() == {"status": "success", "output": "foobar"} def test_good_str_input(): class Predictor(cog.Predictor): - @cog.input("text", type=str) - def predict(self, text): + def predict(self, text: str) -> str: return text client = make_client(Predictor()) - resp = client.post("/predict", data={"text": "baz"}) + resp = client.post("/predictions", json={"input": {"text": "baz"}}) assert resp.status_code == 200 - assert resp.data == b"baz" + assert resp.json() == {"status": "success", "output": "baz"} def test_good_int_input(): class Predictor(cog.Predictor): - @cog.input("num", type=int) - def predict(self, num): - return str(num ** 3) + def predict(self, num: int) -> int: + return num ** 3 client = make_client(Predictor()) - resp = client.post("/predict", data={"num": 3}) + resp = client.post("/predictions", json={"input": {"num": 3}}) assert resp.status_code == 200 - assert resp.data == b"27" - resp = client.post("/predict", data={"num": -3}) + assert resp.json() == {"output": 27, "status": "success"} + resp = client.post("/predictions", json={"input": {"num": -3}}) assert resp.status_code == 200 - assert resp.data == b"-27" + assert resp.json() == {"output": -27, "status": "success"} def test_bad_int_input(): class Predictor(cog.Predictor): - @cog.input("num", type=int) - def predict(self, num): - return str(num ** 2) + def predict(self, num: int) -> int: + return num ** 2 client = make_client(Predictor()) - resp = client.post("/predict", data={"num": "foo"}) - assert resp.status_code == 400 + resp = client.post("/predictions", json={"input": {"num": "foo"}}) + assert resp.json() == { + "detail": [ + { + "loc": ["body", "input", "num"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] + } + assert resp.status_code == 422 def test_default_int_input(): class Predictor(cog.Predictor): - @cog.input("num", type=int, default=5) - def predict(self, num): - return str(num ** 2) + def predict(self, num: int = Input(default=5)) -> int: + return num ** 2 client = make_client(Predictor()) - resp = client.post("/predict", data={"num": 3}) + + resp = client.post("/predictions", json={"input": {}}) assert resp.status_code == 200 - assert resp.data == b"9" - resp = client.post("/predict") + assert resp.json() == {"output": 25, "status": "success"} + + resp = client.post("/predictions", json={"input": {"num": 3}}) assert resp.status_code == 200 - assert resp.data == b"25" + assert resp.json() == {"output": 9, "status": "success"} -def test_good_float_input(): +def test_file_input_data_url(): class Predictor(cog.Predictor): - @cog.input("num", type=float) - def predict(self, num): - return str(num ** 3) + def predict(self, file: File) -> str: + return file.read() client = make_client(Predictor()) - resp = client.post("/predict", data={"num": 3}) - assert resp.status_code == 200 - assert resp.data == b"27.0" - resp = client.post("/predict", data={"num": 3.5}) - assert resp.status_code == 200 - assert resp.data == b"42.875" - resp = client.post("/predict", data={"num": -3.5}) + resp = client.post( + "/predictions", + json={ + "input": { + "file": "data:text/plain;base64," + + base64.b64encode(b"bar").decode("utf-8") + } + }, + ) + assert resp.json() == {"output": "bar", "status": "success"} assert resp.status_code == 200 - assert resp.data == b"-42.875" -def test_bad_float_input(): +@responses.activate +def test_file_input_with_http_url(): class Predictor(cog.Predictor): - @cog.input("num", type=float) - def predict(self, num): - return str(num ** 2) - - client = make_client(Predictor()) - resp = client.post("/predict", data={"num": "foo"}) - assert resp.status_code == 400 + def predict(self, file: File) -> str: + return file.read() - -def test_good_bool_input(): - class Predictor(cog.Predictor): - @cog.input("flag", type=bool) - def predict(self, flag): - if flag: - return "yes" - else: - return "no" + responses.add(responses.GET, "http://example.com/foo.txt", body="hello") client = make_client(Predictor()) - resp = client.post("/predict", data={"flag": True}) - assert resp.status_code == 200 - assert resp.data == b"yes" - resp = client.post("/predict", data={"flag": False}) - assert resp.status_code == 200 - assert resp.data == b"no" + resp = client.post( + "/predictions", + json={"input": {"file": "http://example.com/foo.txt"}}, + ) + assert resp.json() == {"output": "hello", "status": "success"} -def test_good_path_input(): +def test_path_input_data_url(): class Predictor(cog.Predictor): - @cog.input("path", type=Path) - def predict(self, path): - with open(path) as f: - return f.read() + " " + os.path.basename(path) + def predict(self, path: Path) -> str: + with open(path) as fh: + extension = fh.name.split(".")[-1] + return f"{extension} {fh.read()}" client = make_client(Predictor()) - path_data = (io.BytesIO(b"bar"), "foo.txt") resp = client.post( - "/predict", data={"path": path_data}, content_type="multipart/form-data" + "/predictions", + json={ + "input": { + "path": "data:text/plain;base64," + + base64.b64encode(b"bar").decode("utf-8") + } + }, ) + assert resp.json() == {"output": "txt bar", "status": "success"} assert resp.status_code == 200 - assert resp.data == b"bar foo.txt" -def test_bad_path_input(): +@responses.activate +def test_file_input_with_http_url(): class Predictor(cog.Predictor): - @cog.input("path", type=Path) - def predict(self, path): - with open(path) as f: - return f.read() + " " + os.path.basename(path) + def predict(self, path: Path) -> str: + with open(path) as fh: + extension = fh.name.split(".")[-1] + return f"{extension} {fh.read()}" + + responses.add(responses.GET, "http://example.com/foo.txt", body="hello") client = make_client(Predictor()) - resp = client.post("/predict", data={"path": "bar"}) - assert resp.status_code == 400 + resp = client.post( + "/predictions", + json={"input": {"path": "http://example.com/foo.txt"}}, + ) + assert resp.json() == {"output": "txt hello", "status": "success"} -def test_default_path_input(): +def test_file_bad_input(): class Predictor(cog.Predictor): - @cog.input("path", type=Path, default=None) - def predict(self, path): - if path is None: - return "noneee" - with open(path) as f: - return f.read() + " " + os.path.basename(path) + def predict(self, file: File) -> str: + return file.read() client = make_client(Predictor()) - path_data = (io.BytesIO(b"bar"), "foo.txt") resp = client.post( - "/predict", data={"path": path_data}, content_type="multipart/form-data" + "/predictions", + json={"input": {"file": "foo"}}, ) - assert resp.status_code == 200 - assert resp.data == b"bar foo.txt" - resp = client.post("/predict", data={}) - assert resp.status_code == 200 - assert resp.data == b"noneee" - + assert resp.status_code == 422 -@pytest.mark.skip("This should work but doesn't at the moment") -def test_bad_input_name(): - with pytest.raises(TypeError): - class Predictor(cog.Predictor): - @cog.input("text", type=str) - def predict(self, bad): - return "bar" - - -def test_extranous_input_keys(): +def test_path_output_file(): class Predictor(cog.Predictor): - @cog.input("text", type=str) - def predict(self, text): - return text + def predict(self) -> Path: + temp_dir = tempfile.mkdtemp() + temp_path = os.path.join(temp_dir, "my_file.bmp") + img = Image.new("RGB", (255, 255), "red") + img.save(temp_path) + return Path(temp_path) client = make_client(Predictor()) - resp = client.post("/predict", data={"text": "baz", "text2": "qux"}) - assert resp.status_code == 400 + res = client.post("/predictions") + assert res.status_code == 200 + header, b64data = res.json()["output"].split(",", 1) + # need both image/bmp and image/x-ms-bmp until https://bugs.python.org/issue44211 is fixed + assert header in ["data:image/bmp;base64", "data:image/x-ms-bmp;base64"] + assert len(base64.b64decode(b64data)) == 195894 + +def test_extranous_input_keys(): + class Input(BaseModel): + text: str -def test_min_max(): class Predictor(cog.Predictor): - @cog.input("num1", type=float, min=3, max=10.5) - @cog.input("num2", type=float, min=-4) - @cog.input("num3", type=int, max=-4) - def predict(self, num1, num2, num3): - return num1 + num2 + num3 + def predict(self, input: Input): + return input.text client = make_client(Predictor()) - resp = client.post("/predict", data={"num1": 3, "num2": -4, "num3": -4}) - assert resp.status_code == 200 - assert resp.data == b"-5.0" - resp = client.post("/predict", data={"num1": 2, "num2": -4, "num3": -4}) - assert resp.status_code == 400 - resp = client.post("/predict", data={"num1": 3, "num2": -4.1, "num3": -4}) - assert resp.status_code == 400 - resp = client.post("/predict", data={"num1": 3, "num2": -4, "num3": -3}) - assert resp.status_code == 400 + resp = client.post("/predictions", json={"input": {"text": "baz", "text2": "qux"}}) + assert resp.status_code == 422 -def test_good_options(): +def test_multiple_arguments(): class Predictor(cog.Predictor): - @cog.input("text", type=str, options=["foo", "bar"]) - @cog.input("num", type=int, options=[1, 2, 3]) - def predict(self, text, num): - return text + ("a" * num) + def predict( + self, + text: str, + path: Path, + num1: int, + num2: int = Input(default=10), + ) -> str: + with open(path) as fh: + return text + " " + str(num1 * num2) + " " + fh.read() client = make_client(Predictor()) - resp = client.post("/predict", data={"text": "foo", "num": 2}) + resp = client.post( + "/predictions", + json={ + "input": { + "text": "baz", + "num1": 5, + "path": "data:text/plain;base64," + + base64.b64encode(b"wibble").decode("utf-8"), + } + }, + ) assert resp.status_code == 200 - assert resp.data == b"fooaa" + assert resp.json() == {"output": "baz 50 wibble", "status": "success"} -def test_bad_options(): +def test_gt_lt(): class Predictor(cog.Predictor): - @cog.input("text", type=str, options=["foo", "bar"]) - @cog.input("num", type=int, options=[1, 2, 3]) - def predict(self, text, num): - return text + ("a" * num) + def predict(self, num: float = Input(gt=3, lt=10.5)) -> float: + return num client = make_client(Predictor()) - resp = client.post("/predict", data={"text": "baz", "num": 2}) - assert resp.status_code == 400 - resp = client.post("/predict", data={"text": "bar", "num": 4}) - assert resp.status_code == 400 - - -def test_bad_options_type(): - with pytest.raises(ValueError): - - class Predictor(cog.Predictor): - @cog.input("text", type=str, options=[]) - def predict(self, text): - return text - - with pytest.raises(ValueError): - - class Predictor(cog.Predictor): - @cog.input("text", type=str, options=["foo"]) - def predict(self, text): - return text - - with pytest.raises(ValueError): + resp = client.post("/predictions", json={"input": {"num": 2}}) + assert resp.json() == { + "detail": [ + { + "ctx": {"limit_value": 3}, + "loc": ["body", "input", "num"], + "msg": "ensure this value is greater than 3", + "type": "value_error.number.not_gt", + } + ] + } + assert resp.status_code == 422 + + resp = client.post("/predictions", json={"input": {"num": 5}}) + assert resp.status_code == 200 - class Predictor(cog.Predictor): - @cog.input("text", type=Path, options=["foo"]) - def predict(self, text): - return text +def test_options(): + # TODO: choices + class Options(Enum): + foo = "foo" + bar = "bar" -def test_multiple_arguments(): class Predictor(cog.Predictor): - @cog.input("text", type=str) - @cog.input("num1", type=int) - @cog.input("num2", type=int, default=10) - @cog.input("path", type=Path) - def predict(self, text, num1, num2, path): - with open(path) as f: - path_contents = f.read() - return text + " " + str(num1 * num2) + " " + path_contents + def predict(self, text: Options) -> str: + return str(text) client = make_client(Predictor()) - path_data = (io.BytesIO(b"bar"), "foo.txt") - resp = client.post( - "/predict", - data={"text": "baz", "num1": 5, "path": path_data}, - content_type="multipart/form-data", - ) + resp = client.post("/predictions", json={"input": {"text": "foo"}}) assert resp.status_code == 200 - assert resp.data == b"baz 50 bar" + resp = client.post("/predictions", json={"input": {"text": "baz", "num": 2}}) + assert resp.status_code == 422 diff --git a/python/tests/server/test_http_output.py b/python/tests/server/test_http_output.py index 4986bdc28e..e2d9268b68 100644 --- a/python/tests/server/test_http_output.py +++ b/python/tests/server/test_http_output.py @@ -1,34 +1,31 @@ -import tempfile +import base64 +import io import os -from pathlib import Path +import tempfile import numpy as np from PIL import Image +import responses +from responses.matchers import multipart_matcher import cog +from cog import Path, File from .test_http import make_client -def test_path_output_str(): +def test_return_wrong_type(): class Predictor(cog.Predictor): - @cog.input("text", type=str) - def predict(self, text): - temp_dir = tempfile.mkdtemp() - temp_path = os.path.join(temp_dir, "my_file.txt") - with open(temp_path, "w") as f: - f.write(text) - return Path(temp_path) + def predict(self) -> int: + return "foo" - client = make_client(Predictor()) - resp = client.post("/predict", data={"text": "baz"}) - assert resp.status_code == 200 - assert resp.content_type == "text/plain; charset=utf-8" - assert resp.data == b"baz" + client = make_client(Predictor(), raise_server_exceptions=False) + resp = client.post("/predictions") + assert resp.status_code == 500 -def test_path_output_image(): +def test_path_output_path(): class Predictor(cog.Predictor): - def predict(self): + def predict(self) -> Path: temp_dir = tempfile.mkdtemp() temp_path = os.path.join(temp_dir, "my_file.bmp") img = Image.new("RGB", (255, 255), "red") @@ -36,20 +33,90 @@ def predict(self): return Path(temp_path) client = make_client(Predictor()) - resp = client.post("/predict") - assert resp.status_code == 200 + res = client.post("/predictions") + assert res.status_code == 200 + header, b64data = res.json()["output"].split(",", 1) # need both image/bmp and image/x-ms-bmp until https://bugs.python.org/issue44211 is fixed - assert resp.content_type in ["image/bmp", "image/x-ms-bmp"] - assert resp.content_length == 195894 + assert header in ["data:image/bmp;base64", "data:image/x-ms-bmp;base64"] + assert len(base64.b64decode(b64data)) == 195894 + + +@responses.activate +def test_output_path_to_http(): + class Predictor(cog.Predictor): + def predict(self) -> Path: + temp_dir = tempfile.mkdtemp() + temp_path = os.path.join(temp_dir, "file.txt") + with open(temp_path, "w") as fh: + fh.write("hello") + return Path(temp_path) + + fh = io.BytesIO(b"hello") + fh.name = "file.txt" + responses.add( + responses.PUT, + "http://example.com/upload/file.txt", + status=201, + match=[multipart_matcher({"file": fh})], + ) + + client = make_client(Predictor()) + res = client.post( + "/predictions", json={"output_file_prefix": "http://example.com/upload/"} + ) + assert res.json() == { + "status": "success", + "output": "http://example.com/upload/file.txt", + } + assert res.status_code == 200 + + +def test_path_output_file(): + class Predictor(cog.Predictor): + def predict(self) -> File: + return io.StringIO("hello") + + client = make_client(Predictor()) + res = client.post("/predictions") + assert res.status_code == 200 + assert res.json() == { + "status": "success", + "output": "data:application/octet-stream;base64,aGVsbG8=", # hello + } + + +@responses.activate +def test_output_file_to_http(): + class Predictor(cog.Predictor): + def predict(self) -> File: + fh = io.StringIO("hello") + fh.name = "foo.txt" + return fh + + responses.add( + responses.PUT, + "http://example.com/upload/foo.txt", + status=201, + match=[multipart_matcher({"file": ("foo.txt", b"hello")})], + ) + + client = make_client(Predictor()) + res = client.post( + "/predictions", json={"output_file_prefix": "http://example.com/upload/"} + ) + assert res.json() == { + "status": "success", + "output": "http://example.com/upload/foo.txt", + } + assert res.status_code == 200 def test_json_output_numpy(): class Predictor(cog.Predictor): - def predict(self): - return {"foo": np.float32(1.0)} + def predict(self) -> np.float64: + return np.float64(1.0) client = make_client(Predictor()) - resp = client.post("/predict") + resp = client.post("/predictions") assert resp.status_code == 200 - assert resp.content_type == "application/json" - assert resp.data == b'{"foo": 1.0}' + assert resp.json() == {"output": 1.0, "status": "success"} diff --git a/python/tests/test_json.py b/python/tests/test_json.py new file mode 100644 index 0000000000..5bf8a47fd7 --- /dev/null +++ b/python/tests/test_json.py @@ -0,0 +1,59 @@ +import json +import os +import tempfile + +import cog +from cog.json import encode_json +import numpy as np +from pydantic import BaseModel + + +def test_encode_json_encodes_pydantic_models(): + class Model(BaseModel): + text: str + number: int + + assert encode_json(Model(text="hello", number=5), None) == { + "text": "hello", + "number": 5, + } + + +# TODO +# def test_file(): +# class Model(BaseModel): +# path: cog.Path + +# class Config: +# json_encoders = get_json_encoders() + +# temp_dir = tempfile.mkdtemp() +# temp_path = os.path.join(temp_dir, "my_file.txt") +# with open(temp_path, "w") as fh: +# fh.write("file content") +# model = Model(path=cog.Path(temp_path)) +# assert json.loads(model.json()) == { +# "path": "data:text/plain;base64,ZmlsZSBjb250ZW50" +# } + + +# def test_numpy(): +# class Model(BaseModel): +# ndarray: np.ndarray +# npfloat: np.float64 +# npinteger: np.integer + +# class Config: +# json_encoders = get_json_encoders() +# arbitrary_types_allowed = True + +# model = Model( +# ndarray=np.array([[1, 2], [3, 4]]), +# npfloat=np.float64(1.3), +# npinteger=np.int32(5), +# ) +# assert json.loads(model.json()) == { +# "ndarray": [[1, 2], [3, 4]], +# "npfloat": 1.3, +# "npinteger": 5, +# } diff --git a/requirements-dev.txt b/requirements-dev.txt index 5a3954a598..be17f12b2d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,9 +1,12 @@ -flask==2.0.1 +fastapi==0.70.1 numpy==1.21.1 pillow==9.0.0 +pydantic==1.8.2 pytest==6.2.4 PyYAML==5.4.1 redis==4.1.0 requests==2.25.1 -waiting==1.4.1 +responses==0.16.0 +uvicorn[standard]==0.16.0 wheel==0.36.2 + diff --git a/test-integration/test_integration/fixtures/failing-project/predict.py b/test-integration/test_integration/fixtures/failing-project/predict.py index 04f58ea3a8..db0326c554 100644 --- a/test-integration/test_integration/fixtures/failing-project/predict.py +++ b/test-integration/test_integration/fixtures/failing-project/predict.py @@ -2,9 +2,5 @@ class Predictor(cog.Predictor): - def setup(self): - pass - - @cog.input("text", type=str) - def predict(self, text): + def predict(self, text: str): raise Exception("over budget") diff --git a/test-integration/test_integration/fixtures/file-input-project/cog.yaml b/test-integration/test_integration/fixtures/file-input-project/cog.yaml new file mode 100644 index 0000000000..ce622845eb --- /dev/null +++ b/test-integration/test_integration/fixtures/file-input-project/cog.yaml @@ -0,0 +1,3 @@ +build: + python_version: "3.8" +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/file-input-project/predict.py b/test-integration/test_integration/fixtures/file-input-project/predict.py new file mode 100644 index 0000000000..1fb0b73ff2 --- /dev/null +++ b/test-integration/test_integration/fixtures/file-input-project/predict.py @@ -0,0 +1,8 @@ +import cog +from cog import Path + + +class Predictor(cog.Predictor): + def predict(self, path: Path) -> str: + with open(path) as f: + return f.read() diff --git a/test-integration/test_integration/fixtures/file-output-project/cog.yaml b/test-integration/test_integration/fixtures/file-output-project/cog.yaml new file mode 100644 index 0000000000..ea59756412 --- /dev/null +++ b/test-integration/test_integration/fixtures/file-output-project/cog.yaml @@ -0,0 +1,5 @@ +build: + python_version: "3.8" + python_packages: + - "pillow==8.3.2" +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/file-output-project/predict.py b/test-integration/test_integration/fixtures/file-output-project/predict.py new file mode 100644 index 0000000000..81651136c8 --- /dev/null +++ b/test-integration/test_integration/fixtures/file-output-project/predict.py @@ -0,0 +1,13 @@ +from PIL import Image +import os +import tempfile + +import cog + +class Predictor(cog.Predictor): + def predict(self) -> cog.Path: + temp_dir = tempfile.mkdtemp() + temp_path = os.path.join(temp_dir, f"prediction.bmp") + img = Image.new("RGB", (255, 255), "red") + img.save(temp_path) + return cog.Path(temp_path) diff --git a/test-integration/test_integration/fixtures/file-project/predict.py b/test-integration/test_integration/fixtures/file-project/predict.py index 7e7e0ae4d0..bd5fc493b9 100644 --- a/test-integration/test_integration/fixtures/file-project/predict.py +++ b/test-integration/test_integration/fixtures/file-project/predict.py @@ -1,15 +1,13 @@ -from pathlib import Path import tempfile import cog +from cog import Path class Predictor(cog.Predictor): def setup(self): self.foo = "foo" - @cog.input("text", type=str) - @cog.input("path", type=Path) - def predict(self, text, path): + def predict(self, text: str, path: Path) -> Path: with open(path) as f: output = self.foo + text + f.read() tmpdir = Path(tempfile.mkdtemp()) diff --git a/test-integration/test_integration/fixtures/int-project/cog.yaml b/test-integration/test_integration/fixtures/int-project/cog.yaml new file mode 100644 index 0000000000..ce622845eb --- /dev/null +++ b/test-integration/test_integration/fixtures/int-project/cog.yaml @@ -0,0 +1,3 @@ +build: + python_version: "3.8" +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/int-project/predict.py b/test-integration/test_integration/fixtures/int-project/predict.py new file mode 100644 index 0000000000..2b8daee8ab --- /dev/null +++ b/test-integration/test_integration/fixtures/int-project/predict.py @@ -0,0 +1,6 @@ +import cog + + +class Predictor(cog.Predictor): + def predict(self, input: int) -> int: + return input * 2 diff --git a/test-integration/test_integration/fixtures/logging-project/predict.py b/test-integration/test_integration/fixtures/logging-project/predict.py index 2588249ed8..bedf6bc2a6 100644 --- a/test-integration/test_integration/fixtures/logging-project/predict.py +++ b/test-integration/test_integration/fixtures/logging-project/predict.py @@ -20,8 +20,7 @@ def setup(self): print("setting up predictor") self.foo = "foo" - @cog.input("text", type=str, default="") - def predict(self, text): + def predict(self, text: str = "") -> str: logging.warn("writing log message") time.sleep(0.1) libc.puts(b"writing from C") diff --git a/test-integration/test_integration/fixtures/string-project/predict.py b/test-integration/test_integration/fixtures/string-project/predict.py index 047c4042dc..144a728fcf 100644 --- a/test-integration/test_integration/fixtures/string-project/predict.py +++ b/test-integration/test_integration/fixtures/string-project/predict.py @@ -2,6 +2,5 @@ class Predictor(cog.Predictor): - @cog.input("input", type=str) - def predict(self, input): + def predict(self, input: str) -> str: return "hello " + input diff --git a/test-integration/test_integration/fixtures/subdirectory-project/my-subdir/predict.py b/test-integration/test_integration/fixtures/subdirectory-project/my-subdir/predict.py index 5d495b459a..29ef1b544e 100644 --- a/test-integration/test_integration/fixtures/subdirectory-project/my-subdir/predict.py +++ b/test-integration/test_integration/fixtures/subdirectory-project/my-subdir/predict.py @@ -4,6 +4,5 @@ class Predictor(cog.Predictor): - @cog.input("input", type=str) - def predict(self, input): + def predict(self, input: str) -> str: return concat("hello", input) diff --git a/test-integration/test_integration/fixtures/timeout-project/predict.py b/test-integration/test_integration/fixtures/timeout-project/predict.py index 01c30c07f4..fc8ed60886 100644 --- a/test-integration/test_integration/fixtures/timeout-project/predict.py +++ b/test-integration/test_integration/fixtures/timeout-project/predict.py @@ -3,10 +3,6 @@ class Predictor(cog.Predictor): - def setup(self): - pass - - @cog.input("sleep_time", type=float) - def predict(self, sleep_time): + def predict(self, sleep_time: float) -> str: time.sleep(sleep_time) return "it worked!" diff --git a/test-integration/test_integration/fixtures/yielding-project/predict.py b/test-integration/test_integration/fixtures/yielding-project/predict.py index 72d52b12d3..e6fbd436c9 100644 --- a/test-integration/test_integration/fixtures/yielding-project/predict.py +++ b/test-integration/test_integration/fixtures/yielding-project/predict.py @@ -1,11 +1,9 @@ +from typing import Generator import cog -class Predictor(cog.Predictor): - def setup(self): - pass - @cog.input("text", type=str) - def predict(self, text): +class Predictor(cog.Predictor): + def predict(self, text: str) -> Generator[str, None, None]: predictions = ["foo", "bar", "baz"] for prediction in predictions: yield prediction diff --git a/test-integration/test_integration/fixtures/yielding-timeout-project/predict.py b/test-integration/test_integration/fixtures/yielding-timeout-project/predict.py index 6b4928a8f6..dd8986fc62 100644 --- a/test-integration/test_integration/fixtures/yielding-timeout-project/predict.py +++ b/test-integration/test_integration/fixtures/yielding-timeout-project/predict.py @@ -1,14 +1,12 @@ +from typing import Generator import cog import time class Predictor(cog.Predictor): - def setup(self): - pass - - @cog.input("sleep_time", type=float) - @cog.input("n_iterations", type=int) - def predict(self, sleep_time, n_iterations): + def predict( + self, sleep_time: float, n_iterations: int + ) -> Generator[str, None, None]: for i in range(n_iterations): time.sleep(sleep_time) yield f"yield {i}" diff --git a/test-integration/test_integration/test_build.py b/test-integration/test_integration/test_build.py index 3ab9ebdff8..a5c6b41539 100644 --- a/test-integration/test_integration/test_build.py +++ b/test-integration/test_integration/test_build.py @@ -26,7 +26,7 @@ def test_build_without_predictor(docker_image): assert json.loads(labels["org.cogmodel.config"]) == { "build": {"python_version": "3.8"} } - assert json.loads(labels["org.cogmodel.type_signature"]) == {} + assert "org.cogmodel.openapi_schema" not in labels def test_build_names_uses_image_option_in_cog_yaml(tmpdir_factory, docker_image): @@ -64,11 +64,21 @@ def test_build_with_model(docker_image): ).stdout ) labels = image[0]["Config"]["Labels"] - assert json.loads(labels["org.cogmodel.type_signature"]) == { - "inputs": [ - {"name": "text", "type": "str"}, - {"name": "path", "type": "Path"}, - ] + schema = json.loads(labels["org.cogmodel.openapi_schema"]) + + assert schema["components"]["schemas"]["Input"] == { + "title": "Input", + "required": ["text", "path"], + "type": "object", + "properties": { + "text": {"title": "Text", "type": "string", "x-order": 0}, + "path": { + "title": "Path", + "type": "string", + "format": "uri", + "x-order": 1, + }, + }, } @@ -107,4 +117,4 @@ def test_build_gpu_model_on_cpu(tmpdir_factory, docker_image): "cudnn": "8", } } - assert json.loads(labels["org.cogmodel.type_signature"]) == {} + assert "org.cogmodel.openapi_schema" not in labels diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index fffa7888df..262b239b20 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -1,4 +1,6 @@ from pathlib import Path +import pathlib +import shutil import pytest import subprocess @@ -17,6 +19,62 @@ def test_predict_takes_string_inputs_and_returns_strings_to_stdout(): assert result.stdout == b"hello world\n" +def test_predict_takes_int_inputs_and_returns_ints_to_stdout(): + project_dir = Path(__file__).parent / "fixtures/int-project" + result = subprocess.run( + ["cog", "predict", "-i", "2"], + cwd=project_dir, + check=True, + capture_output=True, + ) + # stdout should be clean without any log messages so it can be piped to other commands + assert result.stdout == b"4\n" + + +def test_predict_takes_file_inputs(tmpdir_factory): + project_dir = Path(__file__).parent / "fixtures/file-input-project" + out_dir = pathlib.Path(tmpdir_factory.mktemp("project")) + shutil.copytree(project_dir, out_dir, dirs_exist_ok=True) + with open(out_dir / "input.txt", "w") as fh: + fh.write("what up") + result = subprocess.run( + ["cog", "predict", "-i", "path=@" + str(out_dir / "input.txt")], + cwd=out_dir, + check=True, + capture_output=True, + ) + assert result.stdout == b"what up\n" + + +def test_predict_writes_files_to_files(tmpdir_factory): + project_dir = Path(__file__).parent / "fixtures/file-output-project" + out_dir = pathlib.Path(tmpdir_factory.mktemp("project")) + shutil.copytree(project_dir, out_dir, dirs_exist_ok=True) + result = subprocess.run( + ["cog", "predict", "-i", "world", "-o", out_dir / "out.txt"], + cwd=out_dir, + check=True, + capture_output=True, + ) + assert result.stdout == b"" + with open(out_dir / "output.bmp", "rb") as f: + assert len(f.read()) == 195894 + + +def test_predict_writes_strings_to_files(tmpdir_factory): + project_dir = Path(__file__).parent / "fixtures/string-project" + out_dir = pathlib.Path(tmpdir_factory.mktemp("project")) + result = subprocess.run( + ["cog", "predict", "-i", "world", "-o", out_dir / "out.txt"], + cwd=project_dir, + check=True, + capture_output=True, + ) + assert result.stdout == b"" + with open(out_dir / "out.txt") as f: + assert f.read() == "hello world" + + def test_predict_runs_an_existing_image(tmpdir_factory): project_dir = Path(__file__).parent / "fixtures/string-project" image_name = "cog-test-" + random_string(10) From 0ffb4ad22a70c8c954979068d83b6784ae45494f Mon Sep 17 00:00:00 2001 From: Ben Firshman Date: Tue, 11 Jan 2022 19:07:47 -0800 Subject: [PATCH 03/14] Add support for redis queue worker Signed-off-by: Ben Firshman --- python/cog/server/redis_queue.py | 63 +++++++------------ .../test_integration/test_redis_queue.py | 5 +- 2 files changed, 25 insertions(+), 43 deletions(-) diff --git a/python/cog/server/redis_queue.py b/python/cog/server/redis_queue.py index b3cac1da9e..13e685c57b 100644 --- a/python/cog/server/redis_queue.py +++ b/python/cog/server/redis_queue.py @@ -1,4 +1,4 @@ -from io import BytesIO +import io import json from pathlib import Path from typing import Optional @@ -9,13 +9,13 @@ import types import contextlib +from pydantic import ValidationError import redis import requests -from werkzeug.datastructures import FileStorage from .redis_log_capture import capture_log -from ..input import InputValidationError, validate_and_convert_inputs -from ..json import to_json +from ..predictor import Predictor, get_input_type, load_predictor +from ..json import encode_json from ..predictor import Predictor, load_predictor @@ -76,6 +76,10 @@ def __init__( self.redis_db = redis_db # TODO: respect max_processing_time in message handling self.max_processing_time = 10 * 60 # timeout after 10 minutes + + # Set up types + self.InputType = get_input_type(self.predictor) + self.redis = redis.Redis( host=self.redis_host, port=self.redis_port, db=self.redis_db ) @@ -193,21 +197,16 @@ def handle_message(self, response_queue, message, cleanup_functions): raw_inputs = message["inputs"] prediction_id = message["id"] + # Flatten the incoming object. The schema and Pydantic will handle downloading files from URLs (see cog/types.py) for k, v in raw_inputs.items(): if "value" in v and v["value"] != "": inputs[k] = v["value"] else: - file_url = v["file"]["url"] - sys.stderr.write(f"Downloading file from {file_url}\n") - value_bytes = self.download(file_url) - inputs[k] = FileStorage( - stream=BytesIO(value_bytes), filename=v["file"]["name"] - ) + inputs[k] = v["file"]["url"] + try: - inputs = validate_and_convert_inputs( - self.predictor, inputs, cleanup_functions - ) - except InputValidationError as e: + input_obj = self.InputType(**inputs) + except ValidationError as e: tb = traceback.format_exc() sys.stderr.write(tb) self.push_error(response_queue, e) @@ -217,7 +216,7 @@ def handle_message(self, response_queue, message, cleanup_functions): with self.capture_log(self.STAGE_RUN, prediction_id), timeout( seconds=self.predict_timeout ): - return_value = self.predictor.predict(**inputs) + return_value = self.predictor.predict(**input_obj.dict()) if isinstance(return_value, types.GeneratorType): last_result = None @@ -260,36 +259,22 @@ def push_error(self, response_queue, error): self.redis.rpush(response_queue, message) def push_result(self, response_queue, result, status): - if isinstance(result, Path): - message = { - "file": { - "url": self.upload_to_temp(result), - "name": result.name, - } - } - elif isinstance(result, str): - message = { - "value": result, - } - else: - message = { - "value": to_json(result), - } + message = { + "value": self.encode_json(result), + } message["status"] = status sys.stderr.write(f"Pushing successful result to {response_queue}\n") self.redis.rpush(response_queue, json.dumps(message)) - def upload_to_temp(self, path: Path) -> str: - sys.stderr.write( - f"Uploading {path.name} to temporary storage at {self.upload_url}\n" - ) - resp = requests.put( - self.upload_url, files={"file": (path.name, path.open("rb"))} - ) - resp.raise_for_status() - return resp.json()["url"] + def encode_json(self, obj): + def upload_file(fh: io.IOBase) -> str: + resp = requests.put(self.upload_url, files={"file": fh}) + resp.raise_for_status() + return resp.json()["url"] + + return encode_json(obj, upload_file) @contextlib.contextmanager def capture_log(self, stage, prediction_id): diff --git a/test-integration/test_integration/test_redis_queue.py b/test-integration/test_integration/test_redis_queue.py index 5a86f4e6af..47f9a61d44 100644 --- a/test-integration/test_integration/test_redis_queue.py +++ b/test-integration/test_integration/test_redis_queue.py @@ -56,10 +56,7 @@ def test_queue_worker_files(docker_image, docker_network, redis_client, upload_s ) response = json.loads(redis_client.brpop("response-queue", timeout=10)[1]) assert response == { - "file": { - "name": "output.txt", - "url": "http://upload-server:5000/download/output.txt", - }, + "value": "http://upload-server:5000/download/output.txt", "status": "success", } From 0348f4827b8c34a890eb27b5b3f6547f6a457ead Mon Sep 17 00:00:00 2001 From: Ben Firshman Date: Wed, 12 Jan 2022 18:23:31 -0800 Subject: [PATCH 04/14] Implement choices option on input Signed-off-by: Ben Firshman --- docs/getting-started-own-model.md | 1 + python/cog/predictor.py | 18 ++++++++++++++++-- python/cog/types.py | 4 +++- python/tests/server/test_http.py | 9 ++++++++- python/tests/server/test_http_input.py | 11 +++-------- 5 files changed, 31 insertions(+), 12 deletions(-) diff --git a/docs/getting-started-own-model.md b/docs/getting-started-own-model.md index 2d8dfca4da..11d18aa4fb 100644 --- a/docs/getting-started-own-model.md +++ b/docs/getting-started-own-model.md @@ -105,6 +105,7 @@ You can provide more information about the input with the `Input()` function, as - `ge`: For `int` or `float` types, the value should be greater than or equal to this number. - `lt`: For `int` or `float` types, the value should be less than this number. - `le`: For `int` or `float` types, the value should be less than or equal to this number. +- `choices`: A list of possible values for this input. There are some more advanced options you can pass, too. For more details, [take a look at the prediction interface documentation](python.md). diff --git a/python/cog/predictor.py b/python/cog/predictor.py index c7130fdbed..9d24af5f44 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Generator +import enum import importlib import inspect import os.path @@ -69,7 +70,9 @@ def get_input_type(predictor: Predictor): order = 0 for name, parameter in signature.parameters.items(): - if not parameter.annotation: + annotation = parameter.annotation + + if not annotation: # TODO: perhaps should throw error if there are arguments not annotated? continue @@ -87,7 +90,18 @@ def get_input_type(predictor: Predictor): default.extra["x-order"] = order order += 1 - create_model_kwargs[name] = (parameter.annotation, default) + # Choices! + if default.extra.get("choices"): + choices = default.extra["choices"] + # It will be passed automatically as 'enum' in the schema, so remove it as an extra field. + del default.extra["choices"] + if annotation != str: + raise TypeError( + f"The input {name} uses the option choices. Choices can only be used with str types." + ) + annotation = enum.Enum(name, {value: value for value in choices}) + + create_model_kwargs[name] = (annotation, default) return create_model("Input", **create_model_kwargs) diff --git a/python/cog/types.py b/python/cog/types.py index e502f1f4f4..ec2137cedb 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -6,7 +6,7 @@ import requests import shutil import tempfile -from typing import Any, Optional +from typing import Any, List, Optional from urllib.parse import urlparse from pydantic import Field @@ -31,6 +31,7 @@ def Input( max_length: int = None, allow_mutation: bool = True, regex: str = None, + choices: List[str] = None, **kwargs: Any, ): """Input is similar to pydantic.Field, but doesn't require a default value to be the first argument.""" @@ -52,6 +53,7 @@ def Input( max_length=max_length, allow_mutation=allow_mutation, regex=regex, + choices=choices, **kwargs, ) diff --git a/python/tests/server/test_http.py b/python/tests/server/test_http.py index 043fa1e08c..46c4b81e64 100644 --- a/python/tests/server/test_http.py +++ b/python/tests/server/test_http.py @@ -44,6 +44,7 @@ def predict( input_with_default: int = Input(title="Some number", default=10), path: Path = Input(title="Some path"), image: File = Input(title="Some path"), + choices: str = Input(choices=["foo", "bar"]), ) -> str: pass @@ -116,7 +117,7 @@ def predict( }, "Input": { "title": "Input", - "required": ["no_default", "path", "image"], + "required": ["no_default", "path", "image", "choices"], "type": "object", "properties": { "no_default": { @@ -148,6 +149,7 @@ def predict( "format": "uri", "x-order": 4, }, + "choices": {"$ref": "#/components/schemas/choices"}, }, }, "Request": { @@ -191,6 +193,11 @@ def predict( "type": {"title": "Error Type", "type": "string"}, }, }, + "choices": { + "title": "choices", + "enum": ["foo", "bar"], + "description": "An enumeration.", + }, } }, } diff --git a/python/tests/server/test_http_input.py b/python/tests/server/test_http_input.py index bb1db2ada6..abbaea0ba0 100644 --- a/python/tests/server/test_http_input.py +++ b/python/tests/server/test_http_input.py @@ -252,18 +252,13 @@ def predict(self, num: float = Input(gt=3, lt=10.5)) -> float: assert resp.status_code == 200 -def test_options(): - # TODO: choices - class Options(Enum): - foo = "foo" - bar = "bar" - +def test_choices(): class Predictor(cog.Predictor): - def predict(self, text: Options) -> str: + def predict(self, text: str = Input(choices=["foo", "bar"])) -> str: return str(text) client = make_client(Predictor()) resp = client.post("/predictions", json={"input": {"text": "foo"}}) assert resp.status_code == 200 - resp = client.post("/predictions", json={"input": {"text": "baz", "num": 2}}) + resp = client.post("/predictions", json={"input": {"text": "baz"}}) assert resp.status_code == 422 From 7f4242bedb4a90c9b272d5a000c129aba9e80d60 Mon Sep 17 00:00:00 2001 From: Ben Firshman Date: Thu, 13 Jan 2022 06:35:29 -0800 Subject: [PATCH 05/14] Test complex output Signed-off-by: Ben Firshman --- python/tests/server/test_http_output.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/tests/server/test_http_output.py b/python/tests/server/test_http_output.py index e2d9268b68..ef756cda6c 100644 --- a/python/tests/server/test_http_output.py +++ b/python/tests/server/test_http_output.py @@ -5,6 +5,7 @@ import numpy as np from PIL import Image +from pydantic import BaseModel import responses from responses.matchers import multipart_matcher @@ -120,3 +121,24 @@ def predict(self) -> np.float64: resp = client.post("/predictions") assert resp.status_code == 200 assert resp.json() == {"output": 1.0, "status": "success"} + + +def test_complex_output(): + class Output(BaseModel): + text: str + file: File + + class Predictor(cog.Predictor): + def predict(self) -> Output: + return Output(text="hello", file=io.StringIO("hello")) + + client = make_client(Predictor()) + resp = client.post("/predictions") + assert resp.json() == { + "output": { + "file": "data:application/octet-stream;base64,aGVsbG8=", + "text": "hello", + }, + "status": "success", + } + assert resp.status_code == 200 From c4df85610deeb2c986d3822565ba1e28a58e28cd Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Thu, 13 Jan 2022 11:30:37 -0800 Subject: [PATCH 06/14] add pip install to CONTRIBUTING Signed-off-by: Zeke Sikelianos --- CONTRIBUTING.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2d2f16367b..25aecfdd7a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -96,7 +96,11 @@ Common contribution types include: `doc`, `code`, `bug`, and `ideas`. See the fu You'll need to [install Go 1.16](https://golang.org/doc/install). If you're using a newer Mac with an M1 chip, be sure to download the `darwin-arm64` installer package. Alternatively you can run `brew install go` which will automatically detect and use the appropriate installer for your system architecture. -Once you have Go installed, then run: +Install the Python dependencies: + + pip install -r requirements-dev.txt + +Once you have Go installed, run: make install From 955b455cc2efda3b209a25a0081ed058997eb4ad Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Thu, 13 Jan 2022 12:00:28 -0800 Subject: [PATCH 07/14] allow server log level to be configured with COG_LOG_LEVEL Signed-off-by: Zeke Sikelianos --- pkg/predict/predictor.go | 4 +++- python/cog/server/http.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pkg/predict/predictor.go b/pkg/predict/predictor.go index d6d3b17867..0221b9d0b4 100644 --- a/pkg/predict/predictor.go +++ b/pkg/predict/predictor.go @@ -39,7 +39,9 @@ type Predictor struct { func NewPredictor(runOptions docker.RunOptions) Predictor { if global.Debug { - runOptions.Env = append(runOptions.Env, "COG_DEBUG=1") + runOptions.Env = append(runOptions.Env, "COG_LOG_LEVEL=debug") + } else { + runOptions.Env = append(runOptions.Env, "COG_LOG_LEVEL=warning") } return Predictor{runOptions: runOptions} } diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 4277266c81..89d79d5068 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -96,7 +96,11 @@ def predict(...) -> output_type: app, host="0.0.0.0", port=5000, - log_level="debug" if os.environ.get("COG_DEBUG") else "warning", + # log level is configurable so we can make it quiet or verbose for `cog predict` + # cog predict --debug # -> debug + # cog predict # -> warning + # docker run # -> info (default) + log_level=os.environ.get("COG_LOG_LEVEL", "info"), # Single worker to safely run on GPUs. workers=1, ) From 884aff5066c65a3832f20e09874f8db5f39f92f6 Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Thu, 13 Jan 2022 14:01:54 -0800 Subject: [PATCH 08/14] always define components.schemas.Output ..even when the user-defined predict function returns a simple type like a string. Co-Authored-By: Ben Firshman Signed-off-by: Zeke Sikelianos --- python/cog/predictor.py | 8 +- python/cog/server/http.py | 21 +++-- python/tests/server/test_http.py | 151 ++++++++++++++++++++++++++++++- 3 files changed, 168 insertions(+), 12 deletions(-) diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 9d24af5f44..ff5d3aadef 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -6,7 +6,7 @@ import os.path from pathlib import Path import typing -from pydantic import create_model +from pydantic import create_model, BaseModel from pydantic.fields import FieldInfo # Added in Python 3.8. Can be from typing if we drop support for <3.8. @@ -117,4 +117,8 @@ def get_output_type(predictor: Predictor): if get_origin(OutputType) is Generator: OutputType = get_args(OutputType)[0] - return OutputType + # Wrap the type in a model so Pydantic can document it in component schema + class Output(BaseModel): + __root__: OutputType + + return Output diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 89d79d5068..d39bac5c0a 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -37,6 +37,18 @@ class Request(BaseModel): input: InputType = None output_file_prefix: str = None + # response_model is purely for generating schema. + # We generate Response again in the request so we can set file output paths correctly, etc. + OutputType = get_output_type(predictor) + + @app.post( + "/predictions", + response_model=get_response_type(OutputType), + response_model_exclude_unset=True, + ) + + # The signature of this function is used by FastAPI to generate the schema. + # The function body is not used to generate the schema. def predict(request: Request = Body(default=None)): if request is None or request.input is None: output = predictor.predict() @@ -77,15 +89,6 @@ def predict(...) -> output_type: ) return JSONResponse(content=encoded_response) - # response_model is purely for generating schema. - # We generate Response again in the request so we can set file output paths correctly, etc. - OutputType = get_output_type(predictor) - app.post( - "/predictions", - response_model=get_response_type(OutputType), - response_model_exclude_unset=True, - )(predict) - return app diff --git a/python/tests/server/test_http.py b/python/tests/server/test_http.py index 46c4b81e64..f8dbc03470 100644 --- a/python/tests/server/test_http.py +++ b/python/tests/server/test_http.py @@ -6,6 +6,7 @@ from unittest import mock from fastapi.testclient import TestClient +from pydantic import BaseModel from PIL import Image import pytest @@ -152,6 +153,10 @@ def predict( "choices": {"$ref": "#/components/schemas/choices"}, }, }, + "Output": { + "title": "Output", + "type": "string", + }, "Request": { "title": "Request", "type": "object", @@ -169,7 +174,7 @@ def predict( "type": "object", "properties": { "status": {"$ref": "#/components/schemas/Status"}, - "output": {"title": "Output", "type": "string"}, + "output": {"$ref": "#/components/schemas/Output"}, "error": {"title": "Error", "type": "string"}, }, "description": "The status of a prediction.", @@ -203,6 +208,150 @@ def predict( } +def test_openapi_specification_with_custom_user_defined_output_type(): + # Calling this `MyOutput` to test if cog renames it to `Output` in the schema + class MyOutput(BaseModel): + foo_number: int = "42" + foo_string: str = "meaning of life" + + class Predictor(cog.Predictor): + def predict( + self, + ) -> MyOutput: + pass + + client = make_client(Predictor()) + resp = client.get("/openapi.json") + assert resp.status_code == 200 + print(resp.json()) + + assert resp.json() == { + "openapi": "3.0.2", + "info": {"title": "Cog", "version": "0.1.0"}, + "paths": { + "/": { + "get": { + "summary": "Root", + "operationId": "root__get", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + } + }, + } + }, + "/predictions": { + "post": { + "summary": "Predict", + "operationId": "predict_predictions_post", + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Request"} + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Response"} + } + }, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + }, + }, + "components": { + "schemas": { + "HTTPValidationError": { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + } + }, + }, + "Input": {"title": "Input", "type": "object", "properties": {}}, + "MyOutput": { + "title": "MyOutput", + "type": "object", + "properties": { + "foo_number": { + "title": "Foo Number", + "type": "integer", + "default": "42", + }, + "foo_string": { + "title": "Foo String", + "type": "string", + "default": "meaning of life", + }, + }, + }, + "Output": {"$ref": "#/components/schemas/MyOutput", "title": "Output"}, + "Request": { + "title": "Request", + "type": "object", + "properties": { + "input": {"$ref": "#/components/schemas/Input"}, + "output_file_prefix": { + "title": "Output File Prefix", + "type": "string", + }, + }, + }, + "Response": { + "title": "Response", + "required": ["status"], + "type": "object", + "properties": { + "status": {"$ref": "#/components/schemas/Status"}, + "output": {"$ref": "#/components/schemas/Output"}, + "error": {"title": "Error", "type": "string"}, + }, + "description": "The status of a prediction.", + }, + "Status": { + "title": "Status", + "enum": ["processing", "success", "failed"], + "description": "An enumeration.", + }, + "ValidationError": { + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"type": "string"}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + }, + } + }, + } + + def test_yielding_strings_from_generator_predictors(): class Predictor(cog.Predictor): def predict(self) -> Generator[str, None, None]: From 825d88120d3435ff455e1e7318e20ded7a14efc0 Mon Sep 17 00:00:00 2001 From: Ben Firshman Date: Thu, 13 Jan 2022 16:29:29 -0800 Subject: [PATCH 09/14] Rename Predictor to BasePredictor Signed-off-by: Ben Firshman --- README.md | 4 +-- docs/getting-started-own-model.md | 5 ++- docs/getting-started.md | 5 ++- docs/python.md | 7 ++-- docs/yaml.md | 6 ++-- pkg/config/config.go | 2 +- python/cog/__init__.py | 5 ++- python/cog/predictor.py | 6 ++-- python/cog/server/http.py | 4 +-- python/cog/server/redis_queue.py | 5 ++- python/tests/server/test_http.py | 19 +++++------ python/tests/server/test_http_input.py | 34 +++++++++---------- python/tests/server/test_http_output.py | 17 +++++----- .../fixtures/failing-project/predict.py | 4 +-- .../fixtures/file-input-project/predict.py | 5 ++- .../fixtures/file-output-project/predict.py | 8 ++--- .../fixtures/file-project/predict.py | 5 ++- .../fixtures/int-project/predict.py | 4 +-- .../fixtures/logging-project/predict.py | 9 +++-- .../fixtures/string-project/predict.py | 4 +-- .../subdirectory-project/my-subdir/predict.py | 4 +-- .../fixtures/timeout-project/predict.py | 5 +-- .../fixtures/yielding-project/predict.py | 5 +-- .../yielding-timeout-project/predict.py | 7 ++-- 24 files changed, 87 insertions(+), 92 deletions(-) diff --git a/README.md b/README.md index e7cb437f1b..72f8afd021 100644 --- a/README.md +++ b/README.md @@ -27,10 +27,10 @@ predict: "predict.py:Predictor" And define how predictions are run on your model with `predict.py`: ```python -from cog import Predictor, Input, Path +from cog import BasePredictor, Input, Path import torch -class ColorizationPredictor(Predictor): +class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.model = torch.load("./weights.pth") diff --git a/docs/getting-started-own-model.md b/docs/getting-started-own-model.md index 11d18aa4fb..097156d682 100644 --- a/docs/getting-started-own-model.md +++ b/docs/getting-started-own-model.md @@ -66,11 +66,10 @@ With `cog.yaml`, you can also install system packages and other things. [Take a The next step is to update `predict.py` to define the interface for running predictions on your model. The `predict.py` generated by `cog init` looks something like this: ```python -import cog -from cog import Path, Input +from cog import BasePredictor, Path, Input import torch -class Predictor(cog.Predictor): +class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.net = torch.load("weights.pth") diff --git a/docs/getting-started.md b/docs/getting-started.md index 2f61c341af..2ea30fc0a1 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -60,15 +60,14 @@ Then, we need to write some code to describe how predictions are run on the mode ```python from typing import Any -import cog -from cog import Input, Path +from cog import BasePredictor, Input, Path from tensorflow.keras.applications.resnet50 import ResNet50 from tensorflow.keras.preprocessing import image as keras_image from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions import numpy as np -class ResNetPredictor(cog.Predictor): +class ResNetPredictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.model = ResNet50(weights='resnet50_weights_tf_dim_ordering_tf_kernels.h5') diff --git a/docs/python.md b/docs/python.md index eb52458c41..379f5934b3 100644 --- a/docs/python.md +++ b/docs/python.md @@ -1,13 +1,12 @@ # Prediction interface reference -You define how Cog runs predictions on your model by defining a class that inherits from `cog.Predictor`. It looks something like this: +You define how Cog runs predictions on your model by defining a class that inherits from `BasePredictor`. It looks something like this: ```python -import cog -from cog import Path, Input +from cog import BasePredictor, Path, Input import torch -class ImageScalingPredictor(cog.Predictor): +class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.model = torch.load("weights.pth") diff --git a/docs/yaml.md b/docs/yaml.md index b09357f294..6d844289b7 100644 --- a/docs/yaml.md +++ b/docs/yaml.md @@ -12,7 +12,7 @@ build: system_packages: - "ffmpeg" - "libavcodec-dev" -predict: "predict.py:JazzSoloComposerPredictor" +predict: "predict.py:Predictor" ``` Tip: Run [`cog init`](getting-started-own-model#initialization) to generate an annotated `cog.yaml` file that can be used as a starting point for setting up your model. @@ -102,12 +102,12 @@ If you don't provide this, a name will be generated from the directory name. ## `predict` -The pointer to the `cog.Predictor` object in your code, which defines how predictions are run on your model. +The pointer to the `Predictor` object in your code, which defines how predictions are run on your model. For example: ```yaml -predict: "predict.py:HotdogPredictor" +predict: "predict.py:Predictor" ``` See [the Python API documentation for more information](python.md). diff --git a/pkg/config/config.go b/pkg/config/config.go index 3fa9efe0f6..d40ae66b6a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -122,7 +122,7 @@ func (c *Config) ValidateAndCompleteConfig() error { } if c.Predict != "" { if len(strings.Split(c.Predict, ".py:")) != 2 { - return fmt.Errorf("'predict' in cog.yaml must be in the form 'predict.py:PredictorClass") + return fmt.Errorf("'predict' in cog.yaml must be in the form 'predict.py:Predictor") } } diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 8ececd7369..8c37d1c309 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -1,8 +1,11 @@ -from .predictor import Predictor +from .predictor import BasePredictor from .types import File, Input, Path +# Backwards compatibility. Will be deprecated before 1.0.0. +Predictor = BasePredictor __all__ = [ + "BasePredictor", "File", "Input", "Path", diff --git a/python/cog/predictor.py b/python/cog/predictor.py index ff5d3aadef..08f087897f 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -17,7 +17,7 @@ from .types import Input -class Predictor(ABC): +class BasePredictor(ABC): def setup(self): pass @@ -63,7 +63,7 @@ def load_predictor(): return predictor_class() -def get_input_type(predictor: Predictor): +def get_input_type(predictor: BasePredictor): signature = inspect.signature(predictor.predict) create_model_kwargs = {} @@ -106,7 +106,7 @@ def get_input_type(predictor: Predictor): return create_model("Input", **create_model_kwargs) -def get_output_type(predictor: Predictor): +def get_output_type(predictor: BasePredictor): signature = inspect.signature(predictor.predict) if signature.return_annotation is inspect.Signature.empty: OutputType = Literal[None] diff --git a/python/cog/server/http.py b/python/cog/server/http.py index d39bac5c0a..3ed9b734af 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -10,13 +10,13 @@ from ..files import upload_file from ..json import encode_json -from ..predictor import Predictor, get_input_type, get_output_type, load_predictor +from ..predictor import BasePredictor, get_input_type, get_output_type, load_predictor from ..response import Status, get_response_type logger = logging.getLogger("cog") -def create_app(predictor: Predictor) -> FastAPI: +def create_app(predictor: BasePredictor) -> FastAPI: app = FastAPI( title="Cog", # TODO: mention model name? # version=None # TODO diff --git a/python/cog/server/redis_queue.py b/python/cog/server/redis_queue.py index 13e685c57b..67db4296bc 100644 --- a/python/cog/server/redis_queue.py +++ b/python/cog/server/redis_queue.py @@ -14,9 +14,8 @@ import requests from .redis_log_capture import capture_log -from ..predictor import Predictor, get_input_type, load_predictor +from ..predictor import BasePredictor, get_input_type, load_predictor from ..json import encode_json -from ..predictor import Predictor, load_predictor class timeout: @@ -53,7 +52,7 @@ class RedisQueueWorker: def __init__( self, - predictor: Predictor, + predictor: BasePredictor, redis_host: str, redis_port: int, input_queue: str, diff --git a/python/tests/server/test_http.py b/python/tests/server/test_http.py index f8dbc03470..c9aa37a904 100644 --- a/python/tests/server/test_http.py +++ b/python/tests/server/test_http.py @@ -10,20 +10,19 @@ from PIL import Image import pytest -import cog -from cog import Input, File, Path +from cog import BasePredictor, Input, File, Path from cog.server.http import create_app -def make_client(predictor: cog.Predictor, **kwargs) -> TestClient: +def make_client(predictor: BasePredictor, **kwargs) -> TestClient: app = create_app(predictor) with TestClient(app, **kwargs) as client: return client def test_setup_is_called(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def setup(self): self.foo = "bar" @@ -37,7 +36,7 @@ def predict(self) -> str: def test_openapi_specification(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict( self, no_default: str, @@ -214,7 +213,7 @@ class MyOutput(BaseModel): foo_number: int = "42" foo_string: str = "meaning of life" - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict( self, ) -> MyOutput: @@ -353,7 +352,7 @@ def predict( def test_yielding_strings_from_generator_predictors(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self) -> Generator[str, None, None]: predictions = ["foo", "bar", "baz"] for prediction in predictions: @@ -366,8 +365,8 @@ def predict(self) -> Generator[str, None, None]: def test_yielding_files_from_generator_predictors(): - class Predictor(cog.Predictor): - def predict(self) -> Generator[cog.Path, None, None]: + class Predictor(BasePredictor): + def predict(self) -> Generator[Path, None, None]: colors = ["red", "blue", "yellow"] for i, color in enumerate(colors): temp_dir = tempfile.mkdtemp() @@ -390,7 +389,7 @@ def predict(self) -> Generator[cog.Path, None, None]: @pytest.mark.skip @mock.patch("time.time", return_value=0.0) def test_timing(time_mock): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def setup(self): time_mock.return_value = 1.0 diff --git a/python/tests/server/test_http_input.py b/python/tests/server/test_http_input.py index abbaea0ba0..f63f570729 100644 --- a/python/tests/server/test_http_input.py +++ b/python/tests/server/test_http_input.py @@ -1,5 +1,4 @@ import base64 -from enum import Enum import os import tempfile @@ -7,13 +6,12 @@ from pydantic import BaseModel import responses -import cog -from cog import Input, Path, File +from cog import BasePredictor, Input, Path, File from .test_http import make_client def test_no_input(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self) -> str: return "foobar" @@ -24,7 +22,7 @@ def predict(self) -> str: def test_good_str_input(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self, text: str) -> str: return text @@ -35,7 +33,7 @@ def predict(self, text: str) -> str: def test_good_int_input(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self, num: int) -> int: return num ** 3 @@ -49,7 +47,7 @@ def predict(self, num: int) -> int: def test_bad_int_input(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self, num: int) -> int: return num ** 2 @@ -68,7 +66,7 @@ def predict(self, num: int) -> int: def test_default_int_input(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self, num: int = Input(default=5)) -> int: return num ** 2 @@ -84,7 +82,7 @@ def predict(self, num: int = Input(default=5)) -> int: def test_file_input_data_url(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self, file: File) -> str: return file.read() @@ -104,7 +102,7 @@ def predict(self, file: File) -> str: @responses.activate def test_file_input_with_http_url(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self, file: File) -> str: return file.read() @@ -119,7 +117,7 @@ def predict(self, file: File) -> str: def test_path_input_data_url(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self, path: Path) -> str: with open(path) as fh: extension = fh.name.split(".")[-1] @@ -141,7 +139,7 @@ def predict(self, path: Path) -> str: @responses.activate def test_file_input_with_http_url(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self, path: Path) -> str: with open(path) as fh: extension = fh.name.split(".")[-1] @@ -158,7 +156,7 @@ def predict(self, path: Path) -> str: def test_file_bad_input(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self, file: File) -> str: return file.read() @@ -171,7 +169,7 @@ def predict(self, file: File) -> str: def test_path_output_file(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self) -> Path: temp_dir = tempfile.mkdtemp() temp_path = os.path.join(temp_dir, "my_file.bmp") @@ -192,7 +190,7 @@ def test_extranous_input_keys(): class Input(BaseModel): text: str - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self, input: Input): return input.text @@ -202,7 +200,7 @@ def predict(self, input: Input): def test_multiple_arguments(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict( self, text: str, @@ -230,7 +228,7 @@ def predict( def test_gt_lt(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self, num: float = Input(gt=3, lt=10.5)) -> float: return num @@ -253,7 +251,7 @@ def predict(self, num: float = Input(gt=3, lt=10.5)) -> float: def test_choices(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self, text: str = Input(choices=["foo", "bar"])) -> str: return str(text) diff --git a/python/tests/server/test_http_output.py b/python/tests/server/test_http_output.py index ef756cda6c..92ae3ad499 100644 --- a/python/tests/server/test_http_output.py +++ b/python/tests/server/test_http_output.py @@ -9,13 +9,12 @@ import responses from responses.matchers import multipart_matcher -import cog -from cog import Path, File +from cog import BasePredictor, Path, File from .test_http import make_client def test_return_wrong_type(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self) -> int: return "foo" @@ -25,7 +24,7 @@ def predict(self) -> int: def test_path_output_path(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self) -> Path: temp_dir = tempfile.mkdtemp() temp_path = os.path.join(temp_dir, "my_file.bmp") @@ -44,7 +43,7 @@ def predict(self) -> Path: @responses.activate def test_output_path_to_http(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self) -> Path: temp_dir = tempfile.mkdtemp() temp_path = os.path.join(temp_dir, "file.txt") @@ -73,7 +72,7 @@ def predict(self) -> Path: def test_path_output_file(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self) -> File: return io.StringIO("hello") @@ -88,7 +87,7 @@ def predict(self) -> File: @responses.activate def test_output_file_to_http(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self) -> File: fh = io.StringIO("hello") fh.name = "foo.txt" @@ -113,7 +112,7 @@ def predict(self) -> File: def test_json_output_numpy(): - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self) -> np.float64: return np.float64(1.0) @@ -128,7 +127,7 @@ class Output(BaseModel): text: str file: File - class Predictor(cog.Predictor): + class Predictor(BasePredictor): def predict(self) -> Output: return Output(text="hello", file=io.StringIO("hello")) diff --git a/test-integration/test_integration/fixtures/failing-project/predict.py b/test-integration/test_integration/fixtures/failing-project/predict.py index db0326c554..6e390c96a2 100644 --- a/test-integration/test_integration/fixtures/failing-project/predict.py +++ b/test-integration/test_integration/fixtures/failing-project/predict.py @@ -1,6 +1,6 @@ -import cog +from cog import BasePredictor -class Predictor(cog.Predictor): +class Predictor(BasePredictor): def predict(self, text: str): raise Exception("over budget") diff --git a/test-integration/test_integration/fixtures/file-input-project/predict.py b/test-integration/test_integration/fixtures/file-input-project/predict.py index 1fb0b73ff2..84d08050a4 100644 --- a/test-integration/test_integration/fixtures/file-input-project/predict.py +++ b/test-integration/test_integration/fixtures/file-input-project/predict.py @@ -1,8 +1,7 @@ -import cog -from cog import Path +from cog import BasePredictor, Path -class Predictor(cog.Predictor): +class Predictor(BasePredictor): def predict(self, path: Path) -> str: with open(path) as f: return f.read() diff --git a/test-integration/test_integration/fixtures/file-output-project/predict.py b/test-integration/test_integration/fixtures/file-output-project/predict.py index 81651136c8..ff8e913c3a 100644 --- a/test-integration/test_integration/fixtures/file-output-project/predict.py +++ b/test-integration/test_integration/fixtures/file-output-project/predict.py @@ -2,12 +2,12 @@ import os import tempfile -import cog +from cog import BasePredictor, Path -class Predictor(cog.Predictor): - def predict(self) -> cog.Path: +class Predictor(BasePredictor): + def predict(self) -> Path: temp_dir = tempfile.mkdtemp() temp_path = os.path.join(temp_dir, f"prediction.bmp") img = Image.new("RGB", (255, 255), "red") img.save(temp_path) - return cog.Path(temp_path) + return Path(temp_path) diff --git a/test-integration/test_integration/fixtures/file-project/predict.py b/test-integration/test_integration/fixtures/file-project/predict.py index bd5fc493b9..f902f0b9c2 100644 --- a/test-integration/test_integration/fixtures/file-project/predict.py +++ b/test-integration/test_integration/fixtures/file-project/predict.py @@ -1,9 +1,8 @@ import tempfile -import cog -from cog import Path +from cog import BasePredictor, Path -class Predictor(cog.Predictor): +class Predictor(BasePredictor): def setup(self): self.foo = "foo" diff --git a/test-integration/test_integration/fixtures/int-project/predict.py b/test-integration/test_integration/fixtures/int-project/predict.py index 2b8daee8ab..51a338613b 100644 --- a/test-integration/test_integration/fixtures/int-project/predict.py +++ b/test-integration/test_integration/fixtures/int-project/predict.py @@ -1,6 +1,6 @@ -import cog +from cog import BasePredictor -class Predictor(cog.Predictor): +class Predictor(BasePredictor): def predict(self, input: int) -> int: return input * 2 diff --git a/test-integration/test_integration/fixtures/logging-project/predict.py b/test-integration/test_integration/fixtures/logging-project/predict.py index bedf6bc2a6..641e84b002 100644 --- a/test-integration/test_integration/fixtures/logging-project/predict.py +++ b/test-integration/test_integration/fixtures/logging-project/predict.py @@ -1,11 +1,10 @@ -import logging import ctypes +import logging import sys -import tempfile -from pathlib import Path -import cog import time +from cog import BasePredictor + libc = ctypes.CDLL(None) # test that we can still capture type signature even if we write @@ -15,7 +14,7 @@ sys.stderr.write("writing to stderr at import time\n") -class Predictor(cog.Predictor): +class Predictor(BasePredictor): def setup(self): print("setting up predictor") self.foo = "foo" diff --git a/test-integration/test_integration/fixtures/string-project/predict.py b/test-integration/test_integration/fixtures/string-project/predict.py index 144a728fcf..a74db19adb 100644 --- a/test-integration/test_integration/fixtures/string-project/predict.py +++ b/test-integration/test_integration/fixtures/string-project/predict.py @@ -1,6 +1,6 @@ -import cog +from cog import BasePredictor -class Predictor(cog.Predictor): +class Predictor(BasePredictor): def predict(self, input: str) -> str: return "hello " + input diff --git a/test-integration/test_integration/fixtures/subdirectory-project/my-subdir/predict.py b/test-integration/test_integration/fixtures/subdirectory-project/my-subdir/predict.py index 29ef1b544e..e442cc1c18 100644 --- a/test-integration/test_integration/fixtures/subdirectory-project/my-subdir/predict.py +++ b/test-integration/test_integration/fixtures/subdirectory-project/my-subdir/predict.py @@ -1,8 +1,8 @@ -import cog +from cog import BasePredictor from mylib import concat -class Predictor(cog.Predictor): +class Predictor(BasePredictor): def predict(self, input: str) -> str: return concat("hello", input) diff --git a/test-integration/test_integration/fixtures/timeout-project/predict.py b/test-integration/test_integration/fixtures/timeout-project/predict.py index fc8ed60886..501d66cd2d 100644 --- a/test-integration/test_integration/fixtures/timeout-project/predict.py +++ b/test-integration/test_integration/fixtures/timeout-project/predict.py @@ -1,8 +1,9 @@ -import cog import time +from cog import BasePredictor -class Predictor(cog.Predictor): + +class Predictor(BasePredictor): def predict(self, sleep_time: float) -> str: time.sleep(sleep_time) return "it worked!" diff --git a/test-integration/test_integration/fixtures/yielding-project/predict.py b/test-integration/test_integration/fixtures/yielding-project/predict.py index e6fbd436c9..445e38c812 100644 --- a/test-integration/test_integration/fixtures/yielding-project/predict.py +++ b/test-integration/test_integration/fixtures/yielding-project/predict.py @@ -1,8 +1,9 @@ from typing import Generator -import cog +from cog import BasePredictor -class Predictor(cog.Predictor): + +class Predictor(BasePredictor): def predict(self, text: str) -> Generator[str, None, None]: predictions = ["foo", "bar", "baz"] for prediction in predictions: diff --git a/test-integration/test_integration/fixtures/yielding-timeout-project/predict.py b/test-integration/test_integration/fixtures/yielding-timeout-project/predict.py index dd8986fc62..661b922359 100644 --- a/test-integration/test_integration/fixtures/yielding-timeout-project/predict.py +++ b/test-integration/test_integration/fixtures/yielding-timeout-project/predict.py @@ -1,9 +1,10 @@ -from typing import Generator -import cog import time +from typing import Generator + +from cog import BasePredictor -class Predictor(cog.Predictor): +class Predictor(BasePredictor): def predict( self, sleep_time: float, n_iterations: int ) -> Generator[str, None, None]: From 9c4ca1e9395d7a12a8368f66458af65f60e0559f Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Fri, 14 Jan 2022 10:31:53 -0800 Subject: [PATCH 10/14] document how to run tests Signed-off-by: Zeke Sikelianos --- CONTRIBUTING.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 25aecfdd7a..9a578c3821 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -131,6 +131,39 @@ As much as possible, this is attempting to follow the [Standard Go Project Layou - `python/` - The Cog Python library. - `test-integration/` - High-level integration tests for Cog. +## Runnings tests + +To run the entire test suite: + +```sh +make test +``` + +To run just the Golang tests: + +```sh +make test-go +``` + +To run just the Python tests: + +```sh +make test-python +``` + +To stand up a server for one of the integration tests: + +```sh +make install +pip install -r requirements-dev.txt +make test +cd test-integration/test_integration/fixtures/file-project +cog build +docker run -p 5001:5000 --init --platform=linux/amd64 cog-file-project +``` + +Then visit [localhost:5001](http://localhost:5001) in your browser. + ## Publishing a release This project has a [GitHub Actions workflow](https://github.com/replicate/cog/blob/39cfc5c44ab81832886c9139ee130296f1585b28/.github/workflows/ci.yaml#L107) that uses [goreleaser](https://goreleaser.com/quick-start/#quick-start) to facilitate the process of publishing new releases. The release process is triggered by manually creating and pushing a new git tag. From e58e1b97ec65c517f04a791959a6fcef0df6b6d0 Mon Sep 17 00:00:00 2001 From: Ben Firshman Date: Fri, 14 Jan 2022 16:00:48 -0800 Subject: [PATCH 11/14] Add cog.BaseModel As a convenience for `from pydantic import BaseModel` Signed-off-by: Ben Firshman --- python/cog/__init__.py | 3 +++ python/tests/server/test_http_output.py | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 8c37d1c309..a56666511a 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -1,3 +1,5 @@ +from pydantic import BaseModel + from .predictor import BasePredictor from .types import File, Input, Path @@ -5,6 +7,7 @@ Predictor = BasePredictor __all__ = [ + "BaseModel", "BasePredictor", "File", "Input", diff --git a/python/tests/server/test_http_output.py b/python/tests/server/test_http_output.py index 92ae3ad499..39ac18656c 100644 --- a/python/tests/server/test_http_output.py +++ b/python/tests/server/test_http_output.py @@ -5,11 +5,10 @@ import numpy as np from PIL import Image -from pydantic import BaseModel import responses from responses.matchers import multipart_matcher -from cog import BasePredictor, Path, File +from cog import BaseModel, BasePredictor, Path, File from .test_http import make_client From 0bbe391979b444ce2eabb52ba52b9b796044da29 Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Thu, 13 Jan 2022 22:32:39 -0800 Subject: [PATCH 12/14] document new Python API Signed-off-by: Zeke Sikelianos --- docs/python.md | 86 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 69 insertions(+), 17 deletions(-) diff --git a/docs/python.md b/docs/python.md index 379f5934b3..0ef6bfdc43 100644 --- a/docs/python.md +++ b/docs/python.md @@ -1,5 +1,19 @@ # Prediction interface reference +This document defines the API of the `cog` Python module, which is used to define the interface for running predictions on your model. + +Tip: Run [`cog init`](getting-started-own-model#initialization) to generate an annotated `predict.py` file that can be used as a starting point for setting up your model. + +## Contents + +- [`BasePredictor`](#basepredictor) + - [`Predictor.setup()`](#predictorsetup) + - [`Predictor.predict(**kwargs)`](#predictorpredictkwargs) +- [`Input(**kwargs)`](#inputkwargs) +- [`Output(BaseModel)`](#outputbasemodel) + +## `BasePredictor` + You define how Cog runs predictions on your model by defining a class that inherits from `BasePredictor`. It looks something like this: ```python @@ -22,23 +36,25 @@ class Predictor(BasePredictor): return output ``` -Tip: Run [`cog init`](getting-started-own-model#initialization) to generate an annotated `predict.py` file that can be used as a starting point for setting up your model. - -You need to override two functions: `setup()` and `predict()`. +Your Predictor class should define two methods: `setup()` and `predict()`. ### `Predictor.setup()` -Set up the model for prediction so multiple predictions run efficiently. Include any expensive one-off operations in here like loading trained models, instantiate data transformations, etc. +Prepare the model so multiple predictions run efficiently. + +Use this _optional_ method to include any expensive one-off operations in here like loading trained models, instantiate data transformations, etc. It's best not to download model weights or any other files in this function. You should bake these into the image when you build it. This means your model doesn't depend on any other system being available and accessible. It also means the Docker image ID becomes an immutable identifier for the precise model you're running, instead of the combination of the image ID and whatever files it might have downloaded. ### `Predictor.predict(**kwargs)` -Run a single prediction. This is where you call the model that was loaded during `setup()`, but you may also want to add pre- and post-processing code here. +Run a single prediction. + +This _required_ method is where you call the model that was loaded during `setup()`, but you may also want to add pre- and post-processing code here. -The `predict()` function takes an arbitrary list of named arguments, where each argument name must correspond to a `@cog.input()` annotation. +The `predict()` method takes an arbitrary list of named arguments, where each argument name must correspond to an [`Input()`](#inputkwargs) annotation. -`predict()` can output strings, numbers, `pathlib.Path` objects, or lists or dicts of those types. We are working on support for other types of output, but for now we recommend using base-64 encoded strings or `pathlib.Path`s for more complex outputs. +`predict()` can return strings, numbers, `pathlib.Path` objects, or lists or dicts of those types. You can also define a custom [`Output()`](#outputbasemodel) for more complex return types. #### Returning `pathlib.Path` objects @@ -47,22 +63,58 @@ If the output is a `pathlib.Path` object, that will be returned by the built-in To output `pathlib.Path` objects the file needs to exist, which means that you probably need to create a temporary file first. This file will automatically be deleted by Cog after it has been returned. For example: ```python -def predict(self, input): - output = do_some_processing(input) +def predict(self, image: Path = Input(description="Image to enlarge")) -> Path: + output = do_some_processing(image) out_path = Path(tempfile.mkdtemp()) / "my-file.txt" out_path.write_text(output) return out_path ``` -### `@cog.input(name, type, help, default=None, min=None, max=None, options=None)` +## `Input(**kwargs)` -The `@cog.input()` annotation describes a single input to the `predict()` function. The `name` must correspond to an argument name in `predict()`. +Use cog's `Input()` function to define each of the parameters in your `predict()` method: + +```py +class Predictor(BasePredictor): + def predict(self, + image: Path = Input(description="Image to enlarge"), + scale: float = Input(description="Factor to scale image by", default=1.5, gt=0, lt=10) + ) -> Path: +``` -It takes these arguments: +The `Input()` function takes these keyword arguments: -- `type`: Either `str`, `int`, `float`, `bool`, or `Path` (be sure to add the import, as in the example above). `Path` is used for files. For more complex inputs, save it to a file and use `Path`. -- `help`: A description of what to pass to this input for users of the model +- `description`: A description of what to pass to this input for users of the model. - `default`: A default value to set the input to. If this argument is not passed, the input is required. If it is explicitly set to `None`, the input is optional. -- `min`: A minimum value for `int` or `float` types. -- `max`: A maximum value for `int` or `float` types. -- `options`: A list of values to limit the input to. It can be used with `str`, `int`, and `float` inputs. +- `gt`: For `int` or `float` types, the value must be greater than this number. +- `ge`: For `int` or `float` types, the value must be greater than or equal to this number. +- `lt`: For `int` or `float` types, the value must be less than this number. +- `le`: For `int` or `float` types, the value must be less than or equal to this number. +- `choices`: A list of possible values for this input. + +Each parameter of the `predict()` method must be annotated with a type. The supported types are: + +- `str`: a string +- `int`: an integer +- `float`: a floating point number +- `bool`: a boolean +- `cog.File`: a file-like object representing a file +- `cog.Path`: a path to a file on disk + +## `Output(BaseModel)` + +Your `predict()` method can return a simple data type like a string or a number, or a more complex object with multiple values. + +You can optionally use cog's `Output()` function to define the object returned by your `predict()` method: + +```py +from cog import BasePredictor, BaseModel + +class Output(BaseModel): + text: str + file: File + +class Predictor(BasePredictor): + def predict(self) -> Output: + return Output(text="hello", file=io.StringIO("hello")) +``` From b212209196771d272c22b4cd92d262c8964acd15 Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Tue, 18 Jan 2022 17:21:18 -0800 Subject: [PATCH 13/14] add docstrings to new python functions Signed-off-by: Zeke Sikelianos --- python/cog/json.py | 3 ++- python/cog/predictor.py | 30 +++++++++++++++++++++++++++--- python/cog/server/http.py | 3 +++ python/tests/server/test_http.py | 7 +++---- 4 files changed, 35 insertions(+), 8 deletions(-) diff --git a/python/cog/json.py b/python/cog/json.py index e179800ae4..cbef9493c3 100644 --- a/python/cog/json.py +++ b/python/cog/json.py @@ -1,5 +1,6 @@ from enum import Enum import io +from typing import Any from pydantic import BaseModel @@ -13,7 +14,7 @@ has_numpy = False -def encode_json(obj, upload_file): +def encode_json(obj: Any, upload_file) -> Any: """ Returns a JSON-compatible version of the object. It will encode any Pydantic models and custom types. diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 08f087897f..7d6a266705 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -5,7 +5,6 @@ import inspect import os.path from pathlib import Path -import typing from pydantic import create_model, BaseModel from pydantic.fields import FieldInfo @@ -19,11 +18,15 @@ class BasePredictor(ABC): def setup(self): - pass + """ + An optional method to prepare the model so multiple predictions run efficiently. + """ @abstractmethod def predict(self, **kwargs): - pass + """ + Run a single prediction on the model. + """ def run_prediction(predictor, inputs, cleanup_functions): @@ -38,6 +41,10 @@ def run_prediction(predictor, inputs, cleanup_functions): def load_predictor(): + """ + Reads cog.yaml and constructs an instance of the user-defined Predictor class. + """ + # Assumes the working directory is /src config_path = os.path.abspath("cog.yaml") try: @@ -64,6 +71,19 @@ def load_predictor(): def get_input_type(predictor: BasePredictor): + """ + Creates a Pydantic Input model from the arguments of a Predictor's predict() method. + + class Predictor(BasePredictor): + def predict(self, text: str): + ... + + programmatically creates a model like this: + + class Input(BaseModel): + text: str + """ + signature = inspect.signature(predictor.predict) create_model_kwargs = {} @@ -107,6 +127,10 @@ def get_input_type(predictor: BasePredictor): def get_output_type(predictor: BasePredictor): + """ + Creates a Pydantic Output model from the return type annotation of a Predictor's predict() method. + """ + signature = inspect.signature(predictor.predict) if signature.return_annotation is inspect.Signature.empty: OutputType = Literal[None] diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 3ed9b734af..a83e5642e8 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -50,6 +50,9 @@ class Request(BaseModel): # The signature of this function is used by FastAPI to generate the schema. # The function body is not used to generate the schema. def predict(request: Request = Body(default=None)): + """ + Run a single prediction on the model. + """ if request is None or request.input is None: output = predictor.predict() else: diff --git a/python/tests/server/test_http.py b/python/tests/server/test_http.py index c9aa37a904..653a754d15 100644 --- a/python/tests/server/test_http.py +++ b/python/tests/server/test_http.py @@ -71,6 +71,7 @@ def predict( "/predictions": { "post": { "summary": "Predict", + "description": "Run a single prediction on the model.", "operationId": "predict_predictions_post", "requestBody": { "content": { @@ -152,10 +153,7 @@ def predict( "choices": {"$ref": "#/components/schemas/choices"}, }, }, - "Output": { - "title": "Output", - "type": "string", - }, + "Output": {"title": "Output", "type": "string"}, "Request": { "title": "Request", "type": "object", @@ -243,6 +241,7 @@ def predict( "/predictions": { "post": { "summary": "Predict", + "description": "Run a single prediction on the model.", "operationId": "predict_predictions_post", "requestBody": { "content": { From 1fdff8272b8ed5de38c003f42ef01945e942aa55 Mon Sep 17 00:00:00 2001 From: Zeke Sikelianos Date: Fri, 21 Jan 2022 10:23:49 -0800 Subject: [PATCH 14/14] update `cog init` predictor template Signed-off-by: Zeke Sikelianos --- pkg/cli/init-templates/predict.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/pkg/cli/init-templates/predict.py b/pkg/cli/init-templates/predict.py index c59fce56e6..2c5a12f371 100644 --- a/pkg/cli/init-templates/predict.py +++ b/pkg/cli/init-templates/predict.py @@ -1,18 +1,22 @@ # Prediction interface for Cog ⚙️ -# Reference: https://github.com/replicate/cog/blob/main/docs/python.md +# https://github.com/replicate/cog/blob/main/docs/python.md -import cog -# import torch +from cog import BasePredictor, Input, Path -class Predictor(cog.Predictor): + +class Predictor(BasePredictor): def setup(self): - """Load the model into memory to make running multiple predictions efficient""" - # self.model = torch.load("./weights.pth") + """Load the model into memory to make running multiple predictions efficient""" + # self.model = torch.load("./weights.pth") - @cog.input("image", type=cog.Path, help="Grayscale input image") - @cog.input("scale", type=float, default=1.5, help="Factor to scale image by") - def predict(self, image): + def predict( + self, + input: Path = Input(title="Grayscale input image"), + scale: float = Input( + title="Factor to scale image by", gt=0, lt=10, default=1.5 + ), + ) -> Path: """Run a single prediction on the model""" - # processed_input = preprocess(image) - # output = self.model(processed_input) - # return post_processing(output) + # processed_input = preprocess(input) + # output = self.model(processed_input, scale) + # return postprocess(output)