mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
use replicate http requests instead
This commit is contained in:
parent
3d6836417e
commit
c45b132675
3 changed files with 188 additions and 80 deletions
142
litellm/llms/replicate.py
Normal file
142
litellm/llms/replicate.py
Normal file
|
@ -0,0 +1,142 @@
|
|||
import os
|
||||
import json
|
||||
import requests
|
||||
import time
|
||||
from typing import Callable
|
||||
from litellm.utils import ModelResponse
|
||||
import tiktoken
|
||||
|
||||
# Function to start a prediction and get the prediction URL
|
||||
def start_prediction(version_id, input_data, api_token):
|
||||
base_url = "https://api.replicate.com/v1"
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
initial_prediction_data = {
|
||||
"version": version_id,
|
||||
"input": input_data,
|
||||
"max_new_tokens": 500,
|
||||
}
|
||||
|
||||
response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers)
|
||||
if response.status_code == 201:
|
||||
response_data = response.json()
|
||||
return response_data.get("urls", {}).get("get")
|
||||
else:
|
||||
raise ValueError(response.status_code, "Failed to start prediction.")
|
||||
|
||||
# Function to handle prediction response (non-streaming)
|
||||
def handle_prediction_response(prediction_url, api_token, print_verbose):
|
||||
output_string = ""
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
status = ""
|
||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||
print_verbose("making request")
|
||||
time.sleep(0.0001)
|
||||
response = requests.get(prediction_url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if "output" in response_data:
|
||||
output_string = "".join(response_data['output'])
|
||||
print_verbose(f"Non-streamed output:{output_string}")
|
||||
status = response_data['status']
|
||||
else:
|
||||
print_verbose("Failed to fetch prediction status and output.")
|
||||
return output_string
|
||||
|
||||
# Function to handle prediction response (streaming)
|
||||
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
|
||||
previous_output = ""
|
||||
output_string = ""
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
status = ""
|
||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||
time.sleep(0.0001)
|
||||
response = requests.get(prediction_url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if "output" in response_data:
|
||||
output_string = "".join(response_data['output'])
|
||||
new_output = output_string[len(previous_output):]
|
||||
yield new_output
|
||||
previous_output = output_string
|
||||
status = response_data['status']
|
||||
|
||||
# Function to extract version ID from model string
|
||||
def model_to_version_id(model):
|
||||
if ":" in model:
|
||||
split_model = model.split(":")
|
||||
return split_model[1]
|
||||
return model
|
||||
|
||||
# Main function for prediction completion
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
logging_obj,
|
||||
api_key,
|
||||
encoding=tiktoken.get_encoding("cl100k_base"),
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
# Convert messages to prompt
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
# Start a prediction and get the prediction URL
|
||||
version_id = model_to_version_id(model)
|
||||
input_data = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": 50,
|
||||
}
|
||||
|
||||
prediction_url = start_prediction(version_id, input_data, api_key)
|
||||
print_verbose(prediction_url)
|
||||
|
||||
# Handle the prediction response (streaming or non-streaming)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
return handle_prediction_response_streaming(prediction_url, api_key, print_verbose)
|
||||
else:
|
||||
result = handle_prediction_response(prediction_url, api_key, print_verbose)
|
||||
model_response["choices"][0]["message"]["content"] = result
|
||||
|
||||
# Calculate usage
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(encoding.encode(model_response["choices"][0]["message"]["content"]))
|
||||
model_response["created"] = time.time()
|
||||
model_response["model"] = model
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
return model_response
|
||||
|
||||
|
||||
|
||||
# # Example usage:
|
||||
# response = completion(
|
||||
# api_key="",
|
||||
# messages=[{"content": "good morning"}],
|
||||
# model="replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",
|
||||
# model_response=ModelResponse(),
|
||||
# print_verbose=print,
|
||||
# logging_obj=print, # stub logging_obj
|
||||
# optional_params={"stream": False}
|
||||
# )
|
||||
|
||||
# print(response)
|
|
@ -24,6 +24,7 @@ from .llms import ai21
|
|||
from .llms import sagemaker
|
||||
from .llms import bedrock
|
||||
from .llms import huggingface_restapi
|
||||
from .llms import replicate
|
||||
from .llms import aleph_alpha
|
||||
from .llms import baseten
|
||||
import tiktoken
|
||||
|
@ -341,10 +342,7 @@ def completion(
|
|||
response = model_response
|
||||
elif "replicate" in model or custom_llm_provider == "replicate":
|
||||
# import replicate/if it fails then pip install replicate
|
||||
try:
|
||||
import replicate
|
||||
except:
|
||||
Exception("Replicate import failed please run `pip install replicate`")
|
||||
|
||||
|
||||
# Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN")
|
||||
replicate_key = os.environ.get("REPLICATE_API_TOKEN")
|
||||
|
@ -358,56 +356,25 @@ def completion(
|
|||
)
|
||||
# set replicate key
|
||||
os.environ["REPLICATE_API_TOKEN"] = str(replicate_key)
|
||||
prompt = " ".join([message["content"] for message in messages])
|
||||
input = {
|
||||
"prompt": prompt
|
||||
}
|
||||
if "max_tokens" in optional_params:
|
||||
input["max_length"] = optional_params['max_tokens'] # for t5 models
|
||||
input["max_new_tokens"] = optional_params['max_tokens'] # for llama2 models
|
||||
## LOGGING
|
||||
logging.pre_call(
|
||||
input=prompt,
|
||||
|
||||
model_response = replicate.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding, # for calculating input/output tokens
|
||||
api_key=replicate_key,
|
||||
additional_args={
|
||||
"complete_input_dict": input,
|
||||
"max_tokens": max_tokens,
|
||||
},
|
||||
logging_obj=logging,
|
||||
)
|
||||
## COMPLETION CALL
|
||||
output = replicate.run(model, input=input)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
# let the stream handler know this is replicate
|
||||
response = CustomStreamWrapper(output, "replicate", logging_obj=logging)
|
||||
response = CustomStreamWrapper(model_response, model, logging_obj=logging)
|
||||
return response
|
||||
response = ""
|
||||
for item in output:
|
||||
response += item
|
||||
completion_response = response
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=prompt,
|
||||
api_key=replicate_key,
|
||||
original_response=completion_response,
|
||||
additional_args={
|
||||
"complete_input_dict": input,
|
||||
"max_tokens": max_tokens,
|
||||
},
|
||||
)
|
||||
## USAGE
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(encoding.encode(completion_response))
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["message"]["content"] = completion_response
|
||||
model_response["created"] = time.time()
|
||||
model_response["model"] = model
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
response = model_response
|
||||
|
||||
elif model in litellm.anthropic_models:
|
||||
anthropic_key = (
|
||||
api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY")
|
||||
|
|
|
@ -349,54 +349,53 @@ def test_completion_azure_deployment_id():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
|
||||
def test_completion_replicate_llama_stream():
|
||||
model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
|
||||
try:
|
||||
response = completion(model=model_name, messages=messages, stream=True)
|
||||
# Add any assertions here to check the response
|
||||
for result in response:
|
||||
print(result)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# # Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
|
||||
# def test_completion_replicate_llama_stream():
|
||||
# model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
|
||||
# try:
|
||||
# response = completion(model=model_name, messages=messages, stream=True)
|
||||
# # Add any assertions here to check the response
|
||||
# for result in response:
|
||||
# print(result)
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_replicate_stability_stream():
|
||||
model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
|
||||
try:
|
||||
response = completion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
custom_llm_provider="replicate",
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
for chunk in response:
|
||||
print(chunk["choices"][0]["delta"])
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# def test_completion_replicate_stability_stream():
|
||||
# model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
|
||||
# try:
|
||||
# response = completion(
|
||||
# model=model_name,
|
||||
# messages=messages,
|
||||
# stream=True,
|
||||
# custom_llm_provider="replicate",
|
||||
# )
|
||||
# # Add any assertions here to check the response
|
||||
# for chunk in response:
|
||||
# print(chunk["choices"][0]["delta"])
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_replicate_stability():
|
||||
model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
|
||||
def test_completion_replicate_llama_2():
|
||||
model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf"
|
||||
try:
|
||||
response = completion(
|
||||
model=model_name, messages=messages, custom_llm_provider="replicate"
|
||||
)
|
||||
print(response)
|
||||
# Add any assertions here to check the response
|
||||
response_str = response["choices"][0]["message"]["content"]
|
||||
response_str_2 = response.choices[0].message.content
|
||||
print(response_str)
|
||||
print(response_str_2)
|
||||
if type(response_str) != str:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
if type(response_str_2) != str:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_replicate_llama_2()
|
||||
|
||||
|
||||
######## Test TogetherAI ########
|
||||
def test_completion_together_ai():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue