diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index afa56d978..fc5293eb2 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -74,8 +74,14 @@ class ReplicateConfig(): # Function to start a prediction and get the prediction URL -def start_prediction(version_id, input_data, api_token, api_base, logging_obj): +def start_prediction(version_id, input_data, api_token, api_base, logging_obj, print_verbose): base_url = api_base + if "deployments" in version_id: + print_verbose("\nLiteLLM: Request to custom replicate deployment") + version_id = version_id.replace("deployments/", "") + base_url = f"https://api.replicate.com/v1/deployments/{version_id}" + print_verbose(f"Deployment base URL: {base_url}\n") + headers = { "Authorization": f"Token {api_token}", "Content-Type": "application/json" @@ -86,7 +92,7 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj): "input": input_data, } - ## LOGGING + ## LOGGING logging_obj.pre_call( input=input_data["prompt"], api_key="", @@ -209,7 +215,7 @@ def completion( ## Step2: Poll prediction url for response ## Step2: is handled with and without streaming model_response["created"] = time.time() # for pricing this must remain right before calling api - prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj) + prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj, print_verbose=print_verbose) print_verbose(prediction_url) # Handle the prediction response (streaming or non-streaming)