Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions predictionguard/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

import requests
from typing import Optional
from typing import Optional, Union

from .src.audio import Audio
from .src.chat import Chat
Expand Down Expand Up @@ -30,11 +30,15 @@ class PredictionGuard:
"""PredictionGuard provides access the Prediction Guard API."""

def __init__(
self, api_key: Optional[str] = None, url: Optional[str] = None
self,
api_key: Optional[str] = None,
url: Optional[str] = None,
timeout: Optional[Union[int, float]] = None
) -> None:
"""
:param api_key: api_key represents PG api key.
:param url: url represents the transport and domain:port
:param timeout: request timeout in seconds.
"""

# Get the access api_key.
Expand All @@ -56,50 +60,67 @@ def __init__(
url = "https://api.predictionguard.com"
self.url = url

if not timeout:
timeout = os.environ.get("TIMEOUT")
if not timeout:
timeout = None
if timeout:
try:
timeout = float(timeout)
except ValueError:
raise ValueError(
"Timeout must be of type integer or float, not %s." % (type(timeout).__name__,)
)
except TypeError:
raise TypeError(
"Timeout should be of type integer or float, not %s." % (type(timeout).__name__,)
)
self.timeout = timeout

# Connect to Prediction Guard and set the access api_key.
self._connect_client()

# Pass Prediction Guard class variables to inner classes
self.chat: Chat = Chat(self.api_key, self.url)
self.chat: Chat = Chat(self.api_key, self.url, self.timeout)
"""Chat generates chat completions based on a conversation history"""

self.completions: Completions = Completions(self.api_key, self.url)
self.completions: Completions = Completions(self.api_key, self.url, self.timeout)
"""Completions generates text completions based on the provided input"""

self.embeddings: Embeddings = Embeddings(self.api_key, self.url)
self.embeddings: Embeddings = Embeddings(self.api_key, self.url, self.timeout)
"""Embedding generates chat completions based on a conversation history."""

self.audio: Audio = Audio(self.api_key, self.url)
self.audio: Audio = Audio(self.api_key, self.url, self.timeout)
"""Audio allows for the transcription of audio files."""

self.documents: Documents = Documents(self.api_key, self.url)
self.documents: Documents = Documents(self.api_key, self.url, self.timeout)
"""Documents allows you to extract text from various document file types."""

self.rerank: Rerank = Rerank(self.api_key, self.url)
self.rerank: Rerank = Rerank(self.api_key, self.url, self.timeout)
"""Rerank sorts text inputs by semantic relevance to a specified query."""

self.translate: Translate = Translate(self.api_key, self.url)
self.translate: Translate = Translate(self.api_key, self.url, self.timeout)
"""Translate converts text from one language to another."""

self.factuality: Factuality = Factuality(self.api_key, self.url)
self.factuality: Factuality = Factuality(self.api_key, self.url, self.timeout)
"""Factuality checks the factuality of a given text compared to a reference."""

self.toxicity: Toxicity = Toxicity(self.api_key, self.url)
self.toxicity: Toxicity = Toxicity(self.api_key, self.url, self.timeout)
"""Toxicity checks the toxicity of a given text."""

self.pii: Pii = Pii(self.api_key, self.url)
self.pii: Pii = Pii(self.api_key, self.url, self.timeout)
"""Pii replaces personal information such as names, SSNs, and emails in a given text."""

self.injection: Injection = Injection(self.api_key, self.url)
self.injection: Injection = Injection(self.api_key, self.url, self.timeout)
"""Injection detects potential prompt injection attacks in a given prompt."""

self.tokenize: Tokenize = Tokenize(self.api_key, self.url)
self.tokenize: Tokenize = Tokenize(self.api_key, self.url, self.timeout)
"""Tokenize generates tokens for input text."""

self.detokenize: Detokenize = Detokenize(self.api_key, self.url)
self.detokenize: Detokenize = Detokenize(self.api_key, self.url, self.timeout)
"""Detokenizes generates text for input tokens."""

self.models: Models = Models(self.api_key, self.url)
self.models: Models = Models(self.api_key, self.url, self.timeout)
"""Models lists all of the models available in the Prediction Guard API."""

def _connect_client(self) -> None:
Expand All @@ -112,7 +133,7 @@ def _connect_client(self) -> None:
}

# Try listing models to make sure we can connect.
response = requests.request("GET", self.url + "/completions", headers=headers)
response = requests.request("GET", self.url + "/completions", headers=headers, timeout=self.timeout)

# If the connection was unsuccessful, raise an exception.
if response.status_code == 200:
Expand Down
10 changes: 6 additions & 4 deletions predictionguard/src/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,18 @@ class Audio:
))
"""

def __init__(self, api_key, url):
def __init__(self, api_key, url, timeout):
self.api_key = api_key
self.url = url
self.timeout = timeout

self.transcriptions: AudioTranscriptions = AudioTranscriptions(self.api_key, self.url)
self.transcriptions: AudioTranscriptions = AudioTranscriptions(self.api_key, self.url, self.timeout)

class AudioTranscriptions:
def __init__(self, api_key, url):
def __init__(self, api_key, url, timeout):
self.api_key = api_key
self.url = url
self.timeout = timeout

def create(
self,
Expand Down Expand Up @@ -164,7 +166,7 @@ def _transcribe_audio(
}

response = requests.request(
"POST", self.url + "/audio/transcriptions", headers=headers, files=files, data=data
"POST", self.url + "/audio/transcriptions", headers=headers, files=files, data=data, timeout=self.timeout
)

# If the request was successful, print the proxies.
Expand Down
21 changes: 12 additions & 9 deletions predictionguard/src/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,19 @@ class Chat:
))
"""

def __init__(self, api_key, url):
def __init__(self, api_key, url, timeout):
self.api_key = api_key
self.url = url
self.timeout = timeout

self.completions: ChatCompletions = ChatCompletions(self.api_key, self.url)
self.completions: ChatCompletions = ChatCompletions(self.api_key, self.url, self.timeout)


class ChatCompletions:
def __init__(self, api_key, url):
def __init__(self, api_key, url, timeout):
self.api_key = api_key
self.url = url
self.timeout = timeout

def create(
self,
Expand Down Expand Up @@ -192,9 +194,9 @@ def _generate_chat(
Function to generate a single chat response.
"""

def return_dict(url, headers, payload):
def return_dict(url, headers, payload, timeout):
response = requests.request(
"POST", url + "/chat/completions", headers=headers, data=payload
"POST", url + "/chat/completions", headers=headers, data=payload, timeout=timeout
)
# If the request was successful, print the proxies.
if response.status_code == 200:
Expand All @@ -215,12 +217,13 @@ def return_dict(url, headers, payload):
pass
raise ValueError("Could not make prediction. " + err)

def stream_generator(url, headers, payload, stream):
def stream_generator(url, headers, payload, stream, timeout):
with requests.post(
url + "/chat/completions",
headers=headers,
data=payload,
stream=stream,
timeout=timeout,
) as response:
response.raise_for_status()

Expand Down Expand Up @@ -356,10 +359,10 @@ def stream_generator(url, headers, payload, stream):
payload = json.dumps(payload_dict)

if stream:
return stream_generator(self.url, headers, payload, stream)
return stream_generator(self.url, headers, payload, stream, self.timeout)

else:
return return_dict(self.url, headers, payload)
return return_dict(self.url, headers, payload, self.timeout)

def list_models(self, capability: Optional[str] = "chat-completion") -> List[str]:
# Get the list of current models.
Expand All @@ -376,7 +379,7 @@ def list_models(self, capability: Optional[str] = "chat-completion") -> List[str
else:
model_path = "/models/" + capability

response = requests.request("GET", self.url + model_path, headers=headers)
response = requests.request("GET", self.url + model_path, headers=headers, timeout=self.timeout)

response_list = []
for model in response.json()["data"]:
Expand Down
16 changes: 9 additions & 7 deletions predictionguard/src/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class Completions:
))
"""

def __init__(self, api_key, url):
def __init__(self, api_key, url, timeout):
self.api_key = api_key
self.url = url
self.timeout = timeout

def create(
self,
Expand Down Expand Up @@ -132,9 +133,9 @@ def _generate_completion(
Function to generate a single completion.
"""

def return_dict(url, headers, payload):
def return_dict(url, headers, payload, timeout):
response = requests.request(
"POST", url + "/completions", headers=headers, data=payload
"POST", url + "/completions", headers=headers, data=payload, timeout=timeout
)
# If the request was successful, print the proxies.
if response.status_code == 200:
Expand All @@ -155,12 +156,13 @@ def return_dict(url, headers, payload):
pass
raise ValueError("Could not make prediction. " + err)

def stream_generator(url, headers, payload, stream):
def stream_generator(url, headers, payload, stream, timeout):
with requests.post(
url + "/completions",
headers=headers,
data=payload,
stream=stream,
timeout=timeout
) as response:
response.raise_for_status()

Expand Down Expand Up @@ -215,10 +217,10 @@ def stream_generator(url, headers, payload, stream):
payload = json.dumps(payload_dict)

if stream:
return stream_generator(self.url, headers, payload, stream)
return stream_generator(self.url, headers, payload, stream, self.timeout)

else:
return return_dict(self.url, headers, payload)
return return_dict(self.url, headers, payload, self.timeout)

def list_models(self) -> List[str]:
# Get the list of current models.
Expand All @@ -228,7 +230,7 @@ def list_models(self) -> List[str]:
"User-Agent": "Prediction Guard Python Client: " + __version__,
}

response = requests.request("GET", self.url + "/models/completion", headers=headers)
response = requests.request("GET", self.url + "/models/completion", headers=headers, timeout=self.timeout)

response_list = []
for model in response.json()["data"]:
Expand Down
7 changes: 4 additions & 3 deletions predictionguard/src/detokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class Detokenize:
"""


def __init__(self, api_key, url):
def __init__(self, api_key, url, timeout):
self.api_key = api_key
self.url = url
self.timeout = timeout

def create(self, model: str, tokens: List[int]) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -85,7 +86,7 @@ def _create_tokens(self, model, tokens):
payload = json.dumps(payload)

response = requests.request(
"POST", self.url + "/detokenize", headers=headers, data=payload
"POST", self.url + "/detokenize", headers=headers, data=payload, timeout=self.timeout
)

if response.status_code == 200:
Expand Down Expand Up @@ -114,7 +115,7 @@ def list_models(self):
"User-Agent": "Prediction Guard Python Client: " + __version__
}

response = requests.request("GET", self.url + "/models/detokenize", headers=headers)
response = requests.request("GET", self.url + "/models/detokenize", headers=headers, timeout=self.timeout)

response_list = []
for model in response.json()["data"]:
Expand Down
10 changes: 6 additions & 4 deletions predictionguard/src/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,18 @@ class Documents:
))
"""

def __init__(self, api_key, url):
def __init__(self, api_key, url, timeout):
self.api_key = api_key
self.url = url
self.timeout = timeout

self.extract: DocumentsExtract = DocumentsExtract(self.api_key, self.url)
self.extract: DocumentsExtract = DocumentsExtract(self.api_key, self.url, self.timeout)

class DocumentsExtract:
def __init__(self, api_key, url):
def __init__(self, api_key, url, timeout):
self.api_key = api_key
self.url = url
self.timeout = timeout

def create(
self,
Expand Down Expand Up @@ -117,7 +119,7 @@ def _extract_documents(

response = requests.request(
"POST", self.url + "/documents/extract",
headers=headers, files=files, data=data
headers=headers, files=files, data=data, timeout=self.timeout
)

# If the request was successful, print the proxies.
Expand Down
7 changes: 4 additions & 3 deletions predictionguard/src/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ class Embeddings:
))
"""

def __init__(self, api_key, url):
def __init__(self, api_key, url, timeout):
self.api_key = api_key
self.url = url
self.timeout = timeout

def create(
self,
Expand Down Expand Up @@ -166,7 +167,7 @@ def _generate_embeddings(self, model, input, truncate, truncation_direction):

payload = json.dumps(payload_dict)
response = requests.request(
"POST", self.url + "/embeddings", headers=headers, data=payload
"POST", self.url + "/embeddings", headers=headers, data=payload, timeout=self.timeout
)

# If the request was successful, print the proxies.
Expand Down Expand Up @@ -204,7 +205,7 @@ def list_models(self, capability: Optional[str] = "embedding") -> List[str]:
else:
model_path = "/models/" + capability

response = requests.request("GET", self.url + model_path, headers=headers)
response = requests.request("GET", self.url + model_path, headers=headers, timeout=self.timeout)

response_list = []
for model in response.json()["data"]:
Expand Down
5 changes: 3 additions & 2 deletions predictionguard/src/factuality.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class Factuality:
))
"""

def __init__(self, api_key, url):
def __init__(self, api_key, url, timeout):
self.api_key = api_key
self.url = url
self.timeout = timeout

def check(self, reference: str, text: str) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -72,7 +73,7 @@ def _generate_score(self, reference, text):
payload_dict = {"reference": reference, "text": text}
payload = json.dumps(payload_dict)
response = requests.request(
"POST", self.url + "/factuality", headers=headers, data=payload
"POST", self.url + "/factuality", headers=headers, data=payload, timeout=self.timeout
)

# If the request was successful, print the proxies.
Expand Down
Loading
Loading