diff --git a/predictionguard/client.py b/predictionguard/client.py index 293c3f2..6697a71 100644 --- a/predictionguard/client.py +++ b/predictionguard/client.py @@ -1,7 +1,7 @@ import os import requests -from typing import Optional +from typing import Optional, Union from .src.audio import Audio from .src.chat import Chat @@ -30,11 +30,15 @@ class PredictionGuard: """PredictionGuard provides access the Prediction Guard API.""" def __init__( - self, api_key: Optional[str] = None, url: Optional[str] = None + self, + api_key: Optional[str] = None, + url: Optional[str] = None, + timeout: Optional[Union[int, float]] = None ) -> None: """ :param api_key: api_key represents PG api key. :param url: url represents the transport and domain:port + :param timeout: request timeout in seconds. """ # Get the access api_key. @@ -56,50 +60,67 @@ def __init__( url = "https://api.predictionguard.com" self.url = url + if not timeout: + timeout = os.environ.get("TIMEOUT") + if not timeout: + timeout = None + if timeout: + try: + timeout = float(timeout) + except ValueError: + raise ValueError( + "Timeout must be of type integer or float, not %s." % (type(timeout).__name__,) + ) + except TypeError: + raise TypeError( + "Timeout should be of type integer or float, not %s." % (type(timeout).__name__,) + ) + self.timeout = timeout + # Connect to Prediction Guard and set the access api_key. self._connect_client() # Pass Prediction Guard class variables to inner classes - self.chat: Chat = Chat(self.api_key, self.url) + self.chat: Chat = Chat(self.api_key, self.url, self.timeout) """Chat generates chat completions based on a conversation history""" - self.completions: Completions = Completions(self.api_key, self.url) + self.completions: Completions = Completions(self.api_key, self.url, self.timeout) """Completions generates text completions based on the provided input""" - self.embeddings: Embeddings = Embeddings(self.api_key, self.url) + self.embeddings: Embeddings = Embeddings(self.api_key, self.url, self.timeout) """Embedding generates chat completions based on a conversation history.""" - self.audio: Audio = Audio(self.api_key, self.url) + self.audio: Audio = Audio(self.api_key, self.url, self.timeout) """Audio allows for the transcription of audio files.""" - self.documents: Documents = Documents(self.api_key, self.url) + self.documents: Documents = Documents(self.api_key, self.url, self.timeout) """Documents allows you to extract text from various document file types.""" - self.rerank: Rerank = Rerank(self.api_key, self.url) + self.rerank: Rerank = Rerank(self.api_key, self.url, self.timeout) """Rerank sorts text inputs by semantic relevance to a specified query.""" - self.translate: Translate = Translate(self.api_key, self.url) + self.translate: Translate = Translate(self.api_key, self.url, self.timeout) """Translate converts text from one language to another.""" - self.factuality: Factuality = Factuality(self.api_key, self.url) + self.factuality: Factuality = Factuality(self.api_key, self.url, self.timeout) """Factuality checks the factuality of a given text compared to a reference.""" - self.toxicity: Toxicity = Toxicity(self.api_key, self.url) + self.toxicity: Toxicity = Toxicity(self.api_key, self.url, self.timeout) """Toxicity checks the toxicity of a given text.""" - self.pii: Pii = Pii(self.api_key, self.url) + self.pii: Pii = Pii(self.api_key, self.url, self.timeout) """Pii replaces personal information such as names, SSNs, and emails in a given text.""" - self.injection: Injection = Injection(self.api_key, self.url) + self.injection: Injection = Injection(self.api_key, self.url, self.timeout) """Injection detects potential prompt injection attacks in a given prompt.""" - self.tokenize: Tokenize = Tokenize(self.api_key, self.url) + self.tokenize: Tokenize = Tokenize(self.api_key, self.url, self.timeout) """Tokenize generates tokens for input text.""" - self.detokenize: Detokenize = Detokenize(self.api_key, self.url) + self.detokenize: Detokenize = Detokenize(self.api_key, self.url, self.timeout) """Detokenizes generates text for input tokens.""" - self.models: Models = Models(self.api_key, self.url) + self.models: Models = Models(self.api_key, self.url, self.timeout) """Models lists all of the models available in the Prediction Guard API.""" def _connect_client(self) -> None: @@ -112,7 +133,7 @@ def _connect_client(self) -> None: } # Try listing models to make sure we can connect. - response = requests.request("GET", self.url + "/completions", headers=headers) + response = requests.request("GET", self.url + "/completions", headers=headers, timeout=self.timeout) # If the connection was unsuccessful, raise an exception. if response.status_code == 200: diff --git a/predictionguard/src/audio.py b/predictionguard/src/audio.py index b5b57b9..c0cbce0 100644 --- a/predictionguard/src/audio.py +++ b/predictionguard/src/audio.py @@ -39,16 +39,18 @@ class Audio: )) """ - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout - self.transcriptions: AudioTranscriptions = AudioTranscriptions(self.api_key, self.url) + self.transcriptions: AudioTranscriptions = AudioTranscriptions(self.api_key, self.url, self.timeout) class AudioTranscriptions: - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout def create( self, @@ -164,7 +166,7 @@ def _transcribe_audio( } response = requests.request( - "POST", self.url + "/audio/transcriptions", headers=headers, files=files, data=data + "POST", self.url + "/audio/transcriptions", headers=headers, files=files, data=data, timeout=self.timeout ) # If the request was successful, print the proxies. diff --git a/predictionguard/src/chat.py b/predictionguard/src/chat.py index 1f1bd4a..3f4c3ce 100644 --- a/predictionguard/src/chat.py +++ b/predictionguard/src/chat.py @@ -66,17 +66,19 @@ class Chat: )) """ - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout - self.completions: ChatCompletions = ChatCompletions(self.api_key, self.url) + self.completions: ChatCompletions = ChatCompletions(self.api_key, self.url, self.timeout) class ChatCompletions: - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout def create( self, @@ -192,9 +194,9 @@ def _generate_chat( Function to generate a single chat response. """ - def return_dict(url, headers, payload): + def return_dict(url, headers, payload, timeout): response = requests.request( - "POST", url + "/chat/completions", headers=headers, data=payload + "POST", url + "/chat/completions", headers=headers, data=payload, timeout=timeout ) # If the request was successful, print the proxies. if response.status_code == 200: @@ -215,12 +217,13 @@ def return_dict(url, headers, payload): pass raise ValueError("Could not make prediction. " + err) - def stream_generator(url, headers, payload, stream): + def stream_generator(url, headers, payload, stream, timeout): with requests.post( url + "/chat/completions", headers=headers, data=payload, stream=stream, + timeout=timeout, ) as response: response.raise_for_status() @@ -356,10 +359,10 @@ def stream_generator(url, headers, payload, stream): payload = json.dumps(payload_dict) if stream: - return stream_generator(self.url, headers, payload, stream) + return stream_generator(self.url, headers, payload, stream, self.timeout) else: - return return_dict(self.url, headers, payload) + return return_dict(self.url, headers, payload, self.timeout) def list_models(self, capability: Optional[str] = "chat-completion") -> List[str]: # Get the list of current models. @@ -376,7 +379,7 @@ def list_models(self, capability: Optional[str] = "chat-completion") -> List[str else: model_path = "/models/" + capability - response = requests.request("GET", self.url + model_path, headers=headers) + response = requests.request("GET", self.url + model_path, headers=headers, timeout=self.timeout) response_list = [] for model in response.json()["data"]: diff --git a/predictionguard/src/completions.py b/predictionguard/src/completions.py index 54b10c8..83e4ddf 100644 --- a/predictionguard/src/completions.py +++ b/predictionguard/src/completions.py @@ -41,9 +41,10 @@ class Completions: )) """ - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout def create( self, @@ -132,9 +133,9 @@ def _generate_completion( Function to generate a single completion. """ - def return_dict(url, headers, payload): + def return_dict(url, headers, payload, timeout): response = requests.request( - "POST", url + "/completions", headers=headers, data=payload + "POST", url + "/completions", headers=headers, data=payload, timeout=timeout ) # If the request was successful, print the proxies. if response.status_code == 200: @@ -155,12 +156,13 @@ def return_dict(url, headers, payload): pass raise ValueError("Could not make prediction. " + err) - def stream_generator(url, headers, payload, stream): + def stream_generator(url, headers, payload, stream, timeout): with requests.post( url + "/completions", headers=headers, data=payload, stream=stream, + timeout=timeout ) as response: response.raise_for_status() @@ -215,10 +217,10 @@ def stream_generator(url, headers, payload, stream): payload = json.dumps(payload_dict) if stream: - return stream_generator(self.url, headers, payload, stream) + return stream_generator(self.url, headers, payload, stream, self.timeout) else: - return return_dict(self.url, headers, payload) + return return_dict(self.url, headers, payload, self.timeout) def list_models(self) -> List[str]: # Get the list of current models. @@ -228,7 +230,7 @@ def list_models(self) -> List[str]: "User-Agent": "Prediction Guard Python Client: " + __version__, } - response = requests.request("GET", self.url + "/models/completion", headers=headers) + response = requests.request("GET", self.url + "/models/completion", headers=headers, timeout=self.timeout) response_list = [] for model in response.json()["data"]: diff --git a/predictionguard/src/detokenize.py b/predictionguard/src/detokenize.py index 37d7ab7..bb5cc5a 100644 --- a/predictionguard/src/detokenize.py +++ b/predictionguard/src/detokenize.py @@ -41,9 +41,10 @@ class Detokenize: """ - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout def create(self, model: str, tokens: List[int]) -> Dict[str, Any]: """ @@ -85,7 +86,7 @@ def _create_tokens(self, model, tokens): payload = json.dumps(payload) response = requests.request( - "POST", self.url + "/detokenize", headers=headers, data=payload + "POST", self.url + "/detokenize", headers=headers, data=payload, timeout=self.timeout ) if response.status_code == 200: @@ -114,7 +115,7 @@ def list_models(self): "User-Agent": "Prediction Guard Python Client: " + __version__ } - response = requests.request("GET", self.url + "/models/detokenize", headers=headers) + response = requests.request("GET", self.url + "/models/detokenize", headers=headers, timeout=self.timeout) response_list = [] for model in response.json()["data"]: diff --git a/predictionguard/src/documents.py b/predictionguard/src/documents.py index 8231c82..36acef3 100644 --- a/predictionguard/src/documents.py +++ b/predictionguard/src/documents.py @@ -37,16 +37,18 @@ class Documents: )) """ - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout - self.extract: DocumentsExtract = DocumentsExtract(self.api_key, self.url) + self.extract: DocumentsExtract = DocumentsExtract(self.api_key, self.url, self.timeout) class DocumentsExtract: - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout def create( self, @@ -117,7 +119,7 @@ def _extract_documents( response = requests.request( "POST", self.url + "/documents/extract", - headers=headers, files=files, data=data + headers=headers, files=files, data=data, timeout=self.timeout ) # If the request was successful, print the proxies. diff --git a/predictionguard/src/embeddings.py b/predictionguard/src/embeddings.py index 789f3aa..97375a6 100644 --- a/predictionguard/src/embeddings.py +++ b/predictionguard/src/embeddings.py @@ -46,9 +46,10 @@ class Embeddings: )) """ - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout def create( self, @@ -166,7 +167,7 @@ def _generate_embeddings(self, model, input, truncate, truncation_direction): payload = json.dumps(payload_dict) response = requests.request( - "POST", self.url + "/embeddings", headers=headers, data=payload + "POST", self.url + "/embeddings", headers=headers, data=payload, timeout=self.timeout ) # If the request was successful, print the proxies. @@ -204,7 +205,7 @@ def list_models(self, capability: Optional[str] = "embedding") -> List[str]: else: model_path = "/models/" + capability - response = requests.request("GET", self.url + model_path, headers=headers) + response = requests.request("GET", self.url + model_path, headers=headers, timeout=self.timeout) response_list = [] for model in response.json()["data"]: diff --git a/predictionguard/src/factuality.py b/predictionguard/src/factuality.py index bc84bde..3fc6f33 100644 --- a/predictionguard/src/factuality.py +++ b/predictionguard/src/factuality.py @@ -41,9 +41,10 @@ class Factuality: )) """ - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout def check(self, reference: str, text: str) -> Dict[str, Any]: """ @@ -72,7 +73,7 @@ def _generate_score(self, reference, text): payload_dict = {"reference": reference, "text": text} payload = json.dumps(payload_dict) response = requests.request( - "POST", self.url + "/factuality", headers=headers, data=payload + "POST", self.url + "/factuality", headers=headers, data=payload, timeout=self.timeout ) # If the request was successful, print the proxies. diff --git a/predictionguard/src/injection.py b/predictionguard/src/injection.py index 549b281..ca368f2 100644 --- a/predictionguard/src/injection.py +++ b/predictionguard/src/injection.py @@ -40,9 +40,10 @@ class Injection: )) """ - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout def check(self, prompt: str, detect: Optional[bool] = False) -> Dict[str, Any]: """ @@ -73,7 +74,7 @@ def _check_injection(self, prompt, detect): payload = json.dumps(payload) response = requests.request( - "POST", self.url + "/injection", headers=headers, data=payload + "POST", self.url + "/injection", headers=headers, data=payload, timeout=self.timeout ) if response.status_code == 200: diff --git a/predictionguard/src/models.py b/predictionguard/src/models.py index 3e408dc..7793a52 100644 --- a/predictionguard/src/models.py +++ b/predictionguard/src/models.py @@ -35,9 +35,10 @@ class Models: )) """ - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout def list(self, capability: Optional[str] = "") -> Dict[str, Any]: """ @@ -78,7 +79,7 @@ def _list_models(self, capability): models_path += "/" + capability response = requests.request( - "GET", self.url + models_path, headers=headers + "GET", self.url + models_path, headers=headers, timeout=self.timeout ) if response.status_code == 200: diff --git a/predictionguard/src/pii.py b/predictionguard/src/pii.py index d7efce4..f7b46cf 100644 --- a/predictionguard/src/pii.py +++ b/predictionguard/src/pii.py @@ -41,9 +41,10 @@ class Pii: )) """ - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout def check( self, prompt: str, replace: bool, replace_method: Optional[str] = "random" @@ -76,7 +77,7 @@ def _check_pii(self, prompt, replace, replace_method): payload = json.dumps(payload_dict) response = requests.request( - "POST", self.url + "/PII", headers=headers, data=payload + "POST", self.url + "/PII", headers=headers, data=payload, timeout=self.timeout ) if response.status_code == 200: diff --git a/predictionguard/src/rerank.py b/predictionguard/src/rerank.py index f8187fc..8a5b1f2 100644 --- a/predictionguard/src/rerank.py +++ b/predictionguard/src/rerank.py @@ -46,9 +46,10 @@ class Rerank: """ - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout def create( self, @@ -92,7 +93,7 @@ def _create_rerank(self, model, query, documents, return_documents): payload = json.dumps(payload) response = requests.request( - "POST", self.url + "/rerank", headers=headers, data=payload + "POST", self.url + "/rerank", headers=headers, data=payload, timeout=self.timeout ) if response.status_code == 200: @@ -121,7 +122,7 @@ def list_models(self): "User-Agent": "Prediction Guard Python Client: " + __version__ } - response = requests.request("GET", self.url + "/models/rerank", headers=headers) + response = requests.request("GET", self.url + "/models/rerank", headers=headers, timeout=self.timeout) response_list = [] for model in response.json()["data"]: diff --git a/predictionguard/src/tokenize.py b/predictionguard/src/tokenize.py index f69ea9f..4f23b68 100644 --- a/predictionguard/src/tokenize.py +++ b/predictionguard/src/tokenize.py @@ -41,9 +41,10 @@ class Tokenize: """ - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout def create(self, model: str, input: str) -> Dict[str, Any]: """ @@ -80,7 +81,7 @@ def _create_tokens(self, model, input): payload = json.dumps(payload) response = requests.request( - "POST", self.url + "/tokenize", headers=headers, data=payload + "POST", self.url + "/tokenize", headers=headers, data=payload, timeout=self.timeout ) if response.status_code == 200: @@ -109,7 +110,7 @@ def list_models(self): "User-Agent": "Prediction Guard Python Client: " + __version__ } - response = requests.request("GET", self.url + "/models/tokenize", headers=headers) + response = requests.request("GET", self.url + "/models/tokenize", headers=headers, timeout=self.timeout) response_list = [] for model in response.json()["data"]: diff --git a/predictionguard/src/toxicity.py b/predictionguard/src/toxicity.py index 76df010..66f5ea8 100644 --- a/predictionguard/src/toxicity.py +++ b/predictionguard/src/toxicity.py @@ -38,9 +38,10 @@ class Toxicity: )) """ - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout def check(self, text: str) -> Dict[str, Any]: """ @@ -68,7 +69,7 @@ def _generate_score(self, text): payload_dict = {"text": text} payload = json.dumps(payload_dict) response = requests.request( - "POST", self.url + "/toxicity", headers=headers, data=payload + "POST", self.url + "/toxicity", headers=headers, data=payload, timeout=self.timeout ) # If the request was successful, print the proxies. diff --git a/predictionguard/src/translate.py b/predictionguard/src/translate.py index bd10287..9e6ad4b 100644 --- a/predictionguard/src/translate.py +++ b/predictionguard/src/translate.py @@ -4,9 +4,10 @@ class Translate: """No longer supported.""" - def __init__(self, api_key, url): + def __init__(self, api_key, url, timeout): self.api_key = api_key self.url = url + self.timeout = timeout def create( self, diff --git a/predictionguard/version.py b/predictionguard/version.py index 7f01fe5..48da1d9 100644 --- a/predictionguard/version.py +++ b/predictionguard/version.py @@ -1,2 +1,2 @@ # Setting the package version -__version__ = "2.9.0" +__version__ = "2.9.1"