use replicate http requests instead

This commit is contained in:
ishaan-jaff 2023-09-06 09:43:04 -07:00
parent 3d6836417e
commit c45b132675
3 changed files with 188 additions and 80 deletions

142
litellm/llms/replicate.py Normal file
View file

@ -0,0 +1,142 @@
import os
import json
import requests
import time
from typing import Callable
from litellm.utils import ModelResponse
import tiktoken
# Function to start a prediction and get the prediction URL
def start_prediction(version_id, input_data, api_token):
base_url = "https://api.replicate.com/v1"
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json"
}
initial_prediction_data = {
"version": version_id,
"input": input_data,
"max_new_tokens": 500,
}
response = requests.post(f"{base_url}/predictions", json=initial_prediction_data, headers=headers)
if response.status_code == 201:
response_data = response.json()
return response_data.get("urls", {}).get("get")
else:
raise ValueError(response.status_code, "Failed to start prediction.")
# Function to handle prediction response (non-streaming)
def handle_prediction_response(prediction_url, api_token, print_verbose):
output_string = ""
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json"
}
status = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
print_verbose("making request")
time.sleep(0.0001)
response = requests.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
if "output" in response_data:
output_string = "".join(response_data['output'])
print_verbose(f"Non-streamed output:{output_string}")
status = response_data['status']
else:
print_verbose("Failed to fetch prediction status and output.")
return output_string
# Function to handle prediction response (streaming)
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
previous_output = ""
output_string = ""
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json"
}
status = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
time.sleep(0.0001)
response = requests.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
if "output" in response_data:
output_string = "".join(response_data['output'])
new_output = output_string[len(previous_output):]
yield new_output
previous_output = output_string
status = response_data['status']
# Function to extract version ID from model string
def model_to_version_id(model):
if ":" in model:
split_model = model.split(":")
return split_model[1]
return model
# Main function for prediction completion
def completion(
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
logging_obj,
api_key,
encoding=tiktoken.get_encoding("cl100k_base"),
optional_params=None,
litellm_params=None,
logger_fn=None,
):
# Convert messages to prompt
prompt = ""
for message in messages:
prompt += message["content"]
# Start a prediction and get the prediction URL
version_id = model_to_version_id(model)
input_data = {
"prompt": prompt,
"max_new_tokens": 50,
}
prediction_url = start_prediction(version_id, input_data, api_key)
print_verbose(prediction_url)
# Handle the prediction response (streaming or non-streaming)
if "stream" in optional_params and optional_params["stream"] == True:
return handle_prediction_response_streaming(prediction_url, api_key, print_verbose)
else:
result = handle_prediction_response(prediction_url, api_key, print_verbose)
model_response["choices"][0]["message"]["content"] = result
# Calculate usage
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(model_response["choices"][0]["message"]["content"]))
model_response["created"] = time.time()
model_response["model"] = model
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
return model_response
# # Example usage:
# response = completion(
# api_key="",
# messages=[{"content": "good morning"}],
# model="replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",
# model_response=ModelResponse(),
# print_verbose=print,
# logging_obj=print, # stub logging_obj
# optional_params={"stream": False}
# )
# print(response)

View file

@ -24,6 +24,7 @@ from .llms import ai21
from .llms import sagemaker from .llms import sagemaker
from .llms import bedrock from .llms import bedrock
from .llms import huggingface_restapi from .llms import huggingface_restapi
from .llms import replicate
from .llms import aleph_alpha from .llms import aleph_alpha
from .llms import baseten from .llms import baseten
import tiktoken import tiktoken
@ -341,10 +342,7 @@ def completion(
response = model_response response = model_response
elif "replicate" in model or custom_llm_provider == "replicate": elif "replicate" in model or custom_llm_provider == "replicate":
# import replicate/if it fails then pip install replicate # import replicate/if it fails then pip install replicate
try:
import replicate
except:
Exception("Replicate import failed please run `pip install replicate`")
# Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN") # Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN")
replicate_key = os.environ.get("REPLICATE_API_TOKEN") replicate_key = os.environ.get("REPLICATE_API_TOKEN")
@ -358,56 +356,25 @@ def completion(
) )
# set replicate key # set replicate key
os.environ["REPLICATE_API_TOKEN"] = str(replicate_key) os.environ["REPLICATE_API_TOKEN"] = str(replicate_key)
prompt = " ".join([message["content"] for message in messages])
input = { model_response = replicate.completion(
"prompt": prompt model=model,
} messages=messages,
if "max_tokens" in optional_params: model_response=model_response,
input["max_length"] = optional_params['max_tokens'] # for t5 models print_verbose=print_verbose,
input["max_new_tokens"] = optional_params['max_tokens'] # for llama2 models optional_params=optional_params,
## LOGGING litellm_params=litellm_params,
logging.pre_call( logger_fn=logger_fn,
input=prompt, encoding=encoding, # for calculating input/output tokens
api_key=replicate_key, api_key=replicate_key,
additional_args={ logging_obj=logging,
"complete_input_dict": input,
"max_tokens": max_tokens,
},
) )
## COMPLETION CALL
output = replicate.run(model, input=input)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object, # don't try to access stream object,
# let the stream handler know this is replicate response = CustomStreamWrapper(model_response, model, logging_obj=logging)
response = CustomStreamWrapper(output, "replicate", logging_obj=logging)
return response return response
response = ""
for item in output:
response += item
completion_response = response
## LOGGING
logging.post_call(
input=prompt,
api_key=replicate_key,
original_response=completion_response,
additional_args={
"complete_input_dict": input,
"max_tokens": max_tokens,
},
)
## USAGE
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(completion_response))
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
response = model_response response = model_response
elif model in litellm.anthropic_models: elif model in litellm.anthropic_models:
anthropic_key = ( anthropic_key = (
api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY") api_key or litellm.anthropic_key or os.environ.get("ANTHROPIC_API_KEY")

View file

@ -349,54 +349,53 @@ def test_completion_azure_deployment_id():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect. # # Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
def test_completion_replicate_llama_stream(): # def test_completion_replicate_llama_stream():
model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1" # model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
try: # try:
response = completion(model=model_name, messages=messages, stream=True) # response = completion(model=model_name, messages=messages, stream=True)
# Add any assertions here to check the response # # Add any assertions here to check the response
for result in response: # for result in response:
print(result) # print(result)
print(response) # print(response)
except Exception as e: # except Exception as e:
pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
def test_completion_replicate_stability_stream(): # def test_completion_replicate_stability_stream():
model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb" # model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
try: # try:
response = completion( # response = completion(
model=model_name, # model=model_name,
messages=messages, # messages=messages,
stream=True, # stream=True,
custom_llm_provider="replicate", # custom_llm_provider="replicate",
) # )
# Add any assertions here to check the response # # Add any assertions here to check the response
for chunk in response: # for chunk in response:
print(chunk["choices"][0]["delta"]) # print(chunk["choices"][0]["delta"])
print(response) # print(response)
except Exception as e: # except Exception as e:
pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
def test_completion_replicate_stability(): def test_completion_replicate_llama_2():
model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb" model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf"
try: try:
response = completion( response = completion(
model=model_name, messages=messages, custom_llm_provider="replicate" model=model_name, messages=messages, custom_llm_provider="replicate"
) )
print(response)
# Add any assertions here to check the response # Add any assertions here to check the response
response_str = response["choices"][0]["message"]["content"] response_str = response["choices"][0]["message"]["content"]
response_str_2 = response.choices[0].message.content
print(response_str) print(response_str)
print(response_str_2)
if type(response_str) != str: if type(response_str) != str:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
if type(response_str_2) != str:
pytest.fail(f"Error occurred: {e}")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_replicate_llama_2()
######## Test TogetherAI ######## ######## Test TogetherAI ########
def test_completion_together_ai(): def test_completion_together_ai():