From d0c9bfd14a319cc3e670a635225884f5ec10edd2 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Thu, 9 Nov 2023 18:03:07 -0800 Subject: [PATCH] (feat) completion - replicate custom deployments --- litellm/llms/replicate.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index afa56d978..fc5293eb2 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -74,8 +74,14 @@ class ReplicateConfig(): # Function to start a prediction and get the prediction URL -def start_prediction(version_id, input_data, api_token, api_base, logging_obj): +def start_prediction(version_id, input_data, api_token, api_base, logging_obj, print_verbose): base_url = api_base + if "deployments" in version_id: + print_verbose("\nLiteLLM: Request to custom replicate deployment") + version_id = version_id.replace("deployments/", "") + base_url = f"https://api.replicate.com/v1/deployments/{version_id}" + print_verbose(f"Deployment base URL: {base_url}\n") + headers = { "Authorization": f"Token {api_token}", "Content-Type": "application/json" @@ -86,7 +92,7 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj): "input": input_data, } - ## LOGGING + ## LOGGING logging_obj.pre_call( input=input_data["prompt"], api_key="", @@ -209,7 +215,7 @@ def completion( ## Step2: Poll prediction url for response ## Step2: is handled with and without streaming model_response["created"] = time.time() # for pricing this must remain right before calling api - prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj) + prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj, print_verbose=print_verbose) print_verbose(prediction_url) # Handle the prediction response (streaming or non-streaming)