(Refactor) /v1/messages to follow simpler logic for Anthropic API spec (#9013)

* anthropic_messages_handler v0

* fix /messages

* working messages with router methods

* test_anthropic_messages_handler_litellm_router_non_streaming

* test_anthropic_messages_litellm_router_non_streaming_with_logging

* AnthropicMessagesConfig

* _handle_anthropic_messages_response_logging

* working with /v1/messages endpoint

* working /v1/messages endpoint

* refactor to use router factory function

* use aanthropic_messages

* use BaseConfig for Anthropic /v1/messages

* track api key, team on /v1/messages endpoint

* fix get_logging_payload

* BaseAnthropicMessagesTest

* align test config

* test_anthropic_messages_with_thinking

* test_anthropic_streaming_with_thinking

* fix - display anthropic url for debugging

* test_bad_request_error_handling

* test_anthropic_messages_router_streaming_with_bad_request

* fix ProxyException

* test_bad_request_error_handling_streaming

* use provider_specific_header

* test_anthropic_messages_with_extra_headers

* test_anthropic_messages_to_wildcard_model

* fix gcs pub sub test

* standard_logging_payload

* fix unit testing for anthopic /v1/messages support

* fix pass through anthropic messages api

* delete dead code

* fix anthropic pass through response

* revert change to spend tracking utils

* fix get_litellm_metadata_from_kwargs

* fix spend logs payload json

* proxy_pass_through_endpoint_tests

* TestAnthropicPassthroughBasic

* fix pass through tests

* test_async_vertex_proxy_route_api_key_auth

* _handle_anthropic_messages_response_logging

* vertex_credentials

* test_set_default_vertex_config

* test_anthropic_messages_litellm_router_non_streaming_with_logging

* test_ageneric_api_call_with_fallbacks_basic

* test__aadapter_completion
This commit is contained in:
Ishaan Jaff 2025-03-06 00:43:08 -08:00 committed by GitHub
parent 31c5ea74ab
commit f47987e673
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1581 additions and 1027 deletions

View file

@ -1935,12 +1935,12 @@ jobs:
pip install prisma
pip install fastapi
pip install jsonschema
pip install "httpx==0.24.1"
pip install "httpx==0.27.0"
pip install "anyio==3.7.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
pip install "google-cloud-aiplatform==1.59.0"
pip install "anthropic==0.21.3"
pip install "anthropic==0.49.0"
# Run pytest and generate JUnit XML report
- run:
name: Build Docker image

View file

@ -800,9 +800,6 @@ from .llms.oobabooga.chat.transformation import OobaboogaConfig
from .llms.maritalk import MaritalkConfig
from .llms.openrouter.chat.transformation import OpenrouterConfig
from .llms.anthropic.chat.transformation import AnthropicConfig
from .llms.anthropic.experimental_pass_through.transformation import (
AnthropicExperimentalPassThroughConfig,
)
from .llms.groq.stt.transformation import GroqSTTConfig
from .llms.anthropic.completion.transformation import AnthropicTextConfig
from .llms.triton.completion.transformation import TritonConfig
@ -821,6 +818,9 @@ from .llms.infinity.rerank.transformation import InfinityRerankConfig
from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig
from .llms.clarifai.chat.transformation import ClarifaiConfig
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
from .llms.anthropic.experimental_pass_through.messages.transformation import (
AnthropicMessagesConfig,
)
from .llms.together_ai.chat import TogetherAIConfig
from .llms.together_ai.completion.transformation import TogetherAITextCompletionConfig
from .llms.cloudflare.chat.transformation import CloudflareChatConfig
@ -1011,6 +1011,7 @@ from .assistants.main import *
from .batches.main import *
from .batch_completion.main import * # type: ignore
from .rerank_api.main import *
from .llms.anthropic.experimental_pass_through.messages.handler import *
from .realtime_api.main import _arealtime
from .fine_tuning.main import *
from .files.main import *

View file

@ -1,186 +0,0 @@
# What is this?
## Translates OpenAI call to Anthropic `/v1/messages` format
import traceback
from typing import Any, Optional
import litellm
from litellm import ChatCompletionRequest, verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.llms.anthropic import AnthropicMessagesRequest, AnthropicResponse
from litellm.types.utils import AdapterCompletionStreamWrapper, ModelResponse
class AnthropicAdapter(CustomLogger):
def __init__(self) -> None:
super().__init__()
def translate_completion_input_params(
self, kwargs
) -> Optional[ChatCompletionRequest]:
"""
- translate params, where needed
- pass rest, as is
"""
request_body = AnthropicMessagesRequest(**kwargs) # type: ignore
translated_body = litellm.AnthropicExperimentalPassThroughConfig().translate_anthropic_to_openai(
anthropic_message_request=request_body
)
return translated_body
def translate_completion_output_params(
self, response: ModelResponse
) -> Optional[AnthropicResponse]:
return litellm.AnthropicExperimentalPassThroughConfig().translate_openai_response_to_anthropic(
response=response
)
def translate_completion_output_params_streaming(
self, completion_stream: Any
) -> AdapterCompletionStreamWrapper | None:
return AnthropicStreamWrapper(completion_stream=completion_stream)
anthropic_adapter = AnthropicAdapter()
class AnthropicStreamWrapper(AdapterCompletionStreamWrapper):
"""
- first chunk return 'message_start'
- content block must be started and stopped
- finish_reason must map exactly to anthropic reason, else anthropic client won't be able to parse it.
"""
sent_first_chunk: bool = False
sent_content_block_start: bool = False
sent_content_block_finish: bool = False
sent_last_message: bool = False
holding_chunk: Optional[Any] = None
def __next__(self):
try:
if self.sent_first_chunk is False:
self.sent_first_chunk = True
return {
"type": "message_start",
"message": {
"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY",
"type": "message",
"role": "assistant",
"content": [],
"model": "claude-3-5-sonnet-20240620",
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 25, "output_tokens": 1},
},
}
if self.sent_content_block_start is False:
self.sent_content_block_start = True
return {
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
}
for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
processed_chunk = litellm.AnthropicExperimentalPassThroughConfig().translate_streaming_openai_response_to_anthropic(
response=chunk
)
if (
processed_chunk["type"] == "message_delta"
and self.sent_content_block_finish is False
):
self.holding_chunk = processed_chunk
self.sent_content_block_finish = True
return {
"type": "content_block_stop",
"index": 0,
}
elif self.holding_chunk is not None:
return_chunk = self.holding_chunk
self.holding_chunk = processed_chunk
return return_chunk
else:
return processed_chunk
if self.holding_chunk is not None:
return_chunk = self.holding_chunk
self.holding_chunk = None
return return_chunk
if self.sent_last_message is False:
self.sent_last_message = True
return {"type": "message_stop"}
raise StopIteration
except StopIteration:
if self.sent_last_message is False:
self.sent_last_message = True
return {"type": "message_stop"}
raise StopIteration
except Exception as e:
verbose_logger.error(
"Anthropic Adapter - {}\n{}".format(e, traceback.format_exc())
)
async def __anext__(self):
try:
if self.sent_first_chunk is False:
self.sent_first_chunk = True
return {
"type": "message_start",
"message": {
"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY",
"type": "message",
"role": "assistant",
"content": [],
"model": "claude-3-5-sonnet-20240620",
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 25, "output_tokens": 1},
},
}
if self.sent_content_block_start is False:
self.sent_content_block_start = True
return {
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
}
async for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
processed_chunk = litellm.AnthropicExperimentalPassThroughConfig().translate_streaming_openai_response_to_anthropic(
response=chunk
)
if (
processed_chunk["type"] == "message_delta"
and self.sent_content_block_finish is False
):
self.holding_chunk = processed_chunk
self.sent_content_block_finish = True
return {
"type": "content_block_stop",
"index": 0,
}
elif self.holding_chunk is not None:
return_chunk = self.holding_chunk
self.holding_chunk = processed_chunk
return return_chunk
else:
return processed_chunk
if self.holding_chunk is not None:
return_chunk = self.holding_chunk
self.holding_chunk = None
return return_chunk
if self.sent_last_message is False:
self.sent_last_message = True
return {"type": "message_stop"}
raise StopIteration
except StopIteration:
if self.sent_last_message is False:
self.sent_last_message = True
return {"type": "message_stop"}
raise StopAsyncIteration

View file

@ -73,6 +73,8 @@ def remove_index_from_tool_calls(
def get_litellm_metadata_from_kwargs(kwargs: dict):
"""
Helper to get litellm metadata from all litellm request kwargs
Return `litellm_metadata` if it exists, otherwise return `metadata`
"""
litellm_params = kwargs.get("litellm_params", {})
if litellm_params:

View file

@ -932,6 +932,9 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time
self.model_call_details["cache_hit"] = cache_hit
if self.call_type == CallTypes.anthropic_messages.value:
result = self._handle_anthropic_messages_response_logging(result=result)
## if model in model cost map - log the response cost
## else set cost to None
if (
@ -2304,6 +2307,37 @@ class Logging(LiteLLMLoggingBaseClass):
return complete_streaming_response
return None
def _handle_anthropic_messages_response_logging(self, result: Any) -> ModelResponse:
"""
Handles logging for Anthropic messages responses.
Args:
result: The response object from the model call
Returns:
The the response object from the model call
- For Non-streaming responses, we need to transform the response to a ModelResponse object.
- For streaming responses, anthropic_messages handler calls success_handler with a assembled ModelResponse.
"""
if self.stream and isinstance(result, ModelResponse):
return result
result = litellm.AnthropicConfig().transform_response(
raw_response=self.model_call_details["httpx_response"],
model_response=litellm.ModelResponse(),
model=self.model,
messages=[],
logging_obj=self,
optional_params={},
api_key="",
request_data={},
encoding=litellm.encoding,
json_mode=False,
litellm_params={},
)
return result
def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
"""

View file

@ -0,0 +1,179 @@
"""
- call /messages on Anthropic API
- Make streaming + non-streaming request - just pass it through direct to Anthropic. No need to do anything special here
- Ensure requests are logged in the DB - stream + non-stream
"""
import json
from typing import Any, AsyncIterator, Dict, Optional, Union, cast
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.anthropic_messages.transformation import (
BaseAnthropicMessagesConfig,
)
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import ProviderSpecificHeader
from litellm.utils import ProviderConfigManager, client
class AnthropicMessagesHandler:
@staticmethod
async def _handle_anthropic_streaming(
response: httpx.Response,
request_body: dict,
litellm_logging_obj: LiteLLMLoggingObj,
) -> AsyncIterator:
"""Helper function to handle Anthropic streaming responses using the existing logging handlers"""
from datetime import datetime
from litellm.proxy.pass_through_endpoints.streaming_handler import (
PassThroughStreamingHandler,
)
from litellm.proxy.pass_through_endpoints.success_handler import (
PassThroughEndpointLogging,
)
from litellm.proxy.pass_through_endpoints.types import EndpointType
# Create success handler object
passthrough_success_handler_obj = PassThroughEndpointLogging()
# Use the existing streaming handler for Anthropic
start_time = datetime.now()
return PassThroughStreamingHandler.chunk_processor(
response=response,
request_body=request_body,
litellm_logging_obj=litellm_logging_obj,
endpoint_type=EndpointType.ANTHROPIC,
start_time=start_time,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route="/v1/messages",
)
@client
async def anthropic_messages(
api_key: str,
model: str,
stream: bool = False,
api_base: Optional[str] = None,
client: Optional[AsyncHTTPHandler] = None,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> Union[Dict[str, Any], AsyncIterator]:
"""
Makes Anthropic `/v1/messages` API calls In the Anthropic API Spec
"""
# Use provided client or create a new one
optional_params = GenericLiteLLMParams(**kwargs)
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
litellm.get_llm_provider(
model=model,
custom_llm_provider=custom_llm_provider,
api_base=optional_params.api_base,
api_key=optional_params.api_key,
)
)
anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = (
ProviderConfigManager.get_provider_anthropic_messages_config(
model=model,
provider=litellm.LlmProviders(_custom_llm_provider),
)
)
if anthropic_messages_provider_config is None:
raise ValueError(
f"Anthropic messages provider config not found for model: {model}"
)
if client is None or not isinstance(client, AsyncHTTPHandler):
async_httpx_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.ANTHROPIC
)
else:
async_httpx_client = client
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
# Prepare headers
provider_specific_header = cast(
Optional[ProviderSpecificHeader], kwargs.get("provider_specific_header", None)
)
extra_headers = (
provider_specific_header.get("extra_headers", {})
if provider_specific_header
else {}
)
headers = anthropic_messages_provider_config.validate_environment(
headers=extra_headers or {},
model=model,
api_key=api_key,
)
litellm_logging_obj.update_environment_variables(
model=model,
optional_params=dict(optional_params),
litellm_params={
"metadata": kwargs.get("metadata", {}),
"preset_cache_key": None,
"stream_response": {},
**optional_params.model_dump(exclude_unset=True),
},
custom_llm_provider=_custom_llm_provider,
)
litellm_logging_obj.model_call_details.update(kwargs)
# Prepare request body
request_body = kwargs.copy()
request_body = {
k: v
for k, v in request_body.items()
if k
in anthropic_messages_provider_config.get_supported_anthropic_messages_params(
model=model
)
}
request_body["stream"] = stream
request_body["model"] = model
litellm_logging_obj.stream = stream
# Make the request
request_url = anthropic_messages_provider_config.get_complete_url(
api_base=api_base, model=model
)
litellm_logging_obj.pre_call(
input=[{"role": "user", "content": json.dumps(request_body)}],
api_key="",
additional_args={
"complete_input_dict": request_body,
"api_base": str(request_url),
"headers": headers,
},
)
response = await async_httpx_client.post(
url=request_url,
headers=headers,
data=json.dumps(request_body),
stream=stream,
)
response.raise_for_status()
# used for logging + cost tracking
litellm_logging_obj.model_call_details["httpx_response"] = response
if stream:
return await AnthropicMessagesHandler._handle_anthropic_streaming(
response=response,
request_body=request_body,
litellm_logging_obj=litellm_logging_obj,
)
else:
return response.json()

View file

@ -0,0 +1,47 @@
from typing import Optional
from litellm.llms.base_llm.anthropic_messages.transformation import (
BaseAnthropicMessagesConfig,
)
DEFAULT_ANTHROPIC_API_BASE = "https://api.anthropic.com"
DEFAULT_ANTHROPIC_API_VERSION = "2023-06-01"
class AnthropicMessagesConfig(BaseAnthropicMessagesConfig):
def get_supported_anthropic_messages_params(self, model: str) -> list:
return [
"messages",
"model",
"system",
"max_tokens",
"stop_sequences",
"temperature",
"top_p",
"top_k",
"tools",
"tool_choice",
"thinking",
# TODO: Add Anthropic `metadata` support
# "metadata",
]
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
api_base = api_base or DEFAULT_ANTHROPIC_API_BASE
if not api_base.endswith("/v1/messages"):
api_base = f"{api_base}/v1/messages"
return api_base
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
if "x-api-key" not in headers:
headers["x-api-key"] = api_key
if "anthropic-version" not in headers:
headers["anthropic-version"] = DEFAULT_ANTHROPIC_API_VERSION
if "content-type" not in headers:
headers["content-type"] = "application/json"
return headers

View file

@ -1,412 +0,0 @@
import json
from typing import List, Literal, Optional, Tuple, Union
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
from litellm.types.llms.anthropic import (
AllAnthropicToolsValues,
AnthopicMessagesAssistantMessageParam,
AnthropicFinishReason,
AnthropicMessagesRequest,
AnthropicMessagesToolChoice,
AnthropicMessagesUserMessageParam,
AnthropicResponse,
AnthropicResponseContentBlockText,
AnthropicResponseContentBlockToolUse,
AnthropicResponseUsageBlock,
ContentBlockDelta,
ContentJsonBlockDelta,
ContentTextBlockDelta,
MessageBlockDelta,
MessageDelta,
UsageDelta,
)
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionAssistantToolCall,
ChatCompletionImageObject,
ChatCompletionImageUrlObject,
ChatCompletionRequest,
ChatCompletionSystemMessage,
ChatCompletionTextObject,
ChatCompletionToolCallFunctionChunk,
ChatCompletionToolChoiceFunctionParam,
ChatCompletionToolChoiceObjectParam,
ChatCompletionToolChoiceValues,
ChatCompletionToolMessage,
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
ChatCompletionUserMessage,
)
from litellm.types.utils import Choices, ModelResponse, Usage
class AnthropicExperimentalPassThroughConfig:
def __init__(self):
pass
### FOR [BETA] `/v1/messages` endpoint support
def translatable_anthropic_params(self) -> List:
"""
Which anthropic params, we need to translate to the openai format.
"""
return ["messages", "metadata", "system", "tool_choice", "tools"]
def translate_anthropic_messages_to_openai( # noqa: PLR0915
self,
messages: List[
Union[
AnthropicMessagesUserMessageParam,
AnthopicMessagesAssistantMessageParam,
]
],
) -> List:
new_messages: List[AllMessageValues] = []
for m in messages:
user_message: Optional[ChatCompletionUserMessage] = None
tool_message_list: List[ChatCompletionToolMessage] = []
new_user_content_list: List[
Union[ChatCompletionTextObject, ChatCompletionImageObject]
] = []
## USER MESSAGE ##
if m["role"] == "user":
## translate user message
message_content = m.get("content")
if message_content and isinstance(message_content, str):
user_message = ChatCompletionUserMessage(
role="user", content=message_content
)
elif message_content and isinstance(message_content, list):
for content in message_content:
if content["type"] == "text":
text_obj = ChatCompletionTextObject(
type="text", text=content["text"]
)
new_user_content_list.append(text_obj)
elif content["type"] == "image":
image_url = ChatCompletionImageUrlObject(
url=f"data:{content['type']};base64,{content['source']}"
)
image_obj = ChatCompletionImageObject(
type="image_url", image_url=image_url
)
new_user_content_list.append(image_obj)
elif content["type"] == "tool_result":
if "content" not in content:
tool_result = ChatCompletionToolMessage(
role="tool",
tool_call_id=content["tool_use_id"],
content="",
)
tool_message_list.append(tool_result)
elif isinstance(content["content"], str):
tool_result = ChatCompletionToolMessage(
role="tool",
tool_call_id=content["tool_use_id"],
content=content["content"],
)
tool_message_list.append(tool_result)
elif isinstance(content["content"], list):
for c in content["content"]:
if c["type"] == "text":
tool_result = ChatCompletionToolMessage(
role="tool",
tool_call_id=content["tool_use_id"],
content=c["text"],
)
tool_message_list.append(tool_result)
elif c["type"] == "image":
image_str = (
f"data:{c['type']};base64,{c['source']}"
)
tool_result = ChatCompletionToolMessage(
role="tool",
tool_call_id=content["tool_use_id"],
content=image_str,
)
tool_message_list.append(tool_result)
if user_message is not None:
new_messages.append(user_message)
if len(new_user_content_list) > 0:
new_messages.append({"role": "user", "content": new_user_content_list}) # type: ignore
if len(tool_message_list) > 0:
new_messages.extend(tool_message_list)
## ASSISTANT MESSAGE ##
assistant_message_str: Optional[str] = None
tool_calls: List[ChatCompletionAssistantToolCall] = []
if m["role"] == "assistant":
if isinstance(m["content"], str):
assistant_message_str = m["content"]
elif isinstance(m["content"], list):
for content in m["content"]:
if content["type"] == "text":
if assistant_message_str is None:
assistant_message_str = content["text"]
else:
assistant_message_str += content["text"]
elif content["type"] == "tool_use":
function_chunk = ChatCompletionToolCallFunctionChunk(
name=content["name"],
arguments=json.dumps(content["input"]),
)
tool_calls.append(
ChatCompletionAssistantToolCall(
id=content["id"],
type="function",
function=function_chunk,
)
)
if assistant_message_str is not None or len(tool_calls) > 0:
assistant_message = ChatCompletionAssistantMessage(
role="assistant",
content=assistant_message_str,
)
if len(tool_calls) > 0:
assistant_message["tool_calls"] = tool_calls
new_messages.append(assistant_message)
return new_messages
def translate_anthropic_tool_choice_to_openai(
self, tool_choice: AnthropicMessagesToolChoice
) -> ChatCompletionToolChoiceValues:
if tool_choice["type"] == "any":
return "required"
elif tool_choice["type"] == "auto":
return "auto"
elif tool_choice["type"] == "tool":
tc_function_param = ChatCompletionToolChoiceFunctionParam(
name=tool_choice.get("name", "")
)
return ChatCompletionToolChoiceObjectParam(
type="function", function=tc_function_param
)
else:
raise ValueError(
"Incompatible tool choice param submitted - {}".format(tool_choice)
)
def translate_anthropic_tools_to_openai(
self, tools: List[AllAnthropicToolsValues]
) -> List[ChatCompletionToolParam]:
new_tools: List[ChatCompletionToolParam] = []
mapped_tool_params = ["name", "input_schema", "description"]
for tool in tools:
function_chunk = ChatCompletionToolParamFunctionChunk(
name=tool["name"],
)
if "input_schema" in tool:
function_chunk["parameters"] = tool["input_schema"] # type: ignore
if "description" in tool:
function_chunk["description"] = tool["description"] # type: ignore
for k, v in tool.items():
if k not in mapped_tool_params: # pass additional computer kwargs
function_chunk.setdefault("parameters", {}).update({k: v})
new_tools.append(
ChatCompletionToolParam(type="function", function=function_chunk)
)
return new_tools
def translate_anthropic_to_openai(
self, anthropic_message_request: AnthropicMessagesRequest
) -> ChatCompletionRequest:
"""
This is used by the beta Anthropic Adapter, for translating anthropic `/v1/messages` requests to the openai format.
"""
new_messages: List[AllMessageValues] = []
## CONVERT ANTHROPIC MESSAGES TO OPENAI
new_messages = self.translate_anthropic_messages_to_openai(
messages=anthropic_message_request["messages"]
)
## ADD SYSTEM MESSAGE TO MESSAGES
if "system" in anthropic_message_request:
new_messages.insert(
0,
ChatCompletionSystemMessage(
role="system", content=anthropic_message_request["system"]
),
)
new_kwargs: ChatCompletionRequest = {
"model": anthropic_message_request["model"],
"messages": new_messages,
}
## CONVERT METADATA (user_id)
if "metadata" in anthropic_message_request:
if "user_id" in anthropic_message_request["metadata"]:
new_kwargs["user"] = anthropic_message_request["metadata"]["user_id"]
# Pass litellm proxy specific metadata
if "litellm_metadata" in anthropic_message_request:
# metadata will be passed to litellm.acompletion(), it's a litellm_param
new_kwargs["metadata"] = anthropic_message_request.pop("litellm_metadata")
## CONVERT TOOL CHOICE
if "tool_choice" in anthropic_message_request:
new_kwargs["tool_choice"] = self.translate_anthropic_tool_choice_to_openai(
tool_choice=anthropic_message_request["tool_choice"]
)
## CONVERT TOOLS
if "tools" in anthropic_message_request:
new_kwargs["tools"] = self.translate_anthropic_tools_to_openai(
tools=anthropic_message_request["tools"]
)
translatable_params = self.translatable_anthropic_params()
for k, v in anthropic_message_request.items():
if k not in translatable_params: # pass remaining params as is
new_kwargs[k] = v # type: ignore
return new_kwargs
def _translate_openai_content_to_anthropic(
self, choices: List[Choices]
) -> List[
Union[AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse]
]:
new_content: List[
Union[
AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse
]
] = []
for choice in choices:
if (
choice.message.tool_calls is not None
and len(choice.message.tool_calls) > 0
):
for tool_call in choice.message.tool_calls:
new_content.append(
AnthropicResponseContentBlockToolUse(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name or "",
input=json.loads(tool_call.function.arguments),
)
)
elif choice.message.content is not None:
new_content.append(
AnthropicResponseContentBlockText(
type="text", text=choice.message.content
)
)
return new_content
def _translate_openai_finish_reason_to_anthropic(
self, openai_finish_reason: str
) -> AnthropicFinishReason:
if openai_finish_reason == "stop":
return "end_turn"
elif openai_finish_reason == "length":
return "max_tokens"
elif openai_finish_reason == "tool_calls":
return "tool_use"
return "end_turn"
def translate_openai_response_to_anthropic(
self, response: ModelResponse
) -> AnthropicResponse:
## translate content block
anthropic_content = self._translate_openai_content_to_anthropic(choices=response.choices) # type: ignore
## extract finish reason
anthropic_finish_reason = self._translate_openai_finish_reason_to_anthropic(
openai_finish_reason=response.choices[0].finish_reason # type: ignore
)
# extract usage
usage: Usage = getattr(response, "usage")
anthropic_usage = AnthropicResponseUsageBlock(
input_tokens=usage.prompt_tokens or 0,
output_tokens=usage.completion_tokens or 0,
)
translated_obj = AnthropicResponse(
id=response.id,
type="message",
role="assistant",
model=response.model or "unknown-model",
stop_sequence=None,
usage=anthropic_usage,
content=anthropic_content,
stop_reason=anthropic_finish_reason,
)
return translated_obj
def _translate_streaming_openai_chunk_to_anthropic(
self, choices: List[OpenAIStreamingChoice]
) -> Tuple[
Literal["text_delta", "input_json_delta"],
Union[ContentTextBlockDelta, ContentJsonBlockDelta],
]:
text: str = ""
partial_json: Optional[str] = None
for choice in choices:
if choice.delta.content is not None:
text += choice.delta.content
elif choice.delta.tool_calls is not None:
partial_json = ""
for tool in choice.delta.tool_calls:
if (
tool.function is not None
and tool.function.arguments is not None
):
partial_json += tool.function.arguments
if partial_json is not None:
return "input_json_delta", ContentJsonBlockDelta(
type="input_json_delta", partial_json=partial_json
)
else:
return "text_delta", ContentTextBlockDelta(type="text_delta", text=text)
def translate_streaming_openai_response_to_anthropic(
self, response: ModelResponse
) -> Union[ContentBlockDelta, MessageBlockDelta]:
## base case - final chunk w/ finish reason
if response.choices[0].finish_reason is not None:
delta = MessageDelta(
stop_reason=self._translate_openai_finish_reason_to_anthropic(
response.choices[0].finish_reason
),
)
if getattr(response, "usage", None) is not None:
litellm_usage_chunk: Optional[Usage] = response.usage # type: ignore
elif (
hasattr(response, "_hidden_params")
and "usage" in response._hidden_params
):
litellm_usage_chunk = response._hidden_params["usage"]
else:
litellm_usage_chunk = None
if litellm_usage_chunk is not None:
usage_delta = UsageDelta(
input_tokens=litellm_usage_chunk.prompt_tokens or 0,
output_tokens=litellm_usage_chunk.completion_tokens or 0,
)
else:
usage_delta = UsageDelta(input_tokens=0, output_tokens=0)
return MessageBlockDelta(
type="message_delta", delta=delta, usage=usage_delta
)
(
type_of_content,
content_block_delta,
) = self._translate_streaming_openai_chunk_to_anthropic(
choices=response.choices # type: ignore
)
return ContentBlockDelta(
type="content_block_delta",
index=response.choices[0].index,
delta=content_block_delta,
)

View file

@ -0,0 +1,35 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseAnthropicMessagesConfig(ABC):
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
pass
@abstractmethod
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
return api_base or ""
@abstractmethod
def get_supported_anthropic_messages_params(self, model: str) -> list:
pass

View file

@ -1963,7 +1963,7 @@ class ProxyException(Exception):
code: Optional[Union[int, str]] = None,
headers: Optional[Dict[str, str]] = None,
):
self.message = message
self.message = str(message)
self.type = type
self.param = param

View file

@ -0,0 +1,252 @@
"""
Unified /v1/messages endpoint - (Anthropic Spec)
"""
import asyncio
import json
import time
import traceback
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.responses import StreamingResponse
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.utils import ProxyLogging
router = APIRouter()
async def async_data_generator_anthropic(
response,
user_api_key_dict: UserAPIKeyAuth,
request_data: dict,
proxy_logging_obj: ProxyLogging,
):
verbose_proxy_logger.debug("inside generator")
try:
time.time()
async for chunk in response:
verbose_proxy_logger.debug(
"async_data_generator: received streaming chunk - {}".format(chunk)
)
### CALL HOOKS ### - modify outgoing data
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
user_api_key_dict=user_api_key_dict, response=chunk
)
yield chunk
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}".format(
str(e)
)
)
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=request_data,
)
verbose_proxy_logger.debug(
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
)
if isinstance(e, HTTPException):
raise e
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
proxy_exception = ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
error_returned = json.dumps({"error": proxy_exception.to_dict()})
yield f"data: {error_returned}\n\n"
@router.post(
"/v1/messages",
tags=["[beta] Anthropic `/v1/messages`"],
dependencies=[Depends(user_api_key_auth)],
include_in_schema=False,
)
async def anthropic_response( # noqa: PLR0915
fastapi_response: Response,
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Use `{PROXY_BASE_URL}/anthropic/v1/messages` instead - [Docs](https://docs.litellm.ai/docs/anthropic_completion).
This was a BETA endpoint that calls 100+ LLMs in the anthropic format.
"""
from litellm.proxy.proxy_server import (
general_settings,
get_custom_headers,
llm_router,
proxy_config,
proxy_logging_obj,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
request_data = await _read_request_body(request=request)
data: dict = {**request_data}
try:
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or data.get("model", None) # default passed in http request
)
if user_model:
data["model"] = user_model
data = await add_litellm_data_to_request(
data=data, # type: ignore
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
# override with user settings, these are params passed via cli
if user_temperature:
data["temperature"] = user_temperature
if user_request_timeout:
data["request_timeout"] = user_request_timeout
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]]
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
)
### ROUTE THE REQUESTs ###
router_model_names = llm_router.model_names if llm_router is not None else []
# skip router if user passed their key
if (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data))
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data))
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
llm_response = asyncio.create_task(
llm_router.aanthropic_messages(**data, specific_deployment=True)
)
elif (
llm_router is not None and data["model"] in llm_router.get_model_ids()
): # model in router model list
llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.pattern_router.patterns) > 0
)
): # model in router deployments, calling a specific deployment on the router
llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
llm_response = asyncio.create_task(litellm.anthropic_messages(**data))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "completion: Invalid model name passed in model="
+ data.get("model", "")
},
)
# Await the llm_response task
response = await llm_response
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
verbose_proxy_logger.debug("final response: %s", response)
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
request_data=data,
hidden_params=hidden_params,
)
)
if (
"stream" in data and data["stream"] is True
): # use generate_responses to stream responses
selected_data_generator = async_data_generator_anthropic(
response=response,
user_api_key_dict=user_api_key_dict,
request_data=data,
proxy_logging_obj=proxy_logging_obj,
)
return StreamingResponse(
selected_data_generator, # type: ignore
media_type="text/event-stream",
)
verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response))
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.anthropic_response(): Exception occured - {}".format(
str(e)
)
)
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)

