From a166af50d0af51e0879e3a689ad7981014cc51c8 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 8 Aug 2023 17:01:58 -0700 Subject: [PATCH] add cohere streaming --- litellm/main.py | 27 +++++++++++++++++++++------ litellm/tests/test_completion.py | 14 ++++++++++++++ litellm/utils.py | 10 ++++++++++ 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 0eac87724..17144a47f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -14,15 +14,24 @@ dotenv.load_dotenv() # Loading env variables using dotenv # TODO this will evolve to accepting models # replicate/anthropic/cohere class CustomStreamWrapper: - def __init__(self, completion_stream): - self.completion_stream = completion_stream + def __init__(self, completion_stream, model): + self.model = model + if model in litellm.cohere_models: + # cohere does not return an iterator, so we need to wrap it in one + self.completion_stream = iter(completion_stream) + else: + self.completion_stream = completion_stream def __iter__(self): return self def __next__(self): - chunk = next(self.completion_stream) - return {"choices": [{"delta": chunk.completion}]} + if self.model in litellm.anthropic_models: + chunk = next(self.completion_stream) + return {"choices": [{"delta": chunk.completion}]} + elif self.model in litellm.cohere_models: + chunk = next(self.completion_stream) + return {"choices": [{"delta": chunk.text}]} new_response = { "choices": [ @@ -241,7 +250,7 @@ def completion( ) if 'stream' in optional_params and optional_params['stream'] == True: # don't try to access stream object, - response = CustomStreamWrapper(completion) + response = CustomStreamWrapper(completion, model) return response completion_response = completion.completion @@ -277,8 +286,14 @@ def completion( ## COMPLETION CALL response = co.generate( model=model, - prompt = prompt + prompt = prompt, + **optional_params ) + if 'stream' in optional_params and optional_params['stream'] == True: + # don't try to access stream object, + response = CustomStreamWrapper(response, model) + return response + completion_response = response[0].text ## LOGGING logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 35f7b631d..d5733e2fb 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -57,6 +57,20 @@ def test_completion_cohere(): except Exception as e: pytest.fail(f"Error occurred: {e}") + +def test_completion_cohere_stream(): + try: + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "how does a court case get to the Supreme Court?"} + ] + response = completion(model="command-nightly", messages=messages, stream=True, max_tokens=50) + # Add any assertions here to check the response + for chunk in response: + print(chunk['choices'][0]['delta']) # same as openai format + except Exception as e: + pytest.fail(f"Error occurred: {e}") + def test_completion_openai(): try: response = completion(model="gpt-3.5-turbo", messages=messages) diff --git a/litellm/utils.py b/litellm/utils.py index 599c61e24..04e92737a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -161,6 +161,16 @@ def get_optional_params( if top_p != 1: optional_params["top_p"] = top_p return optional_params + elif model in litellm.cohere_models: + # handle cohere params + if stream: + optional_params["stream"] = stream + if temperature != 1: + optional_params["temperature"] = temperature + if max_tokens != float('inf'): + optional_params["max_tokens"] = max_tokens + return optional_params + else:# assume passing in params for openai/azure openai if functions != []: optional_params["functions"] = functions