From 2ed593e05225d37a18bc82eb775673b60dd73498 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Mon, 14 Apr 2025 19:51:01 -0700 Subject: [PATCH] Updated cohere v2 passthrough (#9997) * Add cohere `/v2/chat` pass-through cost tracking support (#8235) * feat(cohere_passthrough_handler.py): initial working commit with cohere passthrough cost tracking * fix(v2_transformation.py): support cohere /v2/chat endpoint * fix: fix linting errors * fix: fix import * fix(v2_transformation.py): fix linting error * test: handle openai exception change --- .gitignore | 1 + litellm/llms/cohere/chat/v2_transformation.py | 356 ++++++++++++++++++ litellm/llms/cohere/common_utils.py | 35 +- ...odel_prices_and_context_window_backup.json | 2 +- .../base_passthrough_logging_handler.py | 219 +++++++++++ .../cohere_passthrough_logging_handler.py | 56 +++ .../pass_through_endpoints.py | 1 + .../pass_through_endpoints/success_handler.py | 32 ++ litellm/types/llms/cohere.py | 56 +++ tests/local_testing/test_exceptions.py | 4 +- 10 files changed, 742 insertions(+), 20 deletions(-) create mode 100644 litellm/llms/cohere/chat/v2_transformation.py create mode 100644 litellm/proxy/pass_through_endpoints/llm_provider_handlers/base_passthrough_logging_handler.py create mode 100644 litellm/proxy/pass_through_endpoints/llm_provider_handlers/cohere_passthrough_logging_handler.py diff --git a/.gitignore b/.gitignore index e599411328..81fff2d342 100644 --- a/.gitignore +++ b/.gitignore @@ -73,6 +73,7 @@ tests/local_testing/log.txt .codegpt litellm/proxy/_new_new_secret_config.yaml litellm/proxy/custom_guardrail.py +.mypy_cache/* litellm/proxy/_experimental/out/404.html litellm/proxy/_experimental/out/404.html litellm/proxy/_experimental/out/model_hub.html diff --git a/litellm/llms/cohere/chat/v2_transformation.py b/litellm/llms/cohere/chat/v2_transformation.py new file mode 100644 index 0000000000..76948e7f8b --- /dev/null +++ b/litellm/llms/cohere/chat/v2_transformation.py @@ -0,0 +1,356 @@ +import time +from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union + +import httpx + +import litellm +from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2 +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException +from litellm.types.llms.cohere import CohereV2ChatResponse +from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolCallChunk +from litellm.types.utils import ModelResponse, Usage + +from ..common_utils import CohereError +from ..common_utils import ModelResponseIterator as CohereModelResponseIterator +from ..common_utils import validate_environment as cohere_validate_environment + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class CohereV2ChatConfig(BaseConfig): + """ + Configuration class for Cohere's API interface. + + Args: + preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one. + chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model. + generation_id (str, optional): Unique identifier for the generated reply. + response_id (str, optional): Unique identifier for the response. + conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation. + prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'. + connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply. + search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries. + documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite. + temperature (float, optional): A non-negative float that tunes the degree of randomness in generation. + max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response. + k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step. + p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation. + frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens. + presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens. + tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking. + tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools. + seed (int, optional): A seed to assist reproducibility of the model's response. + """ + + preamble: Optional[str] = None + chat_history: Optional[list] = None + generation_id: Optional[str] = None + response_id: Optional[str] = None + conversation_id: Optional[str] = None + prompt_truncation: Optional[str] = None + connectors: Optional[list] = None + search_queries_only: Optional[bool] = None + documents: Optional[list] = None + temperature: Optional[int] = None + max_tokens: Optional[int] = None + k: Optional[int] = None + p: Optional[int] = None + frequency_penalty: Optional[int] = None + presence_penalty: Optional[int] = None + tools: Optional[list] = None + tool_results: Optional[list] = None + seed: Optional[int] = None + + def __init__( + self, + preamble: Optional[str] = None, + chat_history: Optional[list] = None, + generation_id: Optional[str] = None, + response_id: Optional[str] = None, + conversation_id: Optional[str] = None, + prompt_truncation: Optional[str] = None, + connectors: Optional[list] = None, + search_queries_only: Optional[bool] = None, + documents: Optional[list] = None, + temperature: Optional[int] = None, + max_tokens: Optional[int] = None, + k: Optional[int] = None, + p: Optional[int] = None, + frequency_penalty: Optional[int] = None, + presence_penalty: Optional[int] = None, + tools: Optional[list] = None, + tool_results: Optional[list] = None, + seed: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + return cohere_validate_environment( + headers=headers, + model=model, + messages=messages, + optional_params=optional_params, + api_key=api_key, + ) + + def get_supported_openai_params(self, model: str) -> List[str]: + return [ + "stream", + "temperature", + "max_tokens", + "top_p", + "frequency_penalty", + "presence_penalty", + "stop", + "n", + "tools", + "tool_choice", + "seed", + "extra_headers", + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + if param == "stream": + optional_params["stream"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "max_tokens": + optional_params["max_tokens"] = value + if param == "n": + optional_params["num_generations"] = value + if param == "top_p": + optional_params["p"] = value + if param == "frequency_penalty": + optional_params["frequency_penalty"] = value + if param == "presence_penalty": + optional_params["presence_penalty"] = value + if param == "stop": + optional_params["stop_sequences"] = value + if param == "tools": + optional_params["tools"] = value + if param == "seed": + optional_params["seed"] = value + return optional_params + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + ## Load Config + for k, v in litellm.CohereChatConfig.get_config().items(): + if ( + k not in optional_params + ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + most_recent_message, chat_history = cohere_messages_pt_v2( + messages=messages, model=model, llm_provider="cohere_chat" + ) + + ## Handle Tool Calling + if "tools" in optional_params: + _is_function_call = True + cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"]) + optional_params["tools"] = cohere_tools + if isinstance(most_recent_message, dict): + optional_params["tool_results"] = [most_recent_message] + elif isinstance(most_recent_message, str): + optional_params["message"] = most_recent_message + + ## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails + if len(chat_history) > 0 and chat_history[-1]["role"] == "USER": + optional_params["force_single_step"] = True + + return optional_params + + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + try: + raw_response_json = raw_response.json() + except Exception: + raise CohereError( + message=raw_response.text, status_code=raw_response.status_code + ) + + try: + cohere_v2_chat_response = CohereV2ChatResponse(**raw_response_json) # type: ignore + except Exception: + raise CohereError(message=raw_response.text, status_code=422) + + cohere_content = cohere_v2_chat_response["message"].get("content", None) + if cohere_content is not None: + model_response.choices[0].message.content = "".join( # type: ignore + [ + content.get("text", "") + for content in cohere_content + if content is not None + ] + ) + + ## ADD CITATIONS + if "citations" in cohere_v2_chat_response: + setattr(model_response, "citations", cohere_v2_chat_response["citations"]) + + ## Tool calling response + cohere_tools_response = cohere_v2_chat_response["message"].get("tool_calls", []) + if cohere_tools_response is not None and cohere_tools_response != []: + # convert cohere_tools_response to OpenAI response format + tool_calls: List[ChatCompletionToolCallChunk] = [] + for index, tool in enumerate(cohere_tools_response): + tool_call: ChatCompletionToolCallChunk = { + **tool, # type: ignore + "index": index, + } + tool_calls.append(tool_call) + _message = litellm.Message( + tool_calls=tool_calls, + content=None, + ) + model_response.choices[0].message = _message # type: ignore + + ## CALCULATING USAGE - use cohere `billed_units` for returning usage + token_usage = cohere_v2_chat_response["usage"].get("tokens", {}) + prompt_tokens = token_usage.get("input_tokens", 0) + completion_tokens = token_usage.get("output_tokens", 0) + + model_response.created = int(time.time()) + model_response.model = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + setattr(model_response, "usage", usage) + return model_response + + def _construct_cohere_tool( + self, + tools: Optional[list] = None, + ): + if tools is None: + tools = [] + cohere_tools = [] + for tool in tools: + cohere_tool = self._translate_openai_tool_to_cohere(tool) + cohere_tools.append(cohere_tool) + return cohere_tools + + def _translate_openai_tool_to_cohere( + self, + openai_tool: dict, + ): + # cohere tools look like this + """ + { + "name": "query_daily_sales_report", + "description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.", + "parameter_definitions": { + "day": { + "description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.", + "type": "str", + "required": True + } + } + } + """ + + # OpenAI tools look like this + """ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + """ + cohere_tool = { + "name": openai_tool["function"]["name"], + "description": openai_tool["function"]["description"], + "parameter_definitions": {}, + } + + for param_name, param_def in openai_tool["function"]["parameters"][ + "properties" + ].items(): + required_params = ( + openai_tool.get("function", {}) + .get("parameters", {}) + .get("required", []) + ) + cohere_param_def = { + "description": param_def.get("description", ""), + "type": param_def.get("type", ""), + "required": param_name in required_params, + } + cohere_tool["parameter_definitions"][param_name] = cohere_param_def + + return cohere_tool + + def get_model_response_iterator( + self, + streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], + sync_stream: bool, + json_mode: Optional[bool] = False, + ): + return CohereModelResponseIterator( + streaming_response=streaming_response, + sync_stream=sync_stream, + json_mode=json_mode, + ) + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + return CohereError(status_code=status_code, message=error_message) diff --git a/litellm/llms/cohere/common_utils.py b/litellm/llms/cohere/common_utils.py index 11ff73efc2..6dbe52d575 100644 --- a/litellm/llms/cohere/common_utils.py +++ b/litellm/llms/cohere/common_utils.py @@ -104,19 +104,28 @@ class ModelResponseIterator: raise RuntimeError(f"Error receiving chunk from stream: {e}") try: - str_line = chunk - if isinstance(chunk, bytes): # Handle binary data - str_line = chunk.decode("utf-8") # Convert bytes to string - index = str_line.find("data:") - if index != -1: - str_line = str_line[index:] - data_json = json.loads(str_line) - return self.chunk_parser(chunk=data_json) + return self.convert_str_chunk_to_generic_chunk(chunk=chunk) except StopIteration: raise StopIteration except ValueError as e: raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + def convert_str_chunk_to_generic_chunk(self, chunk: str) -> GenericStreamingChunk: + """ + Convert a string chunk to a GenericStreamingChunk + + Note: This is used for Cohere pass through streaming logging + """ + str_line = chunk + if isinstance(chunk, bytes): # Handle binary data + str_line = chunk.decode("utf-8") # Convert bytes to string + index = str_line.find("data:") + if index != -1: + str_line = str_line[index:] + + data_json = json.loads(str_line) + return self.chunk_parser(chunk=data_json) + # Async iterator def __aiter__(self): self.async_response_iterator = self.streaming_response.__aiter__() @@ -131,15 +140,7 @@ class ModelResponseIterator: raise RuntimeError(f"Error receiving chunk from stream: {e}") try: - str_line = chunk - if isinstance(chunk, bytes): # Handle binary data - str_line = chunk.decode("utf-8") # Convert bytes to string - index = str_line.find("data:") - if index != -1: - str_line = str_line[index:] - - data_json = json.loads(str_line) - return self.chunk_parser(chunk=data_json) + return self.convert_str_chunk_to_generic_chunk(chunk=chunk) except StopAsyncIteration: raise StopAsyncIteration except ValueError as e: diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 4480fbee90..d42762355c 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -2669,7 +2669,7 @@ "max_output_tokens": 4096, "input_cost_per_token": 0.00000008, "input_cost_per_audio_token": 0.000004, - "output_cost_per_token": 0.00032, + "output_cost_per_token": 0.00000032, "litellm_provider": "azure_ai", "mode": "chat", "supports_audio_input": true, diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/base_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/base_passthrough_logging_handler.py new file mode 100644 index 0000000000..e73b652d83 --- /dev/null +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/base_passthrough_logging_handler.py @@ -0,0 +1,219 @@ +import json +from datetime import datetime +from typing import TYPE_CHECKING, Any, List, Optional, Union + +import httpx + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.litellm_logging import ( + get_standard_logging_object_payload, +) +from litellm.llms.base_llm.chat.transformation import BaseConfig +from litellm.proxy._types import PassThroughEndpointLoggingTypedDict +from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body +from litellm.proxy.pass_through_endpoints.types import PassthroughStandardLoggingPayload +from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse + +if TYPE_CHECKING: + from ..success_handler import PassThroughEndpointLogging + from ..types import EndpointType +else: + PassThroughEndpointLogging = Any + EndpointType = Any + +from abc import ABC, abstractmethod + + +class BasePassthroughLoggingHandler(ABC): + @property + @abstractmethod + def llm_provider_name(self) -> LlmProviders: + pass + + @abstractmethod + def get_provider_config(self, model: str) -> BaseConfig: + pass + + def passthrough_chat_handler( + self, + httpx_response: httpx.Response, + response_body: dict, + logging_obj: LiteLLMLoggingObj, + url_route: str, + result: str, + start_time: datetime, + end_time: datetime, + cache_hit: bool, + request_body: dict, + **kwargs, + ) -> PassThroughEndpointLoggingTypedDict: + """ + Transforms LLM response to OpenAI response, generates a standard logging object so downstream logging can be handled + """ + model = request_body.get("model", response_body.get("model", "")) + provider_config = self.get_provider_config(model=model) + litellm_model_response: ModelResponse = provider_config.transform_response( + raw_response=httpx_response, + model_response=litellm.ModelResponse(), + model=model, + messages=[], + logging_obj=logging_obj, + optional_params={}, + api_key="", + request_data={}, + encoding=litellm.encoding, + json_mode=False, + litellm_params={}, + ) + + kwargs = self._create_response_logging_payload( + litellm_model_response=litellm_model_response, + model=model, + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + ) + + return { + "result": litellm_model_response, + "kwargs": kwargs, + } + + def _get_user_from_metadata( + self, + passthrough_logging_payload: PassthroughStandardLoggingPayload, + ) -> Optional[str]: + request_body = passthrough_logging_payload.get("request_body") + if request_body: + return get_end_user_id_from_request_body(request_body) + return None + + def _create_response_logging_payload( + self, + litellm_model_response: Union[ModelResponse, TextCompletionResponse], + model: str, + kwargs: dict, + start_time: datetime, + end_time: datetime, + logging_obj: LiteLLMLoggingObj, + ) -> dict: + """ + Create the standard logging object for Generic LLM passthrough + + handles streaming and non-streaming responses + """ + + try: + response_cost = litellm.completion_cost( + completion_response=litellm_model_response, + model=model, + ) + + kwargs["response_cost"] = response_cost + kwargs["model"] = model + passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore + kwargs.get("passthrough_logging_payload") + ) + if passthrough_logging_payload: + user = self._get_user_from_metadata( + passthrough_logging_payload=passthrough_logging_payload, + ) + if user: + kwargs.setdefault("litellm_params", {}) + kwargs["litellm_params"].update( + {"proxy_server_request": {"body": {"user": user}}} + ) + + # Make standard logging object for Anthropic + standard_logging_object = get_standard_logging_object_payload( + kwargs=kwargs, + init_response_obj=litellm_model_response, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + status="success", + ) + + # pretty print standard logging object + verbose_proxy_logger.debug( + "standard_logging_object= %s", + json.dumps(standard_logging_object, indent=4), + ) + kwargs["standard_logging_object"] = standard_logging_object + + # set litellm_call_id to logging response object + litellm_model_response.id = logging_obj.litellm_call_id + litellm_model_response.model = model + logging_obj.model_call_details["model"] = model + return kwargs + except Exception as e: + verbose_proxy_logger.exception( + "Error creating LLM passthrough response logging payload: %s", e + ) + return kwargs + + @abstractmethod + def _build_complete_streaming_response( + self, + all_chunks: List[str], + litellm_logging_obj: LiteLLMLoggingObj, + model: str, + ) -> Optional[Union[ModelResponse, TextCompletionResponse]]: + """ + Builds complete response from raw chunks + + - Converts str chunks to generic chunks + - Converts generic chunks to litellm chunks (OpenAI format) + - Builds complete response from litellm chunks + """ + pass + + def _handle_logging_llm_collected_chunks( + self, + litellm_logging_obj: LiteLLMLoggingObj, + passthrough_success_handler_obj: PassThroughEndpointLogging, + url_route: str, + request_body: dict, + endpoint_type: EndpointType, + start_time: datetime, + all_chunks: List[str], + end_time: datetime, + ) -> PassThroughEndpointLoggingTypedDict: + """ + Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks + + - Builds complete response from chunks + - Creates standard logging object + - Logs in litellm callbacks + """ + + model = request_body.get("model", "") + complete_streaming_response = self._build_complete_streaming_response( + all_chunks=all_chunks, + litellm_logging_obj=litellm_logging_obj, + model=model, + ) + if complete_streaming_response is None: + verbose_proxy_logger.error( + "Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..." + ) + return { + "result": None, + "kwargs": {}, + } + kwargs = self._create_response_logging_payload( + litellm_model_response=complete_streaming_response, + model=model, + kwargs={}, + start_time=start_time, + end_time=end_time, + logging_obj=litellm_logging_obj, + ) + + return { + "result": complete_streaming_response, + "kwargs": kwargs, + } diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/cohere_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/cohere_passthrough_logging_handler.py new file mode 100644 index 0000000000..a8228de6e0 --- /dev/null +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/cohere_passthrough_logging_handler.py @@ -0,0 +1,56 @@ +from typing import List, Optional, Union + +from litellm import stream_chunk_builder +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper +from litellm.llms.base_llm.chat.transformation import BaseConfig +from litellm.llms.cohere.chat.v2_transformation import CohereV2ChatConfig +from litellm.llms.cohere.common_utils import ( + ModelResponseIterator as CohereModelResponseIterator, +) +from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse + +from .base_passthrough_logging_handler import BasePassthroughLoggingHandler + + +class CoherePassthroughLoggingHandler(BasePassthroughLoggingHandler): + @property + def llm_provider_name(self) -> LlmProviders: + return LlmProviders.COHERE + + def get_provider_config(self, model: str) -> BaseConfig: + return CohereV2ChatConfig() + + def _build_complete_streaming_response( + self, + all_chunks: List[str], + litellm_logging_obj: LiteLLMLoggingObj, + model: str, + ) -> Optional[Union[ModelResponse, TextCompletionResponse]]: + cohere_model_response_iterator = CohereModelResponseIterator( + streaming_response=None, + sync_stream=False, + ) + litellm_custom_stream_wrapper = CustomStreamWrapper( + completion_stream=cohere_model_response_iterator, + model=model, + logging_obj=litellm_logging_obj, + custom_llm_provider="cohere", + ) + all_openai_chunks = [] + for _chunk_str in all_chunks: + try: + generic_chunk = ( + cohere_model_response_iterator.convert_str_chunk_to_generic_chunk( + chunk=_chunk_str + ) + ) + litellm_chunk = litellm_custom_stream_wrapper.chunk_creator( + chunk=generic_chunk + ) + if litellm_chunk is not None: + all_openai_chunks.append(litellm_chunk) + except (StopIteration, StopAsyncIteration): + break + complete_streaming_response = stream_chunk_builder(chunks=all_openai_chunks) + return complete_streaming_response diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 563d0cb543..737cf12001 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -683,6 +683,7 @@ async def pass_through_request( # noqa: PLR0915 end_time=end_time, logging_obj=logging_obj, cache_hit=False, + request_body=_parsed_body, **kwargs, ) ) diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 02e81566e8..e8676f018f 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -16,10 +16,15 @@ from .llm_provider_handlers.anthropic_passthrough_logging_handler import ( from .llm_provider_handlers.assembly_passthrough_logging_handler import ( AssemblyAIPassthroughLoggingHandler, ) +from .llm_provider_handlers.cohere_passthrough_logging_handler import ( + CoherePassthroughLoggingHandler, +) from .llm_provider_handlers.vertex_passthrough_logging_handler import ( VertexPassthroughLoggingHandler, ) +cohere_passthrough_logging_handler = CoherePassthroughLoggingHandler() + class PassThroughEndpointLogging: def __init__(self): @@ -32,6 +37,8 @@ class PassThroughEndpointLogging: # Anthropic self.TRACKED_ANTHROPIC_ROUTES = ["/messages"] + # Cohere + self.TRACKED_COHERE_ROUTES = ["/v2/chat"] self.assemblyai_passthrough_logging_handler = ( AssemblyAIPassthroughLoggingHandler() ) @@ -84,6 +91,7 @@ class PassThroughEndpointLogging: start_time: datetime, end_time: datetime, cache_hit: bool, + request_body: dict, **kwargs, ): standard_logging_response_object: Optional[ @@ -125,6 +133,25 @@ class PassThroughEndpointLogging: anthropic_passthrough_logging_handler_result["result"] ) kwargs = anthropic_passthrough_logging_handler_result["kwargs"] + elif self.is_cohere_route(url_route): + cohere_passthrough_logging_handler_result = ( + cohere_passthrough_logging_handler.passthrough_chat_handler( + httpx_response=httpx_response, + response_body=response_body or {}, + logging_obj=logging_obj, + url_route=url_route, + result=result, + start_time=start_time, + end_time=end_time, + cache_hit=cache_hit, + request_body=request_body, + **kwargs, + ) + ) + standard_logging_response_object = ( + cohere_passthrough_logging_handler_result["result"] + ) + kwargs = cohere_passthrough_logging_handler_result["kwargs"] elif self.is_assemblyai_route(url_route): if ( AssemblyAIPassthroughLoggingHandler._should_log_request( @@ -173,6 +200,11 @@ class PassThroughEndpointLogging: return True return False + def is_cohere_route(self, url_route: str): + for route in self.TRACKED_COHERE_ROUTES: + if route in url_route: + return True + def is_assemblyai_route(self, url_route: str): parsed_url = urlparse(url_route) if parsed_url.hostname == "api.assemblyai.com": diff --git a/litellm/types/llms/cohere.py b/litellm/types/llms/cohere.py index 7112a242f9..ea41bacd96 100644 --- a/litellm/types/llms/cohere.py +++ b/litellm/types/llms/cohere.py @@ -44,3 +44,59 @@ class ChatHistoryChatBot(TypedDict, total=False): ChatHistory = List[ Union[ChatHistorySystem, ChatHistoryChatBot, ChatHistoryUser, ChatHistoryToolResult] ] + + +class CohereV2ChatResponseMessageToolCallFunction(TypedDict, total=False): + name: str + parameters: dict + + +class CohereV2ChatResponseMessageToolCall(TypedDict): + id: str + type: Literal["function"] + function: CohereV2ChatResponseMessageToolCallFunction + + +class CohereV2ChatResponseMessageContent(TypedDict): + id: str + type: Literal["tool"] + tool: str + + +class CohereV2ChatResponseMessage(TypedDict, total=False): + role: Required[Literal["assistant"]] + tool_calls: List[CohereV2ChatResponseMessageToolCall] + tool_plan: str + content: List[CohereV2ChatResponseMessageContent] + citations: List[dict] + + +class CohereV2ChatResponseUsageBilledUnits(TypedDict, total=False): + input_tokens: int + output_tokens: int + search_units: int + classifications: int + + +class CohereV2ChatResponseUsageTokens(TypedDict, total=False): + input_tokens: int + output_tokens: int + + +class CohereV2ChatResponseUsage(TypedDict, total=False): + billed_units: CohereV2ChatResponseUsageBilledUnits + tokens: CohereV2ChatResponseUsageTokens + + +class CohereV2ChatResponseLogProbs(TypedDict, total=False): + token_ids: Required[List[int]] + text: str + logprobs: List[float] + + +class CohereV2ChatResponse(TypedDict): + id: str + finish_reason: str + message: CohereV2ChatResponseMessage + usage: CohereV2ChatResponseUsage + logprobs: CohereV2ChatResponseLogProbs diff --git a/tests/local_testing/test_exceptions.py b/tests/local_testing/test_exceptions.py index e68d368779..229ea07c7a 100644 --- a/tests/local_testing/test_exceptions.py +++ b/tests/local_testing/test_exceptions.py @@ -498,11 +498,11 @@ def test_completion_bedrock_invalid_role_exception(): == "litellm.BadRequestError: Invalid Message passed in {'role': 'very-bad-role', 'content': 'hello'}" ) - +@pytest.mark.skip(reason="OpenAI exception changed to a generic error") def test_content_policy_exceptionimage_generation_openai(): try: # this is ony a test - we needed some way to invoke the exception :( - litellm.set_verbose = True + litellm._turn_on_debug() response = litellm.image_generation( prompt="where do i buy lethal drugs from", model="dall-e-3" )