View file

@ -1,9 +1,29 @@
model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: claude-3-5-sonnet-20241022
litellm_params:
model: anthropic/claude-3-5-sonnet-20241022
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: claude-special-alias
litellm_params:
model: anthropic/claude-3-haiku-20240307
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: claude-3-5-sonnet-20241022
litellm_params:
model: anthropic/claude-3-5-sonnet-20241022
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: claude-3-7-sonnet-20250219
litellm_params:
model: anthropic/claude-3-7-sonnet-20250219
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: anthropic/*
litellm_params:
model: anthropic/*
api_key: os.environ/ANTHROPIC_API_KEY
general_settings:
master_key: sk-1234
custom_auth: custom_auth_basic.user_api_key_auth
master_key: sk-1234
custom_auth: custom_auth_basic.user_api_key_auth

View file

@ -4,7 +4,22 @@ model_list:
model: openai/my-fake-model
api_key: my-fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: claude-special-alias
litellm_params:
model: anthropic/claude-3-haiku-20240307
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: claude-3-5-sonnet-20241022
litellm_params:
model: anthropic/claude-3-5-sonnet-20241022
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: claude-3-7-sonnet-20250219
litellm_params:
model: anthropic/claude-3-7-sonnet-20250219
api_key: os.environ/ANTHROPIC_API_KEY
- model_name: anthropic/*
litellm_params:
model: anthropic/*
api_key: os.environ/ANTHROPIC_API_KEY
general_settings:
store_model_in_db: true

View file

@ -120,6 +120,7 @@ from litellm.proxy._types import *
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
router as analytics_router,
)
from litellm.proxy.anthropic_endpoints.endpoints import router as anthropic_router
from litellm.proxy.auth.auth_checks import log_db_metrics
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
from litellm.proxy.auth.handle_jwt import JWTHandler
@ -3065,58 +3066,6 @@ async def async_data_generator(
yield f"data: {error_returned}\n\n"
async def async_data_generator_anthropic(
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
):
verbose_proxy_logger.debug("inside generator")
try:
time.time()
async for chunk in response:
verbose_proxy_logger.debug(
"async_data_generator: received streaming chunk - {}".format(chunk)
)
### CALL HOOKS ### - modify outgoing data
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
user_api_key_dict=user_api_key_dict, response=chunk
)
event_type = chunk.get("type")
try:
yield f"event: {event_type}\ndata:{json.dumps(chunk)}\n\n"
except Exception as e:
yield f"event: {event_type}\ndata:{str(e)}\n\n"
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}".format(
str(e)
)
)
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=request_data,
)
verbose_proxy_logger.debug(
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
)
if isinstance(e, HTTPException):
raise e
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
proxy_exception = ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
error_returned = json.dumps({"error": proxy_exception.to_dict()})
yield f"data: {error_returned}\n\n"
def select_data_generator(
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
):
@ -5524,224 +5473,6 @@ async def moderations(
)
#### ANTHROPIC ENDPOINTS ####
@router.post(
"/v1/messages",
tags=["[beta] Anthropic `/v1/messages`"],
dependencies=[Depends(user_api_key_auth)],
response_model=AnthropicResponse,
include_in_schema=False,
)
async def anthropic_response( # noqa: PLR0915
anthropic_data: AnthropicMessagesRequest,
fastapi_response: Response,
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
🚨 DEPRECATED ENDPOINT🚨
Use `{PROXY_BASE_URL}/anthropic/v1/messages` instead - [Docs](https://docs.litellm.ai/docs/anthropic_completion).
This was a BETA endpoint that calls 100+ LLMs in the anthropic format.
"""
from litellm import adapter_completion
from litellm.adapters.anthropic_adapter import anthropic_adapter
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
request_data = await _read_request_body(request=request)
data: dict = {**request_data, "adapter_id": "anthropic"}
try:
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or data.get("model", None) # default passed in http request
)
if user_model:
data["model"] = user_model
data = await add_litellm_data_to_request(
data=data, # type: ignore
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
# override with user settings, these are params passed via cli
if user_temperature:
data["temperature"] = user_temperature
if user_request_timeout:
data["request_timeout"] = user_request_timeout
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]]
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
)
### ROUTE THE REQUESTs ###
router_model_names = llm_router.model_names if llm_router is not None else []
# skip router if user passed their key
if "api_key" in data:
llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
llm_response = asyncio.create_task(
llm_router.aadapter_completion(**data, specific_deployment=True)
)
elif (
llm_router is not None and data["model"] in llm_router.get_model_ids()
): # model in router model list
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and (
llm_router.default_deployment is not None
or len(llm_router.pattern_router.patterns) > 0
)
): # model in router deployments, calling a specific deployment on the router
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "completion: Invalid model name passed in model="
+ data.get("model", "")
},
)
# Await the llm_response task
response = await llm_response
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
verbose_proxy_logger.debug("final response: %s", response)
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
request_data=data,
hidden_params=hidden_params,
)
)
if (
"stream" in data and data["stream"] is True
): # use generate_responses to stream responses
selected_data_generator = async_data_generator_anthropic(
response=response,
user_api_key_dict=user_api_key_dict,
request_data=data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
)
verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response))
return response
except RejectedRequestError as e:
_data = e.request_data
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
)
if _data.get("stream", None) is not None and _data["stream"] is True:
_chat_response = litellm.ModelResponse()
_usage = litellm.Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
_chat_response.usage = _usage # type: ignore
_chat_response.choices[0].message.content = e.message # type: ignore
_iterator = litellm.utils.ModelResponseIterator(
model_response=_chat_response, convert_to_delta=True
)
_streaming_response = litellm.TextCompletionStreamWrapper(
completion_stream=_iterator,
model=_data.get("model", ""),
)
selected_data_generator = select_data_generator(
response=_streaming_response,
user_api_key_dict=user_api_key_dict,
request_data=data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
headers={},
)
else:
_response = litellm.TextCompletionResponse()
_response.choices[0].text = e.message
return _response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.anthropic_response(): Exception occured - {}".format(
str(e)
)
)
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
#### DEV UTILS ####
# @router.get(
@ -8840,6 +8571,7 @@ app.include_router(rerank_router)
app.include_router(fine_tuning_router)
app.include_router(vertex_router)
app.include_router(llm_passthrough_router)
app.include_router(anthropic_router)
app.include_router(langfuse_router)
app.include_router(pass_through_router)
app.include_router(health_router)

View file

@ -10,6 +10,7 @@ from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
from litellm.proxy.utils import PrismaClient, hash_token
from litellm.types.utils import StandardLoggingPayload
@ -119,9 +120,7 @@ def get_logging_payload( # noqa: PLR0915
response_obj = {}
# standardize this function to be used across, s3, dynamoDB, langfuse logging
litellm_params = kwargs.get("litellm_params", {})
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
metadata = get_litellm_metadata_from_kwargs(kwargs)
metadata = _add_proxy_server_request_to_metadata(
metadata=metadata, litellm_params=litellm_params
)

View file

@ -580,6 +580,9 @@ class Router:
self.amoderation = self.factory_function(
litellm.amoderation, call_type="moderation"
)
self.aanthropic_messages = self.factory_function(
litellm.anthropic_messages, call_type="anthropic_messages"
)
def discard(self):
"""
@ -2349,6 +2352,89 @@ class Router:
self.fail_calls[model] += 1
raise e
async def _ageneric_api_call_with_fallbacks(
self, model: str, original_function: Callable, **kwargs
):
"""
Make a generic LLM API call through the router, this allows you to use retries/fallbacks with litellm router
Args:
model: The model to use
handler_function: The handler function to call (e.g., litellm.anthropic_messages)
**kwargs: Additional arguments to pass to the handler function
Returns:
The response from the handler function
"""
handler_name = original_function.__name__
try:
verbose_router_logger.debug(
f"Inside _ageneric_api_call() - handler: {handler_name}, model: {model}; kwargs: {kwargs}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
deployment = await self.async_get_available_deployment(
model=model,
request_kwargs=kwargs,
messages=kwargs.get("messages", None),
specific_deployment=kwargs.pop("specific_deployment", None),
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
data = deployment["litellm_params"].copy()
model_name = data["model"]
model_client = self._get_async_openai_model_client(
deployment=deployment,
kwargs=kwargs,
)
self.total_calls[model_name] += 1
response = original_function(
**{
**data,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
)
rpm_semaphore = self._get_client(
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response # type: ignore
else:
await self.async_routing_strategy_pre_call_checks(
deployment=deployment, parent_otel_span=parent_otel_span
)
response = await response # type: ignore
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"{handler_name}(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"{handler_name}(model={model})\033[31m Exception {str(e)}\033[0m"
)
if model is not None:
self.fail_calls[model] += 1
raise e
def embedding(
self,
model: str,
@ -2869,10 +2955,14 @@ class Router:
def factory_function(
self,
original_function: Callable,
call_type: Literal["assistants", "moderation"] = "assistants",
call_type: Literal[
"assistants", "moderation", "anthropic_messages"
] = "assistants",
):
async def new_function(
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
custom_llm_provider: Optional[
Literal["openai", "azure", "anthropic"]
] = None,
client: Optional["AsyncOpenAI"] = None,
**kwargs,
):
@ -2889,13 +2979,18 @@ class Router:
original_function=original_function,
**kwargs,
)
elif call_type == "anthropic_messages":
return await self._ageneric_api_call_with_fallbacks( # type: ignore
original_function=original_function,
**kwargs,
)
return new_function
async def _pass_through_assistants_endpoint_factory(
self,
original_function: Callable,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
custom_llm_provider: Optional[Literal["openai", "azure", "anthropic"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
):

View file

@ -186,6 +186,7 @@ class CallTypes(Enum):
aretrieve_batch = "aretrieve_batch"
retrieve_batch = "retrieve_batch"
pass_through = "pass_through_endpoint"
anthropic_messages = "anthropic_messages"
CallTypesLiteral = Literal[
@ -209,6 +210,7 @@ CallTypesLiteral = Literal[
"create_batch",
"acreate_batch",
"pass_through_endpoint",
"anthropic_messages",
]

View file

@ -191,6 +191,9 @@ from typing import (
from openai import OpenAIError as OriginalError
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.llms.base_llm.anthropic_messages.transformation import (
BaseAnthropicMessagesConfig,
)
from litellm.llms.base_llm.audio_transcription.transformation import (
BaseAudioTranscriptionConfig,
)
@ -6245,6 +6248,15 @@ class ProviderConfigManager:
return litellm.JinaAIRerankConfig()
return litellm.CohereRerankConfig()
@staticmethod
def get_provider_anthropic_messages_config(
model: str,
provider: LlmProviders,
) -> Optional[BaseAnthropicMessagesConfig]:
if litellm.LlmProviders.ANTHROPIC == provider:
return litellm.AnthropicMessagesConfig()
return None
@staticmethod
def get_provider_audio_transcription_config(
model: str,

View file

@ -329,57 +329,3 @@ async def test_aaapass_through_endpoint_pass_through_keys_langfuse(
setattr(
litellm.proxy.proxy_server, "proxy_logging_obj", original_proxy_logging_obj
)
@pytest.mark.asyncio
async def test_pass_through_endpoint_anthropic(client):
import litellm
from litellm import Router
from litellm.adapters.anthropic_adapter import anthropic_adapter
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
"mock_response": "Hey, how's it going?",
},
}
]
)
setattr(litellm.proxy.proxy_server, "llm_router", router)
# Define a pass-through endpoint
pass_through_endpoints = [
{
"path": "/v1/test-messages",
"target": anthropic_adapter,
"headers": {"litellm_user_api_key": "my-test-header"},
}
]
# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)
general_settings: Optional[dict] = (
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
)
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
_json_data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Who are you?"}],
}
# Make a request to the pass-through endpoint
response = client.post(
"/v1/test-messages", json=_json_data, headers={"my-test-header": "my-test-key"}
)
print("JSON response: ", _json_data)
# Assert the response
assert response.status_code == 200

View file

@ -0,0 +1,145 @@
from abc import ABC, abstractmethod
import anthropic
import pytest
class BaseAnthropicMessagesTest(ABC):
"""
Abstract base test class that enforces a common test across all test classes.
"""
@abstractmethod
def get_client(self):
return anthropic.Anthropic()
def test_anthropic_basic_completion(self):
print("making basic completion request to anthropic passthrough")
client = self.get_client()
response = client.messages.create(
model="claude-3-5-sonnet-20241022",
max_tokens=1024,
messages=[{"role": "user", "content": "Say 'hello test' and nothing else"}],
extra_body={
"litellm_metadata": {
"tags": ["test-tag-1", "test-tag-2"],
}
},
)
print(response)
def test_anthropic_streaming(self):
print("making streaming request to anthropic passthrough")
collected_output = []
client = self.get_client()
with client.messages.stream(
max_tokens=10,
messages=[
{"role": "user", "content": "Say 'hello stream test' and nothing else"}
],
model="claude-3-5-sonnet-20241022",
extra_body={
"litellm_metadata": {
"tags": ["test-tag-stream-1", "test-tag-stream-2"],
}
},
) as stream:
for text in stream.text_stream:
collected_output.append(text)
full_response = "".join(collected_output)
print(full_response)
def test_anthropic_messages_with_thinking(self):
print("making request to anthropic passthrough with thinking")
client = self.get_client()
response = client.messages.create(
model="claude-3-7-sonnet-20250219",
max_tokens=20000,
thinking={"type": "enabled", "budget_tokens": 16000},
messages=[
{"role": "user", "content": "Just pinging with thinking enabled"}
],
)
print(response)
# Verify the first content block is a thinking block
response_thinking = response.content[0].thinking
assert response_thinking is not None
assert len(response_thinking) > 0
def test_anthropic_streaming_with_thinking(self):
print("making streaming request to anthropic passthrough with thinking enabled")
collected_thinking = []
collected_response = []
client = self.get_client()
with client.messages.stream(
model="claude-3-7-sonnet-20250219",
max_tokens=20000,
thinking={"type": "enabled", "budget_tokens": 16000},
messages=[
{"role": "user", "content": "Just pinging with thinking enabled"}
],
) as stream:
for event in stream:
if event.type == "content_block_delta":
if event.delta.type == "thinking_delta":
collected_thinking.append(event.delta.thinking)
elif event.delta.type == "text_delta":
collected_response.append(event.delta.text)
full_thinking = "".join(collected_thinking)
full_response = "".join(collected_response)
print(
f"Thinking Response: {full_thinking[:100]}..."
) # Print first 100 chars of thinking
print(f"Response: {full_response}")
# Verify we received thinking content
assert len(collected_thinking) > 0
assert len(full_thinking) > 0
# Verify we also received a response
assert len(collected_response) > 0
assert len(full_response) > 0
def test_bad_request_error_handling_streaming(self):
print("making request to anthropic passthrough with bad request")
try:
client = self.get_client()
response = client.messages.create(
model="claude-3-5-sonnet-20241022",
max_tokens=10,
stream=True,
messages=["hi"],
)
print(response)
assert pytest.fail("Expected BadRequestError")
except anthropic.BadRequestError as e:
print("Got BadRequestError from anthropic, e=", e)
print(e.__cause__)
print(e.status_code)
print(e.response)
except Exception as e:
pytest.fail(f"Got unexpected exception: {e}")
def test_bad_request_error_handling_non_streaming(self):
print("making request to anthropic passthrough with bad request")
try:
client = self.get_client()
response = client.messages.create(
model="claude-3-5-sonnet-20241022",
max_tokens=10,
messages=["hi"],
)
print(response)
assert pytest.fail("Expected BadRequestError")
except anthropic.BadRequestError as e:
print("Got BadRequestError from anthropic, e=", e)
print(e.__cause__)
print(e.status_code)
print(e.response)
except Exception as e:
pytest.fail(f"Got unexpected exception: {e}")

View file

@ -8,48 +8,6 @@ import aiohttp
import asyncio
import json
client = anthropic.Anthropic(
base_url="http://0.0.0.0:4000/anthropic", api_key="sk-1234"
)
def test_anthropic_basic_completion():
print("making basic completion request to anthropic passthrough")
response = client.messages.create(
model="claude-3-5-sonnet-20241022",
max_tokens=1024,
messages=[{"role": "user", "content": "Say 'hello test' and nothing else"}],
extra_body={
"litellm_metadata": {
"tags": ["test-tag-1", "test-tag-2"],
}
},
)
print(response)
def test_anthropic_streaming():
print("making streaming request to anthropic passthrough")
collected_output = []
with client.messages.stream(
max_tokens=10,
messages=[
{"role": "user", "content": "Say 'hello stream test' and nothing else"}
],
model="claude-3-5-sonnet-20241022",
extra_body={
"litellm_metadata": {
"tags": ["test-tag-stream-1", "test-tag-stream-2"],
}
},
) as stream:
for text in stream.text_stream:
collected_output.append(text)
full_response = "".join(collected_output)
print(full_response)
@pytest.mark.asyncio
async def test_anthropic_basic_completion_with_headers():

View file

@ -0,0 +1,28 @@
from base_anthropic_messages_test import BaseAnthropicMessagesTest
import anthropic
class TestAnthropicPassthroughBasic(BaseAnthropicMessagesTest):
def get_client(self):
return anthropic.Anthropic(
base_url="http://0.0.0.0:4000/anthropic",
api_key="sk-1234",
)
class TestAnthropicMessagesEndpoint(BaseAnthropicMessagesTest):
def get_client(self):
return anthropic.Anthropic(
base_url="http://0.0.0.0:4000",
api_key="sk-1234",
)
def test_anthropic_messages_to_wildcard_model(self):
client = self.get_client()
response = client.messages.create(
model="anthropic/claude-3-opus-20240229",
messages=[{"role": "user", "content": "Hello, world!"}],
max_tokens=100,
)
print(response)

View file

@ -0,0 +1,487 @@
import json
import os
import sys
from datetime import datetime
from typing import AsyncIterator, Dict, Any
import asyncio
import unittest.mock
from unittest.mock import AsyncMock, MagicMock
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
import pytest
from dotenv import load_dotenv
from litellm.llms.anthropic.experimental_pass_through.messages.handler import (
anthropic_messages,
)
from typing import Optional
from litellm.types.utils import StandardLoggingPayload
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.router import Router
import importlib
# Load environment variables
load_dotenv()
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for each test session."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown(event_loop): # Add event_loop as a dependency
curr_dir = os.getcwd()
sys.path.insert(0, os.path.abspath("../.."))
import litellm
from litellm import Router
importlib.reload(litellm)
# Set the event loop from the fixture
asyncio.set_event_loop(event_loop)
print(litellm)
yield
# Clean up any pending tasks
pending = asyncio.all_tasks(event_loop)
for task in pending:
task.cancel()
# Run the event loop until all tasks are cancelled
if pending:
event_loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
def _validate_anthropic_response(response: Dict[str, Any]):
assert "id" in response
assert "content" in response
assert "model" in response
assert response["role"] == "assistant"
@pytest.mark.asyncio
async def test_anthropic_messages_non_streaming():
"""
Test the anthropic_messages with non-streaming request
"""
# Get API key from environment
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
pytest.skip("ANTHROPIC_API_KEY not found in environment")
# Set up test parameters
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
# Call the handler
response = await anthropic_messages(
messages=messages,
api_key=api_key,
model="claude-3-haiku-20240307",
max_tokens=100,
)
# Verify response
assert "id" in response
assert "content" in response
assert "model" in response
assert response["role"] == "assistant"
print(f"Non-streaming response: {json.dumps(response, indent=2)}")
return response
@pytest.mark.asyncio
async def test_anthropic_messages_streaming():
"""
Test the anthropic_messages with streaming request
"""
# Get API key from environment
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
pytest.skip("ANTHROPIC_API_KEY not found in environment")
# Set up test parameters
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
# Call the handler
async_httpx_client = AsyncHTTPHandler()
response = await anthropic_messages(
messages=messages,
api_key=api_key,
model="claude-3-haiku-20240307",
max_tokens=100,
stream=True,
client=async_httpx_client,
)
if isinstance(response, AsyncIterator):
async for chunk in response:
print("chunk=", chunk)
@pytest.mark.asyncio
async def test_anthropic_messages_streaming_with_bad_request():
"""
Test the anthropic_messages with streaming request
"""
try:
response = await anthropic_messages(
messages=["hi"],
api_key=os.getenv("ANTHROPIC_API_KEY"),
model="claude-3-haiku-20240307",
max_tokens=100,
stream=True,
)
print(response)
async for chunk in response:
print("chunk=", chunk)
except Exception as e:
print("got exception", e)
print("vars", vars(e))
assert e.status_code == 400
@pytest.mark.asyncio
async def test_anthropic_messages_router_streaming_with_bad_request():
"""
Test the anthropic_messages with streaming request
"""
try:
router = Router(
model_list=[
{
"model_name": "claude-special-alias",
"litellm_params": {
"model": "claude-3-haiku-20240307",
"api_key": os.getenv("ANTHROPIC_API_KEY"),
},
}
]
)
response = await router.aanthropic_messages(
messages=["hi"],
model="claude-special-alias",
max_tokens=100,
stream=True,
)
print(response)
async for chunk in response:
print("chunk=", chunk)
except Exception as e:
print("got exception", e)
print("vars", vars(e))
assert e.status_code == 400
@pytest.mark.asyncio
async def test_anthropic_messages_litellm_router_non_streaming():
"""
Test the anthropic_messages with non-streaming request
"""
litellm._turn_on_debug()
router = Router(
model_list=[
{
"model_name": "claude-special-alias",
"litellm_params": {
"model": "claude-3-haiku-20240307",
"api_key": os.getenv("ANTHROPIC_API_KEY"),
},
}
]
)
# Set up test parameters
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
# Call the handler
response = await router.aanthropic_messages(
messages=messages,
model="claude-special-alias",
max_tokens=100,
)
# Verify response
assert "id" in response
assert "content" in response
assert "model" in response
assert response["role"] == "assistant"
print(f"Non-streaming response: {json.dumps(response, indent=2)}")
return response
class TestCustomLogger(CustomLogger):
def __init__(self):
super().__init__()
self.logged_standard_logging_payload: Optional[StandardLoggingPayload] = None
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print("inside async_log_success_event")
self.logged_standard_logging_payload = kwargs.get("standard_logging_object")
pass
@pytest.mark.asyncio
async def test_anthropic_messages_litellm_router_non_streaming_with_logging():
"""
Test the anthropic_messages with non-streaming request
- Ensure Cost + Usage is tracked
"""
test_custom_logger = TestCustomLogger()
litellm.callbacks = [test_custom_logger]
litellm._turn_on_debug()
router = Router(
model_list=[
{
"model_name": "claude-special-alias",
"litellm_params": {
"model": "claude-3-haiku-20240307",
"api_key": os.getenv("ANTHROPIC_API_KEY"),
},
}
]
)
# Set up test parameters
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
# Call the handler
response = await router.aanthropic_messages(
messages=messages,
model="claude-special-alias",
max_tokens=100,
)
# Verify response
_validate_anthropic_response(response)
print(f"Non-streaming response: {json.dumps(response, indent=2)}")
await asyncio.sleep(1)
assert test_custom_logger.logged_standard_logging_payload["messages"] == messages
assert test_custom_logger.logged_standard_logging_payload["response"] is not None
assert (
test_custom_logger.logged_standard_logging_payload["model"]
== "claude-3-haiku-20240307"
)
# check logged usage + spend
assert test_custom_logger.logged_standard_logging_payload["response_cost"] > 0
assert (
test_custom_logger.logged_standard_logging_payload["prompt_tokens"]
== response["usage"]["input_tokens"]
)
assert (
test_custom_logger.logged_standard_logging_payload["completion_tokens"]
== response["usage"]["output_tokens"]
)
@pytest.mark.asyncio
async def test_anthropic_messages_litellm_router_streaming_with_logging():
"""
Test the anthropic_messages with streaming request
- Ensure Cost + Usage is tracked
"""
test_custom_logger = TestCustomLogger()
litellm.callbacks = [test_custom_logger]
# litellm._turn_on_debug()
router = Router(
model_list=[
{
"model_name": "claude-special-alias",
"litellm_params": {
"model": "claude-3-haiku-20240307",
"api_key": os.getenv("ANTHROPIC_API_KEY"),
},
}
]
)
# Set up test parameters
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
# Call the handler
response = await router.aanthropic_messages(
messages=messages,
model="claude-special-alias",
max_tokens=100,
stream=True,
)
response_prompt_tokens = 0
response_completion_tokens = 0
all_anthropic_usage_chunks = []
async for chunk in response:
# Decode chunk if it's bytes
print("chunk=", chunk)
# Handle SSE format chunks
if isinstance(chunk, bytes):
chunk_str = chunk.decode("utf-8")
# Extract the JSON data part from SSE format
for line in chunk_str.split("\n"):
if line.startswith("data: "):
try:
json_data = json.loads(line[6:]) # Skip the 'data: ' prefix
print(
"\n\nJSON data:",
json.dumps(json_data, indent=4, default=str),
)
# Extract usage information
if (
json_data.get("type") == "message_start"
and "message" in json_data
):
if "usage" in json_data["message"]:
usage = json_data["message"]["usage"]
all_anthropic_usage_chunks.append(usage)
print(
"USAGE BLOCK",
json.dumps(usage, indent=4, default=str),
)
elif "usage" in json_data:
usage = json_data["usage"]
all_anthropic_usage_chunks.append(usage)
print(
"USAGE BLOCK", json.dumps(usage, indent=4, default=str)
)
except json.JSONDecodeError:
print(f"Failed to parse JSON from: {line[6:]}")
elif hasattr(chunk, "message"):
if chunk.message.usage:
print(
"USAGE BLOCK",
json.dumps(chunk.message.usage, indent=4, default=str),
)
all_anthropic_usage_chunks.append(chunk.message.usage)
elif hasattr(chunk, "usage"):
print("USAGE BLOCK", json.dumps(chunk.usage, indent=4, default=str))
all_anthropic_usage_chunks.append(chunk.usage)
print(
"all_anthropic_usage_chunks",
json.dumps(all_anthropic_usage_chunks, indent=4, default=str),
)
# Extract token counts from usage data
if all_anthropic_usage_chunks:
response_prompt_tokens = max(
[usage.get("input_tokens", 0) for usage in all_anthropic_usage_chunks]
)
response_completion_tokens = max(
[usage.get("output_tokens", 0) for usage in all_anthropic_usage_chunks]
)
print("input_tokens_anthropic_api", response_prompt_tokens)
print("output_tokens_anthropic_api", response_completion_tokens)
await asyncio.sleep(4)
print(
"logged_standard_logging_payload",
json.dumps(
test_custom_logger.logged_standard_logging_payload, indent=4, default=str
),
)
assert test_custom_logger.logged_standard_logging_payload["messages"] == messages
assert test_custom_logger.logged_standard_logging_payload["response"] is not None
assert (
test_custom_logger.logged_standard_logging_payload["model"]
== "claude-3-haiku-20240307"
)
# check logged usage + spend
assert test_custom_logger.logged_standard_logging_payload["response_cost"] > 0
assert (
test_custom_logger.logged_standard_logging_payload["prompt_tokens"]
== response_prompt_tokens
)
assert (
test_custom_logger.logged_standard_logging_payload["completion_tokens"]
== response_completion_tokens
)
@pytest.mark.asyncio
async def test_anthropic_messages_with_extra_headers():
"""
Test the anthropic_messages with extra headers
"""
# Get API key from environment
api_key = os.getenv("ANTHROPIC_API_KEY", "fake-api-key")
# Set up test parameters
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
extra_headers = {
"anthropic-beta": "very-custom-beta-value",
"anthropic-version": "custom-version-for-test",
}
# Create a mock response
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
mock_response.json.return_value = {
"id": "msg_123456",
"type": "message",
"role": "assistant",
"content": [
{
"type": "text",
"text": "Why did the chicken cross the road? To get to the other side!",
}
],
"model": "claude-3-haiku-20240307",
"stop_reason": "end_turn",
"usage": {"input_tokens": 10, "output_tokens": 20},
}
# Create a mock client with AsyncMock for the post method
mock_client = MagicMock(spec=AsyncHTTPHandler)
mock_client.post = AsyncMock(return_value=mock_response)
# Call the handler with extra_headers and our mocked client
response = await anthropic_messages(
messages=messages,
api_key=api_key,
model="claude-3-haiku-20240307",
max_tokens=100,
client=mock_client,
provider_specific_header={
"custom_llm_provider": "anthropic",
"extra_headers": extra_headers,
},
)
# Verify the post method was called with the right parameters
mock_client.post.assert_called_once()
call_kwargs = mock_client.post.call_args.kwargs
# Verify headers were passed correctly
headers = call_kwargs.get("headers", {})
print("HEADERS IN REQUEST", headers)
for key, value in extra_headers.items():
assert key in headers
assert headers[key] == value
# Verify the response was processed correctly
assert response == mock_response.json.return_value
return response

View file

@ -54,7 +54,7 @@ async def test_get_litellm_virtual_key():
@pytest.mark.asyncio
async def test_vertex_proxy_route_api_key_auth():
async def test_async_vertex_proxy_route_api_key_auth():
"""
Critical
@ -207,7 +207,7 @@ async def test_get_vertex_credentials_stored():
router.add_vertex_credentials(
project_id="test-project",
location="us-central1",
vertex_credentials="test-creds",
vertex_credentials='{"credentials": "test-creds"}',
)
creds = router.get_vertex_credentials(
@ -215,7 +215,7 @@ async def test_get_vertex_credentials_stored():
)
assert creds.vertex_project == "test-project"
assert creds.vertex_location == "us-central1"
assert creds.vertex_credentials == "test-creds"
assert creds.vertex_credentials == '{"credentials": "test-creds"}'
@pytest.mark.asyncio
@ -227,18 +227,20 @@ async def test_add_vertex_credentials():
router.add_vertex_credentials(
project_id="test-project",
location="us-central1",
vertex_credentials="test-creds",
vertex_credentials='{"credentials": "test-creds"}',
)
assert "test-project-us-central1" in router.deployment_key_to_vertex_credentials
creds = router.deployment_key_to_vertex_credentials["test-project-us-central1"]
assert creds.vertex_project == "test-project"
assert creds.vertex_location == "us-central1"
assert creds.vertex_credentials == "test-creds"
assert creds.vertex_credentials == '{"credentials": "test-creds"}'
# Test adding with None values
router.add_vertex_credentials(
project_id=None, location=None, vertex_credentials="test-creds"
project_id=None,
location=None,
vertex_credentials='{"credentials": "test-creds"}',
)
# Should not add None values
assert len(router.deployment_key_to_vertex_credentials) == 1

View file

@ -6,6 +6,7 @@ from typing import Optional
from dotenv import load_dotenv
from fastapi import Request
from datetime import datetime
from unittest.mock import AsyncMock, patch
sys.path.insert(
0, os.path.abspath("../..")
@ -289,43 +290,6 @@ async def test_aaaaatext_completion_endpoint(model_list, sync_mode):
assert response.choices[0].text == "I'm fine, thank you!"
@pytest.mark.asyncio
async def test_anthropic_router_completion_e2e(model_list):
from litellm.adapters.anthropic_adapter import anthropic_adapter
from litellm.types.llms.anthropic import AnthropicResponse
litellm.set_verbose = True
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
router = Router(model_list=model_list)
messages = [{"role": "user", "content": "Hey, how's it going?"}]
## Test 1: user facing function
response = await router.aadapter_completion(
model="claude-3-5-sonnet-20240620",
messages=messages,
adapter_id="anthropic",
mock_response="This is a fake call",
)
## Test 2: underlying function
await router._aadapter_completion(
model="claude-3-5-sonnet-20240620",
messages=messages,
adapter_id="anthropic",
mock_response="This is a fake call",
)
print("Response: {}".format(response))
assert response is not None
AnthropicResponse.model_validate(response)
assert response.model == "gpt-3.5-turbo"
@pytest.mark.asyncio
async def test_router_with_empty_choices(model_list):
"""
@ -349,3 +313,200 @@ async def test_router_with_empty_choices(model_list):
mock_response=mock_response,
)
assert response is not None
@pytest.mark.asyncio
async def test_ageneric_api_call_with_fallbacks_basic():
"""
Test the _ageneric_api_call_with_fallbacks method with a basic successful call
"""
# Create a mock function that will be passed to _ageneric_api_call_with_fallbacks
mock_function = AsyncMock()
mock_function.__name__ = "test_function"
# Create a mock response
mock_response = {
"id": "resp_123456",
"role": "assistant",
"content": "This is a test response",
"model": "test-model",
"usage": {"input_tokens": 10, "output_tokens": 20},
}
mock_function.return_value = mock_response
# Create a router with a test model
router = Router(
model_list=[
{
"model_name": "test-model-alias",
"litellm_params": {
"model": "anthropic/test-model",
"api_key": "fake-api-key",
},
}
]
)
# Call the _ageneric_api_call_with_fallbacks method
response = await router._ageneric_api_call_with_fallbacks(
model="test-model-alias",
original_function=mock_function,
messages=[{"role": "user", "content": "Hello"}],
max_tokens=100,
)
# Verify the mock function was called
mock_function.assert_called_once()
# Verify the response
assert response == mock_response
@pytest.mark.asyncio
async def test_aadapter_completion():
"""
Test the aadapter_completion method which uses async_function_with_fallbacks
"""
# Create a mock for the _aadapter_completion method
mock_response = {
"id": "adapter_resp_123",
"object": "adapter.completion",
"created": 1677858242,
"model": "test-model-with-adapter",
"choices": [
{
"text": "This is a test adapter response",
"index": 0,
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
}
# Create a router with a patched _aadapter_completion method
with patch.object(
Router, "_aadapter_completion", new_callable=AsyncMock
) as mock_method:
mock_method.return_value = mock_response
router = Router(
model_list=[
{
"model_name": "test-adapter-model",
"litellm_params": {
"model": "anthropic/test-model",
"api_key": "fake-api-key",
},
}
]
)
# Replace the async_function_with_fallbacks with a mock
router.async_function_with_fallbacks = AsyncMock(return_value=mock_response)
# Call the aadapter_completion method
response = await router.aadapter_completion(
adapter_id="test-adapter-id",
model="test-adapter-model",
prompt="This is a test prompt",
max_tokens=100,
)
# Verify the response
assert response == mock_response
# Verify async_function_with_fallbacks was called with the right parameters
router.async_function_with_fallbacks.assert_called_once()
call_kwargs = router.async_function_with_fallbacks.call_args.kwargs
assert call_kwargs["adapter_id"] == "test-adapter-id"
assert call_kwargs["model"] == "test-adapter-model"
assert call_kwargs["prompt"] == "This is a test prompt"
assert call_kwargs["max_tokens"] == 100
assert call_kwargs["original_function"] == router._aadapter_completion
assert "metadata" in call_kwargs
assert call_kwargs["metadata"]["model_group"] == "test-adapter-model"
@pytest.mark.asyncio
async def test__aadapter_completion():
"""
Test the _aadapter_completion method directly
"""
# Create a mock response for litellm.aadapter_completion
mock_response = {
"id": "adapter_resp_123",
"object": "adapter.completion",
"created": 1677858242,
"model": "test-model-with-adapter",
"choices": [
{
"text": "This is a test adapter response",
"index": 0,
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
}
# Create a router with a mocked litellm.aadapter_completion
with patch(
"litellm.aadapter_completion", new_callable=AsyncMock
) as mock_adapter_completion:
mock_adapter_completion.return_value = mock_response
router = Router(
model_list=[
{
"model_name": "test-adapter-model",
"litellm_params": {
"model": "anthropic/test-model",
"api_key": "fake-api-key",
},
}
]
)
# Mock the async_get_available_deployment method
router.async_get_available_deployment = AsyncMock(
return_value={
"model_name": "test-adapter-model",
"litellm_params": {
"model": "test-model",
"api_key": "fake-api-key",
},
"model_info": {
"id": "test-unique-id",
},
}
)
# Mock the async_routing_strategy_pre_call_checks method
router.async_routing_strategy_pre_call_checks = AsyncMock()
# Call the _aadapter_completion method
response = await router._aadapter_completion(
adapter_id="test-adapter-id",
model="test-adapter-model",
prompt="This is a test prompt",
max_tokens=100,
)
# Verify the response
assert response == mock_response
# Verify litellm.aadapter_completion was called with the right parameters
mock_adapter_completion.assert_called_once()
call_kwargs = mock_adapter_completion.call_args.kwargs
assert call_kwargs["adapter_id"] == "test-adapter-id"
assert call_kwargs["model"] == "test-model"
assert call_kwargs["prompt"] == "This is a test prompt"
assert call_kwargs["max_tokens"] == 100
assert call_kwargs["api_key"] == "fake-api-key"
assert call_kwargs["caching"] == router.cache_responses
# Verify the success call was recorded
assert router.success_calls["test-model"] == 1
assert router.total_calls["test-model"] == 1
# Verify async_routing_strategy_pre_call_checks was called
router.async_routing_strategy_pre_call_checks.assert_called_once()