(feat) completion - replicate custom deployments

This commit is contained in:
ishaan-jaff 2023-11-09 18:03:07 -08:00
parent 62e83ce75a
commit d0c9bfd14a

View file

@ -74,8 +74,14 @@ class ReplicateConfig():
# Function to start a prediction and get the prediction URL # Function to start a prediction and get the prediction URL
def start_prediction(version_id, input_data, api_token, api_base, logging_obj): def start_prediction(version_id, input_data, api_token, api_base, logging_obj, print_verbose):
base_url = api_base base_url = api_base
if "deployments" in version_id:
print_verbose("\nLiteLLM: Request to custom replicate deployment")
version_id = version_id.replace("deployments/", "")
base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
print_verbose(f"Deployment base URL: {base_url}\n")
headers = { headers = {
"Authorization": f"Token {api_token}", "Authorization": f"Token {api_token}",
"Content-Type": "application/json" "Content-Type": "application/json"
@ -209,7 +215,7 @@ def completion(
## Step2: Poll prediction url for response ## Step2: Poll prediction url for response
## Step2: is handled with and without streaming ## Step2: is handled with and without streaming
model_response["created"] = time.time() # for pricing this must remain right before calling api model_response["created"] = time.time() # for pricing this must remain right before calling api
prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj) prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj, print_verbose=print_verbose)
print_verbose(prediction_url) print_verbose(prediction_url)
# Handle the prediction response (streaming or non-streaming) # Handle the prediction response (streaming or non-streaming)