add Replicate Error class

This commit is contained in:
ishaan-jaff 2023-09-06 10:25:40 -07:00
parent 1c61b7b229
commit bc9b629726

View file

@ -4,7 +4,14 @@ import requests
import time
from typing import Callable
from litellm.utils import ModelResponse
import tiktoken
class ReplicateError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
# Function to start a prediction and get the prediction URL
def start_prediction(version_id, input_data, api_token):
@ -25,7 +32,7 @@ def start_prediction(version_id, input_data, api_token):
response_data = response.json()
return response_data.get("urls", {}).get("get")
else:
raise ValueError(response.status_code, "Failed to start prediction.")
raise ReplicateError(response.status_code, "Failed to start prediction.")
# Function to handle prediction response (non-streaming)
def handle_prediction_response(prediction_url, api_token, print_verbose):
@ -87,7 +94,7 @@ def completion(
print_verbose: Callable,
logging_obj,
api_key,
encoding=tiktoken.get_encoding("cl100k_base"),
encoding,
optional_params=None,
litellm_params=None,
logger_fn=None,