add cohere streaming

This commit is contained in:
ishaan-jaff 2023-08-08 17:01:58 -07:00
parent 5b96906442
commit a166af50d0
3 changed files with 45 additions and 6 deletions

View file

@ -14,15 +14,24 @@ dotenv.load_dotenv() # Loading env variables using dotenv
# TODO this will evolve to accepting models
# replicate/anthropic/cohere
class CustomStreamWrapper:
def __init__(self, completion_stream):
self.completion_stream = completion_stream
def __init__(self, completion_stream, model):
self.model = model
if model in litellm.cohere_models:
# cohere does not return an iterator, so we need to wrap it in one
self.completion_stream = iter(completion_stream)
else:
self.completion_stream = completion_stream
def __iter__(self):
return self
def __next__(self):
chunk = next(self.completion_stream)
return {"choices": [{"delta": chunk.completion}]}
if self.model in litellm.anthropic_models:
chunk = next(self.completion_stream)
return {"choices": [{"delta": chunk.completion}]}
elif self.model in litellm.cohere_models:
chunk = next(self.completion_stream)
return {"choices": [{"delta": chunk.text}]}
new_response = {
"choices": [
@ -241,7 +250,7 @@ def completion(
)
if 'stream' in optional_params and optional_params['stream'] == True:
# don't try to access stream object,
response = CustomStreamWrapper(completion)
response = CustomStreamWrapper(completion, model)
return response
completion_response = completion.completion
@ -277,8 +286,14 @@ def completion(
## COMPLETION CALL
response = co.generate(
model=model,
prompt = prompt
prompt = prompt,
**optional_params
)
if 'stream' in optional_params and optional_params['stream'] == True:
# don't try to access stream object,
response = CustomStreamWrapper(response, model)
return response
completion_response = response[0].text
## LOGGING
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)