Skip to content

Commit 5a9de9b

Browse files
committed
Update examples for new Cog types
replicate/cog#378
1 parent 1680c8e commit 5a9de9b

File tree

4 files changed

+29
-28
lines changed

4 files changed

+29
-28
lines changed

blur/predict.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
import tempfile
2-
from pathlib import Path
32

3+
from cog import BasePredictor, Input, Path
44
from PIL import Image, ImageFilter
5-
import cog
65

76

8-
class Predictor(cog.Predictor):
9-
def setup(self):
10-
pass
11-
12-
@cog.input("input", type=Path, help="Input image")
13-
@cog.input("blur", type=float, help="Blur radius", default=5)
14-
def predict(self, input, blur):
7+
class Predictor(BasePredictor):
8+
def predict(
9+
self,
10+
image: Path = Input(description="Input image"),
11+
blur: float = Input(description="Blur radius", default=5),
12+
) -> Path:
1513
if blur == 0:
1614
return input
17-
im = Image.open(str(input))
15+
im = Image.open(str(image))
1816
im = im.filter(ImageFilter.BoxBlur(blur))
1917
out_path = Path(tempfile.mkdtemp()) / "out.png"
2018
im.save(str(out_path))

hello-world/predict.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
import cog
1+
from cog import BasePredictor, Input
22

3-
class Predictor(cog.Predictor):
3+
4+
class Predictor(BasePredictor):
45
def setup(self):
56
self.prefix = "hello"
67

7-
@cog.input("input", type=str, help="Text that will get prefixed by 'hello '")
8-
def predict(self, input):
9-
return f"\n\n{self.prefix} {input}\n\n"
8+
def predict(self, text: str = Input(description="Text to prefix with 'hello '")) -> str:
9+
return self.prefix + " " + text

resnet/cog.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ build:
33
python_packages:
44
- "pillow==8.3.1"
55
- "tensorflow==2.5.0"
6-
predict: "predict.py:ResNetPredictor"
6+
predict: "predict.py:Predictor"

resnet/predict.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,29 @@
1-
import cog
2-
from pathlib import Path
3-
from tensorflow.keras.applications.resnet50 import ResNet50
4-
from tensorflow.keras.preprocessing import image
5-
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
1+
from typing import Any
2+
63
import numpy as np
4+
from cog import BasePredictor, Input, Path
5+
from tensorflow.keras.applications.resnet50 import (
6+
ResNet50,
7+
decode_predictions,
8+
preprocess_input,
9+
)
10+
from tensorflow.keras.preprocessing import image as keras_image
711

812

9-
class ResNetPredictor(cog.Predictor):
13+
class Predictor(BasePredictor):
1014
def setup(self):
1115
"""Load the model into memory to make running multiple predictions efficient"""
12-
self.model = ResNet50(weights='resnet50_weights_tf_dim_ordering_tf_kernels.h5')
16+
self.model = ResNet50(weights="resnet50_weights_tf_dim_ordering_tf_kernels.h5")
1317

1418
# Define the arguments and types the model takes as input
15-
@cog.input("input", type=Path, help="Image to classify")
16-
def predict(self, input):
19+
def predict(self, image: Path = Input(description="Image to classify")) -> Any:
1720
"""Run a single prediction on the model"""
1821
# Preprocess the image
19-
img = image.load_img(input, target_size=(224, 224))
20-
x = image.img_to_array(img)
22+
img = keras_image.load_img(image, target_size=(224, 224))
23+
x = keras_image.img_to_array(img)
2124
x = np.expand_dims(x, axis=0)
2225
x = preprocess_input(x)
2326
# Run the prediction
2427
preds = self.model.predict(x)
2528
# Return the top 3 predictions
26-
return str(decode_predictions(preds, top=3)[0])
29+
return decode_predictions(preds, top=3)[0]

0 commit comments

Comments
 (0)