forked from phoenix/litellm-mirror
(feat) completion - replicate custom deployments
This commit is contained in:
parent
62e83ce75a
commit
d0c9bfd14a
1 changed files with 9 additions and 3 deletions
|
@ -74,8 +74,14 @@ class ReplicateConfig():
|
||||||
|
|
||||||
|
|
||||||
# Function to start a prediction and get the prediction URL
|
# Function to start a prediction and get the prediction URL
|
||||||
def start_prediction(version_id, input_data, api_token, api_base, logging_obj):
|
def start_prediction(version_id, input_data, api_token, api_base, logging_obj, print_verbose):
|
||||||
base_url = api_base
|
base_url = api_base
|
||||||
|
if "deployments" in version_id:
|
||||||
|
print_verbose("\nLiteLLM: Request to custom replicate deployment")
|
||||||
|
version_id = version_id.replace("deployments/", "")
|
||||||
|
base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
|
||||||
|
print_verbose(f"Deployment base URL: {base_url}\n")
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Token {api_token}",
|
"Authorization": f"Token {api_token}",
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
|
@ -86,7 +92,7 @@ def start_prediction(version_id, input_data, api_token, api_base, logging_obj):
|
||||||
"input": input_data,
|
"input": input_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input_data["prompt"],
|
input=input_data["prompt"],
|
||||||
api_key="",
|
api_key="",
|
||||||
|
@ -209,7 +215,7 @@ def completion(
|
||||||
## Step2: Poll prediction url for response
|
## Step2: Poll prediction url for response
|
||||||
## Step2: is handled with and without streaming
|
## Step2: is handled with and without streaming
|
||||||
model_response["created"] = time.time() # for pricing this must remain right before calling api
|
model_response["created"] = time.time() # for pricing this must remain right before calling api
|
||||||
prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj)
|
prediction_url = start_prediction(version_id, input_data, api_key, api_base, logging_obj=logging_obj, print_verbose=print_verbose)
|
||||||
print_verbose(prediction_url)
|
print_verbose(prediction_url)
|
||||||
|
|
||||||
# Handle the prediction response (streaming or non-streaming)
|
# Handle the prediction response (streaming or non-streaming)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue