Merge pull request #4635 from BerriAI/litellm_anthropic_adapter

Anthropic `/v1/messages` endpoint support
This commit is contained in:
Krish Dholakia 2024-07-10 22:41:53 -07:00 committed by GitHub
commit dacce3d78b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 1163 additions and 161 deletions

View file

@ -882,3 +882,8 @@ from .batches.main import *
from .files.main import * from .files.main import *
from .scheduler import * from .scheduler import *
from .cost_calculator import response_cost_calculator, cost_per_token from .cost_calculator import response_cost_calculator, cost_per_token
### ADAPTERS ###
from .types.adapter import AdapterItem
adapters: List[AdapterItem] = []

View file

@ -0,0 +1,50 @@
# What is this?
## Translates OpenAI call to Anthropic `/v1/messages` format
import json
import os
import traceback
import uuid
from typing import Literal, Optional
import dotenv
import httpx
from pydantic import BaseModel
import litellm
from litellm import ChatCompletionRequest, verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.llms.anthropic import AnthropicMessagesRequest, AnthropicResponse
class AnthropicAdapter(CustomLogger):
def __init__(self) -> None:
super().__init__()
def translate_completion_input_params(
self, kwargs
) -> Optional[ChatCompletionRequest]:
"""
- translate params, where needed
- pass rest, as is
"""
request_body = AnthropicMessagesRequest(**kwargs) # type: ignore
translated_body = litellm.AnthropicConfig().translate_anthropic_to_openai(
anthropic_message_request=request_body
)
return translated_body
def translate_completion_output_params(
self, response: litellm.ModelResponse
) -> Optional[AnthropicResponse]:
return litellm.AnthropicConfig().translate_openai_response_to_anthropic(
response=response
)
def translate_completion_output_params_streaming(self) -> Optional[BaseModel]:
return super().translate_completion_output_params_streaming()
anthropic_adapter = AnthropicAdapter()

View file

@ -5,9 +5,12 @@ import traceback
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import dotenv import dotenv
from pydantic import BaseModel
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.llms.openai import ChatCompletionRequest
from litellm.types.utils import ModelResponse
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
@ -55,6 +58,30 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
def pre_call_check(self, deployment: dict) -> Optional[dict]: def pre_call_check(self, deployment: dict) -> Optional[dict]:
pass pass
#### ADAPTERS #### Allow calling 100+ LLMs in custom format - https://github.com/BerriAI/litellm/pulls
def translate_completion_input_params(
self, kwargs
) -> Optional[ChatCompletionRequest]:
"""
Translates the input params, from the provider's native format to the litellm.completion() format.
"""
pass
def translate_completion_output_params(
self, response: ModelResponse
) -> Optional[BaseModel]:
"""
Translates the output params, from the OpenAI format to the custom format.
"""
pass
def translate_completion_output_params_streaming(self) -> Optional[BaseModel]:
"""
Translates the streaming chunk, from the OpenAI format to the custom format.
"""
pass
#### CALL HOOKS - proxy only #### #### CALL HOOKS - proxy only ####
""" """
Control the modify incoming / outgoung data before calling the model Control the modify incoming / outgoung data before calling the model

View file

@ -20,19 +20,43 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client, _get_httpx_client,
) )
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AnthopicMessagesAssistantMessageParam,
AnthropicFinishReason,
AnthropicMessagesRequest,
AnthropicMessagesTool,
AnthropicMessagesToolChoice, AnthropicMessagesToolChoice,
AnthropicMessagesUserMessageParam,
AnthropicResponse,
AnthropicResponseContentBlockText,
AnthropicResponseContentBlockToolUse,
AnthropicResponseUsageBlock,
ContentBlockDelta, ContentBlockDelta,
ContentBlockStart, ContentBlockStart,
MessageBlockDelta, MessageBlockDelta,
MessageStartBlock, MessageStartBlock,
) )
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionAssistantToolCall,
ChatCompletionImageObject,
ChatCompletionImageUrlObject,
ChatCompletionRequest,
ChatCompletionResponseMessage, ChatCompletionResponseMessage,
ChatCompletionSystemMessage,
ChatCompletionTextObject,
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk, ChatCompletionToolCallFunctionChunk,
ChatCompletionToolChoiceFunctionParam,
ChatCompletionToolChoiceObjectParam,
ChatCompletionToolChoiceValues,
ChatCompletionToolMessage,
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
ChatCompletionUsageBlock, ChatCompletionUsageBlock,
ChatCompletionUserMessage,
) )
from litellm.types.utils import GenericStreamingChunk from litellm.types.utils import Choices, GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .base import BaseLLM from .base import BaseLLM
@ -168,6 +192,287 @@ class AnthropicConfig:
optional_params["top_p"] = value optional_params["top_p"] = value
return optional_params return optional_params
### FOR [BETA] `/v1/messages` endpoint support
def translatable_anthropic_params(self) -> List:
"""
Which anthropic params, we need to translate to the openai format.
"""
return ["messages", "metadata", "system", "tool_choice", "tools"]
def translate_anthropic_messages_to_openai(
self,
messages: List[
Union[
AnthropicMessagesUserMessageParam,
AnthopicMessagesAssistantMessageParam,
]
],
) -> List:
new_messages: List[AllMessageValues] = []
for m in messages:
user_message: Optional[ChatCompletionUserMessage] = None
tool_message_list: List[ChatCompletionToolMessage] = []
## USER MESSAGE ##
if m["role"] == "user":
## translate user message
if isinstance(m["content"], str):
user_message = ChatCompletionUserMessage(
role="user", content=m["content"]
)
elif isinstance(m["content"], list):
new_user_content_list: List[
Union[ChatCompletionTextObject, ChatCompletionImageObject]
] = []
for content in m["content"]:
if content["type"] == "text":
text_obj = ChatCompletionTextObject(
type="text", text=content["text"]
)
new_user_content_list.append(text_obj)
elif content["type"] == "image":
image_url = ChatCompletionImageUrlObject(
url=f"data:{content['type']};base64,{content['source']}"
)
image_obj = ChatCompletionImageObject(
type="image_url", image_url=image_url
)
new_user_content_list.append(image_obj)
elif content["type"] == "tool_result":
if "content" not in content:
tool_result = ChatCompletionToolMessage(
role="tool",
tool_call_id=content["tool_use_id"],
content="",
)
tool_message_list.append(tool_result)
elif isinstance(content["content"], str):
tool_result = ChatCompletionToolMessage(
role="tool",
tool_call_id=content["tool_use_id"],
content=content["content"],
)
tool_message_list.append(tool_result)
elif isinstance(content["content"], list):
for c in content["content"]:
if c["type"] == "text":
tool_result = ChatCompletionToolMessage(
role="tool",
tool_call_id=content["tool_use_id"],
content=c["text"],
)
tool_message_list.append(tool_result)
elif c["type"] == "image":
image_str = (
f"data:{c['type']};base64,{c['source']}"
)
tool_result = ChatCompletionToolMessage(
role="tool",
tool_call_id=content["tool_use_id"],
content=image_str,
)
tool_message_list.append(tool_result)
if user_message is not None:
new_messages.append(user_message)
if len(tool_message_list) > 0:
new_messages.extend(tool_message_list)
## ASSISTANT MESSAGE ##
assistant_message_str: Optional[str] = None
tool_calls: List[ChatCompletionAssistantToolCall] = []
if m["role"] == "assistant":
if isinstance(m["content"], str):
assistant_message_str = m["content"]
elif isinstance(m["content"], list):
for content in m["content"]:
if content["type"] == "text":
if assistant_message_str is None:
assistant_message_str = content["text"]
else:
assistant_message_str += content["text"]
elif content["type"] == "tool_use":
function_chunk = ChatCompletionToolCallFunctionChunk(
name=content["name"],
arguments=json.dumps(content["input"]),
)
tool_calls.append(
ChatCompletionAssistantToolCall(
id=content["id"],
type="function",
function=function_chunk,
)
)
if assistant_message_str is not None or len(tool_calls) > 0:
assistant_message = ChatCompletionAssistantMessage(
role="assistant",
content=assistant_message_str,
)
if len(tool_calls) > 0:
assistant_message["tool_calls"] = tool_calls
new_messages.append(assistant_message)
return new_messages
def translate_anthropic_tool_choice_to_openai(
self, tool_choice: AnthropicMessagesToolChoice
) -> ChatCompletionToolChoiceValues:
if tool_choice["type"] == "any":
return "required"
elif tool_choice["type"] == "auto":
return "auto"
elif tool_choice["type"] == "tool":
tc_function_param = ChatCompletionToolChoiceFunctionParam(
name=tool_choice.get("name", "")
)
return ChatCompletionToolChoiceObjectParam(
type="function", function=tc_function_param
)
else:
raise ValueError(
"Incompatible tool choice param submitted - {}".format(tool_choice)
)
def translate_anthropic_tools_to_openai(
self, tools: List[AnthropicMessagesTool]
) -> List[ChatCompletionToolParam]:
new_tools: List[ChatCompletionToolParam] = []
for tool in tools:
function_chunk = ChatCompletionToolParamFunctionChunk(
name=tool["name"],
parameters=tool["input_schema"],
)
if "description" in tool:
function_chunk["description"] = tool["description"]
new_tools.append(
ChatCompletionToolParam(type="function", function=function_chunk)
)
return new_tools
def translate_anthropic_to_openai(
self, anthropic_message_request: AnthropicMessagesRequest
) -> ChatCompletionRequest:
"""
This is used by the beta Anthropic Adapter, for translating anthropic `/v1/messages` requests to the openai format.
"""
new_messages: List[AllMessageValues] = []
## CONVERT ANTHROPIC MESSAGES TO OPENAI
new_messages = self.translate_anthropic_messages_to_openai(
messages=anthropic_message_request["messages"]
)
## ADD SYSTEM MESSAGE TO MESSAGES
if "system" in anthropic_message_request:
new_messages.insert(
0,
ChatCompletionSystemMessage(
role="system", content=anthropic_message_request["system"]
),
)
new_kwargs: ChatCompletionRequest = {
"model": anthropic_message_request["model"],
"messages": new_messages,
}
## CONVERT METADATA (user_id)
if "metadata" in anthropic_message_request:
if "user_id" in anthropic_message_request["metadata"]:
new_kwargs["user"] = anthropic_message_request["metadata"]["user_id"]
## CONVERT TOOL CHOICE
if "tool_choice" in anthropic_message_request:
new_kwargs["tool_choice"] = self.translate_anthropic_tool_choice_to_openai(
tool_choice=anthropic_message_request["tool_choice"]
)
## CONVERT TOOLS
if "tools" in anthropic_message_request:
new_kwargs["tools"] = self.translate_anthropic_tools_to_openai(
tools=anthropic_message_request["tools"]
)
translatable_params = self.translatable_anthropic_params()
for k, v in anthropic_message_request.items():
if k not in translatable_params: # pass remaining params as is
new_kwargs[k] = v # type: ignore
return new_kwargs
def _translate_openai_content_to_anthropic(
self, choices: List[Choices]
) -> List[
Union[AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse]
]:
new_content: List[
Union[
AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse
]
] = []
for choice in choices:
if (
choice.message.tool_calls is not None
and len(choice.message.tool_calls) > 0
):
for tool_call in choice.message.tool_calls:
new_content.append(
AnthropicResponseContentBlockToolUse(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name or "",
input=json.loads(tool_call.function.arguments),
)
)
elif choice.message.content is not None:
new_content.append(
AnthropicResponseContentBlockText(
type="text", text=choice.message.content
)
)
return new_content
def _translate_openai_finish_reason_to_anthropic(
self, openai_finish_reason: str
) -> AnthropicFinishReason:
if openai_finish_reason == "stop":
return "end_turn"
elif openai_finish_reason == "length":
return "max_tokens"
elif openai_finish_reason == "tool_calls":
return "tool_use"
return "end_turn"
def translate_openai_response_to_anthropic(
self, response: litellm.ModelResponse
) -> AnthropicResponse:
## translate content block
anthropic_content = self._translate_openai_content_to_anthropic(choices=response.choices) # type: ignore
## extract finish reason
anthropic_finish_reason = self._translate_openai_finish_reason_to_anthropic(
openai_finish_reason=response.choices[0].finish_reason # type: ignore
)
# extract usage
usage: litellm.Usage = getattr(response, "usage")
anthropic_usage = AnthropicResponseUsageBlock(
input_tokens=usage.prompt_tokens, output_tokens=usage.completion_tokens
)
translated_obj = AnthropicResponse(
id=response.id,
type="message",
role="assistant",
model=response.model or "unknown-model",
stop_sequence=None,
usage=anthropic_usage,
content=anthropic_content,
stop_reason=anthropic_finish_reason,
)
return translated_obj
# makes headers for API call # makes headers for API call
def validate_environment(api_key, user_headers): def validate_environment(api_key, user_headers):
@ -231,121 +536,6 @@ class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
# def process_streaming_response(
# self,
# model: str,
# response: Union[requests.Response, httpx.Response],
# model_response: ModelResponse,
# stream: bool,
# logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
# optional_params: dict,
# api_key: str,
# data: Union[dict, str],
# messages: List,
# print_verbose,
# encoding,
# ) -> CustomStreamWrapper:
# """
# Return stream object for tool-calling + streaming
# """
# ## LOGGING
# logging_obj.post_call(
# input=messages,
# api_key=api_key,
# original_response=response.text,
# additional_args={"complete_input_dict": data},
# )
# print_verbose(f"raw model_response: {response.text}")
# ## RESPONSE OBJECT
# try:
# completion_response = response.json()
# except:
# raise AnthropicError(
# message=response.text, status_code=response.status_code
# )
# text_content = ""
# tool_calls = []
# for content in completion_response["content"]:
# if content["type"] == "text":
# text_content += content["text"]
# ## TOOL CALLING
# elif content["type"] == "tool_use":
# tool_calls.append(
# {
# "id": content["id"],
# "type": "function",
# "function": {
# "name": content["name"],
# "arguments": json.dumps(content["input"]),
# },
# }
# )
# if "error" in completion_response:
# raise AnthropicError(
# message=str(completion_response["error"]),
# status_code=response.status_code,
# )
# _message = litellm.Message(
# tool_calls=tool_calls,
# content=text_content or None,
# )
# model_response.choices[0].message = _message # type: ignore
# model_response._hidden_params["original_response"] = completion_response[
# "content"
# ] # allow user to access raw anthropic tool calling response
# model_response.choices[0].finish_reason = map_finish_reason(
# completion_response["stop_reason"]
# )
# print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# # return an iterator
# streaming_model_response = ModelResponse(stream=True)
# streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
# 0
# ].finish_reason
# # streaming_model_response.choices = [litellm.utils.StreamingChoices()]
# streaming_choice = litellm.utils.StreamingChoices()
# streaming_choice.index = model_response.choices[0].index
# _tool_calls = []
# print_verbose(
# f"type of model_response.choices[0]: {type(model_response.choices[0])}"
# )
# print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
# if isinstance(model_response.choices[0], litellm.Choices):
# if getattr(
# model_response.choices[0].message, "tool_calls", None
# ) is not None and isinstance(
# model_response.choices[0].message.tool_calls, list
# ):
# for tool_call in model_response.choices[0].message.tool_calls:
# _tool_call = {**tool_call.dict(), "index": 0}
# _tool_calls.append(_tool_call)
# delta_obj = litellm.utils.Delta(
# content=getattr(model_response.choices[0].message, "content", None),
# role=model_response.choices[0].message.role,
# tool_calls=_tool_calls,
# )
# streaming_choice.delta = delta_obj
# streaming_model_response.choices = [streaming_choice]
# completion_stream = ModelResponseIterator(
# model_response=streaming_model_response
# )
# print_verbose(
# "Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
# )
# return CustomStreamWrapper(
# completion_stream=completion_stream,
# model=model,
# custom_llm_provider="cached_response",
# logging_obj=logging_obj,
# )
# else:
# raise AnthropicError(
# status_code=422,
# message="Unprocessable response object - {}".format(response.text),
# )
def process_response( def process_response(
self, self,
model: str, model: str,

View file

@ -38,6 +38,7 @@ import dotenv
import httpx import httpx
import openai import openai
import tiktoken import tiktoken
from pydantic import BaseModel
from typing_extensions import overload from typing_extensions import overload
import litellm import litellm
@ -48,6 +49,7 @@ from litellm import ( # type: ignore
get_litellm_params, get_litellm_params,
get_optional_params, get_optional_params,
) )
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.utils import ( from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
@ -3943,6 +3945,63 @@ def text_completion(
return text_completion_response return text_completion_response
###### Adapter Completion ################
async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
"""
Implemented to handle async calls for adapter_completion()
"""
try:
translation_obj: Optional[CustomLogger] = None
for item in litellm.adapters:
if item["id"] == adapter_id:
translation_obj = item["adapter"]
if translation_obj is None:
raise ValueError(
"No matching adapter given. Received 'adapter_id'={}, litellm.adapters={}".format(
adapter_id, litellm.adapters
)
)
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
response: ModelResponse = await acompletion(**new_kwargs) # type: ignore
translated_response = translation_obj.translate_completion_output_params(
response=response
)
return translated_response
except Exception as e:
raise e
def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
translation_obj: Optional[CustomLogger] = None
for item in litellm.adapters:
if item["id"] == adapter_id:
translation_obj = item["adapter"]
if translation_obj is None:
raise ValueError(
"No matching adapter given. Received 'adapter_id'={}, litellm.adapters={}".format(
adapter_id, litellm.adapters
)
)
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
response: ModelResponse = completion(**new_kwargs) # type: ignore
translated_response = translation_obj.translate_completion_output_params(
response=response
)
return translated_response
##### Moderation ####################### ##### Moderation #######################

View file

@ -2,18 +2,9 @@ model_list:
- model_name: "*" - model_name: "*"
litellm_params: litellm_params:
model: "openai/*" model: "openai/*"
- model_name: gemini-1.5-flash - model_name: claude-3-5-sonnet-20240620
litellm_params: litellm_params:
model: gemini/gemini-1.5-flash model: gpt-3.5-turbo
- model_name: whisper
litellm_params:
model: azure/azure-whisper
api_version: 2024-02-15-preview
api_base: os.environ/AZURE_EUROPE_API_BASE
api_key: os.environ/AZURE_EUROPE_API_KEY
model_info:
mode: audio_transcription
general_settings: general_settings:

View file

@ -71,6 +71,11 @@ azure_api_key_header = APIKeyHeader(
auto_error=False, auto_error=False,
description="Some older versions of the openai Python package will send an API-Key header with just the API key ", description="Some older versions of the openai Python package will send an API-Key header with just the API key ",
) )
anthropic_api_key_header = APIKeyHeader(
name="x-api-key",
auto_error=False,
description="If anthropic client used.",
)
def _get_bearer_token( def _get_bearer_token(
@ -87,6 +92,9 @@ async def user_api_key_auth(
request: Request, request: Request,
api_key: str = fastapi.Security(api_key_header), api_key: str = fastapi.Security(api_key_header),
azure_api_key_header: str = fastapi.Security(azure_api_key_header), azure_api_key_header: str = fastapi.Security(azure_api_key_header),
anthropic_api_key_header: Optional[str] = fastapi.Security(
anthropic_api_key_header
),
) -> UserAPIKeyAuth: ) -> UserAPIKeyAuth:
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
@ -114,6 +122,9 @@ async def user_api_key_auth(
elif isinstance(azure_api_key_header, str): elif isinstance(azure_api_key_header, str):
api_key = azure_api_key_header api_key = azure_api_key_header
elif isinstance(anthropic_api_key_header, str):
api_key = anthropic_api_key_header
parent_otel_span: Optional[Span] = None parent_otel_span: Optional[Span] = None
if open_telemetry_logger is not None: if open_telemetry_logger is not None:
parent_otel_span = open_telemetry_logger.tracer.start_span( parent_otel_span = open_telemetry_logger.tracer.start_span(

View file

@ -215,6 +215,12 @@ from litellm.router import (
from litellm.router import ModelInfo as RouterModelInfo from litellm.router import ModelInfo as RouterModelInfo
from litellm.router import updateDeployment from litellm.router import updateDeployment
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
from litellm.types.llms.anthropic import (
AnthropicMessagesRequest,
AnthropicResponse,
AnthropicResponseContentBlockText,
AnthropicResponseUsageBlock,
)
from litellm.types.llms.openai import HttpxBinaryResponseContent from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import RouterGeneralSettings from litellm.types.router import RouterGeneralSettings
@ -5041,6 +5047,198 @@ async def moderations(
) )
#### ANTHROPIC ENDPOINTS ####
@router.post(
"/v1/messages",
tags=["[beta] Anthropic `/v1/messages`"],
dependencies=[Depends(user_api_key_auth)],
response_model=AnthropicResponse,
)
async def anthropic_response(
anthropic_data: AnthropicMessagesRequest,
fastapi_response: Response,
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
from litellm import adapter_completion
from litellm.adapters.anthropic_adapter import anthropic_adapter
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
data: dict = {**anthropic_data, "adapter_id": "anthropic"}
try:
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
)
if user_model:
data["model"] = user_model
data = await add_litellm_data_to_request(
data=data, # type: ignore
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
# override with user settings, these are params passed via cli
if user_temperature:
data["temperature"] = user_temperature
if user_request_timeout:
data["request_timeout"] = user_request_timeout
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]]
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
)
### ROUTE THE REQUESTs ###
router_model_names = llm_router.model_names if llm_router is not None else []
# skip router if user passed their key
if "api_key" in data:
llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
llm_response = asyncio.create_task(
llm_router.aadapter_completion(**data, specific_deployment=True)
)
elif (
llm_router is not None and data["model"] in llm_router.get_model_ids()
): # model in router model list
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "completion: Invalid model name passed in model="
+ data.get("model", "")
},
)
# Await the llm_response task
response = await llm_response
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
verbose_proxy_logger.debug("final response: %s", response)
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
)
)
verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response))
return response
except RejectedRequestError as e:
_data = e.request_data
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
)
if _data.get("stream", None) is not None and _data["stream"] == True:
_chat_response = litellm.ModelResponse()
_usage = litellm.Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
_chat_response.usage = _usage # type: ignore
_chat_response.choices[0].message.content = e.message # type: ignore
_iterator = litellm.utils.ModelResponseIterator(
model_response=_chat_response, convert_to_delta=True
)
_streaming_response = litellm.TextCompletionStreamWrapper(
completion_stream=_iterator,
model=_data.get("model", ""),
)
selected_data_generator = select_data_generator(
response=_streaming_response,
user_api_key_dict=user_api_key_dict,
request_data=data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
headers={},
)
else:
_response = litellm.TextCompletionResponse()
_response.choices[0].text = e.message
return _response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.completion(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
#### DEV UTILS #### #### DEV UTILS ####
# @router.get( # @router.get(

View file

@ -1765,6 +1765,125 @@ class Router:
self.fail_calls[model] += 1 self.fail_calls[model] += 1
raise e raise e
async def aadapter_completion(
self,
adapter_id: str,
model: str,
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
is_async: Optional[bool] = False,
**kwargs,
):
try:
kwargs["model"] = model
kwargs["adapter_id"] = adapter_id
kwargs["original_function"] = self._aadapter_completion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
asyncio.create_task(
send_llm_exception_alert(
litellm_router_instance=self,
request_kwargs=kwargs,
error_traceback_str=traceback.format_exc(),
original_exception=e,
)
)
raise e
async def _aadapter_completion(self, adapter_id: str, model: str, **kwargs):
try:
verbose_router_logger.debug(
f"Inside _aadapter_completion()- model: {model}; kwargs: {kwargs}"
)
deployment = await self.async_get_available_deployment(
model=model,
messages=[{"role": "user", "content": "default text"}],
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
"api_base": deployment.get("litellm_params", {}).get("api_base"),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = litellm.aadapter_completion(
**{
**data,
"adapter_id": adapter_id,
"caching": self.cache_responses,
"client": model_client,
"timeout": self.timeout,
**kwargs,
}
)
rpm_semaphore = self._get_client(
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response # type: ignore
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response # type: ignore
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.aadapter_completion(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"litellm.aadapter_completion(model={model})\033[31m Exception {str(e)}\033[0m"
)
if model is not None:
self.fail_calls[model] += 1
raise e
def embedding( def embedding(
self, self,
model: str, model: str,

View file

@ -0,0 +1,103 @@
# What is this?
## Unit tests for Anthropic Adapter
import asyncio
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from unittest.mock import MagicMock, patch
import pytest
import litellm
from litellm import AnthropicConfig, Router, adapter_completion
from litellm.adapters.anthropic_adapter import anthropic_adapter
from litellm.types.llms.anthropic import AnthropicResponse
def test_anthropic_completion_messages_translation():
messages = [{"role": "user", "content": "Hey, how's it going?"}]
translated_messages = AnthropicConfig().translate_anthropic_messages_to_openai(messages=messages) # type: ignore
assert translated_messages == [{"role": "user", "content": "Hey, how's it going?"}]
def test_anthropic_completion_input_translation():
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
}
translated_input = anthropic_adapter.translate_completion_input_params(kwargs=data)
assert translated_input is not None
assert translated_input["model"] == "gpt-3.5-turbo"
assert translated_input["messages"] == [
{"role": "user", "content": "Hey, how's it going?"}
]
def test_anthropic_completion_e2e():
litellm.set_verbose = True
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response = adapter_completion(
model="gpt-3.5-turbo",
messages=messages,
adapter_id="anthropic",
mock_response="This is a fake call",
)
print("Response: {}".format(response))
assert response is not None
assert isinstance(response, AnthropicResponse)
@pytest.mark.asyncio
async def test_anthropic_router_completion_e2e():
litellm.set_verbose = True
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
router = Router(
model_list=[
{
"model_name": "claude-3-5-sonnet-20240620",
"litellm_params": {
"model": "gpt-3.5-turbo",
"mock_response": "hi this is macintosh.",
},
}
]
)
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response = await router.aadapter_completion(
model="claude-3-5-sonnet-20240620",
messages=messages,
adapter_id="anthropic",
mock_response="This is a fake call",
)
print("Response: {}".format(response))
assert response is not None
assert isinstance(response, AnthropicResponse)
assert response.model == "gpt-3.5-turbo"

10
litellm/types/adapter.py Normal file
View file

@ -0,0 +1,10 @@
from typing import List
from typing_extensions import Dict, Required, TypedDict, override
from litellm.integrations.custom_logger import CustomLogger
class AdapterItem(TypedDict):
id: str
adapter: CustomLogger

View file

@ -9,25 +9,27 @@ class AnthropicMessagesToolChoice(TypedDict, total=False):
name: str name: str
class AnthopicMessagesAssistantMessageTextContentParam(TypedDict, total=False): class AnthropicMessagesTool(TypedDict, total=False):
type: Required[Literal["text"]] name: Required[str]
description: str
input_schema: Required[dict]
class AnthropicMessagesTextParam(TypedDict):
type: Literal["text"]
text: str text: str
class AnthopicMessagesAssistantMessageToolCallParam(TypedDict, total=False): class AnthropicMessagesToolUseParam(TypedDict):
type: Required[Literal["tool_use"]] type: Literal["tool_use"]
id: str id: str
name: str name: str
input: dict input: dict
AnthropicMessagesAssistantMessageValues = Union[ AnthropicMessagesAssistantMessageValues = Union[
AnthopicMessagesAssistantMessageTextContentParam, AnthropicMessagesTextParam,
AnthopicMessagesAssistantMessageToolCallParam, AnthropicMessagesToolUseParam,
] ]
@ -46,6 +48,72 @@ class AnthopicMessagesAssistantMessageParam(TypedDict, total=False):
""" """
class AnthropicImageParamSource(TypedDict):
type: Literal["base64"]
media_type: str
data: str
class AnthropicMessagesImageParam(TypedDict):
type: Literal["image"]
source: AnthropicImageParamSource
class AnthropicMessagesToolResultContent(TypedDict):
type: Literal["text"]
text: str
class AnthropicMessagesToolResultParam(TypedDict, total=False):
type: Required[Literal["tool_result"]]
tool_use_id: Required[str]
is_error: bool
content: Union[
str,
Iterable[
Union[AnthropicMessagesToolResultContent, AnthropicMessagesImageParam]
],
]
AnthropicMessagesUserMessageValues = Union[
AnthropicMessagesTextParam,
AnthropicMessagesImageParam,
AnthropicMessagesToolResultParam,
]
class AnthropicMessagesUserMessageParam(TypedDict, total=False):
role: Required[Literal["user"]]
content: Required[Union[str, Iterable[AnthropicMessagesUserMessageValues]]]
class AnthropicMetadata(TypedDict, total=False):
user_id: str
class AnthropicMessagesRequest(TypedDict, total=False):
model: Required[str]
messages: Required[
List[
Union[
AnthropicMessagesUserMessageParam,
AnthopicMessagesAssistantMessageParam,
]
]
]
max_tokens: Required[int]
metadata: AnthropicMetadata
stop_sequences: List[str]
stream: bool
system: str
temperature: float
tool_choice: AnthropicMessagesToolChoice
tools: List[AnthropicMessagesTool]
top_k: int
top_p: float
class ContentTextBlockDelta(TypedDict): class ContentTextBlockDelta(TypedDict):
""" """
'delta': {'type': 'text_delta', 'text': 'Hello'} 'delta': {'type': 'text_delta', 'text': 'Hello'}
@ -155,3 +223,51 @@ class MessageStartBlock(TypedDict):
type: Literal["message_start"] type: Literal["message_start"]
message: MessageChunk message: MessageChunk
class AnthropicResponseContentBlockText(BaseModel):
type: Literal["text"]
text: str
class AnthropicResponseContentBlockToolUse(BaseModel):
type: Literal["tool_use"]
id: str
name: str
input: dict
class AnthropicResponseUsageBlock(BaseModel):
input_tokens: int
output_tokens: int
AnthropicFinishReason = Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
class AnthropicResponse(BaseModel):
id: str
"""Unique object identifier."""
type: Literal["message"]
"""For Messages, this is always "message"."""
role: Literal["assistant"]
"""Conversational role of the generated message. This will always be "assistant"."""
content: List[
Union[AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse]
]
"""Content generated by the model."""
model: str
"""The model that handled the request."""
stop_reason: Optional[AnthropicFinishReason]
"""The reason that we stopped."""
stop_sequence: Optional[str]
"""Which custom stop sequence was generated, if any."""
usage: AnthropicResponseUsageBlock
"""Billing and rate-limit usage."""

View file

@ -305,7 +305,13 @@ class ChatCompletionToolCallFunctionChunk(TypedDict, total=False):
arguments: str arguments: str
class ChatCompletionToolCallChunk(TypedDict): class ChatCompletionAssistantToolCall(TypedDict):
id: Optional[str]
type: Literal["function"]
function: ChatCompletionToolCallFunctionChunk
class ChatCompletionToolCallChunk(TypedDict): # result of /chat/completions call
id: Optional[str] id: Optional[str]
type: Literal["function"] type: Literal["function"]
function: ChatCompletionToolCallFunctionChunk function: ChatCompletionToolCallFunctionChunk
@ -319,6 +325,107 @@ class ChatCompletionDeltaToolCallChunk(TypedDict, total=False):
index: int index: int
class ChatCompletionTextObject(TypedDict):
type: Literal["text"]
text: str
class ChatCompletionImageUrlObject(TypedDict, total=False):
url: Required[str]
detail: str
class ChatCompletionImageObject(TypedDict):
type: Literal["image_url"]
image_url: ChatCompletionImageUrlObject
class ChatCompletionUserMessage(TypedDict):
role: Literal["user"]
content: Union[
str, Iterable[Union[ChatCompletionTextObject, ChatCompletionImageObject]]
]
class ChatCompletionAssistantMessage(TypedDict, total=False):
role: Required[Literal["assistant"]]
content: Optional[str]
name: str
tool_calls: List[ChatCompletionAssistantToolCall]
class ChatCompletionToolMessage(TypedDict):
role: Literal["tool"]
content: str
tool_call_id: str
class ChatCompletionSystemMessage(TypedDict, total=False):
role: Required[Literal["system"]]
content: Required[str]
name: str
AllMessageValues = Union[
ChatCompletionUserMessage,
ChatCompletionAssistantMessage,
ChatCompletionToolMessage,
ChatCompletionSystemMessage,
]
class ChatCompletionToolChoiceFunctionParam(TypedDict):
name: str
class ChatCompletionToolChoiceObjectParam(TypedDict):
type: Literal["function"]
function: ChatCompletionToolChoiceFunctionParam
ChatCompletionToolChoiceStringValues = Literal["none", "auto", "required"]
ChatCompletionToolChoiceValues = Union[
ChatCompletionToolChoiceStringValues, ChatCompletionToolChoiceObjectParam
]
class ChatCompletionToolParamFunctionChunk(TypedDict, total=False):
name: Required[str]
description: str
parameters: dict
class ChatCompletionToolParam(TypedDict):
type: Literal["function"]
function: ChatCompletionToolParamFunctionChunk
class ChatCompletionRequest(TypedDict, total=False):
model: Required[str]
messages: Required[List[AllMessageValues]]
frequency_penalty: float
logit_bias: dict
logprobs: bool
top_logprobs: int
max_tokens: int
n: int
presence_penalty: float
response_format: dict
seed: int
service_tier: str
stop: Union[str, List[str]]
stream_options: dict
temperature: float
top_p: float
tools: List[ChatCompletionToolParam]
tool_choice: ChatCompletionToolChoiceValues
parallel_tool_calls: bool
function_call: Union[str, dict]
functions: List
user: str
class ChatCompletionDeltaChunk(TypedDict, total=False): class ChatCompletionDeltaChunk(TypedDict, total=False):
content: Optional[str] content: Optional[str]
tool_calls: List[ChatCompletionDeltaToolCallChunk] tool_calls: List[ChatCompletionDeltaToolCallChunk]

View file

@ -166,7 +166,9 @@ class FunctionCall(OpenAIObject):
class Function(OpenAIObject): class Function(OpenAIObject):
arguments: str arguments: str
name: Optional[str] = None name: Optional[
str
] # can be None - openai e.g.: ChoiceDeltaToolCallFunction(arguments='{"', name=None), type=None)
def __init__( def __init__(
self, self,
@ -280,29 +282,43 @@ class ChatCompletionMessageToolCall(OpenAIObject):
setattr(self, key, value) setattr(self, key, value)
"""
Reference:
ChatCompletionMessage(content='This is a test', role='assistant', function_call=None, tool_calls=None))
"""
class Message(OpenAIObject): class Message(OpenAIObject):
content: Optional[str]
role: Literal["assistant"]
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
function_call: Optional[FunctionCall]
def __init__( def __init__(
self, self,
content: Optional[str] = "default", content: Optional[str] = None,
role="assistant", role: Literal["assistant"] = "assistant",
logprobs=None,
function_call=None, function_call=None,
tool_calls=None, tool_calls=None,
**params, **params,
): ):
super(Message, self).__init__(**params) init_values = {
self.content = content "content": content,
self.role = role "role": role,
if function_call is not None: "function_call": (
self.function_call = FunctionCall(**function_call) FunctionCall(**function_call) if function_call is not None else None
),
if tool_calls is not None: "tool_calls": (
self.tool_calls = [] [ChatCompletionMessageToolCall(**tool_call) for tool_call in tool_calls]
for tool_call in tool_calls: if tool_calls is not None
self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call)) else None
),
if logprobs is not None: }
self._logprobs = ChoiceLogprobs(**logprobs) super(Message, self).__init__(
**init_values,
**params,
)
def get(self, key, default=None): def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist # Custom .get() method to access attributes with a default value if the attribute doesn't exist

View file

@ -9483,8 +9483,8 @@ class CustomStreamWrapper:
model_response.choices[0].delta = Delta(**_json_delta) model_response.choices[0].delta = Delta(**_json_delta)
except Exception as e: except Exception as e:
verbose_logger.error( verbose_logger.error(
"litellm.CustomStreamWrapper.chunk_creator(): Exception occured - {}".format( "litellm.CustomStreamWrapper.chunk_creator(): Exception occured - {}\n{}".format(
str(e) str(e), traceback.format_exc()
) )
) )
verbose_logger.debug(traceback.format_exc()) verbose_logger.debug(traceback.format_exc())