Merge pull request #4635 from BerriAI/litellm_anthropic_adapter

Anthropic `/v1/messages` endpoint support
This commit is contained in:
Krish Dholakia 2024-07-10 22:41:53 -07:00 committed by GitHub
commit dacce3d78b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 1163 additions and 161 deletions

View file

@ -882,3 +882,8 @@ from .batches.main import *
from .files.main import *
from .scheduler import *
from .cost_calculator import response_cost_calculator, cost_per_token
### ADAPTERS ###
from .types.adapter import AdapterItem
adapters: List[AdapterItem] = []

View file

@ -0,0 +1,50 @@
# What is this?
## Translates OpenAI call to Anthropic `/v1/messages` format
import json
import os
import traceback
import uuid
from typing import Literal, Optional
import dotenv
import httpx
from pydantic import BaseModel
import litellm
from litellm import ChatCompletionRequest, verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.llms.anthropic import AnthropicMessagesRequest, AnthropicResponse
class AnthropicAdapter(CustomLogger):
def __init__(self) -> None:
super().__init__()
def translate_completion_input_params(
self, kwargs
) -> Optional[ChatCompletionRequest]:
"""
- translate params, where needed
- pass rest, as is
"""
request_body = AnthropicMessagesRequest(**kwargs) # type: ignore
translated_body = litellm.AnthropicConfig().translate_anthropic_to_openai(
anthropic_message_request=request_body
)
return translated_body
def translate_completion_output_params(
self, response: litellm.ModelResponse
) -> Optional[AnthropicResponse]:
return litellm.AnthropicConfig().translate_openai_response_to_anthropic(
response=response
)
def translate_completion_output_params_streaming(self) -> Optional[BaseModel]:
return super().translate_completion_output_params_streaming()
anthropic_adapter = AnthropicAdapter()

View file

@ -5,9 +5,12 @@ import traceback
from typing import Literal, Optional, Union
import dotenv
from pydantic import BaseModel
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.llms.openai import ChatCompletionRequest
from litellm.types.utils import ModelResponse
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
@ -55,6 +58,30 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
def pre_call_check(self, deployment: dict) -> Optional[dict]:
pass
#### ADAPTERS #### Allow calling 100+ LLMs in custom format - https://github.com/BerriAI/litellm/pulls
def translate_completion_input_params(
self, kwargs
) -> Optional[ChatCompletionRequest]:
"""
Translates the input params, from the provider's native format to the litellm.completion() format.
"""
pass
def translate_completion_output_params(
self, response: ModelResponse
) -> Optional[BaseModel]:
"""
Translates the output params, from the OpenAI format to the custom format.
"""
pass
def translate_completion_output_params_streaming(self) -> Optional[BaseModel]:
"""
Translates the streaming chunk, from the OpenAI format to the custom format.
"""
pass
#### CALL HOOKS - proxy only ####
"""
Control the modify incoming / outgoung data before calling the model

View file

@ -20,19 +20,43 @@ from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
)
from litellm.types.llms.anthropic import (
AnthopicMessagesAssistantMessageParam,
AnthropicFinishReason,
AnthropicMessagesRequest,
AnthropicMessagesTool,
AnthropicMessagesToolChoice,
AnthropicMessagesUserMessageParam,
AnthropicResponse,
AnthropicResponseContentBlockText,
AnthropicResponseContentBlockToolUse,
AnthropicResponseUsageBlock,
ContentBlockDelta,
ContentBlockStart,
MessageBlockDelta,
MessageStartBlock,
)
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionAssistantToolCall,
ChatCompletionImageObject,
ChatCompletionImageUrlObject,
ChatCompletionRequest,
ChatCompletionResponseMessage,
ChatCompletionSystemMessage,
ChatCompletionTextObject,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionToolChoiceFunctionParam,
ChatCompletionToolChoiceObjectParam,
ChatCompletionToolChoiceValues,
ChatCompletionToolMessage,
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
ChatCompletionUsageBlock,
ChatCompletionUserMessage,
)
from litellm.types.utils import GenericStreamingChunk
from litellm.types.utils import Choices, GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .base import BaseLLM
@ -168,6 +192,287 @@ class AnthropicConfig:
optional_params["top_p"] = value
return optional_params
### FOR [BETA] `/v1/messages` endpoint support
def translatable_anthropic_params(self) -> List:
"""
Which anthropic params, we need to translate to the openai format.
"""
return ["messages", "metadata", "system", "tool_choice", "tools"]
def translate_anthropic_messages_to_openai(
self,
messages: List[
Union[
AnthropicMessagesUserMessageParam,
AnthopicMessagesAssistantMessageParam,
]
],
) -> List:
new_messages: List[AllMessageValues] = []
for m in messages:
user_message: Optional[ChatCompletionUserMessage] = None
tool_message_list: List[ChatCompletionToolMessage] = []
## USER MESSAGE ##
if m["role"] == "user":
## translate user message
if isinstance(m["content"], str):
user_message = ChatCompletionUserMessage(
role="user", content=m["content"]
)
elif isinstance(m["content"], list):
new_user_content_list: List[
Union[ChatCompletionTextObject, ChatCompletionImageObject]
] = []
for content in m["content"]:
if content["type"] == "text":
text_obj = ChatCompletionTextObject(
type="text", text=content["text"]
)
new_user_content_list.append(text_obj)
elif content["type"] == "image":
image_url = ChatCompletionImageUrlObject(
url=f"data:{content['type']};base64,{content['source']}"
)
image_obj = ChatCompletionImageObject(
type="image_url", image_url=image_url
)
new_user_content_list.append(image_obj)
elif content["type"] == "tool_result":
if "content" not in content:
tool_result = ChatCompletionToolMessage(
role="tool",
tool_call_id=content["tool_use_id"],
content="",
)
tool_message_list.append(tool_result)
elif isinstance(content["content"], str):
tool_result = ChatCompletionToolMessage(
role="tool",
tool_call_id=content["tool_use_id"],
content=content["content"],
)
tool_message_list.append(tool_result)
elif isinstance(content["content"], list):
for c in content["content"]:
if c["type"] == "text":
tool_result = ChatCompletionToolMessage(
role="tool",
tool_call_id=content["tool_use_id"],
content=c["text"],
)
tool_message_list.append(tool_result)
elif c["type"] == "image":
image_str = (
f"data:{c['type']};base64,{c['source']}"
)
tool_result = ChatCompletionToolMessage(
role="tool",
tool_call_id=content["tool_use_id"],
content=image_str,
)
tool_message_list.append(tool_result)
if user_message is not None:
new_messages.append(user_message)
if len(tool_message_list) > 0:
new_messages.extend(tool_message_list)
## ASSISTANT MESSAGE ##
assistant_message_str: Optional[str] = None
tool_calls: List[ChatCompletionAssistantToolCall] = []
if m["role"] == "assistant":
if isinstance(m["content"], str):
assistant_message_str = m["content"]
elif isinstance(m["content"], list):
for content in m["content"]:
if content["type"] == "text":
if assistant_message_str is None:
assistant_message_str = content["text"]
else:
assistant_message_str += content["text"]
elif content["type"] == "tool_use":
function_chunk = ChatCompletionToolCallFunctionChunk(
name=content["name"],
arguments=json.dumps(content["input"]),
)
tool_calls.append(
ChatCompletionAssistantToolCall(
id=content["id"],
type="function",
function=function_chunk,
)
)
if assistant_message_str is not None or len(tool_calls) > 0:
assistant_message = ChatCompletionAssistantMessage(
role="assistant",
content=assistant_message_str,
)
if len(tool_calls) > 0:
assistant_message["tool_calls"] = tool_calls
new_messages.append(assistant_message)
return new_messages
def translate_anthropic_tool_choice_to_openai(
self, tool_choice: AnthropicMessagesToolChoice
) -> ChatCompletionToolChoiceValues:
if tool_choice["type"] == "any":
return "required"
elif tool_choice["type"] == "auto":
return "auto"
elif tool_choice["type"] == "tool":
tc_function_param = ChatCompletionToolChoiceFunctionParam(
name=tool_choice.get("name", "")
)
return ChatCompletionToolChoiceObjectParam(
type="function", function=tc_function_param
)
else:
raise ValueError(
"Incompatible tool choice param submitted - {}".format(tool_choice)
)
def translate_anthropic_tools_to_openai(
self, tools: List[AnthropicMessagesTool]
) -> List[ChatCompletionToolParam]:
new_tools: List[ChatCompletionToolParam] = []
for tool in tools:
function_chunk = ChatCompletionToolParamFunctionChunk(
name=tool["name"],
parameters=tool["input_schema"],
)
if "description" in tool:
function_chunk["description"] = tool["description"]
new_tools.append(
ChatCompletionToolParam(type="function", function=function_chunk)
)
return new_tools
def translate_anthropic_to_openai(
self, anthropic_message_request: AnthropicMessagesRequest
) -> ChatCompletionRequest:
"""
This is used by the beta Anthropic Adapter, for translating anthropic `/v1/messages` requests to the openai format.
"""
new_messages: List[AllMessageValues] = []
## CONVERT ANTHROPIC MESSAGES TO OPENAI
new_messages = self.translate_anthropic_messages_to_openai(
messages=anthropic_message_request["messages"]
)
## ADD SYSTEM MESSAGE TO MESSAGES
if "system" in anthropic_message_request:
new_messages.insert(
0,
ChatCompletionSystemMessage(
role="system", content=anthropic_message_request["system"]
),
)
new_kwargs: ChatCompletionRequest = {
"model": anthropic_message_request["model"],
"messages": new_messages,
}
## CONVERT METADATA (user_id)
if "metadata" in anthropic_message_request:
if "user_id" in anthropic_message_request["metadata"]:
new_kwargs["user"] = anthropic_message_request["metadata"]["user_id"]
## CONVERT TOOL CHOICE
if "tool_choice" in anthropic_message_request:
new_kwargs["tool_choice"] = self.translate_anthropic_tool_choice_to_openai(
tool_choice=anthropic_message_request["tool_choice"]
)
## CONVERT TOOLS
if "tools" in anthropic_message_request:
new_kwargs["tools"] = self.translate_anthropic_tools_to_openai(
tools=anthropic_message_request["tools"]
)
translatable_params = self.translatable_anthropic_params()
for k, v in anthropic_message_request.items():
if k not in translatable_params: # pass remaining params as is
new_kwargs[k] = v # type: ignore
return new_kwargs
def _translate_openai_content_to_anthropic(
self, choices: List[Choices]
) -> List[
Union[AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse]
]:
new_content: List[
Union[
AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse
]
] = []
for choice in choices:
if (
choice.message.tool_calls is not None
and len(choice.message.tool_calls) > 0
):
for tool_call in choice.message.tool_calls:
new_content.append(
AnthropicResponseContentBlockToolUse(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name or "",
input=json.loads(tool_call.function.arguments),
)
)
elif choice.message.content is not None:
new_content.append(
AnthropicResponseContentBlockText(
type="text", text=choice.message.content
)
)
return new_content
def _translate_openai_finish_reason_to_anthropic(
self, openai_finish_reason: str
) -> AnthropicFinishReason:
if openai_finish_reason == "stop":
return "end_turn"
elif openai_finish_reason == "length":
return "max_tokens"
elif openai_finish_reason == "tool_calls":
return "tool_use"
return "end_turn"
def translate_openai_response_to_anthropic(
self, response: litellm.ModelResponse
) -> AnthropicResponse:
## translate content block
anthropic_content = self._translate_openai_content_to_anthropic(choices=response.choices) # type: ignore
## extract finish reason
anthropic_finish_reason = self._translate_openai_finish_reason_to_anthropic(
openai_finish_reason=response.choices[0].finish_reason # type: ignore
)
# extract usage
usage: litellm.Usage = getattr(response, "usage")
anthropic_usage = AnthropicResponseUsageBlock(
input_tokens=usage.prompt_tokens, output_tokens=usage.completion_tokens
)
translated_obj = AnthropicResponse(
id=response.id,
type="message",
role="assistant",
model=response.model or "unknown-model",
stop_sequence=None,
usage=anthropic_usage,
content=anthropic_content,
stop_reason=anthropic_finish_reason,
)
return translated_obj
# makes headers for API call
def validate_environment(api_key, user_headers):
@ -231,121 +536,6 @@ class AnthropicChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
# def process_streaming_response(
# self,
# model: str,
# response: Union[requests.Response, httpx.Response],
# model_response: ModelResponse,
# stream: bool,
# logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
# optional_params: dict,
# api_key: str,
# data: Union[dict, str],
# messages: List,
# print_verbose,
# encoding,
# ) -> CustomStreamWrapper:
# """
# Return stream object for tool-calling + streaming
# """
# ## LOGGING
# logging_obj.post_call(
# input=messages,
# api_key=api_key,
# original_response=response.text,
# additional_args={"complete_input_dict": data},
# )
# print_verbose(f"raw model_response: {response.text}")
# ## RESPONSE OBJECT
# try:
# completion_response = response.json()
# except:
# raise AnthropicError(
# message=response.text, status_code=response.status_code
# )
# text_content = ""
# tool_calls = []
# for content in completion_response["content"]:
# if content["type"] == "text":
# text_content += content["text"]
# ## TOOL CALLING
# elif content["type"] == "tool_use":
# tool_calls.append(
# {
# "id": content["id"],
# "type": "function",
# "function": {
# "name": content["name"],
# "arguments": json.dumps(content["input"]),
# },
# }
# )
# if "error" in completion_response:
# raise AnthropicError(
# message=str(completion_response["error"]),
# status_code=response.status_code,
# )
# _message = litellm.Message(
# tool_calls=tool_calls,
# content=text_content or None,
# )
# model_response.choices[0].message = _message # type: ignore
# model_response._hidden_params["original_response"] = completion_response[
# "content"
# ] # allow user to access raw anthropic tool calling response
# model_response.choices[0].finish_reason = map_finish_reason(
# completion_response["stop_reason"]
# )
# print_verbose("INSIDE ANTHROPIC STREAMING TOOL CALLING CONDITION BLOCK")
# # return an iterator
# streaming_model_response = ModelResponse(stream=True)
# streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore
# 0
# ].finish_reason
# # streaming_model_response.choices = [litellm.utils.StreamingChoices()]
# streaming_choice = litellm.utils.StreamingChoices()
# streaming_choice.index = model_response.choices[0].index
# _tool_calls = []
# print_verbose(
# f"type of model_response.choices[0]: {type(model_response.choices[0])}"
# )
# print_verbose(f"type of streaming_choice: {type(streaming_choice)}")
# if isinstance(model_response.choices[0], litellm.Choices):
# if getattr(
# model_response.choices[0].message, "tool_calls", None
# ) is not None and isinstance(
# model_response.choices[0].message.tool_calls, list
# ):
# for tool_call in model_response.choices[0].message.tool_calls:
# _tool_call = {**tool_call.dict(), "index": 0}
# _tool_calls.append(_tool_call)
# delta_obj = litellm.utils.Delta(
# content=getattr(model_response.choices[0].message, "content", None),
# role=model_response.choices[0].message.role,
# tool_calls=_tool_calls,
# )
# streaming_choice.delta = delta_obj
# streaming_model_response.choices = [streaming_choice]
# completion_stream = ModelResponseIterator(
# model_response=streaming_model_response
# )
# print_verbose(
# "Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
# )
# return CustomStreamWrapper(
# completion_stream=completion_stream,
# model=model,
# custom_llm_provider="cached_response",
# logging_obj=logging_obj,
# )
# else:
# raise AnthropicError(
# status_code=422,
# message="Unprocessable response object - {}".format(response.text),
# )
def process_response(
self,
model: str,

View file

@ -38,6 +38,7 @@ import dotenv
import httpx
import openai
import tiktoken
from pydantic import BaseModel
from typing_extensions import overload
import litellm
@ -48,6 +49,7 @@ from litellm import ( # type: ignore
get_litellm_params,
get_optional_params,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.utils import (
CustomStreamWrapper,
@ -3943,6 +3945,63 @@ def text_completion(
return text_completion_response
###### Adapter Completion ################
async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
"""
Implemented to handle async calls for adapter_completion()
"""
try:
translation_obj: Optional[CustomLogger] = None
for item in litellm.adapters:
if item["id"] == adapter_id:
translation_obj = item["adapter"]
if translation_obj is None:
raise ValueError(
"No matching adapter given. Received 'adapter_id'={}, litellm.adapters={}".format(
adapter_id, litellm.adapters
)
)
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
response: ModelResponse = await acompletion(**new_kwargs) # type: ignore
translated_response = translation_obj.translate_completion_output_params(
response=response
)
return translated_response
except Exception as e:
raise e
def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
translation_obj: Optional[CustomLogger] = None
for item in litellm.adapters:
if item["id"] == adapter_id:
translation_obj = item["adapter"]
if translation_obj is None:
raise ValueError(
"No matching adapter given. Received 'adapter_id'={}, litellm.adapters={}".format(
adapter_id, litellm.adapters
)
)
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
response: ModelResponse = completion(**new_kwargs) # type: ignore
translated_response = translation_obj.translate_completion_output_params(
response=response
)
return translated_response
##### Moderation #######################

View file

@ -2,18 +2,9 @@ model_list:
- model_name: "*"
litellm_params:
model: "openai/*"
- model_name: gemini-1.5-flash
- model_name: claude-3-5-sonnet-20240620
litellm_params:
model: gemini/gemini-1.5-flash
- model_name: whisper
litellm_params:
model: azure/azure-whisper
api_version: 2024-02-15-preview
api_base: os.environ/AZURE_EUROPE_API_BASE
api_key: os.environ/AZURE_EUROPE_API_KEY
model_info:
mode: audio_transcription
model: gpt-3.5-turbo
general_settings:

View file

@ -71,6 +71,11 @@ azure_api_key_header = APIKeyHeader(
auto_error=False,
description="Some older versions of the openai Python package will send an API-Key header with just the API key ",
)
anthropic_api_key_header = APIKeyHeader(
name="x-api-key",
auto_error=False,
description="If anthropic client used.",
)
def _get_bearer_token(
@ -87,6 +92,9 @@ async def user_api_key_auth(
request: Request,
api_key: str = fastapi.Security(api_key_header),
azure_api_key_header: str = fastapi.Security(azure_api_key_header),
anthropic_api_key_header: Optional[str] = fastapi.Security(
anthropic_api_key_header
),
) -> UserAPIKeyAuth:
from litellm.proxy.proxy_server import (
@ -114,6 +122,9 @@ async def user_api_key_auth(
elif isinstance(azure_api_key_header, str):
api_key = azure_api_key_header
elif isinstance(anthropic_api_key_header, str):
api_key = anthropic_api_key_header
parent_otel_span: Optional[Span] = None
if open_telemetry_logger is not None:
parent_otel_span = open_telemetry_logger.tracer.start_span(

View file

@ -215,6 +215,12 @@ from litellm.router import (
from litellm.router import ModelInfo as RouterModelInfo
from litellm.router import updateDeployment
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
from litellm.types.llms.anthropic import (
AnthropicMessagesRequest,
AnthropicResponse,
AnthropicResponseContentBlockText,
AnthropicResponseUsageBlock,
)
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import RouterGeneralSettings
@ -5041,6 +5047,198 @@ async def moderations(
)
#### ANTHROPIC ENDPOINTS ####
@router.post(
"/v1/messages",
tags=["[beta] Anthropic `/v1/messages`"],
dependencies=[Depends(user_api_key_auth)],
response_model=AnthropicResponse,
)
async def anthropic_response(
anthropic_data: AnthropicMessagesRequest,
fastapi_response: Response,
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
from litellm import adapter_completion
from litellm.adapters.anthropic_adapter import anthropic_adapter
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
data: dict = {**anthropic_data, "adapter_id": "anthropic"}
try:
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
)
if user_model:
data["model"] = user_model
data = await add_litellm_data_to_request(
data=data, # type: ignore
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
# override with user settings, these are params passed via cli
if user_temperature:
data["temperature"] = user_temperature
if user_request_timeout:
data["request_timeout"] = user_request_timeout
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]]
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
)
### ROUTE THE REQUESTs ###
router_model_names = llm_router.model_names if llm_router is not None else []
# skip router if user passed their key
if "api_key" in data:
llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
llm_response = asyncio.create_task(
llm_router.aadapter_completion(**data, specific_deployment=True)
)
elif (
llm_router is not None and data["model"] in llm_router.get_model_ids()
): # model in router model list
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "completion: Invalid model name passed in model="
+ data.get("model", "")
},
)
# Await the llm_response task
response = await llm_response
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
verbose_proxy_logger.debug("final response: %s", response)
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
)
)
verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response))
return response
except RejectedRequestError as e:
_data = e.request_data
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
)
if _data.get("stream", None) is not None and _data["stream"] == True:
_chat_response = litellm.ModelResponse()
_usage = litellm.Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
_chat_response.usage = _usage # type: ignore
_chat_response.choices[0].message.content = e.message # type: ignore
_iterator = litellm.utils.ModelResponseIterator(
model_response=_chat_response, convert_to_delta=True
)
_streaming_response = litellm.TextCompletionStreamWrapper(
completion_stream=_iterator,
model=_data.get("model", ""),
)
selected_data_generator = select_data_generator(
response=_streaming_response,
user_api_key_dict=user_api_key_dict,
request_data=data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
headers={},
)
else:
_response = litellm.TextCompletionResponse()
_response.choices[0].text = e.message
return _response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.completion(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
#### DEV UTILS ####
# @router.get(

View file

@ -1765,6 +1765,125 @@ class Router:
self.fail_calls[model] += 1
raise e
async def aadapter_completion(
self,
adapter_id: str,
model: str,
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
is_async: Optional[bool] = False,
**kwargs,
):
try:
kwargs["model"] = model
kwargs["adapter_id"] = adapter_id
kwargs["original_function"] = self._aadapter_completion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
except Exception as e:
asyncio.create_task(
send_llm_exception_alert(
litellm_router_instance=self,
request_kwargs=kwargs,
error_traceback_str=traceback.format_exc(),
original_exception=e,
)
)
raise e
async def _aadapter_completion(self, adapter_id: str, model: str, **kwargs):
try:
verbose_router_logger.debug(
f"Inside _aadapter_completion()- model: {model}; kwargs: {kwargs}"
)
deployment = await self.async_get_available_deployment(
model=model,
messages=[{"role": "user", "content": "default text"}],
specific_deployment=kwargs.pop("specific_deployment", None),
)
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
"api_base": deployment.get("litellm_params", {}).get("api_base"),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
): # prioritize model-specific params > default router params
kwargs[k] = v
elif k == "metadata":
kwargs[k].update(v)
potential_model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="async"
)
# check if provided keys == client keys #
dynamic_api_key = kwargs.get("api_key", None)
if (
dynamic_api_key is not None
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
else:
model_client = potential_model_client
self.total_calls[model_name] += 1
response = litellm.aadapter_completion(
**{
**data,
"adapter_id": adapter_id,
"caching": self.cache_responses,
"client": model_client,
"timeout": self.timeout,
**kwargs,
}
)
rpm_semaphore = self._get_client(
deployment=deployment,
kwargs=kwargs,
client_type="max_parallel_requests",
)
if rpm_semaphore is not None and isinstance(
rpm_semaphore, asyncio.Semaphore
):
async with rpm_semaphore:
"""
- Check rpm limits before making the call
- If allowed, increment the rpm limit (allows global value to be updated, concurrency-safe)
"""
await self.async_routing_strategy_pre_call_checks(
deployment=deployment
)
response = await response # type: ignore
else:
await self.async_routing_strategy_pre_call_checks(deployment=deployment)
response = await response # type: ignore
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"litellm.aadapter_completion(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"litellm.aadapter_completion(model={model})\033[31m Exception {str(e)}\033[0m"
)
if model is not None:
self.fail_calls[model] += 1
raise e
def embedding(
self,
model: str,

View file

@ -0,0 +1,103 @@
# What is this?
## Unit tests for Anthropic Adapter
import asyncio
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from unittest.mock import MagicMock, patch
import pytest
import litellm
from litellm import AnthropicConfig, Router, adapter_completion
from litellm.adapters.anthropic_adapter import anthropic_adapter
from litellm.types.llms.anthropic import AnthropicResponse
def test_anthropic_completion_messages_translation():
messages = [{"role": "user", "content": "Hey, how's it going?"}]
translated_messages = AnthropicConfig().translate_anthropic_messages_to_openai(messages=messages) # type: ignore
assert translated_messages == [{"role": "user", "content": "Hey, how's it going?"}]
def test_anthropic_completion_input_translation():
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hey, how's it going?"}],
}
translated_input = anthropic_adapter.translate_completion_input_params(kwargs=data)
assert translated_input is not None
assert translated_input["model"] == "gpt-3.5-turbo"
assert translated_input["messages"] == [
{"role": "user", "content": "Hey, how's it going?"}
]
def test_anthropic_completion_e2e():
litellm.set_verbose = True
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response = adapter_completion(
model="gpt-3.5-turbo",
messages=messages,
adapter_id="anthropic",
mock_response="This is a fake call",
)
print("Response: {}".format(response))
assert response is not None
assert isinstance(response, AnthropicResponse)
@pytest.mark.asyncio
async def test_anthropic_router_completion_e2e():
litellm.set_verbose = True
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
router = Router(
model_list=[
{
"model_name": "claude-3-5-sonnet-20240620",
"litellm_params": {
"model": "gpt-3.5-turbo",
"mock_response": "hi this is macintosh.",
},
}
]
)
messages = [{"role": "user", "content": "Hey, how's it going?"}]
response = await router.aadapter_completion(
model="claude-3-5-sonnet-20240620",
messages=messages,
adapter_id="anthropic",
mock_response="This is a fake call",
)
print("Response: {}".format(response))
assert response is not None
assert isinstance(response, AnthropicResponse)
assert response.model == "gpt-3.5-turbo"

10
litellm/types/adapter.py Normal file
View file

@ -0,0 +1,10 @@
from typing import List
from typing_extensions import Dict, Required, TypedDict, override
from litellm.integrations.custom_logger import CustomLogger
class AdapterItem(TypedDict):
id: str
adapter: CustomLogger

View file

@ -9,25 +9,27 @@ class AnthropicMessagesToolChoice(TypedDict, total=False):
name: str
class AnthopicMessagesAssistantMessageTextContentParam(TypedDict, total=False):
type: Required[Literal["text"]]
class AnthropicMessagesTool(TypedDict, total=False):
name: Required[str]
description: str
input_schema: Required[dict]
class AnthropicMessagesTextParam(TypedDict):
type: Literal["text"]
text: str
class AnthopicMessagesAssistantMessageToolCallParam(TypedDict, total=False):
type: Required[Literal["tool_use"]]
class AnthropicMessagesToolUseParam(TypedDict):
type: Literal["tool_use"]
id: str
name: str
input: dict
AnthropicMessagesAssistantMessageValues = Union[
AnthopicMessagesAssistantMessageTextContentParam,
AnthopicMessagesAssistantMessageToolCallParam,
AnthropicMessagesTextParam,
AnthropicMessagesToolUseParam,
]
@ -46,6 +48,72 @@ class AnthopicMessagesAssistantMessageParam(TypedDict, total=False):
"""
class AnthropicImageParamSource(TypedDict):
type: Literal["base64"]
media_type: str
data: str
class AnthropicMessagesImageParam(TypedDict):
type: Literal["image"]
source: AnthropicImageParamSource
class AnthropicMessagesToolResultContent(TypedDict):
type: Literal["text"]
text: str
class AnthropicMessagesToolResultParam(TypedDict, total=False):
type: Required[Literal["tool_result"]]
tool_use_id: Required[str]
is_error: bool
content: Union[
str,
Iterable[
Union[AnthropicMessagesToolResultContent, AnthropicMessagesImageParam]
],
]
AnthropicMessagesUserMessageValues = Union[
AnthropicMessagesTextParam,
AnthropicMessagesImageParam,
AnthropicMessagesToolResultParam,
]
class AnthropicMessagesUserMessageParam(TypedDict, total=False):
role: Required[Literal["user"]]
content: Required[Union[str, Iterable[AnthropicMessagesUserMessageValues]]]
class AnthropicMetadata(TypedDict, total=False):
user_id: str
class AnthropicMessagesRequest(TypedDict, total=False):
model: Required[str]
messages: Required[
List[
Union[
AnthropicMessagesUserMessageParam,
AnthopicMessagesAssistantMessageParam,
]
]
]
max_tokens: Required[int]
metadata: AnthropicMetadata
stop_sequences: List[str]
stream: bool
system: str
temperature: float
tool_choice: AnthropicMessagesToolChoice
tools: List[AnthropicMessagesTool]
top_k: int
top_p: float
class ContentTextBlockDelta(TypedDict):
"""
'delta': {'type': 'text_delta', 'text': 'Hello'}
@ -155,3 +223,51 @@ class MessageStartBlock(TypedDict):
type: Literal["message_start"]
message: MessageChunk
class AnthropicResponseContentBlockText(BaseModel):
type: Literal["text"]
text: str
class AnthropicResponseContentBlockToolUse(BaseModel):
type: Literal["tool_use"]
id: str
name: str
input: dict
class AnthropicResponseUsageBlock(BaseModel):
input_tokens: int
output_tokens: int
AnthropicFinishReason = Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
class AnthropicResponse(BaseModel):
id: str
"""Unique object identifier."""
type: Literal["message"]
"""For Messages, this is always "message"."""
role: Literal["assistant"]
"""Conversational role of the generated message. This will always be "assistant"."""
content: List[
Union[AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse]
]
"""Content generated by the model."""
model: str
"""The model that handled the request."""
stop_reason: Optional[AnthropicFinishReason]
"""The reason that we stopped."""
stop_sequence: Optional[str]
"""Which custom stop sequence was generated, if any."""
usage: AnthropicResponseUsageBlock
"""Billing and rate-limit usage."""

View file

@ -305,7 +305,13 @@ class ChatCompletionToolCallFunctionChunk(TypedDict, total=False):
arguments: str
class ChatCompletionToolCallChunk(TypedDict):
class ChatCompletionAssistantToolCall(TypedDict):
id: Optional[str]
type: Literal["function"]
function: ChatCompletionToolCallFunctionChunk
class ChatCompletionToolCallChunk(TypedDict): # result of /chat/completions call
id: Optional[str]
type: Literal["function"]
function: ChatCompletionToolCallFunctionChunk
@ -319,6 +325,107 @@ class ChatCompletionDeltaToolCallChunk(TypedDict, total=False):
index: int
class ChatCompletionTextObject(TypedDict):
type: Literal["text"]
text: str
class ChatCompletionImageUrlObject(TypedDict, total=False):
url: Required[str]
detail: str
class ChatCompletionImageObject(TypedDict):
type: Literal["image_url"]
image_url: ChatCompletionImageUrlObject
class ChatCompletionUserMessage(TypedDict):
role: Literal["user"]
content: Union[
str, Iterable[Union[ChatCompletionTextObject, ChatCompletionImageObject]]
]
class ChatCompletionAssistantMessage(TypedDict, total=False):
role: Required[Literal["assistant"]]
content: Optional[str]
name: str
tool_calls: List[ChatCompletionAssistantToolCall]
class ChatCompletionToolMessage(TypedDict):
role: Literal["tool"]
content: str
tool_call_id: str
class ChatCompletionSystemMessage(TypedDict, total=False):
role: Required[Literal["system"]]
content: Required[str]
name: str
AllMessageValues = Union[
ChatCompletionUserMessage,
ChatCompletionAssistantMessage,
ChatCompletionToolMessage,
ChatCompletionSystemMessage,
]
class ChatCompletionToolChoiceFunctionParam(TypedDict):
name: str
class ChatCompletionToolChoiceObjectParam(TypedDict):
type: Literal["function"]
function: ChatCompletionToolChoiceFunctionParam
ChatCompletionToolChoiceStringValues = Literal["none", "auto", "required"]
ChatCompletionToolChoiceValues = Union[
ChatCompletionToolChoiceStringValues, ChatCompletionToolChoiceObjectParam
]
class ChatCompletionToolParamFunctionChunk(TypedDict, total=False):
name: Required[str]
description: str
parameters: dict
class ChatCompletionToolParam(TypedDict):
type: Literal["function"]
function: ChatCompletionToolParamFunctionChunk
class ChatCompletionRequest(TypedDict, total=False):
model: Required[str]
messages: Required[List[AllMessageValues]]
frequency_penalty: float
logit_bias: dict
logprobs: bool
top_logprobs: int
max_tokens: int
n: int
presence_penalty: float
response_format: dict
seed: int
service_tier: str
stop: Union[str, List[str]]
stream_options: dict
temperature: float
top_p: float
tools: List[ChatCompletionToolParam]
tool_choice: ChatCompletionToolChoiceValues
parallel_tool_calls: bool
function_call: Union[str, dict]
functions: List
user: str
class ChatCompletionDeltaChunk(TypedDict, total=False):
content: Optional[str]
tool_calls: List[ChatCompletionDeltaToolCallChunk]

View file

@ -166,7 +166,9 @@ class FunctionCall(OpenAIObject):
class Function(OpenAIObject):
arguments: str
name: Optional[str] = None
name: Optional[
str
] # can be None - openai e.g.: ChoiceDeltaToolCallFunction(arguments='{"', name=None), type=None)
def __init__(
self,
@ -280,29 +282,43 @@ class ChatCompletionMessageToolCall(OpenAIObject):
setattr(self, key, value)
"""
Reference:
ChatCompletionMessage(content='This is a test', role='assistant', function_call=None, tool_calls=None))
"""
class Message(OpenAIObject):
content: Optional[str]
role: Literal["assistant"]
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
function_call: Optional[FunctionCall]
def __init__(
self,
content: Optional[str] = "default",
role="assistant",
logprobs=None,
content: Optional[str] = None,
role: Literal["assistant"] = "assistant",
function_call=None,
tool_calls=None,
**params,
):
super(Message, self).__init__(**params)
self.content = content
self.role = role
if function_call is not None:
self.function_call = FunctionCall(**function_call)
if tool_calls is not None:
self.tool_calls = []
for tool_call in tool_calls:
self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call))
if logprobs is not None:
self._logprobs = ChoiceLogprobs(**logprobs)
init_values = {
"content": content,
"role": role,
"function_call": (
FunctionCall(**function_call) if function_call is not None else None
),
"tool_calls": (
[ChatCompletionMessageToolCall(**tool_call) for tool_call in tool_calls]
if tool_calls is not None
else None
),
}
super(Message, self).__init__(
**init_values,
**params,
)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist

View file

@ -8126,7 +8126,7 @@ class CustomStreamWrapper:
if chunk.startswith(self.complete_response):
# Remove last_sent_chunk only if it appears at the start of the new chunk
chunk = chunk[len(self.complete_response):]
chunk = chunk[len(self.complete_response) :]
self.complete_response += chunk
return chunk
@ -9483,8 +9483,8 @@ class CustomStreamWrapper:
model_response.choices[0].delta = Delta(**_json_delta)
except Exception as e:
verbose_logger.error(
"litellm.CustomStreamWrapper.chunk_creator(): Exception occured - {}".format(
str(e)
"litellm.CustomStreamWrapper.chunk_creator(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
verbose_logger.debug(traceback.format_exc())
@ -10124,7 +10124,7 @@ def mock_completion_streaming_obj(
model_response, mock_response, model, n: Optional[int] = None
):
for i in range(0, len(mock_response), 3):
completion_obj = Delta(role="assistant", content=mock_response[i: i + 3])
completion_obj = Delta(role="assistant", content=mock_response[i : i + 3])
if n is None:
model_response.choices[0].delta = completion_obj
else:
@ -10133,7 +10133,7 @@ def mock_completion_streaming_obj(
_streaming_choice = litellm.utils.StreamingChoices(
index=j,
delta=litellm.utils.Delta(
role="assistant", content=mock_response[i: i + 3]
role="assistant", content=mock_response[i : i + 3]
),
)
_all_choices.append(_streaming_choice)
@ -10145,7 +10145,7 @@ async def async_mock_completion_streaming_obj(
model_response, mock_response, model, n: Optional[int] = None
):
for i in range(0, len(mock_response), 3):
completion_obj = Delta(role="assistant", content=mock_response[i: i + 3])
completion_obj = Delta(role="assistant", content=mock_response[i : i + 3])
if n is None:
model_response.choices[0].delta = completion_obj
else:
@ -10154,7 +10154,7 @@ async def async_mock_completion_streaming_obj(
_streaming_choice = litellm.utils.StreamingChoices(
index=j,
delta=litellm.utils.Delta(
role="assistant", content=mock_response[i: i + 3]
role="assistant", content=mock_response[i : i + 3]
),
)
_all_choices.append(_streaming_choice)