add cohere streaming

This commit is contained in:
ishaan-jaff 2023-08-08 17:01:58 -07:00
parent 5b96906442
commit a166af50d0
3 changed files with 45 additions and 6 deletions

View file

@ -14,15 +14,24 @@ dotenv.load_dotenv() # Loading env variables using dotenv
# TODO this will evolve to accepting models # TODO this will evolve to accepting models
# replicate/anthropic/cohere # replicate/anthropic/cohere
class CustomStreamWrapper: class CustomStreamWrapper:
def __init__(self, completion_stream): def __init__(self, completion_stream, model):
self.completion_stream = completion_stream self.model = model
if model in litellm.cohere_models:
# cohere does not return an iterator, so we need to wrap it in one
self.completion_stream = iter(completion_stream)
else:
self.completion_stream = completion_stream
def __iter__(self): def __iter__(self):
return self return self
def __next__(self): def __next__(self):
chunk = next(self.completion_stream) if self.model in litellm.anthropic_models:
return {"choices": [{"delta": chunk.completion}]} chunk = next(self.completion_stream)
return {"choices": [{"delta": chunk.completion}]}
elif self.model in litellm.cohere_models:
chunk = next(self.completion_stream)
return {"choices": [{"delta": chunk.text}]}
new_response = { new_response = {
"choices": [ "choices": [
@ -241,7 +250,7 @@ def completion(
) )
if 'stream' in optional_params and optional_params['stream'] == True: if 'stream' in optional_params and optional_params['stream'] == True:
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper(completion) response = CustomStreamWrapper(completion, model)
return response return response
completion_response = completion.completion completion_response = completion.completion
@ -277,8 +286,14 @@ def completion(
## COMPLETION CALL ## COMPLETION CALL
response = co.generate( response = co.generate(
model=model, model=model,
prompt = prompt prompt = prompt,
**optional_params
) )
if 'stream' in optional_params and optional_params['stream'] == True:
# don't try to access stream object,
response = CustomStreamWrapper(response, model)
return response
completion_response = response[0].text completion_response = response[0].text
## LOGGING ## LOGGING
logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)

View file

@ -57,6 +57,20 @@ def test_completion_cohere():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_cohere_stream():
try:
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "how does a court case get to the Supreme Court?"}
]
response = completion(model="command-nightly", messages=messages, stream=True, max_tokens=50)
# Add any assertions here to check the response
for chunk in response:
print(chunk['choices'][0]['delta']) # same as openai format
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_openai(): def test_completion_openai():
try: try:
response = completion(model="gpt-3.5-turbo", messages=messages) response = completion(model="gpt-3.5-turbo", messages=messages)

View file

@ -161,6 +161,16 @@ def get_optional_params(
if top_p != 1: if top_p != 1:
optional_params["top_p"] = top_p optional_params["top_p"] = top_p
return optional_params return optional_params
elif model in litellm.cohere_models:
# handle cohere params
if stream:
optional_params["stream"] = stream
if temperature != 1:
optional_params["temperature"] = temperature
if max_tokens != float('inf'):
optional_params["max_tokens"] = max_tokens
return optional_params
else:# assume passing in params for openai/azure openai else:# assume passing in params for openai/azure openai
if functions != []: if functions != []:
optional_params["functions"] = functions optional_params["functions"] = functions