forked from phoenix/litellm-mirror
streaming replicate tests
This commit is contained in:
parent
c55adebcfd
commit
d66bda43d3
3 changed files with 78 additions and 44 deletions
|
@ -7,32 +7,9 @@ import litellm
|
|||
from litellm import client, logging, exception_type, timeout, get_optional_params
|
||||
import tiktoken
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
from litellm.utils import get_secret, install_and_import
|
||||
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
|
||||
# TODO this will evolve to accepting models
|
||||
# replicate/anthropic/cohere
|
||||
class CustomStreamWrapper:
|
||||
def __init__(self, completion_stream, model):
|
||||
self.model = model
|
||||
if model in litellm.cohere_models:
|
||||
# cohere does not return an iterator, so we need to wrap it in one
|
||||
self.completion_stream = iter(completion_stream)
|
||||
else:
|
||||
self.completion_stream = completion_stream
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.model in litellm.anthropic_models:
|
||||
chunk = next(self.completion_stream)
|
||||
return {"choices": [{"delta": chunk.completion}]}
|
||||
elif self.model in litellm.cohere_models:
|
||||
chunk = next(self.completion_stream)
|
||||
return {"choices": [{"delta": chunk.text}]}
|
||||
|
||||
new_response = {
|
||||
"choices": [
|
||||
{
|
||||
|
@ -67,7 +44,7 @@ def completion(
|
|||
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
|
||||
# Optional liteLLM function params
|
||||
*, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False,
|
||||
hugging_face = False
|
||||
hugging_face = False, replicate=False,
|
||||
):
|
||||
try:
|
||||
global new_response
|
||||
|
@ -77,7 +54,8 @@ def completion(
|
|||
functions=functions, function_call=function_call,
|
||||
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id,
|
||||
model=model
|
||||
# params to identify the model
|
||||
model=model, replicate=replicate, hugging_face=hugging_face
|
||||
)
|
||||
if azure == True:
|
||||
# azure configs
|
||||
|
@ -172,7 +150,7 @@ def completion(
|
|||
model_response["model"] = model
|
||||
model_response["usage"] = response["usage"]
|
||||
response = model_response
|
||||
elif "replicate" in model:
|
||||
elif "replicate" in model or replicate == True:
|
||||
# import replicate/if it fails then pip install replicate
|
||||
install_and_import("replicate")
|
||||
import replicate
|
||||
|
@ -196,6 +174,11 @@ def completion(
|
|||
output = replicate.run(
|
||||
model,
|
||||
input=input)
|
||||
if 'stream' in optional_params and optional_params['stream'] == True:
|
||||
# don't try to access stream object,
|
||||
# let the stream handler know this is replicate
|
||||
response = CustomStreamWrapper(output, "replicate")
|
||||
return response
|
||||
response = ""
|
||||
for item in output:
|
||||
response += item
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue