mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
add cohere streaming
This commit is contained in:
parent
5b96906442
commit
a166af50d0
3 changed files with 45 additions and 6 deletions
|
@ -14,15 +14,24 @@ dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
# TODO this will evolve to accepting models
|
# TODO this will evolve to accepting models
|
||||||
# replicate/anthropic/cohere
|
# replicate/anthropic/cohere
|
||||||
class CustomStreamWrapper:
|
class CustomStreamWrapper:
|
||||||
def __init__(self, completion_stream):
|
def __init__(self, completion_stream, model):
|
||||||
self.completion_stream = completion_stream
|
self.model = model
|
||||||
|
if model in litellm.cohere_models:
|
||||||
|
# cohere does not return an iterator, so we need to wrap it in one
|
||||||
|
self.completion_stream = iter(completion_stream)
|
||||||
|
else:
|
||||||
|
self.completion_stream = completion_stream
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
chunk = next(self.completion_stream)
|
if self.model in litellm.anthropic_models:
|
||||||
return {"choices": [{"delta": chunk.completion}]}
|
chunk = next(self.completion_stream)
|
||||||
|
return {"choices": [{"delta": chunk.completion}]}
|
||||||
|
elif self.model in litellm.cohere_models:
|
||||||
|
chunk = next(self.completion_stream)
|
||||||
|
return {"choices": [{"delta": chunk.text}]}
|
||||||
|
|
||||||
new_response = {
|
new_response = {
|
||||||
"choices": [
|
"choices": [
|
||||||
|
@ -241,7 +250,7 @@ def completion(
|
||||||
)
|
)
|
||||||
if 'stream' in optional_params and optional_params['stream'] == True:
|
if 'stream' in optional_params and optional_params['stream'] == True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(completion)
|
response = CustomStreamWrapper(completion, model)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
completion_response = completion.completion
|
completion_response = completion.completion
|
||||||
|
@ -277,8 +286,14 @@ def completion(
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
response = co.generate(
|
response = co.generate(
|
||||||
model=model,
|
model=model,
|
||||||
prompt = prompt
|
prompt = prompt,
|
||||||
|
**optional_params
|
||||||
)
|
)
|
||||||
|
if 'stream' in optional_params and optional_params['stream'] == True:
|
||||||
|
# don't try to access stream object,
|
||||||
|
response = CustomStreamWrapper(response, model)
|
||||||
|
return response
|
||||||
|
|
||||||
completion_response = response[0].text
|
completion_response = response[0].text
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
||||||
|
|
|
@ -57,6 +57,20 @@ def test_completion_cohere():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_cohere_stream():
|
||||||
|
try:
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "how does a court case get to the Supreme Court?"}
|
||||||
|
]
|
||||||
|
response = completion(model="command-nightly", messages=messages, stream=True, max_tokens=50)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk['choices'][0]['delta']) # same as openai format
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
def test_completion_openai():
|
def test_completion_openai():
|
||||||
try:
|
try:
|
||||||
response = completion(model="gpt-3.5-turbo", messages=messages)
|
response = completion(model="gpt-3.5-turbo", messages=messages)
|
||||||
|
|
|
@ -161,6 +161,16 @@ def get_optional_params(
|
||||||
if top_p != 1:
|
if top_p != 1:
|
||||||
optional_params["top_p"] = top_p
|
optional_params["top_p"] = top_p
|
||||||
return optional_params
|
return optional_params
|
||||||
|
elif model in litellm.cohere_models:
|
||||||
|
# handle cohere params
|
||||||
|
if stream:
|
||||||
|
optional_params["stream"] = stream
|
||||||
|
if temperature != 1:
|
||||||
|
optional_params["temperature"] = temperature
|
||||||
|
if max_tokens != float('inf'):
|
||||||
|
optional_params["max_tokens"] = max_tokens
|
||||||
|
return optional_params
|
||||||
|
|
||||||
else:# assume passing in params for openai/azure openai
|
else:# assume passing in params for openai/azure openai
|
||||||
if functions != []:
|
if functions != []:
|
||||||
optional_params["functions"] = functions
|
optional_params["functions"] = functions
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue