use replicate http requests instead

This commit is contained in:
ishaan-jaff 2023-09-06 09:43:04 -07:00
parent 3d6836417e
commit c45b132675
3 changed files with 188 additions and 80 deletions

142
litellm/llms/replicate.py Normal file
View file

@ -0,0 +1,142 @@
import os
import json
import requests
import time
from typing import Callable
from litellm.utils import ModelResponse
import tiktoken
# Function to start a prediction and get the prediction URL
def start_prediction(version_id, input_data, api_token):
base_url = "https://api.replicate.com/v1"
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json"
}
initial_prediction_data = {
"version": version_id,
"input": input_data,
"max_new_tokens": 500,
}
response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers)
if response.status_code == 201:
response_data = response.json()
return response_data.get("urls", {}).get("get")
else:
raise ValueError(response.status_code, "Failed to start prediction.")
# Function to handle prediction response (non-streaming)
def handle_prediction_response(prediction_url, api_token, print_verbose):
output_string = ""
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json"
}
status = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
print_verbose("making request")
time.sleep(0.0001)
response = requests.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
if "output" in response_data:
output_string = "".join(response_data['output'])
print_verbose(f"Non-streamed output:{output_string}")
status = response_data['status']
else:
print_verbose("Failed to fetch prediction status and output.")
return output_string
# Function to handle prediction response (streaming)
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
previous_output = ""
output_string = ""
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json"
}
status = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
time.sleep(0.0001)
response = requests.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
if "output" in response_data:
output_string = "".join(response_data['output'])
new_output = output_string[len(previous_output):]
yield new_output
previous_output = output_string
status = response_data['status']
# Function to extract version ID from model string
def model_to_version_id(model):
if ":" in model:
split_model = model.split(":")
return split_model[1]
return model
# Main function for prediction completion
def completion(
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
logging_obj,
api_key,
encoding=tiktoken.get_encoding("cl100k_base"),
optional_params=None,
litellm_params=None,
logger_fn=None,
):
# Convert messages to prompt
prompt = ""
for message in messages:
prompt += message["content"]
# Start a prediction and get the prediction URL
version_id = model_to_version_id(model)
input_data = {
"prompt": prompt,
"max_new_tokens": 50,
}
prediction_url = start_prediction(version_id, input_data, api_key)
print_verbose(prediction_url)
# Handle the prediction response (streaming or non-streaming)
if "stream" in optional_params and optional_params["stream"] == True:
return handle_prediction_response_streaming(prediction_url, api_key, print_verbose)
else:
result = handle_prediction_response(prediction_url, api_key, print_verbose)
model_response["choices"][0]["message"]["content"] = result
# Calculate usage
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(model_response["choices"][0]["message"]["content"]))
model_response["created"] = time.time()
model_response["model"] = model
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
return model_response
# # Example usage:
# response = completion(
# api_key="",
# messages=[{"content": "good morning"}],
# model="replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",
# model_response=ModelResponse(),
# print_verbose=print,
# logging_obj=print, # stub logging_obj
# optional_params={"stream": False}
# )
# print(response)

View file

@ -24,6 +24,7 @@ from .llms import ai21
from .llms import sagemaker
from .llms import bedrock
from .llms import huggingface_restapi
from .llms import replicate
from .llms import aleph_alpha
from .llms import baseten
import tiktoken
@ -341,10 +342,7 @@ def completion(
response = model_response
elif "replicate" in model or custom_llm_provider == "replicate":
# import replicate/if it fails then pip install replicate
try:
import replicate
except:
Exception("Replicate import failed please run `pip install replicate`")
# Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN")
replicate_key = os.environ.get("REPLICATE_API_TOKEN")
@ -358,56 +356,25 @@ def completion(
)
# set replicate key
os.environ["REPLICATE_API_TOKEN"] = str(replicate_key)
prompt = " ".join([message["content"] for message in messages])
input = {
"prompt": prompt
}
if "max_tokens" in optional_params:
input["max_length"] = optional_params['max_tokens'] # for t5 models
input["max_new_tokens"] = optional_params['max_tokens'] # for llama2 models
## LOGGING
logging.pre_call(
input=prompt,
model_response = replicate.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=replicate_key,
additional_args={
"complete_input_dict": input,
"max_tokens": max_tokens,
},
logging_obj=logging,
)
## COMPLETION CALL
output = replicate.run(model, input=input)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
# let the stream handler know this is replicate
response = CustomStreamWrapper(output, "replicate", logging_obj=logging)
response = CustomStreamWrapper(model_response, model, logging_obj=logging)
return response
response = ""
for item in output:
response += item
completion_response = response
## LOGGING
logging.post_call(
input=prompt,
api_key=replicate_key,
original_response=completion_response,
additional_args={
"complete_input_dict": input,
"max_tokens": max_tokens,
},
)
## USAGE
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(completion_response))
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
response = model_response
elif model in litellm.anthropic_models:
anthropic_key = (
api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY")

View file

@ -349,54 +349,53 @@ def test_completion_azure_deployment_id():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
def test_completion_replicate_llama_stream():
model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
try:
response = completion(model=model_name, messages=messages, stream=True)
# Add any assertions here to check the response
for result in response:
print(result)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# # Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
# def test_completion_replicate_llama_stream():
# model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
# try:
# response = completion(model=model_name, messages=messages, stream=True)
# # Add any assertions here to check the response
# for result in response:
# print(result)
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
def test_completion_replicate_stability_stream():
model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
try:
response = completion(
model=model_name,
messages=messages,
stream=True,
custom_llm_provider="replicate",
)
# Add any assertions here to check the response
for chunk in response:
print(chunk["choices"][0]["delta"])
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# def test_completion_replicate_stability_stream():
# model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
# try:
# response = completion(
# model=model_name,
# messages=messages,
# stream=True,
# custom_llm_provider="replicate",
# )
# # Add any assertions here to check the response
# for chunk in response:
# print(chunk["choices"][0]["delta"])
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
def test_completion_replicate_stability():
model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
def test_completion_replicate_llama_2():
model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf"
try:
response = completion(
model=model_name, messages=messages, custom_llm_provider="replicate"
)
print(response)
# Add any assertions here to check the response
response_str = response["choices"][0]["message"]["content"]
response_str_2 = response.choices[0].message.content
print(response_str)
print(response_str_2)
if type(response_str) != str:
pytest.fail(f"Error occurred: {e}")
if type(response_str_2) != str:
pytest.fail(f"Error occurred: {e}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_replicate_llama_2()
######## Test TogetherAI ########
def test_completion_together_ai():