streaming replicate tests

This commit is contained in:
ishaan-jaff 2023-08-08 17:50:36 -07:00
parent c55adebcfd
commit d66bda43d3
3 changed files with 78 additions and 44 deletions

View file

@ -7,32 +7,9 @@ import litellm
from litellm import client, logging, exception_type, timeout, get_optional_params
import tiktoken
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import get_secret, install_and_import
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper
####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
# TODO this will evolve to accepting models
# replicate/anthropic/cohere
class CustomStreamWrapper:
def __init__(self, completion_stream, model):
self.model = model
if model in litellm.cohere_models:
# cohere does not return an iterator, so we need to wrap it in one
self.completion_stream = iter(completion_stream)
else:
self.completion_stream = completion_stream
def __iter__(self):
return self
def __next__(self):
if self.model in litellm.anthropic_models:
chunk = next(self.completion_stream)
return {"choices": [{"delta": chunk.completion}]}
elif self.model in litellm.cohere_models:
chunk = next(self.completion_stream)
return {"choices": [{"delta": chunk.text}]}
new_response = {
"choices": [
{
@ -67,7 +44,7 @@ def completion(
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
# Optional liteLLM function params
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False,
hugging_face = False
hugging_face = False, replicate=False,
):
try:
global new_response
@ -77,7 +54,8 @@ def completion(
functions=functions, function_call=function_call,
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id,
model=model
# params to identify the model
model=model, replicate=replicate, hugging_face=hugging_face
)
if azure == True:
# azure configs
@ -172,7 +150,7 @@ def completion(
model_response["model"] = model
model_response["usage"] = response["usage"]
response = model_response
elif "replicate" in model:
elif "replicate" in model or replicate == True:
# import replicate/if it fails then pip install replicate
install_and_import("replicate")
import replicate
@ -196,6 +174,11 @@ def completion(
output = replicate.run(
model,
input=input)
if 'stream' in optional_params and optional_params['stream'] == True:
# don't try to access stream object,
# let the stream handler know this is replicate
response = CustomStreamWrapper(output, "replicate")
return response
response = ""
for item in output:
response += item