diff --git a/.circleci/config.yml b/.circleci/config.yml index b2f6b7edce..ecae22f872 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1935,12 +1935,12 @@ jobs: pip install prisma pip install fastapi pip install jsonschema - pip install "httpx==0.24.1" + pip install "httpx==0.27.0" pip install "anyio==3.7.1" pip install "asyncio==3.4.3" pip install "PyGithub==1.59.1" pip install "google-cloud-aiplatform==1.59.0" - pip install "anthropic==0.21.3" + pip install "anthropic==0.49.0" # Run pytest and generate JUnit XML report - run: name: Build Docker image diff --git a/litellm/__init__.py b/litellm/__init__.py index 60b8cf81a0..9ca1517c92 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -800,9 +800,6 @@ from .llms.oobabooga.chat.transformation import OobaboogaConfig from .llms.maritalk import MaritalkConfig from .llms.openrouter.chat.transformation import OpenrouterConfig from .llms.anthropic.chat.transformation import AnthropicConfig -from .llms.anthropic.experimental_pass_through.transformation import ( - AnthropicExperimentalPassThroughConfig, -) from .llms.groq.stt.transformation import GroqSTTConfig from .llms.anthropic.completion.transformation import AnthropicTextConfig from .llms.triton.completion.transformation import TritonConfig @@ -821,6 +818,9 @@ from .llms.infinity.rerank.transformation import InfinityRerankConfig from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig from .llms.clarifai.chat.transformation import ClarifaiConfig from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config +from .llms.anthropic.experimental_pass_through.messages.transformation import ( + AnthropicMessagesConfig, +) from .llms.together_ai.chat import TogetherAIConfig from .llms.together_ai.completion.transformation import TogetherAITextCompletionConfig from .llms.cloudflare.chat.transformation import CloudflareChatConfig @@ -1011,6 +1011,7 @@ from .assistants.main import * from .batches.main import * from .batch_completion.main import * # type: ignore from .rerank_api.main import * +from .llms.anthropic.experimental_pass_through.messages.handler import * from .realtime_api.main import _arealtime from .fine_tuning.main import * from .files.main import * diff --git a/litellm/adapters/anthropic_adapter.py b/litellm/adapters/anthropic_adapter.py deleted file mode 100644 index 961bc77527..0000000000 --- a/litellm/adapters/anthropic_adapter.py +++ /dev/null @@ -1,186 +0,0 @@ -# What is this? -## Translates OpenAI call to Anthropic `/v1/messages` format -import traceback -from typing import Any, Optional - -import litellm -from litellm import ChatCompletionRequest, verbose_logger -from litellm.integrations.custom_logger import CustomLogger -from litellm.types.llms.anthropic import AnthropicMessagesRequest, AnthropicResponse -from litellm.types.utils import AdapterCompletionStreamWrapper, ModelResponse - - -class AnthropicAdapter(CustomLogger): - def __init__(self) -> None: - super().__init__() - - def translate_completion_input_params( - self, kwargs - ) -> Optional[ChatCompletionRequest]: - """ - - translate params, where needed - - pass rest, as is - """ - request_body = AnthropicMessagesRequest(**kwargs) # type: ignore - - translated_body = litellm.AnthropicExperimentalPassThroughConfig().translate_anthropic_to_openai( - anthropic_message_request=request_body - ) - - return translated_body - - def translate_completion_output_params( - self, response: ModelResponse - ) -> Optional[AnthropicResponse]: - - return litellm.AnthropicExperimentalPassThroughConfig().translate_openai_response_to_anthropic( - response=response - ) - - def translate_completion_output_params_streaming( - self, completion_stream: Any - ) -> AdapterCompletionStreamWrapper | None: - return AnthropicStreamWrapper(completion_stream=completion_stream) - - -anthropic_adapter = AnthropicAdapter() - - -class AnthropicStreamWrapper(AdapterCompletionStreamWrapper): - """ - - first chunk return 'message_start' - - content block must be started and stopped - - finish_reason must map exactly to anthropic reason, else anthropic client won't be able to parse it. - """ - - sent_first_chunk: bool = False - sent_content_block_start: bool = False - sent_content_block_finish: bool = False - sent_last_message: bool = False - holding_chunk: Optional[Any] = None - - def __next__(self): - try: - if self.sent_first_chunk is False: - self.sent_first_chunk = True - return { - "type": "message_start", - "message": { - "id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", - "type": "message", - "role": "assistant", - "content": [], - "model": "claude-3-5-sonnet-20240620", - "stop_reason": None, - "stop_sequence": None, - "usage": {"input_tokens": 25, "output_tokens": 1}, - }, - } - if self.sent_content_block_start is False: - self.sent_content_block_start = True - return { - "type": "content_block_start", - "index": 0, - "content_block": {"type": "text", "text": ""}, - } - - for chunk in self.completion_stream: - if chunk == "None" or chunk is None: - raise Exception - - processed_chunk = litellm.AnthropicExperimentalPassThroughConfig().translate_streaming_openai_response_to_anthropic( - response=chunk - ) - if ( - processed_chunk["type"] == "message_delta" - and self.sent_content_block_finish is False - ): - self.holding_chunk = processed_chunk - self.sent_content_block_finish = True - return { - "type": "content_block_stop", - "index": 0, - } - elif self.holding_chunk is not None: - return_chunk = self.holding_chunk - self.holding_chunk = processed_chunk - return return_chunk - else: - return processed_chunk - if self.holding_chunk is not None: - return_chunk = self.holding_chunk - self.holding_chunk = None - return return_chunk - if self.sent_last_message is False: - self.sent_last_message = True - return {"type": "message_stop"} - raise StopIteration - except StopIteration: - if self.sent_last_message is False: - self.sent_last_message = True - return {"type": "message_stop"} - raise StopIteration - except Exception as e: - verbose_logger.error( - "Anthropic Adapter - {}\n{}".format(e, traceback.format_exc()) - ) - - async def __anext__(self): - try: - if self.sent_first_chunk is False: - self.sent_first_chunk = True - return { - "type": "message_start", - "message": { - "id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", - "type": "message", - "role": "assistant", - "content": [], - "model": "claude-3-5-sonnet-20240620", - "stop_reason": None, - "stop_sequence": None, - "usage": {"input_tokens": 25, "output_tokens": 1}, - }, - } - if self.sent_content_block_start is False: - self.sent_content_block_start = True - return { - "type": "content_block_start", - "index": 0, - "content_block": {"type": "text", "text": ""}, - } - async for chunk in self.completion_stream: - if chunk == "None" or chunk is None: - raise Exception - processed_chunk = litellm.AnthropicExperimentalPassThroughConfig().translate_streaming_openai_response_to_anthropic( - response=chunk - ) - if ( - processed_chunk["type"] == "message_delta" - and self.sent_content_block_finish is False - ): - self.holding_chunk = processed_chunk - self.sent_content_block_finish = True - return { - "type": "content_block_stop", - "index": 0, - } - elif self.holding_chunk is not None: - return_chunk = self.holding_chunk - self.holding_chunk = processed_chunk - return return_chunk - else: - return processed_chunk - if self.holding_chunk is not None: - return_chunk = self.holding_chunk - self.holding_chunk = None - return return_chunk - if self.sent_last_message is False: - self.sent_last_message = True - return {"type": "message_stop"} - raise StopIteration - except StopIteration: - if self.sent_last_message is False: - self.sent_last_message = True - return {"type": "message_stop"} - raise StopAsyncIteration diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index e571e3f6c6..2036b93692 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -73,6 +73,8 @@ def remove_index_from_tool_calls( def get_litellm_metadata_from_kwargs(kwargs: dict): """ Helper to get litellm metadata from all litellm request kwargs + + Return `litellm_metadata` if it exists, otherwise return `metadata` """ litellm_params = kwargs.get("litellm_params", {}) if litellm_params: diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index b92b0927d2..dab2b9a80b 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -932,6 +932,9 @@ class Logging(LiteLLMLoggingBaseClass): self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["end_time"] = end_time self.model_call_details["cache_hit"] = cache_hit + + if self.call_type == CallTypes.anthropic_messages.value: + result = self._handle_anthropic_messages_response_logging(result=result) ## if model in model cost map - log the response cost ## else set cost to None if ( @@ -2304,6 +2307,37 @@ class Logging(LiteLLMLoggingBaseClass): return complete_streaming_response return None + def _handle_anthropic_messages_response_logging(self, result: Any) -> ModelResponse: + """ + Handles logging for Anthropic messages responses. + + Args: + result: The response object from the model call + + Returns: + The the response object from the model call + + - For Non-streaming responses, we need to transform the response to a ModelResponse object. + - For streaming responses, anthropic_messages handler calls success_handler with a assembled ModelResponse. + """ + if self.stream and isinstance(result, ModelResponse): + return result + + result = litellm.AnthropicConfig().transform_response( + raw_response=self.model_call_details["httpx_response"], + model_response=litellm.ModelResponse(), + model=self.model, + messages=[], + logging_obj=self, + optional_params={}, + api_key="", + request_data={}, + encoding=litellm.encoding, + json_mode=False, + litellm_params={}, + ) + return result + def set_callbacks(callback_list, function_id=None): # noqa: PLR0915 """ diff --git a/litellm/llms/anthropic/experimental_pass_through/messages/handler.py b/litellm/llms/anthropic/experimental_pass_through/messages/handler.py new file mode 100644 index 0000000000..a7dfff74d9 --- /dev/null +++ b/litellm/llms/anthropic/experimental_pass_through/messages/handler.py @@ -0,0 +1,179 @@ +""" +- call /messages on Anthropic API +- Make streaming + non-streaming request - just pass it through direct to Anthropic. No need to do anything special here +- Ensure requests are logged in the DB - stream + non-stream + +""" + +import json +from typing import Any, AsyncIterator, Dict, Optional, Union, cast + +import httpx + +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.base_llm.anthropic_messages.transformation import ( + BaseAnthropicMessagesConfig, +) +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + get_async_httpx_client, +) +from litellm.types.router import GenericLiteLLMParams +from litellm.types.utils import ProviderSpecificHeader +from litellm.utils import ProviderConfigManager, client + + +class AnthropicMessagesHandler: + + @staticmethod + async def _handle_anthropic_streaming( + response: httpx.Response, + request_body: dict, + litellm_logging_obj: LiteLLMLoggingObj, + ) -> AsyncIterator: + """Helper function to handle Anthropic streaming responses using the existing logging handlers""" + from datetime import datetime + + from litellm.proxy.pass_through_endpoints.streaming_handler import ( + PassThroughStreamingHandler, + ) + from litellm.proxy.pass_through_endpoints.success_handler import ( + PassThroughEndpointLogging, + ) + from litellm.proxy.pass_through_endpoints.types import EndpointType + + # Create success handler object + passthrough_success_handler_obj = PassThroughEndpointLogging() + + # Use the existing streaming handler for Anthropic + start_time = datetime.now() + return PassThroughStreamingHandler.chunk_processor( + response=response, + request_body=request_body, + litellm_logging_obj=litellm_logging_obj, + endpoint_type=EndpointType.ANTHROPIC, + start_time=start_time, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route="/v1/messages", + ) + + +@client +async def anthropic_messages( + api_key: str, + model: str, + stream: bool = False, + api_base: Optional[str] = None, + client: Optional[AsyncHTTPHandler] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, +) -> Union[Dict[str, Any], AsyncIterator]: + """ + Makes Anthropic `/v1/messages` API calls In the Anthropic API Spec + """ + # Use provided client or create a new one + optional_params = GenericLiteLLMParams(**kwargs) + model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = ( + litellm.get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=optional_params.api_base, + api_key=optional_params.api_key, + ) + ) + anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = ( + ProviderConfigManager.get_provider_anthropic_messages_config( + model=model, + provider=litellm.LlmProviders(_custom_llm_provider), + ) + ) + if anthropic_messages_provider_config is None: + raise ValueError( + f"Anthropic messages provider config not found for model: {model}" + ) + if client is None or not isinstance(client, AsyncHTTPHandler): + async_httpx_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.ANTHROPIC + ) + else: + async_httpx_client = client + + litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None) + + # Prepare headers + provider_specific_header = cast( + Optional[ProviderSpecificHeader], kwargs.get("provider_specific_header", None) + ) + extra_headers = ( + provider_specific_header.get("extra_headers", {}) + if provider_specific_header + else {} + ) + headers = anthropic_messages_provider_config.validate_environment( + headers=extra_headers or {}, + model=model, + api_key=api_key, + ) + + litellm_logging_obj.update_environment_variables( + model=model, + optional_params=dict(optional_params), + litellm_params={ + "metadata": kwargs.get("metadata", {}), + "preset_cache_key": None, + "stream_response": {}, + **optional_params.model_dump(exclude_unset=True), + }, + custom_llm_provider=_custom_llm_provider, + ) + litellm_logging_obj.model_call_details.update(kwargs) + + # Prepare request body + request_body = kwargs.copy() + request_body = { + k: v + for k, v in request_body.items() + if k + in anthropic_messages_provider_config.get_supported_anthropic_messages_params( + model=model + ) + } + request_body["stream"] = stream + request_body["model"] = model + litellm_logging_obj.stream = stream + + # Make the request + request_url = anthropic_messages_provider_config.get_complete_url( + api_base=api_base, model=model + ) + + litellm_logging_obj.pre_call( + input=[{"role": "user", "content": json.dumps(request_body)}], + api_key="", + additional_args={ + "complete_input_dict": request_body, + "api_base": str(request_url), + "headers": headers, + }, + ) + + response = await async_httpx_client.post( + url=request_url, + headers=headers, + data=json.dumps(request_body), + stream=stream, + ) + response.raise_for_status() + + # used for logging + cost tracking + litellm_logging_obj.model_call_details["httpx_response"] = response + + if stream: + return await AnthropicMessagesHandler._handle_anthropic_streaming( + response=response, + request_body=request_body, + litellm_logging_obj=litellm_logging_obj, + ) + else: + return response.json() diff --git a/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py b/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py new file mode 100644 index 0000000000..e9b598f18d --- /dev/null +++ b/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py @@ -0,0 +1,47 @@ +from typing import Optional + +from litellm.llms.base_llm.anthropic_messages.transformation import ( + BaseAnthropicMessagesConfig, +) + +DEFAULT_ANTHROPIC_API_BASE = "https://api.anthropic.com" +DEFAULT_ANTHROPIC_API_VERSION = "2023-06-01" + + +class AnthropicMessagesConfig(BaseAnthropicMessagesConfig): + def get_supported_anthropic_messages_params(self, model: str) -> list: + return [ + "messages", + "model", + "system", + "max_tokens", + "stop_sequences", + "temperature", + "top_p", + "top_k", + "tools", + "tool_choice", + "thinking", + # TODO: Add Anthropic `metadata` support + # "metadata", + ] + + def get_complete_url(self, api_base: Optional[str], model: str) -> str: + api_base = api_base or DEFAULT_ANTHROPIC_API_BASE + if not api_base.endswith("/v1/messages"): + api_base = f"{api_base}/v1/messages" + return api_base + + def validate_environment( + self, + headers: dict, + model: str, + api_key: Optional[str] = None, + ) -> dict: + if "x-api-key" not in headers: + headers["x-api-key"] = api_key + if "anthropic-version" not in headers: + headers["anthropic-version"] = DEFAULT_ANTHROPIC_API_VERSION + if "content-type" not in headers: + headers["content-type"] = "application/json" + return headers diff --git a/litellm/llms/anthropic/experimental_pass_through/transformation.py b/litellm/llms/anthropic/experimental_pass_through/transformation.py deleted file mode 100644 index b24cf47ad4..0000000000 --- a/litellm/llms/anthropic/experimental_pass_through/transformation.py +++ /dev/null @@ -1,412 +0,0 @@ -import json -from typing import List, Literal, Optional, Tuple, Union - -from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice - -from litellm.types.llms.anthropic import ( - AllAnthropicToolsValues, - AnthopicMessagesAssistantMessageParam, - AnthropicFinishReason, - AnthropicMessagesRequest, - AnthropicMessagesToolChoice, - AnthropicMessagesUserMessageParam, - AnthropicResponse, - AnthropicResponseContentBlockText, - AnthropicResponseContentBlockToolUse, - AnthropicResponseUsageBlock, - ContentBlockDelta, - ContentJsonBlockDelta, - ContentTextBlockDelta, - MessageBlockDelta, - MessageDelta, - UsageDelta, -) -from litellm.types.llms.openai import ( - AllMessageValues, - ChatCompletionAssistantMessage, - ChatCompletionAssistantToolCall, - ChatCompletionImageObject, - ChatCompletionImageUrlObject, - ChatCompletionRequest, - ChatCompletionSystemMessage, - ChatCompletionTextObject, - ChatCompletionToolCallFunctionChunk, - ChatCompletionToolChoiceFunctionParam, - ChatCompletionToolChoiceObjectParam, - ChatCompletionToolChoiceValues, - ChatCompletionToolMessage, - ChatCompletionToolParam, - ChatCompletionToolParamFunctionChunk, - ChatCompletionUserMessage, -) -from litellm.types.utils import Choices, ModelResponse, Usage - - -class AnthropicExperimentalPassThroughConfig: - def __init__(self): - pass - - ### FOR [BETA] `/v1/messages` endpoint support - - def translatable_anthropic_params(self) -> List: - """ - Which anthropic params, we need to translate to the openai format. - """ - return ["messages", "metadata", "system", "tool_choice", "tools"] - - def translate_anthropic_messages_to_openai( # noqa: PLR0915 - self, - messages: List[ - Union[ - AnthropicMessagesUserMessageParam, - AnthopicMessagesAssistantMessageParam, - ] - ], - ) -> List: - new_messages: List[AllMessageValues] = [] - for m in messages: - user_message: Optional[ChatCompletionUserMessage] = None - tool_message_list: List[ChatCompletionToolMessage] = [] - new_user_content_list: List[ - Union[ChatCompletionTextObject, ChatCompletionImageObject] - ] = [] - ## USER MESSAGE ## - if m["role"] == "user": - ## translate user message - message_content = m.get("content") - if message_content and isinstance(message_content, str): - user_message = ChatCompletionUserMessage( - role="user", content=message_content - ) - elif message_content and isinstance(message_content, list): - for content in message_content: - if content["type"] == "text": - text_obj = ChatCompletionTextObject( - type="text", text=content["text"] - ) - new_user_content_list.append(text_obj) - elif content["type"] == "image": - image_url = ChatCompletionImageUrlObject( - url=f"data:{content['type']};base64,{content['source']}" - ) - image_obj = ChatCompletionImageObject( - type="image_url", image_url=image_url - ) - - new_user_content_list.append(image_obj) - elif content["type"] == "tool_result": - if "content" not in content: - tool_result = ChatCompletionToolMessage( - role="tool", - tool_call_id=content["tool_use_id"], - content="", - ) - tool_message_list.append(tool_result) - elif isinstance(content["content"], str): - tool_result = ChatCompletionToolMessage( - role="tool", - tool_call_id=content["tool_use_id"], - content=content["content"], - ) - tool_message_list.append(tool_result) - elif isinstance(content["content"], list): - for c in content["content"]: - if c["type"] == "text": - tool_result = ChatCompletionToolMessage( - role="tool", - tool_call_id=content["tool_use_id"], - content=c["text"], - ) - tool_message_list.append(tool_result) - elif c["type"] == "image": - image_str = ( - f"data:{c['type']};base64,{c['source']}" - ) - tool_result = ChatCompletionToolMessage( - role="tool", - tool_call_id=content["tool_use_id"], - content=image_str, - ) - tool_message_list.append(tool_result) - - if user_message is not None: - new_messages.append(user_message) - - if len(new_user_content_list) > 0: - new_messages.append({"role": "user", "content": new_user_content_list}) # type: ignore - - if len(tool_message_list) > 0: - new_messages.extend(tool_message_list) - - ## ASSISTANT MESSAGE ## - assistant_message_str: Optional[str] = None - tool_calls: List[ChatCompletionAssistantToolCall] = [] - if m["role"] == "assistant": - if isinstance(m["content"], str): - assistant_message_str = m["content"] - elif isinstance(m["content"], list): - for content in m["content"]: - if content["type"] == "text": - if assistant_message_str is None: - assistant_message_str = content["text"] - else: - assistant_message_str += content["text"] - elif content["type"] == "tool_use": - function_chunk = ChatCompletionToolCallFunctionChunk( - name=content["name"], - arguments=json.dumps(content["input"]), - ) - - tool_calls.append( - ChatCompletionAssistantToolCall( - id=content["id"], - type="function", - function=function_chunk, - ) - ) - - if assistant_message_str is not None or len(tool_calls) > 0: - assistant_message = ChatCompletionAssistantMessage( - role="assistant", - content=assistant_message_str, - ) - if len(tool_calls) > 0: - assistant_message["tool_calls"] = tool_calls - new_messages.append(assistant_message) - - return new_messages - - def translate_anthropic_tool_choice_to_openai( - self, tool_choice: AnthropicMessagesToolChoice - ) -> ChatCompletionToolChoiceValues: - if tool_choice["type"] == "any": - return "required" - elif tool_choice["type"] == "auto": - return "auto" - elif tool_choice["type"] == "tool": - tc_function_param = ChatCompletionToolChoiceFunctionParam( - name=tool_choice.get("name", "") - ) - return ChatCompletionToolChoiceObjectParam( - type="function", function=tc_function_param - ) - else: - raise ValueError( - "Incompatible tool choice param submitted - {}".format(tool_choice) - ) - - def translate_anthropic_tools_to_openai( - self, tools: List[AllAnthropicToolsValues] - ) -> List[ChatCompletionToolParam]: - new_tools: List[ChatCompletionToolParam] = [] - mapped_tool_params = ["name", "input_schema", "description"] - for tool in tools: - function_chunk = ChatCompletionToolParamFunctionChunk( - name=tool["name"], - ) - if "input_schema" in tool: - function_chunk["parameters"] = tool["input_schema"] # type: ignore - if "description" in tool: - function_chunk["description"] = tool["description"] # type: ignore - - for k, v in tool.items(): - if k not in mapped_tool_params: # pass additional computer kwargs - function_chunk.setdefault("parameters", {}).update({k: v}) - new_tools.append( - ChatCompletionToolParam(type="function", function=function_chunk) - ) - - return new_tools - - def translate_anthropic_to_openai( - self, anthropic_message_request: AnthropicMessagesRequest - ) -> ChatCompletionRequest: - """ - This is used by the beta Anthropic Adapter, for translating anthropic `/v1/messages` requests to the openai format. - """ - new_messages: List[AllMessageValues] = [] - - ## CONVERT ANTHROPIC MESSAGES TO OPENAI - new_messages = self.translate_anthropic_messages_to_openai( - messages=anthropic_message_request["messages"] - ) - ## ADD SYSTEM MESSAGE TO MESSAGES - if "system" in anthropic_message_request: - new_messages.insert( - 0, - ChatCompletionSystemMessage( - role="system", content=anthropic_message_request["system"] - ), - ) - - new_kwargs: ChatCompletionRequest = { - "model": anthropic_message_request["model"], - "messages": new_messages, - } - ## CONVERT METADATA (user_id) - if "metadata" in anthropic_message_request: - if "user_id" in anthropic_message_request["metadata"]: - new_kwargs["user"] = anthropic_message_request["metadata"]["user_id"] - - # Pass litellm proxy specific metadata - if "litellm_metadata" in anthropic_message_request: - # metadata will be passed to litellm.acompletion(), it's a litellm_param - new_kwargs["metadata"] = anthropic_message_request.pop("litellm_metadata") - - ## CONVERT TOOL CHOICE - if "tool_choice" in anthropic_message_request: - new_kwargs["tool_choice"] = self.translate_anthropic_tool_choice_to_openai( - tool_choice=anthropic_message_request["tool_choice"] - ) - ## CONVERT TOOLS - if "tools" in anthropic_message_request: - new_kwargs["tools"] = self.translate_anthropic_tools_to_openai( - tools=anthropic_message_request["tools"] - ) - - translatable_params = self.translatable_anthropic_params() - for k, v in anthropic_message_request.items(): - if k not in translatable_params: # pass remaining params as is - new_kwargs[k] = v # type: ignore - - return new_kwargs - - def _translate_openai_content_to_anthropic( - self, choices: List[Choices] - ) -> List[ - Union[AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse] - ]: - new_content: List[ - Union[ - AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse - ] - ] = [] - for choice in choices: - if ( - choice.message.tool_calls is not None - and len(choice.message.tool_calls) > 0 - ): - for tool_call in choice.message.tool_calls: - new_content.append( - AnthropicResponseContentBlockToolUse( - type="tool_use", - id=tool_call.id, - name=tool_call.function.name or "", - input=json.loads(tool_call.function.arguments), - ) - ) - elif choice.message.content is not None: - new_content.append( - AnthropicResponseContentBlockText( - type="text", text=choice.message.content - ) - ) - - return new_content - - def _translate_openai_finish_reason_to_anthropic( - self, openai_finish_reason: str - ) -> AnthropicFinishReason: - if openai_finish_reason == "stop": - return "end_turn" - elif openai_finish_reason == "length": - return "max_tokens" - elif openai_finish_reason == "tool_calls": - return "tool_use" - return "end_turn" - - def translate_openai_response_to_anthropic( - self, response: ModelResponse - ) -> AnthropicResponse: - ## translate content block - anthropic_content = self._translate_openai_content_to_anthropic(choices=response.choices) # type: ignore - ## extract finish reason - anthropic_finish_reason = self._translate_openai_finish_reason_to_anthropic( - openai_finish_reason=response.choices[0].finish_reason # type: ignore - ) - # extract usage - usage: Usage = getattr(response, "usage") - anthropic_usage = AnthropicResponseUsageBlock( - input_tokens=usage.prompt_tokens or 0, - output_tokens=usage.completion_tokens or 0, - ) - translated_obj = AnthropicResponse( - id=response.id, - type="message", - role="assistant", - model=response.model or "unknown-model", - stop_sequence=None, - usage=anthropic_usage, - content=anthropic_content, - stop_reason=anthropic_finish_reason, - ) - - return translated_obj - - def _translate_streaming_openai_chunk_to_anthropic( - self, choices: List[OpenAIStreamingChoice] - ) -> Tuple[ - Literal["text_delta", "input_json_delta"], - Union[ContentTextBlockDelta, ContentJsonBlockDelta], - ]: - text: str = "" - partial_json: Optional[str] = None - for choice in choices: - if choice.delta.content is not None: - text += choice.delta.content - elif choice.delta.tool_calls is not None: - partial_json = "" - for tool in choice.delta.tool_calls: - if ( - tool.function is not None - and tool.function.arguments is not None - ): - partial_json += tool.function.arguments - - if partial_json is not None: - return "input_json_delta", ContentJsonBlockDelta( - type="input_json_delta", partial_json=partial_json - ) - else: - return "text_delta", ContentTextBlockDelta(type="text_delta", text=text) - - def translate_streaming_openai_response_to_anthropic( - self, response: ModelResponse - ) -> Union[ContentBlockDelta, MessageBlockDelta]: - ## base case - final chunk w/ finish reason - if response.choices[0].finish_reason is not None: - delta = MessageDelta( - stop_reason=self._translate_openai_finish_reason_to_anthropic( - response.choices[0].finish_reason - ), - ) - if getattr(response, "usage", None) is not None: - litellm_usage_chunk: Optional[Usage] = response.usage # type: ignore - elif ( - hasattr(response, "_hidden_params") - and "usage" in response._hidden_params - ): - litellm_usage_chunk = response._hidden_params["usage"] - else: - litellm_usage_chunk = None - if litellm_usage_chunk is not None: - usage_delta = UsageDelta( - input_tokens=litellm_usage_chunk.prompt_tokens or 0, - output_tokens=litellm_usage_chunk.completion_tokens or 0, - ) - else: - usage_delta = UsageDelta(input_tokens=0, output_tokens=0) - return MessageBlockDelta( - type="message_delta", delta=delta, usage=usage_delta - ) - ( - type_of_content, - content_block_delta, - ) = self._translate_streaming_openai_chunk_to_anthropic( - choices=response.choices # type: ignore - ) - return ContentBlockDelta( - type="content_block_delta", - index=response.choices[0].index, - delta=content_block_delta, - ) diff --git a/litellm/llms/base_llm/anthropic_messages/transformation.py b/litellm/llms/base_llm/anthropic_messages/transformation.py new file mode 100644 index 0000000000..7619ffbbf6 --- /dev/null +++ b/litellm/llms/base_llm/anthropic_messages/transformation.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class BaseAnthropicMessagesConfig(ABC): + @abstractmethod + def validate_environment( + self, + headers: dict, + model: str, + api_key: Optional[str] = None, + ) -> dict: + pass + + @abstractmethod + def get_complete_url(self, api_base: Optional[str], model: str) -> str: + """ + OPTIONAL + + Get the complete url for the request + + Some providers need `model` in `api_base` + """ + return api_base or "" + + @abstractmethod + def get_supported_anthropic_messages_params(self, model: str) -> list: + pass diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index ba27de78be..89f02f8228 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1963,7 +1963,7 @@ class ProxyException(Exception): code: Optional[Union[int, str]] = None, headers: Optional[Dict[str, str]] = None, ): - self.message = message + self.message = str(message) self.type = type self.param = param diff --git a/litellm/proxy/anthropic_endpoints/endpoints.py b/litellm/proxy/anthropic_endpoints/endpoints.py new file mode 100644 index 0000000000..a3956ef274 --- /dev/null +++ b/litellm/proxy/anthropic_endpoints/endpoints.py @@ -0,0 +1,252 @@ +""" +Unified /v1/messages endpoint - (Anthropic Spec) +""" + +import asyncio +import json +import time +import traceback + +from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from fastapi.responses import StreamingResponse + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import * +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.common_utils.http_parsing_utils import _read_request_body +from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request +from litellm.proxy.utils import ProxyLogging + +router = APIRouter() + + +async def async_data_generator_anthropic( + response, + user_api_key_dict: UserAPIKeyAuth, + request_data: dict, + proxy_logging_obj: ProxyLogging, +): + verbose_proxy_logger.debug("inside generator") + try: + time.time() + async for chunk in response: + verbose_proxy_logger.debug( + "async_data_generator: received streaming chunk - {}".format(chunk) + ) + ### CALL HOOKS ### - modify outgoing data + chunk = await proxy_logging_obj.async_post_call_streaming_hook( + user_api_key_dict=user_api_key_dict, response=chunk + ) + + yield chunk + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}".format( + str(e) + ) + ) + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, + original_exception=e, + request_data=request_data, + ) + verbose_proxy_logger.debug( + f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`" + ) + + if isinstance(e, HTTPException): + raise e + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}\n\n{error_traceback}" + + proxy_exception = ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + error_returned = json.dumps({"error": proxy_exception.to_dict()}) + yield f"data: {error_returned}\n\n" + + +@router.post( + "/v1/messages", + tags=["[beta] Anthropic `/v1/messages`"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +async def anthropic_response( # noqa: PLR0915 + fastapi_response: Response, + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Use `{PROXY_BASE_URL}/anthropic/v1/messages` instead - [Docs](https://docs.litellm.ai/docs/anthropic_completion). + + This was a BETA endpoint that calls 100+ LLMs in the anthropic format. + """ + from litellm.proxy.proxy_server import ( + general_settings, + get_custom_headers, + llm_router, + proxy_config, + proxy_logging_obj, + user_api_base, + user_max_tokens, + user_model, + user_request_timeout, + user_temperature, + version, + ) + + request_data = await _read_request_body(request=request) + data: dict = {**request_data} + try: + data["model"] = ( + general_settings.get("completion_model", None) # server default + or user_model # model name passed via cli args + or data.get("model", None) # default passed in http request + ) + if user_model: + data["model"] = user_model + + data = await add_litellm_data_to_request( + data=data, # type: ignore + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) + + # override with user settings, these are params passed via cli + if user_temperature: + data["temperature"] = user_temperature + if user_request_timeout: + data["request_timeout"] = user_request_timeout + if user_max_tokens: + data["max_tokens"] = user_max_tokens + if user_api_base: + data["api_base"] = user_api_base + + ### MODEL ALIAS MAPPING ### + # check if model name in model alias map + # get the actual model name + if data["model"] in litellm.model_alias_map: + data["model"] = litellm.model_alias_map[data["model"]] + + ### CALL HOOKS ### - modify incoming data before calling the model + data = await proxy_logging_obj.pre_call_hook( # type: ignore + user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" + ) + + ### ROUTE THE REQUESTs ### + router_model_names = llm_router.model_names if llm_router is not None else [] + + # skip router if user passed their key + if ( + llm_router is not None and data["model"] in router_model_names + ): # model in router model list + llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data)) + elif ( + llm_router is not None + and llm_router.model_group_alias is not None + and data["model"] in llm_router.model_group_alias + ): # model set in model_group_alias + llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data)) + elif ( + llm_router is not None and data["model"] in llm_router.deployment_names + ): # model in router deployments, calling a specific deployment on the router + llm_response = asyncio.create_task( + llm_router.aanthropic_messages(**data, specific_deployment=True) + ) + elif ( + llm_router is not None and data["model"] in llm_router.get_model_ids() + ): # model in router model list + llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data)) + elif ( + llm_router is not None + and data["model"] not in router_model_names + and ( + llm_router.default_deployment is not None + or len(llm_router.pattern_router.patterns) > 0 + ) + ): # model in router deployments, calling a specific deployment on the router + llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data)) + elif user_model is not None: # `litellm --model ` + llm_response = asyncio.create_task(litellm.anthropic_messages(**data)) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": "completion: Invalid model name passed in model=" + + data.get("model", "") + }, + ) + + # Await the llm_response task + response = await llm_response + + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + response_cost = hidden_params.get("response_cost", None) or "" + + ### ALERTING ### + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) + + verbose_proxy_logger.debug("final response: %s", response) + + fastapi_response.headers.update( + get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + response_cost=response_cost, + request_data=data, + hidden_params=hidden_params, + ) + ) + + if ( + "stream" in data and data["stream"] is True + ): # use generate_responses to stream responses + selected_data_generator = async_data_generator_anthropic( + response=response, + user_api_key_dict=user_api_key_dict, + request_data=data, + proxy_logging_obj=proxy_logging_obj, + ) + + return StreamingResponse( + selected_data_generator, # type: ignore + media_type="text/event-stream", + ) + + verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response)) + return response + except Exception as e: + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.anthropic_response(): Exception occured - {}".format( + str(e) + ) + ) + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) diff --git a/litellm/proxy/example_config_yaml/pass_through_config.yaml b/litellm/proxy/example_config_yaml/pass_through_config.yaml index 41d581249f..ccc13f4d5a 100644 --- a/litellm/proxy/example_config_yaml/pass_through_config.yaml +++ b/litellm/proxy/example_config_yaml/pass_through_config.yaml @@ -1,9 +1,29 @@ model_list: - - model_name: fake-openai-endpoint - litellm_params: - model: openai/fake - api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ + - model_name: fake-openai-endpoint + litellm_params: + model: openai/fake + api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + - model_name: claude-3-5-sonnet-20241022 + litellm_params: + model: anthropic/claude-3-5-sonnet-20241022 + api_key: os.environ/ANTHROPIC_API_KEY + - model_name: claude-special-alias + litellm_params: + model: anthropic/claude-3-haiku-20240307 + api_key: os.environ/ANTHROPIC_API_KEY + - model_name: claude-3-5-sonnet-20241022 + litellm_params: + model: anthropic/claude-3-5-sonnet-20241022 + api_key: os.environ/ANTHROPIC_API_KEY + - model_name: claude-3-7-sonnet-20250219 + litellm_params: + model: anthropic/claude-3-7-sonnet-20250219 + api_key: os.environ/ANTHROPIC_API_KEY + - model_name: anthropic/* + litellm_params: + model: anthropic/* + api_key: os.environ/ANTHROPIC_API_KEY general_settings: - master_key: sk-1234 - custom_auth: custom_auth_basic.user_api_key_auth \ No newline at end of file + master_key: sk-1234 + custom_auth: custom_auth_basic.user_api_key_auth \ No newline at end of file diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 86c32b610a..eef4a55ed3 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -4,7 +4,22 @@ model_list: model: openai/my-fake-model api_key: my-fake-key api_base: https://exampleopenaiendpoint-production.up.railway.app/ - + - model_name: claude-special-alias + litellm_params: + model: anthropic/claude-3-haiku-20240307 + api_key: os.environ/ANTHROPIC_API_KEY + - model_name: claude-3-5-sonnet-20241022 + litellm_params: + model: anthropic/claude-3-5-sonnet-20241022 + api_key: os.environ/ANTHROPIC_API_KEY + - model_name: claude-3-7-sonnet-20250219 + litellm_params: + model: anthropic/claude-3-7-sonnet-20250219 + api_key: os.environ/ANTHROPIC_API_KEY + - model_name: anthropic/* + litellm_params: + model: anthropic/* + api_key: os.environ/ANTHROPIC_API_KEY general_settings: store_model_in_db: true diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index cf5e7543d7..0fda92b878 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -120,6 +120,7 @@ from litellm.proxy._types import * from litellm.proxy.analytics_endpoints.analytics_endpoints import ( router as analytics_router, ) +from litellm.proxy.anthropic_endpoints.endpoints import router as anthropic_router from litellm.proxy.auth.auth_checks import log_db_metrics from litellm.proxy.auth.auth_utils import check_response_size_is_safe from litellm.proxy.auth.handle_jwt import JWTHandler @@ -3065,58 +3066,6 @@ async def async_data_generator( yield f"data: {error_returned}\n\n" -async def async_data_generator_anthropic( - response, user_api_key_dict: UserAPIKeyAuth, request_data: dict -): - verbose_proxy_logger.debug("inside generator") - try: - time.time() - async for chunk in response: - verbose_proxy_logger.debug( - "async_data_generator: received streaming chunk - {}".format(chunk) - ) - ### CALL HOOKS ### - modify outgoing data - chunk = await proxy_logging_obj.async_post_call_streaming_hook( - user_api_key_dict=user_api_key_dict, response=chunk - ) - - event_type = chunk.get("type") - - try: - yield f"event: {event_type}\ndata:{json.dumps(chunk)}\n\n" - except Exception as e: - yield f"event: {event_type}\ndata:{str(e)}\n\n" - except Exception as e: - verbose_proxy_logger.exception( - "litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}".format( - str(e) - ) - ) - await proxy_logging_obj.post_call_failure_hook( - user_api_key_dict=user_api_key_dict, - original_exception=e, - request_data=request_data, - ) - verbose_proxy_logger.debug( - f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`" - ) - - if isinstance(e, HTTPException): - raise e - else: - error_traceback = traceback.format_exc() - error_msg = f"{str(e)}\n\n{error_traceback}" - - proxy_exception = ProxyException( - message=getattr(e, "message", error_msg), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", 500), - ) - error_returned = json.dumps({"error": proxy_exception.to_dict()}) - yield f"data: {error_returned}\n\n" - - def select_data_generator( response, user_api_key_dict: UserAPIKeyAuth, request_data: dict ): @@ -5524,224 +5473,6 @@ async def moderations( ) -#### ANTHROPIC ENDPOINTS #### - - -@router.post( - "/v1/messages", - tags=["[beta] Anthropic `/v1/messages`"], - dependencies=[Depends(user_api_key_auth)], - response_model=AnthropicResponse, - include_in_schema=False, -) -async def anthropic_response( # noqa: PLR0915 - anthropic_data: AnthropicMessagesRequest, - fastapi_response: Response, - request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - 🚨 DEPRECATED ENDPOINT🚨 - - Use `{PROXY_BASE_URL}/anthropic/v1/messages` instead - [Docs](https://docs.litellm.ai/docs/anthropic_completion). - - This was a BETA endpoint that calls 100+ LLMs in the anthropic format. - """ - from litellm import adapter_completion - from litellm.adapters.anthropic_adapter import anthropic_adapter - - litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}] - - global user_temperature, user_request_timeout, user_max_tokens, user_api_base - request_data = await _read_request_body(request=request) - data: dict = {**request_data, "adapter_id": "anthropic"} - try: - data["model"] = ( - general_settings.get("completion_model", None) # server default - or user_model # model name passed via cli args - or data.get("model", None) # default passed in http request - ) - if user_model: - data["model"] = user_model - - data = await add_litellm_data_to_request( - data=data, # type: ignore - request=request, - general_settings=general_settings, - user_api_key_dict=user_api_key_dict, - version=version, - proxy_config=proxy_config, - ) - - # override with user settings, these are params passed via cli - if user_temperature: - data["temperature"] = user_temperature - if user_request_timeout: - data["request_timeout"] = user_request_timeout - if user_max_tokens: - data["max_tokens"] = user_max_tokens - if user_api_base: - data["api_base"] = user_api_base - - ### MODEL ALIAS MAPPING ### - # check if model name in model alias map - # get the actual model name - if data["model"] in litellm.model_alias_map: - data["model"] = litellm.model_alias_map[data["model"]] - - ### CALL HOOKS ### - modify incoming data before calling the model - data = await proxy_logging_obj.pre_call_hook( # type: ignore - user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" - ) - - ### ROUTE THE REQUESTs ### - router_model_names = llm_router.model_names if llm_router is not None else [] - # skip router if user passed their key - if "api_key" in data: - llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) - elif ( - llm_router is not None and data["model"] in router_model_names - ): # model in router model list - llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) - elif ( - llm_router is not None - and llm_router.model_group_alias is not None - and data["model"] in llm_router.model_group_alias - ): # model set in model_group_alias - llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) - elif ( - llm_router is not None and data["model"] in llm_router.deployment_names - ): # model in router deployments, calling a specific deployment on the router - llm_response = asyncio.create_task( - llm_router.aadapter_completion(**data, specific_deployment=True) - ) - elif ( - llm_router is not None and data["model"] in llm_router.get_model_ids() - ): # model in router model list - llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) - elif ( - llm_router is not None - and data["model"] not in router_model_names - and ( - llm_router.default_deployment is not None - or len(llm_router.pattern_router.patterns) > 0 - ) - ): # model in router deployments, calling a specific deployment on the router - llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) - elif user_model is not None: # `litellm --model ` - llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={ - "error": "completion: Invalid model name passed in model=" - + data.get("model", "") - }, - ) - - # Await the llm_response task - response = await llm_response - - hidden_params = getattr(response, "_hidden_params", {}) or {} - model_id = hidden_params.get("model_id", None) or "" - cache_key = hidden_params.get("cache_key", None) or "" - api_base = hidden_params.get("api_base", None) or "" - response_cost = hidden_params.get("response_cost", None) or "" - - ### ALERTING ### - asyncio.create_task( - proxy_logging_obj.update_request_status( - litellm_call_id=data.get("litellm_call_id", ""), status="success" - ) - ) - - verbose_proxy_logger.debug("final response: %s", response) - - fastapi_response.headers.update( - get_custom_headers( - user_api_key_dict=user_api_key_dict, - model_id=model_id, - cache_key=cache_key, - api_base=api_base, - version=version, - response_cost=response_cost, - request_data=data, - hidden_params=hidden_params, - ) - ) - - if ( - "stream" in data and data["stream"] is True - ): # use generate_responses to stream responses - selected_data_generator = async_data_generator_anthropic( - response=response, - user_api_key_dict=user_api_key_dict, - request_data=data, - ) - return StreamingResponse( - selected_data_generator, - media_type="text/event-stream", - ) - - verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response)) - return response - except RejectedRequestError as e: - _data = e.request_data - await proxy_logging_obj.post_call_failure_hook( - user_api_key_dict=user_api_key_dict, - original_exception=e, - request_data=_data, - ) - if _data.get("stream", None) is not None and _data["stream"] is True: - _chat_response = litellm.ModelResponse() - _usage = litellm.Usage( - prompt_tokens=0, - completion_tokens=0, - total_tokens=0, - ) - _chat_response.usage = _usage # type: ignore - _chat_response.choices[0].message.content = e.message # type: ignore - _iterator = litellm.utils.ModelResponseIterator( - model_response=_chat_response, convert_to_delta=True - ) - _streaming_response = litellm.TextCompletionStreamWrapper( - completion_stream=_iterator, - model=_data.get("model", ""), - ) - - selected_data_generator = select_data_generator( - response=_streaming_response, - user_api_key_dict=user_api_key_dict, - request_data=data, - ) - - return StreamingResponse( - selected_data_generator, - media_type="text/event-stream", - headers={}, - ) - else: - _response = litellm.TextCompletionResponse() - _response.choices[0].text = e.message - return _response - except Exception as e: - await proxy_logging_obj.post_call_failure_hook( - user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data - ) - verbose_proxy_logger.exception( - "litellm.proxy.proxy_server.anthropic_response(): Exception occured - {}".format( - str(e) - ) - ) - error_msg = f"{str(e)}" - raise ProxyException( - message=getattr(e, "message", error_msg), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", 500), - ) - - #### DEV UTILS #### # @router.get( @@ -8840,6 +8571,7 @@ app.include_router(rerank_router) app.include_router(fine_tuning_router) app.include_router(vertex_router) app.include_router(llm_passthrough_router) +app.include_router(anthropic_router) app.include_router(langfuse_router) app.include_router(pass_through_router) app.include_router(health_router) diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index 78304fa6b0..122432c787 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -10,6 +10,7 @@ from pydantic import BaseModel import litellm from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload from litellm.proxy.utils import PrismaClient, hash_token from litellm.types.utils import StandardLoggingPayload @@ -119,9 +120,7 @@ def get_logging_payload( # noqa: PLR0915 response_obj = {} # standardize this function to be used across, s3, dynamoDB, langfuse logging litellm_params = kwargs.get("litellm_params", {}) - metadata = ( - litellm_params.get("metadata", {}) or {} - ) # if litellm_params['metadata'] == None + metadata = get_litellm_metadata_from_kwargs(kwargs) metadata = _add_proxy_server_request_to_metadata( metadata=metadata, litellm_params=litellm_params ) diff --git a/litellm/router.py b/litellm/router.py index 8b0bbd6b9f..9ddd596006 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -580,6 +580,9 @@ class Router: self.amoderation = self.factory_function( litellm.amoderation, call_type="moderation" ) + self.aanthropic_messages = self.factory_function( + litellm.anthropic_messages, call_type="anthropic_messages" + ) def discard(self): """ @@ -2349,6 +2352,89 @@ class Router: self.fail_calls[model] += 1 raise e + async def _ageneric_api_call_with_fallbacks( + self, model: str, original_function: Callable, **kwargs + ): + """ + Make a generic LLM API call through the router, this allows you to use retries/fallbacks with litellm router + + Args: + model: The model to use + handler_function: The handler function to call (e.g., litellm.anthropic_messages) + **kwargs: Additional arguments to pass to the handler function + + Returns: + The response from the handler function + """ + handler_name = original_function.__name__ + try: + verbose_router_logger.debug( + f"Inside _ageneric_api_call() - handler: {handler_name}, model: {model}; kwargs: {kwargs}" + ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + deployment = await self.async_get_available_deployment( + model=model, + request_kwargs=kwargs, + messages=kwargs.get("messages", None), + specific_deployment=kwargs.pop("specific_deployment", None), + ) + self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs) + + data = deployment["litellm_params"].copy() + model_name = data["model"] + + model_client = self._get_async_openai_model_client( + deployment=deployment, + kwargs=kwargs, + ) + self.total_calls[model_name] += 1 + + response = original_function( + **{ + **data, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ) + + rpm_semaphore = self._get_client( + deployment=deployment, + kwargs=kwargs, + client_type="max_parallel_requests", + ) + + if rpm_semaphore is not None and isinstance( + rpm_semaphore, asyncio.Semaphore + ): + async with rpm_semaphore: + """ + - Check rpm limits before making the call + - If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe) + """ + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response # type: ignore + else: + await self.async_routing_strategy_pre_call_checks( + deployment=deployment, parent_otel_span=parent_otel_span + ) + response = await response # type: ignore + + self.success_calls[model_name] += 1 + verbose_router_logger.info( + f"{handler_name}(model={model_name})\033[32m 200 OK\033[0m" + ) + return response + except Exception as e: + verbose_router_logger.info( + f"{handler_name}(model={model})\033[31m Exception {str(e)}\033[0m" + ) + if model is not None: + self.fail_calls[model] += 1 + raise e + def embedding( self, model: str, @@ -2869,10 +2955,14 @@ class Router: def factory_function( self, original_function: Callable, - call_type: Literal["assistants", "moderation"] = "assistants", + call_type: Literal[ + "assistants", "moderation", "anthropic_messages" + ] = "assistants", ): async def new_function( - custom_llm_provider: Optional[Literal["openai", "azure"]] = None, + custom_llm_provider: Optional[ + Literal["openai", "azure", "anthropic"] + ] = None, client: Optional["AsyncOpenAI"] = None, **kwargs, ): @@ -2889,13 +2979,18 @@ class Router: original_function=original_function, **kwargs, ) + elif call_type == "anthropic_messages": + return await self._ageneric_api_call_with_fallbacks( # type: ignore + original_function=original_function, + **kwargs, + ) return new_function async def _pass_through_assistants_endpoint_factory( self, original_function: Callable, - custom_llm_provider: Optional[Literal["openai", "azure"]] = None, + custom_llm_provider: Optional[Literal["openai", "azure", "anthropic"]] = None, client: Optional[AsyncOpenAI] = None, **kwargs, ): diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 7e6b95ab15..894ef70933 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -186,6 +186,7 @@ class CallTypes(Enum): aretrieve_batch = "aretrieve_batch" retrieve_batch = "retrieve_batch" pass_through = "pass_through_endpoint" + anthropic_messages = "anthropic_messages" CallTypesLiteral = Literal[ @@ -209,6 +210,7 @@ CallTypesLiteral = Literal[ "create_batch", "acreate_batch", "pass_through_endpoint", + "anthropic_messages", ] diff --git a/litellm/utils.py b/litellm/utils.py index a6dd10ad9a..469ad50059 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -191,6 +191,9 @@ from typing import ( from openai import OpenAIError as OriginalError from litellm.litellm_core_utils.thread_pool_executor import executor +from litellm.llms.base_llm.anthropic_messages.transformation import ( + BaseAnthropicMessagesConfig, +) from litellm.llms.base_llm.audio_transcription.transformation import ( BaseAudioTranscriptionConfig, ) @@ -6245,6 +6248,15 @@ class ProviderConfigManager: return litellm.JinaAIRerankConfig() return litellm.CohereRerankConfig() + @staticmethod + def get_provider_anthropic_messages_config( + model: str, + provider: LlmProviders, + ) -> Optional[BaseAnthropicMessagesConfig]: + if litellm.LlmProviders.ANTHROPIC == provider: + return litellm.AnthropicMessagesConfig() + return None + @staticmethod def get_provider_audio_transcription_config( model: str, diff --git a/tests/local_testing/test_pass_through_endpoints.py b/tests/local_testing/test_pass_through_endpoints.py index 7e9dfcfc79..0215e295be 100644 --- a/tests/local_testing/test_pass_through_endpoints.py +++ b/tests/local_testing/test_pass_through_endpoints.py @@ -329,57 +329,3 @@ async def test_aaapass_through_endpoint_pass_through_keys_langfuse( setattr( litellm.proxy.proxy_server, "proxy_logging_obj", original_proxy_logging_obj ) - - -@pytest.mark.asyncio -async def test_pass_through_endpoint_anthropic(client): - import litellm - from litellm import Router - from litellm.adapters.anthropic_adapter import anthropic_adapter - - router = Router( - model_list=[ - { - "model_name": "gpt-3.5-turbo", - "litellm_params": { - "model": "gpt-3.5-turbo", - "api_key": os.getenv("OPENAI_API_KEY"), - "mock_response": "Hey, how's it going?", - }, - } - ] - ) - - setattr(litellm.proxy.proxy_server, "llm_router", router) - - # Define a pass-through endpoint - pass_through_endpoints = [ - { - "path": "/v1/test-messages", - "target": anthropic_adapter, - "headers": {"litellm_user_api_key": "my-test-header"}, - } - ] - - # Initialize the pass-through endpoint - await initialize_pass_through_endpoints(pass_through_endpoints) - general_settings: Optional[dict] = ( - getattr(litellm.proxy.proxy_server, "general_settings", {}) or {} - ) - general_settings.update({"pass_through_endpoints": pass_through_endpoints}) - setattr(litellm.proxy.proxy_server, "general_settings", general_settings) - - _json_data = { - "model": "gpt-3.5-turbo", - "messages": [{"role": "user", "content": "Who are you?"}], - } - - # Make a request to the pass-through endpoint - response = client.post( - "/v1/test-messages", json=_json_data, headers={"my-test-header": "my-test-key"} - ) - - print("JSON response: ", _json_data) - - # Assert the response - assert response.status_code == 200 diff --git a/tests/pass_through_tests/base_anthropic_messages_test.py b/tests/pass_through_tests/base_anthropic_messages_test.py new file mode 100644 index 0000000000..aed267ac8a --- /dev/null +++ b/tests/pass_through_tests/base_anthropic_messages_test.py @@ -0,0 +1,145 @@ +from abc import ABC, abstractmethod + +import anthropic +import pytest + + +class BaseAnthropicMessagesTest(ABC): + """ + Abstract base test class that enforces a common test across all test classes. + """ + + @abstractmethod + def get_client(self): + return anthropic.Anthropic() + + def test_anthropic_basic_completion(self): + print("making basic completion request to anthropic passthrough") + client = self.get_client() + response = client.messages.create( + model="claude-3-5-sonnet-20241022", + max_tokens=1024, + messages=[{"role": "user", "content": "Say 'hello test' and nothing else"}], + extra_body={ + "litellm_metadata": { + "tags": ["test-tag-1", "test-tag-2"], + } + }, + ) + print(response) + + def test_anthropic_streaming(self): + print("making streaming request to anthropic passthrough") + collected_output = [] + client = self.get_client() + with client.messages.stream( + max_tokens=10, + messages=[ + {"role": "user", "content": "Say 'hello stream test' and nothing else"} + ], + model="claude-3-5-sonnet-20241022", + extra_body={ + "litellm_metadata": { + "tags": ["test-tag-stream-1", "test-tag-stream-2"], + } + }, + ) as stream: + for text in stream.text_stream: + collected_output.append(text) + + full_response = "".join(collected_output) + print(full_response) + + def test_anthropic_messages_with_thinking(self): + print("making request to anthropic passthrough with thinking") + client = self.get_client() + response = client.messages.create( + model="claude-3-7-sonnet-20250219", + max_tokens=20000, + thinking={"type": "enabled", "budget_tokens": 16000}, + messages=[ + {"role": "user", "content": "Just pinging with thinking enabled"} + ], + ) + + print(response) + + # Verify the first content block is a thinking block + response_thinking = response.content[0].thinking + assert response_thinking is not None + assert len(response_thinking) > 0 + + def test_anthropic_streaming_with_thinking(self): + print("making streaming request to anthropic passthrough with thinking enabled") + collected_thinking = [] + collected_response = [] + client = self.get_client() + with client.messages.stream( + model="claude-3-7-sonnet-20250219", + max_tokens=20000, + thinking={"type": "enabled", "budget_tokens": 16000}, + messages=[ + {"role": "user", "content": "Just pinging with thinking enabled"} + ], + ) as stream: + for event in stream: + if event.type == "content_block_delta": + if event.delta.type == "thinking_delta": + collected_thinking.append(event.delta.thinking) + elif event.delta.type == "text_delta": + collected_response.append(event.delta.text) + + full_thinking = "".join(collected_thinking) + full_response = "".join(collected_response) + + print( + f"Thinking Response: {full_thinking[:100]}..." + ) # Print first 100 chars of thinking + print(f"Response: {full_response}") + + # Verify we received thinking content + assert len(collected_thinking) > 0 + assert len(full_thinking) > 0 + + # Verify we also received a response + assert len(collected_response) > 0 + assert len(full_response) > 0 + + def test_bad_request_error_handling_streaming(self): + print("making request to anthropic passthrough with bad request") + try: + client = self.get_client() + response = client.messages.create( + model="claude-3-5-sonnet-20241022", + max_tokens=10, + stream=True, + messages=["hi"], + ) + print(response) + assert pytest.fail("Expected BadRequestError") + except anthropic.BadRequestError as e: + print("Got BadRequestError from anthropic, e=", e) + print(e.__cause__) + print(e.status_code) + print(e.response) + except Exception as e: + pytest.fail(f"Got unexpected exception: {e}") + + def test_bad_request_error_handling_non_streaming(self): + print("making request to anthropic passthrough with bad request") + try: + client = self.get_client() + response = client.messages.create( + model="claude-3-5-sonnet-20241022", + max_tokens=10, + messages=["hi"], + ) + print(response) + assert pytest.fail("Expected BadRequestError") + except anthropic.BadRequestError as e: + print("Got BadRequestError from anthropic, e=", e) + print(e.__cause__) + print(e.status_code) + print(e.response) + except Exception as e: + pytest.fail(f"Got unexpected exception: {e}") diff --git a/tests/pass_through_tests/test_anthropic_passthrough.py b/tests/pass_through_tests/test_anthropic_passthrough.py index c9cb0e0e55..82fd2815ae 100644 --- a/tests/pass_through_tests/test_anthropic_passthrough.py +++ b/tests/pass_through_tests/test_anthropic_passthrough.py @@ -8,48 +8,6 @@ import aiohttp import asyncio import json -client = anthropic.Anthropic( - base_url="http://0.0.0.0:4000/anthropic", api_key="sk-1234" -) - - -def test_anthropic_basic_completion(): - print("making basic completion request to anthropic passthrough") - response = client.messages.create( - model="claude-3-5-sonnet-20241022", - max_tokens=1024, - messages=[{"role": "user", "content": "Say 'hello test' and nothing else"}], - extra_body={ - "litellm_metadata": { - "tags": ["test-tag-1", "test-tag-2"], - } - }, - ) - print(response) - - -def test_anthropic_streaming(): - print("making streaming request to anthropic passthrough") - collected_output = [] - - with client.messages.stream( - max_tokens=10, - messages=[ - {"role": "user", "content": "Say 'hello stream test' and nothing else"} - ], - model="claude-3-5-sonnet-20241022", - extra_body={ - "litellm_metadata": { - "tags": ["test-tag-stream-1", "test-tag-stream-2"], - } - }, - ) as stream: - for text in stream.text_stream: - collected_output.append(text) - - full_response = "".join(collected_output) - print(full_response) - @pytest.mark.asyncio async def test_anthropic_basic_completion_with_headers(): diff --git a/tests/pass_through_tests/test_anthropic_passthrough_basic.py b/tests/pass_through_tests/test_anthropic_passthrough_basic.py new file mode 100644 index 0000000000..86d9381824 --- /dev/null +++ b/tests/pass_through_tests/test_anthropic_passthrough_basic.py @@ -0,0 +1,28 @@ +from base_anthropic_messages_test import BaseAnthropicMessagesTest +import anthropic + + +class TestAnthropicPassthroughBasic(BaseAnthropicMessagesTest): + + def get_client(self): + return anthropic.Anthropic( + base_url="http://0.0.0.0:4000/anthropic", + api_key="sk-1234", + ) + + +class TestAnthropicMessagesEndpoint(BaseAnthropicMessagesTest): + def get_client(self): + return anthropic.Anthropic( + base_url="http://0.0.0.0:4000", + api_key="sk-1234", + ) + + def test_anthropic_messages_to_wildcard_model(self): + client = self.get_client() + response = client.messages.create( + model="anthropic/claude-3-opus-20240229", + messages=[{"role": "user", "content": "Hello, world!"}], + max_tokens=100, + ) + print(response) diff --git a/tests/pass_through_unit_tests/test_anthropic_messages_passthrough.py b/tests/pass_through_unit_tests/test_anthropic_messages_passthrough.py new file mode 100644 index 0000000000..b5b3302acc --- /dev/null +++ b/tests/pass_through_unit_tests/test_anthropic_messages_passthrough.py @@ -0,0 +1,487 @@ +import json +import os +import sys +from datetime import datetime +from typing import AsyncIterator, Dict, Any +import asyncio +import unittest.mock +from unittest.mock import AsyncMock, MagicMock + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +import pytest +from dotenv import load_dotenv +from litellm.llms.anthropic.experimental_pass_through.messages.handler import ( + anthropic_messages, +) +from typing import Optional +from litellm.types.utils import StandardLoggingPayload +from litellm.integrations.custom_logger import CustomLogger +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.router import Router +import importlib + +# Load environment variables +load_dotenv() + + +@pytest.fixture(scope="session") +def event_loop(): + """Create an instance of the default event loop for each test session.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="function", autouse=True) +def setup_and_teardown(event_loop): # Add event_loop as a dependency + curr_dir = os.getcwd() + sys.path.insert(0, os.path.abspath("../..")) + + import litellm + from litellm import Router + + importlib.reload(litellm) + + # Set the event loop from the fixture + asyncio.set_event_loop(event_loop) + + print(litellm) + yield + + # Clean up any pending tasks + pending = asyncio.all_tasks(event_loop) + for task in pending: + task.cancel() + + # Run the event loop until all tasks are cancelled + if pending: + event_loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + + +def _validate_anthropic_response(response: Dict[str, Any]): + assert "id" in response + assert "content" in response + assert "model" in response + assert response["role"] == "assistant" + + +@pytest.mark.asyncio +async def test_anthropic_messages_non_streaming(): + """ + Test the anthropic_messages with non-streaming request + """ + # Get API key from environment + api_key = os.getenv("ANTHROPIC_API_KEY") + if not api_key: + pytest.skip("ANTHROPIC_API_KEY not found in environment") + + # Set up test parameters + messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}] + + # Call the handler + response = await anthropic_messages( + messages=messages, + api_key=api_key, + model="claude-3-haiku-20240307", + max_tokens=100, + ) + + # Verify response + assert "id" in response + assert "content" in response + assert "model" in response + assert response["role"] == "assistant" + + print(f"Non-streaming response: {json.dumps(response, indent=2)}") + return response + + +@pytest.mark.asyncio +async def test_anthropic_messages_streaming(): + """ + Test the anthropic_messages with streaming request + """ + # Get API key from environment + api_key = os.getenv("ANTHROPIC_API_KEY") + if not api_key: + pytest.skip("ANTHROPIC_API_KEY not found in environment") + + # Set up test parameters + messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}] + + # Call the handler + async_httpx_client = AsyncHTTPHandler() + response = await anthropic_messages( + messages=messages, + api_key=api_key, + model="claude-3-haiku-20240307", + max_tokens=100, + stream=True, + client=async_httpx_client, + ) + + if isinstance(response, AsyncIterator): + async for chunk in response: + print("chunk=", chunk) + + +@pytest.mark.asyncio +async def test_anthropic_messages_streaming_with_bad_request(): + """ + Test the anthropic_messages with streaming request + """ + try: + response = await anthropic_messages( + messages=["hi"], + api_key=os.getenv("ANTHROPIC_API_KEY"), + model="claude-3-haiku-20240307", + max_tokens=100, + stream=True, + ) + print(response) + async for chunk in response: + print("chunk=", chunk) + except Exception as e: + print("got exception", e) + print("vars", vars(e)) + assert e.status_code == 400 + + +@pytest.mark.asyncio +async def test_anthropic_messages_router_streaming_with_bad_request(): + """ + Test the anthropic_messages with streaming request + """ + try: + router = Router( + model_list=[ + { + "model_name": "claude-special-alias", + "litellm_params": { + "model": "claude-3-haiku-20240307", + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + } + ] + ) + + response = await router.aanthropic_messages( + messages=["hi"], + model="claude-special-alias", + max_tokens=100, + stream=True, + ) + print(response) + async for chunk in response: + print("chunk=", chunk) + except Exception as e: + print("got exception", e) + print("vars", vars(e)) + assert e.status_code == 400 + + +@pytest.mark.asyncio +async def test_anthropic_messages_litellm_router_non_streaming(): + """ + Test the anthropic_messages with non-streaming request + """ + litellm._turn_on_debug() + router = Router( + model_list=[ + { + "model_name": "claude-special-alias", + "litellm_params": { + "model": "claude-3-haiku-20240307", + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + } + ] + ) + + # Set up test parameters + messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}] + + # Call the handler + response = await router.aanthropic_messages( + messages=messages, + model="claude-special-alias", + max_tokens=100, + ) + + # Verify response + assert "id" in response + assert "content" in response + assert "model" in response + assert response["role"] == "assistant" + + print(f"Non-streaming response: {json.dumps(response, indent=2)}") + return response + + +class TestCustomLogger(CustomLogger): + def __init__(self): + super().__init__() + self.logged_standard_logging_payload: Optional[StandardLoggingPayload] = None + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + print("inside async_log_success_event") + self.logged_standard_logging_payload = kwargs.get("standard_logging_object") + + pass + + +@pytest.mark.asyncio +async def test_anthropic_messages_litellm_router_non_streaming_with_logging(): + """ + Test the anthropic_messages with non-streaming request + + - Ensure Cost + Usage is tracked + """ + test_custom_logger = TestCustomLogger() + litellm.callbacks = [test_custom_logger] + litellm._turn_on_debug() + router = Router( + model_list=[ + { + "model_name": "claude-special-alias", + "litellm_params": { + "model": "claude-3-haiku-20240307", + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + } + ] + ) + + # Set up test parameters + messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}] + + # Call the handler + response = await router.aanthropic_messages( + messages=messages, + model="claude-special-alias", + max_tokens=100, + ) + + # Verify response + _validate_anthropic_response(response) + + print(f"Non-streaming response: {json.dumps(response, indent=2)}") + + await asyncio.sleep(1) + assert test_custom_logger.logged_standard_logging_payload["messages"] == messages + assert test_custom_logger.logged_standard_logging_payload["response"] is not None + assert ( + test_custom_logger.logged_standard_logging_payload["model"] + == "claude-3-haiku-20240307" + ) + + # check logged usage + spend + assert test_custom_logger.logged_standard_logging_payload["response_cost"] > 0 + assert ( + test_custom_logger.logged_standard_logging_payload["prompt_tokens"] + == response["usage"]["input_tokens"] + ) + assert ( + test_custom_logger.logged_standard_logging_payload["completion_tokens"] + == response["usage"]["output_tokens"] + ) + + +@pytest.mark.asyncio +async def test_anthropic_messages_litellm_router_streaming_with_logging(): + """ + Test the anthropic_messages with streaming request + + - Ensure Cost + Usage is tracked + """ + test_custom_logger = TestCustomLogger() + litellm.callbacks = [test_custom_logger] + # litellm._turn_on_debug() + router = Router( + model_list=[ + { + "model_name": "claude-special-alias", + "litellm_params": { + "model": "claude-3-haiku-20240307", + "api_key": os.getenv("ANTHROPIC_API_KEY"), + }, + } + ] + ) + + # Set up test parameters + messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}] + + # Call the handler + response = await router.aanthropic_messages( + messages=messages, + model="claude-special-alias", + max_tokens=100, + stream=True, + ) + + response_prompt_tokens = 0 + response_completion_tokens = 0 + all_anthropic_usage_chunks = [] + + async for chunk in response: + # Decode chunk if it's bytes + print("chunk=", chunk) + + # Handle SSE format chunks + if isinstance(chunk, bytes): + chunk_str = chunk.decode("utf-8") + # Extract the JSON data part from SSE format + for line in chunk_str.split("\n"): + if line.startswith("data: "): + try: + json_data = json.loads(line[6:]) # Skip the 'data: ' prefix + print( + "\n\nJSON data:", + json.dumps(json_data, indent=4, default=str), + ) + + # Extract usage information + if ( + json_data.get("type") == "message_start" + and "message" in json_data + ): + if "usage" in json_data["message"]: + usage = json_data["message"]["usage"] + all_anthropic_usage_chunks.append(usage) + print( + "USAGE BLOCK", + json.dumps(usage, indent=4, default=str), + ) + elif "usage" in json_data: + usage = json_data["usage"] + all_anthropic_usage_chunks.append(usage) + print( + "USAGE BLOCK", json.dumps(usage, indent=4, default=str) + ) + except json.JSONDecodeError: + print(f"Failed to parse JSON from: {line[6:]}") + elif hasattr(chunk, "message"): + if chunk.message.usage: + print( + "USAGE BLOCK", + json.dumps(chunk.message.usage, indent=4, default=str), + ) + all_anthropic_usage_chunks.append(chunk.message.usage) + elif hasattr(chunk, "usage"): + print("USAGE BLOCK", json.dumps(chunk.usage, indent=4, default=str)) + all_anthropic_usage_chunks.append(chunk.usage) + + print( + "all_anthropic_usage_chunks", + json.dumps(all_anthropic_usage_chunks, indent=4, default=str), + ) + + # Extract token counts from usage data + if all_anthropic_usage_chunks: + response_prompt_tokens = max( + [usage.get("input_tokens", 0) for usage in all_anthropic_usage_chunks] + ) + response_completion_tokens = max( + [usage.get("output_tokens", 0) for usage in all_anthropic_usage_chunks] + ) + + print("input_tokens_anthropic_api", response_prompt_tokens) + print("output_tokens_anthropic_api", response_completion_tokens) + + await asyncio.sleep(4) + + print( + "logged_standard_logging_payload", + json.dumps( + test_custom_logger.logged_standard_logging_payload, indent=4, default=str + ), + ) + + assert test_custom_logger.logged_standard_logging_payload["messages"] == messages + assert test_custom_logger.logged_standard_logging_payload["response"] is not None + assert ( + test_custom_logger.logged_standard_logging_payload["model"] + == "claude-3-haiku-20240307" + ) + + # check logged usage + spend + assert test_custom_logger.logged_standard_logging_payload["response_cost"] > 0 + assert ( + test_custom_logger.logged_standard_logging_payload["prompt_tokens"] + == response_prompt_tokens + ) + assert ( + test_custom_logger.logged_standard_logging_payload["completion_tokens"] + == response_completion_tokens + ) + + +@pytest.mark.asyncio +async def test_anthropic_messages_with_extra_headers(): + """ + Test the anthropic_messages with extra headers + """ + # Get API key from environment + api_key = os.getenv("ANTHROPIC_API_KEY", "fake-api-key") + + # Set up test parameters + messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}] + extra_headers = { + "anthropic-beta": "very-custom-beta-value", + "anthropic-version": "custom-version-for-test", + } + + # Create a mock response + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = { + "id": "msg_123456", + "type": "message", + "role": "assistant", + "content": [ + { + "type": "text", + "text": "Why did the chicken cross the road? To get to the other side!", + } + ], + "model": "claude-3-haiku-20240307", + "stop_reason": "end_turn", + "usage": {"input_tokens": 10, "output_tokens": 20}, + } + + # Create a mock client with AsyncMock for the post method + mock_client = MagicMock(spec=AsyncHTTPHandler) + mock_client.post = AsyncMock(return_value=mock_response) + + # Call the handler with extra_headers and our mocked client + response = await anthropic_messages( + messages=messages, + api_key=api_key, + model="claude-3-haiku-20240307", + max_tokens=100, + client=mock_client, + provider_specific_header={ + "custom_llm_provider": "anthropic", + "extra_headers": extra_headers, + }, + ) + + # Verify the post method was called with the right parameters + mock_client.post.assert_called_once() + call_kwargs = mock_client.post.call_args.kwargs + + # Verify headers were passed correctly + headers = call_kwargs.get("headers", {}) + print("HEADERS IN REQUEST", headers) + for key, value in extra_headers.items(): + assert key in headers + assert headers[key] == value + + # Verify the response was processed correctly + assert response == mock_response.json.return_value + + return response diff --git a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py index d82cba8a11..ba5dfa33a8 100644 --- a/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py +++ b/tests/pass_through_unit_tests/test_unit_test_vertex_pass_through.py @@ -54,7 +54,7 @@ async def test_get_litellm_virtual_key(): @pytest.mark.asyncio -async def test_vertex_proxy_route_api_key_auth(): +async def test_async_vertex_proxy_route_api_key_auth(): """ Critical @@ -207,7 +207,7 @@ async def test_get_vertex_credentials_stored(): router.add_vertex_credentials( project_id="test-project", location="us-central1", - vertex_credentials="test-creds", + vertex_credentials='{"credentials": "test-creds"}', ) creds = router.get_vertex_credentials( @@ -215,7 +215,7 @@ async def test_get_vertex_credentials_stored(): ) assert creds.vertex_project == "test-project" assert creds.vertex_location == "us-central1" - assert creds.vertex_credentials == "test-creds" + assert creds.vertex_credentials == '{"credentials": "test-creds"}' @pytest.mark.asyncio @@ -227,18 +227,20 @@ async def test_add_vertex_credentials(): router.add_vertex_credentials( project_id="test-project", location="us-central1", - vertex_credentials="test-creds", + vertex_credentials='{"credentials": "test-creds"}', ) assert "test-project-us-central1" in router.deployment_key_to_vertex_credentials creds = router.deployment_key_to_vertex_credentials["test-project-us-central1"] assert creds.vertex_project == "test-project" assert creds.vertex_location == "us-central1" - assert creds.vertex_credentials == "test-creds" + assert creds.vertex_credentials == '{"credentials": "test-creds"}' # Test adding with None values router.add_vertex_credentials( - project_id=None, location=None, vertex_credentials="test-creds" + project_id=None, + location=None, + vertex_credentials='{"credentials": "test-creds"}', ) # Should not add None values assert len(router.deployment_key_to_vertex_credentials) == 1 diff --git a/tests/router_unit_tests/test_router_endpoints.py b/tests/router_unit_tests/test_router_endpoints.py index 99164827cc..e80b7dc3a8 100644 --- a/tests/router_unit_tests/test_router_endpoints.py +++ b/tests/router_unit_tests/test_router_endpoints.py @@ -6,6 +6,7 @@ from typing import Optional from dotenv import load_dotenv from fastapi import Request from datetime import datetime +from unittest.mock import AsyncMock, patch sys.path.insert( 0, os.path.abspath("../..") @@ -289,43 +290,6 @@ async def test_aaaaatext_completion_endpoint(model_list, sync_mode): assert response.choices[0].text == "I'm fine, thank you!" -@pytest.mark.asyncio -async def test_anthropic_router_completion_e2e(model_list): - from litellm.adapters.anthropic_adapter import anthropic_adapter - from litellm.types.llms.anthropic import AnthropicResponse - - litellm.set_verbose = True - - litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}] - - router = Router(model_list=model_list) - messages = [{"role": "user", "content": "Hey, how's it going?"}] - - ## Test 1: user facing function - response = await router.aadapter_completion( - model="claude-3-5-sonnet-20240620", - messages=messages, - adapter_id="anthropic", - mock_response="This is a fake call", - ) - - ## Test 2: underlying function - await router._aadapter_completion( - model="claude-3-5-sonnet-20240620", - messages=messages, - adapter_id="anthropic", - mock_response="This is a fake call", - ) - - print("Response: {}".format(response)) - - assert response is not None - - AnthropicResponse.model_validate(response) - - assert response.model == "gpt-3.5-turbo" - - @pytest.mark.asyncio async def test_router_with_empty_choices(model_list): """ @@ -349,3 +313,200 @@ async def test_router_with_empty_choices(model_list): mock_response=mock_response, ) assert response is not None + + +@pytest.mark.asyncio +async def test_ageneric_api_call_with_fallbacks_basic(): + """ + Test the _ageneric_api_call_with_fallbacks method with a basic successful call + """ + # Create a mock function that will be passed to _ageneric_api_call_with_fallbacks + mock_function = AsyncMock() + mock_function.__name__ = "test_function" + + # Create a mock response + mock_response = { + "id": "resp_123456", + "role": "assistant", + "content": "This is a test response", + "model": "test-model", + "usage": {"input_tokens": 10, "output_tokens": 20}, + } + mock_function.return_value = mock_response + + # Create a router with a test model + router = Router( + model_list=[ + { + "model_name": "test-model-alias", + "litellm_params": { + "model": "anthropic/test-model", + "api_key": "fake-api-key", + }, + } + ] + ) + + # Call the _ageneric_api_call_with_fallbacks method + response = await router._ageneric_api_call_with_fallbacks( + model="test-model-alias", + original_function=mock_function, + messages=[{"role": "user", "content": "Hello"}], + max_tokens=100, + ) + + # Verify the mock function was called + mock_function.assert_called_once() + + # Verify the response + assert response == mock_response + + +@pytest.mark.asyncio +async def test_aadapter_completion(): + """ + Test the aadapter_completion method which uses async_function_with_fallbacks + """ + # Create a mock for the _aadapter_completion method + mock_response = { + "id": "adapter_resp_123", + "object": "adapter.completion", + "created": 1677858242, + "model": "test-model-with-adapter", + "choices": [ + { + "text": "This is a test adapter response", + "index": 0, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + } + + # Create a router with a patched _aadapter_completion method + with patch.object( + Router, "_aadapter_completion", new_callable=AsyncMock + ) as mock_method: + mock_method.return_value = mock_response + + router = Router( + model_list=[ + { + "model_name": "test-adapter-model", + "litellm_params": { + "model": "anthropic/test-model", + "api_key": "fake-api-key", + }, + } + ] + ) + + # Replace the async_function_with_fallbacks with a mock + router.async_function_with_fallbacks = AsyncMock(return_value=mock_response) + + # Call the aadapter_completion method + response = await router.aadapter_completion( + adapter_id="test-adapter-id", + model="test-adapter-model", + prompt="This is a test prompt", + max_tokens=100, + ) + + # Verify the response + assert response == mock_response + + # Verify async_function_with_fallbacks was called with the right parameters + router.async_function_with_fallbacks.assert_called_once() + call_kwargs = router.async_function_with_fallbacks.call_args.kwargs + assert call_kwargs["adapter_id"] == "test-adapter-id" + assert call_kwargs["model"] == "test-adapter-model" + assert call_kwargs["prompt"] == "This is a test prompt" + assert call_kwargs["max_tokens"] == 100 + assert call_kwargs["original_function"] == router._aadapter_completion + assert "metadata" in call_kwargs + assert call_kwargs["metadata"]["model_group"] == "test-adapter-model" + + +@pytest.mark.asyncio +async def test__aadapter_completion(): + """ + Test the _aadapter_completion method directly + """ + # Create a mock response for litellm.aadapter_completion + mock_response = { + "id": "adapter_resp_123", + "object": "adapter.completion", + "created": 1677858242, + "model": "test-model-with-adapter", + "choices": [ + { + "text": "This is a test adapter response", + "index": 0, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + } + + # Create a router with a mocked litellm.aadapter_completion + with patch( + "litellm.aadapter_completion", new_callable=AsyncMock + ) as mock_adapter_completion: + mock_adapter_completion.return_value = mock_response + + router = Router( + model_list=[ + { + "model_name": "test-adapter-model", + "litellm_params": { + "model": "anthropic/test-model", + "api_key": "fake-api-key", + }, + } + ] + ) + + # Mock the async_get_available_deployment method + router.async_get_available_deployment = AsyncMock( + return_value={ + "model_name": "test-adapter-model", + "litellm_params": { + "model": "test-model", + "api_key": "fake-api-key", + }, + "model_info": { + "id": "test-unique-id", + }, + } + ) + + # Mock the async_routing_strategy_pre_call_checks method + router.async_routing_strategy_pre_call_checks = AsyncMock() + + # Call the _aadapter_completion method + response = await router._aadapter_completion( + adapter_id="test-adapter-id", + model="test-adapter-model", + prompt="This is a test prompt", + max_tokens=100, + ) + + # Verify the response + assert response == mock_response + + # Verify litellm.aadapter_completion was called with the right parameters + mock_adapter_completion.assert_called_once() + call_kwargs = mock_adapter_completion.call_args.kwargs + assert call_kwargs["adapter_id"] == "test-adapter-id" + assert call_kwargs["model"] == "test-model" + assert call_kwargs["prompt"] == "This is a test prompt" + assert call_kwargs["max_tokens"] == 100 + assert call_kwargs["api_key"] == "fake-api-key" + assert call_kwargs["caching"] == router.cache_responses + + # Verify the success call was recorded + assert router.success_calls["test-model"] == 1 + assert router.total_calls["test-model"] == 1 + + # Verify async_routing_strategy_pre_call_checks was called + router.async_routing_strategy_pre_call_checks.assert_called_once()