mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
177 lines
5.9 KiB
Python
177 lines
5.9 KiB
Python
import os
|
|
import json
|
|
import requests
|
|
import time
|
|
from typing import Callable
|
|
from litellm.utils import ModelResponse
|
|
|
|
class ReplicateError(Exception):
|
|
def __init__(self, status_code, message):
|
|
self.status_code = status_code
|
|
self.message = message
|
|
super().__init__(
|
|
self.message
|
|
) # Call the base class constructor with the parameters it needs
|
|
|
|
# Function to start a prediction and get the prediction URL
|
|
def start_prediction(version_id, input_data, api_token, logging_obj):
|
|
base_url = "https://api.replicate.com/v1"
|
|
headers = {
|
|
"Authorization": f"Token {api_token}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
initial_prediction_data = {
|
|
"version": version_id,
|
|
"input": input_data,
|
|
"max_new_tokens": 500,
|
|
}
|
|
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=input_data["prompt"],
|
|
api_key="",
|
|
additional_args={"complete_input_dict": initial_prediction_data, "headers": headers},
|
|
)
|
|
|
|
response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers)
|
|
if response.status_code == 201:
|
|
response_data = response.json()
|
|
return response_data.get("urls", {}).get("get")
|
|
else:
|
|
raise ReplicateError(response.status_code, message=response.text)
|
|
|
|
# Function to handle prediction response (non-streaming)
|
|
def handle_prediction_response(prediction_url, api_token, print_verbose):
|
|
output_string = ""
|
|
headers = {
|
|
"Authorization": f"Token {api_token}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
status = ""
|
|
while True and (status not in ["succeeded", "failed", "canceled"]):
|
|
print_verbose("making request")
|
|
time.sleep(0.0001)
|
|
response = requests.get(prediction_url, headers=headers)
|
|
if response.status_code == 200:
|
|
response_data = response.json()
|
|
if "output" in response_data:
|
|
output_string = "".join(response_data['output'])
|
|
print_verbose(f"Non-streamed output:{output_string}")
|
|
status = response_data['status']
|
|
else:
|
|
print_verbose("Failed to fetch prediction status and output.")
|
|
return output_string
|
|
|
|
# Function to handle prediction response (streaming)
|
|
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
|
|
previous_output = ""
|
|
output_string = ""
|
|
|
|
headers = {
|
|
"Authorization": f"Token {api_token}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
status = ""
|
|
while True and (status not in ["succeeded", "failed", "canceled"]):
|
|
time.sleep(0.0001)
|
|
response = requests.get(prediction_url, headers=headers)
|
|
if response.status_code == 200:
|
|
response_data = response.json()
|
|
if "output" in response_data:
|
|
output_string = "".join(response_data['output'])
|
|
new_output = output_string[len(previous_output):]
|
|
yield new_output
|
|
previous_output = output_string
|
|
status = response_data['status']
|
|
|
|
# Function to extract version ID from model string
|
|
def model_to_version_id(model):
|
|
if ":" in model:
|
|
split_model = model.split(":")
|
|
return split_model[1]
|
|
return model
|
|
|
|
# Main function for prediction completion
|
|
def completion(
|
|
model: str,
|
|
messages: list,
|
|
model_response: ModelResponse,
|
|
print_verbose: Callable,
|
|
logging_obj,
|
|
api_key,
|
|
encoding,
|
|
optional_params=None,
|
|
litellm_params=None,
|
|
logger_fn=None,
|
|
):
|
|
# Convert messages to prompt
|
|
prompt = ""
|
|
for message in messages:
|
|
prompt += message["content"]
|
|
|
|
# Start a prediction and get the prediction URL
|
|
version_id = model_to_version_id(model)
|
|
input_data = {
|
|
"prompt": prompt,
|
|
**optional_params
|
|
}
|
|
|
|
## COMPLETION CALL
|
|
## Replicate Compeltion calls have 2 steps
|
|
## Step1: Start Prediction: gets a prediction url
|
|
## Step2: Poll prediction url for response
|
|
## Step2: is handled with and without streaming
|
|
prediction_url = start_prediction(version_id, input_data, api_key, logging_obj=logging_obj)
|
|
print_verbose(prediction_url)
|
|
|
|
# Handle the prediction response (streaming or non-streaming)
|
|
if "stream" in optional_params and optional_params["stream"] == True:
|
|
print_verbose("streaming request")
|
|
return handle_prediction_response_streaming(prediction_url, api_key, print_verbose)
|
|
else:
|
|
result = handle_prediction_response(prediction_url, api_key, print_verbose)
|
|
|
|
## LOGGING
|
|
logging_obj.post_call(
|
|
input=prompt,
|
|
api_key="",
|
|
original_response=result,
|
|
additional_args={"complete_input_dict": input_data},
|
|
)
|
|
|
|
print_verbose(f"raw model_response: {result}")
|
|
|
|
if len(result) == 0: # edge case, where result from replicate is empty
|
|
result = " "
|
|
|
|
## Building RESPONSE OBJECT
|
|
model_response["choices"][0]["message"]["content"] = result
|
|
|
|
# Calculate usage
|
|
prompt_tokens = len(encoding.encode(prompt))
|
|
completion_tokens = len(encoding.encode(model_response["choices"][0]["message"]["content"]))
|
|
model_response["created"] = time.time()
|
|
model_response["model"] = model
|
|
model_response["usage"] = {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": prompt_tokens + completion_tokens,
|
|
}
|
|
return model_response
|
|
|
|
|
|
|
|
# # Example usage:
|
|
# response = completion(
|
|
# api_key="",
|
|
# messages=[{"content": "good morning"}],
|
|
# model="replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",
|
|
# model_response=ModelResponse(),
|
|
# print_verbose=print,
|
|
# logging_obj=print, # stub logging_obj
|
|
# optional_params={"stream": False}
|
|
# )
|
|
|
|
# print(response)
|