mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
streaming replicate tests
This commit is contained in:
parent
c55adebcfd
commit
d66bda43d3
3 changed files with 78 additions and 44 deletions
|
@ -7,32 +7,9 @@ import litellm
|
|||
from litellm import client, logging, exception_type, timeout, get_optional_params
|
||||
import tiktoken
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
from litellm.utils import get_secret, install_and_import
|
||||
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
|
||||
# TODO this will evolve to accepting models
|
||||
# replicate/anthropic/cohere
|
||||
class CustomStreamWrapper:
|
||||
def __init__(self, completion_stream, model):
|
||||
self.model = model
|
||||
if model in litellm.cohere_models:
|
||||
# cohere does not return an iterator, so we need to wrap it in one
|
||||
self.completion_stream = iter(completion_stream)
|
||||
else:
|
||||
self.completion_stream = completion_stream
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.model in litellm.anthropic_models:
|
||||
chunk = next(self.completion_stream)
|
||||
return {"choices": [{"delta": chunk.completion}]}
|
||||
elif self.model in litellm.cohere_models:
|
||||
chunk = next(self.completion_stream)
|
||||
return {"choices": [{"delta": chunk.text}]}
|
||||
|
||||
new_response = {
|
||||
"choices": [
|
||||
{
|
||||
|
@ -67,7 +44,7 @@ def completion(
|
|||
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
|
||||
# Optional liteLLM function params
|
||||
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False,
|
||||
hugging_face = False
|
||||
hugging_face = False, replicate=False,
|
||||
):
|
||||
try:
|
||||
global new_response
|
||||
|
@ -77,7 +54,8 @@ def completion(
|
|||
functions=functions, function_call=function_call,
|
||||
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id,
|
||||
model=model
|
||||
# params to identify the model
|
||||
model=model, replicate=replicate, hugging_face=hugging_face
|
||||
)
|
||||
if azure == True:
|
||||
# azure configs
|
||||
|
@ -172,7 +150,7 @@ def completion(
|
|||
model_response["model"] = model
|
||||
model_response["usage"] = response["usage"]
|
||||
response = model_response
|
||||
elif "replicate" in model:
|
||||
elif "replicate" in model or replicate == True:
|
||||
# import replicate/if it fails then pip install replicate
|
||||
install_and_import("replicate")
|
||||
import replicate
|
||||
|
@ -196,6 +174,11 @@ def completion(
|
|||
output = replicate.run(
|
||||
model,
|
||||
input=input)
|
||||
if 'stream' in optional_params and optional_params['stream'] == True:
|
||||
# don't try to access stream object,
|
||||
# let the stream handler know this is replicate
|
||||
response = CustomStreamWrapper(output, "replicate")
|
||||
return response
|
||||
response = ""
|
||||
for item in output:
|
||||
response += item
|
||||
|
|
|
@ -139,20 +139,36 @@ def test_completion_azure():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
|
||||
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
|
||||
# [TODO] improve our try-except block to handle for these
|
||||
# def test_completion_replicate_llama():
|
||||
# model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
|
||||
# try:
|
||||
# response = completion(model=model_name, messages=messages, max_tokens=500)
|
||||
# # Add any assertions here to check the response
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# print(f"in replicate llama, got error {e}")
|
||||
# pass
|
||||
# if e == "FunctionTimedOut":
|
||||
# pass
|
||||
# else:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
def test_completion_replicate_llama_stream():
|
||||
model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
|
||||
try:
|
||||
response = completion(model=model_name, messages=messages, stream=True)
|
||||
# Add any assertions here to check the response
|
||||
for result in response:
|
||||
print(result)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
def test_completion_replicate_stability_stream():
|
||||
model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
|
||||
try:
|
||||
response = completion(model=model_name, messages=messages, stream=True, replicate=True)
|
||||
# Add any assertions here to check the response
|
||||
for result in response:
|
||||
print(result)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
def test_completion_replicate_stability():
|
||||
model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
|
||||
try:
|
||||
response = completion(model=model_name, messages=messages, replicate=True)
|
||||
# Add any assertions here to check the response
|
||||
for result in response:
|
||||
print(result)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
|
@ -148,6 +148,8 @@ def get_optional_params(
|
|||
user = "",
|
||||
deployment_id = None,
|
||||
model = None,
|
||||
replicate = False,
|
||||
hugging_face = False,
|
||||
):
|
||||
optional_params = {}
|
||||
if model in litellm.anthropic_models:
|
||||
|
@ -170,7 +172,12 @@ def get_optional_params(
|
|||
if max_tokens != float('inf'):
|
||||
optional_params["max_tokens"] = max_tokens
|
||||
return optional_params
|
||||
|
||||
elif replicate == True:
|
||||
# any replicate models
|
||||
# TODO: handle translating remaining replicate params
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
return optional_params
|
||||
else:# assume passing in params for openai/azure openai
|
||||
if functions != []:
|
||||
optional_params["functions"] = functions
|
||||
|
@ -199,6 +206,7 @@ def get_optional_params(
|
|||
if deployment_id != None:
|
||||
optional_params["deployment_id"] = deployment_id
|
||||
return optional_params
|
||||
return optional_params
|
||||
|
||||
def set_callbacks(callback_list):
|
||||
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient
|
||||
|
@ -557,3 +565,30 @@ def get_secret(secret_name):
|
|||
return os.environ.get(secret_name)
|
||||
else:
|
||||
return os.environ.get(secret_name)
|
||||
|
||||
######## Streaming Class ############################
|
||||
# wraps the completion stream to return the correct format for the model
|
||||
# replicate/anthropic/cohere
|
||||
class CustomStreamWrapper:
|
||||
def __init__(self, completion_stream, model):
|
||||
self.model = model
|
||||
if model in litellm.cohere_models:
|
||||
# cohere does not return an iterator, so we need to wrap it in one
|
||||
self.completion_stream = iter(completion_stream)
|
||||
else:
|
||||
self.completion_stream = completion_stream
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.model in litellm.anthropic_models:
|
||||
chunk = next(self.completion_stream)
|
||||
return {"choices": [{"delta": chunk.completion}]}
|
||||
elif self.model == "replicate":
|
||||
chunk = next(self.completion_stream)
|
||||
return {"choices": [{"delta": chunk}]}
|
||||
elif self.model in litellm.cohere_models:
|
||||
chunk = next(self.completion_stream)
|
||||
return {"choices": [{"delta": chunk.text}]}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue