Skip to content

Commit 0f0162e

Browse files
committed
Rename Predictor to BasePredictor
Signed-off-by: Ben Firshman <[email protected]>
1 parent 0fccc84 commit 0f0162e

File tree

24 files changed

+86
-91
lines changed

24 files changed

+86
-91
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ predict: "predict.py:Predictor"
2727
And define how predictions are run on your model with `predict.py`:
2828

2929
```python
30-
from cog import Predictor, Input, Path
30+
from cog import BasePredictor, Input, Path
3131
import torch
3232
33-
class ColorizationPredictor(Predictor):
33+
class Predictor(BasePredictor):
3434
def setup(self):
3535
"""Load the model into memory to make running multiple predictions efficient"""
3636
self.model = torch.load("./weights.pth")

docs/getting-started-own-model.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,10 @@ With `cog.yaml`, you can also install system packages and other things. [Take a
6666
The next step is to update `predict.py` to define the interface for running predictions on your model. The `predict.py` generated by `cog init` looks something like this:
6767
6868
```python
69-
import cog
70-
from cog import Path, Input
69+
from cog import BasePredictor, Path, Input
7170
import torch
7271
73-
class Predictor(cog.Predictor):
72+
class Predictor(BasePredictor):
7473
def setup(self):
7574
"""Load the model into memory to make running multiple predictions efficient"""
7675
self.net = torch.load("weights.pth")

docs/getting-started.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,14 @@ Then, we need to write some code to describe how predictions are run on the mode
6060

6161
```python
6262
from typing import Any
63-
import cog
64-
from cog import Input, Path
63+
from cog import BasePredictor, Input, Path
6564
from tensorflow.keras.applications.resnet50 import ResNet50
6665
from tensorflow.keras.preprocessing import image as keras_image
6766
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
6867
import numpy as np
6968
7069
71-
class ResNetPredictor(cog.Predictor):
70+
class ResNetPredictor(BasePredictor):
7271
def setup(self):
7372
"""Load the model into memory to make running multiple predictions efficient"""
7473
self.model = ResNet50(weights='resnet50_weights_tf_dim_ordering_tf_kernels.h5')

docs/python.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# Prediction interface reference
22

3-
You define how Cog runs predictions on your model by defining a class that inherits from `cog.Predictor`. It looks something like this:
3+
You define how Cog runs predictions on your model by defining a class that inherits from `BasePredictor`. It looks something like this:
44

55
```python
6-
import cog
7-
from cog import Path, Input
6+
from cog import BasePredictor, Path, Input
87
import torch
98

10-
class ImageScalingPredictor(cog.Predictor):
9+
class Predictor(BasePredictor):
1110
def setup(self):
1211
"""Load the model into memory to make running multiple predictions efficient"""
1312
self.model = torch.load("weights.pth")

docs/yaml.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build:
1212
system_packages:
1313
- "ffmpeg"
1414
- "libavcodec-dev"
15-
predict: "predict.py:JazzSoloComposerPredictor"
15+
predict: "predict.py:Predictor"
1616
```
1717
1818
Tip: Run [`cog init`](getting-started-own-model#initialization) to generate an annotated `cog.yaml` file that can be used as a starting point for setting up your model.
@@ -102,12 +102,12 @@ If you don't provide this, a name will be generated from the directory name.
102102

103103
## `predict`
104104

105-
The pointer to the `cog.Predictor` object in your code, which defines how predictions are run on your model.
105+
The pointer to the `Predictor` object in your code, which defines how predictions are run on your model.
106106

107107
For example:
108108

109109
```yaml
110-
predict: "predict.py:HotdogPredictor"
110+
predict: "predict.py:Predictor"
111111
```
112112

113113
See [the Python API documentation for more information](python.md).

pkg/config/config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ func (c *Config) ValidateAndCompleteConfig() error {
122122
}
123123
if c.Predict != "" {
124124
if len(strings.Split(c.Predict, ".py:")) != 2 {
125-
return fmt.Errorf("'predict' in cog.yaml must be in the form 'predict.py:PredictorClass")
125+
return fmt.Errorf("'predict' in cog.yaml must be in the form 'predict.py:Predictor")
126126
}
127127
}
128128

python/cog/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from .predictor import Predictor
1+
from .predictor import BasePredictor
22
from .types import File, Input, Path
33

4+
# Backwards compatibility. Will be deprecated before 1.0.0.
5+
Predictor = BasePredictor
46

57
__all__ = [
8+
"BasePredictor",
69
"File",
710
"Input",
811
"Path",

python/cog/predictor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .types import Input
1818

1919

20-
class Predictor(ABC):
20+
class BasePredictor(ABC):
2121
def setup(self):
2222
pass
2323

@@ -63,7 +63,7 @@ def load_predictor():
6363
return predictor_class()
6464

6565

66-
def get_input_type(predictor: Predictor):
66+
def get_input_type(predictor: BasePredictor):
6767
signature = inspect.signature(predictor.predict)
6868
create_model_kwargs = {}
6969

@@ -106,7 +106,7 @@ def get_input_type(predictor: Predictor):
106106
return create_model("Input", **create_model_kwargs)
107107

108108

109-
def get_output_type(predictor: Predictor):
109+
def get_output_type(predictor: BasePredictor):
110110
signature = inspect.signature(predictor.predict)
111111
if signature.return_annotation is inspect.Signature.empty:
112112
OutputType = Literal[None]

python/cog/server/http.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010

1111
from ..files import upload_file
1212
from ..json import encode_json
13-
from ..predictor import Predictor, get_input_type, get_output_type, load_predictor
13+
from ..predictor import BasePredictor, get_input_type, get_output_type, load_predictor
1414
from ..response import Status, get_response_type
1515

1616
logger = logging.getLogger("cog")
1717

1818

19-
def create_app(predictor: Predictor) -> FastAPI:
19+
def create_app(predictor: BasePredictor) -> FastAPI:
2020
app = FastAPI(
2121
title="Cog", # TODO: mention model name?
2222
# version=None # TODO

python/cog/server/redis_queue.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
import requests
1515

1616
from .redis_log_capture import capture_log
17-
from ..predictor import Predictor, get_input_type, load_predictor
17+
from ..predictor import BasePredictor, get_input_type, load_predictor
1818
from ..json import encode_json
19-
from ..predictor import Predictor, load_predictor
2019

2120

2221
class timeout:
@@ -53,7 +52,7 @@ class RedisQueueWorker:
5352

5453
def __init__(
5554
self,
56-
predictor: Predictor,
55+
predictor: BasePredictor,
5756
redis_host: str,
5857
redis_port: int,
5958
input_queue: str,

0 commit comments

Comments
 (0)