mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
(Refactor) /v1/messages
to follow simpler logic for Anthropic API spec (#9013)
* anthropic_messages_handler v0 * fix /messages * working messages with router methods * test_anthropic_messages_handler_litellm_router_non_streaming * test_anthropic_messages_litellm_router_non_streaming_with_logging * AnthropicMessagesConfig * _handle_anthropic_messages_response_logging * working with /v1/messages endpoint * working /v1/messages endpoint * refactor to use router factory function * use aanthropic_messages * use BaseConfig for Anthropic /v1/messages * track api key, team on /v1/messages endpoint * fix get_logging_payload * BaseAnthropicMessagesTest * align test config * test_anthropic_messages_with_thinking * test_anthropic_streaming_with_thinking * fix - display anthropic url for debugging * test_bad_request_error_handling * test_anthropic_messages_router_streaming_with_bad_request * fix ProxyException * test_bad_request_error_handling_streaming * use provider_specific_header * test_anthropic_messages_with_extra_headers * test_anthropic_messages_to_wildcard_model * fix gcs pub sub test * standard_logging_payload * fix unit testing for anthopic /v1/messages support * fix pass through anthropic messages api * delete dead code * fix anthropic pass through response * revert change to spend tracking utils * fix get_litellm_metadata_from_kwargs * fix spend logs payload json * proxy_pass_through_endpoint_tests * TestAnthropicPassthroughBasic * fix pass through tests * test_async_vertex_proxy_route_api_key_auth * _handle_anthropic_messages_response_logging * vertex_credentials * test_set_default_vertex_config * test_anthropic_messages_litellm_router_non_streaming_with_logging * test_ageneric_api_call_with_fallbacks_basic * test__aadapter_completion
This commit is contained in:
parent
31c5ea74ab
commit
f47987e673
25 changed files with 1581 additions and 1027 deletions
|
@ -1935,12 +1935,12 @@ jobs:
|
|||
pip install prisma
|
||||
pip install fastapi
|
||||
pip install jsonschema
|
||||
pip install "httpx==0.24.1"
|
||||
pip install "httpx==0.27.0"
|
||||
pip install "anyio==3.7.1"
|
||||
pip install "asyncio==3.4.3"
|
||||
pip install "PyGithub==1.59.1"
|
||||
pip install "google-cloud-aiplatform==1.59.0"
|
||||
pip install "anthropic==0.21.3"
|
||||
pip install "anthropic==0.49.0"
|
||||
# Run pytest and generate JUnit XML report
|
||||
- run:
|
||||
name: Build Docker image
|
||||
|
|
|
@ -800,9 +800,6 @@ from .llms.oobabooga.chat.transformation import OobaboogaConfig
|
|||
from .llms.maritalk import MaritalkConfig
|
||||
from .llms.openrouter.chat.transformation import OpenrouterConfig
|
||||
from .llms.anthropic.chat.transformation import AnthropicConfig
|
||||
from .llms.anthropic.experimental_pass_through.transformation import (
|
||||
AnthropicExperimentalPassThroughConfig,
|
||||
)
|
||||
from .llms.groq.stt.transformation import GroqSTTConfig
|
||||
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
||||
from .llms.triton.completion.transformation import TritonConfig
|
||||
|
@ -821,6 +818,9 @@ from .llms.infinity.rerank.transformation import InfinityRerankConfig
|
|||
from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig
|
||||
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
||||
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
|
||||
from .llms.anthropic.experimental_pass_through.messages.transformation import (
|
||||
AnthropicMessagesConfig,
|
||||
)
|
||||
from .llms.together_ai.chat import TogetherAIConfig
|
||||
from .llms.together_ai.completion.transformation import TogetherAITextCompletionConfig
|
||||
from .llms.cloudflare.chat.transformation import CloudflareChatConfig
|
||||
|
@ -1011,6 +1011,7 @@ from .assistants.main import *
|
|||
from .batches.main import *
|
||||
from .batch_completion.main import * # type: ignore
|
||||
from .rerank_api.main import *
|
||||
from .llms.anthropic.experimental_pass_through.messages.handler import *
|
||||
from .realtime_api.main import _arealtime
|
||||
from .fine_tuning.main import *
|
||||
from .files.main import *
|
||||
|
|
|
@ -1,186 +0,0 @@
|
|||
# What is this?
|
||||
## Translates OpenAI call to Anthropic `/v1/messages` format
|
||||
import traceback
|
||||
from typing import Any, Optional
|
||||
|
||||
import litellm
|
||||
from litellm import ChatCompletionRequest, verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.llms.anthropic import AnthropicMessagesRequest, AnthropicResponse
|
||||
from litellm.types.utils import AdapterCompletionStreamWrapper, ModelResponse
|
||||
|
||||
|
||||
class AnthropicAdapter(CustomLogger):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def translate_completion_input_params(
|
||||
self, kwargs
|
||||
) -> Optional[ChatCompletionRequest]:
|
||||
"""
|
||||
- translate params, where needed
|
||||
- pass rest, as is
|
||||
"""
|
||||
request_body = AnthropicMessagesRequest(**kwargs) # type: ignore
|
||||
|
||||
translated_body = litellm.AnthropicExperimentalPassThroughConfig().translate_anthropic_to_openai(
|
||||
anthropic_message_request=request_body
|
||||
)
|
||||
|
||||
return translated_body
|
||||
|
||||
def translate_completion_output_params(
|
||||
self, response: ModelResponse
|
||||
) -> Optional[AnthropicResponse]:
|
||||
|
||||
return litellm.AnthropicExperimentalPassThroughConfig().translate_openai_response_to_anthropic(
|
||||
response=response
|
||||
)
|
||||
|
||||
def translate_completion_output_params_streaming(
|
||||
self, completion_stream: Any
|
||||
) -> AdapterCompletionStreamWrapper | None:
|
||||
return AnthropicStreamWrapper(completion_stream=completion_stream)
|
||||
|
||||
|
||||
anthropic_adapter = AnthropicAdapter()
|
||||
|
||||
|
||||
class AnthropicStreamWrapper(AdapterCompletionStreamWrapper):
|
||||
"""
|
||||
- first chunk return 'message_start'
|
||||
- content block must be started and stopped
|
||||
- finish_reason must map exactly to anthropic reason, else anthropic client won't be able to parse it.
|
||||
"""
|
||||
|
||||
sent_first_chunk: bool = False
|
||||
sent_content_block_start: bool = False
|
||||
sent_content_block_finish: bool = False
|
||||
sent_last_message: bool = False
|
||||
holding_chunk: Optional[Any] = None
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
if self.sent_first_chunk is False:
|
||||
self.sent_first_chunk = True
|
||||
return {
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": 25, "output_tokens": 1},
|
||||
},
|
||||
}
|
||||
if self.sent_content_block_start is False:
|
||||
self.sent_content_block_start = True
|
||||
return {
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
}
|
||||
|
||||
for chunk in self.completion_stream:
|
||||
if chunk == "None" or chunk is None:
|
||||
raise Exception
|
||||
|
||||
processed_chunk = litellm.AnthropicExperimentalPassThroughConfig().translate_streaming_openai_response_to_anthropic(
|
||||
response=chunk
|
||||
)
|
||||
if (
|
||||
processed_chunk["type"] == "message_delta"
|
||||
and self.sent_content_block_finish is False
|
||||
):
|
||||
self.holding_chunk = processed_chunk
|
||||
self.sent_content_block_finish = True
|
||||
return {
|
||||
"type": "content_block_stop",
|
||||
"index": 0,
|
||||
}
|
||||
elif self.holding_chunk is not None:
|
||||
return_chunk = self.holding_chunk
|
||||
self.holding_chunk = processed_chunk
|
||||
return return_chunk
|
||||
else:
|
||||
return processed_chunk
|
||||
if self.holding_chunk is not None:
|
||||
return_chunk = self.holding_chunk
|
||||
self.holding_chunk = None
|
||||
return return_chunk
|
||||
if self.sent_last_message is False:
|
||||
self.sent_last_message = True
|
||||
return {"type": "message_stop"}
|
||||
raise StopIteration
|
||||
except StopIteration:
|
||||
if self.sent_last_message is False:
|
||||
self.sent_last_message = True
|
||||
return {"type": "message_stop"}
|
||||
raise StopIteration
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"Anthropic Adapter - {}\n{}".format(e, traceback.format_exc())
|
||||
)
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
if self.sent_first_chunk is False:
|
||||
self.sent_first_chunk = True
|
||||
return {
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": 25, "output_tokens": 1},
|
||||
},
|
||||
}
|
||||
if self.sent_content_block_start is False:
|
||||
self.sent_content_block_start = True
|
||||
return {
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
}
|
||||
async for chunk in self.completion_stream:
|
||||
if chunk == "None" or chunk is None:
|
||||
raise Exception
|
||||
processed_chunk = litellm.AnthropicExperimentalPassThroughConfig().translate_streaming_openai_response_to_anthropic(
|
||||
response=chunk
|
||||
)
|
||||
if (
|
||||
processed_chunk["type"] == "message_delta"
|
||||
and self.sent_content_block_finish is False
|
||||
):
|
||||
self.holding_chunk = processed_chunk
|
||||
self.sent_content_block_finish = True
|
||||
return {
|
||||
"type": "content_block_stop",
|
||||
"index": 0,
|
||||
}
|
||||
elif self.holding_chunk is not None:
|
||||
return_chunk = self.holding_chunk
|
||||
self.holding_chunk = processed_chunk
|
||||
return return_chunk
|
||||
else:
|
||||
return processed_chunk
|
||||
if self.holding_chunk is not None:
|
||||
return_chunk = self.holding_chunk
|
||||
self.holding_chunk = None
|
||||
return return_chunk
|
||||
if self.sent_last_message is False:
|
||||
self.sent_last_message = True
|
||||
return {"type": "message_stop"}
|
||||
raise StopIteration
|
||||
except StopIteration:
|
||||
if self.sent_last_message is False:
|
||||
self.sent_last_message = True
|
||||
return {"type": "message_stop"}
|
||||
raise StopAsyncIteration
|
|
@ -73,6 +73,8 @@ def remove_index_from_tool_calls(
|
|||
def get_litellm_metadata_from_kwargs(kwargs: dict):
|
||||
"""
|
||||
Helper to get litellm metadata from all litellm request kwargs
|
||||
|
||||
Return `litellm_metadata` if it exists, otherwise return `metadata`
|
||||
"""
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
if litellm_params:
|
||||
|
|
|
@ -932,6 +932,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||
self.model_call_details["end_time"] = end_time
|
||||
self.model_call_details["cache_hit"] = cache_hit
|
||||
|
||||
if self.call_type == CallTypes.anthropic_messages.value:
|
||||
result = self._handle_anthropic_messages_response_logging(result=result)
|
||||
## if model in model cost map - log the response cost
|
||||
## else set cost to None
|
||||
if (
|
||||
|
@ -2304,6 +2307,37 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
return complete_streaming_response
|
||||
return None
|
||||
|
||||
def _handle_anthropic_messages_response_logging(self, result: Any) -> ModelResponse:
|
||||
"""
|
||||
Handles logging for Anthropic messages responses.
|
||||
|
||||
Args:
|
||||
result: The response object from the model call
|
||||
|
||||
Returns:
|
||||
The the response object from the model call
|
||||
|
||||
- For Non-streaming responses, we need to transform the response to a ModelResponse object.
|
||||
- For streaming responses, anthropic_messages handler calls success_handler with a assembled ModelResponse.
|
||||
"""
|
||||
if self.stream and isinstance(result, ModelResponse):
|
||||
return result
|
||||
|
||||
result = litellm.AnthropicConfig().transform_response(
|
||||
raw_response=self.model_call_details["httpx_response"],
|
||||
model_response=litellm.ModelResponse(),
|
||||
model=self.model,
|
||||
messages=[],
|
||||
logging_obj=self,
|
||||
optional_params={},
|
||||
api_key="",
|
||||
request_data={},
|
||||
encoding=litellm.encoding,
|
||||
json_mode=False,
|
||||
litellm_params={},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,179 @@
|
|||
"""
|
||||
- call /messages on Anthropic API
|
||||
- Make streaming + non-streaming request - just pass it through direct to Anthropic. No need to do anything special here
|
||||
- Ensure requests are logged in the DB - stream + non-stream
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.anthropic_messages.transformation import (
|
||||
BaseAnthropicMessagesConfig,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import ProviderSpecificHeader
|
||||
from litellm.utils import ProviderConfigManager, client
|
||||
|
||||
|
||||
class AnthropicMessagesHandler:
|
||||
|
||||
@staticmethod
|
||||
async def _handle_anthropic_streaming(
|
||||
response: httpx.Response,
|
||||
request_body: dict,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
) -> AsyncIterator:
|
||||
"""Helper function to handle Anthropic streaming responses using the existing logging handlers"""
|
||||
from datetime import datetime
|
||||
|
||||
from litellm.proxy.pass_through_endpoints.streaming_handler import (
|
||||
PassThroughStreamingHandler,
|
||||
)
|
||||
from litellm.proxy.pass_through_endpoints.success_handler import (
|
||||
PassThroughEndpointLogging,
|
||||
)
|
||||
from litellm.proxy.pass_through_endpoints.types import EndpointType
|
||||
|
||||
# Create success handler object
|
||||
passthrough_success_handler_obj = PassThroughEndpointLogging()
|
||||
|
||||
# Use the existing streaming handler for Anthropic
|
||||
start_time = datetime.now()
|
||||
return PassThroughStreamingHandler.chunk_processor(
|
||||
response=response,
|
||||
request_body=request_body,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
endpoint_type=EndpointType.ANTHROPIC,
|
||||
start_time=start_time,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route="/v1/messages",
|
||||
)
|
||||
|
||||
|
||||
@client
|
||||
async def anthropic_messages(
|
||||
api_key: str,
|
||||
model: str,
|
||||
stream: bool = False,
|
||||
api_base: Optional[str] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[Dict[str, Any], AsyncIterator]:
|
||||
"""
|
||||
Makes Anthropic `/v1/messages` API calls In the Anthropic API Spec
|
||||
"""
|
||||
# Use provided client or create a new one
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
|
||||
litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
)
|
||||
)
|
||||
anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = (
|
||||
ProviderConfigManager.get_provider_anthropic_messages_config(
|
||||
model=model,
|
||||
provider=litellm.LlmProviders(_custom_llm_provider),
|
||||
)
|
||||
)
|
||||
if anthropic_messages_provider_config is None:
|
||||
raise ValueError(
|
||||
f"Anthropic messages provider config not found for model: {model}"
|
||||
)
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.ANTHROPIC
|
||||
)
|
||||
else:
|
||||
async_httpx_client = client
|
||||
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
|
||||
|
||||
# Prepare headers
|
||||
provider_specific_header = cast(
|
||||
Optional[ProviderSpecificHeader], kwargs.get("provider_specific_header", None)
|
||||
)
|
||||
extra_headers = (
|
||||
provider_specific_header.get("extra_headers", {})
|
||||
if provider_specific_header
|
||||
else {}
|
||||
)
|
||||
headers = anthropic_messages_provider_config.validate_environment(
|
||||
headers=extra_headers or {},
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
optional_params=dict(optional_params),
|
||||
litellm_params={
|
||||
"metadata": kwargs.get("metadata", {}),
|
||||
"preset_cache_key": None,
|
||||
"stream_response": {},
|
||||
**optional_params.model_dump(exclude_unset=True),
|
||||
},
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
)
|
||||
litellm_logging_obj.model_call_details.update(kwargs)
|
||||
|
||||
# Prepare request body
|
||||
request_body = kwargs.copy()
|
||||
request_body = {
|
||||
k: v
|
||||
for k, v in request_body.items()
|
||||
if k
|
||||
in anthropic_messages_provider_config.get_supported_anthropic_messages_params(
|
||||
model=model
|
||||
)
|
||||
}
|
||||
request_body["stream"] = stream
|
||||
request_body["model"] = model
|
||||
litellm_logging_obj.stream = stream
|
||||
|
||||
# Make the request
|
||||
request_url = anthropic_messages_provider_config.get_complete_url(
|
||||
api_base=api_base, model=model
|
||||
)
|
||||
|
||||
litellm_logging_obj.pre_call(
|
||||
input=[{"role": "user", "content": json.dumps(request_body)}],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": request_body,
|
||||
"api_base": str(request_url),
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
response = await async_httpx_client.post(
|
||||
url=request_url,
|
||||
headers=headers,
|
||||
data=json.dumps(request_body),
|
||||
stream=stream,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# used for logging + cost tracking
|
||||
litellm_logging_obj.model_call_details["httpx_response"] = response
|
||||
|
||||
if stream:
|
||||
return await AnthropicMessagesHandler._handle_anthropic_streaming(
|
||||
response=response,
|
||||
request_body=request_body,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
)
|
||||
else:
|
||||
return response.json()
|
|
@ -0,0 +1,47 @@
|
|||
from typing import Optional
|
||||
|
||||
from litellm.llms.base_llm.anthropic_messages.transformation import (
|
||||
BaseAnthropicMessagesConfig,
|
||||
)
|
||||
|
||||
DEFAULT_ANTHROPIC_API_BASE = "https://api.anthropic.com"
|
||||
DEFAULT_ANTHROPIC_API_VERSION = "2023-06-01"
|
||||
|
||||
|
||||
class AnthropicMessagesConfig(BaseAnthropicMessagesConfig):
|
||||
def get_supported_anthropic_messages_params(self, model: str) -> list:
|
||||
return [
|
||||
"messages",
|
||||
"model",
|
||||
"system",
|
||||
"max_tokens",
|
||||
"stop_sequences",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"thinking",
|
||||
# TODO: Add Anthropic `metadata` support
|
||||
# "metadata",
|
||||
]
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
api_base = api_base or DEFAULT_ANTHROPIC_API_BASE
|
||||
if not api_base.endswith("/v1/messages"):
|
||||
api_base = f"{api_base}/v1/messages"
|
||||
return api_base
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
if "x-api-key" not in headers:
|
||||
headers["x-api-key"] = api_key
|
||||
if "anthropic-version" not in headers:
|
||||
headers["anthropic-version"] = DEFAULT_ANTHROPIC_API_VERSION
|
||||
if "content-type" not in headers:
|
||||
headers["content-type"] = "application/json"
|
||||
return headers
|
|
@ -1,412 +0,0 @@
|
|||
import json
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
|
||||
|
||||
from litellm.types.llms.anthropic import (
|
||||
AllAnthropicToolsValues,
|
||||
AnthopicMessagesAssistantMessageParam,
|
||||
AnthropicFinishReason,
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicMessagesToolChoice,
|
||||
AnthropicMessagesUserMessageParam,
|
||||
AnthropicResponse,
|
||||
AnthropicResponseContentBlockText,
|
||||
AnthropicResponseContentBlockToolUse,
|
||||
AnthropicResponseUsageBlock,
|
||||
ContentBlockDelta,
|
||||
ContentJsonBlockDelta,
|
||||
ContentTextBlockDelta,
|
||||
MessageBlockDelta,
|
||||
MessageDelta,
|
||||
UsageDelta,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionAssistantToolCall,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionImageUrlObject,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionSystemMessage,
|
||||
ChatCompletionTextObject,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionToolChoiceFunctionParam,
|
||||
ChatCompletionToolChoiceObjectParam,
|
||||
ChatCompletionToolChoiceValues,
|
||||
ChatCompletionToolMessage,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
ChatCompletionUserMessage,
|
||||
)
|
||||
from litellm.types.utils import Choices, ModelResponse, Usage
|
||||
|
||||
|
||||
class AnthropicExperimentalPassThroughConfig:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
### FOR [BETA] `/v1/messages` endpoint support
|
||||
|
||||
def translatable_anthropic_params(self) -> List:
|
||||
"""
|
||||
Which anthropic params, we need to translate to the openai format.
|
||||
"""
|
||||
return ["messages", "metadata", "system", "tool_choice", "tools"]
|
||||
|
||||
def translate_anthropic_messages_to_openai( # noqa: PLR0915
|
||||
self,
|
||||
messages: List[
|
||||
Union[
|
||||
AnthropicMessagesUserMessageParam,
|
||||
AnthopicMessagesAssistantMessageParam,
|
||||
]
|
||||
],
|
||||
) -> List:
|
||||
new_messages: List[AllMessageValues] = []
|
||||
for m in messages:
|
||||
user_message: Optional[ChatCompletionUserMessage] = None
|
||||
tool_message_list: List[ChatCompletionToolMessage] = []
|
||||
new_user_content_list: List[
|
||||
Union[ChatCompletionTextObject, ChatCompletionImageObject]
|
||||
] = []
|
||||
## USER MESSAGE ##
|
||||
if m["role"] == "user":
|
||||
## translate user message
|
||||
message_content = m.get("content")
|
||||
if message_content and isinstance(message_content, str):
|
||||
user_message = ChatCompletionUserMessage(
|
||||
role="user", content=message_content
|
||||
)
|
||||
elif message_content and isinstance(message_content, list):
|
||||
for content in message_content:
|
||||
if content["type"] == "text":
|
||||
text_obj = ChatCompletionTextObject(
|
||||
type="text", text=content["text"]
|
||||
)
|
||||
new_user_content_list.append(text_obj)
|
||||
elif content["type"] == "image":
|
||||
image_url = ChatCompletionImageUrlObject(
|
||||
url=f"data:{content['type']};base64,{content['source']}"
|
||||
)
|
||||
image_obj = ChatCompletionImageObject(
|
||||
type="image_url", image_url=image_url
|
||||
)
|
||||
|
||||
new_user_content_list.append(image_obj)
|
||||
elif content["type"] == "tool_result":
|
||||
if "content" not in content:
|
||||
tool_result = ChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=content["tool_use_id"],
|
||||
content="",
|
||||
)
|
||||
tool_message_list.append(tool_result)
|
||||
elif isinstance(content["content"], str):
|
||||
tool_result = ChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=content["tool_use_id"],
|
||||
content=content["content"],
|
||||
)
|
||||
tool_message_list.append(tool_result)
|
||||
elif isinstance(content["content"], list):
|
||||
for c in content["content"]:
|
||||
if c["type"] == "text":
|
||||
tool_result = ChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=content["tool_use_id"],
|
||||
content=c["text"],
|
||||
)
|
||||
tool_message_list.append(tool_result)
|
||||
elif c["type"] == "image":
|
||||
image_str = (
|
||||
f"data:{c['type']};base64,{c['source']}"
|
||||
)
|
||||
tool_result = ChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=content["tool_use_id"],
|
||||
content=image_str,
|
||||
)
|
||||
tool_message_list.append(tool_result)
|
||||
|
||||
if user_message is not None:
|
||||
new_messages.append(user_message)
|
||||
|
||||
if len(new_user_content_list) > 0:
|
||||
new_messages.append({"role": "user", "content": new_user_content_list}) # type: ignore
|
||||
|
||||
if len(tool_message_list) > 0:
|
||||
new_messages.extend(tool_message_list)
|
||||
|
||||
## ASSISTANT MESSAGE ##
|
||||
assistant_message_str: Optional[str] = None
|
||||
tool_calls: List[ChatCompletionAssistantToolCall] = []
|
||||
if m["role"] == "assistant":
|
||||
if isinstance(m["content"], str):
|
||||
assistant_message_str = m["content"]
|
||||
elif isinstance(m["content"], list):
|
||||
for content in m["content"]:
|
||||
if content["type"] == "text":
|
||||
if assistant_message_str is None:
|
||||
assistant_message_str = content["text"]
|
||||
else:
|
||||
assistant_message_str += content["text"]
|
||||
elif content["type"] == "tool_use":
|
||||
function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||
name=content["name"],
|
||||
arguments=json.dumps(content["input"]),
|
||||
)
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionAssistantToolCall(
|
||||
id=content["id"],
|
||||
type="function",
|
||||
function=function_chunk,
|
||||
)
|
||||
)
|
||||
|
||||
if assistant_message_str is not None or len(tool_calls) > 0:
|
||||
assistant_message = ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=assistant_message_str,
|
||||
)
|
||||
if len(tool_calls) > 0:
|
||||
assistant_message["tool_calls"] = tool_calls
|
||||
new_messages.append(assistant_message)
|
||||
|
||||
return new_messages
|
||||
|
||||
def translate_anthropic_tool_choice_to_openai(
|
||||
self, tool_choice: AnthropicMessagesToolChoice
|
||||
) -> ChatCompletionToolChoiceValues:
|
||||
if tool_choice["type"] == "any":
|
||||
return "required"
|
||||
elif tool_choice["type"] == "auto":
|
||||
return "auto"
|
||||
elif tool_choice["type"] == "tool":
|
||||
tc_function_param = ChatCompletionToolChoiceFunctionParam(
|
||||
name=tool_choice.get("name", "")
|
||||
)
|
||||
return ChatCompletionToolChoiceObjectParam(
|
||||
type="function", function=tc_function_param
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Incompatible tool choice param submitted - {}".format(tool_choice)
|
||||
)
|
||||
|
||||
def translate_anthropic_tools_to_openai(
|
||||
self, tools: List[AllAnthropicToolsValues]
|
||||
) -> List[ChatCompletionToolParam]:
|
||||
new_tools: List[ChatCompletionToolParam] = []
|
||||
mapped_tool_params = ["name", "input_schema", "description"]
|
||||
for tool in tools:
|
||||
function_chunk = ChatCompletionToolParamFunctionChunk(
|
||||
name=tool["name"],
|
||||
)
|
||||
if "input_schema" in tool:
|
||||
function_chunk["parameters"] = tool["input_schema"] # type: ignore
|
||||
if "description" in tool:
|
||||
function_chunk["description"] = tool["description"] # type: ignore
|
||||
|
||||
for k, v in tool.items():
|
||||
if k not in mapped_tool_params: # pass additional computer kwargs
|
||||
function_chunk.setdefault("parameters", {}).update({k: v})
|
||||
new_tools.append(
|
||||
ChatCompletionToolParam(type="function", function=function_chunk)
|
||||
)
|
||||
|
||||
return new_tools
|
||||
|
||||
def translate_anthropic_to_openai(
|
||||
self, anthropic_message_request: AnthropicMessagesRequest
|
||||
) -> ChatCompletionRequest:
|
||||
"""
|
||||
This is used by the beta Anthropic Adapter, for translating anthropic `/v1/messages` requests to the openai format.
|
||||
"""
|
||||
new_messages: List[AllMessageValues] = []
|
||||
|
||||
## CONVERT ANTHROPIC MESSAGES TO OPENAI
|
||||
new_messages = self.translate_anthropic_messages_to_openai(
|
||||
messages=anthropic_message_request["messages"]
|
||||
)
|
||||
## ADD SYSTEM MESSAGE TO MESSAGES
|
||||
if "system" in anthropic_message_request:
|
||||
new_messages.insert(
|
||||
0,
|
||||
ChatCompletionSystemMessage(
|
||||
role="system", content=anthropic_message_request["system"]
|
||||
),
|
||||
)
|
||||
|
||||
new_kwargs: ChatCompletionRequest = {
|
||||
"model": anthropic_message_request["model"],
|
||||
"messages": new_messages,
|
||||
}
|
||||
## CONVERT METADATA (user_id)
|
||||
if "metadata" in anthropic_message_request:
|
||||
if "user_id" in anthropic_message_request["metadata"]:
|
||||
new_kwargs["user"] = anthropic_message_request["metadata"]["user_id"]
|
||||
|
||||
# Pass litellm proxy specific metadata
|
||||
if "litellm_metadata" in anthropic_message_request:
|
||||
# metadata will be passed to litellm.acompletion(), it's a litellm_param
|
||||
new_kwargs["metadata"] = anthropic_message_request.pop("litellm_metadata")
|
||||
|
||||
## CONVERT TOOL CHOICE
|
||||
if "tool_choice" in anthropic_message_request:
|
||||
new_kwargs["tool_choice"] = self.translate_anthropic_tool_choice_to_openai(
|
||||
tool_choice=anthropic_message_request["tool_choice"]
|
||||
)
|
||||
## CONVERT TOOLS
|
||||
if "tools" in anthropic_message_request:
|
||||
new_kwargs["tools"] = self.translate_anthropic_tools_to_openai(
|
||||
tools=anthropic_message_request["tools"]
|
||||
)
|
||||
|
||||
translatable_params = self.translatable_anthropic_params()
|
||||
for k, v in anthropic_message_request.items():
|
||||
if k not in translatable_params: # pass remaining params as is
|
||||
new_kwargs[k] = v # type: ignore
|
||||
|
||||
return new_kwargs
|
||||
|
||||
def _translate_openai_content_to_anthropic(
|
||||
self, choices: List[Choices]
|
||||
) -> List[
|
||||
Union[AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse]
|
||||
]:
|
||||
new_content: List[
|
||||
Union[
|
||||
AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse
|
||||
]
|
||||
] = []
|
||||
for choice in choices:
|
||||
if (
|
||||
choice.message.tool_calls is not None
|
||||
and len(choice.message.tool_calls) > 0
|
||||
):
|
||||
for tool_call in choice.message.tool_calls:
|
||||
new_content.append(
|
||||
AnthropicResponseContentBlockToolUse(
|
||||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name or "",
|
||||
input=json.loads(tool_call.function.arguments),
|
||||
)
|
||||
)
|
||||
elif choice.message.content is not None:
|
||||
new_content.append(
|
||||
AnthropicResponseContentBlockText(
|
||||
type="text", text=choice.message.content
|
||||
)
|
||||
)
|
||||
|
||||
return new_content
|
||||
|
||||
def _translate_openai_finish_reason_to_anthropic(
|
||||
self, openai_finish_reason: str
|
||||
) -> AnthropicFinishReason:
|
||||
if openai_finish_reason == "stop":
|
||||
return "end_turn"
|
||||
elif openai_finish_reason == "length":
|
||||
return "max_tokens"
|
||||
elif openai_finish_reason == "tool_calls":
|
||||
return "tool_use"
|
||||
return "end_turn"
|
||||
|
||||
def translate_openai_response_to_anthropic(
|
||||
self, response: ModelResponse
|
||||
) -> AnthropicResponse:
|
||||
## translate content block
|
||||
anthropic_content = self._translate_openai_content_to_anthropic(choices=response.choices) # type: ignore
|
||||
## extract finish reason
|
||||
anthropic_finish_reason = self._translate_openai_finish_reason_to_anthropic(
|
||||
openai_finish_reason=response.choices[0].finish_reason # type: ignore
|
||||
)
|
||||
# extract usage
|
||||
usage: Usage = getattr(response, "usage")
|
||||
anthropic_usage = AnthropicResponseUsageBlock(
|
||||
input_tokens=usage.prompt_tokens or 0,
|
||||
output_tokens=usage.completion_tokens or 0,
|
||||
)
|
||||
translated_obj = AnthropicResponse(
|
||||
id=response.id,
|
||||
type="message",
|
||||
role="assistant",
|
||||
model=response.model or "unknown-model",
|
||||
stop_sequence=None,
|
||||
usage=anthropic_usage,
|
||||
content=anthropic_content,
|
||||
stop_reason=anthropic_finish_reason,
|
||||
)
|
||||
|
||||
return translated_obj
|
||||
|
||||
def _translate_streaming_openai_chunk_to_anthropic(
|
||||
self, choices: List[OpenAIStreamingChoice]
|
||||
) -> Tuple[
|
||||
Literal["text_delta", "input_json_delta"],
|
||||
Union[ContentTextBlockDelta, ContentJsonBlockDelta],
|
||||
]:
|
||||
text: str = ""
|
||||
partial_json: Optional[str] = None
|
||||
for choice in choices:
|
||||
if choice.delta.content is not None:
|
||||
text += choice.delta.content
|
||||
elif choice.delta.tool_calls is not None:
|
||||
partial_json = ""
|
||||
for tool in choice.delta.tool_calls:
|
||||
if (
|
||||
tool.function is not None
|
||||
and tool.function.arguments is not None
|
||||
):
|
||||
partial_json += tool.function.arguments
|
||||
|
||||
if partial_json is not None:
|
||||
return "input_json_delta", ContentJsonBlockDelta(
|
||||
type="input_json_delta", partial_json=partial_json
|
||||
)
|
||||
else:
|
||||
return "text_delta", ContentTextBlockDelta(type="text_delta", text=text)
|
||||
|
||||
def translate_streaming_openai_response_to_anthropic(
|
||||
self, response: ModelResponse
|
||||
) -> Union[ContentBlockDelta, MessageBlockDelta]:
|
||||
## base case - final chunk w/ finish reason
|
||||
if response.choices[0].finish_reason is not None:
|
||||
delta = MessageDelta(
|
||||
stop_reason=self._translate_openai_finish_reason_to_anthropic(
|
||||
response.choices[0].finish_reason
|
||||
),
|
||||
)
|
||||
if getattr(response, "usage", None) is not None:
|
||||
litellm_usage_chunk: Optional[Usage] = response.usage # type: ignore
|
||||
elif (
|
||||
hasattr(response, "_hidden_params")
|
||||
and "usage" in response._hidden_params
|
||||
):
|
||||
litellm_usage_chunk = response._hidden_params["usage"]
|
||||
else:
|
||||
litellm_usage_chunk = None
|
||||
if litellm_usage_chunk is not None:
|
||||
usage_delta = UsageDelta(
|
||||
input_tokens=litellm_usage_chunk.prompt_tokens or 0,
|
||||
output_tokens=litellm_usage_chunk.completion_tokens or 0,
|
||||
)
|
||||
else:
|
||||
usage_delta = UsageDelta(input_tokens=0, output_tokens=0)
|
||||
return MessageBlockDelta(
|
||||
type="message_delta", delta=delta, usage=usage_delta
|
||||
)
|
||||
(
|
||||
type_of_content,
|
||||
content_block_delta,
|
||||
) = self._translate_streaming_openai_chunk_to_anthropic(
|
||||
choices=response.choices # type: ignore
|
||||
)
|
||||
return ContentBlockDelta(
|
||||
type="content_block_delta",
|
||||
index=response.choices[0].index,
|
||||
delta=content_block_delta,
|
||||
)
|
35
litellm/llms/base_llm/anthropic_messages/transformation.py
Normal file
35
litellm/llms/base_llm/anthropic_messages/transformation.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BaseAnthropicMessagesConfig(ABC):
|
||||
@abstractmethod
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
return api_base or ""
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_anthropic_messages_params(self, model: str) -> list:
|
||||
pass
|
|
@ -1963,7 +1963,7 @@ class ProxyException(Exception):
|
|||
code: Optional[Union[int, str]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
self.message = message
|
||||
self.message = str(message)
|
||||
self.type = type
|
||||
self.param = param
|
||||
|
||||
|
|
252
litellm/proxy/anthropic_endpoints/endpoints.py
Normal file
252
litellm/proxy/anthropic_endpoints/endpoints.py
Normal file
|
@ -0,0 +1,252 @@
|
|||
"""
|
||||
Unified /v1/messages endpoint - (Anthropic Spec)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def async_data_generator_anthropic(
|
||||
response,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
request_data: dict,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
):
|
||||
verbose_proxy_logger.debug("inside generator")
|
||||
try:
|
||||
time.time()
|
||||
async for chunk in response:
|
||||
verbose_proxy_logger.debug(
|
||||
"async_data_generator: received streaming chunk - {}".format(chunk)
|
||||
)
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=chunk
|
||||
)
|
||||
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=request_data,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||
)
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
|
||||
proxy_exception = ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
error_returned = json.dumps({"error": proxy_exception.to_dict()})
|
||||
yield f"data: {error_returned}\n\n"
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/messages",
|
||||
tags=["[beta] Anthropic `/v1/messages`"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def anthropic_response( # noqa: PLR0915
|
||||
fastapi_response: Response,
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Use `{PROXY_BASE_URL}/anthropic/v1/messages` instead - [Docs](https://docs.litellm.ai/docs/anthropic_completion).
|
||||
|
||||
This was a BETA endpoint that calls 100+ LLMs in the anthropic format.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
get_custom_headers,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
user_api_base,
|
||||
user_max_tokens,
|
||||
user_model,
|
||||
user_request_timeout,
|
||||
user_temperature,
|
||||
version,
|
||||
)
|
||||
|
||||
request_data = await _read_request_body(request=request)
|
||||
data: dict = {**request_data}
|
||||
try:
|
||||
data["model"] = (
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or data.get("model", None) # default passed in http request
|
||||
)
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data, # type: ignore
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
# override with user settings, these are params passed via cli
|
||||
if user_temperature:
|
||||
data["temperature"] = user_temperature
|
||||
if user_request_timeout:
|
||||
data["request_timeout"] = user_request_timeout
|
||||
if user_max_tokens:
|
||||
data["max_tokens"] = user_max_tokens
|
||||
if user_api_base:
|
||||
data["api_base"] = user_api_base
|
||||
|
||||
### MODEL ALIAS MAPPING ###
|
||||
# check if model name in model alias map
|
||||
# get the actual model name
|
||||
if data["model"] in litellm.model_alias_map:
|
||||
data["model"] = litellm.model_alias_map[data["model"]]
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
|
||||
)
|
||||
|
||||
### ROUTE THE REQUESTs ###
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
|
||||
# skip router if user passed their key
|
||||
if (
|
||||
llm_router is not None and data["model"] in router_model_names
|
||||
): # model in router model list
|
||||
llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data))
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data["model"] in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data))
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
llm_response = asyncio.create_task(
|
||||
llm_router.aanthropic_messages(**data, specific_deployment=True)
|
||||
)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.get_model_ids()
|
||||
): # model in router model list
|
||||
llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data))
|
||||
elif (
|
||||
llm_router is not None
|
||||
and data["model"] not in router_model_names
|
||||
and (
|
||||
llm_router.default_deployment is not None
|
||||
or len(llm_router.pattern_router.patterns) > 0
|
||||
)
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
llm_response = asyncio.create_task(llm_router.aanthropic_messages(**data))
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
llm_response = asyncio.create_task(litellm.anthropic_messages(**data))
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error": "completion: Invalid model name passed in model="
|
||||
+ data.get("model", "")
|
||||
},
|
||||
)
|
||||
|
||||
# Await the llm_response task
|
||||
response = await llm_response
|
||||
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
response_cost = hidden_params.get("response_cost", None) or ""
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||
)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("final response: %s", response)
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
response_cost=response_cost,
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
"stream" in data and data["stream"] is True
|
||||
): # use generate_responses to stream responses
|
||||
selected_data_generator = async_data_generator_anthropic(
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=data,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
selected_data_generator, # type: ignore
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response))
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.anthropic_response(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
|
@ -1,9 +1,29 @@
|
|||
model_list:
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
- model_name: claude-3-5-sonnet-20241022
|
||||
litellm_params:
|
||||
model: anthropic/claude-3-5-sonnet-20241022
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: claude-special-alias
|
||||
litellm_params:
|
||||
model: anthropic/claude-3-haiku-20240307
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: claude-3-5-sonnet-20241022
|
||||
litellm_params:
|
||||
model: anthropic/claude-3-5-sonnet-20241022
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: claude-3-7-sonnet-20250219
|
||||
litellm_params:
|
||||
model: anthropic/claude-3-7-sonnet-20250219
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: anthropic/*
|
||||
litellm_params:
|
||||
model: anthropic/*
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
custom_auth: custom_auth_basic.user_api_key_auth
|
||||
master_key: sk-1234
|
||||
custom_auth: custom_auth_basic.user_api_key_auth
|
|
@ -4,7 +4,22 @@ model_list:
|
|||
model: openai/my-fake-model
|
||||
api_key: my-fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
- model_name: claude-special-alias
|
||||
litellm_params:
|
||||
model: anthropic/claude-3-haiku-20240307
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: claude-3-5-sonnet-20241022
|
||||
litellm_params:
|
||||
model: anthropic/claude-3-5-sonnet-20241022
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: claude-3-7-sonnet-20250219
|
||||
litellm_params:
|
||||
model: anthropic/claude-3-7-sonnet-20250219
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
- model_name: anthropic/*
|
||||
litellm_params:
|
||||
model: anthropic/*
|
||||
api_key: os.environ/ANTHROPIC_API_KEY
|
||||
|
||||
general_settings:
|
||||
store_model_in_db: true
|
||||
|
|
|
@ -120,6 +120,7 @@ from litellm.proxy._types import *
|
|||
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
||||
router as analytics_router,
|
||||
)
|
||||
from litellm.proxy.anthropic_endpoints.endpoints import router as anthropic_router
|
||||
from litellm.proxy.auth.auth_checks import log_db_metrics
|
||||
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
|
@ -3065,58 +3066,6 @@ async def async_data_generator(
|
|||
yield f"data: {error_returned}\n\n"
|
||||
|
||||
|
||||
async def async_data_generator_anthropic(
|
||||
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||
):
|
||||
verbose_proxy_logger.debug("inside generator")
|
||||
try:
|
||||
time.time()
|
||||
async for chunk in response:
|
||||
verbose_proxy_logger.debug(
|
||||
"async_data_generator: received streaming chunk - {}".format(chunk)
|
||||
)
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=chunk
|
||||
)
|
||||
|
||||
event_type = chunk.get("type")
|
||||
|
||||
try:
|
||||
yield f"event: {event_type}\ndata:{json.dumps(chunk)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"event: {event_type}\ndata:{str(e)}\n\n"
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=request_data,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||
)
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
|
||||
proxy_exception = ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
error_returned = json.dumps({"error": proxy_exception.to_dict()})
|
||||
yield f"data: {error_returned}\n\n"
|
||||
|
||||
|
||||
def select_data_generator(
|
||||
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||
):
|
||||
|
@ -5524,224 +5473,6 @@ async def moderations(
|
|||
)
|
||||
|
||||
|
||||
#### ANTHROPIC ENDPOINTS ####
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/messages",
|
||||
tags=["[beta] Anthropic `/v1/messages`"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AnthropicResponse,
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def anthropic_response( # noqa: PLR0915
|
||||
anthropic_data: AnthropicMessagesRequest,
|
||||
fastapi_response: Response,
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
🚨 DEPRECATED ENDPOINT🚨
|
||||
|
||||
Use `{PROXY_BASE_URL}/anthropic/v1/messages` instead - [Docs](https://docs.litellm.ai/docs/anthropic_completion).
|
||||
|
||||
This was a BETA endpoint that calls 100+ LLMs in the anthropic format.
|
||||
"""
|
||||
from litellm import adapter_completion
|
||||
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
||||
|
||||
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
|
||||
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
request_data = await _read_request_body(request=request)
|
||||
data: dict = {**request_data, "adapter_id": "anthropic"}
|
||||
try:
|
||||
data["model"] = (
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or data.get("model", None) # default passed in http request
|
||||
)
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
|
||||
data = await add_litellm_data_to_request(
|
||||
data=data, # type: ignore
|
||||
request=request,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
version=version,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
# override with user settings, these are params passed via cli
|
||||
if user_temperature:
|
||||
data["temperature"] = user_temperature
|
||||
if user_request_timeout:
|
||||
data["request_timeout"] = user_request_timeout
|
||||
if user_max_tokens:
|
||||
data["max_tokens"] = user_max_tokens
|
||||
if user_api_base:
|
||||
data["api_base"] = user_api_base
|
||||
|
||||
### MODEL ALIAS MAPPING ###
|
||||
# check if model name in model alias map
|
||||
# get the actual model name
|
||||
if data["model"] in litellm.model_alias_map:
|
||||
data["model"] = litellm.model_alias_map[data["model"]]
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
|
||||
)
|
||||
|
||||
### ROUTE THE REQUESTs ###
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
# skip router if user passed their key
|
||||
if "api_key" in data:
|
||||
llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
|
||||
elif (
|
||||
llm_router is not None and data["model"] in router_model_names
|
||||
): # model in router model list
|
||||
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
|
||||
elif (
|
||||
llm_router is not None
|
||||
and llm_router.model_group_alias is not None
|
||||
and data["model"] in llm_router.model_group_alias
|
||||
): # model set in model_group_alias
|
||||
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
llm_response = asyncio.create_task(
|
||||
llm_router.aadapter_completion(**data, specific_deployment=True)
|
||||
)
|
||||
elif (
|
||||
llm_router is not None and data["model"] in llm_router.get_model_ids()
|
||||
): # model in router model list
|
||||
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
|
||||
elif (
|
||||
llm_router is not None
|
||||
and data["model"] not in router_model_names
|
||||
and (
|
||||
llm_router.default_deployment is not None
|
||||
or len(llm_router.pattern_router.patterns) > 0
|
||||
)
|
||||
): # model in router deployments, calling a specific deployment on the router
|
||||
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
|
||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||
llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"error": "completion: Invalid model name passed in model="
|
||||
+ data.get("model", "")
|
||||
},
|
||||
)
|
||||
|
||||
# Await the llm_response task
|
||||
response = await llm_response
|
||||
|
||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
response_cost = hidden_params.get("response_cost", None) or ""
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
litellm_call_id=data.get("litellm_call_id", ""), status="success"
|
||||
)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("final response: %s", response)
|
||||
|
||||
fastapi_response.headers.update(
|
||||
get_custom_headers(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
model_id=model_id,
|
||||
cache_key=cache_key,
|
||||
api_base=api_base,
|
||||
version=version,
|
||||
response_cost=response_cost,
|
||||
request_data=data,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
"stream" in data and data["stream"] is True
|
||||
): # use generate_responses to stream responses
|
||||
selected_data_generator = async_data_generator_anthropic(
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=data,
|
||||
)
|
||||
return StreamingResponse(
|
||||
selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response))
|
||||
return response
|
||||
except RejectedRequestError as e:
|
||||
_data = e.request_data
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=_data,
|
||||
)
|
||||
if _data.get("stream", None) is not None and _data["stream"] is True:
|
||||
_chat_response = litellm.ModelResponse()
|
||||
_usage = litellm.Usage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
_chat_response.usage = _usage # type: ignore
|
||||
_chat_response.choices[0].message.content = e.message # type: ignore
|
||||
_iterator = litellm.utils.ModelResponseIterator(
|
||||
model_response=_chat_response, convert_to_delta=True
|
||||
)
|
||||
_streaming_response = litellm.TextCompletionStreamWrapper(
|
||||
completion_stream=_iterator,
|
||||
model=_data.get("model", ""),
|
||||
)
|
||||
|
||||
selected_data_generator = select_data_generator(
|
||||
response=_streaming_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=data,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
headers={},
|
||||
)
|
||||
else:
|
||||
_response = litellm.TextCompletionResponse()
|
||||
_response.choices[0].text = e.message
|
||||
return _response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.anthropic_response(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
error_msg = f"{str(e)}"
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
|
||||
|
||||
#### DEV UTILS ####
|
||||
|
||||
# @router.get(
|
||||
|
@ -8840,6 +8571,7 @@ app.include_router(rerank_router)
|
|||
app.include_router(fine_tuning_router)
|
||||
app.include_router(vertex_router)
|
||||
app.include_router(llm_passthrough_router)
|
||||
app.include_router(anthropic_router)
|
||||
app.include_router(langfuse_router)
|
||||
app.include_router(pass_through_router)
|
||||
app.include_router(health_router)
|
||||
|
|
|
@ -10,6 +10,7 @@ from pydantic import BaseModel
|
|||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
|
||||
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
||||
from litellm.proxy.utils import PrismaClient, hash_token
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
@ -119,9 +120,7 @@ def get_logging_payload( # noqa: PLR0915
|
|||
response_obj = {}
|
||||
# standardize this function to be used across, s3, dynamoDB, langfuse logging
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs)
|
||||
metadata = _add_proxy_server_request_to_metadata(
|
||||
metadata=metadata, litellm_params=litellm_params
|
||||
)
|
||||
|
|
|
@ -580,6 +580,9 @@ class Router:
|
|||
self.amoderation = self.factory_function(
|
||||
litellm.amoderation, call_type="moderation"
|
||||
)
|
||||
self.aanthropic_messages = self.factory_function(
|
||||
litellm.anthropic_messages, call_type="anthropic_messages"
|
||||
)
|
||||
|
||||
def discard(self):
|
||||
"""
|
||||
|
@ -2349,6 +2352,89 @@ class Router:
|
|||
self.fail_calls[model] += 1
|
||||
raise e
|
||||
|
||||
async def _ageneric_api_call_with_fallbacks(
|
||||
self, model: str, original_function: Callable, **kwargs
|
||||
):
|
||||
"""
|
||||
Make a generic LLM API call through the router, this allows you to use retries/fallbacks with litellm router
|
||||
|
||||
Args:
|
||||
model: The model to use
|
||||
handler_function: The handler function to call (e.g., litellm.anthropic_messages)
|
||||
**kwargs: Additional arguments to pass to the handler function
|
||||
|
||||
Returns:
|
||||
The response from the handler function
|
||||
"""
|
||||
handler_name = original_function.__name__
|
||||
try:
|
||||
verbose_router_logger.debug(
|
||||
f"Inside _ageneric_api_call() - handler: {handler_name}, model: {model}; kwargs: {kwargs}"
|
||||
)
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
deployment = await self.async_get_available_deployment(
|
||||
model=model,
|
||||
request_kwargs=kwargs,
|
||||
messages=kwargs.get("messages", None),
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
)
|
||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
|
||||
model_client = self._get_async_openai_model_client(
|
||||
deployment=deployment,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
self.total_calls[model_name] += 1
|
||||
|
||||
response = original_function(
|
||||
**{
|
||||
**data,
|
||||
"caching": self.cache_responses,
|
||||
"client": model_client,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
rpm_semaphore = self._get_client(
|
||||
deployment=deployment,
|
||||
kwargs=kwargs,
|
||||
client_type="max_parallel_requests",
|
||||
)
|
||||
|
||||
if rpm_semaphore is not None and isinstance(
|
||||
rpm_semaphore, asyncio.Semaphore
|
||||
):
|
||||
async with rpm_semaphore:
|
||||
"""
|
||||
- Check rpm limits before making the call
|
||||
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
|
||||
"""
|
||||
await self.async_routing_strategy_pre_call_checks(
|
||||
deployment=deployment, parent_otel_span=parent_otel_span
|
||||
)
|
||||
response = await response # type: ignore
|
||||
else:
|
||||
await self.async_routing_strategy_pre_call_checks(
|
||||
deployment=deployment, parent_otel_span=parent_otel_span
|
||||
)
|
||||
response = await response # type: ignore
|
||||
|
||||
self.success_calls[model_name] += 1
|
||||
verbose_router_logger.info(
|
||||
f"{handler_name}(model={model_name})\033[32m 200 OK\033[0m"
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_router_logger.info(
|
||||
f"{handler_name}(model={model})\033[31m Exception {str(e)}\033[0m"
|
||||
)
|
||||
if model is not None:
|
||||
self.fail_calls[model] += 1
|
||||
raise e
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -2869,10 +2955,14 @@ class Router:
|
|||
def factory_function(
|
||||
self,
|
||||
original_function: Callable,
|
||||
call_type: Literal["assistants", "moderation"] = "assistants",
|
||||
call_type: Literal[
|
||||
"assistants", "moderation", "anthropic_messages"
|
||||
] = "assistants",
|
||||
):
|
||||
async def new_function(
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
custom_llm_provider: Optional[
|
||||
Literal["openai", "azure", "anthropic"]
|
||||
] = None,
|
||||
client: Optional["AsyncOpenAI"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -2889,13 +2979,18 @@ class Router:
|
|||
original_function=original_function,
|
||||
**kwargs,
|
||||
)
|
||||
elif call_type == "anthropic_messages":
|
||||
return await self._ageneric_api_call_with_fallbacks( # type: ignore
|
||||
original_function=original_function,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return new_function
|
||||
|
||||
async def _pass_through_assistants_endpoint_factory(
|
||||
self,
|
||||
original_function: Callable,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||
custom_llm_provider: Optional[Literal["openai", "azure", "anthropic"]] = None,
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
|
|
@ -186,6 +186,7 @@ class CallTypes(Enum):
|
|||
aretrieve_batch = "aretrieve_batch"
|
||||
retrieve_batch = "retrieve_batch"
|
||||
pass_through = "pass_through_endpoint"
|
||||
anthropic_messages = "anthropic_messages"
|
||||
|
||||
|
||||
CallTypesLiteral = Literal[
|
||||
|
@ -209,6 +210,7 @@ CallTypesLiteral = Literal[
|
|||
"create_batch",
|
||||
"acreate_batch",
|
||||
"pass_through_endpoint",
|
||||
"anthropic_messages",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -191,6 +191,9 @@ from typing import (
|
|||
from openai import OpenAIError as OriginalError
|
||||
|
||||
from litellm.litellm_core_utils.thread_pool_executor import executor
|
||||
from litellm.llms.base_llm.anthropic_messages.transformation import (
|
||||
BaseAnthropicMessagesConfig,
|
||||
)
|
||||
from litellm.llms.base_llm.audio_transcription.transformation import (
|
||||
BaseAudioTranscriptionConfig,
|
||||
)
|
||||
|
@ -6245,6 +6248,15 @@ class ProviderConfigManager:
|
|||
return litellm.JinaAIRerankConfig()
|
||||
return litellm.CohereRerankConfig()
|
||||
|
||||
@staticmethod
|
||||
def get_provider_anthropic_messages_config(
|
||||
model: str,
|
||||
provider: LlmProviders,
|
||||
) -> Optional[BaseAnthropicMessagesConfig]:
|
||||
if litellm.LlmProviders.ANTHROPIC == provider:
|
||||
return litellm.AnthropicMessagesConfig()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_provider_audio_transcription_config(
|
||||
model: str,
|
||||
|
|
|
@ -329,57 +329,3 @@ async def test_aaapass_through_endpoint_pass_through_keys_langfuse(
|
|||
setattr(
|
||||
litellm.proxy.proxy_server, "proxy_logging_obj", original_proxy_logging_obj
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass_through_endpoint_anthropic(client):
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
"mock_response": "Hey, how's it going?",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "llm_router", router)
|
||||
|
||||
# Define a pass-through endpoint
|
||||
pass_through_endpoints = [
|
||||
{
|
||||
"path": "/v1/test-messages",
|
||||
"target": anthropic_adapter,
|
||||
"headers": {"litellm_user_api_key": "my-test-header"},
|
||||
}
|
||||
]
|
||||
|
||||
# Initialize the pass-through endpoint
|
||||
await initialize_pass_through_endpoints(pass_through_endpoints)
|
||||
general_settings: Optional[dict] = (
|
||||
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
|
||||
)
|
||||
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
|
||||
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
|
||||
|
||||
_json_data = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Who are you?"}],
|
||||
}
|
||||
|
||||
# Make a request to the pass-through endpoint
|
||||
response = client.post(
|
||||
"/v1/test-messages", json=_json_data, headers={"my-test-header": "my-test-key"}
|
||||
)
|
||||
|
||||
print("JSON response: ", _json_data)
|
||||
|
||||
# Assert the response
|
||||
assert response.status_code == 200
|
||||
|
|
145
tests/pass_through_tests/base_anthropic_messages_test.py
Normal file
145
tests/pass_through_tests/base_anthropic_messages_test.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import anthropic
|
||||
import pytest
|
||||
|
||||
|
||||
class BaseAnthropicMessagesTest(ABC):
|
||||
"""
|
||||
Abstract base test class that enforces a common test across all test classes.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_client(self):
|
||||
return anthropic.Anthropic()
|
||||
|
||||
def test_anthropic_basic_completion(self):
|
||||
print("making basic completion request to anthropic passthrough")
|
||||
client = self.get_client()
|
||||
response = client.messages.create(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "Say 'hello test' and nothing else"}],
|
||||
extra_body={
|
||||
"litellm_metadata": {
|
||||
"tags": ["test-tag-1", "test-tag-2"],
|
||||
}
|
||||
},
|
||||
)
|
||||
print(response)
|
||||
|
||||
def test_anthropic_streaming(self):
|
||||
print("making streaming request to anthropic passthrough")
|
||||
collected_output = []
|
||||
client = self.get_client()
|
||||
with client.messages.stream(
|
||||
max_tokens=10,
|
||||
messages=[
|
||||
{"role": "user", "content": "Say 'hello stream test' and nothing else"}
|
||||
],
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
extra_body={
|
||||
"litellm_metadata": {
|
||||
"tags": ["test-tag-stream-1", "test-tag-stream-2"],
|
||||
}
|
||||
},
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
collected_output.append(text)
|
||||
|
||||
full_response = "".join(collected_output)
|
||||
print(full_response)
|
||||
|
||||
def test_anthropic_messages_with_thinking(self):
|
||||
print("making request to anthropic passthrough with thinking")
|
||||
client = self.get_client()
|
||||
response = client.messages.create(
|
||||
model="claude-3-7-sonnet-20250219",
|
||||
max_tokens=20000,
|
||||
thinking={"type": "enabled", "budget_tokens": 16000},
|
||||
messages=[
|
||||
{"role": "user", "content": "Just pinging with thinking enabled"}
|
||||
],
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
# Verify the first content block is a thinking block
|
||||
response_thinking = response.content[0].thinking
|
||||
assert response_thinking is not None
|
||||
assert len(response_thinking) > 0
|
||||
|
||||
def test_anthropic_streaming_with_thinking(self):
|
||||
print("making streaming request to anthropic passthrough with thinking enabled")
|
||||
collected_thinking = []
|
||||
collected_response = []
|
||||
client = self.get_client()
|
||||
with client.messages.stream(
|
||||
model="claude-3-7-sonnet-20250219",
|
||||
max_tokens=20000,
|
||||
thinking={"type": "enabled", "budget_tokens": 16000},
|
||||
messages=[
|
||||
{"role": "user", "content": "Just pinging with thinking enabled"}
|
||||
],
|
||||
) as stream:
|
||||
for event in stream:
|
||||
if event.type == "content_block_delta":
|
||||
if event.delta.type == "thinking_delta":
|
||||
collected_thinking.append(event.delta.thinking)
|
||||
elif event.delta.type == "text_delta":
|
||||
collected_response.append(event.delta.text)
|
||||
|
||||
full_thinking = "".join(collected_thinking)
|
||||
full_response = "".join(collected_response)
|
||||
|
||||
print(
|
||||
f"Thinking Response: {full_thinking[:100]}..."
|
||||
) # Print first 100 chars of thinking
|
||||
print(f"Response: {full_response}")
|
||||
|
||||
# Verify we received thinking content
|
||||
assert len(collected_thinking) > 0
|
||||
assert len(full_thinking) > 0
|
||||
|
||||
# Verify we also received a response
|
||||
assert len(collected_response) > 0
|
||||
assert len(full_response) > 0
|
||||
|
||||
def test_bad_request_error_handling_streaming(self):
|
||||
print("making request to anthropic passthrough with bad request")
|
||||
try:
|
||||
client = self.get_client()
|
||||
response = client.messages.create(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
messages=["hi"],
|
||||
)
|
||||
print(response)
|
||||
assert pytest.fail("Expected BadRequestError")
|
||||
except anthropic.BadRequestError as e:
|
||||
print("Got BadRequestError from anthropic, e=", e)
|
||||
print(e.__cause__)
|
||||
print(e.status_code)
|
||||
print(e.response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Got unexpected exception: {e}")
|
||||
|
||||
def test_bad_request_error_handling_non_streaming(self):
|
||||
print("making request to anthropic passthrough with bad request")
|
||||
try:
|
||||
client = self.get_client()
|
||||
response = client.messages.create(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=10,
|
||||
messages=["hi"],
|
||||
)
|
||||
print(response)
|
||||
assert pytest.fail("Expected BadRequestError")
|
||||
except anthropic.BadRequestError as e:
|
||||
print("Got BadRequestError from anthropic, e=", e)
|
||||
print(e.__cause__)
|
||||
print(e.status_code)
|
||||
print(e.response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Got unexpected exception: {e}")
|
|
@ -8,48 +8,6 @@ import aiohttp
|
|||
import asyncio
|
||||
import json
|
||||
|
||||
client = anthropic.Anthropic(
|
||||
base_url="http://0.0.0.0:4000/anthropic", api_key="sk-1234"
|
||||
)
|
||||
|
||||
|
||||
def test_anthropic_basic_completion():
|
||||
print("making basic completion request to anthropic passthrough")
|
||||
response = client.messages.create(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "Say 'hello test' and nothing else"}],
|
||||
extra_body={
|
||||
"litellm_metadata": {
|
||||
"tags": ["test-tag-1", "test-tag-2"],
|
||||
}
|
||||
},
|
||||
)
|
||||
print(response)
|
||||
|
||||
|
||||
def test_anthropic_streaming():
|
||||
print("making streaming request to anthropic passthrough")
|
||||
collected_output = []
|
||||
|
||||
with client.messages.stream(
|
||||
max_tokens=10,
|
||||
messages=[
|
||||
{"role": "user", "content": "Say 'hello stream test' and nothing else"}
|
||||
],
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
extra_body={
|
||||
"litellm_metadata": {
|
||||
"tags": ["test-tag-stream-1", "test-tag-stream-2"],
|
||||
}
|
||||
},
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
collected_output.append(text)
|
||||
|
||||
full_response = "".join(collected_output)
|
||||
print(full_response)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_basic_completion_with_headers():
|
||||
|
|
28
tests/pass_through_tests/test_anthropic_passthrough_basic.py
Normal file
28
tests/pass_through_tests/test_anthropic_passthrough_basic.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
from base_anthropic_messages_test import BaseAnthropicMessagesTest
|
||||
import anthropic
|
||||
|
||||
|
||||
class TestAnthropicPassthroughBasic(BaseAnthropicMessagesTest):
|
||||
|
||||
def get_client(self):
|
||||
return anthropic.Anthropic(
|
||||
base_url="http://0.0.0.0:4000/anthropic",
|
||||
api_key="sk-1234",
|
||||
)
|
||||
|
||||
|
||||
class TestAnthropicMessagesEndpoint(BaseAnthropicMessagesTest):
|
||||
def get_client(self):
|
||||
return anthropic.Anthropic(
|
||||
base_url="http://0.0.0.0:4000",
|
||||
api_key="sk-1234",
|
||||
)
|
||||
|
||||
def test_anthropic_messages_to_wildcard_model(self):
|
||||
client = self.get_client()
|
||||
response = client.messages.create(
|
||||
model="anthropic/claude-3-opus-20240229",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
max_tokens=100,
|
||||
)
|
||||
print(response)
|
|
@ -0,0 +1,487 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import AsyncIterator, Dict, Any
|
||||
import asyncio
|
||||
import unittest.mock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages.handler import (
|
||||
anthropic_messages,
|
||||
)
|
||||
from typing import Optional
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.router import Router
|
||||
import importlib
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for each test session."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_and_teardown(event_loop): # Add event_loop as a dependency
|
||||
curr_dir = os.getcwd()
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
|
||||
importlib.reload(litellm)
|
||||
|
||||
# Set the event loop from the fixture
|
||||
asyncio.set_event_loop(event_loop)
|
||||
|
||||
print(litellm)
|
||||
yield
|
||||
|
||||
# Clean up any pending tasks
|
||||
pending = asyncio.all_tasks(event_loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
# Run the event loop until all tasks are cancelled
|
||||
if pending:
|
||||
event_loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
|
||||
|
||||
def _validate_anthropic_response(response: Dict[str, Any]):
|
||||
assert "id" in response
|
||||
assert "content" in response
|
||||
assert "model" in response
|
||||
assert response["role"] == "assistant"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_messages_non_streaming():
|
||||
"""
|
||||
Test the anthropic_messages with non-streaming request
|
||||
"""
|
||||
# Get API key from environment
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("ANTHROPIC_API_KEY not found in environment")
|
||||
|
||||
# Set up test parameters
|
||||
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
|
||||
|
||||
# Call the handler
|
||||
response = await anthropic_messages(
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
model="claude-3-haiku-20240307",
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert "id" in response
|
||||
assert "content" in response
|
||||
assert "model" in response
|
||||
assert response["role"] == "assistant"
|
||||
|
||||
print(f"Non-streaming response: {json.dumps(response, indent=2)}")
|
||||
return response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_messages_streaming():
|
||||
"""
|
||||
Test the anthropic_messages with streaming request
|
||||
"""
|
||||
# Get API key from environment
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("ANTHROPIC_API_KEY not found in environment")
|
||||
|
||||
# Set up test parameters
|
||||
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
|
||||
|
||||
# Call the handler
|
||||
async_httpx_client = AsyncHTTPHandler()
|
||||
response = await anthropic_messages(
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
model="claude-3-haiku-20240307",
|
||||
max_tokens=100,
|
||||
stream=True,
|
||||
client=async_httpx_client,
|
||||
)
|
||||
|
||||
if isinstance(response, AsyncIterator):
|
||||
async for chunk in response:
|
||||
print("chunk=", chunk)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_messages_streaming_with_bad_request():
|
||||
"""
|
||||
Test the anthropic_messages with streaming request
|
||||
"""
|
||||
try:
|
||||
response = await anthropic_messages(
|
||||
messages=["hi"],
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
model="claude-3-haiku-20240307",
|
||||
max_tokens=100,
|
||||
stream=True,
|
||||
)
|
||||
print(response)
|
||||
async for chunk in response:
|
||||
print("chunk=", chunk)
|
||||
except Exception as e:
|
||||
print("got exception", e)
|
||||
print("vars", vars(e))
|
||||
assert e.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_messages_router_streaming_with_bad_request():
|
||||
"""
|
||||
Test the anthropic_messages with streaming request
|
||||
"""
|
||||
try:
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "claude-special-alias",
|
||||
"litellm_params": {
|
||||
"model": "claude-3-haiku-20240307",
|
||||
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
response = await router.aanthropic_messages(
|
||||
messages=["hi"],
|
||||
model="claude-special-alias",
|
||||
max_tokens=100,
|
||||
stream=True,
|
||||
)
|
||||
print(response)
|
||||
async for chunk in response:
|
||||
print("chunk=", chunk)
|
||||
except Exception as e:
|
||||
print("got exception", e)
|
||||
print("vars", vars(e))
|
||||
assert e.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_messages_litellm_router_non_streaming():
|
||||
"""
|
||||
Test the anthropic_messages with non-streaming request
|
||||
"""
|
||||
litellm._turn_on_debug()
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "claude-special-alias",
|
||||
"litellm_params": {
|
||||
"model": "claude-3-haiku-20240307",
|
||||
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Set up test parameters
|
||||
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
|
||||
|
||||
# Call the handler
|
||||
response = await router.aanthropic_messages(
|
||||
messages=messages,
|
||||
model="claude-special-alias",
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert "id" in response
|
||||
assert "content" in response
|
||||
assert "model" in response
|
||||
assert response["role"] == "assistant"
|
||||
|
||||
print(f"Non-streaming response: {json.dumps(response, indent=2)}")
|
||||
return response
|
||||
|
||||
|
||||
class TestCustomLogger(CustomLogger):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.logged_standard_logging_payload: Optional[StandardLoggingPayload] = None
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print("inside async_log_success_event")
|
||||
self.logged_standard_logging_payload = kwargs.get("standard_logging_object")
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_messages_litellm_router_non_streaming_with_logging():
|
||||
"""
|
||||
Test the anthropic_messages with non-streaming request
|
||||
|
||||
- Ensure Cost + Usage is tracked
|
||||
"""
|
||||
test_custom_logger = TestCustomLogger()
|
||||
litellm.callbacks = [test_custom_logger]
|
||||
litellm._turn_on_debug()
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "claude-special-alias",
|
||||
"litellm_params": {
|
||||
"model": "claude-3-haiku-20240307",
|
||||
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Set up test parameters
|
||||
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
|
||||
|
||||
# Call the handler
|
||||
response = await router.aanthropic_messages(
|
||||
messages=messages,
|
||||
model="claude-special-alias",
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
# Verify response
|
||||
_validate_anthropic_response(response)
|
||||
|
||||
print(f"Non-streaming response: {json.dumps(response, indent=2)}")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
assert test_custom_logger.logged_standard_logging_payload["messages"] == messages
|
||||
assert test_custom_logger.logged_standard_logging_payload["response"] is not None
|
||||
assert (
|
||||
test_custom_logger.logged_standard_logging_payload["model"]
|
||||
== "claude-3-haiku-20240307"
|
||||
)
|
||||
|
||||
# check logged usage + spend
|
||||
assert test_custom_logger.logged_standard_logging_payload["response_cost"] > 0
|
||||
assert (
|
||||
test_custom_logger.logged_standard_logging_payload["prompt_tokens"]
|
||||
== response["usage"]["input_tokens"]
|
||||
)
|
||||
assert (
|
||||
test_custom_logger.logged_standard_logging_payload["completion_tokens"]
|
||||
== response["usage"]["output_tokens"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_messages_litellm_router_streaming_with_logging():
|
||||
"""
|
||||
Test the anthropic_messages with streaming request
|
||||
|
||||
- Ensure Cost + Usage is tracked
|
||||
"""
|
||||
test_custom_logger = TestCustomLogger()
|
||||
litellm.callbacks = [test_custom_logger]
|
||||
# litellm._turn_on_debug()
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "claude-special-alias",
|
||||
"litellm_params": {
|
||||
"model": "claude-3-haiku-20240307",
|
||||
"api_key": os.getenv("ANTHROPIC_API_KEY"),
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Set up test parameters
|
||||
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
|
||||
|
||||
# Call the handler
|
||||
response = await router.aanthropic_messages(
|
||||
messages=messages,
|
||||
model="claude-special-alias",
|
||||
max_tokens=100,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
response_prompt_tokens = 0
|
||||
response_completion_tokens = 0
|
||||
all_anthropic_usage_chunks = []
|
||||
|
||||
async for chunk in response:
|
||||
# Decode chunk if it's bytes
|
||||
print("chunk=", chunk)
|
||||
|
||||
# Handle SSE format chunks
|
||||
if isinstance(chunk, bytes):
|
||||
chunk_str = chunk.decode("utf-8")
|
||||
# Extract the JSON data part from SSE format
|
||||
for line in chunk_str.split("\n"):
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
json_data = json.loads(line[6:]) # Skip the 'data: ' prefix
|
||||
print(
|
||||
"\n\nJSON data:",
|
||||
json.dumps(json_data, indent=4, default=str),
|
||||
)
|
||||
|
||||
# Extract usage information
|
||||
if (
|
||||
json_data.get("type") == "message_start"
|
||||
and "message" in json_data
|
||||
):
|
||||
if "usage" in json_data["message"]:
|
||||
usage = json_data["message"]["usage"]
|
||||
all_anthropic_usage_chunks.append(usage)
|
||||
print(
|
||||
"USAGE BLOCK",
|
||||
json.dumps(usage, indent=4, default=str),
|
||||
)
|
||||
elif "usage" in json_data:
|
||||
usage = json_data["usage"]
|
||||
all_anthropic_usage_chunks.append(usage)
|
||||
print(
|
||||
"USAGE BLOCK", json.dumps(usage, indent=4, default=str)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Failed to parse JSON from: {line[6:]}")
|
||||
elif hasattr(chunk, "message"):
|
||||
if chunk.message.usage:
|
||||
print(
|
||||
"USAGE BLOCK",
|
||||
json.dumps(chunk.message.usage, indent=4, default=str),
|
||||
)
|
||||
all_anthropic_usage_chunks.append(chunk.message.usage)
|
||||
elif hasattr(chunk, "usage"):
|
||||
print("USAGE BLOCK", json.dumps(chunk.usage, indent=4, default=str))
|
||||
all_anthropic_usage_chunks.append(chunk.usage)
|
||||
|
||||
print(
|
||||
"all_anthropic_usage_chunks",
|
||||
json.dumps(all_anthropic_usage_chunks, indent=4, default=str),
|
||||
)
|
||||
|
||||
# Extract token counts from usage data
|
||||
if all_anthropic_usage_chunks:
|
||||
response_prompt_tokens = max(
|
||||
[usage.get("input_tokens", 0) for usage in all_anthropic_usage_chunks]
|
||||
)
|
||||
response_completion_tokens = max(
|
||||
[usage.get("output_tokens", 0) for usage in all_anthropic_usage_chunks]
|
||||
)
|
||||
|
||||
print("input_tokens_anthropic_api", response_prompt_tokens)
|
||||
print("output_tokens_anthropic_api", response_completion_tokens)
|
||||
|
||||
await asyncio.sleep(4)
|
||||
|
||||
print(
|
||||
"logged_standard_logging_payload",
|
||||
json.dumps(
|
||||
test_custom_logger.logged_standard_logging_payload, indent=4, default=str
|
||||
),
|
||||
)
|
||||
|
||||
assert test_custom_logger.logged_standard_logging_payload["messages"] == messages
|
||||
assert test_custom_logger.logged_standard_logging_payload["response"] is not None
|
||||
assert (
|
||||
test_custom_logger.logged_standard_logging_payload["model"]
|
||||
== "claude-3-haiku-20240307"
|
||||
)
|
||||
|
||||
# check logged usage + spend
|
||||
assert test_custom_logger.logged_standard_logging_payload["response_cost"] > 0
|
||||
assert (
|
||||
test_custom_logger.logged_standard_logging_payload["prompt_tokens"]
|
||||
== response_prompt_tokens
|
||||
)
|
||||
assert (
|
||||
test_custom_logger.logged_standard_logging_payload["completion_tokens"]
|
||||
== response_completion_tokens
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_messages_with_extra_headers():
|
||||
"""
|
||||
Test the anthropic_messages with extra headers
|
||||
"""
|
||||
# Get API key from environment
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY", "fake-api-key")
|
||||
|
||||
# Set up test parameters
|
||||
messages = [{"role": "user", "content": "Hello, can you tell me a short joke?"}]
|
||||
extra_headers = {
|
||||
"anthropic-beta": "very-custom-beta-value",
|
||||
"anthropic-version": "custom-version-for-test",
|
||||
}
|
||||
|
||||
# Create a mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"id": "msg_123456",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Why did the chicken cross the road? To get to the other side!",
|
||||
}
|
||||
],
|
||||
"model": "claude-3-haiku-20240307",
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||
}
|
||||
|
||||
# Create a mock client with AsyncMock for the post method
|
||||
mock_client = MagicMock(spec=AsyncHTTPHandler)
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Call the handler with extra_headers and our mocked client
|
||||
response = await anthropic_messages(
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
model="claude-3-haiku-20240307",
|
||||
max_tokens=100,
|
||||
client=mock_client,
|
||||
provider_specific_header={
|
||||
"custom_llm_provider": "anthropic",
|
||||
"extra_headers": extra_headers,
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the post method was called with the right parameters
|
||||
mock_client.post.assert_called_once()
|
||||
call_kwargs = mock_client.post.call_args.kwargs
|
||||
|
||||
# Verify headers were passed correctly
|
||||
headers = call_kwargs.get("headers", {})
|
||||
print("HEADERS IN REQUEST", headers)
|
||||
for key, value in extra_headers.items():
|
||||
assert key in headers
|
||||
assert headers[key] == value
|
||||
|
||||
# Verify the response was processed correctly
|
||||
assert response == mock_response.json.return_value
|
||||
|
||||
return response
|
|
@ -54,7 +54,7 @@ async def test_get_litellm_virtual_key():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_proxy_route_api_key_auth():
|
||||
async def test_async_vertex_proxy_route_api_key_auth():
|
||||
"""
|
||||
Critical
|
||||
|
||||
|
@ -207,7 +207,7 @@ async def test_get_vertex_credentials_stored():
|
|||
router.add_vertex_credentials(
|
||||
project_id="test-project",
|
||||
location="us-central1",
|
||||
vertex_credentials="test-creds",
|
||||
vertex_credentials='{"credentials": "test-creds"}',
|
||||
)
|
||||
|
||||
creds = router.get_vertex_credentials(
|
||||
|
@ -215,7 +215,7 @@ async def test_get_vertex_credentials_stored():
|
|||
)
|
||||
assert creds.vertex_project == "test-project"
|
||||
assert creds.vertex_location == "us-central1"
|
||||
assert creds.vertex_credentials == "test-creds"
|
||||
assert creds.vertex_credentials == '{"credentials": "test-creds"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -227,18 +227,20 @@ async def test_add_vertex_credentials():
|
|||
router.add_vertex_credentials(
|
||||
project_id="test-project",
|
||||
location="us-central1",
|
||||
vertex_credentials="test-creds",
|
||||
vertex_credentials='{"credentials": "test-creds"}',
|
||||
)
|
||||
|
||||
assert "test-project-us-central1" in router.deployment_key_to_vertex_credentials
|
||||
creds = router.deployment_key_to_vertex_credentials["test-project-us-central1"]
|
||||
assert creds.vertex_project == "test-project"
|
||||
assert creds.vertex_location == "us-central1"
|
||||
assert creds.vertex_credentials == "test-creds"
|
||||
assert creds.vertex_credentials == '{"credentials": "test-creds"}'
|
||||
|
||||
# Test adding with None values
|
||||
router.add_vertex_credentials(
|
||||
project_id=None, location=None, vertex_credentials="test-creds"
|
||||
project_id=None,
|
||||
location=None,
|
||||
vertex_credentials='{"credentials": "test-creds"}',
|
||||
)
|
||||
# Should not add None values
|
||||
assert len(router.deployment_key_to_vertex_credentials) == 1
|
||||
|
|
|
@ -6,6 +6,7 @@ from typing import Optional
|
|||
from dotenv import load_dotenv
|
||||
from fastapi import Request
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -289,43 +290,6 @@ async def test_aaaaatext_completion_endpoint(model_list, sync_mode):
|
|||
assert response.choices[0].text == "I'm fine, thank you!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_router_completion_e2e(model_list):
|
||||
from litellm.adapters.anthropic_adapter import anthropic_adapter
|
||||
from litellm.types.llms.anthropic import AnthropicResponse
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
|
||||
|
||||
router = Router(model_list=model_list)
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
|
||||
## Test 1: user facing function
|
||||
response = await router.aadapter_completion(
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
messages=messages,
|
||||
adapter_id="anthropic",
|
||||
mock_response="This is a fake call",
|
||||
)
|
||||
|
||||
## Test 2: underlying function
|
||||
await router._aadapter_completion(
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
messages=messages,
|
||||
adapter_id="anthropic",
|
||||
mock_response="This is a fake call",
|
||||
)
|
||||
|
||||
print("Response: {}".format(response))
|
||||
|
||||
assert response is not None
|
||||
|
||||
AnthropicResponse.model_validate(response)
|
||||
|
||||
assert response.model == "gpt-3.5-turbo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_with_empty_choices(model_list):
|
||||
"""
|
||||
|
@ -349,3 +313,200 @@ async def test_router_with_empty_choices(model_list):
|
|||
mock_response=mock_response,
|
||||
)
|
||||
assert response is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ageneric_api_call_with_fallbacks_basic():
|
||||
"""
|
||||
Test the _ageneric_api_call_with_fallbacks method with a basic successful call
|
||||
"""
|
||||
# Create a mock function that will be passed to _ageneric_api_call_with_fallbacks
|
||||
mock_function = AsyncMock()
|
||||
mock_function.__name__ = "test_function"
|
||||
|
||||
# Create a mock response
|
||||
mock_response = {
|
||||
"id": "resp_123456",
|
||||
"role": "assistant",
|
||||
"content": "This is a test response",
|
||||
"model": "test-model",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 20},
|
||||
}
|
||||
mock_function.return_value = mock_response
|
||||
|
||||
# Create a router with a test model
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "test-model-alias",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/test-model",
|
||||
"api_key": "fake-api-key",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Call the _ageneric_api_call_with_fallbacks method
|
||||
response = await router._ageneric_api_call_with_fallbacks(
|
||||
model="test-model-alias",
|
||||
original_function=mock_function,
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
# Verify the mock function was called
|
||||
mock_function.assert_called_once()
|
||||
|
||||
# Verify the response
|
||||
assert response == mock_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadapter_completion():
|
||||
"""
|
||||
Test the aadapter_completion method which uses async_function_with_fallbacks
|
||||
"""
|
||||
# Create a mock for the _aadapter_completion method
|
||||
mock_response = {
|
||||
"id": "adapter_resp_123",
|
||||
"object": "adapter.completion",
|
||||
"created": 1677858242,
|
||||
"model": "test-model-with-adapter",
|
||||
"choices": [
|
||||
{
|
||||
"text": "This is a test adapter response",
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||
}
|
||||
|
||||
# Create a router with a patched _aadapter_completion method
|
||||
with patch.object(
|
||||
Router, "_aadapter_completion", new_callable=AsyncMock
|
||||
) as mock_method:
|
||||
mock_method.return_value = mock_response
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "test-adapter-model",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/test-model",
|
||||
"api_key": "fake-api-key",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Replace the async_function_with_fallbacks with a mock
|
||||
router.async_function_with_fallbacks = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Call the aadapter_completion method
|
||||
response = await router.aadapter_completion(
|
||||
adapter_id="test-adapter-id",
|
||||
model="test-adapter-model",
|
||||
prompt="This is a test prompt",
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response == mock_response
|
||||
|
||||
# Verify async_function_with_fallbacks was called with the right parameters
|
||||
router.async_function_with_fallbacks.assert_called_once()
|
||||
call_kwargs = router.async_function_with_fallbacks.call_args.kwargs
|
||||
assert call_kwargs["adapter_id"] == "test-adapter-id"
|
||||
assert call_kwargs["model"] == "test-adapter-model"
|
||||
assert call_kwargs["prompt"] == "This is a test prompt"
|
||||
assert call_kwargs["max_tokens"] == 100
|
||||
assert call_kwargs["original_function"] == router._aadapter_completion
|
||||
assert "metadata" in call_kwargs
|
||||
assert call_kwargs["metadata"]["model_group"] == "test-adapter-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test__aadapter_completion():
|
||||
"""
|
||||
Test the _aadapter_completion method directly
|
||||
"""
|
||||
# Create a mock response for litellm.aadapter_completion
|
||||
mock_response = {
|
||||
"id": "adapter_resp_123",
|
||||
"object": "adapter.completion",
|
||||
"created": 1677858242,
|
||||
"model": "test-model-with-adapter",
|
||||
"choices": [
|
||||
{
|
||||
"text": "This is a test adapter response",
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||
}
|
||||
|
||||
# Create a router with a mocked litellm.aadapter_completion
|
||||
with patch(
|
||||
"litellm.aadapter_completion", new_callable=AsyncMock
|
||||
) as mock_adapter_completion:
|
||||
mock_adapter_completion.return_value = mock_response
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "test-adapter-model",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/test-model",
|
||||
"api_key": "fake-api-key",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Mock the async_get_available_deployment method
|
||||
router.async_get_available_deployment = AsyncMock(
|
||||
return_value={
|
||||
"model_name": "test-adapter-model",
|
||||
"litellm_params": {
|
||||
"model": "test-model",
|
||||
"api_key": "fake-api-key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "test-unique-id",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Mock the async_routing_strategy_pre_call_checks method
|
||||
router.async_routing_strategy_pre_call_checks = AsyncMock()
|
||||
|
||||
# Call the _aadapter_completion method
|
||||
response = await router._aadapter_completion(
|
||||
adapter_id="test-adapter-id",
|
||||
model="test-adapter-model",
|
||||
prompt="This is a test prompt",
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response == mock_response
|
||||
|
||||
# Verify litellm.aadapter_completion was called with the right parameters
|
||||
mock_adapter_completion.assert_called_once()
|
||||
call_kwargs = mock_adapter_completion.call_args.kwargs
|
||||
assert call_kwargs["adapter_id"] == "test-adapter-id"
|
||||
assert call_kwargs["model"] == "test-model"
|
||||
assert call_kwargs["prompt"] == "This is a test prompt"
|
||||
assert call_kwargs["max_tokens"] == 100
|
||||
assert call_kwargs["api_key"] == "fake-api-key"
|
||||
assert call_kwargs["caching"] == router.cache_responses
|
||||
|
||||
# Verify the success call was recorded
|
||||
assert router.success_calls["test-model"] == 1
|
||||
assert router.total_calls["test-model"] == 1
|
||||
|
||||
# Verify async_routing_strategy_pre_call_checks was called
|
||||
router.async_routing_strategy_pre_call_checks.assert_called_once()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue