streaming replicate tests

This commit is contained in:
ishaan-jaff 2023-08-08 17:50:36 -07:00
parent d87ae07574
commit f4048886ab
3 changed files with 78 additions and 44 deletions

View file

@ -7,32 +7,9 @@ import litellm
from litellm import client, logging, exception_type, timeout, get_optional_params from litellm import client, logging, exception_type, timeout, get_optional_params
import tiktoken import tiktoken
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import get_secret, install_and_import from litellm.utils import get_secret, install_and_import, CustomStreamWrapper
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
# TODO this will evolve to accepting models
# replicate/anthropic/cohere
class CustomStreamWrapper:
def __init__(self, completion_stream, model):
self.model = model
if model in litellm.cohere_models:
# cohere does not return an iterator, so we need to wrap it in one
self.completion_stream = iter(completion_stream)
else:
self.completion_stream = completion_stream
def __iter__(self):
return self
def __next__(self):
if self.model in litellm.anthropic_models:
chunk = next(self.completion_stream)
return {"choices": [{"delta": chunk.completion}]}
elif self.model in litellm.cohere_models:
chunk = next(self.completion_stream)
return {"choices": [{"delta": chunk.text}]}
new_response = { new_response = {
"choices": [ "choices": [
{ {
@ -67,7 +44,7 @@ def completion(
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None, presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
# Optional liteLLM function params # Optional liteLLM function params
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False, *, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False,
hugging_face = False hugging_face = False, replicate=False,
): ):
try: try:
global new_response global new_response
@ -77,7 +54,8 @@ def completion(
functions=functions, function_call=function_call, functions=functions, function_call=function_call,
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens, temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id,
model=model # params to identify the model
model=model, replicate=replicate, hugging_face=hugging_face
) )
if azure == True: if azure == True:
# azure configs # azure configs
@ -172,7 +150,7 @@ def completion(
model_response["model"] = model model_response["model"] = model
model_response["usage"] = response["usage"] model_response["usage"] = response["usage"]
response = model_response response = model_response
elif "replicate" in model: elif "replicate" in model or replicate == True:
# import replicate/if it fails then pip install replicate # import replicate/if it fails then pip install replicate
install_and_import("replicate") install_and_import("replicate")
import replicate import replicate
@ -196,6 +174,11 @@ def completion(
output = replicate.run( output = replicate.run(
model, model,
input=input) input=input)
if 'stream' in optional_params and optional_params['stream'] == True:
# don't try to access stream object,
# let the stream handler know this is replicate
response = CustomStreamWrapper(output, "replicate")
return response
response = "" response = ""
for item in output: for item in output:
response += item response += item

View file

@ -139,20 +139,36 @@ def test_completion_azure():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect. # Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
# [TODO] improve our try-except block to handle for these def test_completion_replicate_llama_stream():
# def test_completion_replicate_llama(): model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
# model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1" try:
# try: response = completion(model=model_name, messages=messages, stream=True)
# response = completion(model=model_name, messages=messages, max_tokens=500) # Add any assertions here to check the response
# # Add any assertions here to check the response for result in response:
# print(response) print(result)
# except Exception as e: print(response)
# print(f"in replicate llama, got error {e}") except Exception as e:
# pass pytest.fail(f"Error occurred: {e}")
# if e == "FunctionTimedOut":
# pass def test_completion_replicate_stability_stream():
# else: model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
# pytest.fail(f"Error occurred: {e}") try:
response = completion(model=model_name, messages=messages, stream=True, replicate=True)
# Add any assertions here to check the response
for result in response:
print(result)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_replicate_stability():
model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
try:
response = completion(model=model_name, messages=messages, replicate=True)
# Add any assertions here to check the response
for result in response:
print(result)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -148,6 +148,8 @@ def get_optional_params(
user = "", user = "",
deployment_id = None, deployment_id = None,
model = None, model = None,
replicate = False,
hugging_face = False,
): ):
optional_params = {} optional_params = {}
if model in litellm.anthropic_models: if model in litellm.anthropic_models:
@ -170,7 +172,12 @@ def get_optional_params(
if max_tokens != float('inf'): if max_tokens != float('inf'):
optional_params["max_tokens"] = max_tokens optional_params["max_tokens"] = max_tokens
return optional_params return optional_params
elif replicate == True:
# any replicate models
# TODO: handle translating remaining replicate params
if stream:
optional_params["stream"] = stream
return optional_params
else:# assume passing in params for openai/azure openai else:# assume passing in params for openai/azure openai
if functions != []: if functions != []:
optional_params["functions"] = functions optional_params["functions"] = functions
@ -199,6 +206,7 @@ def get_optional_params(
if deployment_id != None: if deployment_id != None:
optional_params["deployment_id"] = deployment_id optional_params["deployment_id"] = deployment_id
return optional_params return optional_params
return optional_params
def set_callbacks(callback_list): def set_callbacks(callback_list):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient
@ -557,3 +565,30 @@ def get_secret(secret_name):
return os.environ.get(secret_name) return os.environ.get(secret_name)
else: else:
return os.environ.get(secret_name) return os.environ.get(secret_name)
######## Streaming Class ############################
# wraps the completion stream to return the correct format for the model
# replicate/anthropic/cohere
class CustomStreamWrapper:
def __init__(self, completion_stream, model):
self.model = model
if model in litellm.cohere_models:
# cohere does not return an iterator, so we need to wrap it in one
self.completion_stream = iter(completion_stream)
else:
self.completion_stream = completion_stream
def __iter__(self):
return self
def __next__(self):
if self.model in litellm.anthropic_models:
chunk = next(self.completion_stream)
return {"choices": [{"delta": chunk.completion}]}
elif self.model == "replicate":
chunk = next(self.completion_stream)
return {"choices": [{"delta": chunk}]}
elif self.model in litellm.cohere_models:
chunk = next(self.completion_stream)
return {"choices": [{"delta": chunk.text}]}