forked from phoenix/litellm-mirror
streaming replicate tests
This commit is contained in:
parent
c55adebcfd
commit
d66bda43d3
3 changed files with 78 additions and 44 deletions
|
@ -7,32 +7,9 @@ import litellm
|
||||||
from litellm import client, logging, exception_type, timeout, get_optional_params
|
from litellm import client, logging, exception_type, timeout, get_optional_params
|
||||||
import tiktoken
|
import tiktoken
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
from litellm.utils import get_secret, install_and_import
|
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
|
|
||||||
# TODO this will evolve to accepting models
|
|
||||||
# replicate/anthropic/cohere
|
|
||||||
class CustomStreamWrapper:
|
|
||||||
def __init__(self, completion_stream, model):
|
|
||||||
self.model = model
|
|
||||||
if model in litellm.cohere_models:
|
|
||||||
# cohere does not return an iterator, so we need to wrap it in one
|
|
||||||
self.completion_stream = iter(completion_stream)
|
|
||||||
else:
|
|
||||||
self.completion_stream = completion_stream
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
if self.model in litellm.anthropic_models:
|
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
return {"choices": [{"delta": chunk.completion}]}
|
|
||||||
elif self.model in litellm.cohere_models:
|
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
return {"choices": [{"delta": chunk.text}]}
|
|
||||||
|
|
||||||
new_response = {
|
new_response = {
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
|
@ -67,7 +44,7 @@ def completion(
|
||||||
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
|
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
|
||||||
# Optional liteLLM function params
|
# Optional liteLLM function params
|
||||||
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False,
|
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False,
|
||||||
hugging_face = False
|
hugging_face = False, replicate=False,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
global new_response
|
global new_response
|
||||||
|
@ -77,7 +54,8 @@ def completion(
|
||||||
functions=functions, function_call=function_call,
|
functions=functions, function_call=function_call,
|
||||||
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
|
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
|
||||||
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id,
|
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id,
|
||||||
model=model
|
# params to identify the model
|
||||||
|
model=model, replicate=replicate, hugging_face=hugging_face
|
||||||
)
|
)
|
||||||
if azure == True:
|
if azure == True:
|
||||||
# azure configs
|
# azure configs
|
||||||
|
@ -172,7 +150,7 @@ def completion(
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
model_response["usage"] = response["usage"]
|
model_response["usage"] = response["usage"]
|
||||||
response = model_response
|
response = model_response
|
||||||
elif "replicate" in model:
|
elif "replicate" in model or replicate == True:
|
||||||
# import replicate/if it fails then pip install replicate
|
# import replicate/if it fails then pip install replicate
|
||||||
install_and_import("replicate")
|
install_and_import("replicate")
|
||||||
import replicate
|
import replicate
|
||||||
|
@ -196,6 +174,11 @@ def completion(
|
||||||
output = replicate.run(
|
output = replicate.run(
|
||||||
model,
|
model,
|
||||||
input=input)
|
input=input)
|
||||||
|
if 'stream' in optional_params and optional_params['stream'] == True:
|
||||||
|
# don't try to access stream object,
|
||||||
|
# let the stream handler know this is replicate
|
||||||
|
response = CustomStreamWrapper(output, "replicate")
|
||||||
|
return response
|
||||||
response = ""
|
response = ""
|
||||||
for item in output:
|
for item in output:
|
||||||
response += item
|
response += item
|
||||||
|
|
|
@ -139,20 +139,36 @@ def test_completion_azure():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
|
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
|
||||||
# [TODO] improve our try-except block to handle for these
|
def test_completion_replicate_llama_stream():
|
||||||
# def test_completion_replicate_llama():
|
model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
|
||||||
# model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
|
try:
|
||||||
# try:
|
response = completion(model=model_name, messages=messages, stream=True)
|
||||||
# response = completion(model=model_name, messages=messages, max_tokens=500)
|
# Add any assertions here to check the response
|
||||||
# # Add any assertions here to check the response
|
for result in response:
|
||||||
# print(response)
|
print(result)
|
||||||
# except Exception as e:
|
print(response)
|
||||||
# print(f"in replicate llama, got error {e}")
|
except Exception as e:
|
||||||
# pass
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# if e == "FunctionTimedOut":
|
|
||||||
# pass
|
def test_completion_replicate_stability_stream():
|
||||||
# else:
|
model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
try:
|
||||||
|
response = completion(model=model_name, messages=messages, stream=True, replicate=True)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
for result in response:
|
||||||
|
print(result)
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
def test_completion_replicate_stability():
|
||||||
|
model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
|
||||||
|
try:
|
||||||
|
response = completion(model=model_name, messages=messages, replicate=True)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
for result in response:
|
||||||
|
print(result)
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
|
@ -148,6 +148,8 @@ def get_optional_params(
|
||||||
user = "",
|
user = "",
|
||||||
deployment_id = None,
|
deployment_id = None,
|
||||||
model = None,
|
model = None,
|
||||||
|
replicate = False,
|
||||||
|
hugging_face = False,
|
||||||
):
|
):
|
||||||
optional_params = {}
|
optional_params = {}
|
||||||
if model in litellm.anthropic_models:
|
if model in litellm.anthropic_models:
|
||||||
|
@ -170,7 +172,12 @@ def get_optional_params(
|
||||||
if max_tokens != float('inf'):
|
if max_tokens != float('inf'):
|
||||||
optional_params["max_tokens"] = max_tokens
|
optional_params["max_tokens"] = max_tokens
|
||||||
return optional_params
|
return optional_params
|
||||||
|
elif replicate == True:
|
||||||
|
# any replicate models
|
||||||
|
# TODO: handle translating remaining replicate params
|
||||||
|
if stream:
|
||||||
|
optional_params["stream"] = stream
|
||||||
|
return optional_params
|
||||||
else:# assume passing in params for openai/azure openai
|
else:# assume passing in params for openai/azure openai
|
||||||
if functions != []:
|
if functions != []:
|
||||||
optional_params["functions"] = functions
|
optional_params["functions"] = functions
|
||||||
|
@ -199,6 +206,7 @@ def get_optional_params(
|
||||||
if deployment_id != None:
|
if deployment_id != None:
|
||||||
optional_params["deployment_id"] = deployment_id
|
optional_params["deployment_id"] = deployment_id
|
||||||
return optional_params
|
return optional_params
|
||||||
|
return optional_params
|
||||||
|
|
||||||
def set_callbacks(callback_list):
|
def set_callbacks(callback_list):
|
||||||
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient
|
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient
|
||||||
|
@ -557,3 +565,30 @@ def get_secret(secret_name):
|
||||||
return os.environ.get(secret_name)
|
return os.environ.get(secret_name)
|
||||||
else:
|
else:
|
||||||
return os.environ.get(secret_name)
|
return os.environ.get(secret_name)
|
||||||
|
|
||||||
|
######## Streaming Class ############################
|
||||||
|
# wraps the completion stream to return the correct format for the model
|
||||||
|
# replicate/anthropic/cohere
|
||||||
|
class CustomStreamWrapper:
|
||||||
|
def __init__(self, completion_stream, model):
|
||||||
|
self.model = model
|
||||||
|
if model in litellm.cohere_models:
|
||||||
|
# cohere does not return an iterator, so we need to wrap it in one
|
||||||
|
self.completion_stream = iter(completion_stream)
|
||||||
|
else:
|
||||||
|
self.completion_stream = completion_stream
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.model in litellm.anthropic_models:
|
||||||
|
chunk = next(self.completion_stream)
|
||||||
|
return {"choices": [{"delta": chunk.completion}]}
|
||||||
|
elif self.model == "replicate":
|
||||||
|
chunk = next(self.completion_stream)
|
||||||
|
return {"choices": [{"delta": chunk}]}
|
||||||
|
elif self.model in litellm.cohere_models:
|
||||||
|
chunk = next(self.completion_stream)
|
||||||
|
return {"choices": [{"delta": chunk.text}]}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue