mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(Refactor) /v1/messages
to follow simpler logic for Anthropic API spec (#9013)
* anthropic_messages_handler v0 * fix /messages * working messages with router methods * test_anthropic_messages_handler_litellm_router_non_streaming * test_anthropic_messages_litellm_router_non_streaming_with_logging * AnthropicMessagesConfig * _handle_anthropic_messages_response_logging * working with /v1/messages endpoint * working /v1/messages endpoint * refactor to use router factory function * use aanthropic_messages * use BaseConfig for Anthropic /v1/messages * track api key, team on /v1/messages endpoint * fix get_logging_payload * BaseAnthropicMessagesTest * align test config * test_anthropic_messages_with_thinking * test_anthropic_streaming_with_thinking * fix - display anthropic url for debugging * test_bad_request_error_handling * test_anthropic_messages_router_streaming_with_bad_request * fix ProxyException * test_bad_request_error_handling_streaming * use provider_specific_header * test_anthropic_messages_with_extra_headers * test_anthropic_messages_to_wildcard_model * fix gcs pub sub test * standard_logging_payload * fix unit testing for anthopic /v1/messages support * fix pass through anthropic messages api * delete dead code * fix anthropic pass through response * revert change to spend tracking utils * fix get_litellm_metadata_from_kwargs * fix spend logs payload json * proxy_pass_through_endpoint_tests * TestAnthropicPassthroughBasic * fix pass through tests * test_async_vertex_proxy_route_api_key_auth * _handle_anthropic_messages_response_logging * vertex_credentials * test_set_default_vertex_config * test_anthropic_messages_litellm_router_non_streaming_with_logging * test_ageneric_api_call_with_fallbacks_basic * test__aadapter_completion
This commit is contained in:
parent
31c5ea74ab
commit
f47987e673
25 changed files with 1581 additions and 1027 deletions
|
@ -1935,12 +1935,12 @@ jobs:
|
||||||
pip install prisma
|
pip install prisma
|
||||||
pip install fastapi
|
pip install fastapi
|
||||||
pip install jsonschema
|
pip install jsonschema
|
||||||
pip install "httpx==0.24.1"
|
pip install "httpx==0.27.0"
|
||||||
pip install "anyio==3.7.1"
|
pip install "anyio==3.7.1"
|
||||||
pip install "asyncio==3.4.3"
|
pip install "asyncio==3.4.3"
|
||||||
pip install "PyGithub==1.59.1"
|
pip install "PyGithub==1.59.1"
|
||||||
pip install "google-cloud-aiplatform==1.59.0"
|
pip install "google-cloud-aiplatform==1.59.0"
|
||||||
pip install "anthropic==0.21.3"
|
pip install "anthropic==0.49.0"
|
||||||
# Run pytest and generate JUnit XML report
|
# Run pytest and generate JUnit XML report
|
||||||
- run:
|
- run:
|
||||||
name: Build Docker image
|
name: Build Docker image
|
||||||
|
|
|
@ -800,9 +800,6 @@ from .llms.oobabooga.chat.transformation import OobaboogaConfig
|
||||||
from .llms.maritalk import MaritalkConfig
|
from .llms.maritalk import MaritalkConfig
|
||||||
from .llms.openrouter.chat.transformation import OpenrouterConfig
|
from .llms.openrouter.chat.transformation import OpenrouterConfig
|
||||||
from .llms.anthropic.chat.transformation import AnthropicConfig
|
from .llms.anthropic.chat.transformation import AnthropicConfig
|
||||||
from .llms.anthropic.experimental_pass_through.transformation import (
|
|
||||||
AnthropicExperimentalPassThroughConfig,
|
|
||||||
)
|
|
||||||
from .llms.groq.stt.transformation import GroqSTTConfig
|
from .llms.groq.stt.transformation import GroqSTTConfig
|
||||||
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
||||||
from .llms.triton.completion.transformation import TritonConfig
|
from .llms.triton.completion.transformation import TritonConfig
|
||||||
|
@ -821,6 +818,9 @@ from .llms.infinity.rerank.transformation import InfinityRerankConfig
|
||||||
from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig
|
from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig
|
||||||
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
||||||
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
|
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
|
||||||
|
from .llms.anthropic.experimental_pass_through.messages.transformation import (
|
||||||
|
AnthropicMessagesConfig,
|
||||||
|
)
|
||||||
from .llms.together_ai.chat import TogetherAIConfig
|
from .llms.together_ai.chat import TogetherAIConfig
|
||||||
from .llms.together_ai.completion.transformation import TogetherAITextCompletionConfig
|
from .llms.together_ai.completion.transformation import TogetherAITextCompletionConfig
|
||||||
from .llms.cloudflare.chat.transformation import CloudflareChatConfig
|
from .llms.cloudflare.chat.transformation import CloudflareChatConfig
|
||||||
|
@ -1011,6 +1011,7 @@ from .assistants.main import *
|
||||||
from .batches.main import *
|
from .batches.main import *
|
||||||
from .batch_completion.main import * # type: ignore
|
from .batch_completion.main import * # type: ignore
|
||||||
from .rerank_api.main import *
|
from .rerank_api.main import *
|
||||||
|
from .llms.anthropic.experimental_pass_through.messages.handler import *
|
||||||
from .realtime_api.main import _arealtime
|
from .realtime_api.main import _arealtime
|
||||||
from .fine_tuning.main import *
|
from .fine_tuning.main import *
|
||||||
from .files.main import *
|
from .files.main import *
|
||||||
|
|
|
@ -1,186 +0,0 @@
|
||||||
# What is this?
|
|
||||||
## Translates OpenAI call to Anthropic `/v1/messages` format
|
|
||||||
import traceback
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
from litellm import ChatCompletionRequest, verbose_logger
|
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
|
||||||
from litellm.types.llms.anthropic import AnthropicMessagesRequest, AnthropicResponse
|
|
||||||
from litellm.types.utils import AdapterCompletionStreamWrapper, ModelResponse
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicAdapter(CustomLogger):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def translate_completion_input_params(
|
|
||||||
self, kwargs
|
|
||||||
) -> Optional[ChatCompletionRequest]:
|
|
||||||
"""
|
|
||||||
- translate params, where needed
|
|
||||||
- pass rest, as is
|
|
||||||
"""
|
|
||||||
request_body = AnthropicMessagesRequest(**kwargs) # type: ignore
|
|
||||||
|
|
||||||
translated_body = litellm.AnthropicExperimentalPassThroughConfig().translate_anthropic_to_openai(
|
|
||||||
anthropic_message_request=request_body
|
|
||||||
)
|
|
||||||
|
|
||||||
return translated_body
|
|
||||||
|
|
||||||
def translate_completion_output_params(
|
|
||||||
self, response: ModelResponse
|
|
||||||
) -> Optional[AnthropicResponse]:
|
|
||||||
|
|
||||||
return litellm.AnthropicExperimentalPassThroughConfig().translate_openai_response_to_anthropic(
|
|
||||||
response=response
|
|
||||||
)
|
|
||||||
|
|
||||||
def translate_completion_output_params_streaming(
|
|
||||||
self, completion_stream: Any
|
|
||||||
) -> AdapterCompletionStreamWrapper | None:
|
|
||||||
return AnthropicStreamWrapper(completion_stream=completion_stream)
|
|
||||||
|
|
||||||
|
|
||||||
anthropic_adapter = AnthropicAdapter()
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicStreamWrapper(AdapterCompletionStreamWrapper):
|
|
||||||
"""
|
|
||||||
- first chunk return 'message_start'
|
|
||||||
- content block must be started and stopped
|
|
||||||
- finish_reason must map exactly to anthropic reason, else anthropic client won't be able to parse it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
sent_first_chunk: bool = False
|
|
||||||
sent_content_block_start: bool = False
|
|
||||||
sent_content_block_finish: bool = False
|
|
||||||
sent_last_message: bool = False
|
|
||||||
holding_chunk: Optional[Any] = None
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
try:
|
|
||||||
if self.sent_first_chunk is False:
|
|
||||||
self.sent_first_chunk = True
|
|
||||||
return {
|
|
||||||
"type": "message_start",
|
|
||||||
"message": {
|
|
||||||
"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY",
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [],
|
|
||||||
"model": "claude-3-5-sonnet-20240620",
|
|
||||||
"stop_reason": None,
|
|
||||||
"stop_sequence": None,
|
|
||||||
"usage": {"input_tokens": 25, "output_tokens": 1},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if self.sent_content_block_start is False:
|
|
||||||
self.sent_content_block_start = True
|
|
||||||
return {
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": 0,
|
|
||||||
"content_block": {"type": "text", "text": ""},
|
|
||||||
}
|
|
||||||
|
|
||||||
for chunk in self.completion_stream:
|
|
||||||
if chunk == "None" or chunk is None:
|
|
||||||
raise Exception
|
|
||||||
|
|
||||||
processed_chunk = litellm.AnthropicExperimentalPassThroughConfig().translate_streaming_openai_response_to_anthropic(
|
|
||||||
response=chunk
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
processed_chunk["type"] == "message_delta"
|
|
||||||
and self.sent_content_block_finish is False
|
|
||||||
):
|
|
||||||
self.holding_chunk = processed_chunk
|
|
||||||
self.sent_content_block_finish = True
|
|
||||||
return {
|
|
||||||
"type": "content_block_stop",
|
|
||||||
"index": 0,
|
|
||||||
}
|
|
||||||
elif self.holding_chunk is not None:
|
|
||||||
return_chunk = self.holding_chunk
|
|
||||||
self.holding_chunk = processed_chunk
|
|
||||||
return return_chunk
|
|
||||||
else:
|
|
||||||
return processed_chunk
|
|
||||||
if self.holding_chunk is not None:
|
|
||||||
return_chunk = self.holding_chunk
|
|
||||||
self.holding_chunk = None
|
|
||||||
return return_chunk
|
|
||||||
if self.sent_last_message is False:
|
|
||||||
self.sent_last_message = True
|
|
||||||
return {"type": "message_stop"}
|
|
||||||
raise StopIteration
|
|
||||||
except StopIteration:
|
|
||||||
if self.sent_last_message is False:
|
|
||||||
self.sent_last_message = True
|
|
||||||
return {"type": "message_stop"}
|
|
||||||
raise StopIteration
|
|
||||||
except Exception as e:
|
|
||||||
verbose_logger.error(
|
|
||||||
"Anthropic Adapter - {}\n{}".format(e, traceback.format_exc())
|
|
||||||
)
|
|
||||||
|
|
||||||
async def __anext__(self):
|
|
||||||
try:
|
|
||||||
if self.sent_first_chunk is False:
|
|
||||||
self.sent_first_chunk = True
|
|
||||||
return {
|
|
||||||
"type": "message_start",
|
|
||||||
"message": {
|
|
||||||
"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY",
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [],
|
|
||||||
"model": "claude-3-5-sonnet-20240620",
|
|
||||||
"stop_reason": None,
|
|
||||||
"stop_sequence": None,
|
|
||||||
"usage": {"input_tokens": 25, "output_tokens": 1},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if self.sent_content_block_start is False:
|
|
||||||
self.sent_content_block_start = True
|
|
||||||
return {
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": 0,
|
|
||||||
"content_block": {"type": "text", "text": ""},
|
|
||||||
}
|
|
||||||
async for chunk in self.completion_stream:
|
|
||||||
if chunk == "None" or chunk is None:
|
|
||||||
raise Exception
|
|
||||||
processed_chunk = litellm.AnthropicExperimentalPassThroughConfig().translate_streaming_openai_response_to_anthropic(
|
|
||||||
response=chunk
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
processed_chunk["type"] == "message_delta"
|
|
||||||
and self.sent_content_block_finish is False
|
|
||||||
):
|
|
||||||
self.holding_chunk = processed_chunk
|
|
||||||
self.sent_content_block_finish = True
|
|
||||||
return {
|
|
||||||
"type": "content_block_stop",
|
|
||||||
"index": 0,
|
|
||||||
}
|
|
||||||
elif self.holding_chunk is not None:
|
|
||||||
return_chunk = self.holding_chunk
|
|
||||||
self.holding_chunk = processed_chunk
|
|
||||||
return return_chunk
|
|
||||||
else:
|
|
||||||
return processed_chunk
|
|
||||||
if self.holding_chunk is not None:
|
|
||||||
return_chunk = self.holding_chunk
|
|
||||||
self.holding_chunk = None
|
|
||||||
return return_chunk
|
|
||||||
if self.sent_last_message is False:
|
|
||||||
self.sent_last_message = True
|
|
||||||
return {"type": "message_stop"}
|
|
||||||
raise StopIteration
|
|
||||||
except StopIteration:
|
|
||||||
if self.sent_last_message is False:
|
|
||||||
self.sent_last_message = True
|
|
||||||
return {"type": "message_stop"}
|
|
||||||
raise StopAsyncIteration
|
|
|
@ -73,6 +73,8 @@ def remove_index_from_tool_calls(
|
||||||
def get_litellm_metadata_from_kwargs(kwargs: dict):
|
def get_litellm_metadata_from_kwargs(kwargs: dict):
|
||||||
"""
|
"""
|
||||||
Helper to get litellm metadata from all litellm request kwargs
|
Helper to get litellm metadata from all litellm request kwargs
|
||||||
|
|
||||||
|
Return `litellm_metadata` if it exists, otherwise return `metadata`
|
||||||
"""
|
"""
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
if litellm_params:
|
if litellm_params:
|
||||||
|
|
|
@ -932,6 +932,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
self.model_call_details["log_event_type"] = "successful_api_call"
|
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||||
self.model_call_details["end_time"] = end_time
|
self.model_call_details["end_time"] = end_time
|
||||||
self.model_call_details["cache_hit"] = cache_hit
|
self.model_call_details["cache_hit"] = cache_hit
|
||||||
|
|
||||||
|
if self.call_type == CallTypes.anthropic_messages.value:
|
||||||
|
result = self._handle_anthropic_messages_response_logging(result=result)
|
||||||
## if model in model cost map - log the response cost
|
## if model in model cost map - log the response cost
|
||||||
## else set cost to None
|
## else set cost to None
|
||||||
if (
|
if (
|
||||||
|
@ -2304,6 +2307,37 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
return complete_streaming_response
|
return complete_streaming_response
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _handle_anthropic_messages_response_logging(self, result: Any) -> ModelResponse:
|
||||||
|
"""
|
||||||
|
Handles logging for Anthropic messages responses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: The response object from the model call
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The the response object from the model call
|
||||||
|
|
||||||
|
- For Non-streaming responses, we need to transform the response to a ModelResponse object.
|
||||||
|
- For streaming responses, anthropic_messages handler calls success_handler with a assembled ModelResponse.
|
||||||
|
"""
|
||||||
|
if self.stream and isinstance(result, ModelResponse):
|
||||||
|
return result
|
||||||
|
|
||||||
|
result = litellm.AnthropicConfig().transform_response(
|
||||||
|
raw_response=self.model_call_details["httpx_response"],
|
||||||
|
model_response=litellm.ModelResponse(),
|
||||||
|
model=self.model,
|
||||||
|
messages=[],
|
||||||
|
logging_obj=self,
|
||||||
|
optional_params={},
|
||||||
|
api_key="",
|
||||||
|
request_data={},
|
||||||
|
encoding=litellm.encoding,
|
||||||
|
json_mode=False,
|
||||||
|
litellm_params={},
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
|
def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -0,0 +1,179 @@
|
||||||
|
"""
|
||||||
|
- call /messages on Anthropic API
|
||||||
|
- Make streaming + non-streaming request - just pass it through direct to Anthropic. No need to do anything special here
|
||||||
|
- Ensure requests are logged in the DB - stream + non-stream
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, AsyncIterator, Dict, Optional, Union, cast
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.llms.base_llm.anthropic_messages.transformation import (
|
||||||
|
BaseAnthropicMessagesConfig,
|
||||||
|
)
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.types.router import GenericLiteLLMParams
|
||||||
|
from litellm.types.utils import ProviderSpecificHeader
|
||||||
|
from litellm.utils import ProviderConfigManager, client
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicMessagesHandler:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _handle_anthropic_streaming(
|
||||||
|
response: httpx.Response,
|
||||||
|
request_body: dict,
|
||||||
|
litellm_logging_obj: LiteLLMLoggingObj,
|
||||||
|
) -> AsyncIterator:
|
||||||
|
"""Helper function to handle Anthropic streaming responses using the existing logging handlers"""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from litellm.proxy.pass_through_endpoints.streaming_handler import (
|
||||||
|
PassThroughStreamingHandler,
|
||||||
|
)
|
||||||
|
from litellm.proxy.pass_through_endpoints.success_handler import (
|
||||||
|
PassThroughEndpointLogging,
|
||||||
|
)
|
||||||
|
from litellm.proxy.pass_through_endpoints.types import EndpointType
|
||||||
|
|
||||||
|
# Create success handler object
|
||||||
|
passthrough_success_handler_obj = PassThroughEndpointLogging()
|
||||||
|
|
||||||
|
# Use the existing streaming handler for Anthropic
|
||||||
|
start_time = datetime.now()
|
||||||
|
return PassThroughStreamingHandler.chunk_processor(
|
||||||
|
response=response,
|
||||||
|
request_body=request_body,
|
||||||
|
litellm_logging_obj=litellm_logging_obj,
|
||||||
|
endpoint_type=EndpointType.ANTHROPIC,
|
||||||
|
start_time=start_time,
|
||||||
|
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||||
|
url_route="/v1/messages",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@client
|
||||||
|
async def anthropic_messages(
|
||||||
|
api_key: str,
|
||||||
|
model: str,
|
||||||
|
stream: bool = False,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Dict[str, Any], AsyncIterator]:
|
||||||
|
"""
|
||||||
|
Makes Anthropic `/v1/messages` API calls In the Anthropic API Spec
|
||||||
|
"""
|
||||||
|
# Use provided client or create a new one
|
||||||
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
|
||||||
|
litellm.get_llm_provider(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
api_base=optional_params.api_base,
|
||||||
|
api_key=optional_params.api_key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = (
|
||||||
|
ProviderConfigManager.get_provider_anthropic_messages_config(
|
||||||
|
model=model,
|
||||||
|
provider=litellm.LlmProviders(_custom_llm_provider),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if anthropic_messages_provider_config is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Anthropic messages provider config not found for model: {model}"
|
||||||
|
)
|
||||||
|
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||||
|
async_httpx_client = get_async_httpx_client(
|
||||||
|
llm_provider=litellm.LlmProviders.ANTHROPIC
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
async_httpx_client = client
|
||||||
|
|
||||||
|
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
|
||||||
|
|
||||||
|
# Prepare headers
|
||||||
|
provider_specific_header = cast(
|
||||||
|
Optional[ProviderSpecificHeader], kwargs.get("provider_specific_header", None)
|
||||||
|
)
|
||||||
|
extra_headers = (
|
||||||
|
provider_specific_header.get("extra_headers", {})
|
||||||
|
if provider_specific_header
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
headers = anthropic_messages_provider_config.validate_environment(
|
||||||
|
headers=extra_headers or {},
|
||||||
|
model=model,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
litellm_logging_obj.update_environment_variables(
|
||||||
|
model=model,
|
||||||
|
optional_params=dict(optional_params),
|
||||||
|
litellm_params={
|
||||||
|
"metadata": kwargs.get("metadata", {}),
|
||||||
|
"preset_cache_key": None,
|
||||||
|
"stream_response": {},
|
||||||
|
**optional_params.model_dump(exclude_unset=True),
|
||||||
|
},
|
||||||
|
custom_llm_provider=_custom_llm_provider,
|
||||||
|
)
|
||||||
|
litellm_logging_obj.model_call_details.update(kwargs)
|
||||||
|
|
||||||
|
# Prepare request body
|
||||||
|
request_body = kwargs.copy()
|
||||||
|
request_body = {
|
||||||
|
k: v
|
||||||
|
for k, v in request_body.items()
|
||||||
|
if k
|
||||||
|
in anthropic_messages_provider_config.get_supported_anthropic_messages_params(
|
||||||
|
model=model
|
||||||
|
)
|
||||||
|
}
|
||||||
|
request_body["stream"] = stream
|
||||||
|
request_body["model"] = model
|
||||||
|
litellm_logging_obj.stream = stream
|
||||||
|
|
||||||
|
# Make the request
|
||||||
|
request_url = anthropic_messages_provider_config.get_complete_url(
|
||||||
|
api_base=api_base, model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
litellm_logging_obj.pre_call(
|
||||||
|
input=[{"role": "user", "content": json.dumps(request_body)}],
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": request_body,
|
||||||
|
"api_base": str(request_url),
|
||||||
|
"headers": headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await async_httpx_client.post(
|
||||||
|
url=request_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(request_body),
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# used for logging + cost tracking
|
||||||
|
litellm_logging_obj.model_call_details["httpx_response"] = response
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return await AnthropicMessagesHandler._handle_anthropic_streaming(
|
||||||
|
response=response,
|
||||||
|
request_body=request_body,
|
||||||
|
litellm_logging_obj=litellm_logging_obj,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return response.json()
|
|
@ -0,0 +1,47 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.anthropic_messages.transformation import (
|
||||||
|
BaseAnthropicMessagesConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
DEFAULT_ANTHROPIC_API_BASE = "https://api.anthropic.com"
|
||||||
|
DEFAULT_ANTHROPIC_API_VERSION = "2023-06-01"
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicMessagesConfig(BaseAnthropicMessagesConfig):
|
||||||
|
def get_supported_anthropic_messages_params(self, model: str) -> list:
|
||||||
|
return [
|
||||||
|
"messages",
|
||||||
|
"model",
|
||||||
|
"system",
|
||||||
|
"max_tokens",
|
||||||
|
"stop_sequences",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"thinking",
|
||||||
|
# TODO: Add Anthropic `metadata` support
|
||||||
|
# "metadata",
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||||
|
api_base = api_base or DEFAULT_ANTHROPIC_API_BASE
|
||||||
|
if not api_base.endswith("/v1/messages"):
|
||||||
|
api_base = f"{api_base}/v1/messages"
|
||||||
|
return api_base
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
if "x-api-key" not in headers:
|
||||||
|
headers["x-api-key"] = api_key
|
||||||
|
if "anthropic-version" not in headers:
|
||||||
|
headers["anthropic-version"] = DEFAULT_ANTHROPIC_API_VERSION
|
||||||
|
if "content-type" not in headers:
|
||||||
|
headers["content-type"] = "application/json"
|
||||||
|
return headers
|
|
@ -1,412 +0,0 @@
|
||||||
import json
|
|
||||||
from typing import List, Literal, Optional, Tuple, Union
|
|
||||||
|
|
||||||
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
|
|
||||||
|
|
||||||
from litellm.types.llms.anthropic import (
|
|
||||||
AllAnthropicToolsValues,
|
|
||||||
AnthopicMessagesAssistantMessageParam,
|
|
||||||
AnthropicFinishReason,
|
|
||||||
AnthropicMessagesRequest,
|
|
||||||
AnthropicMessagesToolChoice,
|
|
||||||
AnthropicMessagesUserMessageParam,
|
|
||||||
AnthropicResponse,
|
|
||||||
AnthropicResponseContentBlockText,
|
|
||||||
AnthropicResponseContentBlockToolUse,
|
|
||||||
AnthropicResponseUsageBlock,
|
|
||||||
ContentBlockDelta,
|
|
||||||
ContentJsonBlockDelta,
|
|
||||||
ContentTextBlockDelta,
|
|
||||||
MessageBlockDelta,
|
|
||||||
MessageDelta,
|
|
||||||
UsageDelta,
|
|
||||||
)
|
|
||||||
from litellm.types.llms.openai import (
|
|
||||||
AllMessageValues,
|
|
||||||
ChatCompletionAssistantMessage,
|
|
||||||
ChatCompletionAssistantToolCall,
|
|
||||||
ChatCompletionImageObject,
|
|
||||||
ChatCompletionImageUrlObject,
|
|
||||||
ChatCompletionRequest,
|
|
||||||
ChatCompletionSystemMessage,
|
|
||||||
ChatCompletionTextObject,
|
|
||||||
ChatCompletionToolCallFunctionChunk,
|
|
||||||
ChatCompletionToolChoiceFunctionParam,
|
|
||||||
ChatCompletionToolChoiceObjectParam,
|
|
||||||
ChatCompletionToolChoiceValues,
|
|
||||||
ChatCompletionToolMessage,
|
|
||||||
ChatCompletionToolParam,
|
|
||||||
ChatCompletionToolParamFunctionChunk,
|
|
||||||
ChatCompletionUserMessage,
|
|
||||||
)
|
|
||||||
from litellm.types.utils import Choices, ModelResponse, Usage
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicExperimentalPassThroughConfig:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
### FOR [BETA] `/v1/messages` endpoint support
|
|
||||||
|
|
||||||
def translatable_anthropic_params(self) -> List:
|
|
||||||
"""
|
|
||||||
Which anthropic params, we need to translate to the openai format.
|
|
||||||
"""
|
|
||||||
return ["messages", "metadata", "system", "tool_choice", "tools"]
|
|
||||||
|
|
||||||
def translate_anthropic_messages_to_openai( # noqa: PLR0915
|
|
||||||
self,
|
|
||||||
messages: List[
|
|
||||||
Union[
|
|
||||||
AnthropicMessagesUserMessageParam,
|
|
||||||
AnthopicMessagesAssistantMessageParam,
|
|
||||||
]
|
|
||||||
],
|
|
||||||
) -> List:
|
|
||||||
new_messages: List[AllMessageValues] = []
|
|
||||||
for m in messages:
|
|
||||||
user_message: Optional[ChatCompletionUserMessage] = None
|
|
||||||
tool_message_list: List[ChatCompletionToolMessage] = []
|
|
||||||
new_user_content_list: List[
|
|
||||||
Union[ChatCompletionTextObject, ChatCompletionImageObject]
|
|
||||||
] = []
|
|
||||||
## USER MESSAGE ##
|
|
||||||
if m["role"] == "user":
|
|
||||||
## translate user message
|
|
||||||
message_content = m.get("content")
|
|
||||||
if message_content and isinstance(message_content, str):
|
|
||||||
user_message = ChatCompletionUserMessage(
|
|
||||||
role="user", content=message_content
|
|
||||||
)
|
|
||||||
elif message_content and isinstance(message_content, list):
|
|
||||||
for content in message_content:
|
|
||||||
if content["type"] == "text":
|
|
||||||
text_obj = ChatCompletionTextObject(
|
|
||||||
type="text", text=content["text"]
|
|
||||||
)
|
|
||||||
new_user_content_list.append(text_obj)
|
|
||||||
elif content["type"] == "image":
|
|
||||||
image_url = ChatCompletionImageUrlObject(
|
|
||||||
url=f"data:{content['type']};base64,{content['source']}"
|
|
||||||
)
|
|
||||||
image_obj = ChatCompletionImageObject(
|
|
||||||
type="image_url", image_url=image_url
|
|
||||||
)
|
|
||||||
|
|
||||||
new_user_content_list.append(image_obj)
|
|
||||||
elif content["type"] == "tool_result":
|
|
||||||
if "content" not in content:
|
|
||||||
tool_result = ChatCompletionToolMessage(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=content["tool_use_id"],
|
|
||||||
content="",
|
|
||||||
)
|
|
||||||
tool_message_list.append(tool_result)
|
|
||||||
elif isinstance(content["content"], str):
|
|
||||||
tool_result = ChatCompletionToolMessage(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=content["tool_use_id"],
|
|
||||||
content=content["content"],
|
|
||||||
)
|
|
||||||
tool_message_list.append(tool_result)
|
|
||||||
elif isinstance(content["content"], list):
|
|
||||||
for c in content["content"]:
|
|
||||||
if c["type"] == "text":
|
|
||||||
tool_result = ChatCompletionToolMessage(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=content["tool_use_id"],
|
|
||||||
content=c["text"],
|
|
||||||
)
|
|
||||||
tool_message_list.append(tool_result)
|
|
||||||
elif c["type"] == "image":
|
|
||||||
image_str = (
|
|
||||||
f"data:{c['type']};base64,{c['source']}"
|
|
||||||
)
|
|
||||||
tool_result = ChatCompletionToolMessage(
|
|
||||||
role="tool",
|
|
||||||
tool_call_id=content["tool_use_id"],
|
|
||||||
content=image_str,
|
|
||||||
)
|
|
||||||
tool_message_list.append(tool_result)
|
|
||||||
|
|
||||||
if user_message is not None:
|
|
||||||
new_messages.append(user_message)
|
|
||||||
|
|
||||||
if len(new_user_content_list) > 0:
|
|
||||||
new_messages.append({"role": "user", "content": new_user_content_list}) # type: ignore
|
|
||||||
|
|
||||||
if len(tool_message_list) > 0:
|
|
||||||
new_messages.extend(tool_message_list)
|
|
||||||
|
|
||||||
## ASSISTANT MESSAGE ##
|
|
||||||
assistant_message_str: Optional[str] = None
|
|
||||||
tool_calls: List[ChatCompletionAssistantToolCall] = []
|
|
||||||
if m["role"] == "assistant":
|
|
||||||
if isinstance(m["content"], str):
|
|
||||||
assistant_message_str = m["content"]
|
|
||||||
elif isinstance(m["content"], list):
|
|
||||||
for content in m["content"]:
|
|
||||||
if content["type"] == "text":
|
|
||||||
if assistant_message_str is None:
|
|
||||||
assistant_message_str = content["text"]
|
|
||||||
else:
|
|
||||||
assistant_message_str += content["text"]
|
|
||||||
elif content["type"] == "tool_use":
|
|
||||||
function_chunk = ChatCompletionToolCallFunctionChunk(
|
|
||||||
name=content["name"],
|
|
||||||
arguments=json.dumps(content["input"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_calls.append(
|
|
||||||
ChatCompletionAssistantToolCall(
|
|
||||||
id=content["id"],
|
|
||||||
type="function",
|
|
||||||
function=function_chunk,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if assistant_message_str is not None or len(tool_calls) > 0:
|
|
||||||
assistant_message = ChatCompletionAssistantMessage(
|
|
||||||
role="assistant",
|
|
||||||
content=assistant_message_str,
|
|
||||||
)
|
|
||||||
if len(tool_calls) > 0:
|
|
||||||
assistant_message["tool_calls"] = tool_calls
|
|
||||||
new_messages.append(assistant_message)
|
|
||||||
|
|
||||||
return new_messages
|
|
||||||
|
|
||||||
def translate_anthropic_tool_choice_to_openai(
|
|
||||||
self, tool_choice: AnthropicMessagesToolChoice
|
|
||||||
) -> ChatCompletionToolChoiceValues:
|
|
||||||
if tool_choice["type"] == "any":
|
|
||||||
return "required"
|
|
||||||
elif tool_choice["type"] == "auto":
|
|
||||||
return "auto"
|
|
||||||
elif tool_choice["type"] == "tool":
|
|
||||||
tc_function_param = ChatCompletionToolChoiceFunctionParam(
|
|
||||||
name=tool_choice.get("name", "")
|
|
||||||
)
|
|
||||||
return ChatCompletionToolChoiceObjectParam(
|
|
||||||
type="function", function=tc_function_param
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Incompatible tool choice param submitted - {}".format(tool_choice)
|
|
||||||
)
|
|
||||||
|
|
||||||
def translate_anthropic_tools_to_openai(
|
|
||||||
self, tools: List[AllAnthropicToolsValues]
|
|
||||||
) -> List[ChatCompletionToolParam]:
|
|
||||||
new_tools: List[ChatCompletionToolParam] = []
|
|
||||||
mapped_tool_params = ["name", "input_schema", "description"]
|
|
||||||
for tool in tools:
|
|
||||||
function_chunk = ChatCompletionToolParamFunctionChunk(
|
|
||||||
name=tool["name"],
|
|
||||||
)
|
|
||||||
if "input_schema" in tool:
|
|
||||||
function_chunk["parameters"] = tool["input_schema"] # type: ignore
|
|
||||||
if "description" in tool:
|
|
||||||
function_chunk["description"] = tool["description"] # type: ignore
|
|
||||||
|
|
||||||
for k, v in tool.items():
|
|
||||||
if k not in mapped_tool_params: # pass additional computer kwargs
|
|
||||||
function_chunk.setdefault("parameters", {}).update({k: v})
|
|
||||||
new_tools.append(
|
|
||||||
ChatCompletionToolParam(type="function", function=function_chunk)
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_tools
|
|
||||||
|
|
||||||
def translate_anthropic_to_openai(
|
|
||||||
self, anthropic_message_request: AnthropicMessagesRequest
|
|
||||||
) -> ChatCompletionRequest:
|
|
||||||
"""
|
|
||||||
This is used by the beta Anthropic Adapter, for translating anthropic `/v1/messages` requests to the openai format.
|
|
||||||
"""
|
|
||||||
new_messages: List[AllMessageValues] = []
|
|
||||||
|
|
||||||
## CONVERT ANTHROPIC MESSAGES TO OPENAI
|
|
||||||
new_messages = self.translate_anthropic_messages_to_openai(
|
|
||||||
messages=anthropic_message_request["messages"]
|
|
||||||
)
|
|
||||||
## ADD SYSTEM MESSAGE TO MESSAGES
|
|
||||||
if "system" in anthropic_message_request:
|
|
||||||
new_messages.insert(
|
|
||||||
0,
|
|
||||||
ChatCompletionSystemMessage(
|
|
||||||
role="system", content=anthropic_message_request["system"]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
new_kwargs: ChatCompletionRequest = {
|
|
||||||
"model": anthropic_message_request["model"],
|
|
||||||
"messages": new_messages,
|
|
||||||
}
|
|
||||||
## CONVERT METADATA (user_id)
|
|
||||||
if "metadata" in anthropic_message_request:
|
|
||||||
if "user_id" in anthropic_message_request["metadata"]:
|
|
||||||
new_kwargs["user"] = anthropic_message_request["metadata"]["user_id"]
|
|
||||||
|
|
||||||
# Pass litellm proxy specific metadata
|
|
||||||
if "litellm_metadata" in anthropic_message_request:
|
|
||||||
# metadata will be passed to litellm.acompletion(), it's a litellm_param
|
|
||||||
new_kwargs["metadata"] = anthropic_message_request.pop("litellm_metadata")
|
|
||||||
|
|
||||||
## CONVERT TOOL CHOICE
|
|
||||||
if "tool_choice" in anthropic_message_request:
|
|
||||||
new_kwargs["tool_choice"] = self.translate_anthropic_tool_choice_to_openai(
|
|
||||||
tool_choice=anthropic_message_request["tool_choice"]
|
|
||||||
)
|
|
||||||
## CONVERT TOOLS
|
|
||||||
if "tools" in anthropic_message_request:
|
|
||||||
new_kwargs["tools"] = self.translate_anthropic_tools_to_openai(
|
|
||||||
tools=anthropic_message_request["tools"]
|
|
||||||
)
|
|
||||||
|
|
||||||
translatable_params = self.translatable_anthropic_params()
|
|
||||||
for k, v in anthropic_message_request.items():
|
|
||||||
if k not in translatable_params: # pass remaining params as is
|
|
||||||
new_kwargs[k] = v # type: ignore
|
|
||||||
|
|
||||||
return new_kwargs
|
|
||||||
|
|
||||||
def _translate_openai_content_to_anthropic(
|
|
||||||
self, choices: List[Choices]
|
|
||||||
) -> List[
|
|
||||||
Union[AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse]
|
|
||||||
]:
|
|
||||||
new_content: List[
|
|
||||||
Union[
|
|
||||||
AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse
|
|
||||||
]
|
|
||||||
] = []
|
|
||||||
for choice in choices:
|
|
||||||
if (
|
|
||||||
choice.message.tool_calls is not None
|
|
||||||
and len(choice.message.tool_calls) > 0
|
|
||||||
):
|
|
||||||
for tool_call in choice.message.tool_calls:
|
|
||||||
new_content.append(
|
|
||||||
AnthropicResponseContentBlockToolUse(
|
|
||||||
type="tool_use",
|
|
||||||
id=tool_call.id,
|
|
||||||
name=tool_call.function.name or "",
|
|
||||||
input=json.loads(tool_call.function.arguments),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif choice.message.content is not None:
|
|
||||||
new_content.append(
|
|
||||||
AnthropicResponseContentBlockText(
|
|
||||||
type="text", text=choice.message.content
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_content
|
|
||||||
|
|
||||||
def _translate_openai_finish_reason_to_anthropic(
|
|
||||||
self, openai_finish_reason: str
|
|
||||||
) -> AnthropicFinishReason:
|
|
||||||
if openai_finish_reason == "stop":
|
|
||||||
return "end_turn"
|
|
||||||
elif openai_finish_reason == "length":
|
|
||||||
return "max_tokens"
|
|
||||||
elif openai_finish_reason == "tool_calls":
|
|
||||||
return "tool_use"
|
|
||||||
return "end_turn"
|
|
||||||
|
|
||||||
def translate_openai_response_to_anthropic(
|
|
||||||
self, response: ModelResponse
|
|
||||||
) -> AnthropicResponse:
|
|
||||||
## translate content block
|
|
||||||
anthropic_content = self._translate_openai_content_to_anthropic(choices=response.choices) # type: ignore
|
|
||||||
## extract finish reason
|
|
||||||
anthropic_finish_reason = self._translate_openai_finish_reason_to_anthropic(
|
|
||||||
openai_finish_reason=response.choices[0].finish_reason # type: ignore
|
|
||||||
)
|
|
||||||
# extract usage
|
|
||||||
usage: Usage = getattr(response, "usage")
|
|
||||||
anthropic_usage = AnthropicResponseUsageBlock(
|
|
||||||
input_tokens=usage.prompt_tokens or 0,
|
|
||||||
output_tokens=usage.completion_tokens or 0,
|
|
||||||
)
|
|
||||||
translated_obj = AnthropicResponse(
|
|
||||||
id=response.id,
|
|
||||||
type="message",
|
|
||||||
role="assistant",
|
|
||||||
model=response.model or "unknown-model",
|
|
||||||
stop_sequence=None,
|
|
||||||
usage=anthropic_usage,
|
|
||||||
content=anthropic_content,
|
|
||||||
stop_reason=anthropic_finish_reason,
|
|
||||||
)
|
|
||||||
|
|
||||||
return translated_obj
|
|
||||||
|
|
||||||
def _translate_streaming_openai_chunk_to_anthropic(
|
|
||||||
self, choices: List[OpenAIStreamingChoice]
|
|
||||||
) -> Tuple[
|
|
||||||
Literal["text_delta", "input_json_delta"],
|
|
||||||
Union[ContentTextBlockDelta, ContentJsonBlockDelta],
|
|
||||||
]:
|
|
||||||
text: str = ""
|
|
||||||
partial_json: Optional[str] = None
|
|
||||||
for choice in choices:
|
|
||||||
if choice.delta.content is not None:
|
|
||||||
text += choice.delta.content
|
|
||||||
elif choice.delta.tool_calls is not None:
|
|
||||||
partial_json = ""
|
|
||||||
for tool in choice.delta.tool_calls:
|
|
||||||
if (
|
|
||||||
tool.function is not None
|
|
||||||
and tool.function.arguments is not None
|
|
||||||
):
|
|
||||||
partial_json += tool.function.arguments
|
|
||||||
|
|
||||||
if partial_json is not None:
|
|
||||||
return "input_json_delta", ContentJsonBlockDelta(
|
|
||||||
type="input_json_delta", partial_json=partial_json
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return "text_delta", ContentTextBlockDelta(type="text_delta", text=text)
|
|
||||||
|
|
||||||
def translate_streaming_openai_response_to_anthropic(
|
|
||||||
self, response: ModelResponse
|
|
||||||
) -> Union[ContentBlockDelta, MessageBlockDelta]:
|
|
||||||
## base case - final chunk w/ finish reason
|
|
||||||
if response.choices[0].finish_reason is not None:
|
|
||||||
delta = MessageDelta(
|
|
||||||
stop_reason=self._translate_openai_finish_reason_to_anthropic(
|
|
||||||
response.choices[0].finish_reason
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if getattr(response, "usage", None) is not None:
|
|
||||||
litellm_usage_chunk: Optional[Usage] = response.usage # type: ignore
|
|
||||||
elif (
|
|
||||||
hasattr(response, "_hidden_params")
|
|
||||||
and "usage" in response._hidden_params
|
|
||||||
):
|
|
||||||
litellm_usage_chunk = response._hidden_params["usage"]
|
|
||||||
else:
|
|
||||||
litellm_usage_chunk = None
|
|
||||||
if litellm_usage_chunk is not None:
|
|
||||||
usage_delta = UsageDelta(
|
|
||||||
input_tokens=litellm_usage_chunk.prompt_tokens or 0,
|
|
||||||
output_tokens=litellm_usage_chunk.completion_tokens or 0,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
usage_delta = UsageDelta(input_tokens=0, output_tokens=0)
|
|
||||||
return MessageBlockDelta(
|
|
||||||
type="message_delta", delta=delta, usage=usage_delta
|
|
||||||
)
|
|
||||||
(
|
|
||||||
type_of_content,
|
|
||||||
content_block_delta,
|
|
||||||
) = self._translate_streaming_openai_chunk_to_anthropic(
|
|
||||||
choices=response.choices # type: ignore
|
|
||||||
)
|
|
||||||
return ContentBlockDelta(
|
|
||||||
type="content_block_delta",
|
|
||||||
index=response.choices[0].index,
|
|
||||||
delta=content_block_delta,
|
|
||||||
)
|
|
35
litellm/llms/base_llm/anthropic_messages/transformation.py
Normal file
35
litellm/llms/base_llm/anthropic_messages/transformation.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LiteLLMLoggingObj = Any
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAnthropicMessagesConfig(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: dict,
|
||||||
|
model: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||||
|
"""
|
||||||
|
OPTIONAL
|
||||||
|
|
||||||
|
Get the complete url for the request
|
||||||
|
|
||||||
|
Some providers need `model` in `api_base`
|
||||||
|
"""
|
||||||
|
return api_base or ""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_supported_anthropic_messages_params(self, model: str) -> list:
|
||||||
|
pass
|
|
@ -1963,7 +1963,7 @@ class ProxyException(Exception):
|
||||||
code: Optional[Union[int, str]] = None,
|
code: Optional[Union[int, str]] = None,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[Dict[str, str]] = None,
|
||||||
):
|
):
|
||||||
self.message = message
|
self.message = str(message)
|
||||||
self.type = type
|
self.type = type
|
||||||
self.param = param
|
self.param = param
|
||||||
|
|
||||||
|
|
252
litellm/proxy/anthropic_endpoints/endpoints.py
Normal file
252
litellm/proxy/anthropic_endpoints/endpoints.py
Normal file
|
@ -0,0 +1,252 @@
|
||||||
|
"""
|
||||||
|
Unified /v1/messages endpoint - (Anthropic Spec)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy._types import *
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||||
|
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
async def async_data_generator_anthropic(
|
||||||
|
response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
request_data: dict,
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
):
|
||||||
|
verbose_proxy_logger.debug("inside generator")
|
||||||
|
try:
|
||||||
|
time.time()
|
||||||
|
async for chunk in response:
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"async_data_generator: received streaming chunk - {}".format(chunk)
|
||||||
|
)
|
||||||
|
### CALL HOOKS ### - modify outgoing data
|
||||||
|
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, response=chunk
|
||||||
|
)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(
|
||||||
|
"litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
original_exception=e,
|
||||||
|
request_data=request_data,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||||
|
|
||||||
|
proxy_exception = ProxyException(
|
||||||
|
message=getattr(e, "message", error_msg),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", 500),
|
||||||
|
)
|
||||||
|
error_returned = json.dumps({"error": proxy_exception.to_dict()})
|
||||||
|
yield f"data: {error_returned}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/messages",
|
||||||
|
tags=["[beta] Anthropic `/v1/messages`"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
async def anthropic_response( # noqa: PLR0915
|
||||||
|
fastapi_response: Response,
|
||||||
|
request: Request,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Use `{PROXY_BASE_URL}/anthropic/v1/messages` instead - [Docs](https://docs.litellm.ai/docs/anthropic_completion).
|
||||||
|
|
||||||
|
This was a BETA endpoint that calls 100+ LLMs in the anthropic format.
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
general_settings,
|
||||||
|
get_custom_headers,
|
||||||
|
llm_router,
|
||||||
|
proxy_config,
|
||||||
|
proxy_logging_obj,
|
||||||
|
user_api_base,
|
||||||
|
user_max_tokens,
|
||||||
|
user_model,
|
||||||
|
user_request_timeout,
|
||||||
|
user_temperature,
|
||||||
|
version,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_data = await _read_request_body(request=request)
|
||||||
|
data: dict = {**request_data}
|
||||||
|
try:
|
||||||
|
data["model"] = (
|
||||||
|
general_settings.get("completion_model", None) # server default
|
||||||
|
or user_model # model name passed via cli args
|
||||||
|
or data.get("model", None) # default passed in http request
|
||||||
|
)
|
||||||
|
if user_model:
|
||||||
|
data["model"] = user_model
|
||||||
|
|
||||||
|
data = await add_litellm_data_to_request(
|
||||||
|
data=data, # type: ignore
|
||||||
|
request=request,
|
||||||
|
general_settings=general_settings,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
version=version,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# override with user settings, these are params passed via cli
|
||||||
|
if user_temperature:
|
||||||
|
data["temperature"] = user_temperature
|
||||||
|
if user_request_timeout:
|
||||||
|
data["request_timeout"] = user_request_timeout
|
||||||
|
if user_max_tokens:
|
||||||
|
data["max_tokens"] = user_max_tokens
|
||||||
|
if user_api_base:
|
||||||
|
data["api_base"] = user_api_base
|
||||||
|
|
||||||
|
### MODEL ALIAS MAPPING ###
|
||||||
|
# check if model name in model alias map
|
||||||
|
# get the actual model name
|
||||||
|
if data["model"] in litellm.model_alias_map:
|
||||||
|
data["model"] = litellm.model_alias_map[data["model"]]
|
||||||
|
|
||||||
|
### CALL HOOKS ### - modify incoming data before calling the model
|
||||||
|
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||||
|
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
|
||||||
|
)
|
||||||
|
|
||||||
|
### ROUTE THE REQUESTs ###
|
||||||
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||||
|
|
||||||
|
# skip router if user passed their key
|
||||||
|
if (
|
||||||
|
llm_router is not None and data["model"] in router_model_names
|
||||||
|
): # model in router model list
|
||||||
|
llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data))
|
||||||
|
elif (
|
||||||
|
llm_router is not None
|
||||||
|
and llm_router.model_group_alias is not None
|
||||||
|
and data["model"] in llm_router.model_group_alias
|
||||||
|
): # model set in model_group_alias
|
||||||
|
llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data))
|
||||||
|
elif (
|
||||||
|
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||||
|
): # model in router deployments, calling a specific deployment on the router
|
||||||
|
llm_response = asyncio.create_task(
|
||||||
|
llm_router.aanthropic_messages(**data, specific_deployment=True)
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
llm_router is not None and data["model"] in llm_router.get_model_ids()
|
||||||
|
): # model in router model list
|
||||||
|
llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data))
|
||||||
|
elif (
|
||||||
|
llm_router is not None
|
||||||
|
and data["model"] not in router_model_names
|
||||||
|
and (
|
||||||
|
llm_router.default_deployment is not None
|
||||||
|
or len(llm_router.pattern_router.patterns) > 0
|
||||||
|
)
|
||||||
|
): # model in router deployments, calling a specific deployment on the router
|
||||||
|
llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data))
|
||||||
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||||
|
llm_response = asyncio.create_task(litellm.anthropic_messages(**data))
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={
|
||||||
|
"error": "completion: Invalid model name passed in model="
|
||||||
|
+ data.get("model", "")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Await the llm_response task
|
||||||
|
response = await llm_response
|
||||||
|
|
||||||
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
model_id = hidden_params.get("model_id", None) or ""
|
||||||
|
cache_key = hidden_params.get("cache_key", None) or ""
|
||||||
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
response_cost = hidden_params.get("response_cost", None) or ""
|
||||||
|
|
||||||
|
### ALERTING ###
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.update_request_status(
|
||||||
|
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("final response: %s", response)
|
||||||
|
|
||||||
|
fastapi_response.headers.update(
|
||||||
|
get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
model_id=model_id,
|
||||||
|
cache_key=cache_key,
|
||||||
|
api_base=api_base,
|
||||||
|
version=version,
|
||||||
|
response_cost=response_cost,
|
||||||
|
request_data=data,
|
||||||
|
hidden_params=hidden_params,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
"stream" in data and data["stream"] is True
|
||||||
|
): # use generate_responses to stream responses
|
||||||
|
selected_data_generator = async_data_generator_anthropic(
|
||||||
|
response=response,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
request_data=data,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
selected_data_generator, # type: ignore
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response))
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.exception(
|
||||||
|
"litellm.proxy.proxy_server.anthropic_response(): Exception occured - {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
error_msg = f"{str(e)}"
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", error_msg),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", 500),
|
||||||
|
)
|
|
@ -4,6 +4,26 @@ model_list:
|
||||||
model: openai/fake
|
model: openai/fake
|
||||||
api_key: fake-key
|
api_key: fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
- model_name: claude-3-5-sonnet-20241022
|
||||||
|
litellm_params:
|
||||||
|
model: anthropic/claude-3-5-sonnet-20241022
|
||||||
|
api_key: os.environ/ANTHROPIC_API_KEY
|
||||||
|
- model_name: claude-special-alias
|
||||||
|
litellm_params:
|
||||||
|
model: anthropic/claude-3-haiku-20240307
|
||||||
|
api_key: os.environ/ANTHROPIC_API_KEY
|
||||||
|
- model_name: claude-3-5-sonnet-20241022
|
||||||
|
litellm_params:
|
||||||
|
model: anthropic/claude-3-5-sonnet-20241022
|
||||||
|
api_key: os.environ/ANTHROPIC_API_KEY
|
||||||
|
- model_name: claude-3-7-sonnet-20250219
|
||||||
|
litellm_params:
|
||||||
|
model: anthropic/claude-3-7-sonnet-20250219
|
||||||
|
api_key: os.environ/ANTHROPIC_API_KEY
|
||||||
|
- model_name: anthropic/*
|
||||||
|
litellm_params:
|
||||||
|
model: anthropic/*
|
||||||
|
api_key: os.environ/ANTHROPIC_API_KEY
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
custom_auth: custom_auth_basic.user_api_key_auth
|
custom_auth: custom_auth_basic.user_api_key_auth
|
|
@ -4,7 +4,22 @@ model_list:
|
||||||
model: openai/my-fake-model
|
model: openai/my-fake-model
|
||||||
api_key: my-fake-key
|
api_key: my-fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
- model_name: claude-special-alias
|
||||||
|
litellm_params:
|
||||||
|
model: anthropic/claude-3-haiku-20240307
|
||||||
|
api_key: os.environ/ANTHROPIC_API_KEY
|
||||||
|
- model_name: claude-3-5-sonnet-20241022
|
||||||
|
litellm_params:
|
||||||
|
model: anthropic/claude-3-5-sonnet-20241022
|
||||||
|
api_key: os.environ/ANTHROPIC_API_KEY
|
||||||
|
- model_name: claude-3-7-sonnet-20250219
|
||||||
|
litellm_params:
|
||||||
|
model: anthropic/claude-3-7-sonnet-20250219
|
||||||
|
api_key: os.environ/ANTHROPIC_API_KEY
|
||||||
|
- model_name: anthropic/*
|
||||||
|
litellm_params:
|
||||||
|
model: anthropic/*
|
||||||
|
api_key: os.environ/ANTHROPIC_API_KEY
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
store_model_in_db: true
|
store_model_in_db: true
|
||||||
|
|
|
@ -120,6 +120,7 @@ from litellm.proxy._types import *
|
||||||
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
||||||
router as analytics_router,
|
router as analytics_router,
|
||||||
)
|
)
|
||||||
|
from litellm.proxy.anthropic_endpoints.endpoints import router as anthropic_router
|
||||||
from litellm.proxy.auth.auth_checks import log_db_metrics
|
from litellm.proxy.auth.auth_checks import log_db_metrics
|
||||||
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
|
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
|
@ -3065,58 +3066,6 @@ async def async_data_generator(
|
||||||
yield f"data: {error_returned}\n\n"
|
yield f"data: {error_returned}\n\n"
|
||||||
|
|
||||||
|
|
||||||
async def async_data_generator_anthropic(
|
|
||||||
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
|
||||||
):
|
|
||||||
verbose_proxy_logger.debug("inside generator")
|
|
||||||
try:
|
|
||||||
time.time()
|
|
||||||
async for chunk in response:
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
"async_data_generator: received streaming chunk - {}".format(chunk)
|
|
||||||
)
|
|
||||||
### CALL HOOKS ### - modify outgoing data
|
|
||||||
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
|
|
||||||
user_api_key_dict=user_api_key_dict, response=chunk
|
|
||||||
)
|
|
||||||
|
|
||||||
event_type = chunk.get("type")
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield f"event: {event_type}\ndata:{json.dumps(chunk)}\n\n"
|
|
||||||
except Exception as e:
|
|
||||||
yield f"event: {event_type}\ndata:{str(e)}\n\n"
|
|
||||||
except Exception as e:
|
|
||||||
verbose_proxy_logger.exception(
|
|
||||||
"litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}".format(
|
|
||||||
str(e)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
original_exception=e,
|
|
||||||
request_data=request_data,
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(e, HTTPException):
|
|
||||||
raise e
|
|
||||||
else:
|
|
||||||
error_traceback = traceback.format_exc()
|
|
||||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
|
||||||
|
|
||||||
proxy_exception = ProxyException(
|
|
||||||
message=getattr(e, "message", error_msg),
|
|
||||||
type=getattr(e, "type", "None"),
|
|
||||||
param=getattr(e, "param", "None"),
|
|
||||||
code=getattr(e, "status_code", 500),
|
|
||||||
)
|
|
||||||
error_returned = json.dumps({"error": proxy_exception.to_dict()})
|
|
||||||
yield f"data: {error_returned}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
def select_data_generator(
|
def select_data_generator(
|
||||||
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||||
):
|
):
|
||||||
|
@ -5524,224 +5473,6 @@ async def moderations(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
#### ANTHROPIC ENDPOINTS ####
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/v1/messages",
|
|
||||||
tags=["[beta] Anthropic `/v1/messages`"],
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
response_model=AnthropicResponse,
|
|
||||||
include_in_schema=False,
|
|
||||||
)
|
|
||||||
async def anthropic_response( # noqa: PLR0915
|
|
||||||
anthropic_data: AnthropicMessagesRequest,
|
|
||||||
fastapi_response: Response,
|
|
||||||
request: Request,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
🚨 DEPRECATED ENDPOINT🚨
|
|
||||||
|
|
||||||
Use `{PROXY_BASE_URL}/anthropic/v1/messages` instead - [Docs](https://docs.litellm.ai/docs/anthropic_completion).
|
|
||||||
|
|
||||||
This was a BETA endpoint that calls 100+ LLMs in the anthropic format.
|
|
||||||
"""
|
|
||||||
from litellm import adapter_completion
|
|
||||||
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
|
||||||
|
|
||||||
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
|
|
||||||
|
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
|
||||||
request_data = await _read_request_body(request=request)
|
|
||||||
data: dict = {**request_data, "adapter_id": "anthropic"}
|
|
||||||
try:
|
|
||||||
data["model"] = (
|
|
||||||
general_settings.get("completion_model", None) # server default
|
|
||||||
or user_model # model name passed via cli args
|
|
||||||
or data.get("model", None) # default passed in http request
|
|
||||||
)
|
|
||||||
if user_model:
|
|
||||||
data["model"] = user_model
|
|
||||||
|
|
||||||
data = await add_litellm_data_to_request(
|
|
||||||
data=data, # type: ignore
|
|
||||||
request=request,
|
|
||||||
general_settings=general_settings,
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
version=version,
|
|
||||||
proxy_config=proxy_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# override with user settings, these are params passed via cli
|
|
||||||
if user_temperature:
|
|
||||||
data["temperature"] = user_temperature
|
|
||||||
if user_request_timeout:
|
|
||||||
data["request_timeout"] = user_request_timeout
|
|
||||||
if user_max_tokens:
|
|
||||||
data["max_tokens"] = user_max_tokens
|
|
||||||
if user_api_base:
|
|
||||||
data["api_base"] = user_api_base
|
|
||||||
|
|
||||||
### MODEL ALIAS MAPPING ###
|
|
||||||
# check if model name in model alias map
|
|
||||||
# get the actual model name
|
|
||||||
if data["model"] in litellm.model_alias_map:
|
|
||||||
data["model"] = litellm.model_alias_map[data["model"]]
|
|
||||||
|
|
||||||
### CALL HOOKS ### - modify incoming data before calling the model
|
|
||||||
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
|
|
||||||
)
|
|
||||||
|
|
||||||
### ROUTE THE REQUESTs ###
|
|
||||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
|
||||||
# skip router if user passed their key
|
|
||||||
if "api_key" in data:
|
|
||||||
llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in router_model_names
|
|
||||||
): # model in router model list
|
|
||||||
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and llm_router.model_group_alias is not None
|
|
||||||
and data["model"] in llm_router.model_group_alias
|
|
||||||
): # model set in model_group_alias
|
|
||||||
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
llm_response = asyncio.create_task(
|
|
||||||
llm_router.aadapter_completion(**data, specific_deployment=True)
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in llm_router.get_model_ids()
|
|
||||||
): # model in router model list
|
|
||||||
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and data["model"] not in router_model_names
|
|
||||||
and (
|
|
||||||
llm_router.default_deployment is not None
|
|
||||||
or len(llm_router.pattern_router.patterns) > 0
|
|
||||||
)
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
|
|
||||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
||||||
llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
"error": "completion: Invalid model name passed in model="
|
|
||||||
+ data.get("model", "")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Await the llm_response task
|
|
||||||
response = await llm_response
|
|
||||||
|
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
|
||||||
model_id = hidden_params.get("model_id", None) or ""
|
|
||||||
cache_key = hidden_params.get("cache_key", None) or ""
|
|
||||||
api_base = hidden_params.get("api_base", None) or ""
|
|
||||||
response_cost = hidden_params.get("response_cost", None) or ""
|
|
||||||
|
|
||||||
### ALERTING ###
|
|
||||||
asyncio.create_task(
|
|
||||||
proxy_logging_obj.update_request_status(
|
|
||||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug("final response: %s", response)
|
|
||||||
|
|
||||||
fastapi_response.headers.update(
|
|
||||||
get_custom_headers(
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
model_id=model_id,
|
|
||||||
cache_key=cache_key,
|
|
||||||
api_base=api_base,
|
|
||||||
version=version,
|
|
||||||
response_cost=response_cost,
|
|
||||||
request_data=data,
|
|
||||||
hidden_params=hidden_params,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
"stream" in data and data["stream"] is True
|
|
||||||
): # use generate_responses to stream responses
|
|
||||||
selected_data_generator = async_data_generator_anthropic(
|
|
||||||
response=response,
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
request_data=data,
|
|
||||||
)
|
|
||||||
return StreamingResponse(
|
|
||||||
selected_data_generator,
|
|
||||||
media_type="text/event-stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response))
|
|
||||||
return response
|
|
||||||
except RejectedRequestError as e:
|
|
||||||
_data = e.request_data
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
original_exception=e,
|
|
||||||
request_data=_data,
|
|
||||||
)
|
|
||||||
if _data.get("stream", None) is not None and _data["stream"] is True:
|
|
||||||
_chat_response = litellm.ModelResponse()
|
|
||||||
_usage = litellm.Usage(
|
|
||||||
prompt_tokens=0,
|
|
||||||
completion_tokens=0,
|
|
||||||
total_tokens=0,
|
|
||||||
)
|
|
||||||
_chat_response.usage = _usage # type: ignore
|
|
||||||
_chat_response.choices[0].message.content = e.message # type: ignore
|
|
||||||
_iterator = litellm.utils.ModelResponseIterator(
|
|
||||||
model_response=_chat_response, convert_to_delta=True
|
|
||||||
)
|
|
||||||
_streaming_response = litellm.TextCompletionStreamWrapper(
|
|
||||||
completion_stream=_iterator,
|
|
||||||
model=_data.get("model", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
selected_data_generator = select_data_generator(
|
|
||||||
response=_streaming_response,
|
|
||||||
user_api_key_dict=user_api_key_dict,
|
|
||||||
request_data=data,
|
|
||||||
)
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
selected_data_generator,
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_response = litellm.TextCompletionResponse()
|
|
||||||
_response.choices[0].text = e.message
|
|
||||||
return _response
|
|
||||||
except Exception as e:
|
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.exception(
|
|
||||||
"litellm.proxy.proxy_server.anthropic_response(): Exception occured - {}".format(
|
|
||||||
str(e)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
error_msg = f"{str(e)}"
|
|
||||||
raise ProxyException(
|
|
||||||
message=getattr(e, "message", error_msg),
|
|
||||||
type=getattr(e, "type", "None"),
|
|
||||||
param=getattr(e, "param", "None"),
|
|
||||||
code=getattr(e, "status_code", 500),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
#### DEV UTILS ####
|
#### DEV UTILS ####
|
||||||
|
|
||||||
# @router.get(
|
# @router.get(
|
||||||
|
@ -8840,6 +8571,7 @@ app.include_router(rerank_router)
|
||||||
app.include_router(fine_tuning_router)
|
app.include_router(fine_tuning_router)
|
||||||
app.include_router(vertex_router)
|
app.include_router(vertex_router)
|
||||||
app.include_router(llm_passthrough_router)
|
app.include_router(llm_passthrough_router)
|
||||||
|
app.include_router(anthropic_router)
|
||||||
app.include_router(langfuse_router)
|
app.include_router(langfuse_router)
|
||||||
app.include_router(pass_through_router)
|
app.include_router(pass_through_router)
|
||||||
app.include_router(health_router)
|
app.include_router(health_router)
|
||||||
|
|
|
@ -10,6 +10,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
|
||||||
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
||||||
from litellm.proxy.utils import PrismaClient, hash_token
|
from litellm.proxy.utils import PrismaClient, hash_token
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
@ -119,9 +120,7 @@ def get_logging_payload( # noqa: PLR0915
|
||||||
response_obj = {}
|
response_obj = {}
|
||||||
# standardize this function to be used across, s3, dynamoDB, langfuse logging
|
# standardize this function to be used across, s3, dynamoDB, langfuse logging
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
metadata = (
|
metadata = get_litellm_metadata_from_kwargs(kwargs)
|
||||||
litellm_params.get("metadata", {}) or {}
|
|
||||||
) # if litellm_params['metadata'] == None
|
|
||||||
metadata = _add_proxy_server_request_to_metadata(
|
metadata = _add_proxy_server_request_to_metadata(
|
||||||
metadata=metadata, litellm_params=litellm_params
|
metadata=metadata, litellm_params=litellm_params
|
||||||
)
|
)
|
||||||
|
|
|
@ -580,6 +580,9 @@ class Router:
|
||||||
self.amoderation = self.factory_function(
|
self.amoderation = self.factory_function(
|
||||||
litellm.amoderation, call_type="moderation"
|
litellm.amoderation, call_type="moderation"
|
||||||
)
|
)
|
||||||
|
self.aanthropic_messages = self.factory_function(
|
||||||
|
litellm.anthropic_messages, call_type="anthropic_messages"
|
||||||
|
)
|
||||||
|
|
||||||
def discard(self):
|
def discard(self):
|
||||||
"""
|
"""
|
||||||
|
@ -2349,6 +2352,89 @@ class Router:
|
||||||
self.fail_calls[model] += 1
|
self.fail_calls[model] += 1
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def _ageneric_api_call_with_fallbacks(
|
||||||
|
self, model: str, original_function: Callable, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Make a generic LLM API call through the router, this allows you to use retries/fallbacks with litellm router
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to use
|
||||||
|
handler_function: The handler function to call (e.g., litellm.anthropic_messages)
|
||||||
|
**kwargs: Additional arguments to pass to the handler function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The response from the handler function
|
||||||
|
"""
|
||||||
|
handler_name = original_function.__name__
|
||||||
|
try:
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
f"Inside _ageneric_api_call() - handler: {handler_name}, model: {model}; kwargs: {kwargs}"
|
||||||
|
)
|
||||||
|
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||||
|
deployment = await self.async_get_available_deployment(
|
||||||
|
model=model,
|
||||||
|
request_kwargs=kwargs,
|
||||||
|
messages=kwargs.get("messages", None),
|
||||||
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
|
)
|
||||||
|
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||||
|
|
||||||
|
data = deployment["litellm_params"].copy()
|
||||||
|
model_name = data["model"]
|
||||||
|
|
||||||
|
model_client = self._get_async_openai_model_client(
|
||||||
|
deployment=deployment,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
self.total_calls[model_name] += 1
|
||||||
|
|
||||||
|
response = original_function(
|
||||||
|
**{
|
||||||
|
**data,
|
||||||
|
"caching": self.cache_responses,
|
||||||
|
"client": model_client,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
rpm_semaphore = self._get_client(
|
||||||
|
deployment=deployment,
|
||||||
|
kwargs=kwargs,
|
||||||
|
client_type="max_parallel_requests",
|
||||||
|
)
|
||||||
|
|
||||||
|
if rpm_semaphore is not None and isinstance(
|
||||||
|
rpm_semaphore, asyncio.Semaphore
|
||||||
|
):
|
||||||
|
async with rpm_semaphore:
|
||||||
|
"""
|
||||||
|
- Check rpm limits before making the call
|
||||||
|
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||||
|
"""
|
||||||
|
await self.async_routing_strategy_pre_call_checks(
|
||||||
|
deployment=deployment, parent_otel_span=parent_otel_span
|
||||||
|
)
|
||||||
|
response = await response # type: ignore
|
||||||
|
else:
|
||||||
|
await self.async_routing_strategy_pre_call_checks(
|
||||||
|
deployment=deployment, parent_otel_span=parent_otel_span
|
||||||
|
)
|
||||||
|
response = await response # type: ignore
|
||||||
|
|
||||||
|
self.success_calls[model_name] += 1
|
||||||
|
verbose_router_logger.info(
|
||||||
|
f"{handler_name}(model={model_name})\033[32m 200 OK\033[0m"
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
verbose_router_logger.info(
|
||||||
|
f"{handler_name}(model={model})\033[31m Exception {str(e)}\033[0m"
|
||||||
|
)
|
||||||
|
if model is not None:
|
||||||
|
self.fail_calls[model] += 1
|
||||||
|
raise e
|
||||||
|
|
||||||
def embedding(
|
def embedding(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -2869,10 +2955,14 @@ class Router:
|
||||||
def factory_function(
|
def factory_function(
|
||||||
self,
|
self,
|
||||||
original_function: Callable,
|
original_function: Callable,
|
||||||
call_type: Literal["assistants", "moderation"] = "assistants",
|
call_type: Literal[
|
||||||
|
"assistants", "moderation", "anthropic_messages"
|
||||||
|
] = "assistants",
|
||||||
):
|
):
|
||||||
async def new_function(
|
async def new_function(
|
||||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
custom_llm_provider: Optional[
|
||||||
|
Literal["openai", "azure", "anthropic"]
|
||||||
|
] = None,
|
||||||
client: Optional["AsyncOpenAI"] = None,
|
client: Optional["AsyncOpenAI"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
@ -2889,13 +2979,18 @@ class Router:
|
||||||
original_function=original_function,
|
original_function=original_function,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
elif call_type == "anthropic_messages":
|
||||||
|
return await self._ageneric_api_call_with_fallbacks( # type: ignore
|
||||||
|
original_function=original_function,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
return new_function
|
return new_function
|
||||||
|
|
||||||
async def _pass_through_assistants_endpoint_factory(
|
async def _pass_through_assistants_endpoint_factory(
|
||||||
self,
|
self,
|
||||||
original_function: Callable,
|
original_function: Callable,
|
||||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
custom_llm_provider: Optional[Literal["openai", "azure", "anthropic"]] = None,
|
||||||
client: Optional[AsyncOpenAI] = None,
|
client: Optional[AsyncOpenAI] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
|
@ -186,6 +186,7 @@ class CallTypes(Enum):
|
||||||
aretrieve_batch = "aretrieve_batch"
|
aretrieve_batch = "aretrieve_batch"
|
||||||
retrieve_batch = "retrieve_batch"
|
retrieve_batch = "retrieve_batch"
|
||||||
pass_through = "pass_through_endpoint"
|
pass_through = "pass_through_endpoint"
|
||||||
|
anthropic_messages = "anthropic_messages"
|
||||||
|
|
||||||
|
|
||||||
CallTypesLiteral = Literal[
|
CallTypesLiteral = Literal[
|
||||||
|
@ -209,6 +210,7 @@ CallTypesLiteral = Literal[
|
||||||
"create_batch",
|
"create_batch",
|
||||||
"acreate_batch",
|
"acreate_batch",
|
||||||
"pass_through_endpoint",
|
"pass_through_endpoint",
|
||||||
|
"anthropic_messages",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -191,6 +191,9 @@ from typing import (
|
||||||
from openai import OpenAIError as OriginalError
|
from openai import OpenAIError as OriginalError
|
||||||
|
|
||||||
from litellm.litellm_core_utils.thread_pool_executor import executor
|
from litellm.litellm_core_utils.thread_pool_executor import executor
|
||||||
|
from litellm.llms.base_llm.anthropic_messages.transformation import (
|
||||||
|
BaseAnthropicMessagesConfig,
|
||||||
|
)
|
||||||
from litellm.llms.base_llm.audio_transcription.transformation import (
|
from litellm.llms.base_llm.audio_transcription.transformation import (
|
||||||
BaseAudioTranscriptionConfig,
|
BaseAudioTranscriptionConfig,
|
||||||
)
|
)
|
||||||
|
@ -6245,6 +6248,15 @@ class ProviderConfigManager:
|
||||||
return litellm.JinaAIRerankConfig()
|
return litellm.JinaAIRerankConfig()
|
||||||
return litellm.CohereRerankConfig()
|
return litellm.CohereRerankConfig()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_provider_anthropic_messages_config(
|
||||||
|
model: str,
|
||||||
|
provider: LlmProviders,
|
||||||
|
) -> Optional[BaseAnthropicMessagesConfig]:
|
||||||
|
if litellm.LlmProviders.ANTHROPIC == provider:
|
||||||
|
return litellm.AnthropicMessagesConfig()
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_provider_audio_transcription_config(
|
def get_provider_audio_transcription_config(
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -329,57 +329,3 @@ async def test_aaapass_through_endpoint_pass_through_keys_langfuse(
|
||||||
setattr(
|
setattr(
|
||||||
litellm.proxy.proxy_server, "proxy_logging_obj", original_proxy_logging_obj
|
litellm.proxy.proxy_server, "proxy_logging_obj", original_proxy_logging_obj
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_pass_through_endpoint_anthropic(client):
|
|
||||||
import litellm
|
|
||||||
from litellm import Router
|
|
||||||
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
|
||||||
|
|
||||||
router = Router(
|
|
||||||
model_list=[
|
|
||||||
{
|
|
||||||
"model_name": "gpt-3.5-turbo",
|
|
||||||
"litellm_params": {
|
|
||||||
"model": "gpt-3.5-turbo",
|
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
||||||
"mock_response": "Hey, how's it going?",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
setattr(litellm.proxy.proxy_server, "llm_router", router)
|
|
||||||
|
|
||||||
# Define a pass-through endpoint
|
|
||||||
pass_through_endpoints = [
|
|
||||||
{
|
|
||||||
"path": "/v1/test-messages",
|
|
||||||
"target": anthropic_adapter,
|
|
||||||
"headers": {"litellm_user_api_key": "my-test-header"},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Initialize the pass-through endpoint
|
|
||||||
await initialize_pass_through_endpoints(pass_through_endpoints)
|
|
||||||
general_settings: Optional[dict] = (
|
|
||||||
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
|
|
||||||
)
|
|
||||||
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
|
|
||||||
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
|
|
||||||
|
|
||||||
_json_data = {
|
|
||||||
"model": "gpt-3.5-turbo",
|
|
||||||
"messages": [{"role": "user", "content": "Who are you?"}],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Make a request to the pass-through endpoint
|
|
||||||
response = client.post(
|
|
||||||
"/v1/test-messages", json=_json_data, headers={"my-test-header": "my-test-key"}
|
|
||||||
)
|
|
||||||
|
|
||||||
print("JSON response: ", _json_data)
|
|
||||||
|
|
||||||
# Assert the response
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
145
tests/pass_through_tests/base_anthropic_messages_test.py
Normal file
145
tests/pass_through_tests/base_anthropic_messages_test.py
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
import anthropic
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAnthropicMessagesTest(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base test class that enforces a common test across all test classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_client(self):
|
||||||
|
return anthropic.Anthropic()
|
||||||
|
|
||||||
|
def test_anthropic_basic_completion(self):
|
||||||
|
print("making basic completion request to anthropic passthrough")
|
||||||
|
client = self.get_client()
|
||||||
|
response = client.messages.create(
|
||||||
|
model="claude-3-5-sonnet-20241022",
|
||||||
|
max_tokens=1024,
|
||||||
|
messages=[{"role": "user", "content": "Say 'hello test' and nothing else"}],
|
||||||
|
extra_body={
|
||||||
|
"litellm_metadata": {
|
||||||
|
"tags": ["test-tag-1", "test-tag-2"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
def test_anthropic_streaming(self):
|
||||||
|
print("making streaming request to anthropic passthrough")
|
||||||
|
collected_output = []
|
||||||
|
client = self.get_client()
|
||||||
|
with client.messages.stream(
|
||||||
|
max_tokens=10,
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Say 'hello stream test' and nothing else"}
|
||||||
|
],
|
||||||
|
model="claude-3-5-sonnet-20241022",
|
||||||
|
extra_body={
|
||||||
|
"litellm_metadata": {
|
||||||
|
"tags": ["test-tag-stream-1", "test-tag-stream-2"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
) as stream:
|
||||||
|
for text in stream.text_stream:
|
||||||
|
collected_output.append(text)
|
||||||
|
|
||||||
|
full_response = "".join(collected_output)
|
||||||
|
print(full_response)
|
||||||
|
|
||||||
|
def test_anthropic_messages_with_thinking(self):
|
||||||
|
print("making request to anthropic passthrough with thinking")
|
||||||
|
client = self.get_client()
|
||||||
|
response = client.messages.create(
|
||||||
|
model="claude-3-7-sonnet-20250219",
|
||||||
|
max_tokens=20000,
|
||||||
|
thinking={"type": "enabled", "budget_tokens": 16000},
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Just pinging with thinking enabled"}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
# Verify the first content block is a thinking block
|
||||||
|
response_thinking = response.content[0].thinking
|
||||||
|
assert response_thinking is not None
|
||||||
|
assert len(response_thinking) > 0
|
||||||
|
|
||||||
|
def test_anthropic_streaming_with_thinking(self):
|
||||||
|
print("making streaming request to anthropic passthrough with thinking enabled")
|
||||||
|
collected_thinking = []
|
||||||
|
collected_response = []
|
||||||
|
client = self.get_client()
|
||||||
|
with client.messages.stream(
|
||||||
|
model="claude-3-7-sonnet-20250219",
|
||||||
|
max_tokens=20000,
|
||||||
|
thinking={"type": "enabled", "budget_tokens": 16000},
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Just pinging with thinking enabled"}
|
||||||
|
],
|
||||||
|
) as stream:
|
||||||
|
for event in stream:
|
||||||
|
if event.type == "content_block_delta":
|
||||||
|
if event.delta.type == "thinking_delta":
|
||||||
|
collected_thinking.append(event.delta.thinking)
|
||||||
|
elif event.delta.type == "text_delta":
|
||||||
|
collected_response.append(event.delta.text)
|
||||||
|
|
||||||
|
full_thinking = "".join(collected_thinking)
|
||||||
|
full_response = "".join(collected_response)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Thinking Response: {full_thinking[:100]}..."
|
||||||
|
) # Print first 100 chars of thinking
|
||||||
|
print(f"Response: {full_response}")
|
||||||
|
|
||||||
|
# Verify we received thinking content
|
||||||
|
assert len(collected_thinking) > 0
|
||||||
|
assert len(full_thinking) > 0
|
||||||
|
|
||||||
|
# Verify we also received a response
|
||||||
|
assert len(collected_response) > 0
|
||||||
|
assert len(full_response) > 0
|
||||||
|
|
||||||
|
def test_bad_request_error_handling_streaming(self):
|
||||||
|
print("making request to anthropic passthrough with bad request")
|
||||||
|
try:
|
||||||
|
client = self.get_client()
|
||||||
|
response = client.messages.create(
|
||||||
|
model="claude-3-5-sonnet-20241022",
|
||||||
|
max_tokens=10,
|
||||||
|
stream=True,
|
||||||
|
messages=["hi"],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
assert pytest.fail("Expected BadRequestError")
|
||||||
|
except anthropic.BadRequestError as e:
|
||||||
|
print("Got BadRequestError from anthropic, e=", e)
|
||||||
|
print(e.__cause__)
|
||||||
|
print(e.status_code)
|
||||||
|
print(e.response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Got unexpected exception: {e}")
|
||||||
|
|
||||||
|
def test_bad_request_error_handling_non_streaming(self):
|
||||||
|
print("making request to anthropic passthrough with bad request")
|
||||||
|
try:
|
||||||
|
client = self.get_client()
|
||||||
|
response = client.messages.create(
|
||||||
|
model="claude-3-5-sonnet-20241022",
|
||||||
|
max_tokens=10,
|
||||||
|
messages=["hi"],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
assert pytest.fail("Expected BadRequestError")
|
||||||
|
except anthropic.BadRequestError as e:
|
||||||
|
print("Got BadRequestError from anthropic, e=", e)
|
||||||
|
print(e.__cause__)
|
||||||
|
print(e.status_code)
|
||||||
|
print(e.response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Got unexpected exception: {e}")
|
|
@ -8,48 +8,6 @@ import aiohttp
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
|
||||||
client = anthropic.Anthropic(
|
|
||||||
base_url="http://0.0.0.0:4000/anthropic", api_key="sk-1234"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_anthropic_basic_completion():
|
|
||||||
print("making basic completion request to anthropic passthrough")
|
|
||||||
response = client.messages.create(
|
|
||||||
model="claude-3-5-sonnet-20241022",
|
|
||||||
max_tokens=1024,
|
|
||||||
messages=[{"role": "user", "content": "Say 'hello test' and nothing else"}],
|
|
||||||
extra_body={
|
|
||||||
"litellm_metadata": {
|
|
||||||
"tags": ["test-tag-1", "test-tag-2"],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
|
|
||||||
def test_anthropic_streaming():
|
|
||||||
print("making streaming request to anthropic passthrough")
|
|
||||||
collected_output = []
|
|
||||||
|
|
||||||
with client.messages.stream(
|
|
||||||
max_tokens=10,
|
|
||||||
messages=[
|
|
||||||
{"role": "user", "content": "Say 'hello stream test' and nothing else"}
|
|
||||||
],
|
|
||||||
model="claude-3-5-sonnet-20241022",
|
|
||||||
extra_body={
|
|
||||||
"litellm_metadata": {
|
|
||||||
"tags": ["test-tag-stream-1", "test-tag-stream-2"],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
) as stream:
|
|
||||||
for text in stream.text_stream:
|
|
||||||
collected_output.append(text)
|
|
||||||
|
|
||||||
full_response = "".join(collected_output)
|
|
||||||
print(full_response)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_anthropic_basic_completion_with_headers():
|
async def test_anthropic_basic_completion_with_headers():
|
||||||
|
|
28
tests/pass_through_tests/test_anthropic_passthrough_basic.py
Normal file
28
tests/pass_through_tests/test_anthropic_passthrough_basic.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
from base_anthropic_messages_test import BaseAnthropicMessagesTest
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnthropicPassthroughBasic(BaseAnthropicMessagesTest):
|
||||||
|
|
||||||
|
def get_client(self):
|
||||||
|
return anthropic.Anthropic(
|
||||||
|
base_url="http://0.0.0.0:4000/anthropic",
|
||||||
|
api_key="sk-1234",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnthropicMessagesEndpoint(BaseAnthropicMessagesTest):
|
||||||
|
def get_client(self):
|
||||||
|
return anthropic.Anthropic(
|
||||||
|
base_url="http://0.0.0.0:4000",
|
||||||
|
api_key="sk-1234",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_anthropic_messages_to_wildcard_model(self):
|
||||||
|
client = self.get_client()
|
||||||
|
response = client.messages.create(
|
||||||
|
model="anthropic/claude-3-opus-20240229",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
print(response)
|
|
@ -0,0 +1,487 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import AsyncIterator, Dict, Any
|
||||||
|
import asyncio
|
||||||
|
import unittest.mock
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
import pytest
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from litellm.llms.anthropic.experimental_pass_through.messages.handler import (
|
||||||
|
anthropic_messages,
|
||||||
|
)
|
||||||
|
from typing import Optional
|
||||||
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from litellm.router import Router
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
"""Create an instance of the default event loop for each test session."""
|
||||||
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def setup_and_teardown(event_loop): # Add event_loop as a dependency
|
||||||
|
curr_dir = os.getcwd()
|
||||||
|
sys.path.insert(0, os.path.abspath("../.."))
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
importlib.reload(litellm)
|
||||||
|
|
||||||
|
# Set the event loop from the fixture
|
||||||
|
asyncio.set_event_loop(event_loop)
|
||||||
|
|
||||||
|
print(litellm)
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Clean up any pending tasks
|
||||||
|
pending = asyncio.all_tasks(event_loop)
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Run the event loop until all tasks are cancelled
|
||||||
|
if pending:
|
||||||
|
event_loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_anthropic_response(response: Dict[str, Any]):
|
||||||
|
assert "id" in response
|
||||||
|
assert "content" in response
|
||||||
|
assert "model" in response
|
||||||
|
assert response["role"] == "assistant"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anthropic_messages_non_streaming():
|
||||||
|
"""
|
||||||
|
Test the anthropic_messages with non-streaming request
|
||||||
|
"""
|
||||||
|
# Get API key from environment
|
||||||
|
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
pytest.skip("ANTHROPIC_API_KEY not found in environment")
|
||||||
|
|
||||||
|
# Set up test parameters
|
||||||
|
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
|
||||||
|
|
||||||
|
# Call the handler
|
||||||
|
response = await anthropic_messages(
|
||||||
|
messages=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
model="claude-3-haiku-20240307",
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response
|
||||||
|
assert "id" in response
|
||||||
|
assert "content" in response
|
||||||
|
assert "model" in response
|
||||||
|
assert response["role"] == "assistant"
|
||||||
|
|
||||||
|
print(f"Non-streaming response: {json.dumps(response, indent=2)}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anthropic_messages_streaming():
|
||||||
|
"""
|
||||||
|
Test the anthropic_messages with streaming request
|
||||||
|
"""
|
||||||
|
# Get API key from environment
|
||||||
|
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
pytest.skip("ANTHROPIC_API_KEY not found in environment")
|
||||||
|
|
||||||
|
# Set up test parameters
|
||||||
|
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
|
||||||
|
|
||||||
|
# Call the handler
|
||||||
|
async_httpx_client = AsyncHTTPHandler()
|
||||||
|
response = await anthropic_messages(
|
||||||
|
messages=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
model="claude-3-haiku-20240307",
|
||||||
|
max_tokens=100,
|
||||||
|
stream=True,
|
||||||
|
client=async_httpx_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, AsyncIterator):
|
||||||
|
async for chunk in response:
|
||||||
|
print("chunk=", chunk)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anthropic_messages_streaming_with_bad_request():
|
||||||
|
"""
|
||||||
|
Test the anthropic_messages with streaming request
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await anthropic_messages(
|
||||||
|
messages=["hi"],
|
||||||
|
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||||
|
model="claude-3-haiku-20240307",
|
||||||
|
max_tokens=100,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
async for chunk in response:
|
||||||
|
print("chunk=", chunk)
|
||||||
|
except Exception as e:
|
||||||
|
print("got exception", e)
|
||||||
|
print("vars", vars(e))
|
||||||
|
assert e.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anthropic_messages_router_streaming_with_bad_request():
|
||||||
|
"""
|
||||||
|
Test the anthropic_messages with streaming request
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "claude-special-alias",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "claude-3-haiku-20240307",
|
||||||
|
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await router.aanthropic_messages(
|
||||||
|
messages=["hi"],
|
||||||
|
model="claude-special-alias",
|
||||||
|
max_tokens=100,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
async for chunk in response:
|
||||||
|
print("chunk=", chunk)
|
||||||
|
except Exception as e:
|
||||||
|
print("got exception", e)
|
||||||
|
print("vars", vars(e))
|
||||||
|
assert e.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anthropic_messages_litellm_router_non_streaming():
|
||||||
|
"""
|
||||||
|
Test the anthropic_messages with non-streaming request
|
||||||
|
"""
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "claude-special-alias",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "claude-3-haiku-20240307",
|
||||||
|
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up test parameters
|
||||||
|
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
|
||||||
|
|
||||||
|
# Call the handler
|
||||||
|
response = await router.aanthropic_messages(
|
||||||
|
messages=messages,
|
||||||
|
model="claude-special-alias",
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response
|
||||||
|
assert "id" in response
|
||||||
|
assert "content" in response
|
||||||
|
assert "model" in response
|
||||||
|
assert response["role"] == "assistant"
|
||||||
|
|
||||||
|
print(f"Non-streaming response: {json.dumps(response, indent=2)}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomLogger(CustomLogger):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.logged_standard_logging_payload: Optional[StandardLoggingPayload] = None
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print("inside async_log_success_event")
|
||||||
|
self.logged_standard_logging_payload = kwargs.get("standard_logging_object")
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anthropic_messages_litellm_router_non_streaming_with_logging():
|
||||||
|
"""
|
||||||
|
Test the anthropic_messages with non-streaming request
|
||||||
|
|
||||||
|
- Ensure Cost + Usage is tracked
|
||||||
|
"""
|
||||||
|
test_custom_logger = TestCustomLogger()
|
||||||
|
litellm.callbacks = [test_custom_logger]
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "claude-special-alias",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "claude-3-haiku-20240307",
|
||||||
|
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up test parameters
|
||||||
|
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
|
||||||
|
|
||||||
|
# Call the handler
|
||||||
|
response = await router.aanthropic_messages(
|
||||||
|
messages=messages,
|
||||||
|
model="claude-special-alias",
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify response
|
||||||
|
_validate_anthropic_response(response)
|
||||||
|
|
||||||
|
print(f"Non-streaming response: {json.dumps(response, indent=2)}")
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
assert test_custom_logger.logged_standard_logging_payload["messages"] == messages
|
||||||
|
assert test_custom_logger.logged_standard_logging_payload["response"] is not None
|
||||||
|
assert (
|
||||||
|
test_custom_logger.logged_standard_logging_payload["model"]
|
||||||
|
== "claude-3-haiku-20240307"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check logged usage + spend
|
||||||
|
assert test_custom_logger.logged_standard_logging_payload["response_cost"] > 0
|
||||||
|
assert (
|
||||||
|
test_custom_logger.logged_standard_logging_payload["prompt_tokens"]
|
||||||
|
== response["usage"]["input_tokens"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
test_custom_logger.logged_standard_logging_payload["completion_tokens"]
|
||||||
|
== response["usage"]["output_tokens"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anthropic_messages_litellm_router_streaming_with_logging():
|
||||||
|
"""
|
||||||
|
Test the anthropic_messages with streaming request
|
||||||
|
|
||||||
|
- Ensure Cost + Usage is tracked
|
||||||
|
"""
|
||||||
|
test_custom_logger = TestCustomLogger()
|
||||||
|
litellm.callbacks = [test_custom_logger]
|
||||||
|
# litellm._turn_on_debug()
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "claude-special-alias",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "claude-3-haiku-20240307",
|
||||||
|
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up test parameters
|
||||||
|
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
|
||||||
|
|
||||||
|
# Call the handler
|
||||||
|
response = await router.aanthropic_messages(
|
||||||
|
messages=messages,
|
||||||
|
model="claude-special-alias",
|
||||||
|
max_tokens=100,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
response_prompt_tokens = 0
|
||||||
|
response_completion_tokens = 0
|
||||||
|
all_anthropic_usage_chunks = []
|
||||||
|
|
||||||
|
async for chunk in response:
|
||||||
|
# Decode chunk if it's bytes
|
||||||
|
print("chunk=", chunk)
|
||||||
|
|
||||||
|
# Handle SSE format chunks
|
||||||
|
if isinstance(chunk, bytes):
|
||||||
|
chunk_str = chunk.decode("utf-8")
|
||||||
|
# Extract the JSON data part from SSE format
|
||||||
|
for line in chunk_str.split("\n"):
|
||||||
|
if line.startswith("data: "):
|
||||||
|
try:
|
||||||
|
json_data = json.loads(line[6:]) # Skip the 'data: ' prefix
|
||||||
|
print(
|
||||||
|
"\n\nJSON data:",
|
||||||
|
json.dumps(json_data, indent=4, default=str),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract usage information
|
||||||
|
if (
|
||||||
|
json_data.get("type") == "message_start"
|
||||||
|
and "message" in json_data
|
||||||
|
):
|
||||||
|
if "usage" in json_data["message"]:
|
||||||
|
usage = json_data["message"]["usage"]
|
||||||
|
all_anthropic_usage_chunks.append(usage)
|
||||||
|
print(
|
||||||
|
"USAGE BLOCK",
|
||||||
|
json.dumps(usage, indent=4, default=str),
|
||||||
|
)
|
||||||
|
elif "usage" in json_data:
|
||||||
|
usage = json_data["usage"]
|
||||||
|
all_anthropic_usage_chunks.append(usage)
|
||||||
|
print(
|
||||||
|
"USAGE BLOCK", json.dumps(usage, indent=4, default=str)
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"Failed to parse JSON from: {line[6:]}")
|
||||||
|
elif hasattr(chunk, "message"):
|
||||||
|
if chunk.message.usage:
|
||||||
|
print(
|
||||||
|
"USAGE BLOCK",
|
||||||
|
json.dumps(chunk.message.usage, indent=4, default=str),
|
||||||
|
)
|
||||||
|
all_anthropic_usage_chunks.append(chunk.message.usage)
|
||||||
|
elif hasattr(chunk, "usage"):
|
||||||
|
print("USAGE BLOCK", json.dumps(chunk.usage, indent=4, default=str))
|
||||||
|
all_anthropic_usage_chunks.append(chunk.usage)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"all_anthropic_usage_chunks",
|
||||||
|
json.dumps(all_anthropic_usage_chunks, indent=4, default=str),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract token counts from usage data
|
||||||
|
if all_anthropic_usage_chunks:
|
||||||
|
response_prompt_tokens = max(
|
||||||
|
[usage.get("input_tokens", 0) for usage in all_anthropic_usage_chunks]
|
||||||
|
)
|
||||||
|
response_completion_tokens = max(
|
||||||
|
[usage.get("output_tokens", 0) for usage in all_anthropic_usage_chunks]
|
||||||
|
)
|
||||||
|
|
||||||
|
print("input_tokens_anthropic_api", response_prompt_tokens)
|
||||||
|
print("output_tokens_anthropic_api", response_completion_tokens)
|
||||||
|
|
||||||
|
await asyncio.sleep(4)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"logged_standard_logging_payload",
|
||||||
|
json.dumps(
|
||||||
|
test_custom_logger.logged_standard_logging_payload, indent=4, default=str
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert test_custom_logger.logged_standard_logging_payload["messages"] == messages
|
||||||
|
assert test_custom_logger.logged_standard_logging_payload["response"] is not None
|
||||||
|
assert (
|
||||||
|
test_custom_logger.logged_standard_logging_payload["model"]
|
||||||
|
== "claude-3-haiku-20240307"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check logged usage + spend
|
||||||
|
assert test_custom_logger.logged_standard_logging_payload["response_cost"] > 0
|
||||||
|
assert (
|
||||||
|
test_custom_logger.logged_standard_logging_payload["prompt_tokens"]
|
||||||
|
== response_prompt_tokens
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
test_custom_logger.logged_standard_logging_payload["completion_tokens"]
|
||||||
|
== response_completion_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anthropic_messages_with_extra_headers():
|
||||||
|
"""
|
||||||
|
Test the anthropic_messages with extra headers
|
||||||
|
"""
|
||||||
|
# Get API key from environment
|
||||||
|
api_key = os.getenv("ANTHROPIC_API_KEY", "fake-api-key")
|
||||||
|
|
||||||
|
# Set up test parameters
|
||||||
|
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
|
||||||
|
extra_headers = {
|
||||||
|
"anthropic-beta": "very-custom-beta-value",
|
||||||
|
"anthropic-version": "custom-version-for-test",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"id": "msg_123456",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Why did the chicken cross the road? To get to the other side!",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "claude-3-haiku-20240307",
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a mock client with AsyncMock for the post method
|
||||||
|
mock_client = MagicMock(spec=AsyncHTTPHandler)
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Call the handler with extra_headers and our mocked client
|
||||||
|
response = await anthropic_messages(
|
||||||
|
messages=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
model="claude-3-haiku-20240307",
|
||||||
|
max_tokens=100,
|
||||||
|
client=mock_client,
|
||||||
|
provider_specific_header={
|
||||||
|
"custom_llm_provider": "anthropic",
|
||||||
|
"extra_headers": extra_headers,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the post method was called with the right parameters
|
||||||
|
mock_client.post.assert_called_once()
|
||||||
|
call_kwargs = mock_client.post.call_args.kwargs
|
||||||
|
|
||||||
|
# Verify headers were passed correctly
|
||||||
|
headers = call_kwargs.get("headers", {})
|
||||||
|
print("HEADERS IN REQUEST", headers)
|
||||||
|
for key, value in extra_headers.items():
|
||||||
|
assert key in headers
|
||||||
|
assert headers[key] == value
|
||||||
|
|
||||||
|
# Verify the response was processed correctly
|
||||||
|
assert response == mock_response.json.return_value
|
||||||
|
|
||||||
|
return response
|
|
@ -54,7 +54,7 @@ async def test_get_litellm_virtual_key():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_vertex_proxy_route_api_key_auth():
|
async def test_async_vertex_proxy_route_api_key_auth():
|
||||||
"""
|
"""
|
||||||
Critical
|
Critical
|
||||||
|
|
||||||
|
@ -207,7 +207,7 @@ async def test_get_vertex_credentials_stored():
|
||||||
router.add_vertex_credentials(
|
router.add_vertex_credentials(
|
||||||
project_id="test-project",
|
project_id="test-project",
|
||||||
location="us-central1",
|
location="us-central1",
|
||||||
vertex_credentials="test-creds",
|
vertex_credentials='{"credentials": "test-creds"}',
|
||||||
)
|
)
|
||||||
|
|
||||||
creds = router.get_vertex_credentials(
|
creds = router.get_vertex_credentials(
|
||||||
|
@ -215,7 +215,7 @@ async def test_get_vertex_credentials_stored():
|
||||||
)
|
)
|
||||||
assert creds.vertex_project == "test-project"
|
assert creds.vertex_project == "test-project"
|
||||||
assert creds.vertex_location == "us-central1"
|
assert creds.vertex_location == "us-central1"
|
||||||
assert creds.vertex_credentials == "test-creds"
|
assert creds.vertex_credentials == '{"credentials": "test-creds"}'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -227,18 +227,20 @@ async def test_add_vertex_credentials():
|
||||||
router.add_vertex_credentials(
|
router.add_vertex_credentials(
|
||||||
project_id="test-project",
|
project_id="test-project",
|
||||||
location="us-central1",
|
location="us-central1",
|
||||||
vertex_credentials="test-creds",
|
vertex_credentials='{"credentials": "test-creds"}',
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "test-project-us-central1" in router.deployment_key_to_vertex_credentials
|
assert "test-project-us-central1" in router.deployment_key_to_vertex_credentials
|
||||||
creds = router.deployment_key_to_vertex_credentials["test-project-us-central1"]
|
creds = router.deployment_key_to_vertex_credentials["test-project-us-central1"]
|
||||||
assert creds.vertex_project == "test-project"
|
assert creds.vertex_project == "test-project"
|
||||||
assert creds.vertex_location == "us-central1"
|
assert creds.vertex_location == "us-central1"
|
||||||
assert creds.vertex_credentials == "test-creds"
|
assert creds.vertex_credentials == '{"credentials": "test-creds"}'
|
||||||
|
|
||||||
# Test adding with None values
|
# Test adding with None values
|
||||||
router.add_vertex_credentials(
|
router.add_vertex_credentials(
|
||||||
project_id=None, location=None, vertex_credentials="test-creds"
|
project_id=None,
|
||||||
|
location=None,
|
||||||
|
vertex_credentials='{"credentials": "test-creds"}',
|
||||||
)
|
)
|
||||||
# Should not add None values
|
# Should not add None values
|
||||||
assert len(router.deployment_key_to_vertex_credentials) == 1
|
assert len(router.deployment_key_to_vertex_credentials) == 1
|
||||||
|
|
|
@ -6,6 +6,7 @@ from typing import Optional
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -289,43 +290,6 @@ async def test_aaaaatext_completion_endpoint(model_list, sync_mode):
|
||||||
assert response.choices[0].text == "I'm fine, thank you!"
|
assert response.choices[0].text == "I'm fine, thank you!"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_anthropic_router_completion_e2e(model_list):
|
|
||||||
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
|
||||||
from litellm.types.llms.anthropic import AnthropicResponse
|
|
||||||
|
|
||||||
litellm.set_verbose = True
|
|
||||||
|
|
||||||
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
|
|
||||||
|
|
||||||
router = Router(model_list=model_list)
|
|
||||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
|
||||||
|
|
||||||
## Test 1: user facing function
|
|
||||||
response = await router.aadapter_completion(
|
|
||||||
model="claude-3-5-sonnet-20240620",
|
|
||||||
messages=messages,
|
|
||||||
adapter_id="anthropic",
|
|
||||||
mock_response="This is a fake call",
|
|
||||||
)
|
|
||||||
|
|
||||||
## Test 2: underlying function
|
|
||||||
await router._aadapter_completion(
|
|
||||||
model="claude-3-5-sonnet-20240620",
|
|
||||||
messages=messages,
|
|
||||||
adapter_id="anthropic",
|
|
||||||
mock_response="This is a fake call",
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Response: {}".format(response))
|
|
||||||
|
|
||||||
assert response is not None
|
|
||||||
|
|
||||||
AnthropicResponse.model_validate(response)
|
|
||||||
|
|
||||||
assert response.model == "gpt-3.5-turbo"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_router_with_empty_choices(model_list):
|
async def test_router_with_empty_choices(model_list):
|
||||||
"""
|
"""
|
||||||
|
@ -349,3 +313,200 @@ async def test_router_with_empty_choices(model_list):
|
||||||
mock_response=mock_response,
|
mock_response=mock_response,
|
||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ageneric_api_call_with_fallbacks_basic():
|
||||||
|
"""
|
||||||
|
Test the _ageneric_api_call_with_fallbacks method with a basic successful call
|
||||||
|
"""
|
||||||
|
# Create a mock function that will be passed to _ageneric_api_call_with_fallbacks
|
||||||
|
mock_function = AsyncMock()
|
||||||
|
mock_function.__name__ = "test_function"
|
||||||
|
|
||||||
|
# Create a mock response
|
||||||
|
mock_response = {
|
||||||
|
"id": "resp_123456",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "This is a test response",
|
||||||
|
"model": "test-model",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||||
|
}
|
||||||
|
mock_function.return_value = mock_response
|
||||||
|
|
||||||
|
# Create a router with a test model
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "test-model-alias",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "anthropic/test-model",
|
||||||
|
"api_key": "fake-api-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the _ageneric_api_call_with_fallbacks method
|
||||||
|
response = await router._ageneric_api_call_with_fallbacks(
|
||||||
|
model="test-model-alias",
|
||||||
|
original_function=mock_function,
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the mock function was called
|
||||||
|
mock_function.assert_called_once()
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert response == mock_response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aadapter_completion():
|
||||||
|
"""
|
||||||
|
Test the aadapter_completion method which uses async_function_with_fallbacks
|
||||||
|
"""
|
||||||
|
# Create a mock for the _aadapter_completion method
|
||||||
|
mock_response = {
|
||||||
|
"id": "adapter_resp_123",
|
||||||
|
"object": "adapter.completion",
|
||||||
|
"created": 1677858242,
|
||||||
|
"model": "test-model-with-adapter",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "This is a test adapter response",
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a router with a patched _aadapter_completion method
|
||||||
|
with patch.object(
|
||||||
|
Router, "_aadapter_completion", new_callable=AsyncMock
|
||||||
|
) as mock_method:
|
||||||
|
mock_method.return_value = mock_response
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "test-adapter-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "anthropic/test-model",
|
||||||
|
"api_key": "fake-api-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace the async_function_with_fallbacks with a mock
|
||||||
|
router.async_function_with_fallbacks = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Call the aadapter_completion method
|
||||||
|
response = await router.aadapter_completion(
|
||||||
|
adapter_id="test-adapter-id",
|
||||||
|
model="test-adapter-model",
|
||||||
|
prompt="This is a test prompt",
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert response == mock_response
|
||||||
|
|
||||||
|
# Verify async_function_with_fallbacks was called with the right parameters
|
||||||
|
router.async_function_with_fallbacks.assert_called_once()
|
||||||
|
call_kwargs = router.async_function_with_fallbacks.call_args.kwargs
|
||||||
|
assert call_kwargs["adapter_id"] == "test-adapter-id"
|
||||||
|
assert call_kwargs["model"] == "test-adapter-model"
|
||||||
|
assert call_kwargs["prompt"] == "This is a test prompt"
|
||||||
|
assert call_kwargs["max_tokens"] == 100
|
||||||
|
assert call_kwargs["original_function"] == router._aadapter_completion
|
||||||
|
assert "metadata" in call_kwargs
|
||||||
|
assert call_kwargs["metadata"]["model_group"] == "test-adapter-model"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test__aadapter_completion():
|
||||||
|
"""
|
||||||
|
Test the _aadapter_completion method directly
|
||||||
|
"""
|
||||||
|
# Create a mock response for litellm.aadapter_completion
|
||||||
|
mock_response = {
|
||||||
|
"id": "adapter_resp_123",
|
||||||
|
"object": "adapter.completion",
|
||||||
|
"created": 1677858242,
|
||||||
|
"model": "test-model-with-adapter",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "This is a test adapter response",
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a router with a mocked litellm.aadapter_completion
|
||||||
|
with patch(
|
||||||
|
"litellm.aadapter_completion", new_callable=AsyncMock
|
||||||
|
) as mock_adapter_completion:
|
||||||
|
mock_adapter_completion.return_value = mock_response
|
||||||
|
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "test-adapter-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "anthropic/test-model",
|
||||||
|
"api_key": "fake-api-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the async_get_available_deployment method
|
||||||
|
router.async_get_available_deployment = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"model_name": "test-adapter-model",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "test-model",
|
||||||
|
"api_key": "fake-api-key",
|
||||||
|
},
|
||||||
|
"model_info": {
|
||||||
|
"id": "test-unique-id",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the async_routing_strategy_pre_call_checks method
|
||||||
|
router.async_routing_strategy_pre_call_checks = AsyncMock()
|
||||||
|
|
||||||
|
# Call the _aadapter_completion method
|
||||||
|
response = await router._aadapter_completion(
|
||||||
|
adapter_id="test-adapter-id",
|
||||||
|
model="test-adapter-model",
|
||||||
|
prompt="This is a test prompt",
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert response == mock_response
|
||||||
|
|
||||||
|
# Verify litellm.aadapter_completion was called with the right parameters
|
||||||
|
mock_adapter_completion.assert_called_once()
|
||||||
|
call_kwargs = mock_adapter_completion.call_args.kwargs
|
||||||
|
assert call_kwargs["adapter_id"] == "test-adapter-id"
|
||||||
|
assert call_kwargs["model"] == "test-model"
|
||||||
|
assert call_kwargs["prompt"] == "This is a test prompt"
|
||||||
|
assert call_kwargs["max_tokens"] == 100
|
||||||
|
assert call_kwargs["api_key"] == "fake-api-key"
|
||||||
|
assert call_kwargs["caching"] == router.cache_responses
|
||||||
|
|
||||||
|
# Verify the success call was recorded
|
||||||
|
assert router.success_calls["test-model"] == 1
|
||||||
|
assert router.total_calls["test-model"] == 1
|
||||||
|
|
||||||
|
# Verify async_routing_strategy_pre_call_checks was called
|
||||||
|
router.async_routing_strategy_pre_call_checks.assert_called_once()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue