diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index e133c3e237..31550ae353 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -411,32 +411,6 @@ class CustomStreamWrapper: except Exception: raise ValueError(f"Unable to parse response. Original response: {chunk}") - def handle_cohere_chunk(self, chunk): - chunk = chunk.decode("utf-8") - data_json = json.loads(chunk) - try: - text = "" - is_finished = False - finish_reason = "" - index: Optional[int] = None - if "index" in data_json: - index = data_json.get("index") - if "text" in data_json: - text = data_json["text"] - elif "is_finished" in data_json: - is_finished = data_json["is_finished"] - finish_reason = data_json["finish_reason"] - else: - raise Exception(data_json) - return { - "index": index, - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } - except Exception: - raise ValueError(f"Unable to parse response. Original response: {chunk}") - def handle_azure_chunk(self, chunk): is_finished = False finish_reason = "" @@ -1157,11 +1131,6 @@ class CustomStreamWrapper: ) else: completion_obj["content"] = str(chunk) - elif self.custom_llm_provider == "cohere": - response_obj = self.handle_cohere_chunk(chunk) - completion_obj["content"] = response_obj["text"] - if response_obj["is_finished"]: - self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "petals": if len(self.completion_stream) == 0: if self.received_finish_reason is not None: @@ -1669,6 +1638,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "text-completion-codestral" or self.custom_llm_provider == "azure_text" or self.custom_llm_provider == "cohere_chat" + or self.custom_llm_provider == "cohere" or self.custom_llm_provider == "anthropic" or self.custom_llm_provider == "anthropic_text" or self.custom_llm_provider == "huggingface" diff --git a/litellm/llms/cohere/chat/transformation.py b/litellm/llms/cohere/chat/transformation.py index e5af5a0235..b28f37e6f0 100644 --- a/litellm/llms/cohere/chat/transformation.py +++ b/litellm/llms/cohere/chat/transformation.py @@ -8,13 +8,10 @@ import litellm from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.prompt_templates.factory import cohere_messages_pt_v2 from litellm.types.llms.openai import AllMessageValues -from litellm.types.utils import ( - ChatCompletionToolCallChunk, - ChatCompletionUsageBlock, - GenericStreamingChunk, - ModelResponse, - Usage, -) +from litellm.types.utils import ModelResponse, Usage + +from ..common_utils import ModelResponseIterator as CohereModelResponseIterator +from ..common_utils import validate_environment as cohere_validate_environment if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj @@ -120,28 +117,13 @@ class CohereChatConfig(BaseConfig): optional_params: dict, api_key: Optional[str] = None, ) -> dict: - """ - Return headers to use for cohere chat completion request - - Cohere API Ref: https://docs.cohere.com/reference/chat - Expected headers: - { - "Request-Source": "unspecified:litellm", - "accept": "application/json", - "content-type": "application/json", - "Authorization": "bearer $CO_API_KEY" - } - """ - headers.update( - { - "Request-Source": "unspecified:litellm", - "accept": "application/json", - "content-type": "application/json", - } + return cohere_validate_environment( + headers=headers, + model=model, + messages=messages, + optional_params=optional_params, + api_key=api_key, ) - if api_key: - headers["Authorization"] = f"bearer {api_key}" - return headers def get_supported_openai_params(self, model: str) -> List[str]: return [ @@ -372,7 +354,7 @@ class CohereChatConfig(BaseConfig): sync_stream: bool, json_mode: Optional[bool] = False, ): - return ModelResponseIterator( + return CohereModelResponseIterator( streaming_response=streaming_response, sync_stream=sync_stream, json_mode=json_mode, @@ -387,103 +369,3 @@ class CohereChatConfig(BaseConfig): self, messages: List[AllMessageValues] ) -> List[AllMessageValues]: raise NotImplementedError - - -class ModelResponseIterator: - def __init__( - self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False - ): - self.streaming_response = streaming_response - self.response_iterator = self.streaming_response - self.content_blocks: List = [] - self.tool_index = -1 - self.json_mode = json_mode - - def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: - try: - text = "" - tool_use: Optional[ChatCompletionToolCallChunk] = None - is_finished = False - finish_reason = "" - usage: Optional[ChatCompletionUsageBlock] = None - provider_specific_fields = None - - index = int(chunk.get("index", 0)) - - if "text" in chunk: - text = chunk["text"] - elif "is_finished" in chunk and chunk["is_finished"] is True: - is_finished = chunk["is_finished"] - finish_reason = chunk["finish_reason"] - - if "citations" in chunk: - provider_specific_fields = {"citations": chunk["citations"]} - - returned_chunk = GenericStreamingChunk( - text=text, - tool_use=tool_use, - is_finished=is_finished, - finish_reason=finish_reason, - usage=usage, - index=index, - provider_specific_fields=provider_specific_fields, - ) - - return returned_chunk - - except json.JSONDecodeError: - raise ValueError(f"Failed to decode JSON from chunk: {chunk}") - - # Sync iterator - def __iter__(self): - return self - - def __next__(self): - try: - chunk = self.response_iterator.__next__() - except StopIteration: - raise StopIteration - except ValueError as e: - raise RuntimeError(f"Error receiving chunk from stream: {e}") - - try: - str_line = chunk - if isinstance(chunk, bytes): # Handle binary data - str_line = chunk.decode("utf-8") # Convert bytes to string - index = str_line.find("data:") - if index != -1: - str_line = str_line[index:] - data_json = json.loads(str_line) - return self.chunk_parser(chunk=data_json) - except StopIteration: - raise StopIteration - except ValueError as e: - raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") - - # Async iterator - def __aiter__(self): - self.async_response_iterator = self.streaming_response.__aiter__() - return self - - async def __anext__(self): - try: - chunk = await self.async_response_iterator.__anext__() - except StopAsyncIteration: - raise StopAsyncIteration - except ValueError as e: - raise RuntimeError(f"Error receiving chunk from stream: {e}") - - try: - str_line = chunk - if isinstance(chunk, bytes): # Handle binary data - str_line = chunk.decode("utf-8") # Convert bytes to string - index = str_line.find("data:") - if index != -1: - str_line = str_line[index:] - - data_json = json.loads(str_line) - return self.chunk_parser(chunk=data_json) - except StopAsyncIteration: - raise StopAsyncIteration - except ValueError as e: - raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") diff --git a/litellm/llms/cohere/common_utils.py b/litellm/llms/cohere/common_utils.py index 808624b891..6aaad2b706 100644 --- a/litellm/llms/cohere/common_utils.py +++ b/litellm/llms/cohere/common_utils.py @@ -1,6 +1,13 @@ -from typing import Optional +import json +from typing import List, Optional from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ( + ChatCompletionToolCallChunk, + ChatCompletionUsageBlock, + GenericStreamingChunk, +) class CohereError(BaseLLMException): @@ -8,7 +15,25 @@ class CohereError(BaseLLMException): super().__init__(status_code=status_code, message=message) -def validate_environment(*, api_key: Optional[str], headers: dict) -> dict: +def validate_environment( + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, +) -> dict: + """ + Return headers to use for cohere chat completion request + + Cohere API Ref: https://docs.cohere.com/reference/chat + Expected headers: + { + "Request-Source": "unspecified:litellm", + "accept": "application/json", + "content-type": "application/json", + "Authorization": "bearer $CO_API_KEY" + } + """ headers.update( { "Request-Source": "unspecified:litellm", @@ -17,5 +42,105 @@ def validate_environment(*, api_key: Optional[str], headers: dict) -> dict: } ) if api_key: - headers["Authorization"] = f"Bearer {api_key}" + headers["Authorization"] = f"bearer {api_key}" return headers + + +class ModelResponseIterator: + def __init__( + self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False + ): + self.streaming_response = streaming_response + self.response_iterator = self.streaming_response + self.content_blocks: List = [] + self.tool_index = -1 + self.json_mode = json_mode + + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + provider_specific_fields = None + + index = int(chunk.get("index", 0)) + + if "text" in chunk: + text = chunk["text"] + elif "is_finished" in chunk and chunk["is_finished"] is True: + is_finished = chunk["is_finished"] + finish_reason = chunk["finish_reason"] + + if "citations" in chunk: + provider_specific_fields = {"citations": chunk["citations"]} + + returned_chunk = GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=index, + provider_specific_fields=provider_specific_fields, + ) + + return returned_chunk + + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") + + # Sync iterator + def __iter__(self): + return self + + def __next__(self): + try: + chunk = self.response_iterator.__next__() + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + data_json = json.loads(str_line) + return self.chunk_parser(chunk=data_json) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + + # Async iterator + def __aiter__(self): + self.async_response_iterator = self.streaming_response.__aiter__() + return self + + async def __anext__(self): + try: + chunk = await self.async_response_iterator.__anext__() + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + + data_json = json.loads(str_line) + return self.chunk_parser(chunk=data_json) + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") diff --git a/litellm/llms/cohere/completion/completion.py b/litellm/llms/cohere/completion/completion.py deleted file mode 100644 index 77ed5cc83c..0000000000 --- a/litellm/llms/cohere/completion/completion.py +++ /dev/null @@ -1,155 +0,0 @@ -##### Calls /generate endpoint ####### - -import json -import os -import time -import traceback -import types -from enum import Enum -from typing import Any, Callable, Optional, Union - -import httpx # type: ignore -import requests # type: ignore - -import litellm -from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.utils import Choices, Message, ModelResponse, Usage - -from ..common_utils import CohereError - - -def construct_cohere_tool(tools=None): - if tools is None: - tools = [] - return {"tools": tools} - - -def validate_environment(api_key, headers: dict): - headers.update( - { - "Request-Source": "unspecified:litellm", - "accept": "application/json", - "content-type": "application/json", - } - ) - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - return headers - - -def completion( - model: str, - messages: list, - api_base: str, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - api_key, - logging_obj, - headers: dict, - optional_params: dict, - litellm_params=None, - logger_fn=None, -): - headers = validate_environment(api_key, headers=headers) - completion_url = api_base - model = model - prompt = " ".join(message["content"] for message in messages) - - ## Load Config - config = litellm.CohereConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## Handle Tool Calling - if "tools" in optional_params: - _is_function_call = True - tool_calling_system_prompt = construct_cohere_tool( - tools=optional_params["tools"] - ) - optional_params["tools"] = tool_calling_system_prompt - - data = { - "model": model, - "prompt": prompt, - **optional_params, - } - - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=api_key, - additional_args={ - "complete_input_dict": data, - "headers": headers, - "api_base": completion_url, - }, - ) - ## COMPLETION CALL - response = requests.post( - completion_url, - headers=headers, - data=json.dumps(data), - stream=optional_params["stream"] if "stream" in optional_params else False, - ) - ## error handling for cohere calls - if response.status_code != 200: - raise CohereError(message=response.text, status_code=response.status_code) - - if "stream" in optional_params and optional_params["stream"] is True: - return response.iter_lines() - else: - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) - print_verbose(f"raw model_response: {response.text}") - ## RESPONSE OBJECT - completion_response = response.json() - if "error" in completion_response: - raise CohereError( - message=completion_response["error"], - status_code=response.status_code, - ) - else: - try: - choices_list = [] - for idx, item in enumerate(completion_response["generations"]): - if len(item["text"]) > 0: - message_obj = Message(content=item["text"]) - else: - message_obj = Message(content=None) - choice_obj = Choices( - finish_reason=item["finish_reason"], - index=idx + 1, - message=message_obj, - ) - choices_list.append(choice_obj) - model_response.choices = choices_list # type: ignore - except Exception: - raise CohereError( - message=response.text, status_code=response.status_code - ) - - ## CALCULATING USAGE - prompt_tokens = len(encoding.encode(prompt)) - completion_tokens = len( - encoding.encode(model_response["choices"][0]["message"].get("content", "")) - ) - - model_response.created = int(time.time()) - model_response.model = model - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - setattr(model_response, "usage", usage) - return model_response diff --git a/litellm/llms/cohere/completion/transformation.py b/litellm/llms/cohere/completion/transformation.py index c38c57d145..9414a88e58 100644 --- a/litellm/llms/cohere/completion/transformation.py +++ b/litellm/llms/cohere/completion/transformation.py @@ -1,13 +1,26 @@ +import json +import time import types -from typing import TYPE_CHECKING, Any, List, Optional, Union +from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union import httpx +import litellm from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException +from litellm.llms.prompt_templates.common_utils import convert_content_list_to_str from litellm.types.llms.openai import AllMessageValues -from litellm.types.utils import ModelResponse +from litellm.types.utils import ( + ChatCompletionToolCallChunk, + ChatCompletionUsageBlock, + Choices, + GenericStreamingChunk, + Message, + ModelResponse, + Usage, +) from ..common_utils import CohereError +from ..common_utils import ModelResponseIterator as CohereModelResponseIterator from ..common_utils import validate_environment as cohere_validate_environment if TYPE_CHECKING: @@ -98,7 +111,13 @@ class CohereTextConfig(BaseConfig): optional_params: dict, api_key: Optional[str] = None, ) -> dict: - return cohere_validate_environment(api_key=api_key, headers=headers) + return cohere_validate_environment( + headers=headers, + model=model, + messages=messages, + optional_params=optional_params, + api_key=api_key, + ) def _transform_messages( self, @@ -161,7 +180,33 @@ class CohereTextConfig(BaseConfig): litellm_params: dict, headers: dict, ) -> dict: - raise NotImplementedError + prompt = " ".join( + convert_content_list_to_str(message=message) for message in messages + ) + + ## Load Config + config = litellm.CohereConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## Handle Tool Calling + if "tools" in optional_params: + _is_function_call = True + tool_calling_system_prompt = self._construct_cohere_tool_for_completion_api( + tools=optional_params["tools"] + ) + optional_params["tools"] = tool_calling_system_prompt + + data = { + "model": model, + "prompt": prompt, + **optional_params, + } + + return data def transform_response( self, @@ -176,4 +221,56 @@ class CohereTextConfig(BaseConfig): api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: - raise NotImplementedError + prompt = " ".join( + convert_content_list_to_str(message=message) for message in messages + ) + completion_response = raw_response.json() + choices_list = [] + for idx, item in enumerate(completion_response["generations"]): + if len(item["text"]) > 0: + message_obj = Message(content=item["text"]) + else: + message_obj = Message(content=None) + choice_obj = Choices( + finish_reason=item["finish_reason"], + index=idx + 1, + message=message_obj, + ) + choices_list.append(choice_obj) + model_response.choices = choices_list # type: ignore + + ## CALCULATING USAGE + prompt_tokens = len(encoding.encode(prompt)) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"].get("content", "")) + ) + + model_response.created = int(time.time()) + model_response.model = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + setattr(model_response, "usage", usage) + return model_response + + def _construct_cohere_tool_for_completion_api( + self, + tools: Optional[List] = None, + ) -> dict: + if tools is None: + tools = [] + return {"tools": tools} + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ): + return CohereModelResponseIterator( + streaming_response=streaming_response, + sync_stream=sync_stream, + json_mode=json_mode, + ) diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 57f8c60b62..de42def31a 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -373,8 +373,12 @@ class BaseLLMHTTPHandler: error_headers = getattr(error_response, "headers", None) if error_response and hasattr(error_response, "text"): error_text = getattr(error_response, "text", error_text) - raise provider_config.error_class( # type: ignore - message=error_text, + if error_headers: + error_headers = dict(error_headers) + else: + error_headers = {} + raise provider_config.get_error_class( + error_message=error_text, status_code=status_code, headers=error_headers, ) diff --git a/litellm/main.py b/litellm/main.py index 25551711e4..b7fd631ec6 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -109,7 +109,6 @@ from .llms.azure_text import AzureTextCompletion from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from .llms.bedrock.embed.embedding import BedrockEmbedding from .llms.bedrock.image.image_handler import BedrockImageGeneration -from .llms.cohere.completion import completion as cohere_completion # type: ignore from .llms.cohere.embed import handler as cohere_embed from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from .llms.custom_llm import CustomLLM, custom_chat_llm_router @@ -446,6 +445,7 @@ async def acompletion( or custom_llm_provider == "groq" or custom_llm_provider == "nvidia_nim" or custom_llm_provider == "cohere_chat" + or custom_llm_provider == "cohere" or custom_llm_provider == "cerebras" or custom_llm_provider == "sambanova" or custom_llm_provider == "ai21_chat" @@ -1895,31 +1895,22 @@ def completion( # type: ignore # noqa: PLR0915 if extra_headers is not None: headers.update(extra_headers) - model_response = cohere_completion.completion( + response = base_llm_http_handler.completion( model=model, + stream=stream, messages=messages, + acompletion=acompletion, api_base=api_base, model_response=model_response, - print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, + custom_llm_provider="cohere", + timeout=timeout, headers=headers, + encoding=encoding, api_key=cohere_key, logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements ) - - if "stream" in optional_params and optional_params["stream"] is True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="cohere", - logging_obj=logging, - ) - return response - response = model_response elif custom_llm_provider == "cohere_chat": cohere_key = ( api_key diff --git a/tests/llm_translation/test_cohere_generate_api.py b/tests/llm_translation/test_cohere_generate_api.py new file mode 100644 index 0000000000..9e0bb82846 --- /dev/null +++ b/tests/llm_translation/test_cohere_generate_api.py @@ -0,0 +1,184 @@ +import os +import sys +import traceback + +from dotenv import load_dotenv + +load_dotenv() +import io +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import json + +import pytest + +import litellm +from litellm import completion +from litellm.llms.cohere.completion.transformation import CohereTextConfig + + +@pytest.mark.asyncio +async def test_cohere_generate_api_completion(): + try: + litellm.set_verbose = False + messages = [ + {"role": "system", "content": "You're a good bot"}, + { + "role": "user", + "content": "Hey", + }, + ] + response = completion( + model="cohere/command-nightly", + messages=messages, + max_tokens=10, + ) + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.asyncio +async def test_cohere_generate_api_stream(): + try: + litellm.set_verbose = True + messages = [ + {"role": "system", "content": "You're a good bot"}, + { + "role": "user", + "content": "Hey", + }, + ] + response = await litellm.acompletion( + model="cohere/command-nightly", + messages=messages, + max_tokens=10, + stream=True, + ) + print("async cohere stream response", response) + async for chunk in response: + print(chunk) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +def test_completion_cohere_stream_bad_key(): + try: + api_key = "bad-key" + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "how does a court case get to the Supreme Court?", + }, + ] + completion( + model="command-nightly", + messages=messages, + stream=True, + max_tokens=50, + api_key=api_key, + ) + + except litellm.AuthenticationError as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +def test_cohere_transform_request(): + try: + config = CohereTextConfig() + messages = [ + {"role": "system", "content": "You're a helpful bot"}, + {"role": "user", "content": "Hello"}, + ] + optional_params = {"max_tokens": 10, "temperature": 0.7} + headers = {} + + transformed_request = config.transform_request( + model="command-nightly", + messages=messages, + optional_params=optional_params, + litellm_params={}, + headers=headers, + ) + + print("transformed_request", json.dumps(transformed_request, indent=4)) + + assert transformed_request["model"] == "command-nightly" + assert transformed_request["prompt"] == "You're a helpful bot Hello" + assert transformed_request["max_tokens"] == 10 + assert transformed_request["temperature"] == 0.7 + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +def test_cohere_transform_request_with_tools(): + try: + config = CohereTextConfig() + messages = [{"role": "user", "content": "What's the weather?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + ] + optional_params = {"tools": tools} + + transformed_request = config.transform_request( + model="command-nightly", + messages=messages, + optional_params=optional_params, + litellm_params={}, + headers={}, + ) + + print("transformed_request", json.dumps(transformed_request, indent=4)) + assert "tools" in transformed_request + assert transformed_request["tools"] == {"tools": tools} + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +def test_cohere_map_openai_params(): + try: + config = CohereTextConfig() + openai_params = { + "temperature": 0.7, + "max_tokens": 100, + "n": 2, + "top_p": 0.9, + "frequency_penalty": 0.5, + "presence_penalty": 0.5, + "stop": ["END"], + "stream": True, + } + + mapped_params = config.map_openai_params( + non_default_params=openai_params, + optional_params={}, + model="command-nightly", + drop_params=False, + ) + + assert mapped_params["temperature"] == 0.7 + assert mapped_params["max_tokens"] == 100 + assert mapped_params["num_generations"] == 2 + assert mapped_params["p"] == 0.9 + assert mapped_params["frequency_penalty"] == 0.5 + assert mapped_params["presence_penalty"] == 0.5 + assert mapped_params["stop_sequences"] == ["END"] + assert mapped_params["stream"] == True + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py index 757ff4d611..7227099792 100644 --- a/tests/local_testing/test_streaming.py +++ b/tests/local_testing/test_streaming.py @@ -436,47 +436,6 @@ def test_completion_azure_stream_content_filter_no_delta(): pytest.fail(f"An exception occurred - {str(e)}") -def test_completion_cohere_stream_bad_key(): - try: - litellm.cache = None - api_key = "bad-key" - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": "how does a court case get to the Supreme Court?", - }, - ] - response = completion( - model="command-nightly", - messages=messages, - stream=True, - max_tokens=50, - api_key=api_key, - ) - complete_response = "" - # Add any assertions here to check the response - has_finish_reason = False - for idx, chunk in enumerate(response): - chunk, finished = streaming_format_tests(idx, chunk) - has_finish_reason = finished - if finished: - break - complete_response += chunk - if has_finish_reason is False: - raise Exception("Finish reason not in final chunk") - if complete_response.strip() == "": - raise Exception("Empty response received") - print(f"completion_response: {complete_response}") - except AuthenticationError as e: - pass - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - -# test_completion_cohere_stream_bad_key() - - @pytest.mark.flaky(retries=5, delay=1) def test_completion_azure_stream(): try: