mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
add cohere streaming
This commit is contained in:
parent
5b96906442
commit
a166af50d0
3 changed files with 45 additions and 6 deletions
|
@ -14,15 +14,24 @@ dotenv.load_dotenv() # Loading env variables using dotenv
|
|||
# TODO this will evolve to accepting models
|
||||
# replicate/anthropic/cohere
|
||||
class CustomStreamWrapper:
|
||||
def __init__(self, completion_stream):
|
||||
self.completion_stream = completion_stream
|
||||
def __init__(self, completion_stream, model):
|
||||
self.model = model
|
||||
if model in litellm.cohere_models:
|
||||
# cohere does not return an iterator, so we need to wrap it in one
|
||||
self.completion_stream = iter(completion_stream)
|
||||
else:
|
||||
self.completion_stream = completion_stream
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
chunk = next(self.completion_stream)
|
||||
return {"choices": [{"delta": chunk.completion}]}
|
||||
if self.model in litellm.anthropic_models:
|
||||
chunk = next(self.completion_stream)
|
||||
return {"choices": [{"delta": chunk.completion}]}
|
||||
elif self.model in litellm.cohere_models:
|
||||
chunk = next(self.completion_stream)
|
||||
return {"choices": [{"delta": chunk.text}]}
|
||||
|
||||
new_response = {
|
||||
"choices": [
|
||||
|
@ -241,7 +250,7 @@ def completion(
|
|||
)
|
||||
if 'stream' in optional_params and optional_params['stream'] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(completion)
|
||||
response = CustomStreamWrapper(completion, model)
|
||||
return response
|
||||
|
||||
completion_response = completion.completion
|
||||
|
@ -277,8 +286,14 @@ def completion(
|
|||
## COMPLETION CALL
|
||||
response = co.generate(
|
||||
model=model,
|
||||
prompt = prompt
|
||||
prompt = prompt,
|
||||
**optional_params
|
||||
)
|
||||
if 'stream' in optional_params and optional_params['stream'] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(response, model)
|
||||
return response
|
||||
|
||||
completion_response = response[0].text
|
||||
## LOGGING
|
||||
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue